{-# LANGUAGE TypeFamilies #-}

-- Naming scheme:
--
-- An adjoint-related object for "x" is named "x_adj".  This means
-- both actual adjoints and statements.
--
-- Do not assume "x'" means anything related to derivatives.
module Futhark.AD.Rev (revVJP) where

import Control.Monad
import Data.List ((\\))
import Data.List.NonEmpty (NonEmpty (..))
import Data.Map qualified as M
import Futhark.AD.Derivatives
import Futhark.AD.Rev.Loop
import Futhark.AD.Rev.Monad
import Futhark.AD.Rev.SOAC
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util (takeLast)

patName :: Pat Type -> ADM VName
patName :: Pat Type -> ADM VName
patName (Pat [PatElem Type
pe]) = VName -> ADM VName
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> ADM VName) -> VName -> ADM VName
forall a b. (a -> b) -> a -> b
$ PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName PatElem Type
pe
patName Pat Type
pat = String -> ADM VName
forall a. HasCallStack => String -> a
error (String -> ADM VName) -> String -> ADM VName
forall a b. (a -> b) -> a -> b
$ String
"Expected single-element pattern: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Pat Type -> String
forall a. Pretty a => a -> String
prettyString Pat Type
pat

copyIfArray :: VName -> ADM VName
copyIfArray :: VName -> ADM VName
copyIfArray VName
v = do
  v_t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
  case v_t of
    Array {} ->
      String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
v String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_copy") (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
        Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty (VName -> SubExp
Var VName
v)
    Type
_ -> VName -> ADM VName
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v

-- The vast majority of BasicOps require no special treatment in the
-- forward pass and produce one value (and hence one adjoint).  We
-- deal with that case here.
commonBasicOp :: Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp :: Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
op ADM ()
m = do
  Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep ADM) -> ADM ()) -> Stm (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep ADM))
-> StmAux (ExpDec (Rep ADM)) -> Exp (Rep ADM) -> Stm (Rep ADM)
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec (Rep ADM))
pat StmAux ()
StmAux (ExpDec (Rep ADM))
aux (Exp (Rep ADM) -> Stm (Rep ADM)) -> Exp (Rep ADM) -> Stm (Rep ADM)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp BasicOp
op
  ADM ()
m
  pat_v <- Pat Type -> ADM VName
patName Pat Type
pat
  pat_adj <- lookupAdjVal pat_v
  pure (pat_v, pat_adj)

diffBasicOp :: Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM ()
diffBasicOp :: Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM ()
diffBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m =
  case BasicOp
e of
    CmpOp {} ->
      ADM (VName, VName) -> ADM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ADM (VName, VName) -> ADM ()) -> ADM (VName, VName) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
    --
    ConvOp ConvOp
op SubExp
x -> do
      (_pat_v, pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      returnSweepCode $ do
        contrib <-
          letExp "contrib" $ BasicOp $ ConvOp (flipConvOp op) $ Var pat_adj
        updateSubExpAdj x contrib
    --
    UnOp UnOp
op SubExp
x -> do
      (_pat_v, pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m

      returnSweepCode $ do
        let t = UnOp -> PrimType
unOpType UnOp
op
        contrib <- do
          let x_pe = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t SubExp
x
              pat_adj' = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t (VName -> SubExp
Var VName
pat_adj)
              dx = UnOp -> PrimExp VName -> PrimExp VName
pdUnOp UnOp
op PrimExp VName
x_pe
          letExp "contrib" <=< toExp $ pat_adj' ~*~ dx

        updateSubExpAdj x contrib
    --
    BinOp BinOp
op SubExp
x SubExp
y -> do
      (_pat_v, pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m

      returnSweepCode $ do
        let t = BinOp -> PrimType
binOpType BinOp
op
            (wrt_x, wrt_y) =
              pdBinOp op (primExpFromSubExp t x) (primExpFromSubExp t y)

            pat_adj' = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t (SubExp -> PrimExp VName) -> SubExp -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
pat_adj

        adj_x <- letExp "binop_x_adj" <=< toExp $ pat_adj' ~*~ wrt_x
        adj_y <- letExp "binop_y_adj" <=< toExp $ pat_adj' ~*~ wrt_y
        updateSubExpAdj x adj_x
        updateSubExpAdj y adj_y
    --
    SubExp SubExp
se -> do
      (_pat_v, pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      returnSweepCode $ updateSubExpAdj se pat_adj
    --
    Assert {} ->
      ADM (VName, VName) -> ADM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ADM (VName, VName) -> ADM ()) -> ADM (VName, VName) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
    --
    ArrayVal {} ->
      ADM (VName, VName) -> ADM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ADM (VName, VName) -> ADM ()) -> ADM (VName, VName) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
    --
    ArrayLit [SubExp]
elems Type
_ -> do
      (_pat_v, pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      t <- lookupType pat_adj
      returnSweepCode $ do
        forM_ (zip [(0 :: Int64) ..] elems) $ \(Int64
i, SubExp
se) -> do
          let slice :: Slice SubExp
slice = Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant Int64
i)]
          SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
se (VName -> ADM ())
-> (Exp SOACS -> ADM VName) -> Exp SOACS -> ADM ()
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"elem_adj" (Exp SOACS -> ADM ()) -> Exp SOACS -> ADM ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
pat_adj Slice SubExp
slice
    --
    Index VName
arr Slice SubExp
slice -> do
      (_pat_v, pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      returnSweepCode $ void $ updateAdjSlice slice arr pat_adj
    FlatIndex {} -> String -> ADM ()
forall a. HasCallStack => String -> a
error String
"FlatIndex not handled by AD yet."
    FlatUpdate {} -> String -> ADM ()
forall a. HasCallStack => String -> a
error String
"FlatUpdate not handled by AD yet."
    --
    Opaque OpaqueOp
_ SubExp
se -> do
      (_pat_v, pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      returnSweepCode $ updateSubExpAdj se pat_adj
    --
    Reshape VName
arr NewShape SubExp
newshape -> do
      (_pat_v, pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      returnSweepCode $ do
        arr_shape <- arrayShape <$> lookupType arr
        void $
          updateAdj arr <=< letExp "adj_reshape" . BasicOp $
            Reshape pat_adj (reshapeAll (newShape newshape) arr_shape)
    --
    Rearrange VName
arr [Int]
perm -> do
      (_pat_v, pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      returnSweepCode $
        void $
          updateAdj arr <=< letExp "adj_rearrange" . BasicOp $
            Rearrange pat_adj (rearrangeInverse perm)
    --
    Replicate (Shape []) (Var VName
se) -> do
      (_pat_v, pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      returnSweepCode $ void $ updateAdj se pat_adj
    --
    Replicate (Shape [SubExp]
ns) SubExp
x -> do
      (_pat_v, pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      returnSweepCode $ do
        x_t <- subExpType x
        lam <- addLambda x_t
        ne <- letSubExp "zero" $ zeroExp x_t
        n <- letSubExp "rep_size" =<< foldBinOp (Mul Int64 OverflowUndef) (intConst Int64 1) ns
        pat_adj_flat <-
          letExp (baseString pat_adj <> "_flat") . BasicOp $
            Reshape pat_adj (reshapeAll (Shape ns) (Shape $ n : arrayDims x_t))
        reduce <- reduceSOAC [Reduce Commutative lam [ne]]
        updateSubExpAdj x
          =<< letExp "rep_contrib" (Op $ Screma n [pat_adj_flat] reduce)
    --
    Concat Int
d (VName
arr :| [VName]
arrs) SubExp
_ -> do
      (_pat_v, pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      returnSweepCode $ do
        let sliceAdj SubExp
_ [] = [VName] -> ADM [VName]
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
            sliceAdj SubExp
start (VName
v : [VName]
vs) = do
              v_t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
              let w = Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
v_t
                  slice = SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
start SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
              pat_adj_slice <-
                letExp (baseString pat_adj <> "_slice") $
                  BasicOp $
                    Index pat_adj (sliceAt v_t d [slice])
              start' <- letSubExp "start" $ BasicOp $ BinOp (Add Int64 OverflowUndef) start w
              slices <- sliceAdj start' vs
              pure $ pat_adj_slice : slices

        slices <- sliceAdj (intConst Int64 0) $ arr : arrs

        zipWithM_ updateAdj (arr : arrs) slices
    --
    Manifest VName
se [Int]
_ -> do
      (_pat_v, pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      returnSweepCode $ void $ updateAdj se pat_adj
    --
    Scratch {} ->
      ADM (VName, VName) -> ADM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ADM (VName, VName) -> ADM ()) -> ADM (VName, VName) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
    --
    Iota SubExp
n SubExp
_ SubExp
_ IntType
t -> do
      (_pat_v, pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      returnSweepCode $ do
        ne <- letSubExp "zero" $ zeroExp $ Prim $ IntType t
        lam <- addLambda $ Prim $ IntType t
        reduce <- reduceSOAC [Reduce Commutative lam [ne]]
        updateSubExpAdj n
          =<< letExp "iota_contrib" (Op $ Screma n [pat_adj] reduce)
    --
    Update Safety
safety VName
arr Slice SubExp
slice SubExp
v -> do
      (_pat_v, pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      returnSweepCode $ do
        v_adj <- letExp "update_val_adj" $ BasicOp $ Index pat_adj slice
        v_adj_copy <- copyIfArray v_adj
        updateSubExpAdj v v_adj_copy
        zeroes <- letSubExp "update_zero" . zeroExp =<< subExpType v
        void $
          updateAdj arr
            =<< letExp "update_src_adj" (BasicOp $ Update safety pat_adj slice zeroes)
    -- See Note [Adjoints of accumulators]
    UpdateAcc Safety
_ VName
_ [SubExp]
is [SubExp]
vs -> do
      Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep ADM) -> ADM ()) -> Stm (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep ADM))
-> StmAux (ExpDec (Rep ADM)) -> Exp (Rep ADM) -> Stm (Rep ADM)
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec (Rep ADM))
pat StmAux ()
StmAux (ExpDec (Rep ADM))
aux (Exp (Rep ADM) -> Stm (Rep ADM)) -> Exp (Rep ADM) -> Stm (Rep ADM)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp BasicOp
e
      ADM ()
m
      pat_adjs <- (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM VName
lookupAdjVal (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
pat)
      returnSweepCode $ do
        forM_ (zip pat_adjs vs) $ \(VName
adj, SubExp
v) -> do
          adj_i <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"updateacc_val_adj" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
adj (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ (SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix [SubExp]
is
          updateSubExpAdj v adj_i

vjpOps :: VjpOps
vjpOps :: VjpOps
vjpOps =
  VjpOps
    { vjpLambda :: [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda = [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
diffLambda,
      vjpStm :: Stm SOACS -> ADM () -> ADM ()
vjpStm = Stm SOACS -> ADM () -> ADM ()
diffStm
    }

diffStm :: Stm SOACS -> ADM () -> ADM ()
diffStm :: Stm SOACS -> ADM () -> ADM ()
diffStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (BasicOp BasicOp
e)) ADM ()
m =
  Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM ()
diffBasicOp Pat Type
Pat (LetDec SOACS)
pat StmAux ()
StmAux (ExpDec SOACS)
aux BasicOp
e ADM ()
m
diffStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
_ (Apply Name
f [(SubExp, Diet)]
args [(RetType SOACS, RetAls)]
_ Safety
_)) ADM ()
m
  | Just (PrimType
ret, [PrimType]
argts) <- Name
-> Map Name (PrimType, [PrimType]) -> Maybe (PrimType, [PrimType])
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Name
f Map Name (PrimType, [PrimType])
builtInFunctions = do
      Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm (Rep ADM)
Stm SOACS
stm
      ADM ()
m

      pat_adj <- VName -> ADM VName
lookupAdjVal (VName -> ADM VName) -> ADM VName -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Pat Type -> ADM VName
patName Pat Type
Pat (LetDec SOACS)
pat
      let arg_pes = (PrimType -> SubExp -> PrimExp VName)
-> [PrimType] -> [SubExp] -> [PrimExp VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PrimType -> SubExp -> PrimExp VName
primExpFromSubExp [PrimType]
argts (((SubExp, Diet) -> SubExp) -> [(SubExp, Diet)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, Diet) -> SubExp
forall a b. (a, b) -> a
fst [(SubExp, Diet)]
args)
          pat_adj' = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
ret (VName -> SubExp
Var VName
pat_adj)
          convert PrimType
ft PrimType
tt
            | PrimType
ft PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
tt = PrimExp VName -> PrimExp VName
forall a. a -> a
id
          convert (IntType IntType
ft) (IntType IntType
tt) = ConvOp -> PrimExp VName -> PrimExp VName
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> IntType -> ConvOp
SExt IntType
ft IntType
tt)
          convert (FloatType FloatType
ft) (FloatType FloatType
tt) = ConvOp -> PrimExp VName -> PrimExp VName
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> FloatType -> ConvOp
FPConv FloatType
ft FloatType
tt)
          convert PrimType
Bool (FloatType FloatType
tt) = ConvOp -> PrimExp VName -> PrimExp VName
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> ConvOp
BToF FloatType
tt)
          convert (FloatType FloatType
ft) PrimType
Bool = ConvOp -> PrimExp VName -> PrimExp VName
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> ConvOp
FToB FloatType
ft)
          convert PrimType
ft PrimType
tt = String -> PrimExp VName -> PrimExp VName
forall a. HasCallStack => String -> a
error (String -> PrimExp VName -> PrimExp VName)
-> String -> PrimExp VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ String
"diffStm.convert: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ (Name, PrimType, PrimType) -> String
forall a. Pretty a => a -> String
prettyString (Name
f, PrimType
ft, PrimType
tt)

      contribs <-
        case pdBuiltin f arg_pes of
          Maybe [PrimExp VName]
Nothing ->
            String -> ADM [VName]
forall a. HasCallStack => String -> a
error (String -> ADM [VName]) -> String -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ String
"No partial derivative defined for builtin function: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Pretty a => a -> String
prettyString Name
f
          Just [PrimExp VName]
derivs ->
            [(PrimExp VName, PrimType)]
-> ((PrimExp VName, PrimType) -> ADM VName) -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([PrimExp VName] -> [PrimType] -> [(PrimExp VName, PrimType)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PrimExp VName]
derivs [PrimType]
argts) (((PrimExp VName, PrimType) -> ADM VName) -> ADM [VName])
-> ((PrimExp VName, PrimType) -> ADM VName) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ \(PrimExp VName
deriv, PrimType
argt) ->
              String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"contrib" (Exp SOACS -> ADM VName)
-> (PrimExp VName -> ADM (Exp SOACS)) -> PrimExp VName -> ADM VName
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< PrimExp VName -> ADM (Exp (Rep ADM))
PrimExp VName -> ADM (Exp SOACS)
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
PrimExp VName -> m (Exp (Rep m))
toExp (PrimExp VName -> ADM (Exp SOACS))
-> (PrimExp VName -> PrimExp VName)
-> PrimExp VName
-> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> PrimType -> PrimExp VName -> PrimExp VName
convert PrimType
ret PrimType
argt (PrimExp VName -> ADM VName) -> PrimExp VName -> ADM VName
forall a b. (a -> b) -> a -> b
$ PrimExp VName
pat_adj' PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
~*~ PrimExp VName
deriv

      zipWithM_ updateSubExpAdj (map fst args) contribs
diffStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
_ (Match [SubExp]
ses [Case (Body SOACS)]
cases Body SOACS
defbody MatchDec (BranchType SOACS)
_)) ADM ()
m = do
  Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm (Rep ADM)
Stm SOACS
stm
  ADM ()
m
  ADM () -> ADM ()
forall a. ADM a -> ADM a
returnSweepCode (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ do
    let cases_free :: [Names]
cases_free = (Case (Body SOACS) -> Names) -> [Case (Body SOACS)] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map Case (Body SOACS) -> Names
forall a. FreeIn a => a -> Names
freeIn [Case (Body SOACS)]
cases
        defbody_free :: Names
defbody_free = Body SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn Body SOACS
defbody
        branches_free :: [VName]
branches_free = Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ Names
defbody_free Names -> [Names] -> [Names]
forall a. a -> [a] -> [a]
: [Names]
cases_free

    adjs <- (VName -> ADM Adj) -> [VName] -> ADM [Adj]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM Adj
lookupAdj ([VName] -> ADM [Adj]) -> [VName] -> ADM [Adj]
forall a b. (a -> b) -> a -> b
$ Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
Pat (LetDec SOACS)
pat

    branches_free_adj <-
      ( pure . takeLast (length branches_free)
          <=< letTupExp "branch_adj"
          <=< renameExp
        )
        =<< eMatch
          ses
          (map (fmap $ diffBody adjs branches_free) cases)
          (diffBody adjs branches_free defbody)
    -- See Note [Array Adjoints of Match]
    forM_ (zip branches_free branches_free_adj) $ \(VName
v, VName
v_adj) ->
      VName -> VName -> ADM ()
insAdj VName
v (VName -> ADM ()) -> ADM VName -> ADM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> ADM VName
copyIfArray VName
v_adj
diffStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op OpC SOACS SOACS
soac)) ADM ()
m =
  -- We add the attributes from 'aux' to every SOAC (but only SOAC) produced. We
  -- could do this on *every* stm, but it would be very verbose.
  (Stms (Rep ADM) -> Stms (Rep ADM)) -> ADM () -> ADM ()
forall (m :: * -> *) a.
MonadBuilder m =>
(Stms (Rep m) -> Stms (Rep m)) -> m a -> m a
censorStms ((Stm SOACS -> Stm SOACS) -> Stms SOACS -> Stms SOACS
forall a b. (a -> b) -> Seq a -> Seq b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Stm SOACS -> Stm SOACS
addAttrs) (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ VjpOps -> Pat Type -> StmAux () -> SOAC SOACS -> ADM () -> ADM ()
vjpSOAC VjpOps
vjpOps Pat Type
Pat (LetDec SOACS)
pat StmAux ()
StmAux (ExpDec SOACS)
aux OpC SOACS SOACS
SOAC SOACS
soac ADM ()
m
  where
    addAttrs :: Stm SOACS -> Stm SOACS
addAttrs Stm SOACS
stm
      | Op OpC SOACS SOACS
_ <- Stm SOACS -> Exp SOACS
forall rep. Stm rep -> Exp rep
stmExp Stm SOACS
stm =
          Attrs -> Stm SOACS -> Stm SOACS
forall rep. Attrs -> Stm rep -> Stm rep
attribute (StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
StmAux (ExpDec SOACS)
aux) Stm SOACS
stm
      | Bool
otherwise = Stm SOACS
stm
diffStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux loop :: Exp SOACS
loop@Loop {}) ADM ()
m =
  (Stms SOACS -> ADM ())
-> Pat Type -> StmAux () -> Exp SOACS -> ADM () -> ADM ()
diffLoop Stms SOACS -> ADM ()
diffStms Pat Type
Pat (LetDec SOACS)
pat StmAux ()
StmAux (ExpDec SOACS)
aux Exp SOACS
loop ADM ()
m
-- See Note [Adjoints of accumulators]
diffStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
_aux (WithAcc [WithAccInput SOACS]
inputs Lambda SOACS
lam)) ADM ()
m = do
  Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm (Rep ADM)
Stm SOACS
stm
  ADM ()
m
  ADM () -> ADM ()
forall a. ADM a -> ADM a
returnSweepCode (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ do
    adjs <- (VName -> ADM Adj) -> [VName] -> ADM [Adj]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM Adj
lookupAdj ([VName] -> ADM [Adj]) -> [VName] -> ADM [Adj]
forall a b. (a -> b) -> a -> b
$ Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
Pat (LetDec SOACS)
pat
    lam' <- renameLambda lam
    free_vars <- filterM isActive $ namesToList $ freeIn lam'
    free_accs <- filterM (fmap isAcc . lookupType) free_vars
    let free_vars' = [VName]
free_vars [VName] -> [VName] -> [VName]
forall a. Eq a => [a] -> [a] -> [a]
\\ [VName]
free_accs
    lam'' <- diffLambda' adjs free_vars' lam'
    inputs' <- mapM renameInputLambda inputs
    free_adjs <- letTupExp "with_acc_contrib" $ WithAcc inputs' lam''
    zipWithM_ insAdj (arrs <> free_vars') free_adjs
  where
    arrs :: [VName]
arrs = (WithAccInput SOACS -> [VName]) -> [WithAccInput SOACS] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\(Shape
_, [VName]
as, Maybe (Lambda SOACS, [SubExp])
_) -> [VName]
as) [WithAccInput SOACS]
inputs
    renameInputLambda :: (a, b, Maybe (Lambda rep, b)) -> m (a, b, Maybe (Lambda rep, b))
renameInputLambda (a
shape, b
as, Just (Lambda rep
f, b
nes)) = do
      f' <- Lambda rep -> m (Lambda rep)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda rep
f
      pure (shape, as, Just (f', nes))
    renameInputLambda (a, b, Maybe (Lambda rep, b))
input = (a, b, Maybe (Lambda rep, b)) -> m (a, b, Maybe (Lambda rep, b))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a, b, Maybe (Lambda rep, b))
input
    diffLambda' :: [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
diffLambda' [Adj]
res_adjs [VName]
get_adjs_for (Lambda [LParam SOACS]
params [Type]
ts Body SOACS
body) =
      Scope SOACS -> ADM (Lambda SOACS) -> ADM (Lambda SOACS)
forall a. Scope SOACS -> ADM a -> ADM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([LParam SOACS] -> Scope SOACS
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [LParam SOACS]
params) (ADM (Lambda SOACS) -> ADM (Lambda SOACS))
-> ADM (Lambda SOACS) -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ do
        Body () stms res <- [Adj] -> [VName] -> Body SOACS -> ADM (Body SOACS)
diffBody [Adj]
res_adjs [VName]
get_adjs_for Body SOACS
body
        let body' = BodyDec SOACS -> Stms SOACS -> Result -> Body SOACS
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms SOACS
stms (Result -> Body SOACS) -> Result -> Body SOACS
forall a b. (a -> b) -> a -> b
$ Int -> Result -> Result
forall a. Int -> [a] -> [a]
take ([WithAccInput SOACS] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput SOACS]
inputs) Result
res Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Int -> Result -> Result
forall a. Int -> [a] -> [a]
takeLast ([VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
get_adjs_for) Result
res
        ts' <- mapM lookupType get_adjs_for
        pure $ Lambda params (take (length inputs) ts <> ts') body'
diffStm Stm SOACS
stm ADM ()
_ = String -> ADM ()
forall a. HasCallStack => String -> a
error (String -> ADM ()) -> String -> ADM ()
forall a b. (a -> b) -> a -> b
$ String
"diffStm unhandled:\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Stm SOACS -> String
forall a. Pretty a => a -> String
prettyString Stm SOACS
stm

diffStms :: Stms SOACS -> ADM ()
diffStms :: Stms SOACS -> ADM ()
diffStms Stms SOACS
all_stms
  | Just (Stm SOACS
stm, Stms SOACS
stms) <- Stms SOACS -> Maybe (Stm SOACS, Stms SOACS)
forall rep. Stms rep -> Maybe (Stm rep, Stms rep)
stmsHead Stms SOACS
all_stms = do
      (subst, copy_stms) <- Stm SOACS -> ADM (Substitutions, Stms SOACS)
copyConsumedArrsInStm Stm SOACS
stm
      let (stm', stms') = substituteNames subst (stm, stms)
      diffStms copy_stms >> diffStm stm' (diffStms stms')
      forM_ (M.toList subst) $ \(VName
from, VName
to) ->
        VName -> Adj -> ADM ()
setAdj VName
from (Adj -> ADM ()) -> ADM Adj -> ADM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> ADM Adj
lookupAdj VName
to
  | Bool
otherwise =
      () -> ADM ()
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- | Preprocess statements before differentiating.
-- For now, it's just stripmining.
preprocess :: Stms SOACS -> ADM (Stms SOACS)
preprocess :: Stms SOACS -> ADM (Stms SOACS)
preprocess = Stms SOACS -> ADM (Stms SOACS)
stripmineStms

diffBody :: [Adj] -> [VName] -> Body SOACS -> ADM (Body SOACS)
diffBody :: [Adj] -> [VName] -> Body SOACS -> ADM (Body SOACS)
diffBody [Adj]
res_adjs [VName]
get_adjs_for (Body () Stms SOACS
stms Result
res) = ADM (Body SOACS) -> ADM (Body SOACS)
forall a. ADM a -> ADM a
subAD (ADM (Body SOACS) -> ADM (Body SOACS))
-> ADM (Body SOACS) -> ADM (Body SOACS)
forall a b. (a -> b) -> a -> b
$
  ADM (Body SOACS) -> ADM (Body SOACS)
forall a. ADM a -> ADM a
subSubsts (ADM (Body SOACS) -> ADM (Body SOACS))
-> ADM (Body SOACS) -> ADM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ do
    let onResult :: SubExpRes -> Adj -> ADM ()
onResult (SubExpRes Certs
_ (Constant PrimValue
_)) Adj
_ = () -> ADM ()
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        onResult (SubExpRes Certs
_ (Var VName
v)) Adj
v_adj = ADM () -> ADM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> ADM ()
updateAdj VName
v (VName -> ADM ()) -> ADM VName -> ADM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Adj -> ADM VName
adjVal Adj
v_adj
    (adjs, stms') <- ADM [VName] -> ADM ([VName], Stms (Rep ADM))
forall a. ADM a -> ADM (a, Stms (Rep ADM))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms (ADM [VName] -> ADM ([VName], Stms (Rep ADM)))
-> ADM [VName] -> ADM ([VName], Stms (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
      (SubExpRes -> Adj -> ADM ()) -> Result -> [Adj] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ SubExpRes -> Adj -> ADM ()
onResult (Int -> Result -> Result
forall a. Int -> [a] -> [a]
takeLast ([Adj] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Adj]
res_adjs) Result
res) [Adj]
res_adjs
      Stms SOACS -> ADM ()
diffStms (Stms SOACS -> ADM ()) -> ADM (Stms SOACS) -> ADM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Stms SOACS -> ADM (Stms SOACS)
preprocess Stms SOACS
stms
      (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM VName
lookupAdjVal [VName]
get_adjs_for
    pure $ Body () stms' $ res <> varsRes adjs

diffLambda :: [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
diffLambda :: [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
diffLambda [Adj]
res_adjs [VName]
get_adjs_for (Lambda [LParam SOACS]
params [Type]
_ Body SOACS
body) =
  Scope SOACS -> ADM (Lambda SOACS) -> ADM (Lambda SOACS)
forall a. Scope SOACS -> ADM a -> ADM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([LParam SOACS] -> Scope SOACS
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [LParam SOACS]
params) (ADM (Lambda SOACS) -> ADM (Lambda SOACS))
-> ADM (Lambda SOACS) -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ do
    Body () stms res <- [Adj] -> [VName] -> Body SOACS -> ADM (Body SOACS)
diffBody [Adj]
res_adjs [VName]
get_adjs_for Body SOACS
body
    let body' = BodyDec SOACS -> Stms SOACS -> Result -> Body SOACS
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms SOACS
stms (Result -> Body SOACS) -> Result -> Body SOACS
forall a b. (a -> b) -> a -> b
$ Int -> Result -> Result
forall a. Int -> [a] -> [a]
takeLast ([VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
get_adjs_for) Result
res
    ts' <- mapM lookupType get_adjs_for
    pure $ Lambda params ts' body'

revVJP :: (MonadFreshNames m) => Scope SOACS -> Lambda SOACS -> m (Lambda SOACS)
revVJP :: forall (m :: * -> *).
MonadFreshNames m =>
Scope SOACS -> Lambda SOACS -> m (Lambda SOACS)
revVJP Scope SOACS
scope (Lambda [LParam SOACS]
params [Type]
ts Body SOACS
body) =
  ADM (Lambda SOACS) -> m (Lambda SOACS)
forall (m :: * -> *) a. MonadFreshNames m => ADM a -> m a
runADM (ADM (Lambda SOACS) -> m (Lambda SOACS))
-> (ADM (Lambda SOACS) -> ADM (Lambda SOACS))
-> ADM (Lambda SOACS)
-> m (Lambda SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope SOACS -> ADM (Lambda SOACS) -> ADM (Lambda SOACS)
forall a. Scope SOACS -> ADM a -> ADM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Scope SOACS
scope Scope SOACS -> Scope SOACS -> Scope SOACS
forall a. Semigroup a => a -> a -> a
<> [Param Type] -> Scope SOACS
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [Param Type]
[LParam SOACS]
params) (ADM (Lambda SOACS) -> m (Lambda SOACS))
-> ADM (Lambda SOACS) -> m (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ do
    params_adj <- [(SubExp, Type)]
-> ((SubExp, Type) -> ADM (Param Type)) -> ADM [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([SubExp] -> [Type] -> [(SubExp, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult Body SOACS
body)) [Type]
ts) (((SubExp, Type) -> ADM (Param Type)) -> ADM [Param Type])
-> ((SubExp, Type) -> ADM (Param Type)) -> ADM [Param Type]
forall a b. (a -> b) -> a -> b
$ \(SubExp
se, Type
t) ->
      Attrs -> VName -> Type -> Param Type
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty (VName -> Type -> Param Type)
-> ADM VName -> ADM (Type -> Param Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ADM VName -> (VName -> ADM VName) -> Maybe VName -> ADM VName
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (String -> ADM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"const_adj") VName -> ADM VName
adjVName (SubExp -> Maybe VName
subExpVar SubExp
se) ADM (Type -> Param Type) -> ADM Type -> ADM (Param Type)
forall a b. ADM (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> ADM Type
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
t

    body' <-
      localScope (scopeOfLParams params_adj) $
        diffBody
          (map adjFromParam params_adj)
          (map paramName params)
          body

    pure $ Lambda (params ++ params_adj) (ts <> map paramType params) body'

-- Note [Adjoints of accumulators]
--
-- The general case of taking adjoints of WithAcc is tricky.  We make
-- some assumptions and lay down a basic design.
--
-- First, we assume that any WithAccs that occur in the program are
-- the result of previous invocations of VJP.  This means we can rely
-- on the operator having a constant adjoint (it's some kind of
-- addition).
--
-- Second, the adjoint of an accumulator is an array of the same type
-- as the underlying array.  For example, the adjoint type of the
-- primal type 'acc(c, [n], {f64})' is '[n]f64'.  In principle the
-- adjoint of 'acc(c, [n], {f64,f32})' should be two arrays of type
-- '[]f64', '[]f32'.  Our current design assumes that adjoints are
-- single variables.  This is fixable.
--
-- # Adjoint of UpdateAcc
--
--   Consider primal code
--
--     update_acc(acc, i, v)
--
--   Interpreted as an imperative statement, this means
--
--     acc[i] ⊕= v
--
--   for some '⊕'.  Normally all the compiler knows of '⊕' is that it
--   is associative and commutative, but because we assume that all
--   accumulators are the result of previous AD transformations, we
--   can assume that '⊕' actually behaves like addition - that is, has
--   unit partial derivatives.  So the return sweep is
--
--     v += acc_adj[i]
--
-- # Adjoint of Map
--
-- Suppose we have primal code
--
--   let acc' =
--     map (...) acc
--
-- where "acc : acc(c, [n], {f64})" and the width of the Map is "w".
-- Our normal transformation for Map input arrays is to similarly map
-- their adjoint, but clearly this doesn't work here because the
-- semantics of mapping an adjoint is an "implicit replicate".  So
-- when generating the return sweep we actually perform that
-- replication:
--
--   map (...) (replicate w acc_adj)
--
-- But what about the contributions to "acc'"?  Those we also have to
-- take special care of.  The result of the map itself is actually a
-- multidimensional array:
--
--   let acc_contribs =
--     map (...) (replicate w acc'_adj)
--
-- which we must then sum to add to the contribution.
--
--   acc_adj += sum(acc_contribs)
--
-- I'm slightly worried about the asymptotics of this, since my
-- intuition of this is that the contributions might be rather sparse.
-- (Maybe completely zero?  If so it will be simplified away
-- entirely.)  Perhaps a better solution is to treat
-- accumulator-inputs in the primal code as we do free variables, and
-- create accumulators for them in the return sweep.
--
-- # Consumption
--
-- A minor problem is that our usual way of handling consumption (Note
-- [Consumption]) is not viable, because accumulators are not
-- copyable.  Fortunately, while the accumulators that are consumed in
-- the forward sweep will also be present in the return sweep given
-- our current translation rules, they will be dead code.  As long as
-- we are careful to run dead code elimination after revVJP, we should
-- be good.

-- Note [Array Adjoints of Match]
--
-- Some unusual, but sadly not completely contrived, contain Match
-- expressions that return multiple arrays, and there the arrays
-- returned by one branch have overlapping aliases with another
-- branch, although in different places. As an example consider this:
--
--   let (X,Y) = if c
--               then (A, B)
--               else (B, A)
--
-- Because our aliasing representation cannot express mutually
-- exclusive aliases, we will consider X and Y to be aliased to each
-- other. In practice, this means it is unlikely for X or Y to be
-- consumed, because it would also consume the other (although it's
-- possible for carefully written code).
--
-- When producing adjoints for this, it will be something like
--
--   let (X_adj,Y_adj) = if c
--                       then (A_adj, B_adj)
--                       else (B_adj, A_adj)
--
-- which completely reflects the primal code. However, while it is
-- unlikely that any consumption takes place for the original primal
-- variables, it is almost guaranteed that X_adj and Y_adj will be
-- consumed (that is the main way we use adjoints after all), and due
-- to the conservative aliasing, when one is consumed, so is the
-- other! To avoid this tragic fate, we are forced to copy any
-- array-typed adjoints returned by a Match. This can be quite costly.
-- However:
--
-- 1) Futhark has pretty OK copy removal, so maybe it can get rid of
--    these by using information not available to the AD pass.
--
-- 2) In many cases, arrays will have accumulator adjoints, which are
--    not subject to this problem.
--
-- Issue #2228 was caused by neglecting to do this.