module Language.Futhark.Interpreter.AD
( Op (..),
ADVariable (..),
ADValue (..),
Tape (..),
VJPValue (..),
JVPValue (..),
Counter (..),
Depth (..),
doOp,
addFor,
tapePrimal,
primitive,
varPrimal,
deriveTape,
unionWithM,
unionsWithM,
)
where
import Control.Monad (foldM, zipWithM)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.Except (ExceptT, catchE, runExceptT, throwE)
import Control.Monad.Trans.State (State, get, modify, runState)
import Data.Either (isRight)
import Data.Foldable (find, foldlM)
import Data.Functor ((<&>))
import Data.Map qualified as M
import Data.Maybe (fromJust, fromMaybe)
import Data.Text qualified as T
import Futhark.AD.Derivatives (pdBinOp, pdBuiltin, pdUnOp)
import Futhark.Analysis.PrimExp (PrimExp (..))
import Language.Futhark.Core (VName (..), nameFromString, nameFromText)
import Language.Futhark.Primitive
( BinOp (Add, FAdd, FMul, LogAnd, LogOr, Mul),
CmpOp,
ConvOp,
Overflow (OverflowWrap),
PrimType (Bool, FloatType, IntType),
PrimValue (BoolValue),
UnOp,
binOpType,
blankPrimValue,
cmpOpType,
convOpType,
doBinOp,
doCmpOp,
doConvOp,
doUnOp,
flipConvOp,
primFuns,
primValueType,
unOpType,
)
newtype Counter = Counter Int
deriving (Counter -> Counter -> Bool
(Counter -> Counter -> Bool)
-> (Counter -> Counter -> Bool) -> Eq Counter
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Counter -> Counter -> Bool
== :: Counter -> Counter -> Bool
$c/= :: Counter -> Counter -> Bool
/= :: Counter -> Counter -> Bool
Eq, Eq Counter
Eq Counter =>
(Counter -> Counter -> Ordering)
-> (Counter -> Counter -> Bool)
-> (Counter -> Counter -> Bool)
-> (Counter -> Counter -> Bool)
-> (Counter -> Counter -> Bool)
-> (Counter -> Counter -> Counter)
-> (Counter -> Counter -> Counter)
-> Ord Counter
Counter -> Counter -> Bool
Counter -> Counter -> Ordering
Counter -> Counter -> Counter
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Counter -> Counter -> Ordering
compare :: Counter -> Counter -> Ordering
$c< :: Counter -> Counter -> Bool
< :: Counter -> Counter -> Bool
$c<= :: Counter -> Counter -> Bool
<= :: Counter -> Counter -> Bool
$c> :: Counter -> Counter -> Bool
> :: Counter -> Counter -> Bool
$c>= :: Counter -> Counter -> Bool
>= :: Counter -> Counter -> Bool
$cmax :: Counter -> Counter -> Counter
max :: Counter -> Counter -> Counter
$cmin :: Counter -> Counter -> Counter
min :: Counter -> Counter -> Counter
Ord, Integer -> Counter
Counter -> Counter
Counter -> Counter -> Counter
(Counter -> Counter -> Counter)
-> (Counter -> Counter -> Counter)
-> (Counter -> Counter -> Counter)
-> (Counter -> Counter)
-> (Counter -> Counter)
-> (Counter -> Counter)
-> (Integer -> Counter)
-> Num Counter
forall a.
(a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (Integer -> a)
-> Num a
$c+ :: Counter -> Counter -> Counter
+ :: Counter -> Counter -> Counter
$c- :: Counter -> Counter -> Counter
- :: Counter -> Counter -> Counter
$c* :: Counter -> Counter -> Counter
* :: Counter -> Counter -> Counter
$cnegate :: Counter -> Counter
negate :: Counter -> Counter
$cabs :: Counter -> Counter
abs :: Counter -> Counter
$csignum :: Counter -> Counter
signum :: Counter -> Counter
$cfromInteger :: Integer -> Counter
fromInteger :: Integer -> Counter
Num, Int -> Counter -> ShowS
[Counter] -> ShowS
Counter -> String
(Int -> Counter -> ShowS)
-> (Counter -> String) -> ([Counter] -> ShowS) -> Show Counter
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Counter -> ShowS
showsPrec :: Int -> Counter -> ShowS
$cshow :: Counter -> String
show :: Counter -> String
$cshowList :: [Counter] -> ShowS
showList :: [Counter] -> ShowS
Show)
type ADMonad = ExceptT String (State Counter)
incCounter :: ADMonad ()
incCounter :: ADMonad ()
incCounter = State Counter () -> ADMonad ()
forall (m :: * -> *) a. Monad m => m a -> ExceptT String m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (State Counter () -> ADMonad ()) -> State Counter () -> ADMonad ()
forall a b. (a -> b) -> a -> b
$ (Counter -> Counter) -> State Counter ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((Counter -> Counter) -> State Counter ())
-> (Counter -> Counter) -> State Counter ()
forall a b. (a -> b) -> a -> b
$ \Counter
i -> Counter
i Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
+ Counter
1
data Op
= OpBin BinOp
| OpCmp CmpOp
| OpUn UnOp
| OpFn T.Text
| OpConv ConvOp
deriving (Int -> Op -> ShowS
[Op] -> ShowS
Op -> String
(Int -> Op -> ShowS)
-> (Op -> String) -> ([Op] -> ShowS) -> Show Op
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Op -> ShowS
showsPrec :: Int -> Op -> ShowS
$cshow :: Op -> String
show :: Op -> String
$cshowList :: [Op] -> ShowS
showList :: [Op] -> ShowS
Show)
opTypeMatch :: Op -> [PrimType] -> Bool
opTypeMatch :: Op -> [PrimType] -> Bool
opTypeMatch (OpBin BinOp
op) [PrimType]
p = (PrimType -> Bool) -> [PrimType] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\PrimType
x -> BinOp -> PrimType
binOpType BinOp
op PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
x) [PrimType]
p
opTypeMatch (OpCmp CmpOp
op) [PrimType]
p = (PrimType -> Bool) -> [PrimType] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\PrimType
x -> CmpOp -> PrimType
cmpOpType CmpOp
op PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
x) [PrimType]
p
opTypeMatch (OpUn UnOp
op) [PrimType]
p = (PrimType -> Bool) -> [PrimType] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\PrimType
x -> UnOp -> PrimType
unOpType UnOp
op PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
x) [PrimType]
p
opTypeMatch (OpConv ConvOp
op) [PrimType]
p = (PrimType -> Bool) -> [PrimType] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\PrimType
x -> (PrimType, PrimType) -> PrimType
forall a b. (a, b) -> a
fst (ConvOp -> (PrimType, PrimType)
convOpType ConvOp
op) PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
x) [PrimType]
p
opTypeMatch (OpFn Text
fn) [PrimType]
p = case Text
-> Map Text ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
-> Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Text
fn Map Text ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
primFuns of
Just ([PrimType]
t, PrimType
_, [PrimValue] -> Maybe PrimValue
_) -> [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ (PrimType -> PrimType -> Bool)
-> [PrimType] -> [PrimType] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
(==) [PrimType]
t [PrimType]
p
Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
Nothing -> String -> Bool
forall a. HasCallStack => String -> a
error String
"opTypeMatch"
opReturnType :: Op -> PrimType
opReturnType :: Op -> PrimType
opReturnType (OpBin BinOp
op) = BinOp -> PrimType
binOpType BinOp
op
opReturnType (OpCmp CmpOp
op) = CmpOp -> PrimType
cmpOpType CmpOp
op
opReturnType (OpUn UnOp
op) = UnOp -> PrimType
unOpType UnOp
op
opReturnType (OpConv ConvOp
op) = (PrimType, PrimType) -> PrimType
forall a b. (a, b) -> b
snd ((PrimType, PrimType) -> PrimType)
-> (PrimType, PrimType) -> PrimType
forall a b. (a -> b) -> a -> b
$ ConvOp -> (PrimType, PrimType)
convOpType ConvOp
op
opReturnType (OpFn Text
fn) = case Text
-> Map Text ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
-> Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Text
fn Map Text ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
primFuns of
Just ([PrimType]
_, PrimType
t, [PrimValue] -> Maybe PrimValue
_) -> PrimType
t
Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
Nothing -> String -> PrimType
forall a. HasCallStack => String -> a
error String
"opReturnType"
addFor :: PrimType -> BinOp
addFor :: PrimType -> BinOp
addFor (IntType IntType
t) = IntType -> Overflow -> BinOp
Add IntType
t Overflow
OverflowWrap
addFor (FloatType FloatType
t) = FloatType -> BinOp
FAdd FloatType
t
addFor PrimType
Bool = BinOp
LogOr
addFor PrimType
t = String -> BinOp
forall a. HasCallStack => String -> a
error (String -> BinOp) -> String -> BinOp
forall a b. (a -> b) -> a -> b
$ String
"addFor: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ PrimType -> String
forall a. Show a => a -> String
show PrimType
t
mulFor :: PrimType -> BinOp
mulFor :: PrimType -> BinOp
mulFor (IntType IntType
t) = IntType -> Overflow -> BinOp
Mul IntType
t Overflow
OverflowWrap
mulFor (FloatType FloatType
t) = FloatType -> BinOp
FMul FloatType
t
mulFor PrimType
Bool = BinOp
LogAnd
mulFor PrimType
t = String -> BinOp
forall a. HasCallStack => String -> a
error (String -> BinOp) -> String -> BinOp
forall a b. (a -> b) -> a -> b
$ String
"mulFor: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ PrimType -> String
forall a. Show a => a -> String
show PrimType
t
newtype Depth = Depth Int
deriving (Eq Depth
Eq Depth =>
(Depth -> Depth -> Ordering)
-> (Depth -> Depth -> Bool)
-> (Depth -> Depth -> Bool)
-> (Depth -> Depth -> Bool)
-> (Depth -> Depth -> Bool)
-> (Depth -> Depth -> Depth)
-> (Depth -> Depth -> Depth)
-> Ord Depth
Depth -> Depth -> Bool
Depth -> Depth -> Ordering
Depth -> Depth -> Depth
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Depth -> Depth -> Ordering
compare :: Depth -> Depth -> Ordering
$c< :: Depth -> Depth -> Bool
< :: Depth -> Depth -> Bool
$c<= :: Depth -> Depth -> Bool
<= :: Depth -> Depth -> Bool
$c> :: Depth -> Depth -> Bool
> :: Depth -> Depth -> Bool
$c>= :: Depth -> Depth -> Bool
>= :: Depth -> Depth -> Bool
$cmax :: Depth -> Depth -> Depth
max :: Depth -> Depth -> Depth
$cmin :: Depth -> Depth -> Depth
min :: Depth -> Depth -> Depth
Ord, Depth -> Depth -> Bool
(Depth -> Depth -> Bool) -> (Depth -> Depth -> Bool) -> Eq Depth
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Depth -> Depth -> Bool
== :: Depth -> Depth -> Bool
$c/= :: Depth -> Depth -> Bool
/= :: Depth -> Depth -> Bool
Eq, Int -> Depth -> ShowS
[Depth] -> ShowS
Depth -> String
(Int -> Depth -> ShowS)
-> (Depth -> String) -> ([Depth] -> ShowS) -> Show Depth
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Depth -> ShowS
showsPrec :: Int -> Depth -> ShowS
$cshow :: Depth -> String
show :: Depth -> String
$cshowList :: [Depth] -> ShowS
showList :: [Depth] -> ShowS
Show)
data ADValue
= Variable Depth ADVariable
| Constant PrimValue
deriving (Int -> ADValue -> ShowS
[ADValue] -> ShowS
ADValue -> String
(Int -> ADValue -> ShowS)
-> (ADValue -> String) -> ([ADValue] -> ShowS) -> Show ADValue
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ADValue -> ShowS
showsPrec :: Int -> ADValue -> ShowS
$cshow :: ADValue -> String
show :: ADValue -> String
$cshowList :: [ADValue] -> ShowS
showList :: [ADValue] -> ShowS
Show)
data ADVariable
= VJP VJPValue
| JVP JVPValue
deriving (Int -> ADVariable -> ShowS
[ADVariable] -> ShowS
ADVariable -> String
(Int -> ADVariable -> ShowS)
-> (ADVariable -> String)
-> ([ADVariable] -> ShowS)
-> Show ADVariable
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ADVariable -> ShowS
showsPrec :: Int -> ADVariable -> ShowS
$cshow :: ADVariable -> String
show :: ADVariable -> String
$cshowList :: [ADVariable] -> ShowS
showList :: [ADVariable] -> ShowS
Show)
depth :: ADValue -> Depth
depth :: ADValue -> Depth
depth (Variable Depth
d ADVariable
_) = Depth
d
depth (Constant PrimValue
_) = Int -> Depth
Depth Int
0
primal :: ADValue -> ADValue
primal :: ADValue -> ADValue
primal (Variable Depth
_ (VJP (VJPValue Tape
t))) = Tape -> ADValue
tapePrimal Tape
t
primal (Variable Depth
_ (JVP (JVPValue ADValue
v ADValue
_))) = ADValue -> ADValue
primal ADValue
v
primal (Constant PrimValue
v) = PrimValue -> ADValue
Constant PrimValue
v
primalFor :: Depth -> ADValue -> ADValue
primalFor :: Depth -> ADValue -> ADValue
primalFor Depth
cur v :: ADValue
v@(Variable Depth
tag ADVariable
_) | Depth
cur Depth -> Depth -> Bool
forall a. Eq a => a -> a -> Bool
/= Depth
tag = ADValue
v
primalFor Depth
_ (Variable Depth
_ (VJP (VJPValue Tape
t))) = Tape -> ADValue
tapePrimal Tape
t
primalFor Depth
cur (Variable Depth
_ (JVP (JVPValue ADValue
v ADValue
_))) = Depth -> ADValue -> ADValue
primalFor Depth
cur ADValue
v
primalFor Depth
_ (Constant PrimValue
v) = PrimValue -> ADValue
Constant PrimValue
v
primitive :: ADValue -> PrimValue
primitive :: ADValue -> PrimValue
primitive (Variable Depth
_ ADVariable
v) = ADVariable -> PrimValue
varPrimal ADVariable
v
primitive (Constant PrimValue
v) = PrimValue
v
varPrimal :: ADVariable -> PrimValue
varPrimal :: ADVariable -> PrimValue
varPrimal (VJP (VJPValue Tape
t)) = ADValue -> PrimValue
primitive (ADValue -> PrimValue) -> ADValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Tape -> ADValue
tapePrimal Tape
t
varPrimal (JVP (JVPValue ADValue
v ADValue
_)) = ADValue -> PrimValue
primitive (ADValue -> PrimValue) -> ADValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ ADValue -> ADValue
primal ADValue
v
evalPrimExp :: M.Map VName ADValue -> PrimExp VName -> ADMonad ADValue
evalPrimExp :: Map VName ADValue -> PrimExp VName -> ADMonad ADValue
evalPrimExp Map VName ADValue
m (LeafExp VName
n PrimType
_) =
ADMonad ADValue
-> (ADValue -> ADMonad ADValue) -> Maybe ADValue -> ADMonad ADValue
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (String -> ADMonad ADValue
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (String -> ADMonad ADValue) -> String -> ADMonad ADValue
forall a b. (a -> b) -> a -> b
$ String
"Unknown variable " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> VName -> String
forall a. Show a => a -> String
show VName
n) ADValue -> ADMonad ADValue
forall a. a -> ExceptT String (State Counter) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe ADValue -> ADMonad ADValue)
-> Maybe ADValue -> ADMonad ADValue
forall a b. (a -> b) -> a -> b
$ VName -> Map VName ADValue -> Maybe ADValue
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
n Map VName ADValue
m
evalPrimExp Map VName ADValue
_ (ValueExp PrimValue
pv) =
ADValue -> ADMonad ADValue
forall a. a -> ExceptT String (State Counter) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADValue -> ADMonad ADValue) -> ADValue -> ADMonad ADValue
forall a b. (a -> b) -> a -> b
$ PrimValue -> ADValue
Constant PrimValue
pv
evalPrimExp Map VName ADValue
m (BinOpExp BinOp
op PrimExp VName
x PrimExp VName
y) = do
x' <- Map VName ADValue -> PrimExp VName -> ADMonad ADValue
evalPrimExp Map VName ADValue
m PrimExp VName
x
y' <- evalPrimExp m y
doOp' (OpBin op) [x', y']
evalPrimExp Map VName ADValue
m (CmpOpExp CmpOp
op PrimExp VName
x PrimExp VName
y) = do
x' <- Map VName ADValue -> PrimExp VName -> ADMonad ADValue
evalPrimExp Map VName ADValue
m PrimExp VName
x
y' <- evalPrimExp m y
doOp' (OpCmp op) [x', y']
evalPrimExp Map VName ADValue
m (UnOpExp UnOp
op PrimExp VName
x) = do
x' <- Map VName ADValue -> PrimExp VName -> ADMonad ADValue
evalPrimExp Map VName ADValue
m PrimExp VName
x
doOp' (OpUn op) [x']
evalPrimExp Map VName ADValue
m (ConvOpExp ConvOp
op PrimExp VName
x) = do
x' <- Map VName ADValue -> PrimExp VName -> ADMonad ADValue
evalPrimExp Map VName ADValue
m PrimExp VName
x
doOp' (OpConv op) [x']
evalPrimExp Map VName ADValue
m (FunExp Text
fn [PrimExp VName]
p PrimType
_) = do
p' <- (PrimExp VName -> ADMonad ADValue)
-> [PrimExp VName] -> ExceptT String (State Counter) [ADValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Map VName ADValue -> PrimExp VName -> ADMonad ADValue
evalPrimExp Map VName ADValue
m) [PrimExp VName]
p
doOp' (OpFn fn) p'
lookupPDs :: Op -> [PrimExp VName] -> Maybe [PrimExp VName]
lookupPDs :: Op -> [PrimExp VName] -> Maybe [PrimExp VName]
lookupPDs (OpBin BinOp
op) [PrimExp VName
x, PrimExp VName
y] = [PrimExp VName] -> Maybe [PrimExp VName]
forall a. a -> Maybe a
Just ([PrimExp VName] -> Maybe [PrimExp VName])
-> [PrimExp VName] -> Maybe [PrimExp VName]
forall a b. (a -> b) -> a -> b
$ do
let (PrimExp VName
a, PrimExp VName
b) = BinOp
-> PrimExp VName -> PrimExp VName -> (PrimExp VName, PrimExp VName)
pdBinOp BinOp
op PrimExp VName
x PrimExp VName
y
[PrimExp VName
a, PrimExp VName
b]
lookupPDs (OpUn UnOp
op) [PrimExp VName
x] = [PrimExp VName] -> Maybe [PrimExp VName]
forall a. a -> Maybe a
Just [UnOp -> PrimExp VName -> PrimExp VName
pdUnOp UnOp
op PrimExp VName
x]
lookupPDs (OpFn Text
fn) [PrimExp VName]
p = Name -> [PrimExp VName] -> Maybe [PrimExp VName]
pdBuiltin (Text -> Name
nameFromText Text
fn) [PrimExp VName]
p
lookupPDs Op
_ [PrimExp VName]
_ = Maybe [PrimExp VName]
forall a. Maybe a
Nothing
doOp :: Op -> [ADValue] -> Counter -> Either String (ADValue, Counter)
doOp :: Op -> [ADValue] -> Counter -> Either String (ADValue, Counter)
doOp Op
op [ADValue]
o Counter
uid = case State Counter (Either String ADValue)
-> Counter -> (Either String ADValue, Counter)
forall s a. State s a -> s -> (a, s)
runState (ADMonad ADValue -> State Counter (Either String ADValue)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ADMonad ADValue -> State Counter (Either String ADValue))
-> ADMonad ADValue -> State Counter (Either String ADValue)
forall a b. (a -> b) -> a -> b
$ Op -> [ADValue] -> ADMonad ADValue
doOp' Op
op [ADValue]
o) Counter
uid of
(Left String
s, Counter
_) -> String -> Either String (ADValue, Counter)
forall a b. a -> Either a b
Left String
s
(Right ADValue
v, Counter
uid') -> (ADValue, Counter) -> Either String (ADValue, Counter)
forall a b. b -> Either a b
Right (ADValue
v, Counter
uid')
doOp' :: Op -> [ADValue] -> ADMonad ADValue
doOp' :: Op -> [ADValue] -> ADMonad ADValue
doOp' Op
op [ADValue]
o
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Op -> [PrimType] -> Bool
opTypeMatch Op
op ((PrimValue -> PrimType) -> [PrimValue] -> [PrimType]
forall a b. (a -> b) -> [a] -> [b]
map PrimValue -> PrimType
primValueType [PrimValue]
pv) =
String -> ADMonad ADValue
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (String -> ADMonad ADValue) -> String -> ADMonad ADValue
forall a b. (a -> b) -> a -> b
$ [String] -> String
unwords [String
"invalid types for op", Op -> String
forall a. Show a => a -> String
show Op
op, String
"and operands", [ADValue] -> String
forall a. Show a => a -> String
show [ADValue]
o]
| Bool
otherwise = do
let dep :: Depth
dep = case Op
op of
OpCmp CmpOp
_ -> Int -> Depth
Depth Int
0
Op
_ -> [Depth] -> Depth
forall a. Ord a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum ((ADValue -> Depth) -> [ADValue] -> [Depth]
forall a b. (a -> b) -> [a] -> [b]
map ADValue -> Depth
depth [ADValue]
o)
if Depth
dep Depth -> Depth -> Bool
forall a. Eq a => a -> a -> Bool
== Int -> Depth
Depth Int
0
then ADMonad ADValue
-> (ADValue -> ADMonad ADValue) -> Maybe ADValue -> ADMonad ADValue
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (String -> ADMonad ADValue
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE String
"failed to evaluate const") ADValue -> ADMonad ADValue
forall a. a -> ExceptT String (State Counter) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe ADValue
constCase ADMonad ADValue -> ADMonad () -> ADMonad ADValue
forall a b.
ExceptT String (State Counter) a
-> ExceptT String (State Counter) b
-> ExceptT String (State Counter) a
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* ADMonad ()
incCounter
else Depth -> ADMonad ADValue
nonconstCase Depth
dep
where
pv :: [PrimValue]
pv = (ADValue -> PrimValue) -> [ADValue] -> [PrimValue]
forall a b. (a -> b) -> [a] -> [b]
map ADValue -> PrimValue
primitive [ADValue]
o
divideDepths :: Depth -> ADValue -> Either ADValue ADVariable
divideDepths :: Depth -> ADValue -> Either ADValue ADVariable
divideDepths Depth
_ v :: ADValue
v@(Constant {}) = ADValue -> Either ADValue ADVariable
forall a b. a -> Either a b
Left ADValue
v
divideDepths Depth
d v :: ADValue
v@(Variable Depth
d' ADVariable
v') = if Depth
d' Depth -> Depth -> Bool
forall a. Ord a => a -> a -> Bool
< Depth
d then ADValue -> Either ADValue ADVariable
forall a b. a -> Either a b
Left ADValue
v else ADVariable -> Either ADValue ADVariable
forall a b. b -> Either a b
Right ADVariable
v'
extractVJP :: Either ADValue ADVariable -> Either ADValue VJPValue
extractVJP :: Either ADValue ADVariable -> Either ADValue VJPValue
extractVJP (Right (VJP VJPValue
v)) = VJPValue -> Either ADValue VJPValue
forall a b. b -> Either a b
Right VJPValue
v
extractVJP (Left ADValue
v) = ADValue -> Either ADValue VJPValue
forall a b. a -> Either a b
Left ADValue
v
extractVJP Either ADValue ADVariable
_ =
String -> Either ADValue VJPValue
forall a. HasCallStack => String -> a
error String
"extractVJP"
extractJVP :: Either ADValue ADVariable -> Either ADValue JVPValue
extractJVP :: Either ADValue ADVariable -> Either ADValue JVPValue
extractJVP (Right (JVP JVPValue
v)) = JVPValue -> Either ADValue JVPValue
forall a b. b -> Either a b
Right JVPValue
v
extractJVP (Left ADValue
v) = ADValue -> Either ADValue JVPValue
forall a b. a -> Either a b
Left ADValue
v
extractJVP Either ADValue ADVariable
_ =
String -> Either ADValue JVPValue
forall a. HasCallStack => String -> a
error String
"extractJVP"
constCase :: Maybe ADValue
constCase =
PrimValue -> ADValue
Constant (PrimValue -> ADValue) -> Maybe PrimValue -> Maybe ADValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> case (Op
op, [PrimValue]
pv) of
(OpBin BinOp
op', [PrimValue
x, PrimValue
y]) -> BinOp -> PrimValue -> PrimValue -> Maybe PrimValue
doBinOp BinOp
op' PrimValue
x PrimValue
y
(OpCmp CmpOp
op', [PrimValue
x, PrimValue
y]) -> Bool -> PrimValue
BoolValue (Bool -> PrimValue) -> Maybe Bool -> Maybe PrimValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CmpOp -> PrimValue -> PrimValue -> Maybe Bool
doCmpOp CmpOp
op' PrimValue
x PrimValue
y
(OpUn UnOp
op', [PrimValue
x]) -> UnOp -> PrimValue -> Maybe PrimValue
doUnOp UnOp
op' PrimValue
x
(OpConv ConvOp
op', [PrimValue
x]) -> ConvOp -> PrimValue -> Maybe PrimValue
doConvOp ConvOp
op' PrimValue
x
(OpFn Text
fn, [PrimValue]
_) -> do
(_, _, f) <- Text
-> Map Text ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
-> Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Text
fn Map Text ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
primFuns
f pv
(Op, [PrimValue])
_ -> String -> Maybe PrimValue
forall a. HasCallStack => String -> a
error String
"doOp': opTypeMatch"
nonconstCase :: Depth -> ADMonad ADValue
nonconstCase Depth
dep = do
let oprev :: [ADValue]
oprev = (ADValue -> ADValue) -> [ADValue] -> [ADValue]
forall a b. (a -> b) -> [a] -> [b]
map (Depth -> ADValue -> ADValue
primalFor Depth
dep) [ADValue]
o
vprev <- Op -> [ADValue] -> ADMonad ADValue
doOp' Op
op [ADValue]
oprev
let o' = (ADValue -> Either ADValue ADVariable)
-> [ADValue] -> [Either ADValue ADVariable]
forall a b. (a -> b) -> [a] -> [b]
map (Depth -> ADValue -> Either ADValue ADVariable
divideDepths Depth
dep) [ADValue]
o
case find isRight o' of
Just (Right (VJP {})) ->
Depth -> ADVariable -> ADValue
Variable Depth
dep (ADVariable -> ADValue) -> (Tape -> ADVariable) -> Tape -> ADValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VJPValue -> ADVariable
VJP (VJPValue -> ADVariable)
-> (Tape -> VJPValue) -> Tape -> ADVariable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tape -> VJPValue
VJPValue
(Tape -> ADValue)
-> ExceptT String (State Counter) Tape -> ADMonad ADValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Op
-> [Either ADValue VJPValue]
-> ADValue
-> ExceptT String (State Counter) Tape
vjpHandleOp Op
op ((Either ADValue ADVariable -> Either ADValue VJPValue)
-> [Either ADValue ADVariable] -> [Either ADValue VJPValue]
forall a b. (a -> b) -> [a] -> [b]
map Either ADValue ADVariable -> Either ADValue VJPValue
extractVJP [Either ADValue ADVariable]
o') ADValue
vprev
Just (Right (JVP {})) ->
Depth -> ADVariable -> ADValue
Variable Depth
dep (ADVariable -> ADValue)
-> (ADValue -> ADVariable) -> ADValue -> ADValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. JVPValue -> ADVariable
JVP (JVPValue -> ADVariable)
-> (ADValue -> JVPValue) -> ADValue -> ADVariable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ADValue -> ADValue -> JVPValue
JVPValue ADValue
vprev
(ADValue -> ADValue) -> ADMonad ADValue -> ADMonad ADValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Op -> [Either ADValue JVPValue] -> ADMonad ADValue
jvpHandleOp Op
op ((Either ADValue ADVariable -> Either ADValue JVPValue)
-> [Either ADValue ADVariable] -> [Either ADValue JVPValue]
forall a b. (a -> b) -> [a] -> [b]
map Either ADValue ADVariable -> Either ADValue JVPValue
extractJVP [Either ADValue ADVariable]
o')
Maybe (Either ADValue ADVariable)
_ ->
String -> ADMonad ADValue
forall a. HasCallStack => String -> a
error String
"find isRight"
calculatePDs :: Op -> [ADValue] -> ADMonad [ADValue]
calculatePDs :: Op -> [ADValue] -> ExceptT String (State Counter) [ADValue]
calculatePDs Op
op [ADValue]
args =
let n :: [VName]
n = (Int -> VName) -> [Int] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (\Int
i -> Name -> Int -> VName
VName (String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"x" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i) Int
i) [Int
1 .. [ADValue] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ADValue]
args]
m :: Map VName ADValue
m = [(VName, ADValue)] -> Map VName ADValue
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, ADValue)] -> Map VName ADValue)
-> [(VName, ADValue)] -> Map VName ADValue
forall a b. (a -> b) -> a -> b
$ [VName] -> [ADValue] -> [(VName, ADValue)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
n [ADValue]
args
pde :: [PrimExp VName]
pde =
[PrimExp VName] -> Maybe [PrimExp VName] -> [PrimExp VName]
forall a. a -> Maybe a -> a
fromMaybe (String -> [PrimExp VName]
forall a. HasCallStack => String -> a
error String
"lookupPDs failed") (Maybe [PrimExp VName] -> [PrimExp VName])
-> Maybe [PrimExp VName] -> [PrimExp VName]
forall a b. (a -> b) -> a -> b
$
Op -> [PrimExp VName] -> Maybe [PrimExp VName]
lookupPDs Op
op ([PrimExp VName] -> Maybe [PrimExp VName])
-> [PrimExp VName] -> Maybe [PrimExp VName]
forall a b. (a -> b) -> a -> b
$
(VName -> ADValue -> PrimExp VName)
-> [VName] -> [ADValue] -> [PrimExp VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\VName
v ADValue
val -> VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
v (PrimType -> PrimExp VName) -> PrimType -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType (PrimValue -> PrimType) -> PrimValue -> PrimType
forall a b. (a -> b) -> a -> b
$ ADValue -> PrimValue
primitive ADValue
val) [VName]
n [ADValue]
args
res :: ExceptT e' (State Counter) [ADValue]
res = (PrimExp VName -> ExceptT e' (State Counter) ADValue)
-> [PrimExp VName] -> ExceptT e' (State Counter) [ADValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\PrimExp VName
x -> ADMonad ADValue
-> (String -> ExceptT e' (State Counter) ADValue)
-> ExceptT e' (State Counter) ADValue
forall (m :: * -> *) e a e'.
Monad m =>
ExceptT e m a -> (e -> ExceptT e' m a) -> ExceptT e' m a
catchE (Map VName ADValue -> PrimExp VName -> ADMonad ADValue
evalPrimExp Map VName ADValue
m PrimExp VName
x) ((String -> ExceptT e' (State Counter) ADValue)
-> ExceptT e' (State Counter) ADValue)
-> (String -> ExceptT e' (State Counter) ADValue)
-> ExceptT e' (State Counter) ADValue
forall a b. (a -> b) -> a -> b
$ String -> ExceptT e' (State Counter) ADValue
forall a. HasCallStack => String -> a
error (String -> ExceptT e' (State Counter) ADValue)
-> ShowS -> String -> ExceptT e' (State Counter) ADValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String
"evalPrimExp failed: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<>)) [PrimExp VName]
pde
in ExceptT String (State Counter) [ADValue]
forall {e'}. ExceptT e' (State Counter) [ADValue]
res
newtype VJPValue = VJPValue Tape
deriving (Int -> VJPValue -> ShowS
[VJPValue] -> ShowS
VJPValue -> String
(Int -> VJPValue -> ShowS)
-> (VJPValue -> String) -> ([VJPValue] -> ShowS) -> Show VJPValue
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> VJPValue -> ShowS
showsPrec :: Int -> VJPValue -> ShowS
$cshow :: VJPValue -> String
show :: VJPValue -> String
$cshowList :: [VJPValue] -> ShowS
showList :: [VJPValue] -> ShowS
Show)
data Tape
=
TapeID Counter ADValue
|
TapeConst ADValue
|
TapeOp Op [Tape] Counter ADValue
deriving (Int -> Tape -> ShowS
[Tape] -> ShowS
Tape -> String
(Int -> Tape -> ShowS)
-> (Tape -> String) -> ([Tape] -> ShowS) -> Show Tape
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Tape -> ShowS
showsPrec :: Int -> Tape -> ShowS
$cshow :: Tape -> String
show :: Tape -> String
$cshowList :: [Tape] -> ShowS
showList :: [Tape] -> ShowS
Show)
tapePrimal :: Tape -> ADValue
tapePrimal :: Tape -> ADValue
tapePrimal (TapeID Counter
_ ADValue
v) = ADValue
v
tapePrimal (TapeConst ADValue
v) = ADValue
v
tapePrimal (TapeOp Op
_ [Tape]
_ Counter
_ ADValue
v) = ADValue
v
vjpHandleOp :: Op -> [Either ADValue VJPValue] -> ADValue -> ADMonad Tape
vjpHandleOp :: Op
-> [Either ADValue VJPValue]
-> ADValue
-> ExceptT String (State Counter) Tape
vjpHandleOp Op
op [Either ADValue VJPValue]
p ADValue
v = do
i <- State Counter Counter -> ExceptT String (State Counter) Counter
forall (m :: * -> *) a. Monad m => m a -> ExceptT String m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift State Counter Counter
forall (m :: * -> *) s. Monad m => StateT s m s
get
pure $ TapeOp op (map toTape p) i v
where
toTape :: Either ADValue VJPValue -> Tape
toTape (Left ADValue
v') = ADValue -> Tape
TapeConst ADValue
v'
toTape (Right (VJPValue Tape
t)) = Tape
t
unionWithM :: (Monad m, Ord k) => (a -> a -> m a) -> M.Map k a -> M.Map k a -> m (M.Map k a)
unionWithM :: forall (m :: * -> *) k a.
(Monad m, Ord k) =>
(a -> a -> m a) -> Map k a -> Map k a -> m (Map k a)
unionWithM a -> a -> m a
f Map k a
m1 Map k a
m2 = do
let m :: Map k a
m = Map k a -> Map k a -> Map k a
forall k a. Ord k => Map k a -> Map k a -> Map k a
M.union (Map k a -> Map k a -> Map k a
forall k a b. Ord k => Map k a -> Map k b -> Map k a
M.difference Map k a
m1 Map k a
m2) (Map k a -> Map k a -> Map k a
forall k a b. Ord k => Map k a -> Map k b -> Map k a
M.difference Map k a
m2 Map k a
m1)
let k :: [k]
k = Map k a -> [k]
forall k a. Map k a -> [k]
M.keys (Map k a -> [k]) -> Map k a -> [k]
forall a b. (a -> b) -> a -> b
$ Map k a -> Map k a -> Map k a
forall k a b. Ord k => Map k a -> Map k b -> Map k a
M.intersection Map k a
m1 Map k a
m2
v <- (k -> m a) -> [k] -> m [a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\k
k' -> a -> a -> m a
f (Maybe a -> a
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe a -> a) -> Maybe a -> a
forall a b. (a -> b) -> a -> b
$ k -> Map k a -> Maybe a
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup k
k' Map k a
m1) (Maybe a -> a
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe a -> a) -> Maybe a -> a
forall a b. (a -> b) -> a -> b
$ k -> Map k a -> Maybe a
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup k
k' Map k a
m2)) [k]
k
pure $ foldl (\Map k a
m' (k
k', a
v') -> k -> a -> Map k a -> Map k a
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert k
k' a
v' Map k a
m') m (zip k v)
unionsWithM :: (Foldable f, Monad m, Ord k) => (a -> a -> m a) -> f (M.Map k a) -> m (M.Map k a)
unionsWithM :: forall (f :: * -> *) (m :: * -> *) k a.
(Foldable f, Monad m, Ord k) =>
(a -> a -> m a) -> f (Map k a) -> m (Map k a)
unionsWithM a -> a -> m a
f = (Map k a -> Map k a -> m (Map k a))
-> Map k a -> f (Map k a) -> m (Map k a)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ((a -> a -> m a) -> Map k a -> Map k a -> m (Map k a)
forall (m :: * -> *) k a.
(Monad m, Ord k) =>
(a -> a -> m a) -> Map k a -> Map k a -> m (Map k a)
unionWithM a -> a -> m a
f) Map k a
forall k a. Map k a
M.empty
deriveTape :: Tape -> ADValue -> Counter -> Either String (M.Map Counter ADValue, Counter)
deriveTape :: Tape
-> ADValue
-> Counter
-> Either String (Map Counter ADValue, Counter)
deriveTape Tape
tp ADValue
s Counter
uid = case State Counter (Either String (Map Counter ADValue))
-> Counter -> (Either String (Map Counter ADValue), Counter)
forall s a. State s a -> s -> (a, s)
runState (ExceptT String (State Counter) (Map Counter ADValue)
-> State Counter (Either String (Map Counter ADValue))
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT String (State Counter) (Map Counter ADValue)
-> State Counter (Either String (Map Counter ADValue)))
-> ExceptT String (State Counter) (Map Counter ADValue)
-> State Counter (Either String (Map Counter ADValue))
forall a b. (a -> b) -> a -> b
$ Tape
-> ADValue -> ExceptT String (State Counter) (Map Counter ADValue)
deriveTape' Tape
tp ADValue
s) Counter
uid of
(Left String
e, Counter
_) -> String -> Either String (Map Counter ADValue, Counter)
forall a b. a -> Either a b
Left String
e
(Right Map Counter ADValue
v, Counter
uid') -> (Map Counter ADValue, Counter)
-> Either String (Map Counter ADValue, Counter)
forall a b. b -> Either a b
Right (Map Counter ADValue
v, Counter
uid')
deriveTape' :: Tape -> ADValue -> ADMonad (M.Map Counter ADValue)
deriveTape' :: Tape
-> ADValue -> ExceptT String (State Counter) (Map Counter ADValue)
deriveTape' (TapeID Counter
i ADValue
_) ADValue
s = Map Counter ADValue
-> ExceptT String (State Counter) (Map Counter ADValue)
forall a. a -> ExceptT String (State Counter) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map Counter ADValue
-> ExceptT String (State Counter) (Map Counter ADValue))
-> Map Counter ADValue
-> ExceptT String (State Counter) (Map Counter ADValue)
forall a b. (a -> b) -> a -> b
$ Counter -> ADValue -> Map Counter ADValue
forall k a. k -> a -> Map k a
M.singleton Counter
i ADValue
s
deriveTape' (TapeConst ADValue
_) ADValue
_ = Map Counter ADValue
-> ExceptT String (State Counter) (Map Counter ADValue)
forall a. a -> ExceptT String (State Counter) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Map Counter ADValue
forall k a. Map k a
M.empty
deriveTape' tp :: Tape
tp@(TapeOp Op
op [Tape]
p Counter
uid ADValue
_) ADValue
s =
(Map Counter ADValue, Map Counter Int) -> Map Counter ADValue
forall a b. (a, b) -> a
fst ((Map Counter ADValue, Map Counter Int) -> Map Counter ADValue)
-> ExceptT
String (State Counter) (Map Counter ADValue, Map Counter Int)
-> ExceptT String (State Counter) (Map Counter ADValue)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tape
-> ADValue
-> Map Counter ADValue
-> Map Counter Int
-> ExceptT
String (State Counter) (Map Counter ADValue, Map Counter Int)
derive Tape
tp ADValue
s Map Counter ADValue
forall k a. Map k a
M.empty ([Tape] -> Map Counter Int -> Map Counter Int
countReferences [Tape]
p (Map Counter Int -> Map Counter Int)
-> Map Counter Int -> Map Counter Int
forall a b. (a -> b) -> a -> b
$ Counter -> Int -> Map Counter Int
forall k a. k -> a -> Map k a
M.singleton (-Counter
uid Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
- Counter
1) Int
1)
where
add :: ADValue -> ADValue -> ADMonad ADValue
add ADValue
x ADValue
y = Op -> [ADValue] -> ADMonad ADValue
doOp' (BinOp -> Op
OpBin (BinOp -> Op) -> BinOp -> Op
forall a b. (a -> b) -> a -> b
$ PrimType -> BinOp
addFor (PrimType -> BinOp) -> PrimType -> BinOp
forall a b. (a -> b) -> a -> b
$ Op -> PrimType
opReturnType Op
op) [ADValue
x, ADValue
y]
mul :: ADValue -> ADValue -> ADMonad ADValue
mul ADValue
x ADValue
y = Op -> [ADValue] -> ADMonad ADValue
doOp' (BinOp -> Op
OpBin (BinOp -> Op) -> BinOp -> Op
forall a b. (a -> b) -> a -> b
$ PrimType -> BinOp
mulFor (PrimType -> BinOp) -> PrimType -> BinOp
forall a b. (a -> b) -> a -> b
$ Op -> PrimType
opReturnType Op
op) [ADValue
x, ADValue
y]
madd :: Counter -> ADValue -> M.Map Counter ADValue -> ADMonad (M.Map Counter ADValue)
madd :: Counter
-> ADValue
-> Map Counter ADValue
-> ExceptT String (State Counter) (Map Counter ADValue)
madd Counter
i ADValue
a Map Counter ADValue
m = case Counter -> Map Counter ADValue -> Maybe ADValue
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Counter
i Map Counter ADValue
m of
Just ADValue
b -> ADValue -> ADValue -> ADMonad ADValue
add ADValue
a ADValue
b ADMonad ADValue
-> (ADValue -> Map Counter ADValue)
-> ExceptT String (State Counter) (Map Counter ADValue)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (\ADValue
x -> Counter -> ADValue -> Map Counter ADValue -> Map Counter ADValue
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Counter
i ADValue
x Map Counter ADValue
m)
Maybe ADValue
Nothing -> Map Counter ADValue
-> ExceptT String (State Counter) (Map Counter ADValue)
forall a. a -> ExceptT String (State Counter) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map Counter ADValue
-> ExceptT String (State Counter) (Map Counter ADValue))
-> Map Counter ADValue
-> ExceptT String (State Counter) (Map Counter ADValue)
forall a b. (a -> b) -> a -> b
$ Counter -> ADValue -> Map Counter ADValue -> Map Counter ADValue
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Counter
i ADValue
a Map Counter ADValue
m
derive ::
Tape ->
ADValue ->
M.Map Counter ADValue ->
M.Map Counter Int ->
ADMonad (M.Map Counter ADValue, M.Map Counter Int)
derive :: Tape
-> ADValue
-> Map Counter ADValue
-> Map Counter Int
-> ExceptT
String (State Counter) (Map Counter ADValue, Map Counter Int)
derive (TapeID Counter
i ADValue
_) ADValue
s' Map Counter ADValue
ss Map Counter Int
rs = Counter
-> ADValue
-> Map Counter ADValue
-> ExceptT String (State Counter) (Map Counter ADValue)
madd Counter
i ADValue
s' Map Counter ADValue
ss ExceptT String (State Counter) (Map Counter ADValue)
-> (Map Counter ADValue -> (Map Counter ADValue, Map Counter Int))
-> ExceptT
String (State Counter) (Map Counter ADValue, Map Counter Int)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (,Map Counter Int
rs)
derive (TapeConst ADValue
_) ADValue
_ Map Counter ADValue
ss Map Counter Int
rs = (Map Counter ADValue, Map Counter Int)
-> ExceptT
String (State Counter) (Map Counter ADValue, Map Counter Int)
forall a. a -> ExceptT String (State Counter) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map Counter ADValue
ss, Map Counter Int
rs)
derive (TapeOp Op
op' [Tape]
p' Counter
uid' ADValue
_) ADValue
s' Map Counter ADValue
ss Map Counter Int
rs = do
let r :: Int
r = Maybe Int -> Int
forall a. HasCallStack => Maybe a -> a
fromJust (Counter -> Map Counter Int -> Maybe Int
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (-Counter
uid' Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
- Counter
1) Map Counter Int
rs) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
rs' :: Map Counter Int
rs' = Counter -> Int -> Map Counter Int -> Map Counter Int
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (-Counter
uid' Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
- Counter
1) Int
r Map Counter Int
rs
ss' <- Counter
-> ADValue
-> Map Counter ADValue
-> ExceptT String (State Counter) (Map Counter ADValue)
madd (-Counter
uid' Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
- Counter
1) ADValue
s' Map Counter ADValue
ss
if r > 0
then pure (ss', rs')
else
if r == 0
then do
let s'' = Maybe ADValue -> ADValue
forall a. HasCallStack => Maybe a -> a
fromJust (Counter -> Map Counter ADValue -> Maybe ADValue
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (-Counter
uid' Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
- Counter
1) Map Counter ADValue
ss')
s''' <- case op' of
OpConv ConvOp
op'' ->
[ADMonad ADValue] -> ExceptT String (State Counter) [ADValue]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence [Op -> [ADValue] -> ADMonad ADValue
doOp' (ConvOp -> Op
OpConv (ConvOp -> Op) -> ConvOp -> Op
forall a b. (a -> b) -> a -> b
$ ConvOp -> ConvOp
flipConvOp ConvOp
op'') [ADValue
s'']]
Op
_ -> Op -> [ADValue] -> ExceptT String (State Counter) [ADValue]
calculatePDs Op
op' ((Tape -> ADValue) -> [Tape] -> [ADValue]
forall a b. (a -> b) -> [a] -> [b]
map Tape -> ADValue
tapePrimal [Tape]
p') ExceptT String (State Counter) [ADValue]
-> ([ADValue] -> ExceptT String (State Counter) [ADValue])
-> ExceptT String (State Counter) [ADValue]
forall a b.
ExceptT String (State Counter) a
-> (a -> ExceptT String (State Counter) b)
-> ExceptT String (State Counter) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (ADValue -> ADMonad ADValue)
-> [ADValue] -> ExceptT String (State Counter) [ADValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (ADValue -> ADValue -> ADMonad ADValue
mul ADValue
s'')
foldlM (\(Map Counter ADValue
ss'', Map Counter Int
rs'') (Tape
p'', ADValue
s'''') -> Tape
-> ADValue
-> Map Counter ADValue
-> Map Counter Int
-> ExceptT
String (State Counter) (Map Counter ADValue, Map Counter Int)
derive Tape
p'' ADValue
s'''' Map Counter ADValue
ss'' Map Counter Int
rs'') (ss', rs') $ zip p' s'''
else error "TODO: This branch is unreachable unless `countReferences` undercounts"
countReferences :: [Tape] -> M.Map Counter Int -> M.Map Counter Int
countReferences :: [Tape] -> Map Counter Int -> Map Counter Int
countReferences [Tape]
p' Map Counter Int
d' = (Map Counter Int -> Tape -> Map Counter Int)
-> Map Counter Int -> [Tape] -> Map Counter Int
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Map Counter Int -> Tape -> Map Counter Int
f Map Counter Int
d' [Tape]
p'
f :: Map Counter Int -> Tape -> Map Counter Int
f Map Counter Int
d'' Tape
x =
case Tape
x of
(TapeOp Op
_ [Tape]
p'' Counter
uid'' ADValue
_) -> case Counter -> Map Counter Int -> Maybe Int
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (-Counter
uid'' Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
- Counter
1) Map Counter Int
d'' of
Just Int
v -> Counter -> Int -> Map Counter Int -> Map Counter Int
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (-Counter
uid'' Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
- Counter
1) (Int
v Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Map Counter Int
d''
Maybe Int
Nothing -> [Tape] -> Map Counter Int -> Map Counter Int
countReferences [Tape]
p'' (Map Counter Int -> Map Counter Int)
-> Map Counter Int -> Map Counter Int
forall a b. (a -> b) -> a -> b
$ Counter -> Int -> Map Counter Int -> Map Counter Int
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (-Counter
uid'' Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
- Counter
1) Int
1 Map Counter Int
d''
Tape
_ -> Map Counter Int
d''
data JVPValue = JVPValue ADValue ADValue
deriving (Int -> JVPValue -> ShowS
[JVPValue] -> ShowS
JVPValue -> String
(Int -> JVPValue -> ShowS)
-> (JVPValue -> String) -> ([JVPValue] -> ShowS) -> Show JVPValue
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> JVPValue -> ShowS
showsPrec :: Int -> JVPValue -> ShowS
$cshow :: JVPValue -> String
show :: JVPValue -> String
$cshowList :: [JVPValue] -> ShowS
showList :: [JVPValue] -> ShowS
Show)
jvpHandleOp :: Op -> [Either ADValue JVPValue] -> ADMonad ADValue
jvpHandleOp :: Op -> [Either ADValue JVPValue] -> ADMonad ADValue
jvpHandleOp Op
op [Either ADValue JVPValue]
p = do
case Op
op of
OpConv ConvOp
_ ->
Op -> [ADValue] -> ADMonad ADValue
doOp' Op
op [Either ADValue JVPValue -> ADValue
forall {a}. Either a JVPValue -> ADValue
tangent (Either ADValue JVPValue -> ADValue)
-> Either ADValue JVPValue -> ADValue
forall a b. (a -> b) -> a -> b
$ [Either ADValue JVPValue] -> Either ADValue JVPValue
forall a. HasCallStack => [a] -> a
head [Either ADValue JVPValue]
p]
Op
_ -> do
pds <- Op -> [ADValue] -> ExceptT String (State Counter) [ADValue]
calculatePDs Op
op ([ADValue] -> ExceptT String (State Counter) [ADValue])
-> [ADValue] -> ExceptT String (State Counter) [ADValue]
forall a b. (a -> b) -> a -> b
$ (Either ADValue JVPValue -> ADValue)
-> [Either ADValue JVPValue] -> [ADValue]
forall a b. (a -> b) -> [a] -> [b]
map Either ADValue JVPValue -> ADValue
primal' [Either ADValue JVPValue]
p
vs <- zipWithM mul pds $ map tangent p
foldM add (Constant $ blankPrimValue op_t) vs
where
op_t :: PrimType
op_t = Op -> PrimType
opReturnType Op
op
primal' :: Either ADValue JVPValue -> ADValue
primal' (Left ADValue
v) = ADValue
v
primal' (Right (JVPValue ADValue
v ADValue
_)) = ADValue
v
tangent :: Either a JVPValue -> ADValue
tangent (Left a
_) = PrimValue -> ADValue
Constant (PrimValue -> ADValue) -> PrimValue -> ADValue
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue (PrimType -> PrimValue) -> PrimType -> PrimValue
forall a b. (a -> b) -> a -> b
$ Op -> PrimType
opReturnType Op
op
tangent (Right (JVPValue ADValue
_ ADValue
d)) = ADValue
d
add :: ADValue -> ADValue -> ADMonad ADValue
add ADValue
x ADValue
y = Op -> [ADValue] -> ADMonad ADValue
doOp' (BinOp -> Op
OpBin (BinOp -> Op) -> BinOp -> Op
forall a b. (a -> b) -> a -> b
$ PrimType -> BinOp
addFor (PrimType -> BinOp) -> PrimType -> BinOp
forall a b. (a -> b) -> a -> b
$ Op -> PrimType
opReturnType Op
op) [ADValue
x, ADValue
y]
mul :: ADValue -> ADValue -> ADMonad ADValue
mul ADValue
x ADValue
y = Op -> [ADValue] -> ADMonad ADValue
doOp' (BinOp -> Op
OpBin (BinOp -> Op) -> BinOp -> Op
forall a b. (a -> b) -> a -> b
$ PrimType -> BinOp
mulFor (PrimType -> BinOp) -> PrimType -> BinOp
forall a b. (a -> b) -> a -> b
$ Op -> PrimType
opReturnType Op
op) [ADValue
x, ADValue
y]