{-# LANGUAGE TypeFamilies #-}
module Futhark.AD.Rev.Monad
( ADM,
RState (..),
runADM,
Adj (..),
InBounds (..),
Sparse (..),
adjFromParam,
adjFromVar,
lookupAdj,
lookupAdjVal,
adjVal,
updateAdj,
updateSubExpAdj,
updateAdjSlice,
updateAdjIndex,
setAdj,
insAdj,
adjsReps,
copyConsumedArrsInStm,
copyConsumedArrsInBody,
addSubstitution,
returnSweepCode,
adjVName,
subAD,
noAdjsFor,
subSubsts,
isActive,
tabNest,
oneExp,
zeroExp,
unitAdjOfType,
addLambda,
VjpOps (..),
setLoopTape,
lookupLoopTape,
substLoopTape,
renameLoopTape,
)
where
import Control.Monad
import Control.Monad.State.Strict
import Data.Bifunctor (second)
import Data.List (foldl')
import Data.Map qualified as M
import Data.Maybe
import Futhark.Analysis.Alias qualified as Alias
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.Aliases (consumedInStms)
import Futhark.IR.Prop.Aliases
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Transform.Substitute
import Futhark.Util (chunks)
zeroExp :: Type -> Exp rep
zeroExp :: forall rep. TypeBase Shape NoUniqueness -> Exp rep
zeroExp (Prim PrimType
pt) =
BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
pt
zeroExp (Array PrimType
pt Shape
shape NoUniqueness
_) =
BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
shape (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
pt
zeroExp TypeBase Shape NoUniqueness
t = [Char] -> Exp rep
forall a. HasCallStack => [Char] -> a
error ([Char] -> Exp rep) -> [Char] -> Exp rep
forall a b. (a -> b) -> a -> b
$ [Char]
"zeroExp: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ TypeBase Shape NoUniqueness -> [Char]
forall a. Pretty a => a -> [Char]
prettyString TypeBase Shape NoUniqueness
t
onePrim :: PrimType -> PrimValue
onePrim :: PrimType -> PrimValue
onePrim (IntType IntType
it) = IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
it (Int
1 :: Int)
onePrim (FloatType FloatType
ft) = FloatValue -> PrimValue
FloatValue (FloatValue -> PrimValue) -> FloatValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ FloatType -> Double -> FloatValue
forall num. Real num => FloatType -> num -> FloatValue
floatValue FloatType
ft (Double
1 :: Double)
onePrim PrimType
Bool = Bool -> PrimValue
BoolValue Bool
True
onePrim PrimType
Unit = PrimValue
UnitValue
oneExp :: Type -> Exp rep
oneExp :: forall rep. TypeBase Shape NoUniqueness -> Exp rep
oneExp (Prim PrimType
t) = BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
forall v. IsValue v => v -> SubExp
constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
onePrim PrimType
t
oneExp (Array PrimType
pt Shape
shape NoUniqueness
_) =
BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
shape (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
onePrim PrimType
pt
oneExp TypeBase Shape NoUniqueness
t = [Char] -> Exp rep
forall a. HasCallStack => [Char] -> a
error ([Char] -> Exp rep) -> [Char] -> Exp rep
forall a b. (a -> b) -> a -> b
$ [Char]
"oneExp: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ TypeBase Shape NoUniqueness -> [Char]
forall a. Pretty a => a -> [Char]
prettyString TypeBase Shape NoUniqueness
t
data InBounds
=
CheckBounds (Maybe SubExp)
|
AssumeBounds
|
OutOfBounds
deriving (InBounds -> InBounds -> Bool
(InBounds -> InBounds -> Bool)
-> (InBounds -> InBounds -> Bool) -> Eq InBounds
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: InBounds -> InBounds -> Bool
== :: InBounds -> InBounds -> Bool
$c/= :: InBounds -> InBounds -> Bool
/= :: InBounds -> InBounds -> Bool
Eq, Eq InBounds
Eq InBounds =>
(InBounds -> InBounds -> Ordering)
-> (InBounds -> InBounds -> Bool)
-> (InBounds -> InBounds -> Bool)
-> (InBounds -> InBounds -> Bool)
-> (InBounds -> InBounds -> Bool)
-> (InBounds -> InBounds -> InBounds)
-> (InBounds -> InBounds -> InBounds)
-> Ord InBounds
InBounds -> InBounds -> Bool
InBounds -> InBounds -> Ordering
InBounds -> InBounds -> InBounds
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: InBounds -> InBounds -> Ordering
compare :: InBounds -> InBounds -> Ordering
$c< :: InBounds -> InBounds -> Bool
< :: InBounds -> InBounds -> Bool
$c<= :: InBounds -> InBounds -> Bool
<= :: InBounds -> InBounds -> Bool
$c> :: InBounds -> InBounds -> Bool
> :: InBounds -> InBounds -> Bool
$c>= :: InBounds -> InBounds -> Bool
>= :: InBounds -> InBounds -> Bool
$cmax :: InBounds -> InBounds -> InBounds
max :: InBounds -> InBounds -> InBounds
$cmin :: InBounds -> InBounds -> InBounds
min :: InBounds -> InBounds -> InBounds
Ord, Int -> InBounds -> [Char] -> [Char]
[InBounds] -> [Char] -> [Char]
InBounds -> [Char]
(Int -> InBounds -> [Char] -> [Char])
-> (InBounds -> [Char])
-> ([InBounds] -> [Char] -> [Char])
-> Show InBounds
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
$cshowsPrec :: Int -> InBounds -> [Char] -> [Char]
showsPrec :: Int -> InBounds -> [Char] -> [Char]
$cshow :: InBounds -> [Char]
show :: InBounds -> [Char]
$cshowList :: [InBounds] -> [Char] -> [Char]
showList :: [InBounds] -> [Char] -> [Char]
Show)
data Sparse = Sparse
{
Sparse -> Shape
sparseShape :: Shape,
Sparse -> PrimType
sparseType :: PrimType,
Sparse -> [(InBounds, SubExp, SubExp)]
sparseIdxVals :: [(InBounds, SubExp, SubExp)]
}
deriving (Sparse -> Sparse -> Bool
(Sparse -> Sparse -> Bool)
-> (Sparse -> Sparse -> Bool) -> Eq Sparse
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Sparse -> Sparse -> Bool
== :: Sparse -> Sparse -> Bool
$c/= :: Sparse -> Sparse -> Bool
/= :: Sparse -> Sparse -> Bool
Eq, Eq Sparse
Eq Sparse =>
(Sparse -> Sparse -> Ordering)
-> (Sparse -> Sparse -> Bool)
-> (Sparse -> Sparse -> Bool)
-> (Sparse -> Sparse -> Bool)
-> (Sparse -> Sparse -> Bool)
-> (Sparse -> Sparse -> Sparse)
-> (Sparse -> Sparse -> Sparse)
-> Ord Sparse
Sparse -> Sparse -> Bool
Sparse -> Sparse -> Ordering
Sparse -> Sparse -> Sparse
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Sparse -> Sparse -> Ordering
compare :: Sparse -> Sparse -> Ordering
$c< :: Sparse -> Sparse -> Bool
< :: Sparse -> Sparse -> Bool
$c<= :: Sparse -> Sparse -> Bool
<= :: Sparse -> Sparse -> Bool
$c> :: Sparse -> Sparse -> Bool
> :: Sparse -> Sparse -> Bool
$c>= :: Sparse -> Sparse -> Bool
>= :: Sparse -> Sparse -> Bool
$cmax :: Sparse -> Sparse -> Sparse
max :: Sparse -> Sparse -> Sparse
$cmin :: Sparse -> Sparse -> Sparse
min :: Sparse -> Sparse -> Sparse
Ord, Int -> Sparse -> [Char] -> [Char]
[Sparse] -> [Char] -> [Char]
Sparse -> [Char]
(Int -> Sparse -> [Char] -> [Char])
-> (Sparse -> [Char])
-> ([Sparse] -> [Char] -> [Char])
-> Show Sparse
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
$cshowsPrec :: Int -> Sparse -> [Char] -> [Char]
showsPrec :: Int -> Sparse -> [Char] -> [Char]
$cshow :: Sparse -> [Char]
show :: Sparse -> [Char]
$cshowList :: [Sparse] -> [Char] -> [Char]
showList :: [Sparse] -> [Char] -> [Char]
Show)
data Adj
= AdjSparse Sparse
| AdjVal SubExp
| AdjZero Shape PrimType
deriving (Adj -> Adj -> Bool
(Adj -> Adj -> Bool) -> (Adj -> Adj -> Bool) -> Eq Adj
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Adj -> Adj -> Bool
== :: Adj -> Adj -> Bool
$c/= :: Adj -> Adj -> Bool
/= :: Adj -> Adj -> Bool
Eq, Eq Adj
Eq Adj =>
(Adj -> Adj -> Ordering)
-> (Adj -> Adj -> Bool)
-> (Adj -> Adj -> Bool)
-> (Adj -> Adj -> Bool)
-> (Adj -> Adj -> Bool)
-> (Adj -> Adj -> Adj)
-> (Adj -> Adj -> Adj)
-> Ord Adj
Adj -> Adj -> Bool
Adj -> Adj -> Ordering
Adj -> Adj -> Adj
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Adj -> Adj -> Ordering
compare :: Adj -> Adj -> Ordering
$c< :: Adj -> Adj -> Bool
< :: Adj -> Adj -> Bool
$c<= :: Adj -> Adj -> Bool
<= :: Adj -> Adj -> Bool
$c> :: Adj -> Adj -> Bool
> :: Adj -> Adj -> Bool
$c>= :: Adj -> Adj -> Bool
>= :: Adj -> Adj -> Bool
$cmax :: Adj -> Adj -> Adj
max :: Adj -> Adj -> Adj
$cmin :: Adj -> Adj -> Adj
min :: Adj -> Adj -> Adj
Ord, Int -> Adj -> [Char] -> [Char]
[Adj] -> [Char] -> [Char]
Adj -> [Char]
(Int -> Adj -> [Char] -> [Char])
-> (Adj -> [Char]) -> ([Adj] -> [Char] -> [Char]) -> Show Adj
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
$cshowsPrec :: Int -> Adj -> [Char] -> [Char]
showsPrec :: Int -> Adj -> [Char] -> [Char]
$cshow :: Adj -> [Char]
show :: Adj -> [Char]
$cshowList :: [Adj] -> [Char] -> [Char]
showList :: [Adj] -> [Char] -> [Char]
Show)
instance Substitute Adj where
substituteNames :: Substitutions -> Adj -> Adj
substituteNames Substitutions
m (AdjVal (Var VName
v)) = SubExp -> Adj
AdjVal (SubExp -> Adj) -> SubExp -> Adj
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Substitutions -> VName -> VName
forall a. Substitute a => Substitutions -> a -> a
substituteNames Substitutions
m VName
v
substituteNames Substitutions
_ Adj
adj = Adj
adj
zeroArray :: (MonadBuilder m) => Shape -> Type -> m VName
zeroArray :: forall (m :: * -> *).
MonadBuilder m =>
Shape -> TypeBase Shape NoUniqueness -> m VName
zeroArray Shape
shape TypeBase Shape NoUniqueness
t
| Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
shape Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 =
[Char] -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"zero" (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> Exp (Rep m)
forall rep. TypeBase Shape NoUniqueness -> Exp rep
zeroExp TypeBase Shape NoUniqueness
t
| Bool
otherwise = do
zero <- [Char] -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"zero" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> Exp (Rep m)
forall rep. TypeBase Shape NoUniqueness -> Exp rep
zeroExp TypeBase Shape NoUniqueness
t
attributing (oneAttr "sequential") $
letExp "zeroes_" . BasicOp $
Replicate shape zero
sparseArray :: (MonadBuilder m, Rep m ~ SOACS) => Sparse -> m VName
sparseArray :: forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
Sparse -> m VName
sparseArray (Sparse Shape
shape PrimType
t [(InBounds, SubExp, SubExp)]
ivs) = do
(VName -> [(InBounds, SubExp, SubExp)] -> m VName)
-> [(InBounds, SubExp, SubExp)] -> VName -> m VName
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((VName -> (InBounds, SubExp, SubExp) -> m VName)
-> VName -> [(InBounds, SubExp, SubExp)] -> m VName
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM VName -> (InBounds, SubExp, SubExp) -> m VName
f) [(InBounds, SubExp, SubExp)]
ivs (VName -> m VName) -> m VName -> m VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Shape -> TypeBase Shape NoUniqueness -> m VName
forall (m :: * -> *).
MonadBuilder m =>
Shape -> TypeBase Shape NoUniqueness -> m VName
zeroArray Shape
shape (PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t)
where
arr_t :: TypeBase Shape NoUniqueness
arr_t = PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t TypeBase Shape NoUniqueness -> Shape -> TypeBase Shape NoUniqueness
`arrayOfShape` Shape
shape
f :: VName -> (InBounds, SubExp, SubExp) -> m VName
f VName
arr (InBounds
check, SubExp
i, SubExp
se) = do
let stm :: Safety -> m VName
stm Safety
s =
[Char] -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"sparse" (Exp SOACS -> m VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> m VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> m VName) -> BasicOp -> m VName
forall a b. (a -> b) -> a -> b
$
Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
s VName
arr (TypeBase Shape NoUniqueness -> [DimIndex SubExp] -> Slice SubExp
fullSlice TypeBase Shape NoUniqueness
arr_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i]) SubExp
se
case InBounds
check of
InBounds
AssumeBounds -> Safety -> m VName
stm Safety
Unsafe
CheckBounds Maybe SubExp
_ -> Safety -> m VName
stm Safety
Safe
InBounds
OutOfBounds -> VName -> m VName
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
arr
adjFromVar :: VName -> Adj
adjFromVar :: VName -> Adj
adjFromVar = SubExp -> Adj
AdjVal (SubExp -> Adj) -> (VName -> SubExp) -> VName -> Adj
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var
adjFromParam :: Param t -> Adj
adjFromParam :: forall t. Param t -> Adj
adjFromParam = VName -> Adj
adjFromVar (VName -> Adj) -> (Param t -> VName) -> Param t -> Adj
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param t -> VName
forall dec. Param dec -> VName
paramName
unitAdjOfType :: Type -> ADM Adj
unitAdjOfType :: TypeBase Shape NoUniqueness -> ADM Adj
unitAdjOfType TypeBase Shape NoUniqueness
t = SubExp -> Adj
AdjVal (SubExp -> Adj) -> ADM SubExp -> ADM Adj
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"adj_unit" (TypeBase Shape NoUniqueness -> Exp SOACS
forall rep. TypeBase Shape NoUniqueness -> Exp rep
oneExp TypeBase Shape NoUniqueness
t)
adjRep :: Adj -> ([SubExp], [SubExp] -> Adj)
adjRep :: Adj -> ([SubExp], [SubExp] -> Adj)
adjRep (AdjVal SubExp
se) = ([SubExp
se], \[SubExp
se'] -> SubExp -> Adj
AdjVal SubExp
se')
adjRep (AdjZero Shape
shape PrimType
pt) = ([], \[] -> Shape -> PrimType -> Adj
AdjZero Shape
shape PrimType
pt)
adjRep (AdjSparse (Sparse Shape
shape PrimType
pt [(InBounds, SubExp, SubExp)]
ivs)) =
(((InBounds, SubExp, SubExp) -> [SubExp])
-> [(InBounds, SubExp, SubExp)] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (InBounds, SubExp, SubExp) -> [SubExp]
forall {a} {a}. (a, a, a) -> [a]
ivRep [(InBounds, SubExp, SubExp)]
ivs, Sparse -> Adj
AdjSparse (Sparse -> Adj) -> ([SubExp] -> Sparse) -> [SubExp] -> Adj
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> PrimType -> [(InBounds, SubExp, SubExp)] -> Sparse
Sparse Shape
shape PrimType
pt ([(InBounds, SubExp, SubExp)] -> Sparse)
-> ([SubExp] -> [(InBounds, SubExp, SubExp)]) -> [SubExp] -> Sparse
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(InBounds, SubExp, SubExp)]
-> [SubExp] -> [(InBounds, SubExp, SubExp)]
forall {b} {c} {c}. [(InBounds, b, c)] -> [c] -> [(InBounds, c, c)]
repIvs [(InBounds, SubExp, SubExp)]
ivs)
where
ivRep :: (a, a, a) -> [a]
ivRep (a
_, a
i, a
v) = [a
i, a
v]
repIvs :: [(InBounds, b, c)] -> [c] -> [(InBounds, c, c)]
repIvs ((InBounds
check, b
_, c
_) : [(InBounds, b, c)]
ivs') (c
i : c
v : [c]
ses) =
(InBounds
check', c
i, c
v) (InBounds, c, c) -> [(InBounds, c, c)] -> [(InBounds, c, c)]
forall a. a -> [a] -> [a]
: [(InBounds, b, c)] -> [c] -> [(InBounds, c, c)]
repIvs [(InBounds, b, c)]
ivs' [c]
ses
where
check' :: InBounds
check' = case InBounds
check of
InBounds
AssumeBounds -> InBounds
AssumeBounds
CheckBounds Maybe SubExp
b -> Maybe SubExp -> InBounds
CheckBounds Maybe SubExp
b
InBounds
OutOfBounds -> Maybe SubExp -> InBounds
CheckBounds (SubExp -> Maybe SubExp
forall a. a -> Maybe a
Just (Bool -> SubExp
forall v. IsValue v => v -> SubExp
constant Bool
False))
repIvs [(InBounds, b, c)]
_ [c]
_ = []
adjsReps :: [Adj] -> ([SubExp], [SubExp] -> [Adj])
adjsReps :: [Adj] -> ([SubExp], [SubExp] -> [Adj])
adjsReps [Adj]
adjs =
let ([[SubExp]]
reps, [[SubExp] -> Adj]
fs) = [([SubExp], [SubExp] -> Adj)] -> ([[SubExp]], [[SubExp] -> Adj])
forall a b. [(a, b)] -> ([a], [b])
unzip ([([SubExp], [SubExp] -> Adj)] -> ([[SubExp]], [[SubExp] -> Adj]))
-> [([SubExp], [SubExp] -> Adj)] -> ([[SubExp]], [[SubExp] -> Adj])
forall a b. (a -> b) -> a -> b
$ (Adj -> ([SubExp], [SubExp] -> Adj))
-> [Adj] -> [([SubExp], [SubExp] -> Adj)]
forall a b. (a -> b) -> [a] -> [b]
map Adj -> ([SubExp], [SubExp] -> Adj)
adjRep [Adj]
adjs
in ([[SubExp]] -> [SubExp]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[SubExp]]
reps, (([SubExp] -> Adj) -> [SubExp] -> Adj)
-> [[SubExp] -> Adj] -> [[SubExp]] -> [Adj]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith ([SubExp] -> Adj) -> [SubExp] -> Adj
forall a b. (a -> b) -> a -> b
($) [[SubExp] -> Adj]
fs ([[SubExp]] -> [Adj])
-> ([SubExp] -> [[SubExp]]) -> [SubExp] -> [Adj]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [SubExp] -> [[SubExp]]
forall a. [Int] -> [a] -> [[a]]
chunks (([SubExp] -> Int) -> [[SubExp]] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [[SubExp]]
reps))
data RState = RState
{ RState -> Map VName Adj
stateAdjs :: M.Map VName Adj,
RState -> Substitutions
stateLoopTape :: Substitutions,
RState -> Substitutions
stateSubsts :: Substitutions,
RState -> VNameSource
stateNameSource :: VNameSource
}
newtype ADM a = ADM (BuilderT SOACS (State RState) a)
deriving
( (forall a b. (a -> b) -> ADM a -> ADM b)
-> (forall a b. a -> ADM b -> ADM a) -> Functor ADM
forall a b. a -> ADM b -> ADM a
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> ADM a -> ADM b
fmap :: forall a b. (a -> b) -> ADM a -> ADM b
$c<$ :: forall a b. a -> ADM b -> ADM a
<$ :: forall a b. a -> ADM b -> ADM a
Functor,
Functor ADM
Functor ADM =>
(forall a. a -> ADM a)
-> (forall a b. ADM (a -> b) -> ADM a -> ADM b)
-> (forall a b c. (a -> b -> c) -> ADM a -> ADM b -> ADM c)
-> (forall a b. ADM a -> ADM b -> ADM b)
-> (forall a b. ADM a -> ADM b -> ADM a)
-> Applicative ADM
forall a. a -> ADM a
forall a b. ADM a -> ADM b -> ADM a
forall a b. ADM a -> ADM b -> ADM b
forall a b. ADM (a -> b) -> ADM a -> ADM b
forall a b c. (a -> b -> c) -> ADM a -> ADM b -> ADM c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall a. a -> ADM a
pure :: forall a. a -> ADM a
$c<*> :: forall a b. ADM (a -> b) -> ADM a -> ADM b
<*> :: forall a b. ADM (a -> b) -> ADM a -> ADM b
$cliftA2 :: forall a b c. (a -> b -> c) -> ADM a -> ADM b -> ADM c
liftA2 :: forall a b c. (a -> b -> c) -> ADM a -> ADM b -> ADM c
$c*> :: forall a b. ADM a -> ADM b -> ADM b
*> :: forall a b. ADM a -> ADM b -> ADM b
$c<* :: forall a b. ADM a -> ADM b -> ADM a
<* :: forall a b. ADM a -> ADM b -> ADM a
Applicative,
Applicative ADM
Applicative ADM =>
(forall a b. ADM a -> (a -> ADM b) -> ADM b)
-> (forall a b. ADM a -> ADM b -> ADM b)
-> (forall a. a -> ADM a)
-> Monad ADM
forall a. a -> ADM a
forall a b. ADM a -> ADM b -> ADM b
forall a b. ADM a -> (a -> ADM b) -> ADM b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall a b. ADM a -> (a -> ADM b) -> ADM b
>>= :: forall a b. ADM a -> (a -> ADM b) -> ADM b
$c>> :: forall a b. ADM a -> ADM b -> ADM b
>> :: forall a b. ADM a -> ADM b -> ADM b
$creturn :: forall a. a -> ADM a
return :: forall a. a -> ADM a
Monad,
MonadState RState,
Monad ADM
ADM VNameSource
Monad ADM =>
ADM VNameSource -> (VNameSource -> ADM ()) -> MonadFreshNames ADM
VNameSource -> ADM ()
forall (m :: * -> *).
Monad m =>
m VNameSource -> (VNameSource -> m ()) -> MonadFreshNames m
$cgetNameSource :: ADM VNameSource
getNameSource :: ADM VNameSource
$cputNameSource :: VNameSource -> ADM ()
putNameSource :: VNameSource -> ADM ()
MonadFreshNames,
HasScope SOACS,
LocalScope SOACS
)
instance MonadBuilder ADM where
type Rep ADM = SOACS
mkExpDecM :: Pat (LetDec (Rep ADM)) -> Exp (Rep ADM) -> ADM (ExpDec (Rep ADM))
mkExpDecM Pat (LetDec (Rep ADM))
pat Exp (Rep ADM)
e = BuilderT SOACS (State RState) (ExpDec (Rep ADM))
-> ADM (ExpDec (Rep ADM))
forall a. BuilderT SOACS (State RState) a -> ADM a
ADM (BuilderT SOACS (State RState) (ExpDec (Rep ADM))
-> ADM (ExpDec (Rep ADM)))
-> BuilderT SOACS (State RState) (ExpDec (Rep ADM))
-> ADM (ExpDec (Rep ADM))
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (BuilderT SOACS (State RState))))
-> Exp (Rep (BuilderT SOACS (State RState)))
-> BuilderT
SOACS (State RState) (ExpDec (Rep (BuilderT SOACS (State RState))))
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m (ExpDec (Rep m))
mkExpDecM Pat (LetDec (Rep (BuilderT SOACS (State RState))))
Pat (LetDec (Rep ADM))
pat Exp (Rep (BuilderT SOACS (State RState)))
Exp (Rep ADM)
e
mkBodyM :: Stms (Rep ADM) -> Result -> ADM (Body (Rep ADM))
mkBodyM Stms (Rep ADM)
bnds Result
res = BuilderT SOACS (State RState) (Body (Rep ADM))
-> ADM (Body (Rep ADM))
forall a. BuilderT SOACS (State RState) a -> ADM a
ADM (BuilderT SOACS (State RState) (Body (Rep ADM))
-> ADM (Body (Rep ADM)))
-> BuilderT SOACS (State RState) (Body (Rep ADM))
-> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ Stms (Rep (BuilderT SOACS (State RState)))
-> Result
-> BuilderT
SOACS (State RState) (Body (Rep (BuilderT SOACS (State RState))))
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM Stms (Rep (BuilderT SOACS (State RState)))
Stms (Rep ADM)
bnds Result
res
mkLetNamesM :: [VName] -> Exp (Rep ADM) -> ADM (Stm (Rep ADM))
mkLetNamesM [VName]
pat Exp (Rep ADM)
e = BuilderT SOACS (State RState) (Stm (Rep ADM))
-> ADM (Stm (Rep ADM))
forall a. BuilderT SOACS (State RState) a -> ADM a
ADM (BuilderT SOACS (State RState) (Stm (Rep ADM))
-> ADM (Stm (Rep ADM)))
-> BuilderT SOACS (State RState) (Stm (Rep ADM))
-> ADM (Stm (Rep ADM))
forall a b. (a -> b) -> a -> b
$ [VName]
-> Exp (Rep (BuilderT SOACS (State RState)))
-> BuilderT
SOACS (State RState) (Stm (Rep (BuilderT SOACS (State RState))))
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m (Stm (Rep m))
mkLetNamesM [VName]
pat Exp (Rep (BuilderT SOACS (State RState)))
Exp (Rep ADM)
e
addStms :: Stms (Rep ADM) -> ADM ()
addStms = BuilderT SOACS (State RState) () -> ADM ()
forall a. BuilderT SOACS (State RState) a -> ADM a
ADM (BuilderT SOACS (State RState) () -> ADM ())
-> (Stms SOACS -> BuilderT SOACS (State RState) ())
-> Stms SOACS
-> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms (Rep (BuilderT SOACS (State RState)))
-> BuilderT SOACS (State RState) ()
Stms SOACS -> BuilderT SOACS (State RState) ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms
collectStms :: forall a. ADM a -> ADM (a, Stms (Rep ADM))
collectStms (ADM BuilderT SOACS (State RState) a
m) = BuilderT SOACS (State RState) (a, Stms (Rep ADM))
-> ADM (a, Stms (Rep ADM))
forall a. BuilderT SOACS (State RState) a -> ADM a
ADM (BuilderT SOACS (State RState) (a, Stms (Rep ADM))
-> ADM (a, Stms (Rep ADM)))
-> BuilderT SOACS (State RState) (a, Stms (Rep ADM))
-> ADM (a, Stms (Rep ADM))
forall a b. (a -> b) -> a -> b
$ BuilderT SOACS (State RState) a
-> BuilderT
SOACS
(State RState)
(a, Stms (Rep (BuilderT SOACS (State RState))))
forall a.
BuilderT SOACS (State RState) a
-> BuilderT
SOACS
(State RState)
(a, Stms (Rep (BuilderT SOACS (State RState))))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms BuilderT SOACS (State RState) a
m
instance MonadFreshNames (State RState) where
getNameSource :: State RState VNameSource
getNameSource = (RState -> VNameSource) -> State RState VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets RState -> VNameSource
stateNameSource
putNameSource :: VNameSource -> State RState ()
putNameSource VNameSource
src = (RState -> RState) -> State RState ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\RState
env -> RState
env {stateNameSource = src})
runADM :: (MonadFreshNames m) => ADM a -> m a
runADM :: forall (m :: * -> *) a. MonadFreshNames m => ADM a -> m a
runADM (ADM BuilderT SOACS (State RState) a
m) =
(VNameSource -> (a, VNameSource)) -> m a
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (a, VNameSource)) -> m a)
-> (VNameSource -> (a, VNameSource)) -> m a
forall a b. (a -> b) -> a -> b
$ \VNameSource
vn ->
(RState -> VNameSource) -> (a, RState) -> (a, VNameSource)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second RState -> VNameSource
stateNameSource ((a, RState) -> (a, VNameSource))
-> (a, RState) -> (a, VNameSource)
forall a b. (a -> b) -> a -> b
$
State RState a -> RState -> (a, RState)
forall s a. State s a -> s -> (a, s)
runState
((a, Stms SOACS) -> a
forall a b. (a, b) -> a
fst ((a, Stms SOACS) -> a)
-> StateT RState Identity (a, Stms SOACS) -> State RState a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BuilderT SOACS (State RState) a
-> Scope SOACS -> StateT RState Identity (a, Stms SOACS)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT BuilderT SOACS (State RState) a
m Scope SOACS
forall a. Monoid a => a
mempty)
(Map VName Adj
-> Substitutions -> Substitutions -> VNameSource -> RState
RState Map VName Adj
forall a. Monoid a => a
mempty Substitutions
forall a. Monoid a => a
mempty Substitutions
forall a. Monoid a => a
mempty VNameSource
vn)
adjVal :: Adj -> ADM VName
adjVal :: Adj -> ADM VName
adjVal (AdjVal SubExp
se) = [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"const_adj" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
adjVal (AdjSparse Sparse
sparse) = Sparse -> ADM VName
forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
Sparse -> m VName
sparseArray Sparse
sparse
adjVal (AdjZero Shape
shape PrimType
t) = Shape -> TypeBase Shape NoUniqueness -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
Shape -> TypeBase Shape NoUniqueness -> m VName
zeroArray Shape
shape (TypeBase Shape NoUniqueness -> ADM VName)
-> TypeBase Shape NoUniqueness -> ADM VName
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
setAdj :: VName -> Adj -> ADM ()
setAdj :: VName -> Adj -> ADM ()
setAdj VName
v Adj
v_adj = (RState -> RState) -> ADM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((RState -> RState) -> ADM ()) -> (RState -> RState) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \RState
env ->
RState
env {stateAdjs = M.insert v v_adj $ stateAdjs env}
insAdj :: VName -> VName -> ADM ()
insAdj :: VName -> VName -> ADM ()
insAdj VName
v = VName -> Adj -> ADM ()
setAdj VName
v (Adj -> ADM ()) -> (VName -> Adj) -> VName -> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> Adj
AdjVal (SubExp -> Adj) -> (VName -> SubExp) -> VName -> Adj
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var
adjVName :: VName -> ADM VName
adjVName :: VName -> ADM VName
adjVName VName
v = [Char] -> ADM VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName (VName -> [Char]
baseString VName
v [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_adj")
copyConsumedArrsInStm :: Stm SOACS -> ADM (Substitutions, Stms SOACS)
copyConsumedArrsInStm :: Stm SOACS -> ADM (Substitutions, Stms SOACS)
copyConsumedArrsInStm Stm SOACS
s = Stm SOACS
-> ADM (Substitutions, Stms SOACS)
-> ADM (Substitutions, Stms SOACS)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stm SOACS
s (ADM (Substitutions, Stms SOACS)
-> ADM (Substitutions, Stms SOACS))
-> ADM (Substitutions, Stms SOACS)
-> ADM (Substitutions, Stms SOACS)
forall a b. (a -> b) -> a -> b
$ ADM Substitutions -> ADM (Substitutions, Stms (Rep ADM))
forall a. ADM a -> ADM (a, Stms (Rep ADM))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms (ADM Substitutions -> ADM (Substitutions, Stms (Rep ADM)))
-> ADM Substitutions -> ADM (Substitutions, Stms (Rep ADM))
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> ADM Substitutions
copyConsumedArrsInStm' Stm SOACS
s
where
copyConsumedArrsInStm' :: Stm SOACS -> ADM Substitutions
copyConsumedArrsInStm' Stm SOACS
stm =
let onConsumed :: VName -> ADM [(VName, VName)]
onConsumed VName
v = Stm SOACS -> ADM [(VName, VName)] -> ADM [(VName, VName)]
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stm SOACS
s (ADM [(VName, VName)] -> ADM [(VName, VName)])
-> ADM [(VName, VName)] -> ADM [(VName, VName)]
forall a b. (a -> b) -> a -> b
$ do
v_t <- VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v
case v_t of
Array {} -> do
v' <-
[Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
v [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_ad_copy") (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty (VName -> SubExp
Var VName
v)
addSubstitution v' v
pure [(v, v')]
TypeBase Shape NoUniqueness
_ -> [(VName, VName)] -> ADM [(VName, VName)]
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [(VName, VName)]
forall a. Monoid a => a
mempty
in [(VName, VName)] -> Substitutions
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, VName)] -> Substitutions)
-> ([[(VName, VName)]] -> [(VName, VName)])
-> [[(VName, VName)]]
-> Substitutions
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[(VName, VName)]] -> [(VName, VName)]
forall a. Monoid a => [a] -> a
mconcat
([[(VName, VName)]] -> Substitutions)
-> ADM [[(VName, VName)]] -> ADM Substitutions
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> ADM [(VName, VName)])
-> [VName] -> ADM [[(VName, VName)]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM [(VName, VName)]
onConsumed (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Stms (Aliases SOACS) -> Names
forall rep. Aliased rep => Stms rep -> Names
consumedInStms (Stms (Aliases SOACS) -> Names) -> Stms (Aliases SOACS) -> Names
forall a b. (a -> b) -> a -> b
$ (Stms (Aliases SOACS), AliasesAndConsumed) -> Stms (Aliases SOACS)
forall a b. (a, b) -> a
fst (AliasTable
-> Stms SOACS -> (Stms (Aliases SOACS), AliasesAndConsumed)
forall rep.
AliasableRep rep =>
AliasTable -> Stms rep -> (Stms (Aliases rep), AliasesAndConsumed)
Alias.analyseStms AliasTable
forall a. Monoid a => a
mempty (Stm SOACS -> Stms SOACS
forall rep. Stm rep -> Stms rep
oneStm Stm SOACS
stm)))
copyConsumedArrsInBody :: [VName] -> Body SOACS -> ADM Substitutions
copyConsumedArrsInBody :: [VName] -> Body SOACS -> ADM Substitutions
copyConsumedArrsInBody [VName]
dontCopy Body SOACS
b =
[Substitutions] -> Substitutions
forall a. Monoid a => [a] -> a
mconcat ([Substitutions] -> Substitutions)
-> ADM [Substitutions] -> ADM Substitutions
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> ADM Substitutions) -> [VName] -> ADM [Substitutions]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM Substitutions
forall {m :: * -> *}. MonadBuilder m => VName -> m Substitutions
onConsumed ((VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
dontCopy) ([VName] -> [VName]) -> [VName] -> [VName]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Body (Aliases SOACS) -> Names
forall rep. Aliased rep => Body rep -> Names
consumedInBody (AliasTable -> Body SOACS -> Body (Aliases SOACS)
forall rep.
AliasableRep rep =>
AliasTable -> Body rep -> Body (Aliases rep)
Alias.analyseBody AliasTable
forall a. Monoid a => a
mempty Body SOACS
b))
where
onConsumed :: VName -> m Substitutions
onConsumed VName
v = do
v_t <- VName -> m (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v
case v_t of
Acc {} -> [Char] -> m Substitutions
forall a. HasCallStack => [Char] -> a
error ([Char] -> m Substitutions) -> [Char] -> m Substitutions
forall a b. (a -> b) -> a -> b
$ [Char]
"copyConsumedArrsInBody: Acc " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString VName
v
Array {} ->
VName -> VName -> Substitutions
forall k a. k -> a -> Map k a
M.singleton VName
v
(VName -> Substitutions) -> m VName -> m Substitutions
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char] -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp
(VName -> [Char]
baseString VName
v [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_ad_copy")
(BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty (VName -> SubExp
Var VName
v))
TypeBase Shape NoUniqueness
_ -> Substitutions -> m Substitutions
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Substitutions
forall a. Monoid a => a
mempty
returnSweepCode :: ADM a -> ADM a
returnSweepCode :: forall a. ADM a -> ADM a
returnSweepCode ADM a
m = do
(a, stms) <- ADM a -> ADM (a, Stms (Rep ADM))
forall a. ADM a -> ADM (a, Stms (Rep ADM))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms ADM a
m
substs <- gets stateSubsts
addStms $ substituteNames substs stms
pure a
addSubstitution :: VName -> VName -> ADM ()
addSubstitution :: VName -> VName -> ADM ()
addSubstitution VName
v VName
v' = (RState -> RState) -> ADM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((RState -> RState) -> ADM ()) -> (RState -> RState) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \RState
env ->
RState
env {stateSubsts = M.insert v v' $ stateSubsts env}
noAdjsFor :: Names -> ADM a -> ADM a
noAdjsFor :: forall a. Names -> ADM a -> ADM a
noAdjsFor Names
names ADM a
m = do
old <- (RState -> [Adj]) -> ADM [Adj]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((RState -> [Adj]) -> ADM [Adj]) -> (RState -> [Adj]) -> ADM [Adj]
forall a b. (a -> b) -> a -> b
$ \RState
env -> (VName -> Maybe Adj) -> [VName] -> [Adj]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (VName -> Map VName Adj -> Maybe Adj
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` RState -> Map VName Adj
stateAdjs RState
env) [VName]
names'
modify $ \RState
env -> RState
env {stateAdjs = foldl' (flip M.delete) (stateAdjs env) names'}
x <- m
modify $ \RState
env -> RState
env {stateAdjs = M.fromList (zip names' old) <> stateAdjs env}
pure x
where
names' :: [VName]
names' = Names -> [VName]
namesToList Names
names
addBinOp :: PrimType -> BinOp
addBinOp :: PrimType -> BinOp
addBinOp (IntType IntType
it) = IntType -> Overflow -> BinOp
Add IntType
it Overflow
OverflowWrap
addBinOp (FloatType FloatType
ft) = FloatType -> BinOp
FAdd FloatType
ft
addBinOp PrimType
Bool = BinOp
LogAnd
addBinOp PrimType
Unit = BinOp
LogAnd
tabNest :: Int -> [VName] -> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
tabNest :: Int
-> [VName] -> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
tabNest = [VName]
-> Int
-> [VName]
-> ([VName] -> [VName] -> ADM [VName])
-> ADM [VName]
forall {m :: * -> *} {t}.
(LParamInfo (Rep m) ~ TypeBase Shape NoUniqueness,
BodyDec (Rep m) ~ (), OpC (Rep m) ~ SOAC, Eq t, Num t,
MonadBuilder m) =>
[VName]
-> t -> [VName] -> ([VName] -> [VName] -> m [VName]) -> m [VName]
tabNest' []
where
tabNest' :: [VName]
-> t -> [VName] -> ([VName] -> [VName] -> m [VName]) -> m [VName]
tabNest' [VName]
is t
0 [VName]
vs [VName] -> [VName] -> m [VName]
f = [VName] -> [VName] -> m [VName]
f ([VName] -> [VName]
forall a. [a] -> [a]
reverse [VName]
is) [VName]
vs
tabNest' [VName]
is t
n [VName]
vs [VName] -> [VName] -> m [VName]
f = do
vs_ts <- (VName -> m (TypeBase Shape NoUniqueness))
-> [VName] -> m [TypeBase Shape NoUniqueness]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> m (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType [VName]
vs
let w = Int -> [TypeBase Shape NoUniqueness] -> SubExp
forall u. Int -> [TypeBase Shape u] -> SubExp
arraysSize Int
0 [TypeBase Shape NoUniqueness]
vs_ts
iota <-
letExp "tab_iota" . BasicOp $
Iota w (intConst Int64 0) (intConst Int64 1) Int64
iparam <- newParam "i" $ Prim int64
params <- forM vs $ \VName
v ->
[Char]
-> TypeBase Shape NoUniqueness
-> m (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
v [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_p") (TypeBase Shape NoUniqueness
-> m (Param (TypeBase Shape NoUniqueness)))
-> (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> TypeBase Shape NoUniqueness
-> m (Param (TypeBase Shape NoUniqueness))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType (TypeBase Shape NoUniqueness
-> m (Param (TypeBase Shape NoUniqueness)))
-> m (TypeBase Shape NoUniqueness)
-> m (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> m (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v
((ret, res), stms) <- collectStms . localScope (scopeOfLParams (iparam : params)) $ do
res <- tabNest' (paramName iparam : is) (n - 1) (map paramName params) f
ret <- mapM lookupType res
pure (ret, varsRes res)
let lam = [LParam (Rep m)]
-> [TypeBase Shape NoUniqueness] -> Body (Rep m) -> Lambda (Rep m)
forall rep.
[LParam rep]
-> [TypeBase Shape NoUniqueness] -> Body rep -> Lambda rep
Lambda (Param (TypeBase Shape NoUniqueness)
iparam Param (TypeBase Shape NoUniqueness)
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. a -> [a] -> [a]
: [Param (TypeBase Shape NoUniqueness)]
params) [TypeBase Shape NoUniqueness]
ret (BodyDec (Rep m) -> Stms (Rep m) -> Result -> Body (Rep m)
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms (Rep m)
stms Result
res)
letTupExp "tab" $ Op $ Screma w (iota : vs) (mapSOAC lam)
addLambda :: Type -> ADM (Lambda SOACS)
addLambda :: TypeBase Shape NoUniqueness -> ADM (Lambda SOACS)
addLambda (Prim PrimType
pt) = BinOp -> PrimType -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
BinOp -> PrimType -> m (Lambda (Rep m))
binOpLambda (PrimType -> BinOp
addBinOp PrimType
pt) PrimType
pt
addLambda t :: TypeBase Shape NoUniqueness
t@Array {} = do
xs_p <- [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"xs" TypeBase Shape NoUniqueness
t
ys_p <- newParam "ys" t
lam <- addLambda $ rowType t
body <- insertStmsM $ do
res <-
letSubExp "lam_map" . Op $
Screma (arraySize 0 t) [paramName xs_p, paramName ys_p] (mapSOAC lam)
pure $ resultBody [res]
pure
Lambda
{ lambdaParams = [xs_p, ys_p],
lambdaReturnType = [t],
lambdaBody = body
}
addLambda TypeBase Shape NoUniqueness
t =
[Char] -> ADM (Lambda SOACS)
forall a. HasCallStack => [Char] -> a
error ([Char] -> ADM (Lambda SOACS)) -> [Char] -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ [Char]
"addLambda: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ TypeBase Shape NoUniqueness -> [Char]
forall a. Show a => a -> [Char]
show TypeBase Shape NoUniqueness
t
addExp :: VName -> VName -> ADM (Exp SOACS)
addExp :: VName -> VName -> ADM (Exp SOACS)
addExp VName
x VName
y = do
x_t <- VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
x
case x_t of
Prim PrimType
pt ->
Exp SOACS -> ADM (Exp SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp SOACS -> ADM (Exp SOACS)) -> Exp SOACS -> ADM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (PrimType -> BinOp
addBinOp PrimType
pt) (VName -> SubExp
Var VName
x) (VName -> SubExp
Var VName
y)
Array {} -> do
lam <- TypeBase Shape NoUniqueness -> ADM (Lambda SOACS)
addLambda (TypeBase Shape NoUniqueness -> ADM (Lambda SOACS))
-> TypeBase Shape NoUniqueness -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType TypeBase Shape NoUniqueness
x_t
pure $ Op $ Screma (arraySize 0 x_t) [x, y] (mapSOAC lam)
TypeBase Shape NoUniqueness
_ ->
[Char] -> ADM (Exp SOACS)
forall a. HasCallStack => [Char] -> a
error ([Char] -> ADM (Exp SOACS)) -> [Char] -> ADM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$ [Char]
"addExp: unexpected type: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ TypeBase Shape NoUniqueness -> [Char]
forall a. Pretty a => a -> [Char]
prettyString TypeBase Shape NoUniqueness
x_t
lookupAdj :: VName -> ADM Adj
lookupAdj :: VName -> ADM Adj
lookupAdj VName
v = do
maybeAdj <- (RState -> Maybe Adj) -> ADM (Maybe Adj)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((RState -> Maybe Adj) -> ADM (Maybe Adj))
-> (RState -> Maybe Adj) -> ADM (Maybe Adj)
forall a b. (a -> b) -> a -> b
$ VName -> Map VName Adj -> Maybe Adj
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v (Map VName Adj -> Maybe Adj)
-> (RState -> Map VName Adj) -> RState -> Maybe Adj
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RState -> Map VName Adj
stateAdjs
case maybeAdj of
Maybe Adj
Nothing -> do
v_t <- VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v
case v_t of
Acc VName
_ Shape
shape [Prim PrimType
t] NoUniqueness
_ -> Adj -> ADM Adj
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Adj -> ADM Adj) -> Adj -> ADM Adj
forall a b. (a -> b) -> a -> b
$ Shape -> PrimType -> Adj
AdjZero Shape
shape PrimType
t
TypeBase Shape NoUniqueness
_ -> Adj -> ADM Adj
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Adj -> ADM Adj) -> Adj -> ADM Adj
forall a b. (a -> b) -> a -> b
$ Shape -> PrimType -> Adj
AdjZero (TypeBase Shape NoUniqueness -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase Shape NoUniqueness
v_t) (TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
v_t)
Just Adj
v_adj -> Adj -> ADM Adj
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Adj
v_adj
lookupAdjVal :: VName -> ADM VName
lookupAdjVal :: VName -> ADM VName
lookupAdjVal VName
v = Adj -> ADM VName
adjVal (Adj -> ADM VName) -> ADM Adj -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> ADM Adj
lookupAdj VName
v
updateAdjIndex :: VName -> (InBounds, SubExp) -> SubExp -> ADM ()
updateAdjIndex :: VName -> (InBounds, SubExp) -> SubExp -> ADM ()
updateAdjIndex VName
v (InBounds
check, SubExp
i) SubExp
se = do
maybeAdj <- (RState -> Maybe Adj) -> ADM (Maybe Adj)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((RState -> Maybe Adj) -> ADM (Maybe Adj))
-> (RState -> Maybe Adj) -> ADM (Maybe Adj)
forall a b. (a -> b) -> a -> b
$ VName -> Map VName Adj -> Maybe Adj
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v (Map VName Adj -> Maybe Adj)
-> (RState -> Map VName Adj) -> RState -> Maybe Adj
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RState -> Map VName Adj
stateAdjs
t <- lookupType v
let iv = (InBounds
check, SubExp
i, SubExp
se)
case maybeAdj of
Maybe Adj
Nothing -> do
VName -> Adj -> ADM ()
setAdj VName
v (Adj -> ADM ()) -> Adj -> ADM ()
forall a b. (a -> b) -> a -> b
$ Sparse -> Adj
AdjSparse (Sparse -> Adj) -> Sparse -> Adj
forall a b. (a -> b) -> a -> b
$ Shape -> PrimType -> [(InBounds, SubExp, SubExp)] -> Sparse
Sparse (TypeBase Shape NoUniqueness -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase Shape NoUniqueness
t) (TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
t) [(InBounds, SubExp, SubExp)
iv]
Just AdjZero {} ->
VName -> Adj -> ADM ()
setAdj VName
v (Adj -> ADM ()) -> Adj -> ADM ()
forall a b. (a -> b) -> a -> b
$ Sparse -> Adj
AdjSparse (Sparse -> Adj) -> Sparse -> Adj
forall a b. (a -> b) -> a -> b
$ Shape -> PrimType -> [(InBounds, SubExp, SubExp)] -> Sparse
Sparse (TypeBase Shape NoUniqueness -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase Shape NoUniqueness
t) (TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
t) [(InBounds, SubExp, SubExp)
iv]
Just (AdjSparse (Sparse Shape
shape PrimType
pt [(InBounds, SubExp, SubExp)]
ivs)) ->
VName -> Adj -> ADM ()
setAdj VName
v (Adj -> ADM ()) -> Adj -> ADM ()
forall a b. (a -> b) -> a -> b
$ Sparse -> Adj
AdjSparse (Sparse -> Adj) -> Sparse -> Adj
forall a b. (a -> b) -> a -> b
$ Shape -> PrimType -> [(InBounds, SubExp, SubExp)] -> Sparse
Sparse Shape
shape PrimType
pt ([(InBounds, SubExp, SubExp)] -> Sparse)
-> [(InBounds, SubExp, SubExp)] -> Sparse
forall a b. (a -> b) -> a -> b
$ (InBounds, SubExp, SubExp)
iv (InBounds, SubExp, SubExp)
-> [(InBounds, SubExp, SubExp)] -> [(InBounds, SubExp, SubExp)]
forall a. a -> [a] -> [a]
: [(InBounds, SubExp, SubExp)]
ivs
Just adj :: Adj
adj@AdjVal {} -> do
v_adj <- Adj -> ADM VName
adjVal Adj
adj
v_adj_t <- lookupType v_adj
se_v <- letExp "se_v" $ BasicOp $ SubExp se
insAdj v
=<< case v_adj_t of
Acc {} -> do
let stms :: Safety -> ADM VName
stms Safety
s = do
dims <- TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> ADM (TypeBase Shape NoUniqueness) -> ADM [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
se_v
~[v_adj'] <-
tabNest (length dims) [se_v, v_adj] $ \[VName]
is [VName
se_v', VName
v_adj'] ->
[Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"acc" (Exp SOACS -> ADM [VName])
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM [VName]) -> BasicOp -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
Safety -> VName -> [SubExp] -> [SubExp] -> BasicOp
UpdateAcc Safety
s VName
v_adj' (SubExp
i SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
is) [VName -> SubExp
Var VName
se_v']
pure v_adj'
case InBounds
check of
CheckBounds Maybe SubExp
_ -> Safety -> ADM VName
stms Safety
Safe
InBounds
AssumeBounds -> Safety -> ADM VName
stms Safety
Unsafe
InBounds
OutOfBounds -> VName -> ADM VName
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v_adj
TypeBase Shape NoUniqueness
_ -> do
let stms :: Safety -> ADM VName
stms Safety
s = do
v_adj_i <-
[Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
v_adj [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_i") (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index VName
v_adj (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
TypeBase Shape NoUniqueness -> [DimIndex SubExp] -> Slice SubExp
fullSlice TypeBase Shape NoUniqueness
v_adj_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i]
se_update <- letSubExp "updated_adj_i" =<< addExp se_v v_adj_i
letExp (baseString v_adj) . BasicOp $
Update s v_adj (fullSlice v_adj_t [DimFix i]) se_update
case InBounds
check of
CheckBounds Maybe SubExp
_ -> Safety -> ADM VName
stms Safety
Safe
InBounds
AssumeBounds -> Safety -> ADM VName
stms Safety
Unsafe
InBounds
OutOfBounds -> VName -> ADM VName
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v_adj
updateAdjWithSafety :: VName -> VName -> Safety -> ADM ()
updateAdjWithSafety :: VName -> VName -> Safety -> ADM ()
updateAdjWithSafety VName
v VName
d Safety
safety = do
maybeAdj <- (RState -> Maybe Adj) -> ADM (Maybe Adj)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((RState -> Maybe Adj) -> ADM (Maybe Adj))
-> (RState -> Maybe Adj) -> ADM (Maybe Adj)
forall a b. (a -> b) -> a -> b
$ VName -> Map VName Adj -> Maybe Adj
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v (Map VName Adj -> Maybe Adj)
-> (RState -> Map VName Adj) -> RState -> Maybe Adj
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RState -> Map VName Adj
stateAdjs
case maybeAdj of
Maybe Adj
Nothing ->
VName -> VName -> ADM ()
insAdj VName
v VName
d
Just Adj
adj -> do
v_adj <- Adj -> ADM VName
adjVal Adj
adj
v_adj_t <- lookupType v_adj
case v_adj_t of
Acc {} -> do
dims <- TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> ADM (TypeBase Shape NoUniqueness) -> ADM [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
d
~[v_adj'] <-
tabNest (length dims) [d, v_adj] $ \[VName]
is [VName
d', VName
v_adj'] ->
[Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"acc" (Exp SOACS -> ADM [VName])
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM [VName]) -> BasicOp -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
Safety -> VName -> [SubExp] -> [SubExp] -> BasicOp
UpdateAcc Safety
safety VName
v_adj' ((VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
is) [VName -> SubExp
Var VName
d']
insAdj v v_adj'
TypeBase Shape NoUniqueness
_ -> do
v_adj' <- [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
v [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_adj") (Exp SOACS -> ADM VName) -> ADM (Exp SOACS) -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> VName -> ADM (Exp SOACS)
addExp VName
v_adj VName
d
insAdj v v_adj'
updateAdjSliceWithSafety :: Slice SubExp -> VName -> VName -> Safety -> ADM ()
updateAdjSliceWithSafety :: Slice SubExp -> VName -> VName -> Safety -> ADM ()
updateAdjSliceWithSafety (Slice [DimFix SubExp
i]) VName
v VName
d Safety
safety =
VName -> (InBounds, SubExp) -> SubExp -> ADM ()
updateAdjIndex VName
v (InBounds
bounds, SubExp
i) (VName -> SubExp
Var VName
d)
where
bounds :: InBounds
bounds = case Safety
safety of
Safety
Safe -> Maybe SubExp -> InBounds
CheckBounds Maybe SubExp
forall a. Maybe a
Nothing
Safety
Unsafe -> InBounds
AssumeBounds
updateAdjSliceWithSafety Slice SubExp
slice VName
v VName
d Safety
safety = do
t <- VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v
v_adj <- lookupAdjVal v
v_adj_t <- lookupType v_adj
v_adj' <- case v_adj_t of
Acc {} -> do
let dims :: [SubExp]
dims = Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice
~[v_adj'] <-
Int
-> [VName] -> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
tabNest ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
dims) [VName
d, VName
v_adj] (([VName] -> [VName] -> ADM [VName]) -> ADM [VName])
-> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ \[VName]
is [VName
d', VName
v_adj'] -> do
slice' <-
(TPrimExp Int64 VName -> ADM SubExp)
-> [TPrimExp Int64 VName] -> ADM [SubExp]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ([Char] -> TPrimExp Int64 VName -> ADM SubExp
forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
[Char] -> a -> m SubExp
toSubExp [Char]
"index") ([TPrimExp Int64 VName] -> ADM [SubExp])
-> [TPrimExp Int64 VName] -> ADM [SubExp]
forall a b. (a -> b) -> a -> b
$
Slice (TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall d. Num d => Slice d -> [d] -> [d]
fixSlice ((SubExp -> TPrimExp Int64 VName)
-> Slice SubExp -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> Slice a -> Slice b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Slice SubExp
slice) ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName])
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$
(VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 [VName]
is
letTupExp (baseString v_adj') . BasicOp $
UpdateAcc safety v_adj' slice' [Var d']
pure v_adj'
TypeBase Shape NoUniqueness
_ -> do
v_adjslice <-
if TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType TypeBase Shape NoUniqueness
t
then VName -> ADM VName
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v_adj
else [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
v [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_slice") (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
v_adj Slice SubExp
slice
letInPlace "updated_adj" v_adj slice =<< addExp v_adjslice d
insAdj v v_adj'
updateAdj :: VName -> VName -> ADM ()
updateAdj :: VName -> VName -> ADM ()
updateAdj VName
v VName
d = VName -> VName -> Safety -> ADM ()
updateAdjWithSafety VName
v VName
d Safety
Unsafe
updateAdjSlice :: Slice SubExp -> VName -> VName -> ADM ()
updateAdjSlice :: Slice SubExp -> VName -> VName -> ADM ()
updateAdjSlice Slice SubExp
slice VName
v VName
d = Slice SubExp -> VName -> VName -> Safety -> ADM ()
updateAdjSliceWithSafety Slice SubExp
slice VName
v VName
d Safety
Unsafe
updateSubExpAdj :: SubExp -> VName -> ADM ()
updateSubExpAdj :: SubExp -> VName -> ADM ()
updateSubExpAdj Constant {} VName
_ = () -> ADM ()
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
updateSubExpAdj (Var VName
v) VName
d = ADM () -> ADM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> ADM ()
updateAdj VName
v VName
d
isActive :: VName -> ADM Bool
isActive :: VName -> ADM Bool
isActive = (TypeBase Shape NoUniqueness -> Bool)
-> ADM (TypeBase Shape NoUniqueness) -> ADM Bool
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness -> Bool
forall a. Eq a => a -> a -> Bool
/= PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit) (ADM (TypeBase Shape NoUniqueness) -> ADM Bool)
-> (VName -> ADM (TypeBase Shape NoUniqueness))
-> VName
-> ADM Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType
subAD :: ADM a -> ADM a
subAD :: forall a. ADM a -> ADM a
subAD ADM a
m = do
old_state_adjs <- (RState -> Map VName Adj) -> ADM (Map VName Adj)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets RState -> Map VName Adj
stateAdjs
x <- m
modify $ \RState
s -> RState
s {stateAdjs = old_state_adjs}
pure x
subSubsts :: ADM a -> ADM a
subSubsts :: forall a. ADM a -> ADM a
subSubsts ADM a
m = do
old_state_substs <- (RState -> Substitutions) -> ADM Substitutions
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets RState -> Substitutions
stateSubsts
x <- m
modify $ \RState
s -> RState
s {stateSubsts = old_state_substs}
pure x
data VjpOps = VjpOps
{ VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda :: [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS),
VjpOps -> Stm SOACS -> ADM () -> ADM ()
vjpStm :: Stm SOACS -> ADM () -> ADM ()
}
setLoopTape :: VName -> VName -> ADM ()
setLoopTape :: VName -> VName -> ADM ()
setLoopTape VName
v VName
vs = (RState -> RState) -> ADM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((RState -> RState) -> ADM ()) -> (RState -> RState) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \RState
env ->
RState
env {stateLoopTape = M.insert v vs $ stateLoopTape env}
lookupLoopTape :: VName -> ADM (Maybe VName)
lookupLoopTape :: VName -> ADM (Maybe VName)
lookupLoopTape VName
v = (RState -> Maybe VName) -> ADM (Maybe VName)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((RState -> Maybe VName) -> ADM (Maybe VName))
-> (RState -> Maybe VName) -> ADM (Maybe VName)
forall a b. (a -> b) -> a -> b
$ VName -> Substitutions -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v (Substitutions -> Maybe VName)
-> (RState -> Substitutions) -> RState -> Maybe VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RState -> Substitutions
stateLoopTape
substLoopTape :: VName -> VName -> ADM ()
substLoopTape :: VName -> VName -> ADM ()
substLoopTape VName
v VName
v' = (VName -> ADM ()) -> Maybe VName -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (VName -> VName -> ADM ()
setLoopTape VName
v') (Maybe VName -> ADM ()) -> ADM (Maybe VName) -> ADM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> ADM (Maybe VName)
lookupLoopTape VName
v
renameLoopTape :: Substitutions -> ADM ()
renameLoopTape :: Substitutions -> ADM ()
renameLoopTape = ((VName, VName) -> ADM ()) -> [(VName, VName)] -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((VName -> VName -> ADM ()) -> (VName, VName) -> ADM ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VName -> VName -> ADM ()
substLoopTape) ([(VName, VName)] -> ADM ())
-> (Substitutions -> [(VName, VName)]) -> Substitutions -> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Substitutions -> [(VName, VName)]
forall k a. Map k a -> [(k, a)]
M.toList