module Futhark.AD.Rev.Scan (diffScan, diffScanVec, diffScanAdd) where
import Control.Monad
import Data.List (transpose)
import Futhark.AD.Rev.Monad
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.IR.SOACS.Simplify (simplifyLambda)
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Util (chunk)
data FirstOrSecond = WrtFirst | WrtSecond
identityM :: Int -> Type -> ADM [[SubExp]]
identityM :: Int -> TypeBase Shape NoUniqueness -> ADM [[SubExp]]
identityM Int
n TypeBase Shape NoUniqueness
t =
([Exp SOACS] -> ADM [SubExp]) -> [[Exp SOACS]] -> ADM [[SubExp]]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse
((Exp SOACS -> ADM SubExp) -> [Exp SOACS] -> ADM [SubExp]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ([Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"id"))
[[if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
j then TypeBase Shape NoUniqueness -> Exp SOACS
forall rep. TypeBase Shape NoUniqueness -> Exp rep
oneExp TypeBase Shape NoUniqueness
t else TypeBase Shape NoUniqueness -> Exp SOACS
forall rep. TypeBase Shape NoUniqueness -> Exp rep
zeroExp TypeBase Shape NoUniqueness
t | Int
i <- [Int
1 .. Int
n]] | Int
j <- [Int
1 .. Int
n]]
matrixMul :: [[PrimExp VName]] -> [[PrimExp VName]] -> PrimType -> [[PrimExp VName]]
matrixMul :: [[PrimExp VName]]
-> [[PrimExp VName]] -> PrimType -> [[PrimExp VName]]
matrixMul [[PrimExp VName]]
m1 [[PrimExp VName]]
m2 PrimType
t =
let zero :: PrimExp VName
zero = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t (SubExp -> PrimExp VName) -> SubExp -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t
in [[(PrimExp VName -> PrimExp VName -> PrimExp VName)
-> PrimExp VName -> [PrimExp VName] -> PrimExp VName
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
(~+~) PrimExp VName
zero ([PrimExp VName] -> PrimExp VName)
-> [PrimExp VName] -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ (PrimExp VName -> PrimExp VName -> PrimExp VName)
-> [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
(~*~) [PrimExp VName]
r [PrimExp VName]
q | [PrimExp VName]
q <- [[PrimExp VName]] -> [[PrimExp VName]]
forall a. [[a]] -> [[a]]
transpose [[PrimExp VName]]
m2] | [PrimExp VName]
r <- [[PrimExp VName]]
m1]
matrixVecMul :: [[PrimExp VName]] -> [PrimExp VName] -> PrimType -> [PrimExp VName]
matrixVecMul :: [[PrimExp VName]] -> [PrimExp VName] -> PrimType -> [PrimExp VName]
matrixVecMul [[PrimExp VName]]
m [PrimExp VName]
v PrimType
t =
let zero :: PrimExp VName
zero = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t (SubExp -> PrimExp VName) -> SubExp -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t
in [(PrimExp VName -> PrimExp VName -> PrimExp VName)
-> PrimExp VName -> [PrimExp VName] -> PrimExp VName
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
(~+~) PrimExp VName
zero ([PrimExp VName] -> PrimExp VName)
-> [PrimExp VName] -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ (PrimExp VName -> PrimExp VName -> PrimExp VName)
-> [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
(~*~) [PrimExp VName]
v [PrimExp VName]
r | [PrimExp VName]
r <- [[PrimExp VName]]
m]
vectorAdd :: [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
vectorAdd :: [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
vectorAdd = (PrimExp VName -> PrimExp VName -> PrimExp VName)
-> [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
(~+~)
orderArgs :: Special -> [a] -> [[a]]
orderArgs :: forall a. Special -> [a] -> [[a]]
orderArgs Special
s [a]
lst = Int -> [a] -> [[a]]
forall a. Int -> [a] -> [[a]]
chunk (Int -> Int -> Int
forall a. Integral a => a -> a -> a
div ([a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
lst) (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Special -> Int
specialScans Special
s) [a]
lst
mkScanAdjointLam :: VjpOps -> Lambda SOACS -> FirstOrSecond -> [SubExp] -> ADM (Lambda SOACS)
mkScanAdjointLam :: VjpOps
-> Lambda SOACS -> FirstOrSecond -> [SubExp] -> ADM (Lambda SOACS)
mkScanAdjointLam VjpOps
ops Lambda SOACS
lam0 FirstOrSecond
which [SubExp]
adjs = do
let len :: Int
len = [TypeBase Shape NoUniqueness] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([TypeBase Shape NoUniqueness] -> Int)
-> [TypeBase Shape NoUniqueness] -> Int
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
lam0
lam <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam0
let p2diff =
case FirstOrSecond
which of
FirstOrSecond
WrtFirst -> Int
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. Int -> [a] -> [a]
take Int
len ([Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)])
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam
FirstOrSecond
WrtSecond -> Int
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. Int -> [a] -> [a]
drop Int
len ([Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)])
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam
vjpLambda ops (fmap AdjVal adjs) (map paramName p2diff) lam
mkScanFusedMapLam ::
VjpOps ->
SubExp ->
Lambda SOACS ->
[VName] ->
[VName] ->
[VName] ->
Special ->
Int ->
ADM (Lambda SOACS)
mkScanFusedMapLam :: VjpOps
-> SubExp
-> Lambda SOACS
-> [VName]
-> [VName]
-> [VName]
-> Special
-> Int
-> ADM (Lambda SOACS)
mkScanFusedMapLam VjpOps
ops SubExp
w Lambda SOACS
scn_lam [VName]
xs [VName]
ys [VName]
ys_adj Special
s Int
d = do
let sc :: Maybe SpecialCase
sc = Special -> Maybe SpecialCase
specialCase Special
s
k :: Int
k = Special -> Int
specialSubSize Special
s
ys_ts <- (VName -> ADM (TypeBase Shape NoUniqueness))
-> [VName] -> ADM [TypeBase Shape NoUniqueness]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType [VName]
ys
idmat <- identityM (length ys) $ rowType $ head ys_ts
lams <- traverse (mkScanAdjointLam ops scn_lam WrtFirst) idmat
par_i <- newParam "i" $ Prim int64
let i = Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
par_i
mkLambda [par_i] $
fmap varsRes . letTupExp "x"
=<< eIf
(toExp $ le64 i .==. 0)
( buildBody_ $ do
j <- letSubExp "j" =<< toExp (pe64 w - (le64 i + 1))
y_s <- forM ys_adj $ \VName
y_ ->
[Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp (VName -> [Char]
baseString VName
y_ [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_j") (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eIndex VName
y_ [SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
j]
let zso = Special -> [SubExp] -> [[SubExp]]
forall a. Special -> [a] -> [[a]]
orderArgs Special
s [SubExp]
y_s
let ido = Special -> [[SubExp]] -> [[[SubExp]]]
forall a. Special -> [a] -> [[a]]
orderArgs Special
s ([[SubExp]] -> [[[SubExp]]]) -> [[SubExp]] -> [[[SubExp]]]
forall a b. (a -> b) -> a -> b
$ Int -> Maybe SpecialCase -> [[SubExp]] -> [[SubExp]]
forall a. Int -> Maybe SpecialCase -> [[a]] -> [[a]]
caseJac Int
k Maybe SpecialCase
sc [[SubExp]]
idmat
pure $ subExpsRes $ concat $ zipWith (++) zso $ fmap concat ido
)
( buildBody_ $ do
j <- letSubExp "j" =<< toExp (pe64 w - (le64 i + 1))
j1 <- letSubExp "j1" =<< toExp (pe64 w - le64 i)
y_s <- forM ys_adj $ \VName
y_ ->
[Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp (VName -> [Char]
baseString VName
y_ [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_j") (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eIndex VName
y_ [SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
j]
let args =
(VName -> ADM (Exp (Rep ADM))) -> [VName] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
`eIndex` [SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
j]) [VName]
ys [ADM (Exp (Rep ADM))]
-> [ADM (Exp (Rep ADM))] -> [ADM (Exp (Rep ADM))]
forall a. [a] -> [a] -> [a]
++ (VName -> ADM (Exp (Rep ADM))) -> [VName] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
`eIndex` [SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
j1]) [VName]
xs
lam_rs <- traverse (`eLambda` args) lams
let yso = Special -> Result -> [Result]
forall a. Special -> [a] -> [[a]]
orderArgs Special
s (Result -> [Result]) -> Result -> [Result]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
y_s
let jaco = Special -> [Result] -> [[Result]]
forall a. Special -> [a] -> [[a]]
orderArgs Special
s ([Result] -> [[Result]]) -> [Result] -> [[Result]]
forall a b. (a -> b) -> a -> b
$ Int -> Maybe SpecialCase -> [Result] -> [Result]
forall a. Int -> Maybe SpecialCase -> [[a]] -> [[a]]
caseJac Int
k Maybe SpecialCase
sc ([Result] -> [Result]) -> [Result] -> [Result]
forall a b. (a -> b) -> a -> b
$ [Result] -> [Result]
forall a. [[a]] -> [[a]]
transpose [Result]
lam_rs
pure $ concat $ zipWith (++) yso $ fmap concat jaco
)
where
caseJac :: Int -> Maybe SpecialCase -> [[a]] -> [[a]]
caseJac :: forall a. Int -> Maybe SpecialCase -> [[a]] -> [[a]]
caseJac Int
_ Maybe SpecialCase
Nothing [[a]]
jac = [[a]]
jac
caseJac Int
k (Just SpecialCase
ZeroQuadrant) [[a]]
jac =
[[[a]]] -> [[a]]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[[a]]] -> [[a]]) -> [[[a]]] -> [[a]]
forall a b. (a -> b) -> a -> b
$
(Int -> [[a]] -> [[a]]) -> [Int] -> [[[a]]] -> [[[a]]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Int
i -> ([a] -> [a]) -> [[a]] -> [[a]]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
take Int
k ([a] -> [a]) -> ([a] -> [a]) -> [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
drop (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
k))) [Int
0 .. Int
d Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
k] ([[[a]]] -> [[[a]]]) -> [[[a]]] -> [[[a]]]
forall a b. (a -> b) -> a -> b
$
Int -> [[a]] -> [[[a]]]
forall a. Int -> [a] -> [[a]]
chunk Int
k [[a]]
jac
caseJac Int
k (Just SpecialCase
MatrixMul) [[a]]
jac =
Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
take Int
k ([a] -> [a]) -> [[a]] -> [[a]]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> [[a]] -> [[a]]
forall a. Int -> [a] -> [a]
take Int
k [[a]]
jac
linFunT0 :: [PrimExp VName] -> [PrimExp VName] -> [[PrimExp VName]] -> Special -> PrimType -> [PrimExp VName]
linFunT0 :: [PrimExp VName]
-> [PrimExp VName]
-> [[PrimExp VName]]
-> Special
-> PrimType
-> [PrimExp VName]
linFunT0 [PrimExp VName]
a1 [PrimExp VName]
a2 [[PrimExp VName]]
b Special
s PrimType
pt =
let t :: [PrimExp VName]
t = case Special -> Maybe SpecialCase
specialCase Special
s of
Just SpecialCase
MatrixMul ->
([PrimExp VName] -> [PrimExp VName])
-> [[PrimExp VName]] -> [PrimExp VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\[PrimExp VName]
v -> [[PrimExp VName]] -> [PrimExp VName] -> PrimType -> [PrimExp VName]
matrixVecMul [[PrimExp VName]]
b [PrimExp VName]
v PrimType
pt) ([[PrimExp VName]] -> [PrimExp VName])
-> [[PrimExp VName]] -> [PrimExp VName]
forall a b. (a -> b) -> a -> b
$ Int -> [PrimExp VName] -> [[PrimExp VName]]
forall a. Int -> [a] -> [[a]]
chunk (Special -> Int
specialSubSize Special
s) [PrimExp VName]
a1
Maybe SpecialCase
_ -> [[PrimExp VName]] -> [PrimExp VName] -> PrimType -> [PrimExp VName]
matrixVecMul [[PrimExp VName]]
b [PrimExp VName]
a1 PrimType
pt
in [PrimExp VName]
a2 [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
`vectorAdd` [PrimExp VName]
t
mkScanLinFunO :: Type -> Special -> ADM (Scan SOACS)
mkScanLinFunO :: TypeBase Shape NoUniqueness -> Special -> ADM (Scan SOACS)
mkScanLinFunO TypeBase Shape NoUniqueness
t Special
s = do
let pt :: PrimType
pt = TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
t
neu_elm <- (Int, Int) -> ADM [SubExp]
mkNeutral ((Int, Int) -> ADM [SubExp]) -> (Int, Int) -> ADM [SubExp]
forall a b. (a -> b) -> a -> b
$ Special -> (Int, Int)
specialNeutral Special
s
let (as, bs) = specialParams s
(a1s, b1s, a2s, b2s) <- mkParams (as, bs)
let pet = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
pt (SubExp -> PrimExp VName)
-> (VName -> SubExp) -> VName -> PrimExp VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var
let (_, n) = specialNeutral s
lam <- mkLambda (map (\VName
v -> Attrs
-> VName
-> TypeBase Shape NoUniqueness
-> Param (TypeBase Shape NoUniqueness)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty VName
v (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType TypeBase Shape NoUniqueness
t)) (a1s ++ b1s ++ a2s ++ b2s)) . fmap subExpsRes $ do
let [a1s', b1s', a2s', b2s'] = (fmap . fmap) pet [a1s, b1s, a2s, b2s]
let (b1sm, b2sm) = (chunk n b1s', chunk n b2s')
let t0 = [PrimExp VName]
-> [PrimExp VName]
-> [[PrimExp VName]]
-> Special
-> PrimType
-> [PrimExp VName]
linFunT0 [PrimExp VName]
a1s' [PrimExp VName]
a2s' [[PrimExp VName]]
b2sm Special
s PrimType
pt
let t1 = [[PrimExp VName]] -> [PrimExp VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[PrimExp VName]] -> [PrimExp VName])
-> [[PrimExp VName]] -> [PrimExp VName]
forall a b. (a -> b) -> a -> b
$ [[PrimExp VName]]
-> [[PrimExp VName]] -> PrimType -> [[PrimExp VName]]
matrixMul [[PrimExp VName]]
b2sm [[PrimExp VName]]
b1sm PrimType
pt
traverse (letSubExp "r" <=< toExp) $ t0 ++ t1
pure $ Scan lam neu_elm
where
mkNeutral :: (Int, Int) -> ADM [SubExp]
mkNeutral (Int
a, Int
b) = do
zeros <- Int -> ADM SubExp -> ADM [SubExp]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
a (ADM SubExp -> ADM [SubExp]) -> ADM SubExp -> ADM [SubExp]
forall a b. (a -> b) -> a -> b
$ [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"zeros" (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> Exp (Rep ADM)
forall rep. TypeBase Shape NoUniqueness -> Exp rep
zeroExp (TypeBase Shape NoUniqueness -> Exp (Rep ADM))
-> TypeBase Shape NoUniqueness -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType TypeBase Shape NoUniqueness
t
idmat <- identityM b $ Prim $ elemType t
pure $ zeros ++ concat idmat
mkParams :: (Int, Int) -> m ([VName], [VName], [VName], [VName])
mkParams (Int
a, Int
b) = do
a1s <- Int -> m VName -> m [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
a (m VName -> m [VName]) -> m VName -> m [VName]
forall a b. (a -> b) -> a -> b
$ [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"a1"
b1s <- replicateM b $ newVName "b1"
a2s <- replicateM a $ newVName "a2"
b2s <- replicateM b $ newVName "b2"
pure (a1s, b1s, a2s, b2s)
mkScanFinalMap :: VjpOps -> SubExp -> Lambda SOACS -> [VName] -> [VName] -> [VName] -> ADM [VName]
mkScanFinalMap :: VjpOps
-> SubExp
-> Lambda SOACS
-> [VName]
-> [VName]
-> [VName]
-> ADM [VName]
mkScanFinalMap VjpOps
ops SubExp
w Lambda SOACS
scan_lam [VName]
xs [VName]
ys [VName]
ds = do
let eltps :: [TypeBase Shape NoUniqueness]
eltps = Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
scan_lam
par_i <- [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" (TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness)))
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
let i = Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
par_i
par_x <- zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_x")) xs eltps
map_lam <-
mkLambda (par_i : par_x) $ do
j <- letSubExp "j" =<< toExp (pe64 w - (le64 i + 1))
dj <-
forM ds $ \VName
dd ->
[Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
dd [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_dj") (Exp SOACS -> ADM VName) -> ADM (Exp SOACS) -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eIndex VName
dd [SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
j]
fmap varsRes . letTupExp "scan_contribs"
=<< eIf
(toExp $ le64 i .==. 0)
(resultBodyM $ fmap Var dj)
( buildBody_ $ do
lam <- mkScanAdjointLam ops scan_lam WrtSecond $ fmap Var dj
im1 <- letSubExp "im1" =<< toExp (le64 i - 1)
ys_im1 <- forM ys $ \VName
y ->
[Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp (VName -> [Char]
baseString VName
y [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_im1") (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eIndex VName
y [SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
im1]
let args = (SubExp -> ADM (Exp (Rep ADM)))
-> [SubExp] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp ([SubExp] -> [ADM (Exp (Rep ADM))])
-> [SubExp] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ [SubExp]
ys_im1 [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ (Param (TypeBase Shape NoUniqueness) -> SubExp)
-> [Param (TypeBase Shape NoUniqueness)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (Param (TypeBase Shape NoUniqueness) -> VName)
-> Param (TypeBase Shape NoUniqueness)
-> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName) [Param (TypeBase Shape NoUniqueness)]
par_x
eLambda lam args
)
iota <- letExp "iota" $ BasicOp $ Iota w (intConst Int64 0) (intConst Int64 1) Int64
letTupExp "scan_contribs" $ Op $ Screma w (iota : xs) $ mapSOAC map_lam
data SpecialCase = ZeroQuadrant | MatrixMul deriving (Int -> SpecialCase -> [Char] -> [Char]
[SpecialCase] -> [Char] -> [Char]
SpecialCase -> [Char]
(Int -> SpecialCase -> [Char] -> [Char])
-> (SpecialCase -> [Char])
-> ([SpecialCase] -> [Char] -> [Char])
-> Show SpecialCase
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
$cshowsPrec :: Int -> SpecialCase -> [Char] -> [Char]
showsPrec :: Int -> SpecialCase -> [Char] -> [Char]
$cshow :: SpecialCase -> [Char]
show :: SpecialCase -> [Char]
$cshowList :: [SpecialCase] -> [Char] -> [Char]
showList :: [SpecialCase] -> [Char] -> [Char]
Show)
data Special = Special
{
Special -> (Int, Int)
specialNeutral :: (Int, Int),
Special -> (Int, Int)
specialParams :: (Int, Int),
Special -> Int
specialScans :: Int,
Special -> Int
specialSubSize :: Int,
Special -> Maybe SpecialCase
specialCase :: Maybe SpecialCase
}
deriving (Int -> Special -> [Char] -> [Char]
[Special] -> [Char] -> [Char]
Special -> [Char]
(Int -> Special -> [Char] -> [Char])
-> (Special -> [Char])
-> ([Special] -> [Char] -> [Char])
-> Show Special
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
$cshowsPrec :: Int -> Special -> [Char] -> [Char]
showsPrec :: Int -> Special -> [Char] -> [Char]
$cshow :: Special -> [Char]
show :: Special -> [Char]
$cshowList :: [Special] -> [Char] -> [Char]
showList :: [Special] -> [Char] -> [Char]
Show)
data ScanAlgo
=
GenericIFL23 Special
|
GenericPPAD
deriving (Int -> ScanAlgo -> [Char] -> [Char]
[ScanAlgo] -> [Char] -> [Char]
ScanAlgo -> [Char]
(Int -> ScanAlgo -> [Char] -> [Char])
-> (ScanAlgo -> [Char])
-> ([ScanAlgo] -> [Char] -> [Char])
-> Show ScanAlgo
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
$cshowsPrec :: Int -> ScanAlgo -> [Char] -> [Char]
showsPrec :: Int -> ScanAlgo -> [Char] -> [Char]
$cshow :: ScanAlgo -> [Char]
show :: ScanAlgo -> [Char]
$cshowList :: [ScanAlgo] -> [Char] -> [Char]
showList :: [ScanAlgo] -> [Char] -> [Char]
Show)
subMats :: Int -> [[Exp SOACS]] -> Exp SOACS -> Maybe Int
subMats :: Int -> [[Exp SOACS]] -> Exp SOACS -> Maybe Int
subMats Int
d [[Exp SOACS]]
mat Exp SOACS
zero =
let sub_d :: [Int]
sub_d = (Int -> Bool) -> [Int] -> [Int]
forall a. (a -> Bool) -> [a] -> [a]
filter (\Int
x -> Int
d Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
x Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) [Int
1 .. (Int
d Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2)]
poss :: [Bool]
poss = (Int -> Bool) -> [Int] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (\Int
m -> (([Exp SOACS], Int) -> Bool) -> [([Exp SOACS], Int)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Int -> ([Exp SOACS], Int) -> Bool
ok Int
m) ([([Exp SOACS], Int)] -> Bool) -> [([Exp SOACS], Int)] -> Bool
forall a b. (a -> b) -> a -> b
$ [[Exp SOACS]] -> [Int] -> [([Exp SOACS], Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [[Exp SOACS]]
mat [Int
0 .. Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]) [Int]
sub_d
tmp :: [(Bool, Int)]
tmp = ((Bool, Int) -> Bool) -> [(Bool, Int)] -> [(Bool, Int)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool, Int) -> Bool
forall a b. (a, b) -> a
fst ([Bool] -> [Int] -> [(Bool, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
poss [Int]
sub_d)
in if [(Bool, Int)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Bool, Int)]
tmp then Maybe Int
forall a. Maybe a
Nothing else Int -> Maybe Int
forall a. a -> Maybe a
Just (Int -> Maybe Int) -> Int -> Maybe Int
forall a b. (a -> b) -> a -> b
$ (Bool, Int) -> Int
forall a b. (a, b) -> b
snd ((Bool, Int) -> Int) -> (Bool, Int) -> Int
forall a b. (a -> b) -> a -> b
$ [(Bool, Int)] -> (Bool, Int)
forall a. HasCallStack => [a] -> a
head [(Bool, Int)]
tmp
where
ok :: Int -> ([Exp SOACS], Int) -> Bool
ok Int
m ([Exp SOACS]
row, Int
i) =
((Exp SOACS, Int) -> Bool) -> [(Exp SOACS, Int)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(Exp SOACS
v, Int
j) -> Exp SOACS
v Exp SOACS -> Exp SOACS -> Bool
forall a. Eq a => a -> a -> Bool
== Exp SOACS
zero Bool -> Bool -> Bool
|| Int
i Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
j Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
m) ([(Exp SOACS, Int)] -> Bool) -> [(Exp SOACS, Int)] -> Bool
forall a b. (a -> b) -> a -> b
$
[Exp SOACS] -> [Int] -> [(Exp SOACS, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Exp SOACS]
row [Int
0 .. Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
cases :: Int -> Type -> [[Exp SOACS]] -> ScanAlgo
cases :: Int -> TypeBase Shape NoUniqueness -> [[Exp SOACS]] -> ScanAlgo
cases Int
d TypeBase Shape NoUniqueness
t [[Exp SOACS]]
mat = case Int -> [[Exp SOACS]] -> Exp SOACS -> Maybe Int
subMats Int
d [[Exp SOACS]]
mat (Exp SOACS -> Maybe Int) -> Exp SOACS -> Maybe Int
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> Exp SOACS
forall rep. TypeBase Shape NoUniqueness -> Exp rep
zeroExp TypeBase Shape NoUniqueness
t of
Just Int
k ->
let nonZeros :: [[[Exp SOACS]]]
nonZeros = (Int -> [[Exp SOACS]] -> [[Exp SOACS]])
-> [Int] -> [[[Exp SOACS]]] -> [[[Exp SOACS]]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Int
i -> ([Exp SOACS] -> [Exp SOACS]) -> [[Exp SOACS]] -> [[Exp SOACS]]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> [Exp SOACS] -> [Exp SOACS]
forall a. Int -> [a] -> [a]
take Int
k ([Exp SOACS] -> [Exp SOACS])
-> ([Exp SOACS] -> [Exp SOACS]) -> [Exp SOACS] -> [Exp SOACS]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [Exp SOACS] -> [Exp SOACS]
forall a. Int -> [a] -> [a]
drop (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
k))) [Int
0 .. Int
d Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
k] ([[[Exp SOACS]]] -> [[[Exp SOACS]]])
-> [[[Exp SOACS]]] -> [[[Exp SOACS]]]
forall a b. (a -> b) -> a -> b
$ Int -> [[Exp SOACS]] -> [[[Exp SOACS]]]
forall a. Int -> [a] -> [[a]]
chunk Int
k [[Exp SOACS]]
mat
in if ([[Exp SOACS]] -> Bool) -> [[[Exp SOACS]]] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ([[Exp SOACS]] -> [[Exp SOACS]] -> Bool
forall a. Eq a => a -> a -> Bool
== [[[Exp SOACS]]] -> [[Exp SOACS]]
forall a. HasCallStack => [a] -> a
head [[[Exp SOACS]]]
nonZeros) ([[[Exp SOACS]]] -> Bool) -> [[[Exp SOACS]]] -> Bool
forall a b. (a -> b) -> a -> b
$ [[[Exp SOACS]]] -> [[[Exp SOACS]]]
forall a. HasCallStack => [a] -> [a]
tail [[[Exp SOACS]]]
nonZeros
then Special -> ScanAlgo
GenericIFL23 (Special -> ScanAlgo) -> Special -> ScanAlgo
forall a b. (a -> b) -> a -> b
$ (Int, Int)
-> (Int, Int) -> Int -> Int -> Maybe SpecialCase -> Special
Special (Int
d, Int
k) (Int
d, Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
k) Int
1 Int
k (Maybe SpecialCase -> Special) -> Maybe SpecialCase -> Special
forall a b. (a -> b) -> a -> b
$ SpecialCase -> Maybe SpecialCase
forall a. a -> Maybe a
Just SpecialCase
MatrixMul
else Special -> ScanAlgo
GenericIFL23 (Special -> ScanAlgo) -> Special -> ScanAlgo
forall a b. (a -> b) -> a -> b
$ (Int, Int)
-> (Int, Int) -> Int -> Int -> Maybe SpecialCase -> Special
Special (Int
k, Int
k) (Int
k, Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
k) (Int
d Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
k) Int
k (Maybe SpecialCase -> Special) -> Maybe SpecialCase -> Special
forall a b. (a -> b) -> a -> b
$ SpecialCase -> Maybe SpecialCase
forall a. a -> Maybe a
Just SpecialCase
ZeroQuadrant
Maybe Int
Nothing ->
case Int
d of
Int
1 -> Special -> ScanAlgo
GenericIFL23 (Special -> ScanAlgo) -> Special -> ScanAlgo
forall a b. (a -> b) -> a -> b
$ (Int, Int)
-> (Int, Int) -> Int -> Int -> Maybe SpecialCase -> Special
Special (Int
d, Int
d) (Int
d, Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
d) Int
1 Int
d Maybe SpecialCase
forall a. Maybe a
Nothing
Int
_ -> ScanAlgo
GenericPPAD
identifyCase :: VjpOps -> Lambda SOACS -> ADM ScanAlgo
identifyCase :: VjpOps -> Lambda SOACS -> ADM ScanAlgo
identifyCase VjpOps
ops Lambda SOACS
lam = do
let t :: [TypeBase Shape NoUniqueness]
t = Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
lam
let d :: Int
d = [TypeBase Shape NoUniqueness] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TypeBase Shape NoUniqueness]
t
idmat <- Int -> TypeBase Shape NoUniqueness -> ADM [[SubExp]]
identityM Int
d (TypeBase Shape NoUniqueness -> ADM [[SubExp]])
-> TypeBase Shape NoUniqueness -> ADM [[SubExp]]
forall a b. (a -> b) -> a -> b
$ [TypeBase Shape NoUniqueness] -> TypeBase Shape NoUniqueness
forall a. HasCallStack => [a] -> a
head [TypeBase Shape NoUniqueness]
t
lams <- traverse (mkScanAdjointLam ops lam WrtFirst) idmat
par1 <- traverse (newParam "tmp1") t
par2 <- traverse (newParam "tmp2") t
jac_lam <- mkLambda (par1 ++ par2) $ do
let args = (Param (TypeBase Shape NoUniqueness) -> ADM (Exp (Rep ADM)))
-> [Param (TypeBase Shape NoUniqueness)] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Param (TypeBase Shape NoUniqueness) -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam ([Param (TypeBase Shape NoUniqueness)] -> [ADM (Exp (Rep ADM))])
-> [Param (TypeBase Shape NoUniqueness)] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ [Param (TypeBase Shape NoUniqueness)]
par1 [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [Param (TypeBase Shape NoUniqueness)]
par2
lam_rs <- traverse (`eLambda` args) lams
pure $ concat (transpose lam_rs)
simp <- simplifyLambda jac_lam
let jac = Int -> [Exp rep] -> [[Exp rep]]
forall a. Int -> [a] -> [[a]]
chunk Int
d ([Exp rep] -> [[Exp rep]]) -> [Exp rep] -> [[Exp rep]]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> Exp rep) -> Result -> [Exp rep]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep)
-> (SubExpRes -> BasicOp) -> SubExpRes -> Exp rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp (SubExp -> BasicOp)
-> (SubExpRes -> SubExp) -> SubExpRes -> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) (Result -> [Exp rep]) -> Result -> [Exp rep]
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult (Body SOACS -> Result) -> Body SOACS -> Result
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
simp
pure $ cases d (head t) jac
scanRight :: [VName] -> SubExp -> Scan SOACS -> ADM [VName]
scanRight :: [VName] -> SubExp -> Scan SOACS -> ADM [VName]
scanRight [VName]
as SubExp
w Scan SOACS
scan = do
as_types <- (VName -> ADM (TypeBase Shape NoUniqueness))
-> [VName] -> ADM [TypeBase Shape NoUniqueness]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType [VName]
as
let arg_type_row = (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType [TypeBase Shape NoUniqueness]
as_types
par_a1 <- zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_par_a1")) as arg_type_row
par_a2 <- zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_par_a2")) as arg_type_row
rev_op <- mkLambda (par_a1 <> par_a2) $ do
op <- renameLambda $ scanLambda scan
eLambda op (map (toExp . paramName) (par_a2 <> par_a1))
let e = Scan SOACS -> [SubExp]
forall rep. Scan rep -> [SubExp]
scanNeutral Scan SOACS
scan
let rev_scan = Lambda SOACS -> [SubExp] -> Scan SOACS
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda SOACS
rev_op [SubExp]
e
iota <-
letExp "iota" $ BasicOp $ Iota w (intConst Int64 0) (intConst Int64 1) Int64
map_scan <- revArrLam as
scan_res <-
letTupExp "adj_ctrb_scan" . Op . Screma w [iota] $
scanomapSOAC [rev_scan] map_scan
rev_lam <- revArrLam scan_res
letTupExp "reverse_scan_result" $ Op $ Screma w [iota] $ mapSOAC rev_lam
where
revArrLam :: [VName] -> ADM (Lambda SOACS)
revArrLam :: [VName] -> ADM (Lambda SOACS)
revArrLam [VName]
arrs = do
par_i <- [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" (TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness)))
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
mkLambda [par_i] . forM arrs $ \VName
arr ->
(VName -> SubExpRes) -> ADM VName -> ADM SubExpRes
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> SubExpRes
varRes (ADM VName -> ADM SubExpRes)
-> (Exp SOACS -> ADM VName) -> Exp SOACS -> ADM SubExpRes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"ys_bar_rev"
(Exp SOACS -> ADM SubExpRes) -> ADM (Exp SOACS) -> ADM SubExpRes
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eIndex VName
arr [TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
par_i) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)]
mkPPADOpLifted :: VjpOps -> [VName] -> Scan SOACS -> ADM (Lambda SOACS)
mkPPADOpLifted :: VjpOps -> [VName] -> Scan SOACS -> ADM (Lambda SOACS)
mkPPADOpLifted VjpOps
ops [VName]
as Scan SOACS
scan = do
as_types <- (VName -> ADM (TypeBase Shape NoUniqueness))
-> [VName] -> ADM [TypeBase Shape NoUniqueness]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType [VName]
as
let arg_type_row = (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType [TypeBase Shape NoUniqueness]
as_types
par_x1 <- zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_x1")) as arg_type_row
par_x2_unused <- zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_x2_unused")) as arg_type_row
par_a1 <- zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_a1")) as arg_type_row
par_a2 <- zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_a2")) as arg_type_row
par_y1_h <- zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_y1_h")) as arg_type_row
par_y2_h <- zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_y2_h")) as arg_type_row
add_lams <- mapM addLambda arg_type_row
mkLambda
(par_x1 ++ par_a1 ++ par_y1_h ++ par_x2_unused ++ par_a2 ++ par_y2_h)
(op_lift par_x1 par_a1 par_y1_h par_a2 par_y2_h add_lams)
where
op_lift :: [Param dec]
-> [Param dec]
-> [Param dec]
-> [Param dec]
-> [Param dec]
-> [Lambda SOACS]
-> ADM Result
op_lift [Param dec]
px1 [Param dec]
pa1 [Param dec]
py1 [Param dec]
pa2 [Param dec]
py2 [Lambda SOACS]
adds = do
op_bar_1 <- VjpOps
-> Lambda SOACS -> FirstOrSecond -> [SubExp] -> ADM (Lambda SOACS)
mkScanAdjointLam VjpOps
ops (Scan SOACS -> Lambda SOACS
forall rep. Scan rep -> Lambda rep
scanLambda Scan SOACS
scan) FirstOrSecond
WrtFirst (VName -> SubExp
Var (VName -> SubExp) -> (Param dec -> VName) -> Param dec -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName (Param dec -> SubExp) -> [Param dec] -> [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Param dec]
py2)
let op_bar_args = SubExp -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
toExp (SubExp -> ADM (Exp (Rep ADM)))
-> (Param dec -> SubExp) -> Param dec -> ADM (Exp (Rep ADM))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp) -> (Param dec -> VName) -> Param dec -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName (Param dec -> ADM (Exp (Rep ADM)))
-> [Param dec] -> [ADM (Exp (Rep ADM))]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Param dec]
px1 [Param dec] -> [Param dec] -> [Param dec]
forall a. [a] -> [a] -> [a]
++ [Param dec]
pa1
z_term <- map resSubExp <$> eLambda op_bar_1 op_bar_args
let z =
((SubExp, SubExp, Lambda (Rep ADM)) -> ADM SubExpRes)
-> [(SubExp, SubExp, Lambda (Rep ADM))] -> ADM Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM
(\(SubExp
z_t, SubExp
y_1, Lambda (Rep ADM)
add) -> Result -> SubExpRes
forall a. HasCallStack => [a] -> a
head (Result -> SubExpRes) -> ADM Result -> ADM SubExpRes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM Result
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda (Rep ADM)
add [SubExp -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
toExp SubExp
z_t, SubExp -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
toExp SubExp
y_1])
([SubExp]
-> [SubExp]
-> [Lambda (Rep ADM)]
-> [(SubExp, SubExp, Lambda (Rep ADM))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SubExp]
z_term (VName -> SubExp
Var (VName -> SubExp) -> (Param dec -> VName) -> Param dec -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName (Param dec -> SubExp) -> [Param dec] -> [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Param dec]
py1) [Lambda (Rep ADM)]
[Lambda SOACS]
adds)
let x1 = [SubExp] -> Result
subExpsRes ([SubExp] -> Result) -> ADM [SubExp] -> ADM Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Param dec -> ADM SubExp) -> [Param dec] -> ADM [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ([Char] -> SubExp -> ADM SubExp
forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
[Char] -> a -> m SubExp
toSubExp [Char]
"x1" (SubExp -> ADM SubExp)
-> (Param dec -> SubExp) -> Param dec -> ADM SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp) -> (Param dec -> VName) -> Param dec -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName) [Param dec]
px1
op <- renameLambda $ scanLambda scan
let a3 = Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM Result
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda (Rep ADM)
Lambda SOACS
op (VName -> ADM (Exp (Rep ADM))
VName -> ADM (Exp SOACS)
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => VName -> m (Exp (Rep m))
toExp (VName -> ADM (Exp SOACS))
-> (Param dec -> VName) -> Param dec -> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName (Param dec -> ADM (Exp SOACS)) -> [Param dec] -> [ADM (Exp SOACS)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Param dec]
pa1 [Param dec] -> [Param dec] -> [Param dec]
forall a. [a] -> [a] -> [a]
++ [Param dec]
pa2)
concat <$> sequence [x1, a3, z]
asLiftPPAD :: [VName] -> SubExp -> [SubExp] -> ADM [VName]
asLiftPPAD :: [VName] -> SubExp -> [SubExp] -> ADM [VName]
asLiftPPAD [VName]
as SubExp
w [SubExp]
e = do
par_i <- [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" (TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness)))
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
lmb <- mkLambda [par_i] $ do
forM (zip as e) $ \(VName
arr, SubExp
arr_e) -> do
a_lift <-
[Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"a_lift"
(Exp SOACS -> ADM VName) -> ADM (Exp SOACS) -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
( do
nm1 <- [Char] -> TPrimExp Int64 VName -> ADM SubExp
forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
[Char] -> a -> m SubExp
toSubExp [Char]
"n_minus_one" (TPrimExp Int64 VName -> ADM SubExp)
-> TPrimExp Int64 VName -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
w TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1
pure $ BasicOp $ CmpOp (CmpSlt Int64) (Var $ paramName par_i) nm1
)
( ADM Result -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (ADM Result -> ADM (Body (Rep ADM)))
-> ADM Result -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ (\SubExp
x -> [SubExp -> SubExpRes
subExpRes SubExp
x]) (SubExp -> Result) -> ADM SubExp -> ADM Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"val" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eIndex VName
arr [TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (TPrimExp Int64 VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
par_i) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1])
)
(ADM Result -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (ADM Result -> ADM (Body (Rep ADM)))
-> ADM Result -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ Result -> ADM Result
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp -> SubExpRes
subExpRes SubExp
arr_e])
pure $ varRes a_lift
iota <- letExp "iota" $ BasicOp $ Iota w (intConst Int64 0) (intConst Int64 1) Int64
letTupExp "as_lift" $ Op $ Screma w [iota] $ mapSOAC lmb
ysRightPPAD :: [VName] -> SubExp -> [SubExp] -> ADM [VName]
ysRightPPAD :: [VName] -> SubExp -> [SubExp] -> ADM [VName]
ysRightPPAD [VName]
ys SubExp
w [SubExp]
e = do
par_i <- [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" (TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness)))
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
lmb <- mkLambda [par_i] $ do
forM (zip ys e) $ \(VName
arr, SubExp
arr_e) -> do
a_lift <-
[Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"y_right"
(Exp SOACS -> ADM VName) -> ADM (Exp SOACS) -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
( Exp (Rep ADM) -> ADM (Exp (Rep ADM))
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Rep ADM) -> ADM (Exp (Rep ADM)))
-> Exp (Rep ADM) -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (PrimType -> CmpOp
CmpEq PrimType
int64) (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
par_i) (Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64))
)
(ADM Result -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (ADM Result -> ADM (Body (Rep ADM)))
-> ADM Result -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ Result -> ADM Result
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp -> SubExpRes
subExpRes SubExp
arr_e])
( ADM Result -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (ADM Result -> ADM (Body (Rep ADM)))
-> ADM Result -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ (\SubExp
x -> [SubExp -> SubExpRes
subExpRes SubExp
x]) (SubExp -> Result) -> ADM SubExp -> ADM Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"val" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eIndex VName
arr [TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (TPrimExp Int64 VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
par_i) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1])
)
pure $ varRes a_lift
iota <- letExp "iota" $ BasicOp $ Iota w (intConst Int64 0) (intConst Int64 1) Int64
letTupExp "ys_right" $ Op $ Screma w [iota] $ mapSOAC lmb
finalMapPPAD :: VjpOps -> [VName] -> Scan SOACS -> ADM (Lambda SOACS)
finalMapPPAD :: VjpOps -> [VName] -> Scan SOACS -> ADM (Lambda SOACS)
finalMapPPAD VjpOps
ops [VName]
as Scan SOACS
scan = do
as_types <- (VName -> ADM (TypeBase Shape NoUniqueness))
-> [VName] -> ADM [TypeBase Shape NoUniqueness]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType [VName]
as
let arg_type_row = (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType [TypeBase Shape NoUniqueness]
as_types
par_y_right <- zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_y_right")) as arg_type_row
par_a <- zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_a")) as arg_type_row
par_r_adj <- zipWithM (\VName
x -> [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_r_adj")) as arg_type_row
mkLambda (par_y_right ++ par_a ++ par_r_adj) $ do
op_bar_2 <- mkScanAdjointLam ops (scanLambda scan) WrtSecond (Var . paramName <$> par_r_adj)
eLambda op_bar_2 $ toExp . Var . paramName <$> par_y_right ++ par_a
diffScan :: VjpOps -> [VName] -> SubExp -> [VName] -> Scan SOACS -> ADM ()
diffScan :: VjpOps -> [VName] -> SubExp -> [VName] -> Scan SOACS -> ADM ()
diffScan VjpOps
ops [VName]
ys SubExp
w [VName]
as Scan SOACS
scan = do
scan_case <- VjpOps -> Lambda SOACS -> ADM ScanAlgo
identifyCase VjpOps
ops (Lambda SOACS -> ADM ScanAlgo) -> Lambda SOACS -> ADM ScanAlgo
forall a b. (a -> b) -> a -> b
$ Scan SOACS -> Lambda SOACS
forall rep. Scan rep -> Lambda rep
scanLambda Scan SOACS
scan
let d = [VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
as
ys_adj <- mapM lookupAdjVal ys
as_ts <- mapM lookupType as
as_contribs <- case scan_case of
ScanAlgo
GenericPPAD -> do
let e :: [SubExp]
e = Scan SOACS -> [SubExp]
forall rep. Scan rep -> [SubExp]
scanNeutral Scan SOACS
scan
as_lift <- [VName] -> SubExp -> [SubExp] -> ADM [VName]
asLiftPPAD [VName]
as SubExp
w [SubExp]
e
let m = [VName]
ys [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
as_lift [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
ys_adj
op_lft <- mkPPADOpLifted ops as scan
a_zero <- mapM (fmap Var . letExp "rscan_zero" . zeroExp . rowType) as_ts
let lft_scan = Lambda SOACS -> [SubExp] -> Scan SOACS
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda SOACS
op_lft ([SubExp] -> Scan SOACS) -> [SubExp] -> Scan SOACS
forall a b. (a -> b) -> a -> b
$ [SubExp]
e [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
e [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
a_zero
rs_adj <- (!! 2) . chunk d <$> scanRight m w lft_scan
ys_right <- ysRightPPAD ys w e
final_lmb <- finalMapPPAD ops as scan
letTupExp "as_bar" $ Op $ Screma w (ys_right ++ as ++ rs_adj) $ mapSOAC final_lmb
GenericIFL23 Special
sc -> do
map1_lam <- VjpOps
-> SubExp
-> Lambda SOACS
-> [VName]
-> [VName]
-> [VName]
-> Special
-> Int
-> ADM (Lambda SOACS)
mkScanFusedMapLam VjpOps
ops SubExp
w (Scan SOACS -> Lambda SOACS
forall rep. Scan rep -> Lambda rep
scanLambda Scan SOACS
scan) [VName]
as [VName]
ys [VName]
ys_adj Special
sc Int
d
scans_lin_fun_o <- mkScanLinFunO (head as_ts) sc
scan_lams <- mkScans (specialScans sc) scans_lin_fun_o
iota <-
letExp "iota" $ BasicOp $ Iota w (intConst Int64 0) (intConst Int64 1) Int64
r_scan <-
letTupExp "adj_ctrb_scan" . Op . Screma w [iota] $
scanomapSOAC scan_lams map1_lam
mkScanFinalMap ops w (scanLambda scan) as ys (splitScanRes sc r_scan d)
zipWithM_ updateAdj as as_contribs
where
mkScans :: Int -> Scan SOACS -> ADM [Scan SOACS]
mkScans :: Int -> Scan SOACS -> ADM [Scan SOACS]
mkScans Int
d Scan SOACS
s =
Int -> ADM (Scan SOACS) -> ADM [Scan SOACS]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
d (ADM (Scan SOACS) -> ADM [Scan SOACS])
-> ADM (Scan SOACS) -> ADM [Scan SOACS]
forall a b. (a -> b) -> a -> b
$ do
lam' <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda (Lambda SOACS -> ADM (Lambda SOACS))
-> Lambda SOACS -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ Scan SOACS -> Lambda SOACS
forall rep. Scan rep -> Lambda rep
scanLambda Scan SOACS
s
pure $ Scan lam' $ scanNeutral s
splitScanRes :: Special -> [b] -> Int -> [b]
splitScanRes Special
sc [b]
res Int
d =
([b] -> [b]) -> [[b]] -> [b]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Int -> [b] -> [b]
forall a. Int -> [a] -> [a]
take (Int -> Int -> Int
forall a. Integral a => a -> a -> a
div Int
d (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Special -> Int
specialScans Special
sc)) (Special -> [b] -> [[b]]
forall a. Special -> [a] -> [[a]]
orderArgs Special
sc [b]
res)
diffScanVec ::
VjpOps ->
[VName] ->
StmAux () ->
SubExp ->
Lambda SOACS ->
[SubExp] ->
[VName] ->
ADM () ->
ADM ()
diffScanVec :: VjpOps
-> [VName]
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> [SubExp]
-> [VName]
-> ADM ()
-> ADM ()
diffScanVec VjpOps
ops [VName]
ys StmAux ()
aux SubExp
w Lambda SOACS
lam [SubExp]
ne [VName]
as ADM ()
m = do
stmts <- ADM [()] -> ADM (Stms (Rep ADM))
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (ADM [()] -> ADM (Stms (Rep ADM)))
-> ADM [()] -> ADM (Stms (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
rank <- TypeBase Shape NoUniqueness -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (TypeBase Shape NoUniqueness -> Int)
-> ADM (TypeBase Shape NoUniqueness) -> ADM Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType ([VName] -> VName
forall a. HasCallStack => [a] -> a
head [VName]
as)
let rear = [Int
1, Int
0] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
drop Int
2 [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
transp_as <-
forM as $ \VName
a ->
[Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
a [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_transp") (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ VName -> [Int] -> BasicOp
Rearrange VName
a [Int]
rear
ts <- traverse lookupType transp_as
let n = Int -> [TypeBase Shape NoUniqueness] -> SubExp
forall u. Int -> [TypeBase Shape u] -> SubExp
arraysSize Int
0 [TypeBase Shape NoUniqueness]
ts
as_par <- traverse (newParam "as_par" . rowType) ts
ne_par <- traverse (newParam "ne_par") $ lambdaReturnType lam
scan_form <- scanSOAC [Scan lam (map (Var . paramName) ne_par)]
map_lam <-
mkLambda (as_par ++ ne_par) . fmap varsRes . letTupExp "map_res" . Op $
Screma w (map paramName as_par) scan_form
transp_ys <-
letTupExp "trans_ys" . Op $
Screma n (transp_as ++ subExpVars ne) (mapSOAC map_lam)
forM (zip ys transp_ys) $ \(VName
y, VName
x) ->
StmAux () -> ADM () -> ADM ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
y] (Exp (Rep ADM) -> ADM ()) -> Exp (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ VName -> [Int] -> BasicOp
Rearrange VName
x [Int]
rear
foldr (vjpStm ops) m stmts
diffScanAdd :: VjpOps -> VName -> SubExp -> Lambda SOACS -> SubExp -> VName -> ADM ()
diffScanAdd :: VjpOps
-> VName -> SubExp -> Lambda SOACS -> SubExp -> VName -> ADM ()
diffScanAdd VjpOps
_ops VName
ys SubExp
n Lambda SOACS
lam' SubExp
ne VName
as = do
lam <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam'
ys_bar <- lookupAdjVal ys
map_scan <- rev_arr_lam ys_bar
iota <-
letExp "iota" $ BasicOp $ Iota n (intConst Int64 0) (intConst Int64 1) Int64
scan_res <-
letExp "res_rev" $ Op $ Screma n [iota] $ scanomapSOAC [Scan lam [ne]] map_scan
rev_lam <- rev_arr_lam scan_res
contrb <- letExp "contrb" $ Op $ Screma n [iota] $ mapSOAC rev_lam
updateAdj as contrb
where
rev_arr_lam :: VName -> ADM (Lambda SOACS)
rev_arr_lam :: VName -> ADM (Lambda SOACS)
rev_arr_lam VName
arr = do
par_i <- [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" (TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness)))
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
mkLambda [par_i] $ do
a <-
letExp "ys_bar_rev"
=<< eIndex arr [toExp (pe64 n - le64 (paramName par_i) - 1)]
pure [varRes a]