module Language.Futhark.Interpreter.AD
  ( Op (..),
    ADVariable (..),
    ADValue (..),
    Tape (..),
    VJPValue (..),
    JVPValue (..),
    Counter (..),
    Depth (..),
    doOp,
    addFor,
    tapePrimal,
    primitive,
    varPrimal,
    deriveTape,
    unionWithM,
    unionsWithM,
  )
where

import Control.Monad (foldM, zipWithM)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.Except (ExceptT, catchE, runExceptT, throwE)
import Control.Monad.Trans.State (State, get, modify, runState)
import Data.Either (isRight)
import Data.Foldable (find, foldlM)
import Data.Functor ((<&>))
import Data.Map qualified as M
import Data.Maybe (fromJust, fromMaybe)
import Data.Text qualified as T
import Futhark.AD.Derivatives (pdBinOp, pdBuiltin, pdUnOp)
import Futhark.Analysis.PrimExp (PrimExp (..))
import Language.Futhark.Core (VName (..), nameFromString, nameFromText)
import Language.Futhark.Primitive
  ( BinOp (Add, FAdd, FMul, LogAnd, LogOr, Mul),
    CmpOp,
    ConvOp,
    Overflow (OverflowWrap),
    PrimType (Bool, FloatType, IntType),
    PrimValue (BoolValue),
    UnOp,
    binOpType,
    blankPrimValue,
    cmpOpType,
    convOpType,
    doBinOp,
    doCmpOp,
    doConvOp,
    doUnOp,
    flipConvOp,
    primFuns,
    primValueType,
    unOpType,
  )

-- | Used to uniquely identify values.
newtype Counter = Counter Int
  deriving (Counter -> Counter -> Bool
(Counter -> Counter -> Bool)
-> (Counter -> Counter -> Bool) -> Eq Counter
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Counter -> Counter -> Bool
== :: Counter -> Counter -> Bool
$c/= :: Counter -> Counter -> Bool
/= :: Counter -> Counter -> Bool
Eq, Eq Counter
Eq Counter =>
(Counter -> Counter -> Ordering)
-> (Counter -> Counter -> Bool)
-> (Counter -> Counter -> Bool)
-> (Counter -> Counter -> Bool)
-> (Counter -> Counter -> Bool)
-> (Counter -> Counter -> Counter)
-> (Counter -> Counter -> Counter)
-> Ord Counter
Counter -> Counter -> Bool
Counter -> Counter -> Ordering
Counter -> Counter -> Counter
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Counter -> Counter -> Ordering
compare :: Counter -> Counter -> Ordering
$c< :: Counter -> Counter -> Bool
< :: Counter -> Counter -> Bool
$c<= :: Counter -> Counter -> Bool
<= :: Counter -> Counter -> Bool
$c> :: Counter -> Counter -> Bool
> :: Counter -> Counter -> Bool
$c>= :: Counter -> Counter -> Bool
>= :: Counter -> Counter -> Bool
$cmax :: Counter -> Counter -> Counter
max :: Counter -> Counter -> Counter
$cmin :: Counter -> Counter -> Counter
min :: Counter -> Counter -> Counter
Ord, Integer -> Counter
Counter -> Counter
Counter -> Counter -> Counter
(Counter -> Counter -> Counter)
-> (Counter -> Counter -> Counter)
-> (Counter -> Counter -> Counter)
-> (Counter -> Counter)
-> (Counter -> Counter)
-> (Counter -> Counter)
-> (Integer -> Counter)
-> Num Counter
forall a.
(a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (Integer -> a)
-> Num a
$c+ :: Counter -> Counter -> Counter
+ :: Counter -> Counter -> Counter
$c- :: Counter -> Counter -> Counter
- :: Counter -> Counter -> Counter
$c* :: Counter -> Counter -> Counter
* :: Counter -> Counter -> Counter
$cnegate :: Counter -> Counter
negate :: Counter -> Counter
$cabs :: Counter -> Counter
abs :: Counter -> Counter
$csignum :: Counter -> Counter
signum :: Counter -> Counter
$cfromInteger :: Integer -> Counter
fromInteger :: Integer -> Counter
Num, Int -> Counter -> ShowS
[Counter] -> ShowS
Counter -> String
(Int -> Counter -> ShowS)
-> (Counter -> String) -> ([Counter] -> ShowS) -> Show Counter
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Counter -> ShowS
showsPrec :: Int -> Counter -> ShowS
$cshow :: Counter -> String
show :: Counter -> String
$cshowList :: [Counter] -> ShowS
showList :: [Counter] -> ShowS
Show)

type ADMonad = ExceptT String (State Counter)

incCounter :: ADMonad ()
incCounter :: ADMonad ()
incCounter = State Counter () -> ADMonad ()
forall (m :: * -> *) a. Monad m => m a -> ExceptT String m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (State Counter () -> ADMonad ()) -> State Counter () -> ADMonad ()
forall a b. (a -> b) -> a -> b
$ (Counter -> Counter) -> State Counter ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((Counter -> Counter) -> State Counter ())
-> (Counter -> Counter) -> State Counter ()
forall a b. (a -> b) -> a -> b
$ \Counter
i -> Counter
i Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
+ Counter
1

-- Mathematical operations subject to AD.
data Op
  = OpBin BinOp
  | OpCmp CmpOp
  | OpUn UnOp
  | OpFn T.Text
  | OpConv ConvOp
  deriving (Int -> Op -> ShowS
[Op] -> ShowS
Op -> String
(Int -> Op -> ShowS)
-> (Op -> String) -> ([Op] -> ShowS) -> Show Op
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Op -> ShowS
showsPrec :: Int -> Op -> ShowS
$cshow :: Op -> String
show :: Op -> String
$cshowList :: [Op] -> ShowS
showList :: [Op] -> ShowS
Show)

-- Checks if an operation matches the types of its operands
opTypeMatch :: Op -> [PrimType] -> Bool
opTypeMatch :: Op -> [PrimType] -> Bool
opTypeMatch (OpBin BinOp
op) [PrimType]
p = (PrimType -> Bool) -> [PrimType] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\PrimType
x -> BinOp -> PrimType
binOpType BinOp
op PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
x) [PrimType]
p
opTypeMatch (OpCmp CmpOp
op) [PrimType]
p = (PrimType -> Bool) -> [PrimType] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\PrimType
x -> CmpOp -> PrimType
cmpOpType CmpOp
op PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
x) [PrimType]
p
opTypeMatch (OpUn UnOp
op) [PrimType]
p = (PrimType -> Bool) -> [PrimType] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\PrimType
x -> UnOp -> PrimType
unOpType UnOp
op PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
x) [PrimType]
p
opTypeMatch (OpConv ConvOp
op) [PrimType]
p = (PrimType -> Bool) -> [PrimType] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\PrimType
x -> (PrimType, PrimType) -> PrimType
forall a b. (a, b) -> a
fst (ConvOp -> (PrimType, PrimType)
convOpType ConvOp
op) PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
x) [PrimType]
p
opTypeMatch (OpFn Text
fn) [PrimType]
p = case Text
-> Map Text ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
-> Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Text
fn Map Text ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
primFuns of
  Just ([PrimType]
t, PrimType
_, [PrimValue] -> Maybe PrimValue
_) -> [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ (PrimType -> PrimType -> Bool)
-> [PrimType] -> [PrimType] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
(==) [PrimType]
t [PrimType]
p
  Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
Nothing -> String -> Bool
forall a. HasCallStack => String -> a
error String
"opTypeMatch" -- It is assumed that the function exists

-- Gets the return type of an operation
opReturnType :: Op -> PrimType
opReturnType :: Op -> PrimType
opReturnType (OpBin BinOp
op) = BinOp -> PrimType
binOpType BinOp
op
opReturnType (OpCmp CmpOp
op) = CmpOp -> PrimType
cmpOpType CmpOp
op
opReturnType (OpUn UnOp
op) = UnOp -> PrimType
unOpType UnOp
op
opReturnType (OpConv ConvOp
op) = (PrimType, PrimType) -> PrimType
forall a b. (a, b) -> b
snd ((PrimType, PrimType) -> PrimType)
-> (PrimType, PrimType) -> PrimType
forall a b. (a -> b) -> a -> b
$ ConvOp -> (PrimType, PrimType)
convOpType ConvOp
op
opReturnType (OpFn Text
fn) = case Text
-> Map Text ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
-> Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Text
fn Map Text ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
primFuns of
  Just ([PrimType]
_, PrimType
t, [PrimValue] -> Maybe PrimValue
_) -> PrimType
t
  Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
Nothing -> String -> PrimType
forall a. HasCallStack => String -> a
error String
"opReturnType" -- It is assumed that the function exists

-- Returns the operation which performs addition (or an
-- equivalent operation) on the given type
addFor :: PrimType -> BinOp
addFor :: PrimType -> BinOp
addFor (IntType IntType
t) = IntType -> Overflow -> BinOp
Add IntType
t Overflow
OverflowWrap
addFor (FloatType FloatType
t) = FloatType -> BinOp
FAdd FloatType
t
addFor PrimType
Bool = BinOp
LogOr
addFor PrimType
t = String -> BinOp
forall a. HasCallStack => String -> a
error (String -> BinOp) -> String -> BinOp
forall a b. (a -> b) -> a -> b
$ String
"addFor: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ PrimType -> String
forall a. Show a => a -> String
show PrimType
t

-- Returns the function which performs multiplication
-- (or an equivalent operation) on the given type
mulFor :: PrimType -> BinOp
mulFor :: PrimType -> BinOp
mulFor (IntType IntType
t) = IntType -> Overflow -> BinOp
Mul IntType
t Overflow
OverflowWrap
mulFor (FloatType FloatType
t) = FloatType -> BinOp
FMul FloatType
t
mulFor PrimType
Bool = BinOp
LogAnd
mulFor PrimType
t = String -> BinOp
forall a. HasCallStack => String -> a
error (String -> BinOp) -> String -> BinOp
forall a b. (a -> b) -> a -> b
$ String
"mulFor: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ PrimType -> String
forall a. Show a => a -> String
show PrimType
t

-- | An indication of the nesting depth of AD. This is used to avoid
-- pertubation confusion.
newtype Depth = Depth Int
  deriving (Eq Depth
Eq Depth =>
(Depth -> Depth -> Ordering)
-> (Depth -> Depth -> Bool)
-> (Depth -> Depth -> Bool)
-> (Depth -> Depth -> Bool)
-> (Depth -> Depth -> Bool)
-> (Depth -> Depth -> Depth)
-> (Depth -> Depth -> Depth)
-> Ord Depth
Depth -> Depth -> Bool
Depth -> Depth -> Ordering
Depth -> Depth -> Depth
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Depth -> Depth -> Ordering
compare :: Depth -> Depth -> Ordering
$c< :: Depth -> Depth -> Bool
< :: Depth -> Depth -> Bool
$c<= :: Depth -> Depth -> Bool
<= :: Depth -> Depth -> Bool
$c> :: Depth -> Depth -> Bool
> :: Depth -> Depth -> Bool
$c>= :: Depth -> Depth -> Bool
>= :: Depth -> Depth -> Bool
$cmax :: Depth -> Depth -> Depth
max :: Depth -> Depth -> Depth
$cmin :: Depth -> Depth -> Depth
min :: Depth -> Depth -> Depth
Ord, Depth -> Depth -> Bool
(Depth -> Depth -> Bool) -> (Depth -> Depth -> Bool) -> Eq Depth
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Depth -> Depth -> Bool
== :: Depth -> Depth -> Bool
$c/= :: Depth -> Depth -> Bool
/= :: Depth -> Depth -> Bool
Eq, Int -> Depth -> ShowS
[Depth] -> ShowS
Depth -> String
(Int -> Depth -> ShowS)
-> (Depth -> String) -> ([Depth] -> ShowS) -> Show Depth
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Depth -> ShowS
showsPrec :: Int -> Depth -> ShowS
$cshow :: Depth -> String
show :: Depth -> String
$cshowList :: [Depth] -> ShowS
showList :: [Depth] -> ShowS
Show)

-- Types and utility functions--
-- When taking the partial derivative of a function, we
-- must differentiate between the values which are kept
-- constant, and those which are not
data ADValue
  = Variable Depth ADVariable
  | Constant PrimValue
  deriving (Int -> ADValue -> ShowS
[ADValue] -> ShowS
ADValue -> String
(Int -> ADValue -> ShowS)
-> (ADValue -> String) -> ([ADValue] -> ShowS) -> Show ADValue
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ADValue -> ShowS
showsPrec :: Int -> ADValue -> ShowS
$cshow :: ADValue -> String
show :: ADValue -> String
$cshowList :: [ADValue] -> ShowS
showList :: [ADValue] -> ShowS
Show)

-- When performing automatic differentiation, each derived
-- variable must be augmented with additional data. This
-- value holds the primitive value of the variable, as well
-- as its data
data ADVariable
  = VJP VJPValue
  | JVP JVPValue
  deriving (Int -> ADVariable -> ShowS
[ADVariable] -> ShowS
ADVariable -> String
(Int -> ADVariable -> ShowS)
-> (ADVariable -> String)
-> ([ADVariable] -> ShowS)
-> Show ADVariable
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ADVariable -> ShowS
showsPrec :: Int -> ADVariable -> ShowS
$cshow :: ADVariable -> String
show :: ADVariable -> String
$cshowList :: [ADVariable] -> ShowS
showList :: [ADVariable] -> ShowS
Show)

depth :: ADValue -> Depth
depth :: ADValue -> Depth
depth (Variable Depth
d ADVariable
_) = Depth
d
depth (Constant PrimValue
_) = Int -> Depth
Depth Int
0

primal :: ADValue -> ADValue
primal :: ADValue -> ADValue
primal (Variable Depth
_ (VJP (VJPValue Tape
t))) = Tape -> ADValue
tapePrimal Tape
t
primal (Variable Depth
_ (JVP (JVPValue ADValue
v ADValue
_))) = ADValue -> ADValue
primal ADValue
v
primal (Constant PrimValue
v) = PrimValue -> ADValue
Constant PrimValue
v

primalFor :: Depth -> ADValue -> ADValue
primalFor :: Depth -> ADValue -> ADValue
primalFor Depth
cur v :: ADValue
v@(Variable Depth
tag ADVariable
_) | Depth
cur Depth -> Depth -> Bool
forall a. Eq a => a -> a -> Bool
/= Depth
tag = ADValue
v
primalFor Depth
_ (Variable Depth
_ (VJP (VJPValue Tape
t))) = Tape -> ADValue
tapePrimal Tape
t
primalFor Depth
cur (Variable Depth
_ (JVP (JVPValue ADValue
v ADValue
_))) = Depth -> ADValue -> ADValue
primalFor Depth
cur ADValue
v
primalFor Depth
_ (Constant PrimValue
v) = PrimValue -> ADValue
Constant PrimValue
v

primitive :: ADValue -> PrimValue
primitive :: ADValue -> PrimValue
primitive (Variable Depth
_ ADVariable
v) = ADVariable -> PrimValue
varPrimal ADVariable
v
primitive (Constant PrimValue
v) = PrimValue
v

varPrimal :: ADVariable -> PrimValue
varPrimal :: ADVariable -> PrimValue
varPrimal (VJP (VJPValue Tape
t)) = ADValue -> PrimValue
primitive (ADValue -> PrimValue) -> ADValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Tape -> ADValue
tapePrimal Tape
t
varPrimal (JVP (JVPValue ADValue
v ADValue
_)) = ADValue -> PrimValue
primitive (ADValue -> PrimValue) -> ADValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ ADValue -> ADValue
primal ADValue
v

-- Evaluates a PrimExp using doOp'
evalPrimExp :: M.Map VName ADValue -> PrimExp VName -> ADMonad ADValue
evalPrimExp :: Map VName ADValue -> PrimExp VName -> ADMonad ADValue
evalPrimExp Map VName ADValue
m (LeafExp VName
n PrimType
_) =
  ADMonad ADValue
-> (ADValue -> ADMonad ADValue) -> Maybe ADValue -> ADMonad ADValue
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (String -> ADMonad ADValue
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (String -> ADMonad ADValue) -> String -> ADMonad ADValue
forall a b. (a -> b) -> a -> b
$ String
"Unknown variable " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> VName -> String
forall a. Show a => a -> String
show VName
n) ADValue -> ADMonad ADValue
forall a. a -> ExceptT String (State Counter) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe ADValue -> ADMonad ADValue)
-> Maybe ADValue -> ADMonad ADValue
forall a b. (a -> b) -> a -> b
$ VName -> Map VName ADValue -> Maybe ADValue
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
n Map VName ADValue
m
evalPrimExp Map VName ADValue
_ (ValueExp PrimValue
pv) =
  ADValue -> ADMonad ADValue
forall a. a -> ExceptT String (State Counter) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADValue -> ADMonad ADValue) -> ADValue -> ADMonad ADValue
forall a b. (a -> b) -> a -> b
$ PrimValue -> ADValue
Constant PrimValue
pv
evalPrimExp Map VName ADValue
m (BinOpExp BinOp
op PrimExp VName
x PrimExp VName
y) = do
  x' <- Map VName ADValue -> PrimExp VName -> ADMonad ADValue
evalPrimExp Map VName ADValue
m PrimExp VName
x
  y' <- evalPrimExp m y
  doOp' (OpBin op) [x', y']
evalPrimExp Map VName ADValue
m (CmpOpExp CmpOp
op PrimExp VName
x PrimExp VName
y) = do
  x' <- Map VName ADValue -> PrimExp VName -> ADMonad ADValue
evalPrimExp Map VName ADValue
m PrimExp VName
x
  y' <- evalPrimExp m y
  doOp' (OpCmp op) [x', y']
evalPrimExp Map VName ADValue
m (UnOpExp UnOp
op PrimExp VName
x) = do
  x' <- Map VName ADValue -> PrimExp VName -> ADMonad ADValue
evalPrimExp Map VName ADValue
m PrimExp VName
x
  doOp' (OpUn op) [x']
evalPrimExp Map VName ADValue
m (ConvOpExp ConvOp
op PrimExp VName
x) = do
  x' <- Map VName ADValue -> PrimExp VName -> ADMonad ADValue
evalPrimExp Map VName ADValue
m PrimExp VName
x
  doOp' (OpConv op) [x']
evalPrimExp Map VName ADValue
m (FunExp Text
fn [PrimExp VName]
p PrimType
_) = do
  p' <- (PrimExp VName -> ADMonad ADValue)
-> [PrimExp VName] -> ExceptT String (State Counter) [ADValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Map VName ADValue -> PrimExp VName -> ADMonad ADValue
evalPrimExp Map VName ADValue
m) [PrimExp VName]
p
  doOp' (OpFn fn) p'

-- Returns a list of PrimExps calculating the partial
-- derivative of each operands of a given operation
lookupPDs :: Op -> [PrimExp VName] -> Maybe [PrimExp VName]
lookupPDs :: Op -> [PrimExp VName] -> Maybe [PrimExp VName]
lookupPDs (OpBin BinOp
op) [PrimExp VName
x, PrimExp VName
y] = [PrimExp VName] -> Maybe [PrimExp VName]
forall a. a -> Maybe a
Just ([PrimExp VName] -> Maybe [PrimExp VName])
-> [PrimExp VName] -> Maybe [PrimExp VName]
forall a b. (a -> b) -> a -> b
$ do
  let (PrimExp VName
a, PrimExp VName
b) = BinOp
-> PrimExp VName -> PrimExp VName -> (PrimExp VName, PrimExp VName)
pdBinOp BinOp
op PrimExp VName
x PrimExp VName
y
  [PrimExp VName
a, PrimExp VName
b]
lookupPDs (OpUn UnOp
op) [PrimExp VName
x] = [PrimExp VName] -> Maybe [PrimExp VName]
forall a. a -> Maybe a
Just [UnOp -> PrimExp VName -> PrimExp VName
pdUnOp UnOp
op PrimExp VName
x]
lookupPDs (OpFn Text
fn) [PrimExp VName]
p = Name -> [PrimExp VName] -> Maybe [PrimExp VName]
pdBuiltin (Text -> Name
nameFromText Text
fn) [PrimExp VName]
p
lookupPDs Op
_ [PrimExp VName]
_ = Maybe [PrimExp VName]
forall a. Maybe a
Nothing

-- Shared AD logic--
-- This function performs a mathematical operation on a
-- list of operands, performing automatic differentiation
-- if one or more operands is a Variable (of depth > 0)
doOp :: Op -> [ADValue] -> Counter -> Either String (ADValue, Counter)
doOp :: Op -> [ADValue] -> Counter -> Either String (ADValue, Counter)
doOp Op
op [ADValue]
o Counter
uid = case State Counter (Either String ADValue)
-> Counter -> (Either String ADValue, Counter)
forall s a. State s a -> s -> (a, s)
runState (ADMonad ADValue -> State Counter (Either String ADValue)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ADMonad ADValue -> State Counter (Either String ADValue))
-> ADMonad ADValue -> State Counter (Either String ADValue)
forall a b. (a -> b) -> a -> b
$ Op -> [ADValue] -> ADMonad ADValue
doOp' Op
op [ADValue]
o) Counter
uid of
  (Left String
s, Counter
_) -> String -> Either String (ADValue, Counter)
forall a b. a -> Either a b
Left String
s
  (Right ADValue
v, Counter
uid') -> (ADValue, Counter) -> Either String (ADValue, Counter)
forall a b. b -> Either a b
Right (ADValue
v, Counter
uid')

doOp' :: Op -> [ADValue] -> ADMonad ADValue
doOp' :: Op -> [ADValue] -> ADMonad ADValue
doOp' Op
op [ADValue]
o
  | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Op -> [PrimType] -> Bool
opTypeMatch Op
op ((PrimValue -> PrimType) -> [PrimValue] -> [PrimType]
forall a b. (a -> b) -> [a] -> [b]
map PrimValue -> PrimType
primValueType [PrimValue]
pv) =
      -- This function may be called with arguments of invalid types,
      -- because it is used as part of an overloaded operator.
      String -> ADMonad ADValue
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (String -> ADMonad ADValue) -> String -> ADMonad ADValue
forall a b. (a -> b) -> a -> b
$ [String] -> String
unwords [String
"invalid types for op", Op -> String
forall a. Show a => a -> String
show Op
op, String
"and operands", [ADValue] -> String
forall a. Show a => a -> String
show [ADValue]
o]
  | Bool
otherwise = do
      let dep :: Depth
dep = case Op
op of
            OpCmp CmpOp
_ -> Int -> Depth
Depth Int
0 -- AD is not well-defined for comparason operations
            -- There are no derivatives for those written in
            -- PrimExp (check lookupPDs)
            Op
_ -> [Depth] -> Depth
forall a. Ord a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum ((ADValue -> Depth) -> [ADValue] -> [Depth]
forall a b. (a -> b) -> [a] -> [b]
map ADValue -> Depth
depth [ADValue]
o)
      if Depth
dep Depth -> Depth -> Bool
forall a. Eq a => a -> a -> Bool
== Int -> Depth
Depth Int
0
        then ADMonad ADValue
-> (ADValue -> ADMonad ADValue) -> Maybe ADValue -> ADMonad ADValue
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (String -> ADMonad ADValue
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE String
"failed to evaluate const") ADValue -> ADMonad ADValue
forall a. a -> ExceptT String (State Counter) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe ADValue
constCase ADMonad ADValue -> ADMonad () -> ADMonad ADValue
forall a b.
ExceptT String (State Counter) a
-> ExceptT String (State Counter) b
-> ExceptT String (State Counter) a
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* ADMonad ()
incCounter
        else Depth -> ADMonad ADValue
nonconstCase Depth
dep
  where
    pv :: [PrimValue]
pv = (ADValue -> PrimValue) -> [ADValue] -> [PrimValue]
forall a b. (a -> b) -> [a] -> [b]
map ADValue -> PrimValue
primitive [ADValue]
o

    divideDepths :: Depth -> ADValue -> Either ADValue ADVariable
    divideDepths :: Depth -> ADValue -> Either ADValue ADVariable
divideDepths Depth
_ v :: ADValue
v@(Constant {}) = ADValue -> Either ADValue ADVariable
forall a b. a -> Either a b
Left ADValue
v
    divideDepths Depth
d v :: ADValue
v@(Variable Depth
d' ADVariable
v') = if Depth
d' Depth -> Depth -> Bool
forall a. Ord a => a -> a -> Bool
< Depth
d then ADValue -> Either ADValue ADVariable
forall a b. a -> Either a b
Left ADValue
v else ADVariable -> Either ADValue ADVariable
forall a b. b -> Either a b
Right ADVariable
v'

    -- TODO: There may be a more graceful way of
    -- doing this
    extractVJP :: Either ADValue ADVariable -> Either ADValue VJPValue
    extractVJP :: Either ADValue ADVariable -> Either ADValue VJPValue
extractVJP (Right (VJP VJPValue
v)) = VJPValue -> Either ADValue VJPValue
forall a b. b -> Either a b
Right VJPValue
v
    extractVJP (Left ADValue
v) = ADValue -> Either ADValue VJPValue
forall a b. a -> Either a b
Left ADValue
v
    extractVJP Either ADValue ADVariable
_ =
      -- This will never be called when the maximum depth layer is JVP
      String -> Either ADValue VJPValue
forall a. HasCallStack => String -> a
error String
"extractVJP"

    -- TODO: There may be a more graceful way of
    -- doing this
    extractJVP :: Either ADValue ADVariable -> Either ADValue JVPValue
    extractJVP :: Either ADValue ADVariable -> Either ADValue JVPValue
extractJVP (Right (JVP JVPValue
v)) = JVPValue -> Either ADValue JVPValue
forall a b. b -> Either a b
Right JVPValue
v
    extractJVP (Left ADValue
v) = ADValue -> Either ADValue JVPValue
forall a b. a -> Either a b
Left ADValue
v
    extractJVP Either ADValue ADVariable
_ =
      -- This will never be called when the maximum depth layer is VJP
      String -> Either ADValue JVPValue
forall a. HasCallStack => String -> a
error String
"extractJVP"

    -- In this case, every operand is a constant, and the
    -- mathematical operation can be applied as it would be
    -- otherwise
    constCase :: Maybe ADValue
constCase =
      PrimValue -> ADValue
Constant (PrimValue -> ADValue) -> Maybe PrimValue -> Maybe ADValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> case (Op
op, [PrimValue]
pv) of
        (OpBin BinOp
op', [PrimValue
x, PrimValue
y]) -> BinOp -> PrimValue -> PrimValue -> Maybe PrimValue
doBinOp BinOp
op' PrimValue
x PrimValue
y
        (OpCmp CmpOp
op', [PrimValue
x, PrimValue
y]) -> Bool -> PrimValue
BoolValue (Bool -> PrimValue) -> Maybe Bool -> Maybe PrimValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CmpOp -> PrimValue -> PrimValue -> Maybe Bool
doCmpOp CmpOp
op' PrimValue
x PrimValue
y
        (OpUn UnOp
op', [PrimValue
x]) -> UnOp -> PrimValue -> Maybe PrimValue
doUnOp UnOp
op' PrimValue
x
        (OpConv ConvOp
op', [PrimValue
x]) -> ConvOp -> PrimValue -> Maybe PrimValue
doConvOp ConvOp
op' PrimValue
x
        (OpFn Text
fn, [PrimValue]
_) -> do
          (_, _, f) <- Text
-> Map Text ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
-> Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Text
fn Map Text ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
primFuns
          f pv
        (Op, [PrimValue])
_ -> String -> Maybe PrimValue
forall a. HasCallStack => String -> a
error String
"doOp': opTypeMatch"

    nonconstCase :: Depth -> ADMonad ADValue
nonconstCase Depth
dep = do
      -- In this case, some values are variables. We therefore
      -- have to perform the necessary steps for AD

      -- First, we calculate the value for the previous depth
      let oprev :: [ADValue]
oprev = (ADValue -> ADValue) -> [ADValue] -> [ADValue]
forall a b. (a -> b) -> [a] -> [b]
map (Depth -> ADValue -> ADValue
primalFor Depth
dep) [ADValue]
o
      vprev <- Op -> [ADValue] -> ADMonad ADValue
doOp' Op
op [ADValue]
oprev

      -- Then we separate the values of the maximum depth from
      -- those of a lower depth
      let o' = (ADValue -> Either ADValue ADVariable)
-> [ADValue] -> [Either ADValue ADVariable]
forall a b. (a -> b) -> [a] -> [b]
map (Depth -> ADValue -> Either ADValue ADVariable
divideDepths Depth
dep) [ADValue]
o
      -- Then we find out what type of AD is being performed
      case find isRight o' of
        -- Finally, we perform the necessary steps for the given
        -- type of AD
        Just (Right (VJP {})) ->
          Depth -> ADVariable -> ADValue
Variable Depth
dep (ADVariable -> ADValue) -> (Tape -> ADVariable) -> Tape -> ADValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VJPValue -> ADVariable
VJP (VJPValue -> ADVariable)
-> (Tape -> VJPValue) -> Tape -> ADVariable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tape -> VJPValue
VJPValue
            (Tape -> ADValue)
-> ExceptT String (State Counter) Tape -> ADMonad ADValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Op
-> [Either ADValue VJPValue]
-> ADValue
-> ExceptT String (State Counter) Tape
vjpHandleOp Op
op ((Either ADValue ADVariable -> Either ADValue VJPValue)
-> [Either ADValue ADVariable] -> [Either ADValue VJPValue]
forall a b. (a -> b) -> [a] -> [b]
map Either ADValue ADVariable -> Either ADValue VJPValue
extractVJP [Either ADValue ADVariable]
o') ADValue
vprev
        Just (Right (JVP {})) ->
          Depth -> ADVariable -> ADValue
Variable Depth
dep (ADVariable -> ADValue)
-> (ADValue -> ADVariable) -> ADValue -> ADValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. JVPValue -> ADVariable
JVP (JVPValue -> ADVariable)
-> (ADValue -> JVPValue) -> ADValue -> ADVariable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ADValue -> ADValue -> JVPValue
JVPValue ADValue
vprev
            (ADValue -> ADValue) -> ADMonad ADValue -> ADMonad ADValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Op -> [Either ADValue JVPValue] -> ADMonad ADValue
jvpHandleOp Op
op ((Either ADValue ADVariable -> Either ADValue JVPValue)
-> [Either ADValue ADVariable] -> [Either ADValue JVPValue]
forall a b. (a -> b) -> [a] -> [b]
map Either ADValue ADVariable -> Either ADValue JVPValue
extractJVP [Either ADValue ADVariable]
o')
        Maybe (Either ADValue ADVariable)
_ ->
          -- Since the maximum depth is non-zero, there must be at
          -- least one variable of depth > 0
          String -> ADMonad ADValue
forall a. HasCallStack => String -> a
error String
"find isRight"

calculatePDs :: Op -> [ADValue] -> ADMonad [ADValue]
calculatePDs :: Op -> [ADValue] -> ExceptT String (State Counter) [ADValue]
calculatePDs Op
op [ADValue]
args =
  -- Create a unique VName for each operand
  let n :: [VName]
n = (Int -> VName) -> [Int] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (\Int
i -> Name -> Int -> VName
VName (String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"x" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i) Int
i) [Int
1 .. [ADValue] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ADValue]
args]
      -- Put the operands in the environment
      m :: Map VName ADValue
m = [(VName, ADValue)] -> Map VName ADValue
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, ADValue)] -> Map VName ADValue)
-> [(VName, ADValue)] -> Map VName ADValue
forall a b. (a -> b) -> a -> b
$ [VName] -> [ADValue] -> [(VName, ADValue)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
n [ADValue]
args

      -- Look up, and calculate the partial derivative
      -- of the operation with respect to each operand
      pde :: [PrimExp VName]
pde =
        [PrimExp VName] -> Maybe [PrimExp VName] -> [PrimExp VName]
forall a. a -> Maybe a -> a
fromMaybe (String -> [PrimExp VName]
forall a. HasCallStack => String -> a
error String
"lookupPDs failed") (Maybe [PrimExp VName] -> [PrimExp VName])
-> Maybe [PrimExp VName] -> [PrimExp VName]
forall a b. (a -> b) -> a -> b
$
          Op -> [PrimExp VName] -> Maybe [PrimExp VName]
lookupPDs Op
op ([PrimExp VName] -> Maybe [PrimExp VName])
-> [PrimExp VName] -> Maybe [PrimExp VName]
forall a b. (a -> b) -> a -> b
$
            (VName -> ADValue -> PrimExp VName)
-> [VName] -> [ADValue] -> [PrimExp VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\VName
v ADValue
val -> VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
v (PrimType -> PrimExp VName) -> PrimType -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType (PrimValue -> PrimType) -> PrimValue -> PrimType
forall a b. (a -> b) -> a -> b
$ ADValue -> PrimValue
primitive ADValue
val) [VName]
n [ADValue]
args
      res :: ExceptT e' (State Counter) [ADValue]
res = (PrimExp VName -> ExceptT e' (State Counter) ADValue)
-> [PrimExp VName] -> ExceptT e' (State Counter) [ADValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\PrimExp VName
x -> ADMonad ADValue
-> (String -> ExceptT e' (State Counter) ADValue)
-> ExceptT e' (State Counter) ADValue
forall (m :: * -> *) e a e'.
Monad m =>
ExceptT e m a -> (e -> ExceptT e' m a) -> ExceptT e' m a
catchE (Map VName ADValue -> PrimExp VName -> ADMonad ADValue
evalPrimExp Map VName ADValue
m PrimExp VName
x) ((String -> ExceptT e' (State Counter) ADValue)
 -> ExceptT e' (State Counter) ADValue)
-> (String -> ExceptT e' (State Counter) ADValue)
-> ExceptT e' (State Counter) ADValue
forall a b. (a -> b) -> a -> b
$ String -> ExceptT e' (State Counter) ADValue
forall a. HasCallStack => String -> a
error (String -> ExceptT e' (State Counter) ADValue)
-> ShowS -> String -> ExceptT e' (State Counter) ADValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String
"evalPrimExp failed: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<>)) [PrimExp VName]
pde
   in ExceptT String (State Counter) [ADValue]
forall {e'}. ExceptT e' (State Counter) [ADValue]
res

-- VJP / Reverse mode automatic differentiation--
-- In reverse mode AD, the entire computation
-- leading up to a variable must be saved
-- This is represented as a Tape
newtype VJPValue = VJPValue Tape
  deriving (Int -> VJPValue -> ShowS
[VJPValue] -> ShowS
VJPValue -> String
(Int -> VJPValue -> ShowS)
-> (VJPValue -> String) -> ([VJPValue] -> ShowS) -> Show VJPValue
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> VJPValue -> ShowS
showsPrec :: Int -> VJPValue -> ShowS
$cshow :: VJPValue -> String
show :: VJPValue -> String
$cshowList :: [VJPValue] -> ShowS
showList :: [VJPValue] -> ShowS
Show)

-- | Represents a computation tree, as well as every intermediate
-- value in its evaluation.
data Tape
  = -- | This represents a variable. Each variable is given a unique ID,
    -- and has an initial value
    TapeID Counter ADValue
  | -- | This represents a constant.
    TapeConst ADValue
  | -- | This represents the application of a mathematical operation.
    -- Each parameter is given by its Tape, and the return value of
    -- the operation is saved
    TapeOp Op [Tape] Counter ADValue
  deriving (Int -> Tape -> ShowS
[Tape] -> ShowS
Tape -> String
(Int -> Tape -> ShowS)
-> (Tape -> String) -> ([Tape] -> ShowS) -> Show Tape
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Tape -> ShowS
showsPrec :: Int -> Tape -> ShowS
$cshow :: Tape -> String
show :: Tape -> String
$cshowList :: [Tape] -> ShowS
showList :: [Tape] -> ShowS
Show)

-- | Returns the primal value of a Tape.
tapePrimal :: Tape -> ADValue
tapePrimal :: Tape -> ADValue
tapePrimal (TapeID Counter
_ ADValue
v) = ADValue
v
tapePrimal (TapeConst ADValue
v) = ADValue
v
tapePrimal (TapeOp Op
_ [Tape]
_ Counter
_ ADValue
v) = ADValue
v

-- This updates Tape of a VJPValue with a new operation,
-- treating all operands of a lower depth as constants
vjpHandleOp :: Op -> [Either ADValue VJPValue] -> ADValue -> ADMonad Tape
vjpHandleOp :: Op
-> [Either ADValue VJPValue]
-> ADValue
-> ExceptT String (State Counter) Tape
vjpHandleOp Op
op [Either ADValue VJPValue]
p ADValue
v = do
  i <- State Counter Counter -> ExceptT String (State Counter) Counter
forall (m :: * -> *) a. Monad m => m a -> ExceptT String m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift State Counter Counter
forall (m :: * -> *) s. Monad m => StateT s m s
get
  pure $ TapeOp op (map toTape p) i v
  where
    toTape :: Either ADValue VJPValue -> Tape
toTape (Left ADValue
v') = ADValue -> Tape
TapeConst ADValue
v'
    toTape (Right (VJPValue Tape
t)) = Tape
t

unionWithM :: (Monad m, Ord k) => (a -> a -> m a) -> M.Map k a -> M.Map k a -> m (M.Map k a)
unionWithM :: forall (m :: * -> *) k a.
(Monad m, Ord k) =>
(a -> a -> m a) -> Map k a -> Map k a -> m (Map k a)
unionWithM a -> a -> m a
f Map k a
m1 Map k a
m2 = do
  let m :: Map k a
m = Map k a -> Map k a -> Map k a
forall k a. Ord k => Map k a -> Map k a -> Map k a
M.union (Map k a -> Map k a -> Map k a
forall k a b. Ord k => Map k a -> Map k b -> Map k a
M.difference Map k a
m1 Map k a
m2) (Map k a -> Map k a -> Map k a
forall k a b. Ord k => Map k a -> Map k b -> Map k a
M.difference Map k a
m2 Map k a
m1)
  let k :: [k]
k = Map k a -> [k]
forall k a. Map k a -> [k]
M.keys (Map k a -> [k]) -> Map k a -> [k]
forall a b. (a -> b) -> a -> b
$ Map k a -> Map k a -> Map k a
forall k a b. Ord k => Map k a -> Map k b -> Map k a
M.intersection Map k a
m1 Map k a
m2
  v <- (k -> m a) -> [k] -> m [a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\k
k' -> a -> a -> m a
f (Maybe a -> a
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe a -> a) -> Maybe a -> a
forall a b. (a -> b) -> a -> b
$ k -> Map k a -> Maybe a
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup k
k' Map k a
m1) (Maybe a -> a
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe a -> a) -> Maybe a -> a
forall a b. (a -> b) -> a -> b
$ k -> Map k a -> Maybe a
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup k
k' Map k a
m2)) [k]
k
  pure $ foldl (\Map k a
m' (k
k', a
v') -> k -> a -> Map k a -> Map k a
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert k
k' a
v' Map k a
m') m (zip k v)

unionsWithM :: (Foldable f, Monad m, Ord k) => (a -> a -> m a) -> f (M.Map k a) -> m (M.Map k a)
unionsWithM :: forall (f :: * -> *) (m :: * -> *) k a.
(Foldable f, Monad m, Ord k) =>
(a -> a -> m a) -> f (Map k a) -> m (Map k a)
unionsWithM a -> a -> m a
f = (Map k a -> Map k a -> m (Map k a))
-> Map k a -> f (Map k a) -> m (Map k a)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ((a -> a -> m a) -> Map k a -> Map k a -> m (Map k a)
forall (m :: * -> *) k a.
(Monad m, Ord k) =>
(a -> a -> m a) -> Map k a -> Map k a -> m (Map k a)
unionWithM a -> a -> m a
f) Map k a
forall k a. Map k a
M.empty

-- | This calculates every partial derivative of a 'Tape'. The result
-- is a map of the partial derivatives, each key corresponding to the
-- ID of a free variable (see TapeID).
deriveTape :: Tape -> ADValue -> Counter -> Either String (M.Map Counter ADValue, Counter)
deriveTape :: Tape
-> ADValue
-> Counter
-> Either String (Map Counter ADValue, Counter)
deriveTape Tape
tp ADValue
s Counter
uid = case State Counter (Either String (Map Counter ADValue))
-> Counter -> (Either String (Map Counter ADValue), Counter)
forall s a. State s a -> s -> (a, s)
runState (ExceptT String (State Counter) (Map Counter ADValue)
-> State Counter (Either String (Map Counter ADValue))
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT String (State Counter) (Map Counter ADValue)
 -> State Counter (Either String (Map Counter ADValue)))
-> ExceptT String (State Counter) (Map Counter ADValue)
-> State Counter (Either String (Map Counter ADValue))
forall a b. (a -> b) -> a -> b
$ Tape
-> ADValue -> ExceptT String (State Counter) (Map Counter ADValue)
deriveTape' Tape
tp ADValue
s) Counter
uid of
  (Left String
e, Counter
_) -> String -> Either String (Map Counter ADValue, Counter)
forall a b. a -> Either a b
Left String
e
  (Right Map Counter ADValue
v, Counter
uid') -> (Map Counter ADValue, Counter)
-> Either String (Map Counter ADValue, Counter)
forall a b. b -> Either a b
Right (Map Counter ADValue
v, Counter
uid')

deriveTape' :: Tape -> ADValue -> ADMonad (M.Map Counter ADValue)
deriveTape' :: Tape
-> ADValue -> ExceptT String (State Counter) (Map Counter ADValue)
deriveTape' (TapeID Counter
i ADValue
_) ADValue
s = Map Counter ADValue
-> ExceptT String (State Counter) (Map Counter ADValue)
forall a. a -> ExceptT String (State Counter) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map Counter ADValue
 -> ExceptT String (State Counter) (Map Counter ADValue))
-> Map Counter ADValue
-> ExceptT String (State Counter) (Map Counter ADValue)
forall a b. (a -> b) -> a -> b
$ Counter -> ADValue -> Map Counter ADValue
forall k a. k -> a -> Map k a
M.singleton Counter
i ADValue
s
deriveTape' (TapeConst ADValue
_) ADValue
_ = Map Counter ADValue
-> ExceptT String (State Counter) (Map Counter ADValue)
forall a. a -> ExceptT String (State Counter) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Map Counter ADValue
forall k a. Map k a
M.empty
deriveTape' tp :: Tape
tp@(TapeOp Op
op [Tape]
p Counter
uid ADValue
_) ADValue
s =
  (Map Counter ADValue, Map Counter Int) -> Map Counter ADValue
forall a b. (a, b) -> a
fst ((Map Counter ADValue, Map Counter Int) -> Map Counter ADValue)
-> ExceptT
     String (State Counter) (Map Counter ADValue, Map Counter Int)
-> ExceptT String (State Counter) (Map Counter ADValue)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tape
-> ADValue
-> Map Counter ADValue
-> Map Counter Int
-> ExceptT
     String (State Counter) (Map Counter ADValue, Map Counter Int)
derive Tape
tp ADValue
s Map Counter ADValue
forall k a. Map k a
M.empty ([Tape] -> Map Counter Int -> Map Counter Int
countReferences [Tape]
p (Map Counter Int -> Map Counter Int)
-> Map Counter Int -> Map Counter Int
forall a b. (a -> b) -> a -> b
$ Counter -> Int -> Map Counter Int
forall k a. k -> a -> Map k a
M.singleton (-Counter
uid Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
- Counter
1) Int
1)
  where
    add :: ADValue -> ADValue -> ADMonad ADValue
add ADValue
x ADValue
y = Op -> [ADValue] -> ADMonad ADValue
doOp' (BinOp -> Op
OpBin (BinOp -> Op) -> BinOp -> Op
forall a b. (a -> b) -> a -> b
$ PrimType -> BinOp
addFor (PrimType -> BinOp) -> PrimType -> BinOp
forall a b. (a -> b) -> a -> b
$ Op -> PrimType
opReturnType Op
op) [ADValue
x, ADValue
y]
    mul :: ADValue -> ADValue -> ADMonad ADValue
mul ADValue
x ADValue
y = Op -> [ADValue] -> ADMonad ADValue
doOp' (BinOp -> Op
OpBin (BinOp -> Op) -> BinOp -> Op
forall a b. (a -> b) -> a -> b
$ PrimType -> BinOp
mulFor (PrimType -> BinOp) -> PrimType -> BinOp
forall a b. (a -> b) -> a -> b
$ Op -> PrimType
opReturnType Op
op) [ADValue
x, ADValue
y]
    madd :: Counter -> ADValue -> M.Map Counter ADValue -> ADMonad (M.Map Counter ADValue)
    madd :: Counter
-> ADValue
-> Map Counter ADValue
-> ExceptT String (State Counter) (Map Counter ADValue)
madd Counter
i ADValue
a Map Counter ADValue
m = case Counter -> Map Counter ADValue -> Maybe ADValue
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Counter
i Map Counter ADValue
m of
      Just ADValue
b -> ADValue -> ADValue -> ADMonad ADValue
add ADValue
a ADValue
b ADMonad ADValue
-> (ADValue -> Map Counter ADValue)
-> ExceptT String (State Counter) (Map Counter ADValue)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (\ADValue
x -> Counter -> ADValue -> Map Counter ADValue -> Map Counter ADValue
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Counter
i ADValue
x Map Counter ADValue
m)
      Maybe ADValue
Nothing -> Map Counter ADValue
-> ExceptT String (State Counter) (Map Counter ADValue)
forall a. a -> ExceptT String (State Counter) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map Counter ADValue
 -> ExceptT String (State Counter) (Map Counter ADValue))
-> Map Counter ADValue
-> ExceptT String (State Counter) (Map Counter ADValue)
forall a b. (a -> b) -> a -> b
$ Counter -> ADValue -> Map Counter ADValue -> Map Counter ADValue
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Counter
i ADValue
a Map Counter ADValue
m
    derive ::
      Tape ->
      ADValue ->
      M.Map Counter ADValue ->
      M.Map Counter Int ->
      ADMonad (M.Map Counter ADValue, M.Map Counter Int)
    derive :: Tape
-> ADValue
-> Map Counter ADValue
-> Map Counter Int
-> ExceptT
     String (State Counter) (Map Counter ADValue, Map Counter Int)
derive (TapeID Counter
i ADValue
_) ADValue
s' Map Counter ADValue
ss Map Counter Int
rs = Counter
-> ADValue
-> Map Counter ADValue
-> ExceptT String (State Counter) (Map Counter ADValue)
madd Counter
i ADValue
s' Map Counter ADValue
ss ExceptT String (State Counter) (Map Counter ADValue)
-> (Map Counter ADValue -> (Map Counter ADValue, Map Counter Int))
-> ExceptT
     String (State Counter) (Map Counter ADValue, Map Counter Int)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (,Map Counter Int
rs)
    derive (TapeConst ADValue
_) ADValue
_ Map Counter ADValue
ss Map Counter Int
rs = (Map Counter ADValue, Map Counter Int)
-> ExceptT
     String (State Counter) (Map Counter ADValue, Map Counter Int)
forall a. a -> ExceptT String (State Counter) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map Counter ADValue
ss, Map Counter Int
rs)
    derive (TapeOp Op
op' [Tape]
p' Counter
uid' ADValue
_) ADValue
s' Map Counter ADValue
ss Map Counter Int
rs = do
      -- Decrease the reference counter
      let r :: Int
r = Maybe Int -> Int
forall a. HasCallStack => Maybe a -> a
fromJust (Counter -> Map Counter Int -> Maybe Int
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (-Counter
uid' Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
- Counter
1) Map Counter Int
rs) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
          rs' :: Map Counter Int
rs' = Counter -> Int -> Map Counter Int -> Map Counter Int
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (-Counter
uid' Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
- Counter
1) Int
r Map Counter Int
rs
      -- Add the sensitivity
      ss' <- Counter
-> ADValue
-> Map Counter ADValue
-> ExceptT String (State Counter) (Map Counter ADValue)
madd (-Counter
uid' Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
- Counter
1) ADValue
s' Map Counter ADValue
ss
      -- If there are still more references left, do nothing
      if r > 0
        then pure (ss', rs')
        else -- Otherwise, derive the tape
          if r == 0
            then do
              let s'' = Maybe ADValue -> ADValue
forall a. HasCallStack => Maybe a -> a
fromJust (Counter -> Map Counter ADValue -> Maybe ADValue
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (-Counter
uid' Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
- Counter
1) Map Counter ADValue
ss')

              -- Calculate the new sensitivities
              s''' <- case op' of
                OpConv ConvOp
op'' ->
                  -- In case of type conversion, simply convert the sensitivity
                  [ADMonad ADValue] -> ExceptT String (State Counter) [ADValue]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence [Op -> [ADValue] -> ADMonad ADValue
doOp' (ConvOp -> Op
OpConv (ConvOp -> Op) -> ConvOp -> Op
forall a b. (a -> b) -> a -> b
$ ConvOp -> ConvOp
flipConvOp ConvOp
op'') [ADValue
s'']]
                Op
_ -> Op -> [ADValue] -> ExceptT String (State Counter) [ADValue]
calculatePDs Op
op' ((Tape -> ADValue) -> [Tape] -> [ADValue]
forall a b. (a -> b) -> [a] -> [b]
map Tape -> ADValue
tapePrimal [Tape]
p') ExceptT String (State Counter) [ADValue]
-> ([ADValue] -> ExceptT String (State Counter) [ADValue])
-> ExceptT String (State Counter) [ADValue]
forall a b.
ExceptT String (State Counter) a
-> (a -> ExceptT String (State Counter) b)
-> ExceptT String (State Counter) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (ADValue -> ADMonad ADValue)
-> [ADValue] -> ExceptT String (State Counter) [ADValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (ADValue -> ADValue -> ADMonad ADValue
mul ADValue
s'')

              -- Propagate the new sensitivities
              foldlM (\(Map Counter ADValue
ss'', Map Counter Int
rs'') (Tape
p'', ADValue
s'''') -> Tape
-> ADValue
-> Map Counter ADValue
-> Map Counter Int
-> ExceptT
     String (State Counter) (Map Counter ADValue, Map Counter Int)
derive Tape
p'' ADValue
s'''' Map Counter ADValue
ss'' Map Counter Int
rs'') (ss', rs') $ zip p' s'''
            else error "TODO: This branch is unreachable unless `countReferences` undercounts"
    countReferences :: [Tape] -> M.Map Counter Int -> M.Map Counter Int
    countReferences :: [Tape] -> Map Counter Int -> Map Counter Int
countReferences [Tape]
p' Map Counter Int
d' = (Map Counter Int -> Tape -> Map Counter Int)
-> Map Counter Int -> [Tape] -> Map Counter Int
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Map Counter Int -> Tape -> Map Counter Int
f Map Counter Int
d' [Tape]
p'
    f :: Map Counter Int -> Tape -> Map Counter Int
f Map Counter Int
d'' Tape
x =
      case Tape
x of
        (TapeOp Op
_ [Tape]
p'' Counter
uid'' ADValue
_) -> case Counter -> Map Counter Int -> Maybe Int
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (-Counter
uid'' Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
- Counter
1) Map Counter Int
d'' of
          Just Int
v -> Counter -> Int -> Map Counter Int -> Map Counter Int
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (-Counter
uid'' Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
- Counter
1) (Int
v Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Map Counter Int
d''
          Maybe Int
Nothing -> [Tape] -> Map Counter Int -> Map Counter Int
countReferences [Tape]
p'' (Map Counter Int -> Map Counter Int)
-> Map Counter Int -> Map Counter Int
forall a b. (a -> b) -> a -> b
$ Counter -> Int -> Map Counter Int -> Map Counter Int
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (-Counter
uid'' Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
- Counter
1) Int
1 Map Counter Int
d''
        Tape
_ -> Map Counter Int
d''

-- JVP / Forward mode automatic differentiation--

-- | In JVP, the derivative of the variable must be saved. This is
-- represented as a second value.
data JVPValue = JVPValue ADValue ADValue
  deriving (Int -> JVPValue -> ShowS
[JVPValue] -> ShowS
JVPValue -> String
(Int -> JVPValue -> ShowS)
-> (JVPValue -> String) -> ([JVPValue] -> ShowS) -> Show JVPValue
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> JVPValue -> ShowS
showsPrec :: Int -> JVPValue -> ShowS
$cshow :: JVPValue -> String
show :: JVPValue -> String
$cshowList :: [JVPValue] -> ShowS
showList :: [JVPValue] -> ShowS
Show)

-- | This calculates the tangent part of the JVPValue resulting
-- from the application of a mathematical operation on one or more
-- JVPValues.
jvpHandleOp :: Op -> [Either ADValue JVPValue] -> ADMonad ADValue
jvpHandleOp :: Op -> [Either ADValue JVPValue] -> ADMonad ADValue
jvpHandleOp Op
op [Either ADValue JVPValue]
p = do
  case Op
op of
    OpConv ConvOp
_ ->
      -- In case of type conversion, simply convert
      -- the old tangent
      Op -> [ADValue] -> ADMonad ADValue
doOp' Op
op [Either ADValue JVPValue -> ADValue
forall {a}. Either a JVPValue -> ADValue
tangent (Either ADValue JVPValue -> ADValue)
-> Either ADValue JVPValue -> ADValue
forall a b. (a -> b) -> a -> b
$ [Either ADValue JVPValue] -> Either ADValue JVPValue
forall a. HasCallStack => [a] -> a
head [Either ADValue JVPValue]
p]
    Op
_ -> do
      -- Calculate the new tangent using the chain rule
      pds <- Op -> [ADValue] -> ExceptT String (State Counter) [ADValue]
calculatePDs Op
op ([ADValue] -> ExceptT String (State Counter) [ADValue])
-> [ADValue] -> ExceptT String (State Counter) [ADValue]
forall a b. (a -> b) -> a -> b
$ (Either ADValue JVPValue -> ADValue)
-> [Either ADValue JVPValue] -> [ADValue]
forall a b. (a -> b) -> [a] -> [b]
map Either ADValue JVPValue -> ADValue
primal' [Either ADValue JVPValue]
p
      vs <- zipWithM mul pds $ map tangent p
      foldM add (Constant $ blankPrimValue op_t) vs
  where
    op_t :: PrimType
op_t = Op -> PrimType
opReturnType Op
op
    primal' :: Either ADValue JVPValue -> ADValue
primal' (Left ADValue
v) = ADValue
v
    primal' (Right (JVPValue ADValue
v ADValue
_)) = ADValue
v
    tangent :: Either a JVPValue -> ADValue
tangent (Left a
_) = PrimValue -> ADValue
Constant (PrimValue -> ADValue) -> PrimValue -> ADValue
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue (PrimType -> PrimValue) -> PrimType -> PrimValue
forall a b. (a -> b) -> a -> b
$ Op -> PrimType
opReturnType Op
op
    tangent (Right (JVPValue ADValue
_ ADValue
d)) = ADValue
d
    add :: ADValue -> ADValue -> ADMonad ADValue
add ADValue
x ADValue
y = Op -> [ADValue] -> ADMonad ADValue
doOp' (BinOp -> Op
OpBin (BinOp -> Op) -> BinOp -> Op
forall a b. (a -> b) -> a -> b
$ PrimType -> BinOp
addFor (PrimType -> BinOp) -> PrimType -> BinOp
forall a b. (a -> b) -> a -> b
$ Op -> PrimType
opReturnType Op
op) [ADValue
x, ADValue
y]
    mul :: ADValue -> ADValue -> ADMonad ADValue
mul ADValue
x ADValue
y = Op -> [ADValue] -> ADMonad ADValue
doOp' (BinOp -> Op
OpBin (BinOp -> Op) -> BinOp -> Op
forall a b. (a -> b) -> a -> b
$ PrimType -> BinOp
mulFor (PrimType -> BinOp) -> PrimType -> BinOp
forall a b. (a -> b) -> a -> b
$ Op -> PrimType
opReturnType Op
op) [ADValue
x, ADValue
y]