module Futhark.AD.Rev.Scan (diffScan, diffScanVec, diffScanAdd) where

import Control.Monad
import Data.List (transpose)
import Futhark.AD.Rev.Monad
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.IR.SOACS.Simplify (simplifyLambda)
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Util (chunk)

data FirstOrSecond = WrtFirst | WrtSecond

identityM :: Int -> Type -> ADM [[SubExp]]
identityM :: Int -> TypeBase Shape NoUniqueness -> ADM [[SubExp]]
identityM Int
n TypeBase Shape NoUniqueness
t =
  ([Exp SOACS] -> ADM [SubExp]) -> [[Exp SOACS]] -> ADM [[SubExp]]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse
    ((Exp SOACS -> ADM SubExp) -> [Exp SOACS] -> ADM [SubExp]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ([Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"id"))
    [[if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
j then TypeBase Shape NoUniqueness -> Exp SOACS
forall rep. TypeBase Shape NoUniqueness -> Exp rep
oneExp TypeBase Shape NoUniqueness
t else TypeBase Shape NoUniqueness -> Exp SOACS
forall rep. TypeBase Shape NoUniqueness -> Exp rep
zeroExp TypeBase Shape NoUniqueness
t | Int
i <- [Int
1 .. Int
n]] | Int
j <- [Int
1 .. Int
n]]

matrixMul :: [[PrimExp VName]] -> [[PrimExp VName]] -> PrimType -> [[PrimExp VName]]
matrixMul :: [[PrimExp VName]]
-> [[PrimExp VName]] -> PrimType -> [[PrimExp VName]]
matrixMul [[PrimExp VName]]
m1 [[PrimExp VName]]
m2 PrimType
t =
  let zero :: PrimExp VName
zero = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t (SubExp -> PrimExp VName) -> SubExp -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t
   in [[(PrimExp VName -> PrimExp VName -> PrimExp VName)
-> PrimExp VName -> [PrimExp VName] -> PrimExp VName
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
(~+~) PrimExp VName
zero ([PrimExp VName] -> PrimExp VName)
-> [PrimExp VName] -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ (PrimExp VName -> PrimExp VName -> PrimExp VName)
-> [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
(~*~) [PrimExp VName]
r [PrimExp VName]
q | [PrimExp VName]
q <- [[PrimExp VName]] -> [[PrimExp VName]]
forall a. [[a]] -> [[a]]
transpose [[PrimExp VName]]
m2] | [PrimExp VName]
r <- [[PrimExp VName]]
m1]

matrixVecMul :: [[PrimExp VName]] -> [PrimExp VName] -> PrimType -> [PrimExp VName]
matrixVecMul :: [[PrimExp VName]] -> [PrimExp VName] -> PrimType -> [PrimExp VName]
matrixVecMul [[PrimExp VName]]
m [PrimExp VName]
v PrimType
t =
  let zero :: PrimExp VName
zero = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t (SubExp -> PrimExp VName) -> SubExp -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t
   in [(PrimExp VName -> PrimExp VName -> PrimExp VName)
-> PrimExp VName -> [PrimExp VName] -> PrimExp VName
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
(~+~) PrimExp VName
zero ([PrimExp VName] -> PrimExp VName)
-> [PrimExp VName] -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ (PrimExp VName -> PrimExp VName -> PrimExp VName)
-> [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
(~*~) [PrimExp VName]
v [PrimExp VName]
r | [PrimExp VName]
r <- [[PrimExp VName]]
m]

vectorAdd :: [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
vectorAdd :: [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
vectorAdd = (PrimExp VName -> PrimExp VName -> PrimExp VName)
-> [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
(~+~)

orderArgs :: Special -> [a] -> [[a]]
orderArgs :: forall a. Special -> [a] -> [[a]]
orderArgs Special
s [a]
lst = Int -> [a] -> [[a]]
forall a. Int -> [a] -> [[a]]
chunk (Int -> Int -> Int
forall a. Integral a => a -> a -> a
div ([a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
lst) (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Special -> Int
specialScans Special
s) [a]
lst

-- computes `d(x op y)/dx` or d(x op y)/dy
mkScanAdjointLam :: VjpOps -> Lambda SOACS -> FirstOrSecond -> [SubExp] -> ADM (Lambda SOACS)
mkScanAdjointLam :: VjpOps
-> Lambda SOACS -> FirstOrSecond -> [SubExp] -> ADM (Lambda SOACS)
mkScanAdjointLam VjpOps
ops Lambda SOACS
lam0 FirstOrSecond
which [SubExp]
adjs = do
  let len :: Int
len = [TypeBase Shape NoUniqueness] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([TypeBase Shape NoUniqueness] -> Int)
-> [TypeBase Shape NoUniqueness] -> Int
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
lam0
  lam <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam0
  let p2diff =
        case FirstOrSecond
which of
          FirstOrSecond
WrtFirst -> Int
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. Int -> [a] -> [a]
take Int
len ([Param (TypeBase Shape NoUniqueness)]
 -> [Param (TypeBase Shape NoUniqueness)])
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam
          FirstOrSecond
WrtSecond -> Int
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. Int -> [a] -> [a]
drop Int
len ([Param (TypeBase Shape NoUniqueness)]
 -> [Param (TypeBase Shape NoUniqueness)])
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam
  vjpLambda ops (fmap AdjVal adjs) (map paramName p2diff) lam

-- Should generate something like:
-- `\ j -> let i = n - 1 - j
--         if i < n-1 then ( ys_adj[i], df2dx ys[i] xs[i+1]) else (ys_adj[i],1) )`
-- where `ys` is  the result of scan
--       `xs` is  the input  of scan
--       `ys_adj` is the known adjoint of ys
--       `j` draw values from `iota n`
mkScanFusedMapLam :: -- i and j above are probably swapped in the code below
  VjpOps -> -- (ops) helper functions
  SubExp -> -- (w) ~length of arrays e.g. xs
  Lambda SOACS -> -- (scn_lam) the scan to be differentiated ('scan' turned into a lambda)
  [VName] -> -- (xs) input of the scan (actually as)
  [VName] -> -- (ys) output of the scan
  [VName] -> -- (ys_adj) adjoint of ys
  Special -> -- (s) information about which special case we're working with for the scan derivative
  Int -> -- (d) dimension of the input (number of elements in the input tuple)
  ADM (Lambda SOACS) -- output: some kind of codegen for the lambda
mkScanFusedMapLam :: VjpOps
-> SubExp
-> Lambda SOACS
-> [VName]
-> [VName]
-> [VName]
-> Special
-> Int
-> ADM (Lambda SOACS)
mkScanFusedMapLam VjpOps
ops SubExp
w Lambda SOACS
scn_lam [VName]
xs [VName]
ys [VName]
ys_adj Special
s Int
d = do
  let sc :: Maybe SpecialCase
sc = Special -> Maybe SpecialCase
specialCase Special
s
      k :: Int
k = Special -> Int
specialSubSize Special
s
  ys_ts <- (VName -> ADM (TypeBase Shape NoUniqueness))
-> [VName] -> ADM [TypeBase Shape NoUniqueness]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType [VName]
ys
  idmat <- identityM (length ys) $ rowType $ head ys_ts
  lams <- traverse (mkScanAdjointLam ops scn_lam WrtFirst) idmat
  par_i <- newParam "i" $ Prim int64
  let i = Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
par_i
  mkLambda [par_i] $
    fmap varsRes . letTupExp "x"
      =<< eIf
        (toExp $ le64 i .==. 0)
        ( buildBody_ $ do
            j <- letSubExp "j" =<< toExp (pe64 w - (le64 i + 1))
            y_s <- forM ys_adj $ \VName
y_ ->
              [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp (VName -> [Char]
baseString VName
y_ [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_j") (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eIndex VName
y_ [SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
j]
            let zso = Special -> [SubExp] -> [[SubExp]]
forall a. Special -> [a] -> [[a]]
orderArgs Special
s [SubExp]
y_s
            let ido = Special -> [[SubExp]] -> [[[SubExp]]]
forall a. Special -> [a] -> [[a]]
orderArgs Special
s ([[SubExp]] -> [[[SubExp]]]) -> [[SubExp]] -> [[[SubExp]]]
forall a b. (a -> b) -> a -> b
$ Int -> Maybe SpecialCase -> [[SubExp]] -> [[SubExp]]
forall a. Int -> Maybe SpecialCase -> [[a]] -> [[a]]
caseJac Int
k Maybe SpecialCase
sc [[SubExp]]
idmat
            pure $ subExpsRes $ concat $ zipWith (++) zso $ fmap concat ido
        )
        ( buildBody_ $ do
            j <- letSubExp "j" =<< toExp (pe64 w - (le64 i + 1))
            j1 <- letSubExp "j1" =<< toExp (pe64 w - le64 i)
            y_s <- forM ys_adj $ \VName
y_ ->
              [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp (VName -> [Char]
baseString VName
y_ [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_j") (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eIndex VName
y_ [SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
j]

            let args =
                  (VName -> ADM (Exp (Rep ADM))) -> [VName] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
`eIndex` [SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
j]) [VName]
ys [ADM (Exp (Rep ADM))]
-> [ADM (Exp (Rep ADM))] -> [ADM (Exp (Rep ADM))]
forall a. [a] -> [a] -> [a]
++ (VName -> ADM (Exp (Rep ADM))) -> [VName] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
`eIndex` [SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
j1]) [VName]
xs
            lam_rs <- traverse (`eLambda` args) lams

            let yso = Special -> Result -> [Result]
forall a. Special -> [a] -> [[a]]
orderArgs Special
s (Result -> [Result]) -> Result -> [Result]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
y_s
            let jaco = Special -> [Result] -> [[Result]]
forall a. Special -> [a] -> [[a]]
orderArgs Special
s ([Result] -> [[Result]]) -> [Result] -> [[Result]]
forall a b. (a -> b) -> a -> b
$ Int -> Maybe SpecialCase -> [Result] -> [Result]
forall a. Int -> Maybe SpecialCase -> [[a]] -> [[a]]
caseJac Int
k Maybe SpecialCase
sc ([Result] -> [Result]) -> [Result] -> [Result]
forall a b. (a -> b) -> a -> b
$ [Result] -> [Result]
forall a. [[a]] -> [[a]]
transpose [Result]
lam_rs

            pure $ concat $ zipWith (++) yso $ fmap concat jaco
        )
  where
    caseJac :: Int -> Maybe SpecialCase -> [[a]] -> [[a]]
    caseJac :: forall a. Int -> Maybe SpecialCase -> [[a]] -> [[a]]
caseJac Int
_ Maybe SpecialCase
Nothing [[a]]
jac = [[a]]
jac
    caseJac Int
k (Just SpecialCase
ZeroQuadrant) [[a]]
jac =
      [[[a]]] -> [[a]]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[[a]]] -> [[a]]) -> [[[a]]] -> [[a]]
forall a b. (a -> b) -> a -> b
$
        (Int -> [[a]] -> [[a]]) -> [Int] -> [[[a]]] -> [[[a]]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Int
i -> ([a] -> [a]) -> [[a]] -> [[a]]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
take Int
k ([a] -> [a]) -> ([a] -> [a]) -> [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
drop (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
k))) [Int
0 .. Int
d Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
k] ([[[a]]] -> [[[a]]]) -> [[[a]]] -> [[[a]]]
forall a b. (a -> b) -> a -> b
$
          Int -> [[a]] -> [[[a]]]
forall a. Int -> [a] -> [[a]]
chunk Int
k [[a]]
jac
    caseJac Int
k (Just SpecialCase
MatrixMul) [[a]]
jac =
      Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
take Int
k ([a] -> [a]) -> [[a]] -> [[a]]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> [[a]] -> [[a]]
forall a. Int -> [a] -> [a]
take Int
k [[a]]
jac

-- a1 a2 b -> a2 + b * a1
linFunT0 :: [PrimExp VName] -> [PrimExp VName] -> [[PrimExp VName]] -> Special -> PrimType -> [PrimExp VName]
linFunT0 :: [PrimExp VName]
-> [PrimExp VName]
-> [[PrimExp VName]]
-> Special
-> PrimType
-> [PrimExp VName]
linFunT0 [PrimExp VName]
a1 [PrimExp VName]
a2 [[PrimExp VName]]
b Special
s PrimType
pt =
  let t :: [PrimExp VName]
t = case Special -> Maybe SpecialCase
specialCase Special
s of
        Just SpecialCase
MatrixMul ->
          ([PrimExp VName] -> [PrimExp VName])
-> [[PrimExp VName]] -> [PrimExp VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\[PrimExp VName]
v -> [[PrimExp VName]] -> [PrimExp VName] -> PrimType -> [PrimExp VName]
matrixVecMul [[PrimExp VName]]
b [PrimExp VName]
v PrimType
pt) ([[PrimExp VName]] -> [PrimExp VName])
-> [[PrimExp VName]] -> [PrimExp VName]
forall a b. (a -> b) -> a -> b
$ Int -> [PrimExp VName] -> [[PrimExp VName]]
forall a. Int -> [a] -> [[a]]
chunk (Special -> Int
specialSubSize Special
s) [PrimExp VName]
a1
        Maybe SpecialCase
_ -> [[PrimExp VName]] -> [PrimExp VName] -> PrimType -> [PrimExp VName]
matrixVecMul [[PrimExp VName]]
b [PrimExp VName]
a1 PrimType
pt
   in [PrimExp VName]
a2 [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
`vectorAdd` [PrimExp VName]
t

-- \(a1, b1) (a2, b2) -> (a2 + b2 * a1, b2 * b1)
mkScanLinFunO :: Type -> Special -> ADM (Scan SOACS) -- a is an instance of y_bar, b is a Jacobian (a 'c' in the 2023 paper)
mkScanLinFunO :: TypeBase Shape NoUniqueness -> Special -> ADM (Scan SOACS)
mkScanLinFunO TypeBase Shape NoUniqueness
t Special
s = do
  let pt :: PrimType
pt = TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
t
  neu_elm <- (Int, Int) -> ADM [SubExp]
mkNeutral ((Int, Int) -> ADM [SubExp]) -> (Int, Int) -> ADM [SubExp]
forall a b. (a -> b) -> a -> b
$ Special -> (Int, Int)
specialNeutral Special
s
  let (as, bs) = specialParams s -- input size, Jacobian element count
  (a1s, b1s, a2s, b2s) <- mkParams (as, bs) -- create sufficient free variables to bind every element of the vectors / matrices
  let pet = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
pt (SubExp -> PrimExp VName)
-> (VName -> SubExp) -> VName -> PrimExp VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var -- manifest variable names as expressions
  let (_, n) = specialNeutral s -- output size (one side of the Jacobian)
  lam <- mkLambda (map (\VName
v -> Attrs
-> VName
-> TypeBase Shape NoUniqueness
-> Param (TypeBase Shape NoUniqueness)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty VName
v (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType TypeBase Shape NoUniqueness
t)) (a1s ++ b1s ++ a2s ++ b2s)) . fmap subExpsRes $ do
    let [a1s', b1s', a2s', b2s'] = (fmap . fmap) pet [a1s, b1s, a2s, b2s]
    let (b1sm, b2sm) = (chunk n b1s', chunk n b2s')

    let t0 = [PrimExp VName]
-> [PrimExp VName]
-> [[PrimExp VName]]
-> Special
-> PrimType
-> [PrimExp VName]
linFunT0 [PrimExp VName]
a1s' [PrimExp VName]
a2s' [[PrimExp VName]]
b2sm Special
s PrimType
pt
    let t1 = [[PrimExp VName]] -> [PrimExp VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[PrimExp VName]] -> [PrimExp VName])
-> [[PrimExp VName]] -> [PrimExp VName]
forall a b. (a -> b) -> a -> b
$ [[PrimExp VName]]
-> [[PrimExp VName]] -> PrimType -> [[PrimExp VName]]
matrixMul [[PrimExp VName]]
b2sm [[PrimExp VName]]
b1sm PrimType
pt
    traverse (letSubExp "r" <=< toExp) $ t0 ++ t1

  pure $ Scan lam neu_elm
  where
    mkNeutral :: (Int, Int) -> ADM [SubExp]
mkNeutral (Int
a, Int
b) = do
      zeros <- Int -> ADM SubExp -> ADM [SubExp]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
a (ADM SubExp -> ADM [SubExp]) -> ADM SubExp -> ADM [SubExp]
forall a b. (a -> b) -> a -> b
$ [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"zeros" (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> Exp (Rep ADM)
forall rep. TypeBase Shape NoUniqueness -> Exp rep
zeroExp (TypeBase Shape NoUniqueness -> Exp (Rep ADM))
-> TypeBase Shape NoUniqueness -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType TypeBase Shape NoUniqueness
t
      idmat <- identityM b $ Prim $ elemType t
      pure $ zeros ++ concat idmat

    mkParams :: (Int, Int) -> m ([VName], [VName], [VName], [VName])
mkParams (Int
a, Int
b) = do
      a1s <- Int -> m VName -> m [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
a (m VName -> m [VName]) -> m VName -> m [VName]
forall a b. (a -> b) -> a -> b
$ [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"a1"
      b1s <- replicateM b $ newVName "b1"
      a2s <- replicateM a $ newVName "a2"
      b2s <- replicateM b $ newVName "b2"
      pure (a1s, b1s, a2s, b2s)

-- perform the final map
-- let xs_contribs =
--    map3 (\ i a r -> if i==0 then r else (df2dy (ys[i-1]) a) \bar{*} r)
--         (iota n) xs (reverse ds)
mkScanFinalMap :: VjpOps -> SubExp -> Lambda SOACS -> [VName] -> [VName] -> [VName] -> ADM [VName]
mkScanFinalMap :: VjpOps
-> SubExp
-> Lambda SOACS
-> [VName]
-> [VName]
-> [VName]
-> ADM [VName]
mkScanFinalMap VjpOps
ops SubExp
w Lambda SOACS
scan_lam [VName]
xs [VName]
ys [VName]
ds = do
  let eltps :: [TypeBase Shape NoUniqueness]
eltps = Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
scan_lam

  par_i <- [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" (TypeBase Shape NoUniqueness
 -> ADM (Param (TypeBase Shape NoUniqueness)))
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  let i = Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
par_i
  par_x <- zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_x")) xs eltps

  map_lam <-
    mkLambda (par_i : par_x) $ do
      j <- letSubExp "j" =<< toExp (pe64 w - (le64 i + 1))

      dj <-
        forM ds $ \VName
dd ->
          [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
dd [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_dj") (Exp SOACS -> ADM VName) -> ADM (Exp SOACS) -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eIndex VName
dd [SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
j]

      fmap varsRes . letTupExp "scan_contribs"
        =<< eIf
          (toExp $ le64 i .==. 0)
          (resultBodyM $ fmap Var dj)
          ( buildBody_ $ do
              lam <- mkScanAdjointLam ops scan_lam WrtSecond $ fmap Var dj

              im1 <- letSubExp "im1" =<< toExp (le64 i - 1)
              ys_im1 <- forM ys $ \VName
y ->
                [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp (VName -> [Char]
baseString VName
y [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_im1") (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eIndex VName
y [SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
im1]

              let args = (SubExp -> ADM (Exp (Rep ADM)))
-> [SubExp] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp ([SubExp] -> [ADM (Exp (Rep ADM))])
-> [SubExp] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ [SubExp]
ys_im1 [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ (Param (TypeBase Shape NoUniqueness) -> SubExp)
-> [Param (TypeBase Shape NoUniqueness)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (Param (TypeBase Shape NoUniqueness) -> VName)
-> Param (TypeBase Shape NoUniqueness)
-> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName) [Param (TypeBase Shape NoUniqueness)]
par_x
              eLambda lam args
          )

  iota <- letExp "iota" $ BasicOp $ Iota w (intConst Int64 0) (intConst Int64 1) Int64
  letTupExp "scan_contribs" $ Op $ Screma w (iota : xs) $ mapSOAC map_lam

-- | Scan special cases.
data SpecialCase = ZeroQuadrant | MatrixMul deriving (Int -> SpecialCase -> [Char] -> [Char]
[SpecialCase] -> [Char] -> [Char]
SpecialCase -> [Char]
(Int -> SpecialCase -> [Char] -> [Char])
-> (SpecialCase -> [Char])
-> ([SpecialCase] -> [Char] -> [Char])
-> Show SpecialCase
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
$cshowsPrec :: Int -> SpecialCase -> [Char] -> [Char]
showsPrec :: Int -> SpecialCase -> [Char] -> [Char]
$cshow :: SpecialCase -> [Char]
show :: SpecialCase -> [Char]
$cshowList :: [SpecialCase] -> [Char] -> [Char]
showList :: [SpecialCase] -> [Char] -> [Char]
Show)

-- | Metadata for how to perform the scan for the return sweep.
data Special = Special
  { -- | Size of one of the two dimensions of the Jacobian (e.g. 3 if
    --  it's 3x3, must be square because scan must be a->a->a). It's
    --  the size of the special neutral element, not the element itself
    Special -> (Int, Int)
specialNeutral :: (Int, Int),
    -- | Size of input (nr params); Flat size of Jacobian (dim1 *
    -- dim2)). Number of params for the special lambda.
    Special -> (Int, Int)
specialParams :: (Int, Int),
    -- | The number of scans to do, 1 in most cases, k in the
    -- ZeroQuadrant (block diagonal?) case.
    Special -> Int
specialScans :: Int,
    -- | Probably: the size of submatrices for the ZeroQuadrant (block
    -- diagonal?) case, or 1 otherwise.
    Special -> Int
specialSubSize :: Int,
    -- | Which case.
    Special -> Maybe SpecialCase
specialCase :: Maybe SpecialCase
  }
  deriving (Int -> Special -> [Char] -> [Char]
[Special] -> [Char] -> [Char]
Special -> [Char]
(Int -> Special -> [Char] -> [Char])
-> (Special -> [Char])
-> ([Special] -> [Char] -> [Char])
-> Show Special
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
$cshowsPrec :: Int -> Special -> [Char] -> [Char]
showsPrec :: Int -> Special -> [Char] -> [Char]
$cshow :: Special -> [Char]
show :: Special -> [Char]
$cshowList :: [Special] -> [Char] -> [Char]
showList :: [Special] -> [Char] -> [Char]
Show)

-- | The different ways to handle scans. The best one is chosen
-- heuristically by looking at the operator.
data ScanAlgo
  = -- | Construct and compose the Jacobians; the approach presented
    -- in *Reverse-Mode AD of Multi-Reduce and Scan in Futhark*.
    GenericIFL23 Special
  | -- | The approach from *Parallelism-preserving automatic
    -- differentiation for second-order array languages*.
    GenericPPAD
  deriving (Int -> ScanAlgo -> [Char] -> [Char]
[ScanAlgo] -> [Char] -> [Char]
ScanAlgo -> [Char]
(Int -> ScanAlgo -> [Char] -> [Char])
-> (ScanAlgo -> [Char])
-> ([ScanAlgo] -> [Char] -> [Char])
-> Show ScanAlgo
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
$cshowsPrec :: Int -> ScanAlgo -> [Char] -> [Char]
showsPrec :: Int -> ScanAlgo -> [Char] -> [Char]
$cshow :: ScanAlgo -> [Char]
show :: ScanAlgo -> [Char]
$cshowList :: [ScanAlgo] -> [Char] -> [Char]
showList :: [ScanAlgo] -> [Char] -> [Char]
Show)

subMats :: Int -> [[Exp SOACS]] -> Exp SOACS -> Maybe Int
subMats :: Int -> [[Exp SOACS]] -> Exp SOACS -> Maybe Int
subMats Int
d [[Exp SOACS]]
mat Exp SOACS
zero =
  let sub_d :: [Int]
sub_d = (Int -> Bool) -> [Int] -> [Int]
forall a. (a -> Bool) -> [a] -> [a]
filter (\Int
x -> Int
d Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
x Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) [Int
1 .. (Int
d Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2)]
      poss :: [Bool]
poss = (Int -> Bool) -> [Int] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (\Int
m -> (([Exp SOACS], Int) -> Bool) -> [([Exp SOACS], Int)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Int -> ([Exp SOACS], Int) -> Bool
ok Int
m) ([([Exp SOACS], Int)] -> Bool) -> [([Exp SOACS], Int)] -> Bool
forall a b. (a -> b) -> a -> b
$ [[Exp SOACS]] -> [Int] -> [([Exp SOACS], Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [[Exp SOACS]]
mat [Int
0 .. Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]) [Int]
sub_d
      tmp :: [(Bool, Int)]
tmp = ((Bool, Int) -> Bool) -> [(Bool, Int)] -> [(Bool, Int)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool, Int) -> Bool
forall a b. (a, b) -> a
fst ([Bool] -> [Int] -> [(Bool, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
poss [Int]
sub_d)
   in if [(Bool, Int)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Bool, Int)]
tmp then Maybe Int
forall a. Maybe a
Nothing else Int -> Maybe Int
forall a. a -> Maybe a
Just (Int -> Maybe Int) -> Int -> Maybe Int
forall a b. (a -> b) -> a -> b
$ (Bool, Int) -> Int
forall a b. (a, b) -> b
snd ((Bool, Int) -> Int) -> (Bool, Int) -> Int
forall a b. (a -> b) -> a -> b
$ [(Bool, Int)] -> (Bool, Int)
forall a. HasCallStack => [a] -> a
head [(Bool, Int)]
tmp
  where
    ok :: Int -> ([Exp SOACS], Int) -> Bool
ok Int
m ([Exp SOACS]
row, Int
i) =
      ((Exp SOACS, Int) -> Bool) -> [(Exp SOACS, Int)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(Exp SOACS
v, Int
j) -> Exp SOACS
v Exp SOACS -> Exp SOACS -> Bool
forall a. Eq a => a -> a -> Bool
== Exp SOACS
zero Bool -> Bool -> Bool
|| Int
i Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
j Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
m) ([(Exp SOACS, Int)] -> Bool) -> [(Exp SOACS, Int)] -> Bool
forall a b. (a -> b) -> a -> b
$
        [Exp SOACS] -> [Int] -> [(Exp SOACS, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Exp SOACS]
row [Int
0 .. Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]

cases :: Int -> Type -> [[Exp SOACS]] -> ScanAlgo
cases :: Int -> TypeBase Shape NoUniqueness -> [[Exp SOACS]] -> ScanAlgo
cases Int
d TypeBase Shape NoUniqueness
t [[Exp SOACS]]
mat = case Int -> [[Exp SOACS]] -> Exp SOACS -> Maybe Int
subMats Int
d [[Exp SOACS]]
mat (Exp SOACS -> Maybe Int) -> Exp SOACS -> Maybe Int
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> Exp SOACS
forall rep. TypeBase Shape NoUniqueness -> Exp rep
zeroExp TypeBase Shape NoUniqueness
t of
  Just Int
k ->
    let nonZeros :: [[[Exp SOACS]]]
nonZeros = (Int -> [[Exp SOACS]] -> [[Exp SOACS]])
-> [Int] -> [[[Exp SOACS]]] -> [[[Exp SOACS]]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Int
i -> ([Exp SOACS] -> [Exp SOACS]) -> [[Exp SOACS]] -> [[Exp SOACS]]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> [Exp SOACS] -> [Exp SOACS]
forall a. Int -> [a] -> [a]
take Int
k ([Exp SOACS] -> [Exp SOACS])
-> ([Exp SOACS] -> [Exp SOACS]) -> [Exp SOACS] -> [Exp SOACS]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [Exp SOACS] -> [Exp SOACS]
forall a. Int -> [a] -> [a]
drop (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
k))) [Int
0 .. Int
d Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
k] ([[[Exp SOACS]]] -> [[[Exp SOACS]]])
-> [[[Exp SOACS]]] -> [[[Exp SOACS]]]
forall a b. (a -> b) -> a -> b
$ Int -> [[Exp SOACS]] -> [[[Exp SOACS]]]
forall a. Int -> [a] -> [[a]]
chunk Int
k [[Exp SOACS]]
mat
     in if ([[Exp SOACS]] -> Bool) -> [[[Exp SOACS]]] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ([[Exp SOACS]] -> [[Exp SOACS]] -> Bool
forall a. Eq a => a -> a -> Bool
== [[[Exp SOACS]]] -> [[Exp SOACS]]
forall a. HasCallStack => [a] -> a
head [[[Exp SOACS]]]
nonZeros) ([[[Exp SOACS]]] -> Bool) -> [[[Exp SOACS]]] -> Bool
forall a b. (a -> b) -> a -> b
$ [[[Exp SOACS]]] -> [[[Exp SOACS]]]
forall a. HasCallStack => [a] -> [a]
tail [[[Exp SOACS]]]
nonZeros
          then Special -> ScanAlgo
GenericIFL23 (Special -> ScanAlgo) -> Special -> ScanAlgo
forall a b. (a -> b) -> a -> b
$ (Int, Int)
-> (Int, Int) -> Int -> Int -> Maybe SpecialCase -> Special
Special (Int
d, Int
k) (Int
d, Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
k) Int
1 Int
k (Maybe SpecialCase -> Special) -> Maybe SpecialCase -> Special
forall a b. (a -> b) -> a -> b
$ SpecialCase -> Maybe SpecialCase
forall a. a -> Maybe a
Just SpecialCase
MatrixMul
          else Special -> ScanAlgo
GenericIFL23 (Special -> ScanAlgo) -> Special -> ScanAlgo
forall a b. (a -> b) -> a -> b
$ (Int, Int)
-> (Int, Int) -> Int -> Int -> Maybe SpecialCase -> Special
Special (Int
k, Int
k) (Int
k, Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
k) (Int
d Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
k) Int
k (Maybe SpecialCase -> Special) -> Maybe SpecialCase -> Special
forall a b. (a -> b) -> a -> b
$ SpecialCase -> Maybe SpecialCase
forall a. a -> Maybe a
Just SpecialCase
ZeroQuadrant
  Maybe Int
Nothing ->
    case Int
d of
      Int
1 -> Special -> ScanAlgo
GenericIFL23 (Special -> ScanAlgo) -> Special -> ScanAlgo
forall a b. (a -> b) -> a -> b
$ (Int, Int)
-> (Int, Int) -> Int -> Int -> Maybe SpecialCase -> Special
Special (Int
d, Int
d) (Int
d, Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
d) Int
1 Int
d Maybe SpecialCase
forall a. Maybe a
Nothing
      Int
_ -> ScanAlgo
GenericPPAD

-- | construct and optimise a temporary lambda, that calculates the
-- Jacobian of the scan op. Figure out if the Jacobian has some
-- special shape, discarding the temporary lambda.
identifyCase :: VjpOps -> Lambda SOACS -> ADM ScanAlgo
identifyCase :: VjpOps -> Lambda SOACS -> ADM ScanAlgo
identifyCase VjpOps
ops Lambda SOACS
lam = do
  let t :: [TypeBase Shape NoUniqueness]
t = Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
lam
  let d :: Int
d = [TypeBase Shape NoUniqueness] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TypeBase Shape NoUniqueness]
t
  idmat <- Int -> TypeBase Shape NoUniqueness -> ADM [[SubExp]]
identityM Int
d (TypeBase Shape NoUniqueness -> ADM [[SubExp]])
-> TypeBase Shape NoUniqueness -> ADM [[SubExp]]
forall a b. (a -> b) -> a -> b
$ [TypeBase Shape NoUniqueness] -> TypeBase Shape NoUniqueness
forall a. HasCallStack => [a] -> a
head [TypeBase Shape NoUniqueness]
t
  lams <- traverse (mkScanAdjointLam ops lam WrtFirst) idmat
  par1 <- traverse (newParam "tmp1") t
  par2 <- traverse (newParam "tmp2") t
  jac_lam <- mkLambda (par1 ++ par2) $ do
    let args = (Param (TypeBase Shape NoUniqueness) -> ADM (Exp (Rep ADM)))
-> [Param (TypeBase Shape NoUniqueness)] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Param (TypeBase Shape NoUniqueness) -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam ([Param (TypeBase Shape NoUniqueness)] -> [ADM (Exp (Rep ADM))])
-> [Param (TypeBase Shape NoUniqueness)] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ [Param (TypeBase Shape NoUniqueness)]
par1 [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [Param (TypeBase Shape NoUniqueness)]
par2
    lam_rs <- traverse (`eLambda` args) lams

    pure $ concat (transpose lam_rs)

  simp <- simplifyLambda jac_lam
  let jac = Int -> [Exp rep] -> [[Exp rep]]
forall a. Int -> [a] -> [[a]]
chunk Int
d ([Exp rep] -> [[Exp rep]]) -> [Exp rep] -> [[Exp rep]]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> Exp rep) -> Result -> [Exp rep]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep)
-> (SubExpRes -> BasicOp) -> SubExpRes -> Exp rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp (SubExp -> BasicOp)
-> (SubExpRes -> SubExp) -> SubExpRes -> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) (Result -> [Exp rep]) -> Result -> [Exp rep]
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult (Body SOACS -> Result) -> Body SOACS -> Result
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
simp
  pure $ cases d (head t) jac

scanRight :: [VName] -> SubExp -> Scan SOACS -> ADM [VName]
scanRight :: [VName] -> SubExp -> Scan SOACS -> ADM [VName]
scanRight [VName]
as SubExp
w Scan SOACS
scan = do
  as_types <- (VName -> ADM (TypeBase Shape NoUniqueness))
-> [VName] -> ADM [TypeBase Shape NoUniqueness]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType [VName]
as
  let arg_type_row = (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType [TypeBase Shape NoUniqueness]
as_types

  par_a1 <- zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_par_a1")) as arg_type_row
  par_a2 <- zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_par_a2")) as arg_type_row
  -- Just the original operator but with par_a1 and par_a2 swapped.
  rev_op <- mkLambda (par_a1 <> par_a2) $ do
    op <- renameLambda $ scanLambda scan
    eLambda op (map (toExp . paramName) (par_a2 <> par_a1))
  -- same neutral element
  let e = Scan SOACS -> [SubExp]
forall rep. Scan rep -> [SubExp]
scanNeutral Scan SOACS
scan
  let rev_scan = Lambda SOACS -> [SubExp] -> Scan SOACS
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda SOACS
rev_op [SubExp]
e

  iota <-
    letExp "iota" $ BasicOp $ Iota w (intConst Int64 0) (intConst Int64 1) Int64
  -- flip the input array (this code is inspired from the code in
  -- diffScanAdd, but made to work with [VName] instead VName)
  map_scan <- revArrLam as
  -- perform the scan
  scan_res <-
    letTupExp "adj_ctrb_scan" . Op . Screma w [iota] $
      scanomapSOAC [rev_scan] map_scan
  -- flip the output array again
  rev_lam <- revArrLam scan_res
  letTupExp "reverse_scan_result" $ Op $ Screma w [iota] $ mapSOAC rev_lam
  where
    revArrLam :: [VName] -> ADM (Lambda SOACS)
    revArrLam :: [VName] -> ADM (Lambda SOACS)
revArrLam [VName]
arrs = do
      par_i <- [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" (TypeBase Shape NoUniqueness
 -> ADM (Param (TypeBase Shape NoUniqueness)))
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
      mkLambda [par_i] . forM arrs $ \VName
arr ->
        (VName -> SubExpRes) -> ADM VName -> ADM SubExpRes
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> SubExpRes
varRes (ADM VName -> ADM SubExpRes)
-> (Exp SOACS -> ADM VName) -> Exp SOACS -> ADM SubExpRes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"ys_bar_rev"
          (Exp SOACS -> ADM SubExpRes) -> ADM (Exp SOACS) -> ADM SubExpRes
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eIndex VName
arr [TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
par_i) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)]

mkPPADOpLifted :: VjpOps -> [VName] -> Scan SOACS -> ADM (Lambda SOACS)
mkPPADOpLifted :: VjpOps -> [VName] -> Scan SOACS -> ADM (Lambda SOACS)
mkPPADOpLifted VjpOps
ops [VName]
as Scan SOACS
scan = do
  as_types <- (VName -> ADM (TypeBase Shape NoUniqueness))
-> [VName] -> ADM [TypeBase Shape NoUniqueness]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType [VName]
as
  let arg_type_row = (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType [TypeBase Shape NoUniqueness]
as_types
  par_x1 <- zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_x1")) as arg_type_row
  par_x2_unused <- zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_x2_unused")) as arg_type_row
  par_a1 <- zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_a1")) as arg_type_row
  par_a2 <- zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_a2")) as arg_type_row
  par_y1_h <- zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_y1_h")) as arg_type_row
  par_y2_h <- zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_y2_h")) as arg_type_row

  add_lams <- mapM addLambda arg_type_row

  mkLambda
    (par_x1 ++ par_a1 ++ par_y1_h ++ par_x2_unused ++ par_a2 ++ par_y2_h)
    (op_lift par_x1 par_a1 par_y1_h par_a2 par_y2_h add_lams)
  where
    op_lift :: [Param dec]
-> [Param dec]
-> [Param dec]
-> [Param dec]
-> [Param dec]
-> [Lambda SOACS]
-> ADM Result
op_lift [Param dec]
px1 [Param dec]
pa1 [Param dec]
py1 [Param dec]
pa2 [Param dec]
py2 [Lambda SOACS]
adds = do
      op_bar_1 <- VjpOps
-> Lambda SOACS -> FirstOrSecond -> [SubExp] -> ADM (Lambda SOACS)
mkScanAdjointLam VjpOps
ops (Scan SOACS -> Lambda SOACS
forall rep. Scan rep -> Lambda rep
scanLambda Scan SOACS
scan) FirstOrSecond
WrtFirst (VName -> SubExp
Var (VName -> SubExp) -> (Param dec -> VName) -> Param dec -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName (Param dec -> SubExp) -> [Param dec] -> [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Param dec]
py2)
      let op_bar_args = SubExp -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
toExp (SubExp -> ADM (Exp (Rep ADM)))
-> (Param dec -> SubExp) -> Param dec -> ADM (Exp (Rep ADM))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp) -> (Param dec -> VName) -> Param dec -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName (Param dec -> ADM (Exp (Rep ADM)))
-> [Param dec] -> [ADM (Exp (Rep ADM))]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Param dec]
px1 [Param dec] -> [Param dec] -> [Param dec]
forall a. [a] -> [a] -> [a]
++ [Param dec]
pa1
      z_term <- map resSubExp <$> eLambda op_bar_1 op_bar_args
      let z =
            ((SubExp, SubExp, Lambda (Rep ADM)) -> ADM SubExpRes)
-> [(SubExp, SubExp, Lambda (Rep ADM))] -> ADM Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM
              (\(SubExp
z_t, SubExp
y_1, Lambda (Rep ADM)
add) -> Result -> SubExpRes
forall a. HasCallStack => [a] -> a
head (Result -> SubExpRes) -> ADM Result -> ADM SubExpRes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM Result
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda (Rep ADM)
add [SubExp -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
toExp SubExp
z_t, SubExp -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
toExp SubExp
y_1])
              ([SubExp]
-> [SubExp]
-> [Lambda (Rep ADM)]
-> [(SubExp, SubExp, Lambda (Rep ADM))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SubExp]
z_term (VName -> SubExp
Var (VName -> SubExp) -> (Param dec -> VName) -> Param dec -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName (Param dec -> SubExp) -> [Param dec] -> [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Param dec]
py1) [Lambda (Rep ADM)]
[Lambda SOACS]
adds)

      let x1 = [SubExp] -> Result
subExpsRes ([SubExp] -> Result) -> ADM [SubExp] -> ADM Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Param dec -> ADM SubExp) -> [Param dec] -> ADM [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ([Char] -> SubExp -> ADM SubExp
forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
[Char] -> a -> m SubExp
toSubExp [Char]
"x1" (SubExp -> ADM SubExp)
-> (Param dec -> SubExp) -> Param dec -> ADM SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp) -> (Param dec -> VName) -> Param dec -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName) [Param dec]
px1
      op <- renameLambda $ scanLambda scan
      let a3 = Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM Result
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda (Rep ADM)
Lambda SOACS
op (VName -> ADM (Exp (Rep ADM))
VName -> ADM (Exp SOACS)
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => VName -> m (Exp (Rep m))
toExp (VName -> ADM (Exp SOACS))
-> (Param dec -> VName) -> Param dec -> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName (Param dec -> ADM (Exp SOACS)) -> [Param dec] -> [ADM (Exp SOACS)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Param dec]
pa1 [Param dec] -> [Param dec] -> [Param dec]
forall a. [a] -> [a] -> [a]
++ [Param dec]
pa2)

      concat <$> sequence [x1, a3, z]

asLiftPPAD :: [VName] -> SubExp -> [SubExp] -> ADM [VName]
asLiftPPAD :: [VName] -> SubExp -> [SubExp] -> ADM [VName]
asLiftPPAD [VName]
as SubExp
w [SubExp]
e = do
  par_i <- [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" (TypeBase Shape NoUniqueness
 -> ADM (Param (TypeBase Shape NoUniqueness)))
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  lmb <- mkLambda [par_i] $ do
    forM (zip as e) $ \(VName
arr, SubExp
arr_e) -> do
      a_lift <-
        [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"a_lift"
          (Exp SOACS -> ADM VName) -> ADM (Exp SOACS) -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
            ( do
                nm1 <- [Char] -> TPrimExp Int64 VName -> ADM SubExp
forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
[Char] -> a -> m SubExp
toSubExp [Char]
"n_minus_one" (TPrimExp Int64 VName -> ADM SubExp)
-> TPrimExp Int64 VName -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
w TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1
                pure $ BasicOp $ CmpOp (CmpSlt Int64) (Var $ paramName par_i) nm1
            )
            ( ADM Result -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (ADM Result -> ADM (Body (Rep ADM)))
-> ADM Result -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ (\SubExp
x -> [SubExp -> SubExpRes
subExpRes SubExp
x]) (SubExp -> Result) -> ADM SubExp -> ADM Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"val" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eIndex VName
arr [TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (TPrimExp Int64 VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
par_i) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1])
            )
            (ADM Result -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (ADM Result -> ADM (Body (Rep ADM)))
-> ADM Result -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ Result -> ADM Result
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp -> SubExpRes
subExpRes SubExp
arr_e])
      pure $ varRes a_lift

  iota <- letExp "iota" $ BasicOp $ Iota w (intConst Int64 0) (intConst Int64 1) Int64
  letTupExp "as_lift" $ Op $ Screma w [iota] $ mapSOAC lmb

ysRightPPAD :: [VName] -> SubExp -> [SubExp] -> ADM [VName]
ysRightPPAD :: [VName] -> SubExp -> [SubExp] -> ADM [VName]
ysRightPPAD [VName]
ys SubExp
w [SubExp]
e = do
  par_i <- [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" (TypeBase Shape NoUniqueness
 -> ADM (Param (TypeBase Shape NoUniqueness)))
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  lmb <- mkLambda [par_i] $ do
    forM (zip ys e) $ \(VName
arr, SubExp
arr_e) -> do
      a_lift <-
        [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"y_right"
          (Exp SOACS -> ADM VName) -> ADM (Exp SOACS) -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
            ( Exp (Rep ADM) -> ADM (Exp (Rep ADM))
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Rep ADM) -> ADM (Exp (Rep ADM)))
-> Exp (Rep ADM) -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (PrimType -> CmpOp
CmpEq PrimType
int64) (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
par_i) (Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64))
            )
            (ADM Result -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (ADM Result -> ADM (Body (Rep ADM)))
-> ADM Result -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ Result -> ADM Result
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp -> SubExpRes
subExpRes SubExp
arr_e])
            ( ADM Result -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (ADM Result -> ADM (Body (Rep ADM)))
-> ADM Result -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ (\SubExp
x -> [SubExp -> SubExpRes
subExpRes SubExp
x]) (SubExp -> Result) -> ADM SubExp -> ADM Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"val" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eIndex VName
arr [TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (TPrimExp Int64 VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
par_i) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1])
            )
      pure $ varRes a_lift

  iota <- letExp "iota" $ BasicOp $ Iota w (intConst Int64 0) (intConst Int64 1) Int64
  letTupExp "ys_right" $ Op $ Screma w [iota] $ mapSOAC lmb

finalMapPPAD :: VjpOps -> [VName] -> Scan SOACS -> ADM (Lambda SOACS)
finalMapPPAD :: VjpOps -> [VName] -> Scan SOACS -> ADM (Lambda SOACS)
finalMapPPAD VjpOps
ops [VName]
as Scan SOACS
scan = do
  as_types <- (VName -> ADM (TypeBase Shape NoUniqueness))
-> [VName] -> ADM [TypeBase Shape NoUniqueness]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType [VName]
as
  let arg_type_row = (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType [TypeBase Shape NoUniqueness]
as_types
  par_y_right <- zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_y_right")) as arg_type_row
  par_a <- zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_a")) as arg_type_row
  par_r_adj <- zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_r_adj")) as arg_type_row

  mkLambda (par_y_right ++ par_a ++ par_r_adj) $ do
    op_bar_2 <- mkScanAdjointLam ops (scanLambda scan) WrtSecond (Var . paramName <$> par_r_adj)
    eLambda op_bar_2 $ toExp . Var . paramName <$> par_y_right ++ par_a

diffScan :: VjpOps -> [VName] -> SubExp -> [VName] -> Scan SOACS -> ADM ()
diffScan :: VjpOps -> [VName] -> SubExp -> [VName] -> Scan SOACS -> ADM ()
diffScan VjpOps
ops [VName]
ys SubExp
w [VName]
as Scan SOACS
scan = do
  -- ys ~ results of scan, w ~ size of input array, as ~ (unzipped)
  -- arrays, scan ~ scan: operator with ne
  scan_case <- VjpOps -> Lambda SOACS -> ADM ScanAlgo
identifyCase VjpOps
ops (Lambda SOACS -> ADM ScanAlgo) -> Lambda SOACS -> ADM ScanAlgo
forall a b. (a -> b) -> a -> b
$ Scan SOACS -> Lambda SOACS
forall rep. Scan rep -> Lambda rep
scanLambda Scan SOACS
scan
  let d = [VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
as
  ys_adj <- mapM lookupAdjVal ys -- ys_bar
  as_ts <- mapM lookupType as

  as_contribs <- case scan_case of
    ScanAlgo
GenericPPAD -> do
      let e :: [SubExp]
e = Scan SOACS -> [SubExp]
forall rep. Scan rep -> [SubExp]
scanNeutral Scan SOACS
scan
      as_lift <- [VName] -> SubExp -> [SubExp] -> ADM [VName]
asLiftPPAD [VName]
as SubExp
w [SubExp]
e

      let m = [VName]
ys [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
as_lift [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
ys_adj

      op_lft <- mkPPADOpLifted ops as scan
      a_zero <- mapM (fmap Var . letExp "rscan_zero" . zeroExp . rowType) as_ts
      let lft_scan = Lambda SOACS -> [SubExp] -> Scan SOACS
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda SOACS
op_lft ([SubExp] -> Scan SOACS) -> [SubExp] -> Scan SOACS
forall a b. (a -> b) -> a -> b
$ [SubExp]
e [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
e [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
a_zero
      rs_adj <- (!! 2) . chunk d <$> scanRight m w lft_scan

      ys_right <- ysRightPPAD ys w e

      final_lmb <- finalMapPPAD ops as scan
      letTupExp "as_bar" $ Op $ Screma w (ys_right ++ as ++ rs_adj) $ mapSOAC final_lmb
    GenericIFL23 Special
sc -> do
      -- IFL23
      map1_lam <- VjpOps
-> SubExp
-> Lambda SOACS
-> [VName]
-> [VName]
-> [VName]
-> Special
-> Int
-> ADM (Lambda SOACS)
mkScanFusedMapLam VjpOps
ops SubExp
w (Scan SOACS -> Lambda SOACS
forall rep. Scan rep -> Lambda rep
scanLambda Scan SOACS
scan) [VName]
as [VName]
ys [VName]
ys_adj Special
sc Int
d
      scans_lin_fun_o <- mkScanLinFunO (head as_ts) sc
      scan_lams <- mkScans (specialScans sc) scans_lin_fun_o
      iota <-
        letExp "iota" $ BasicOp $ Iota w (intConst Int64 0) (intConst Int64 1) Int64
      r_scan <-
        letTupExp "adj_ctrb_scan" . Op . Screma w [iota] $
          scanomapSOAC scan_lams map1_lam
      mkScanFinalMap ops w (scanLambda scan) as ys (splitScanRes sc r_scan d)
  -- Goal: calculate as_contribs in new way
  -- zipWithM_ updateAdj as as_contribs -- as_bar += new adjoint
  zipWithM_ updateAdj as as_contribs
  where
    mkScans :: Int -> Scan SOACS -> ADM [Scan SOACS]
    mkScans :: Int -> Scan SOACS -> ADM [Scan SOACS]
mkScans Int
d Scan SOACS
s =
      Int -> ADM (Scan SOACS) -> ADM [Scan SOACS]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
d (ADM (Scan SOACS) -> ADM [Scan SOACS])
-> ADM (Scan SOACS) -> ADM [Scan SOACS]
forall a b. (a -> b) -> a -> b
$ do
        lam' <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda (Lambda SOACS -> ADM (Lambda SOACS))
-> Lambda SOACS -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ Scan SOACS -> Lambda SOACS
forall rep. Scan rep -> Lambda rep
scanLambda Scan SOACS
s
        pure $ Scan lam' $ scanNeutral s
    splitScanRes :: Special -> [b] -> Int -> [b]
splitScanRes Special
sc [b]
res Int
d =
      ([b] -> [b]) -> [[b]] -> [b]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Int -> [b] -> [b]
forall a. Int -> [a] -> [a]
take (Int -> Int -> Int
forall a. Integral a => a -> a -> a
div Int
d (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Special -> Int
specialScans Special
sc)) (Special -> [b] -> [[b]]
forall a. Special -> [a] -> [[a]]
orderArgs Special
sc [b]
res)

diffScanVec ::
  VjpOps ->
  [VName] ->
  StmAux () ->
  SubExp ->
  Lambda SOACS ->
  [SubExp] ->
  [VName] ->
  ADM () ->
  ADM ()
diffScanVec :: VjpOps
-> [VName]
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> [SubExp]
-> [VName]
-> ADM ()
-> ADM ()
diffScanVec VjpOps
ops [VName]
ys StmAux ()
aux SubExp
w Lambda SOACS
lam [SubExp]
ne [VName]
as ADM ()
m = do
  stmts <- ADM [()] -> ADM (Stms (Rep ADM))
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (ADM [()] -> ADM (Stms (Rep ADM)))
-> ADM [()] -> ADM (Stms (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
    rank <- TypeBase Shape NoUniqueness -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (TypeBase Shape NoUniqueness -> Int)
-> ADM (TypeBase Shape NoUniqueness) -> ADM Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType ([VName] -> VName
forall a. HasCallStack => [a] -> a
head [VName]
as)
    let rear = [Int
1, Int
0] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
drop Int
2 [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]

    transp_as <-
      forM as $ \VName
a ->
        [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
a [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_transp") (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ VName -> [Int] -> BasicOp
Rearrange VName
a [Int]
rear

    ts <- traverse lookupType transp_as
    let n = Int -> [TypeBase Shape NoUniqueness] -> SubExp
forall u. Int -> [TypeBase Shape u] -> SubExp
arraysSize Int
0 [TypeBase Shape NoUniqueness]
ts

    as_par <- traverse (newParam "as_par" . rowType) ts
    ne_par <- traverse (newParam "ne_par") $ lambdaReturnType lam

    scan_form <- scanSOAC [Scan lam (map (Var . paramName) ne_par)]

    map_lam <-
      mkLambda (as_par ++ ne_par) . fmap varsRes . letTupExp "map_res" . Op $
        Screma w (map paramName as_par) scan_form

    transp_ys <-
      letTupExp "trans_ys" . Op $
        Screma n (transp_as ++ subExpVars ne) (mapSOAC map_lam)

    forM (zip ys transp_ys) $ \(VName
y, VName
x) ->
      StmAux () -> ADM () -> ADM ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
y] (Exp (Rep ADM) -> ADM ()) -> Exp (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ VName -> [Int] -> BasicOp
Rearrange VName
x [Int]
rear

  foldr (vjpStm ops) m stmts

diffScanAdd :: VjpOps -> VName -> SubExp -> Lambda SOACS -> SubExp -> VName -> ADM ()
diffScanAdd :: VjpOps
-> VName -> SubExp -> Lambda SOACS -> SubExp -> VName -> ADM ()
diffScanAdd VjpOps
_ops VName
ys SubExp
n Lambda SOACS
lam' SubExp
ne VName
as = do
  lam <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam'
  ys_bar <- lookupAdjVal ys

  map_scan <- rev_arr_lam ys_bar

  iota <-
    letExp "iota" $ BasicOp $ Iota n (intConst Int64 0) (intConst Int64 1) Int64

  scan_res <-
    letExp "res_rev" $ Op $ Screma n [iota] $ scanomapSOAC [Scan lam [ne]] map_scan

  rev_lam <- rev_arr_lam scan_res
  contrb <- letExp "contrb" $ Op $ Screma n [iota] $ mapSOAC rev_lam

  updateAdj as contrb
  where
    rev_arr_lam :: VName -> ADM (Lambda SOACS)
    rev_arr_lam :: VName -> ADM (Lambda SOACS)
rev_arr_lam VName
arr = do
      par_i <- [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" (TypeBase Shape NoUniqueness
 -> ADM (Param (TypeBase Shape NoUniqueness)))
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
      mkLambda [par_i] $ do
        a <-
          letExp "ys_bar_rev"
            =<< eIndex arr [toExp (pe64 n - le64 (paramName par_i) - 1)]
        pure [varRes a]