{-# LANGUAGE TypeFamilies #-}
module Futhark.AD.Fwd (fwdJVP) where
import Control.Monad
import Control.Monad.RWS.Strict
import Control.Monad.State.Strict
import Data.Bifunctor (second)
import Data.List (transpose)
import Data.List.NonEmpty (NonEmpty (..))
import Data.Map qualified as M
import Futhark.AD.Derivatives
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.Construct
import Futhark.IR.SOACS
zeroTan :: Type -> ADM SubExp
zeroTan :: Type -> ADM SubExp
zeroTan (Prim PrimType
t) = SubExp -> ADM SubExp
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> ADM SubExp) -> SubExp -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
forall v. IsValue v => v -> SubExp
constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t
zeroTan Type
t = String -> ADM SubExp
forall a. HasCallStack => String -> a
error (String -> ADM SubExp) -> String -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ String
"zeroTan on non-primitive type: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Pretty a => a -> String
prettyString Type
t
zeroExp :: Type -> Exp SOACS
zeroExp :: Type -> Exp SOACS
zeroExp (Prim PrimType
pt) =
BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
pt
zeroExp (Array PrimType
pt Shape
shape NoUniqueness
_) =
BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
shape (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
pt
zeroExp Type
t = String -> Exp SOACS
forall a. HasCallStack => String -> a
error (String -> Exp SOACS) -> String -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ String
"zeroExp: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Show a => a -> String
show Type
t
tanType :: TypeBase s u -> ADM (TypeBase s u)
tanType :: forall s u. TypeBase s u -> ADM (TypeBase s u)
tanType (Acc VName
acc Shape
ispace [Type]
ts u
u) = do
ts_tan <- (Type -> ADM Type) -> [Type] -> ADM [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Type -> ADM Type
forall s u. TypeBase s u -> ADM (TypeBase s u)
tanType [Type]
ts
pure $ Acc acc ispace (ts ++ ts_tan) u
tanType TypeBase s u
t = TypeBase s u -> ADM (TypeBase s u)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TypeBase s u
t
slocal' :: ADM a -> ADM a
slocal' :: forall a. ADM a -> ADM a
slocal' = (RState -> RState) -> ADM a -> ADM a
forall a. (RState -> RState) -> ADM a -> ADM a
slocal RState -> RState
forall a. a -> a
id
slocal :: (RState -> RState) -> ADM a -> ADM a
slocal :: forall a. (RState -> RState) -> ADM a -> ADM a
slocal RState -> RState
f ADM a
m = do
s <- ADM RState
forall s (m :: * -> *). MonadState s m => m s
get
modify f
a <- m
modify $ \RState
s' -> RState
s' {stateTans = stateTans s}
pure a
data RState = RState
{ RState -> Map VName VName
stateTans :: M.Map VName VName,
RState -> VNameSource
stateNameSource :: VNameSource
}
newtype ADM a = ADM (BuilderT SOACS (State RState) a)
deriving
( (forall a b. (a -> b) -> ADM a -> ADM b)
-> (forall a b. a -> ADM b -> ADM a) -> Functor ADM
forall a b. a -> ADM b -> ADM a
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> ADM a -> ADM b
fmap :: forall a b. (a -> b) -> ADM a -> ADM b
$c<$ :: forall a b. a -> ADM b -> ADM a
<$ :: forall a b. a -> ADM b -> ADM a
Functor,
Functor ADM
Functor ADM =>
(forall a. a -> ADM a)
-> (forall a b. ADM (a -> b) -> ADM a -> ADM b)
-> (forall a b c. (a -> b -> c) -> ADM a -> ADM b -> ADM c)
-> (forall a b. ADM a -> ADM b -> ADM b)
-> (forall a b. ADM a -> ADM b -> ADM a)
-> Applicative ADM
forall a. a -> ADM a
forall a b. ADM a -> ADM b -> ADM a
forall a b. ADM a -> ADM b -> ADM b
forall a b. ADM (a -> b) -> ADM a -> ADM b
forall a b c. (a -> b -> c) -> ADM a -> ADM b -> ADM c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall a. a -> ADM a
pure :: forall a. a -> ADM a
$c<*> :: forall a b. ADM (a -> b) -> ADM a -> ADM b
<*> :: forall a b. ADM (a -> b) -> ADM a -> ADM b
$cliftA2 :: forall a b c. (a -> b -> c) -> ADM a -> ADM b -> ADM c
liftA2 :: forall a b c. (a -> b -> c) -> ADM a -> ADM b -> ADM c
$c*> :: forall a b. ADM a -> ADM b -> ADM b
*> :: forall a b. ADM a -> ADM b -> ADM b
$c<* :: forall a b. ADM a -> ADM b -> ADM a
<* :: forall a b. ADM a -> ADM b -> ADM a
Applicative,
Applicative ADM
Applicative ADM =>
(forall a b. ADM a -> (a -> ADM b) -> ADM b)
-> (forall a b. ADM a -> ADM b -> ADM b)
-> (forall a. a -> ADM a)
-> Monad ADM
forall a. a -> ADM a
forall a b. ADM a -> ADM b -> ADM b
forall a b. ADM a -> (a -> ADM b) -> ADM b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall a b. ADM a -> (a -> ADM b) -> ADM b
>>= :: forall a b. ADM a -> (a -> ADM b) -> ADM b
$c>> :: forall a b. ADM a -> ADM b -> ADM b
>> :: forall a b. ADM a -> ADM b -> ADM b
$creturn :: forall a. a -> ADM a
return :: forall a. a -> ADM a
Monad,
MonadState RState,
Monad ADM
ADM VNameSource
Monad ADM =>
ADM VNameSource -> (VNameSource -> ADM ()) -> MonadFreshNames ADM
VNameSource -> ADM ()
forall (m :: * -> *).
Monad m =>
m VNameSource -> (VNameSource -> m ()) -> MonadFreshNames m
$cgetNameSource :: ADM VNameSource
getNameSource :: ADM VNameSource
$cputNameSource :: VNameSource -> ADM ()
putNameSource :: VNameSource -> ADM ()
MonadFreshNames,
HasScope SOACS,
LocalScope SOACS
)
instance MonadBuilder ADM where
type Rep ADM = SOACS
mkExpDecM :: Pat (LetDec (Rep ADM)) -> Exp (Rep ADM) -> ADM (ExpDec (Rep ADM))
mkExpDecM Pat (LetDec (Rep ADM))
pat Exp (Rep ADM)
e = BuilderT SOACS (State RState) (ExpDec (Rep ADM))
-> ADM (ExpDec (Rep ADM))
forall a. BuilderT SOACS (State RState) a -> ADM a
ADM (BuilderT SOACS (State RState) (ExpDec (Rep ADM))
-> ADM (ExpDec (Rep ADM)))
-> BuilderT SOACS (State RState) (ExpDec (Rep ADM))
-> ADM (ExpDec (Rep ADM))
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (BuilderT SOACS (State RState))))
-> Exp (Rep (BuilderT SOACS (State RState)))
-> BuilderT
SOACS (State RState) (ExpDec (Rep (BuilderT SOACS (State RState))))
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m (ExpDec (Rep m))
mkExpDecM Pat (LetDec (Rep (BuilderT SOACS (State RState))))
Pat (LetDec (Rep ADM))
pat Exp (Rep (BuilderT SOACS (State RState)))
Exp (Rep ADM)
e
mkBodyM :: Stms (Rep ADM) -> Result -> ADM (Body (Rep ADM))
mkBodyM Stms (Rep ADM)
bnds Result
res = BuilderT SOACS (State RState) (Body (Rep ADM))
-> ADM (Body (Rep ADM))
forall a. BuilderT SOACS (State RState) a -> ADM a
ADM (BuilderT SOACS (State RState) (Body (Rep ADM))
-> ADM (Body (Rep ADM)))
-> BuilderT SOACS (State RState) (Body (Rep ADM))
-> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ Stms (Rep (BuilderT SOACS (State RState)))
-> Result
-> BuilderT
SOACS (State RState) (Body (Rep (BuilderT SOACS (State RState))))
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM Stms (Rep (BuilderT SOACS (State RState)))
Stms (Rep ADM)
bnds Result
res
mkLetNamesM :: [VName] -> Exp (Rep ADM) -> ADM (Stm (Rep ADM))
mkLetNamesM [VName]
pat Exp (Rep ADM)
e = BuilderT SOACS (State RState) (Stm (Rep ADM))
-> ADM (Stm (Rep ADM))
forall a. BuilderT SOACS (State RState) a -> ADM a
ADM (BuilderT SOACS (State RState) (Stm (Rep ADM))
-> ADM (Stm (Rep ADM)))
-> BuilderT SOACS (State RState) (Stm (Rep ADM))
-> ADM (Stm (Rep ADM))
forall a b. (a -> b) -> a -> b
$ [VName]
-> Exp (Rep (BuilderT SOACS (State RState)))
-> BuilderT
SOACS (State RState) (Stm (Rep (BuilderT SOACS (State RState))))
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m (Stm (Rep m))
mkLetNamesM [VName]
pat Exp (Rep (BuilderT SOACS (State RState)))
Exp (Rep ADM)
e
addStms :: Stms (Rep ADM) -> ADM ()
addStms = BuilderT SOACS (State RState) () -> ADM ()
forall a. BuilderT SOACS (State RState) a -> ADM a
ADM (BuilderT SOACS (State RState) () -> ADM ())
-> (Stms SOACS -> BuilderT SOACS (State RState) ())
-> Stms SOACS
-> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms (Rep (BuilderT SOACS (State RState)))
-> BuilderT SOACS (State RState) ()
Stms SOACS -> BuilderT SOACS (State RState) ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms
collectStms :: forall a. ADM a -> ADM (a, Stms (Rep ADM))
collectStms (ADM BuilderT SOACS (State RState) a
m) = BuilderT SOACS (State RState) (a, Stms (Rep ADM))
-> ADM (a, Stms (Rep ADM))
forall a. BuilderT SOACS (State RState) a -> ADM a
ADM (BuilderT SOACS (State RState) (a, Stms (Rep ADM))
-> ADM (a, Stms (Rep ADM)))
-> BuilderT SOACS (State RState) (a, Stms (Rep ADM))
-> ADM (a, Stms (Rep ADM))
forall a b. (a -> b) -> a -> b
$ BuilderT SOACS (State RState) a
-> BuilderT
SOACS
(State RState)
(a, Stms (Rep (BuilderT SOACS (State RState))))
forall a.
BuilderT SOACS (State RState) a
-> BuilderT
SOACS
(State RState)
(a, Stms (Rep (BuilderT SOACS (State RState))))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms BuilderT SOACS (State RState) a
m
instance MonadFreshNames (State RState) where
getNameSource :: State RState VNameSource
getNameSource = (RState -> VNameSource) -> State RState VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets RState -> VNameSource
stateNameSource
putNameSource :: VNameSource -> State RState ()
putNameSource VNameSource
src = (RState -> RState) -> State RState ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\RState
env -> RState
env {stateNameSource = src})
runADM :: (MonadFreshNames m) => ADM a -> m a
runADM :: forall (m :: * -> *) a. MonadFreshNames m => ADM a -> m a
runADM (ADM BuilderT SOACS (State RState) a
m) =
(VNameSource -> (a, VNameSource)) -> m a
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (a, VNameSource)) -> m a)
-> (VNameSource -> (a, VNameSource)) -> m a
forall a b. (a -> b) -> a -> b
$ \VNameSource
vn ->
(RState -> VNameSource) -> (a, RState) -> (a, VNameSource)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second RState -> VNameSource
stateNameSource ((a, RState) -> (a, VNameSource))
-> (a, RState) -> (a, VNameSource)
forall a b. (a -> b) -> a -> b
$
State RState a -> RState -> (a, RState)
forall s a. State s a -> s -> (a, s)
runState
((a, Stms SOACS) -> a
forall a b. (a, b) -> a
fst ((a, Stms SOACS) -> a)
-> StateT RState Identity (a, Stms SOACS) -> State RState a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BuilderT SOACS (State RState) a
-> Scope SOACS -> StateT RState Identity (a, Stms SOACS)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT BuilderT SOACS (State RState) a
m Scope SOACS
forall a. Monoid a => a
mempty)
(Map VName VName -> VNameSource -> RState
RState Map VName VName
forall a. Monoid a => a
mempty VNameSource
vn)
tanVName :: VName -> ADM VName
tanVName :: VName -> ADM VName
tanVName VName
v = String -> ADM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (VName -> String
baseString VName
v String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_tan")
insertTan :: VName -> VName -> ADM ()
insertTan :: VName -> VName -> ADM ()
insertTan VName
v VName
v' =
(RState -> RState) -> ADM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((RState -> RState) -> ADM ()) -> (RState -> RState) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \RState
env -> RState
env {stateTans = M.insert v v' (stateTans env)}
class TanBuilder a where
newTan :: a -> ADM a
bundleNew :: a -> ADM [a]
bundleNewList :: (TanBuilder a) => [a] -> ADM [a]
bundleNewList :: forall a. TanBuilder a => [a] -> ADM [a]
bundleNewList = ([[a]] -> [a]) -> ADM [[a]] -> ADM [a]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[a]] -> [a]
forall a. Monoid a => [a] -> a
mconcat (ADM [[a]] -> ADM [a]) -> ([a] -> ADM [[a]]) -> [a] -> ADM [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> ADM [a]) -> [a] -> ADM [[a]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM a -> ADM [a]
forall a. TanBuilder a => a -> ADM [a]
bundleNew
instance TanBuilder (PatElem (TypeBase s u)) where
newTan :: PatElem (TypeBase s u) -> ADM (PatElem (TypeBase s u))
newTan (PatElem VName
p TypeBase s u
t)
| TypeBase s u -> Bool
forall shape u. TypeBase shape u -> Bool
isAcc TypeBase s u
t = do
VName -> VName -> ADM ()
insertTan VName
p VName
p
t' <- TypeBase s u -> ADM (TypeBase s u)
forall s u. TypeBase s u -> ADM (TypeBase s u)
tanType TypeBase s u
t
pure $ PatElem p t'
| Bool
otherwise = do
p' <- VName -> ADM VName
tanVName VName
p
insertTan p p'
t' <- tanType t
pure $ PatElem p' t'
bundleNew :: PatElem (TypeBase s u) -> ADM [PatElem (TypeBase s u)]
bundleNew pe :: PatElem (TypeBase s u)
pe@(PatElem VName
_ TypeBase s u
t) = do
pe' <- PatElem (TypeBase s u) -> ADM (PatElem (TypeBase s u))
forall a. TanBuilder a => a -> ADM a
newTan PatElem (TypeBase s u)
pe
if isAcc t
then pure [pe']
else pure [pe, pe']
newTanPat :: (TanBuilder (PatElem t)) => Pat t -> ADM (Pat t)
newTanPat :: forall t. TanBuilder (PatElem t) => Pat t -> ADM (Pat t)
newTanPat (Pat [PatElem t]
pes) = [PatElem t] -> Pat t
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem t] -> Pat t) -> ADM [PatElem t] -> ADM (Pat t)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PatElem t -> ADM (PatElem t)) -> [PatElem t] -> ADM [PatElem t]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM PatElem t -> ADM (PatElem t)
forall a. TanBuilder a => a -> ADM a
newTan [PatElem t]
pes
bundleNewPat :: (TanBuilder (PatElem t)) => Pat t -> ADM (Pat t)
bundleNewPat :: forall t. TanBuilder (PatElem t) => Pat t -> ADM (Pat t)
bundleNewPat (Pat [PatElem t]
pes) = [PatElem t] -> Pat t
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem t] -> Pat t) -> ADM [PatElem t] -> ADM (Pat t)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [PatElem t] -> ADM [PatElem t]
forall a. TanBuilder a => [a] -> ADM [a]
bundleNewList [PatElem t]
pes
instance TanBuilder (Param (TypeBase s u)) where
newTan :: Param (TypeBase s u) -> ADM (Param (TypeBase s u))
newTan (Param Attrs
_ VName
p TypeBase s u
t) = do
PatElem p' t' <- PatElem (TypeBase s u) -> ADM (PatElem (TypeBase s u))
forall a. TanBuilder a => a -> ADM a
newTan (PatElem (TypeBase s u) -> ADM (PatElem (TypeBase s u)))
-> PatElem (TypeBase s u) -> ADM (PatElem (TypeBase s u))
forall a b. (a -> b) -> a -> b
$ VName -> TypeBase s u -> PatElem (TypeBase s u)
forall dec. VName -> dec -> PatElem dec
PatElem VName
p TypeBase s u
t
pure $ Param mempty p' t'
bundleNew :: Param (TypeBase s u) -> ADM [Param (TypeBase s u)]
bundleNew param :: Param (TypeBase s u)
param@(Param Attrs
_ VName
_ (Prim PrimType
Unit)) =
[Param (TypeBase s u)] -> ADM [Param (TypeBase s u)]
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Param (TypeBase s u)
param]
bundleNew param :: Param (TypeBase s u)
param@(Param Attrs
_ VName
_ TypeBase s u
t) = do
param' <- Param (TypeBase s u) -> ADM (Param (TypeBase s u))
forall a. TanBuilder a => a -> ADM a
newTan Param (TypeBase s u)
param
if isAcc t
then pure [param']
else pure [param, param']
instance (Tangent a) => TanBuilder (Param (TypeBase s u), a) where
newTan :: (Param (TypeBase s u), a) -> ADM (Param (TypeBase s u), a)
newTan (Param (TypeBase s u)
p, a
x) = (,) (Param (TypeBase s u) -> a -> (Param (TypeBase s u), a))
-> ADM (Param (TypeBase s u))
-> ADM (a -> (Param (TypeBase s u), a))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Param (TypeBase s u) -> ADM (Param (TypeBase s u))
forall a. TanBuilder a => a -> ADM a
newTan Param (TypeBase s u)
p ADM (a -> (Param (TypeBase s u), a))
-> ADM a -> ADM (Param (TypeBase s u), a)
forall a b. ADM (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> a -> ADM a
forall a. Tangent a => a -> ADM a
tangent a
x
bundleNew :: (Param (TypeBase s u), a) -> ADM [(Param (TypeBase s u), a)]
bundleNew (Param (TypeBase s u)
p, a
x) = do
b <- Param (TypeBase s u) -> ADM [Param (TypeBase s u)]
forall a. TanBuilder a => a -> ADM [a]
bundleNew Param (TypeBase s u)
p
x_tan <- tangent x
pure $ zip b [x, x_tan]
class Tangent a where
tangent :: a -> ADM a
bundleTan :: a -> ADM [a]
instance Tangent (TypeBase s u) where
tangent :: TypeBase s u -> ADM (TypeBase s u)
tangent = TypeBase s u -> ADM (TypeBase s u)
forall s u. TypeBase s u -> ADM (TypeBase s u)
tanType
bundleTan :: TypeBase s u -> ADM [TypeBase s u]
bundleTan TypeBase s u
t
| TypeBase s u -> Bool
forall shape u. TypeBase shape u -> Bool
isAcc TypeBase s u
t = do
t' <- TypeBase s u -> ADM (TypeBase s u)
forall a. Tangent a => a -> ADM a
tangent TypeBase s u
t
pure [t']
| Bool
otherwise = do
t' <- TypeBase s u -> ADM (TypeBase s u)
forall a. Tangent a => a -> ADM a
tangent TypeBase s u
t
pure [t, t']
bundleTangents :: (Tangent a) => [a] -> ADM [a]
bundleTangents :: forall a. Tangent a => [a] -> ADM [a]
bundleTangents = ([[a]] -> [a]
forall a. Monoid a => [a] -> a
mconcat ([[a]] -> [a]) -> ADM [[a]] -> ADM [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>) (ADM [[a]] -> ADM [a]) -> ([a] -> ADM [[a]]) -> [a] -> ADM [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> ADM [a]) -> [a] -> ADM [[a]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM a -> ADM [a]
forall a. Tangent a => a -> ADM [a]
bundleTan
instance Tangent VName where
tangent :: VName -> ADM VName
tangent VName
v = do
maybeTan <- (RState -> Maybe VName) -> ADM (Maybe VName)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((RState -> Maybe VName) -> ADM (Maybe VName))
-> (RState -> Maybe VName) -> ADM (Maybe VName)
forall a b. (a -> b) -> a -> b
$ VName -> Map VName VName -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v (Map VName VName -> Maybe VName)
-> (RState -> Map VName VName) -> RState -> Maybe VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RState -> Map VName VName
stateTans
case maybeTan of
Just VName
v_tan -> VName -> ADM VName
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v_tan
Maybe VName
Nothing -> do
t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
letExp (baseString v <> "_implicit_tan") $ zeroExp t
bundleTan :: VName -> ADM [VName]
bundleTan VName
v = do
t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
if isAcc t
then pure [v]
else do
v_tan <- tangent v
pure [v, v_tan]
instance Tangent SubExp where
tangent :: SubExp -> ADM SubExp
tangent (Constant PrimValue
c) = Type -> ADM SubExp
zeroTan (Type -> ADM SubExp) -> Type -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
c
tangent (Var VName
v) = VName -> SubExp
Var (VName -> SubExp) -> ADM VName -> ADM SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ADM VName
forall a. Tangent a => a -> ADM a
tangent VName
v
bundleTan :: SubExp -> ADM [SubExp]
bundleTan c :: SubExp
c@Constant {} = do
c_tan <- SubExp -> ADM SubExp
forall a. Tangent a => a -> ADM a
tangent SubExp
c
pure [c, c_tan]
bundleTan (Var VName
v) = (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> SubExp
Var ([VName] -> [SubExp]) -> ADM [VName] -> ADM [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ADM [VName]
forall a. Tangent a => a -> ADM [a]
bundleTan VName
v
instance Tangent SubExpRes where
tangent :: SubExpRes -> ADM SubExpRes
tangent (SubExpRes Certs
cs SubExp
se) = Certs -> SubExp -> SubExpRes
SubExpRes Certs
cs (SubExp -> SubExpRes) -> ADM SubExp -> ADM SubExpRes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> ADM SubExp
forall a. Tangent a => a -> ADM a
tangent SubExp
se
bundleTan :: SubExpRes -> ADM Result
bundleTan (SubExpRes Certs
cs SubExp
se) = (SubExp -> SubExpRes) -> [SubExp] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (Certs -> SubExp -> SubExpRes
SubExpRes Certs
cs) ([SubExp] -> Result) -> ADM [SubExp] -> ADM Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> ADM [SubExp]
forall a. Tangent a => a -> ADM [a]
bundleTan SubExp
se
basicFwd :: Pat Type -> StmAux () -> BasicOp -> ADM ()
basicFwd :: Pat Type -> StmAux () -> BasicOp -> ADM ()
basicFwd Pat Type
pat StmAux ()
aux BasicOp
op = do
pat_tan <- Pat Type -> ADM (Pat Type)
forall t. TanBuilder (PatElem t) => Pat t -> ADM (Pat t)
newTanPat Pat Type
pat
case op of
SubExp SubExp
se -> do
se_tan <- SubExp -> ADM SubExp
forall a. Tangent a => a -> ADM a
tangent SubExp
se
addStm $ Let pat_tan aux $ BasicOp $ SubExp se_tan
Opaque OpaqueOp
opaqueop SubExp
se -> do
se_tan <- SubExp -> ADM SubExp
forall a. Tangent a => a -> ADM a
tangent SubExp
se
addStm $ Let pat_tan aux $ BasicOp $ Opaque opaqueop se_tan
ArrayLit [SubExp]
ses Type
t -> do
ses_tan <- (SubExp -> ADM SubExp) -> [SubExp] -> ADM [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> ADM SubExp
forall a. Tangent a => a -> ADM a
tangent [SubExp]
ses
addStm $ Let pat_tan aux $ BasicOp $ ArrayLit ses_tan t
UnOp UnOp
unop SubExp
x -> do
let t :: PrimType
t = UnOp -> PrimType
unOpType UnOp
unop
x_pe :: PrimExp VName
x_pe = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t SubExp
x
dx :: PrimExp VName
dx = UnOp -> PrimExp VName -> PrimExp VName
pdUnOp UnOp
unop PrimExp VName
x_pe
x_tan <- PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t (SubExp -> PrimExp VName) -> ADM SubExp -> ADM (PrimExp VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> ADM SubExp
forall a. Tangent a => a -> ADM a
tangent SubExp
x
auxing aux $ letBindNames (patNames pat_tan) <=< toExp $ x_tan ~*~ dx
BinOp BinOp
bop SubExp
x SubExp
y -> do
let t :: PrimType
t = BinOp -> PrimType
binOpType BinOp
bop
x_tan <- PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t (SubExp -> PrimExp VName) -> ADM SubExp -> ADM (PrimExp VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> ADM SubExp
forall a. Tangent a => a -> ADM a
tangent SubExp
x
y_tan <- primExpFromSubExp t <$> tangent y
let (wrt_x, wrt_y) =
pdBinOp bop (primExpFromSubExp t x) (primExpFromSubExp t y)
auxing aux $
letBindNames (patNames pat_tan) <=< toExp $
x_tan ~*~ wrt_x ~+~ y_tan ~*~ wrt_y
CmpOp {} ->
Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep ADM) -> ADM ()) -> Stm (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep ADM))
-> StmAux (ExpDec (Rep ADM)) -> Exp (Rep ADM) -> Stm (Rep ADM)
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec (Rep ADM))
pat_tan StmAux ()
StmAux (ExpDec (Rep ADM))
aux (Exp (Rep ADM) -> Stm (Rep ADM)) -> Exp (Rep ADM) -> Stm (Rep ADM)
forall a b. (a -> b) -> a -> b
$ Type -> Exp SOACS
zeroExp (Type -> Exp SOACS) -> Type -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Bool
ConvOp ConvOp
cop SubExp
x -> do
x_tan <- SubExp -> ADM SubExp
forall a. Tangent a => a -> ADM a
tangent SubExp
x
addStm $ Let pat_tan aux $ BasicOp $ ConvOp cop x_tan
Assert {} -> () -> ADM ()
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
Index VName
arr Slice SubExp
slice -> do
arr_tan <- VName -> ADM VName
forall a. Tangent a => a -> ADM a
tangent VName
arr
addStm $ Let pat_tan aux $ BasicOp $ Index arr_tan slice
Update Safety
safety VName
arr Slice SubExp
slice SubExp
se -> do
arr_tan <- VName -> ADM VName
forall a. Tangent a => a -> ADM a
tangent VName
arr
se_tan <- tangent se
addStm $ Let pat_tan aux $ BasicOp $ Update safety arr_tan slice se_tan
Concat Int
d (VName
arr :| [VName]
arrs) SubExp
w -> do
arr_tan <- VName -> ADM VName
forall a. Tangent a => a -> ADM a
tangent VName
arr
arrs_tans <- mapM tangent arrs
addStm $ Let pat_tan aux $ BasicOp $ Concat d (arr_tan :| arrs_tans) w
Manifest VName
arr [Int]
ds -> do
arr_tan <- VName -> ADM VName
forall a. Tangent a => a -> ADM a
tangent VName
arr
addStm $ Let pat_tan aux $ BasicOp $ Manifest arr_tan ds
Iota SubExp
n SubExp
_ SubExp
_ IntType
it -> do
Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep ADM) -> ADM ()) -> Stm (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep ADM))
-> StmAux (ExpDec (Rep ADM)) -> Exp (Rep ADM) -> Stm (Rep ADM)
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec (Rep ADM))
pat_tan StmAux ()
StmAux (ExpDec (Rep ADM))
aux (Exp (Rep ADM) -> Stm (Rep ADM)) -> Exp (Rep ADM) -> Stm (Rep ADM)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
n]) (IntType -> Integer -> SubExp
intConst IntType
it Integer
0)
Replicate Shape
n SubExp
x -> do
x_tan <- SubExp -> ADM SubExp
forall a. Tangent a => a -> ADM a
tangent SubExp
x
addStm $ Let pat_tan aux $ BasicOp $ Replicate n x_tan
Scratch PrimType
t [SubExp]
shape ->
Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep ADM) -> ADM ()) -> Stm (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep ADM))
-> StmAux (ExpDec (Rep ADM)) -> Exp (Rep ADM) -> Stm (Rep ADM)
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec (Rep ADM))
pat_tan StmAux ()
StmAux (ExpDec (Rep ADM))
aux (Exp (Rep ADM) -> Stm (Rep ADM)) -> Exp (Rep ADM) -> Stm (Rep ADM)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ PrimType -> [SubExp] -> BasicOp
Scratch PrimType
t [SubExp]
shape
Reshape VName
arr NewShape SubExp
reshape -> do
arr_tan <- VName -> ADM VName
forall a. Tangent a => a -> ADM a
tangent VName
arr
addStm $ Let pat_tan aux $ BasicOp $ Reshape arr_tan reshape
Rearrange VName
arr [Int]
perm -> do
arr_tan <- VName -> ADM VName
forall a. Tangent a => a -> ADM a
tangent VName
arr
addStm $ Let pat_tan aux $ BasicOp $ Rearrange arr_tan perm
BasicOp
_ -> String -> ADM ()
forall a. HasCallStack => String -> a
error (String -> ADM ()) -> String -> ADM ()
forall a b. (a -> b) -> a -> b
$ String
"basicFwd: Unsupported op " String -> String -> String
forall a. [a] -> [a] -> [a]
++ BasicOp -> String
forall a. Pretty a => a -> String
prettyString BasicOp
op
fwdLambda :: Lambda SOACS -> ADM (Lambda SOACS)
fwdLambda :: Lambda SOACS -> ADM (Lambda SOACS)
fwdLambda l :: Lambda SOACS
l@(Lambda [LParam SOACS]
params [Type]
ret Body SOACS
body) =
[Param Type] -> [Type] -> Body SOACS -> Lambda SOACS
[LParam SOACS] -> [Type] -> Body SOACS -> Lambda SOACS
forall rep. [LParam rep] -> [Type] -> Body rep -> Lambda rep
Lambda ([Param Type] -> [Type] -> Body SOACS -> Lambda SOACS)
-> ADM [Param Type] -> ADM ([Type] -> Body SOACS -> Lambda SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Param Type] -> ADM [Param Type]
forall a. TanBuilder a => [a] -> ADM [a]
bundleNewList [Param Type]
[LParam SOACS]
params ADM ([Type] -> Body SOACS -> Lambda SOACS)
-> ADM [Type] -> ADM (Body SOACS -> Lambda SOACS)
forall a b. ADM (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Type] -> ADM [Type]
forall a. Tangent a => [a] -> ADM [a]
bundleTangents [Type]
ret ADM (Body SOACS -> Lambda SOACS)
-> ADM (Body SOACS) -> ADM (Lambda SOACS)
forall a b. ADM (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Lambda SOACS -> ADM (Body SOACS) -> ADM (Body SOACS)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Lambda SOACS
l (Body SOACS -> ADM (Body SOACS)
fwdBody Body SOACS
body)
fwdStreamLambda :: Lambda SOACS -> ADM (Lambda SOACS)
fwdStreamLambda :: Lambda SOACS -> ADM (Lambda SOACS)
fwdStreamLambda l :: Lambda SOACS
l@(Lambda [LParam SOACS]
params [Type]
ret Body SOACS
body) =
[Param Type] -> [Type] -> Body SOACS -> Lambda SOACS
[LParam SOACS] -> [Type] -> Body SOACS -> Lambda SOACS
forall rep. [LParam rep] -> [Type] -> Body rep -> Lambda rep
Lambda ([Param Type] -> [Type] -> Body SOACS -> Lambda SOACS)
-> ADM [Param Type] -> ADM ([Type] -> Body SOACS -> Lambda SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Int -> [Param Type] -> [Param Type]
forall a. Int -> [a] -> [a]
take Int
1 [Param Type]
[LParam SOACS]
params [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++) ([Param Type] -> [Param Type])
-> ADM [Param Type] -> ADM [Param Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Param Type] -> ADM [Param Type]
forall a. TanBuilder a => [a] -> ADM [a]
bundleNewList (Int -> [Param Type] -> [Param Type]
forall a. Int -> [a] -> [a]
drop Int
1 [Param Type]
[LParam SOACS]
params)) ADM ([Type] -> Body SOACS -> Lambda SOACS)
-> ADM [Type] -> ADM (Body SOACS -> Lambda SOACS)
forall a b. ADM (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Type] -> ADM [Type]
forall a. Tangent a => [a] -> ADM [a]
bundleTangents [Type]
ret ADM (Body SOACS -> Lambda SOACS)
-> ADM (Body SOACS) -> ADM (Lambda SOACS)
forall a b. ADM (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Lambda SOACS -> ADM (Body SOACS) -> ADM (Body SOACS)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Lambda SOACS
l (Body SOACS -> ADM (Body SOACS)
fwdBody Body SOACS
body)
interleave :: [a] -> [a] -> [a]
interleave :: forall a. [a] -> [a] -> [a]
interleave [a]
xs [a]
ys = [[a]] -> [a]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[a]] -> [a]) -> [[a]] -> [a]
forall a b. (a -> b) -> a -> b
$ [[a]] -> [[a]]
forall a. [[a]] -> [[a]]
transpose [[a]
xs, [a]
ys]
zeroFromSubExp :: SubExp -> ADM VName
zeroFromSubExp :: SubExp -> ADM VName
zeroFromSubExp (Constant PrimValue
c) =
String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"zero" (Exp SOACS -> ADM VName)
-> (PrimValue -> Exp SOACS) -> PrimValue -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS)
-> (PrimValue -> BasicOp) -> PrimValue -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp (SubExp -> BasicOp)
-> (PrimValue -> SubExp) -> PrimValue -> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimValue -> SubExp
Constant (PrimValue -> ADM VName) -> PrimValue -> ADM VName
forall a b. (a -> b) -> a -> b
$
PrimType -> PrimValue
blankPrimValue (PrimValue -> PrimType
primValueType PrimValue
c)
zeroFromSubExp (Var VName
v) = do
t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
letExp "zero" $ zeroExp t
fwdSOAC :: Pat Type -> StmAux () -> SOAC SOACS -> ADM ()
fwdSOAC :: Pat Type -> StmAux () -> SOAC SOACS -> ADM ()
fwdSOAC Pat Type
pat StmAux ()
aux (Screma SubExp
size [VName]
xs (ScremaForm Lambda SOACS
f [Scan SOACS]
scs [Reduce SOACS]
reds)) = do
pat' <- Pat Type -> ADM (Pat Type)
forall t. TanBuilder (PatElem t) => Pat t -> ADM (Pat t)
bundleNewPat Pat Type
pat
xs' <- bundleTangents xs
f' <- fwdLambda f
scs' <- mapM fwdScan scs
reds' <- mapM fwdRed reds
addStm $ Let pat' aux $ Op $ Screma size xs' $ ScremaForm f' scs' reds'
where
fwdScan :: Scan SOACS -> ADM (Scan SOACS)
fwdScan :: Scan SOACS -> ADM (Scan SOACS)
fwdScan Scan SOACS
sc = do
op' <- Lambda SOACS -> ADM (Lambda SOACS)
fwdLambda (Lambda SOACS -> ADM (Lambda SOACS))
-> Lambda SOACS -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ Scan SOACS -> Lambda SOACS
forall rep. Scan rep -> Lambda rep
scanLambda Scan SOACS
sc
neutral_tans <- mapM zeroFromSubExp $ scanNeutral sc
pure $
Scan
{ scanNeutral = scanNeutral sc `interleave` map Var neutral_tans,
scanLambda = op'
}
fwdRed :: Reduce SOACS -> ADM (Reduce SOACS)
fwdRed :: Reduce SOACS -> ADM (Reduce SOACS)
fwdRed Reduce SOACS
red = do
op' <- Lambda SOACS -> ADM (Lambda SOACS)
fwdLambda (Lambda SOACS -> ADM (Lambda SOACS))
-> Lambda SOACS -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ Reduce SOACS -> Lambda SOACS
forall rep. Reduce rep -> Lambda rep
redLambda Reduce SOACS
red
neutral_tans <- mapM zeroFromSubExp $ redNeutral red
pure $
Reduce
{ redComm = redComm red,
redLambda = op',
redNeutral = redNeutral red `interleave` map Var neutral_tans
}
fwdSOAC Pat Type
pat StmAux ()
aux (Stream SubExp
size [VName]
xs [SubExp]
nes Lambda SOACS
lam) = do
pat' <- Pat Type -> ADM (Pat Type)
forall t. TanBuilder (PatElem t) => Pat t -> ADM (Pat t)
bundleNewPat Pat Type
pat
lam' <- fwdStreamLambda lam
xs' <- bundleTangents xs
nes_tan <- mapM (fmap Var . zeroFromSubExp) nes
let nes' = [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
interleave [SubExp]
nes [SubExp]
nes_tan
addStm $ Let pat' aux $ Op $ Stream size xs' nes' lam'
fwdSOAC Pat Type
pat StmAux ()
aux (Hist SubExp
w [VName]
arrs [HistOp SOACS]
ops Lambda SOACS
bucket_fun) = do
pat' <- Pat Type -> ADM (Pat Type)
forall t. TanBuilder (PatElem t) => Pat t -> ADM (Pat t)
bundleNewPat Pat Type
pat
ops' <- mapM fwdHist ops
bucket_fun' <- fwdHistBucket bucket_fun
arrs' <- bundleTangents arrs
addStm $ Let pat' aux $ Op $ Hist w arrs' ops' bucket_fun'
where
n_indices :: Int
n_indices = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (HistOp SOACS -> Int) -> [HistOp SOACS] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (Shape -> Int) -> (HistOp SOACS -> Shape) -> HistOp SOACS -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp SOACS -> Shape
forall rep. HistOp rep -> Shape
histShape) [HistOp SOACS]
ops
fwdBodyHist :: Body SOACS -> ADM (Body (Rep ADM))
fwdBodyHist (Body BodyDec SOACS
_ Stms SOACS
stms Result
res) = ADM Result -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (ADM Result -> ADM (Body (Rep ADM)))
-> ADM Result -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
(Stm SOACS -> ADM ()) -> Stms SOACS -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm SOACS -> ADM ()
fwdStm Stms SOACS
stms
let (Result
res_is, Result
res_vs) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n_indices Result
res
(Result
res_is Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++) (Result -> Result) -> ADM Result -> ADM Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Result -> ADM Result
forall a. Tangent a => [a] -> ADM [a]
bundleTangents Result
res_vs
fwdHistBucket :: Lambda SOACS -> ADM (Lambda SOACS)
fwdHistBucket l :: Lambda SOACS
l@(Lambda [LParam SOACS]
params [Type]
ret Body SOACS
body) =
let ([Type]
r_is, [Type]
r_vs) = Int -> [Type] -> ([Type], [Type])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n_indices [Type]
ret
in [Param Type] -> [Type] -> Body SOACS -> Lambda SOACS
[LParam SOACS] -> [Type] -> Body SOACS -> Lambda SOACS
forall rep. [LParam rep] -> [Type] -> Body rep -> Lambda rep
Lambda
([Param Type] -> [Type] -> Body SOACS -> Lambda SOACS)
-> ADM [Param Type] -> ADM ([Type] -> Body SOACS -> Lambda SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Param Type] -> ADM [Param Type]
forall a. TanBuilder a => [a] -> ADM [a]
bundleNewList [Param Type]
[LParam SOACS]
params
ADM ([Type] -> Body SOACS -> Lambda SOACS)
-> ADM [Type] -> ADM (Body SOACS -> Lambda SOACS)
forall a b. ADM (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (([Type]
r_is [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++) ([Type] -> [Type]) -> ADM [Type] -> ADM [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Type] -> ADM [Type]
forall a. Tangent a => [a] -> ADM [a]
bundleTangents [Type]
r_vs)
ADM (Body SOACS -> Lambda SOACS)
-> ADM (Body SOACS) -> ADM (Lambda SOACS)
forall a b. ADM (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Lambda SOACS -> ADM (Body SOACS) -> ADM (Body SOACS)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Lambda SOACS
l (Body SOACS -> ADM (Body (Rep ADM))
fwdBodyHist Body SOACS
body)
fwdHist :: HistOp SOACS -> ADM (HistOp SOACS)
fwdHist :: HistOp SOACS -> ADM (HistOp SOACS)
fwdHist (HistOp Shape
shape SubExp
rf [VName]
dest [SubExp]
nes Lambda SOACS
op) = do
dest' <- [VName] -> ADM [VName]
forall a. Tangent a => [a] -> ADM [a]
bundleTangents [VName]
dest
nes_tan <- mapM (fmap Var . zeroFromSubExp) nes
op' <- fwdLambda op
pure $
HistOp
{ histShape = shape,
histRaceFactor = rf,
histDest = dest',
histNeutral = interleave nes nes_tan,
histOp = op'
}
fwdSOAC (Pat [PatElem Type]
pes) StmAux ()
aux (Scatter SubExp
w [VName]
ivs ScatterSpec VName
as Lambda SOACS
lam) = do
as_tan <- ((Shape, Int, VName) -> ADM (Shape, Int, VName))
-> ScatterSpec VName -> ADM (ScatterSpec VName)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\(Shape
s, Int
n, VName
a) -> do a_tan <- VName -> ADM VName
forall a. Tangent a => a -> ADM a
tangent VName
a; pure (s, n, a_tan)) ScatterSpec VName
as
pes_tan <- mapM newTan pes
ivs' <- bundleTangents ivs
let (as_ws, as_ns, _as_vs) = unzip3 as
n_indices = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) [Int]
as_ns ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Shape -> Int) -> [Shape] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Shape -> Int
forall a. ShapeBase a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
as_ws
lam' <- fwdScatterLambda n_indices lam
let s = Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem Type]
pes [PatElem Type] -> [PatElem Type] -> [PatElem Type]
forall a. [a] -> [a] -> [a]
++ [PatElem Type]
pes_tan)) StmAux ()
StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp
-> [VName] -> ScatterSpec VName -> Lambda SOACS -> SOAC SOACS
forall rep.
SubExp -> [VName] -> ScatterSpec VName -> Lambda rep -> SOAC rep
Scatter SubExp
w [VName]
ivs' (ScatterSpec VName
as ScatterSpec VName -> ScatterSpec VName -> ScatterSpec VName
forall a. [a] -> [a] -> [a]
++ ScatterSpec VName
as_tan) Lambda SOACS
lam'
addStm s
where
fwdScatterLambda :: Int -> Lambda SOACS -> ADM (Lambda SOACS)
fwdScatterLambda :: Int -> Lambda SOACS -> ADM (Lambda SOACS)
fwdScatterLambda Int
n_indices (Lambda [LParam SOACS]
params [Type]
ret Body SOACS
body) = do
params' <- [Param Type] -> ADM [Param Type]
forall a. TanBuilder a => [a] -> ADM [a]
bundleNewList [Param Type]
[LParam SOACS]
params
ret_tan <- mapM tangent $ drop n_indices ret
body' <- fwdBodyScatter n_indices body
let indices = [[Type]] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Type]] -> [Type]) -> [[Type]] -> [Type]
forall a b. (a -> b) -> a -> b
$ Int -> [Type] -> [[Type]]
forall a. Int -> a -> [a]
replicate Int
2 ([Type] -> [[Type]]) -> [Type] -> [[Type]]
forall a b. (a -> b) -> a -> b
$ Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take Int
n_indices [Type]
ret
ret' = [Type]
indices [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop Int
n_indices [Type]
ret [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
ret_tan
pure $ Lambda params' ret' body'
fwdBodyScatter :: Int -> Body SOACS -> ADM (Body SOACS)
fwdBodyScatter :: Int -> Body SOACS -> ADM (Body SOACS)
fwdBodyScatter Int
n_indices (Body BodyDec SOACS
_ Stms SOACS
stms Result
res) = do
(res_tan, stms') <- ADM Result -> ADM (Result, Stms (Rep ADM))
forall a. ADM a -> ADM (a, Stms (Rep ADM))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms (ADM Result -> ADM (Result, Stms (Rep ADM)))
-> ADM Result -> ADM (Result, Stms (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
(Stm SOACS -> ADM ()) -> Stms SOACS -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm SOACS -> ADM ()
fwdStm Stms SOACS
stms
(SubExpRes -> ADM SubExpRes) -> Result -> ADM Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExpRes -> ADM SubExpRes
forall a. Tangent a => a -> ADM a
tangent (Result -> ADM Result) -> Result -> ADM Result
forall a b. (a -> b) -> a -> b
$ Int -> Result -> Result
forall a. Int -> [a] -> [a]
drop Int
n_indices Result
res
let indices = [Result] -> Result
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([Result] -> Result) -> [Result] -> Result
forall a b. (a -> b) -> a -> b
$ Int -> Result -> [Result]
forall a. Int -> a -> [a]
replicate Int
2 (Result -> [Result]) -> Result -> [Result]
forall a b. (a -> b) -> a -> b
$ Int -> Result -> Result
forall a. Int -> [a] -> [a]
take Int
n_indices Result
res
res' = Result
indices Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Int -> Result -> Result
forall a. Int -> [a] -> [a]
drop Int
n_indices Result
res Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
res_tan
pure $ mkBody stms' res'
fwdSOAC Pat Type
_ StmAux ()
_ JVP {} =
String -> ADM ()
forall a. HasCallStack => String -> a
error String
"fwdSOAC: nested JVP not allowed."
fwdSOAC Pat Type
_ StmAux ()
_ VJP {} =
String -> ADM ()
forall a. HasCallStack => String -> a
error String
"fwdSOAC: nested VJP not allowed."
fwdStm :: Stm SOACS -> ADM ()
fwdStm :: Stm SOACS -> ADM ()
fwdStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (BasicOp (UpdateAcc Safety
safety VName
acc [SubExp]
i [SubExp]
x))) = do
pat' <- Pat Type -> ADM (Pat Type)
forall t. TanBuilder (PatElem t) => Pat t -> ADM (Pat t)
bundleNewPat Pat Type
Pat (LetDec SOACS)
pat
x' <- bundleTangents x
acc_tan <- tangent acc
addStm $ Let pat' aux $ BasicOp $ UpdateAcc safety acc_tan i x'
fwdStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (BasicOp BasicOp
e)) = do
Bool -> ADM () -> ADM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ((Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Type -> Bool
forall shape u. TypeBase shape u -> Bool
isAcc ([Type] -> Bool) -> [Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Pat Type -> [Type]
forall dec. Typed dec => Pat dec -> [Type]
patTypes Pat Type
Pat (LetDec SOACS)
pat) (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$
Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm (Rep ADM)
Stm SOACS
stm
Pat Type -> StmAux () -> BasicOp -> ADM ()
basicFwd Pat Type
Pat (LetDec SOACS)
pat StmAux ()
StmAux (ExpDec SOACS)
aux BasicOp
e
fwdStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
_ (Apply Name
f [(SubExp, Diet)]
args [(RetType SOACS, RetAls)]
_ Safety
_))
| Just (PrimType
ret, [PrimType]
argts) <- Name
-> Map Name (PrimType, [PrimType]) -> Maybe (PrimType, [PrimType])
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Name
f Map Name (PrimType, [PrimType])
builtInFunctions = do
Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm (Rep ADM)
Stm SOACS
stm
arg_tans <-
(PrimType -> SubExp -> PrimExp VName)
-> [PrimType] -> [SubExp] -> [PrimExp VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PrimType -> SubExp -> PrimExp VName
primExpFromSubExp [PrimType]
argts ([SubExp] -> [PrimExp VName])
-> ADM [SubExp] -> ADM [PrimExp VName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((SubExp, Diet) -> ADM SubExp) -> [(SubExp, Diet)] -> ADM [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SubExp -> ADM SubExp
forall a. Tangent a => a -> ADM a
tangent (SubExp -> ADM SubExp)
-> ((SubExp, Diet) -> SubExp) -> (SubExp, Diet) -> ADM SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp, Diet) -> SubExp
forall a b. (a, b) -> a
fst) [(SubExp, Diet)]
args
pat_tan <- newTanPat pat
let arg_pes = (PrimType -> SubExp -> PrimExp VName)
-> [PrimType] -> [SubExp] -> [PrimExp VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PrimType -> SubExp -> PrimExp VName
primExpFromSubExp [PrimType]
argts (((SubExp, Diet) -> SubExp) -> [(SubExp, Diet)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, Diet) -> SubExp
forall a b. (a, b) -> a
fst [(SubExp, Diet)]
args)
case pdBuiltin f arg_pes of
Maybe [PrimExp VName]
Nothing ->
String -> ADM ()
forall a. HasCallStack => String -> a
error (String -> ADM ()) -> String -> ADM ()
forall a b. (a -> b) -> a -> b
$ String
"No partial derivative defined for builtin function: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Pretty a => a -> String
prettyString Name
f
Just [PrimExp VName]
derivs -> do
let convertTo :: PrimType -> PrimExp VName -> PrimExp VName
convertTo PrimType
tt PrimExp VName
e
| PrimType
e_t PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
tt = PrimExp VName
e
| Bool
otherwise =
case (PrimType
tt, PrimType
e_t) of
(IntType IntType
tt', IntType IntType
ft) -> ConvOp -> PrimExp VName -> PrimExp VName
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> IntType -> ConvOp
SExt IntType
ft IntType
tt') PrimExp VName
e
(FloatType FloatType
tt', FloatType FloatType
ft) -> ConvOp -> PrimExp VName -> PrimExp VName
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> FloatType -> ConvOp
FPConv FloatType
ft FloatType
tt') PrimExp VName
e
(PrimType
Bool, FloatType FloatType
ft) -> ConvOp -> PrimExp VName -> PrimExp VName
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> ConvOp
FToB FloatType
ft) PrimExp VName
e
(FloatType FloatType
tt', PrimType
Bool) -> ConvOp -> PrimExp VName -> PrimExp VName
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> ConvOp
BToF FloatType
tt') PrimExp VName
e
(PrimType, PrimType)
_ -> String -> PrimExp VName
forall a. HasCallStack => String -> a
error (String -> PrimExp VName) -> String -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ String
"fwdStm.convertTo: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ (Name, PrimType, PrimType) -> String
forall a. Pretty a => a -> String
prettyString (Name
f, PrimType
tt, PrimType
e_t)
where
e_t :: PrimType
e_t = PrimExp VName -> PrimType
forall v. PrimExp v -> PrimType
primExpType PrimExp VName
e
(VName -> Exp SOACS -> ADM ()) -> [VName] -> [Exp SOACS] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ ([VName] -> Exp (Rep ADM) -> ADM ()
[VName] -> Exp SOACS -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames ([VName] -> Exp SOACS -> ADM ())
-> (VName -> [VName]) -> VName -> Exp SOACS -> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> [VName]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure) (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
pat_tan)
([Exp SOACS] -> ADM ()) -> ADM [Exp SOACS] -> ADM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (PrimExp VName -> ADM (Exp SOACS))
-> [PrimExp VName] -> ADM [Exp SOACS]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM PrimExp VName -> ADM (Exp (Rep ADM))
PrimExp VName -> ADM (Exp SOACS)
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
PrimExp VName -> m (Exp (Rep m))
toExp ((PrimExp VName -> PrimExp VName -> PrimExp VName)
-> [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
(~*~) ((PrimExp VName -> PrimExp VName)
-> [PrimExp VName] -> [PrimExp VName]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> PrimExp VName -> PrimExp VName
convertTo PrimType
ret) [PrimExp VName]
arg_tans) [PrimExp VName]
derivs)
fwdStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Match [SubExp]
ses [Case (Body SOACS)]
cases Body SOACS
defbody (MatchDec [BranchType SOACS]
ret MatchSort
ifsort))) = do
cases' <- ADM [Case (Body SOACS)] -> ADM [Case (Body SOACS)]
forall a. ADM a -> ADM a
slocal' (ADM [Case (Body SOACS)] -> ADM [Case (Body SOACS)])
-> ADM [Case (Body SOACS)] -> ADM [Case (Body SOACS)]
forall a b. (a -> b) -> a -> b
$ (Case (Body SOACS) -> ADM (Case (Body SOACS)))
-> [Case (Body SOACS)] -> ADM [Case (Body SOACS)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((Body SOACS -> ADM (Body SOACS))
-> Case (Body SOACS) -> ADM (Case (Body SOACS))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Case a -> f (Case b)
traverse Body SOACS -> ADM (Body SOACS)
fwdBody) [Case (Body SOACS)]
cases
defbody' <- slocal' $ fwdBody defbody
pat' <- bundleNewPat pat
ret' <- bundleTangents ret
addStm $ Let pat' aux $ Match ses cases' defbody' $ MatchDec ret' ifsort
fwdStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Loop [(FParam SOACS, SubExp)]
val_pats loop :: LoopForm
loop@(WhileLoop VName
v) Body SOACS
body)) = do
val_pats' <- [(Param DeclType, SubExp)] -> ADM [(Param DeclType, SubExp)]
forall a. TanBuilder a => [a] -> ADM [a]
bundleNewList [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
val_pats
pat' <- bundleNewPat pat
body' <-
localScope (scopeOfFParams (map fst val_pats) <> scopeOfLoopForm loop) . slocal' $
fwdBody body
addStm $ Let pat' aux $ Loop val_pats' (WhileLoop v) body'
fwdStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Loop [(FParam SOACS, SubExp)]
val_pats loop :: LoopForm
loop@(ForLoop VName
i IntType
it SubExp
bound) Body SOACS
body)) = do
pat' <- Pat Type -> ADM (Pat Type)
forall t. TanBuilder (PatElem t) => Pat t -> ADM (Pat t)
bundleNewPat Pat Type
Pat (LetDec SOACS)
pat
val_pats' <- bundleNewList val_pats
body' <-
localScope (scopeOfFParams (map fst val_pats) <> scopeOfLoopForm loop) . slocal' $
fwdBody body
addStm $ Let pat' aux $ Loop val_pats' (ForLoop i it bound) body'
fwdStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (WithAcc [WithAccInput SOACS]
inputs Lambda SOACS
lam)) = do
inputs' <- [WithAccInput SOACS]
-> (WithAccInput SOACS -> ADM (WithAccInput SOACS))
-> ADM [WithAccInput SOACS]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [WithAccInput SOACS]
inputs ((WithAccInput SOACS -> ADM (WithAccInput SOACS))
-> ADM [WithAccInput SOACS])
-> (WithAccInput SOACS -> ADM (WithAccInput SOACS))
-> ADM [WithAccInput SOACS]
forall a b. (a -> b) -> a -> b
$ \(Shape
shape, [VName]
arrs, Maybe (Lambda SOACS, [SubExp])
op) -> do
arrs_tan <- (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM VName
forall a. Tangent a => a -> ADM a
tangent [VName]
arrs
op' <- case op of
Maybe (Lambda SOACS, [SubExp])
Nothing -> Maybe (Lambda SOACS, [SubExp])
-> ADM (Maybe (Lambda SOACS, [SubExp]))
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Lambda SOACS, [SubExp])
forall a. Maybe a
Nothing
Just (Lambda SOACS
op_lam, [SubExp]
nes) -> do
nes_tan <- (SubExp -> ADM SubExp) -> [SubExp] -> ADM [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((VName -> SubExp) -> ADM VName -> ADM SubExp
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> SubExp
Var (ADM VName -> ADM SubExp)
-> (SubExp -> ADM VName) -> SubExp -> ADM SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> ADM VName
zeroFromSubExp) [SubExp]
nes
op_lam' <- fwdLambda op_lam
case op_lam' of
Lambda [LParam SOACS]
ps [Type]
ret Body SOACS
body -> do
let op_lam'' :: Lambda SOACS
op_lam'' = [LParam SOACS] -> [Type] -> Body SOACS -> Lambda SOACS
forall rep. [LParam rep] -> [Type] -> Body rep -> Lambda rep
Lambda (Int -> [Param Type] -> [Param Type]
forall {t} {a}. (Eq t, Num t) => t -> [a] -> [a]
removeIndexTans (Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
shape) [Param Type]
[LParam SOACS]
ps) [Type]
ret Body SOACS
body
Maybe (Lambda SOACS, [SubExp])
-> ADM (Maybe (Lambda SOACS, [SubExp]))
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Lambda SOACS, [SubExp])
-> ADM (Maybe (Lambda SOACS, [SubExp])))
-> Maybe (Lambda SOACS, [SubExp])
-> ADM (Maybe (Lambda SOACS, [SubExp]))
forall a b. (a -> b) -> a -> b
$ (Lambda SOACS, [SubExp]) -> Maybe (Lambda SOACS, [SubExp])
forall a. a -> Maybe a
Just (Lambda SOACS
op_lam'', [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
interleave [SubExp]
nes [SubExp]
nes_tan)
pure (shape, arrs <> arrs_tan, op')
pat' <- bundleNewPat pat
lam' <- fwdLambda lam
addStm $ Let pat' aux $ WithAcc inputs' lam'
where
removeIndexTans :: t -> [a] -> [a]
removeIndexTans t
0 [a]
ps = [a]
ps
removeIndexTans t
i (a
p : a
_ : [a]
ps) = a
p a -> [a] -> [a]
forall a. a -> [a] -> [a]
: t -> [a] -> [a]
removeIndexTans (t
i t -> t -> t
forall a. Num a => a -> a -> a
- t
1) [a]
ps
removeIndexTans t
_ [a]
ps = [a]
ps
fwdStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op Op SOACS
soac)) = Pat Type -> StmAux () -> SOAC SOACS -> ADM ()
fwdSOAC Pat Type
Pat (LetDec SOACS)
pat StmAux ()
StmAux (ExpDec SOACS)
aux Op SOACS
SOAC SOACS
soac
fwdStm Stm SOACS
stm =
String -> ADM ()
forall a. HasCallStack => String -> a
error (String -> ADM ()) -> String -> ADM ()
forall a b. (a -> b) -> a -> b
$ String
"unhandled forward mode AD for Stm: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Stm SOACS -> String
forall a. Pretty a => a -> String
prettyString Stm SOACS
stm String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Stm SOACS -> String
forall a. Show a => a -> String
show Stm SOACS
stm
fwdBody :: Body SOACS -> ADM (Body SOACS)
fwdBody :: Body SOACS -> ADM (Body SOACS)
fwdBody (Body BodyDec SOACS
_ Stms SOACS
stms Result
res) = ADM Result -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (ADM Result -> ADM (Body (Rep ADM)))
-> ADM Result -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
(Stm SOACS -> ADM ()) -> Stms SOACS -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm SOACS -> ADM ()
fwdStm Stms SOACS
stms
Result -> ADM Result
forall a. Tangent a => [a] -> ADM [a]
bundleTangents Result
res
fwdBodyTansLast :: Body SOACS -> ADM (Body SOACS)
fwdBodyTansLast :: Body SOACS -> ADM (Body SOACS)
fwdBodyTansLast (Body BodyDec SOACS
_ Stms SOACS
stms Result
res) = ADM Result -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (ADM Result -> ADM (Body (Rep ADM)))
-> ADM Result -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
(Stm SOACS -> ADM ()) -> Stms SOACS -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm SOACS -> ADM ()
fwdStm Stms SOACS
stms
(Result
res Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<>) (Result -> Result) -> ADM Result -> ADM Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExpRes -> ADM SubExpRes) -> Result -> ADM Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExpRes -> ADM SubExpRes
forall a. Tangent a => a -> ADM a
tangent Result
res
fwdJVP :: (MonadFreshNames m) => Scope SOACS -> Lambda SOACS -> m (Lambda SOACS)
fwdJVP :: forall (m :: * -> *).
MonadFreshNames m =>
Scope SOACS -> Lambda SOACS -> m (Lambda SOACS)
fwdJVP Scope SOACS
scope l :: Lambda SOACS
l@(Lambda [LParam SOACS]
params [Type]
ret Body SOACS
body) =
ADM (Lambda SOACS) -> m (Lambda SOACS)
forall (m :: * -> *) a. MonadFreshNames m => ADM a -> m a
runADM (ADM (Lambda SOACS) -> m (Lambda SOACS))
-> (ADM (Lambda SOACS) -> ADM (Lambda SOACS))
-> ADM (Lambda SOACS)
-> m (Lambda SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope SOACS -> ADM (Lambda SOACS) -> ADM (Lambda SOACS)
forall a. Scope SOACS -> ADM a -> ADM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope SOACS
scope (ADM (Lambda SOACS) -> ADM (Lambda SOACS))
-> (ADM (Lambda SOACS) -> ADM (Lambda SOACS))
-> ADM (Lambda SOACS)
-> ADM (Lambda SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda SOACS -> ADM (Lambda SOACS) -> ADM (Lambda SOACS)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Lambda SOACS
l (ADM (Lambda SOACS) -> m (Lambda SOACS))
-> ADM (Lambda SOACS) -> m (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ do
params_tan <- (Param Type -> ADM (Param Type))
-> [Param Type] -> ADM [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Param Type -> ADM (Param Type)
forall a. TanBuilder a => a -> ADM a
newTan [Param Type]
[LParam SOACS]
params
body_tan <- fwdBodyTansLast body
ret_tan <- mapM tangent ret
pure $ Lambda (params ++ params_tan) (ret <> ret_tan) body_tan