module Futhark.Optimise.ReduceDeviceSyncs (reduceDeviceSyncs) where
import Control.Monad
import Control.Monad.Reader
import Control.Monad.State hiding (State)
import Data.Bifunctor (second)
import Data.Foldable
import Data.IntMap.Strict qualified as IM
import Data.List (transpose, zip4)
import Data.Map.Strict qualified as M
import Data.Sequence ((><), (|>))
import Data.Text qualified as T
import Futhark.Construct (fullSlice, mkBody, sliceDim)
import Futhark.Error
import Futhark.IR.GPU
import Futhark.MonadFreshNames
import Futhark.Optimise.ReduceDeviceSyncs.MigrationTable
import Futhark.Pass
import Futhark.Transform.Substitute
reduceDeviceSyncs :: Pass GPU GPU
reduceDeviceSyncs :: Pass GPU GPU
reduceDeviceSyncs =
String -> String -> (Prog GPU -> PassM (Prog GPU)) -> Pass GPU GPU
forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass
String
"reduce device synchronizations"
String
"Move host statements to device to reduce blocking memory operations."
((Prog GPU -> PassM (Prog GPU)) -> Pass GPU GPU)
-> (Prog GPU -> PassM (Prog GPU)) -> Pass GPU GPU
forall a b. (a -> b) -> a -> b
$ \Prog GPU
prog -> do
let hof :: HostOnlyFuns
hof = [FunDef GPU] -> HostOnlyFuns
hostOnlyFunDefs ([FunDef GPU] -> HostOnlyFuns) -> [FunDef GPU] -> HostOnlyFuns
forall a b. (a -> b) -> a -> b
$ Prog GPU -> [FunDef GPU]
forall rep. Prog rep -> [FunDef rep]
progFuns Prog GPU
prog
consts_mt :: MigrationTable
consts_mt = HostOnlyFuns -> [FunDef GPU] -> Stms GPU -> MigrationTable
analyseConsts HostOnlyFuns
hof (Prog GPU -> [FunDef GPU]
forall rep. Prog rep -> [FunDef rep]
progFuns Prog GPU
prog) (Prog GPU -> Stms GPU
forall rep. Prog rep -> Stms rep
progConsts Prog GPU
prog)
consts <- MigrationTable -> Stms GPU -> PassM (Stms GPU)
forall {m :: * -> *}.
MonadFreshNames m =>
MigrationTable -> Stms GPU -> m (Stms GPU)
onConsts MigrationTable
consts_mt (Stms GPU -> PassM (Stms GPU)) -> Stms GPU -> PassM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Prog GPU -> Stms GPU
forall rep. Prog rep -> Stms rep
progConsts Prog GPU
prog
funs <- parPass (onFun hof consts_mt) (progFuns prog)
pure $ prog {progConsts = consts, progFuns = funs}
where
onConsts :: MigrationTable -> Stms GPU -> m (Stms GPU)
onConsts MigrationTable
consts_mt Stms GPU
stms =
MigrationTable -> ReduceM (Stms GPU) -> m (Stms GPU)
forall (m :: * -> *) a.
MonadFreshNames m =>
MigrationTable -> ReduceM a -> m a
runReduceM MigrationTable
consts_mt (Stms GPU -> ReduceM (Stms GPU)
optimizeStms Stms GPU
stms)
onFun :: HostOnlyFuns -> MigrationTable -> FunDef GPU -> m (FunDef GPU)
onFun HostOnlyFuns
hof MigrationTable
consts_mt FunDef GPU
fd = do
let mt :: MigrationTable
mt = MigrationTable
consts_mt MigrationTable -> MigrationTable -> MigrationTable
forall a. Semigroup a => a -> a -> a
<> HostOnlyFuns -> FunDef GPU -> MigrationTable
analyseFunDef HostOnlyFuns
hof FunDef GPU
fd
MigrationTable -> ReduceM (FunDef GPU) -> m (FunDef GPU)
forall (m :: * -> *) a.
MonadFreshNames m =>
MigrationTable -> ReduceM a -> m a
runReduceM MigrationTable
mt (FunDef GPU -> ReduceM (FunDef GPU)
optimizeFunDef FunDef GPU
fd)
optimizeFunDef :: FunDef GPU -> ReduceM (FunDef GPU)
optimizeFunDef :: FunDef GPU -> ReduceM (FunDef GPU)
optimizeFunDef FunDef GPU
fd = do
let body :: Body GPU
body = FunDef GPU -> Body GPU
forall rep. FunDef rep -> Body rep
funDefBody FunDef GPU
fd
stms' <- Stms GPU -> ReduceM (Stms GPU)
optimizeStms (Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms Body GPU
body)
pure $ fd {funDefBody = body {bodyStms = stms'}}
optimizeBody :: Body GPU -> ReduceM (Body GPU)
optimizeBody :: Body GPU -> ReduceM (Body GPU)
optimizeBody (Body BodyDec GPU
_ Stms GPU
stms Result
res) = do
stms' <- Stms GPU -> ReduceM (Stms GPU)
optimizeStms Stms GPU
stms
res' <- resolveResult res
pure (Body () stms' res')
optimizeStms :: Stms GPU -> ReduceM (Stms GPU)
optimizeStms :: Stms GPU -> ReduceM (Stms GPU)
optimizeStms = (Stms GPU -> Stm GPU -> ReduceM (Stms GPU))
-> Stms GPU -> Stms GPU -> ReduceM (Stms GPU)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Stms GPU -> Stm GPU -> ReduceM (Stms GPU)
optimizeStm Stms GPU
forall a. Monoid a => a
mempty
optimizeStm :: Stms GPU -> Stm GPU -> ReduceM (Stms GPU)
optimizeStm :: Stms GPU -> Stm GPU -> ReduceM (Stms GPU)
optimizeStm Stms GPU
out Stm GPU
stm = do
move <- (MigrationTable -> Bool) -> ReduceM Bool
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (Stm GPU -> MigrationTable -> Bool
shouldMoveStm Stm GPU
stm)
if move
then moveStm out stm
else case stmExp stm of
BasicOp (Update Safety
safety VName
arr Slice SubExp
slice (Var VName
v))
| Just [SubExp]
_ <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slice -> do
dev <- SubExp -> ReduceM (Maybe VName)
storedScalar (VName -> SubExp
Var VName
v)
case dev of
Maybe VName
Nothing -> Stms GPU -> ReduceM (Stms GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm)
Just VName
dst -> do
let dims :: [DimIndex SubExp]
dims = Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice
let ([DimIndex SubExp]
outer, [DimFix SubExp
i]) = Int -> [DimIndex SubExp] -> ([DimIndex SubExp], [DimIndex SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([DimIndex SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
dims Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) [DimIndex SubExp]
dims
let one :: SubExp
one = IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1
let slice' :: Slice SubExp
slice' = [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp]
outer [DimIndex SubExp] -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
i SubExp
one SubExp
one]
let e :: Exp rep
e = BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
safety VName
arr Slice SubExp
slice' (VName -> SubExp
Var VName
dst))
let stm' :: Stm GPU
stm' = Stm GPU
stm {stmExp = e}
Stms GPU -> ReduceM (Stms GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm')
BasicOp (Replicate (Shape [SubExp]
dims) (Var VName
v))
| Pat [PatElem VName
n LetDec GPU
arr_t] <- Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm -> do
v' <- VName -> ReduceM VName
resolveName VName
v
let v_kept_on_device = VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
/= VName
v'
gpubody_ok <- gets stateGPUBodyOk
case v_kept_on_device of
Bool
False -> Stms GPU -> ReduceM (Stms GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm)
Bool
True
| (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) [SubExp]
dims,
Just Type
t' <- Int -> Type -> Maybe Type
forall u.
Int
-> TypeBase (ShapeBase SubExp) u
-> Maybe (TypeBase (ShapeBase SubExp) u)
peelArray Int
1 Type
LetDec GPU
arr_t,
Bool
gpubody_ok -> do
let n' :: VName
n' = Name -> Int -> VName
VName (VName -> Name
baseName VName
n Name -> String -> Name
`withSuffix` String
"_inner") Int
0
let pat' :: Pat Type
pat' = [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem VName
n' Type
t']
let e' :: Exp rep
e' = BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> ShapeBase SubExp) -> [SubExp] -> ShapeBase SubExp
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExp]
forall a. HasCallStack => [a] -> [a]
tail [SubExp]
dims) (VName -> SubExp
Var VName
v)
let stm' :: Stm GPU
stm' = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec GPU)
pat' (Stm GPU -> StmAux (ExpDec GPU)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm GPU
stm) Exp GPU
forall {rep}. Exp rep
e'
gpubody <- RewriteM (Stm GPU) -> ReduceM (Stm GPU)
inGPUBody (Stm GPU -> RewriteM (Stm GPU)
rewriteStm Stm GPU
stm')
pure (out |> gpubody {stmPat = stmPat stm})
Bool
True
| [SubExp] -> SubExp
forall a. HasCallStack => [a] -> a
last [SubExp]
dims SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1 ->
let e' :: Exp rep
e' = BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> ShapeBase SubExp) -> [SubExp] -> ShapeBase SubExp
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExp]
forall a. HasCallStack => [a] -> [a]
init [SubExp]
dims) (VName -> SubExp
Var VName
v')
stm' :: Stm GPU
stm' = Stm GPU
stm {stmExp = e'}
in Stms GPU -> ReduceM (Stms GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm')
Bool
True -> do
n' <- VName -> ReduceM VName
forall (m :: * -> *). MonadFreshNames m => VName -> m VName
newName VName
n
let dims' = [SubExp]
dims [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1]
let arr_t' = PrimType -> ShapeBase SubExp -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
LetDec GPU
arr_t) ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
dims') NoUniqueness
NoUniqueness
let pat' = [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem VName
n' Type
arr_t']
let e' = BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
dims) (VName -> SubExp
Var VName
v')
let repl = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec GPU)
pat' (Stm GPU -> StmAux (ExpDec GPU)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm GPU
stm) Exp GPU
forall {rep}. Exp rep
e'
let slice = (SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (Type -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims Type
LetDec GPU
arr_t)
let slice' = [DimIndex SubExp]
slice [DimIndex SubExp] -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0]
let idx = BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
n' ([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
slice')
let index = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) Exp GPU
forall {rep}. Exp rep
idx
pure (out |> repl |> index)
BasicOp {} ->
Stms GPU -> ReduceM (Stms GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm)
Apply {} ->
Stms GPU -> ReduceM (Stms GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm)
Match [SubExp]
ses [Case (Body GPU)]
cases Body GPU
defbody (MatchDec [BranchType GPU]
btypes MatchSort
sort) -> do
cases_stms <- (Case (Body GPU) -> ReduceM (Stms GPU))
-> [Case (Body GPU)] -> ReduceM [Stms GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Stms GPU -> ReduceM (Stms GPU)
optimizeStms (Stms GPU -> ReduceM (Stms GPU))
-> (Case (Body GPU) -> Stms GPU)
-> Case (Body GPU)
-> ReduceM (Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms (Body GPU -> Stms GPU)
-> (Case (Body GPU) -> Body GPU) -> Case (Body GPU) -> Stms GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Case (Body GPU) -> Body GPU
forall body. Case body -> body
caseBody) [Case (Body GPU)]
cases
let cases_res = (Case (Body GPU) -> Result) -> [Case (Body GPU)] -> [Result]
forall a b. (a -> b) -> [a] -> [b]
map (Body GPU -> Result
forall rep. Body rep -> Result
bodyResult (Body GPU -> Result)
-> (Case (Body GPU) -> Body GPU) -> Case (Body GPU) -> Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Case (Body GPU) -> Body GPU
forall body. Case body -> body
caseBody) [Case (Body GPU)]
cases
defbody_stms <- optimizeStms $ bodyStms defbody
let defbody_res = Body GPU -> Result
forall rep. Body rep -> Result
bodyResult Body GPU
defbody
let bmerge ([(PatElem Type, Result, ExtType)]
acc, [Stms GPU]
all_stms) (PatElem Type
pe, Result
reses, ExtType
bt) = do
let onHost :: SubExp -> ReduceM Bool
onHost (Var VName
v) = (VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
==) (VName -> Bool) -> ReduceM VName -> ReduceM Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ReduceM VName
resolveName VName
v
onHost SubExp
_ = Bool -> ReduceM Bool
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
on_host <- [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> ReduceM [Bool] -> ReduceM Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExpRes -> ReduceM Bool) -> Result -> ReduceM [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SubExp -> ReduceM Bool
onHost (SubExp -> ReduceM Bool)
-> (SubExpRes -> SubExp) -> SubExpRes -> ReduceM Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) Result
reses
if on_host
then
pure ((pe, reses, bt) : acc, all_stms)
else do
(all_stms', arrs) <-
fmap unzip $
forM (zip all_stms reses) $ \(Stms GPU
stms, SubExpRes
res) ->
Stms GPU -> SubExp -> Type -> ReduceM (Stms GPU, VName)
storeScalar Stms GPU
stms (SubExpRes -> SubExp
resSubExp SubExpRes
res) (PatElem Type -> Type
forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem Type
pe)
pe' <- arrayizePatElem pe
let bt' = Type -> ExtType
forall u.
TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase ExtSize) u
staticShapes1 (PatElem Type -> Type
forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem Type
pe')
reses' = (Certs -> SubExp -> SubExpRes) -> [Certs] -> [SubExp] -> Result
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Certs -> SubExp -> SubExpRes
SubExpRes ((SubExpRes -> Certs) -> Result -> [Certs]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> Certs
resCerts Result
reses) ((VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
arrs)
pure ((pe', reses', bt') : acc, all_stms')
pes = Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems (Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm)
(acc, ~(defbody_stms' : cases_stms')) <-
foldM bmerge ([], defbody_stms : cases_stms) $
zip3 pes (transpose $ defbody_res : cases_res) btypes
let (pes', reses, btypes') = unzip3 (reverse acc)
let cases' =
([Maybe PrimValue] -> Body GPU -> Case (Body GPU))
-> [[Maybe PrimValue]] -> [Body GPU] -> [Case (Body GPU)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith [Maybe PrimValue] -> Body GPU -> Case (Body GPU)
forall body. [Maybe PrimValue] -> body -> Case body
Case ((Case (Body GPU) -> [Maybe PrimValue])
-> [Case (Body GPU)] -> [[Maybe PrimValue]]
forall a b. (a -> b) -> [a] -> [b]
map Case (Body GPU) -> [Maybe PrimValue]
forall body. Case body -> [Maybe PrimValue]
casePat [Case (Body GPU)]
cases) ([Body GPU] -> [Case (Body GPU)])
-> [Body GPU] -> [Case (Body GPU)]
forall a b. (a -> b) -> a -> b
$
(Stms GPU -> Result -> Body GPU)
-> [Stms GPU] -> [Result] -> [Body GPU]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Stms GPU -> Result -> Body GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody [Stms GPU]
cases_stms' ([Result] -> [Body GPU]) -> [Result] -> [Body GPU]
forall a b. (a -> b) -> a -> b
$
Int -> [Result] -> [Result]
forall a. Int -> [a] -> [a]
drop Int
1 ([Result] -> [Result]) -> [Result] -> [Result]
forall a b. (a -> b) -> a -> b
$
[Result] -> [Result]
forall a. [[a]] -> [[a]]
transpose [Result]
reses
defbody' = Stms GPU -> Result -> Body GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms GPU
defbody_stms' (Result -> Body GPU) -> Result -> Body GPU
forall a b. (a -> b) -> a -> b
$ (Result -> SubExpRes) -> [Result] -> Result
forall a b. (a -> b) -> [a] -> [b]
map Result -> SubExpRes
forall a. HasCallStack => [a] -> a
head [Result]
reses
e' = [SubExp]
-> [Case (Body GPU)]
-> Body GPU
-> MatchDec (BranchType GPU)
-> Exp GPU
forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
ses [Case (Body GPU)]
cases' Body GPU
defbody' ([ExtType] -> MatchSort -> MatchDec ExtType
forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [ExtType]
btypes' MatchSort
sort)
stm' = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type]
pes') (Stm GPU -> StmAux (ExpDec GPU)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm GPU
stm) Exp GPU
e'
foldM addRead (out |> stm') (zip pes pes')
Loop [(FParam GPU, SubExp)]
params LoopForm
lform Body GPU
body -> do
let lmerge :: ([(PatElem Type, (Param DeclType, SubExp))], Stms GPU, Stms GPU)
-> (PatElem Type, (Param DeclType, SubExp), MigrationStatus)
-> ReduceM
([(PatElem Type, (Param DeclType, SubExp))], Stms GPU, Stms GPU)
lmerge ([(PatElem Type, (Param DeclType, SubExp))]
res, Stms GPU
stms, Stms GPU
rebinds) (PatElem Type
pe, (Param DeclType, SubExp)
param, MigrationStatus
StayOnHost) =
([(PatElem Type, (Param DeclType, SubExp))], Stms GPU, Stms GPU)
-> ReduceM
([(PatElem Type, (Param DeclType, SubExp))], Stms GPU, Stms GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((PatElem Type
pe, (Param DeclType, SubExp)
param) (PatElem Type, (Param DeclType, SubExp))
-> [(PatElem Type, (Param DeclType, SubExp))]
-> [(PatElem Type, (Param DeclType, SubExp))]
forall a. a -> [a] -> [a]
: [(PatElem Type, (Param DeclType, SubExp))]
res, Stms GPU
stms, Stms GPU
rebinds)
lmerge ([(PatElem Type, (Param DeclType, SubExp))]
res, Stms GPU
stms, Stms GPU
rebinds) (PatElem Type
pe, (Param Attrs
_ VName
pn DeclType
pt, SubExp
pval), MigrationStatus
_) = do
pe' <- PatElem Type -> ReduceM (PatElem Type)
arrayizePatElem PatElem Type
pe
(stms', arr) <- storeScalar stms pval (fromDecl pt)
pn' <- newName pn
let pt' = Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl (PatElem Type -> Type
forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem Type
pe') Uniqueness
Nonunique
let pval' = VName -> SubExp
Var VName
arr
let param' = (Attrs -> VName -> DeclType -> Param DeclType
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty VName
pn' DeclType
pt', SubExp
pval')
rebinds' <- (pe {patElemName = pn}) `migratedTo` (pn', rebinds)
pure ((pe', param') : res, stms', rebinds')
mt <- ReduceM MigrationTable
forall r (m :: * -> *). MonadReader r m => m r
ask
let pes = Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems (Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm)
let mss = ((Param DeclType, SubExp) -> MigrationStatus)
-> [(Param DeclType, SubExp)] -> [MigrationStatus]
forall a b. (a -> b) -> [a] -> [b]
map (\(Param Attrs
_ VName
n DeclType
_, SubExp
_) -> VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt) [(Param DeclType, SubExp)]
[(FParam GPU, SubExp)]
params
(zipped', out', rebinds) <-
foldM lmerge ([], out, mempty) (zip3 pes params mss)
let (pes', params') = unzip (reverse zipped')
let body1 = Body GPU
body {bodyStms = rebinds >< bodyStms body}
body2 <- optimizeBody body1
let zipped =
[MigrationStatus]
-> Result
-> [SubExp]
-> [Type]
-> [(MigrationStatus, SubExpRes, SubExp, Type)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4
[MigrationStatus]
mss
(Body GPU -> Result
forall rep. Body rep -> Result
bodyResult Body GPU
body2)
((SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (Result -> [SubExp]) -> Result -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Body GPU -> Result
forall rep. Body rep -> Result
bodyResult Body GPU
body)
((PatElem Type -> Type) -> [PatElem Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map PatElem Type -> Type
forall dec. Typed dec => PatElem dec -> Type
patElemType [PatElem Type]
pes)
let rstore (Stms GPU
bstms, Result
res) (MigrationStatus
StayOnHost, SubExpRes
r, SubExp
_, Type
_) =
(Stms GPU, Result) -> ReduceM (Stms GPU, Result)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
bstms, SubExpRes
r SubExpRes -> Result -> Result
forall a. a -> [a] -> [a]
: Result
res)
rstore (Stms GPU
bstms, Result
res) (MigrationStatus
_, SubExpRes Certs
certs SubExp
_, SubExp
se, Type
t) = do
(bstms', dev) <- Stms GPU -> SubExp -> Type -> ReduceM (Stms GPU, VName)
storeScalar Stms GPU
bstms SubExp
se Type
t
pure (bstms', SubExpRes certs (Var dev) : res)
(bstms, res) <- foldM rstore (bodyStms body2, []) zipped
let body3 = Body GPU
body2 {bodyStms = bstms, bodyResult = reverse res}
let e' = [(FParam GPU, SubExp)] -> LoopForm -> Body GPU -> Exp GPU
forall rep.
[(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
Loop [(Param DeclType, SubExp)]
[(FParam GPU, SubExp)]
params' LoopForm
lform Body GPU
body3
let stm' = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type]
pes') (Stm GPU -> StmAux (ExpDec GPU)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm GPU
stm) Exp GPU
e'
foldM addRead (out' |> stm') (zip pes pes')
WithAcc [WithAccInput GPU]
inputs Lambda GPU
lmd -> do
let getAcc :: TypeBase shape u -> VName
getAcc (Acc VName
a ShapeBase SubExp
_ [Type]
_ u
_) = VName
a
getAcc TypeBase shape u
_ =
String -> VName
forall a. String -> a
compilerBugS
String
"Type error: WithAcc expression did not return accumulator."
let accs :: [(VName, WithAccInput GPU)]
accs = (Type -> WithAccInput GPU -> (VName, WithAccInput GPU))
-> [Type] -> [WithAccInput GPU] -> [(VName, WithAccInput GPU)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Type
t WithAccInput GPU
i -> (Type -> VName
forall {shape} {u}. TypeBase shape u -> VName
getAcc Type
t, WithAccInput GPU
i)) (Lambda GPU -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda GPU
lmd) [WithAccInput GPU]
inputs
inputs' <- ((VName, WithAccInput GPU) -> ReduceM (WithAccInput GPU))
-> [(VName, WithAccInput GPU)] -> ReduceM [WithAccInput GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((VName -> WithAccInput GPU -> ReduceM (WithAccInput GPU))
-> (VName, WithAccInput GPU) -> ReduceM (WithAccInput GPU)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VName -> WithAccInput GPU -> ReduceM (WithAccInput GPU)
optimizeWithAccInput) [(VName, WithAccInput GPU)]
accs
let body = Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
lmd
stms' <- optimizeStms (bodyStms body)
let rewrite (SubExpRes Certs
certs SubExp
se, Type
t, PatElem Type
pe) =
do
se' <- SubExp -> ReduceM SubExp
resolveSubExp SubExp
se
if se == se'
then pure (SubExpRes certs se, t, pe)
else do
pe' <- arrayizePatElem pe
let t' = PatElem Type -> Type
forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem Type
pe'
pure (SubExpRes certs se', t', pe')
let len = [WithAccInput GPU] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput GPU]
inputs
let (res0, res1) = splitAt len (bodyResult body)
let (rts0, rts1) = splitAt len (lambdaReturnType lmd)
let pes = Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems (Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm)
let (pes0, pes1) = splitAt (length pes - length res1) pes
(res1', rts1', pes1') <- unzip3 <$> mapM rewrite (zip3 res1 rts1 pes1)
let res' = Result
res0 Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
res1'
let rts' = [Type]
rts0 [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
rts1'
let pes' = [PatElem Type]
pes0 [PatElem Type] -> [PatElem Type] -> [PatElem Type]
forall a. [a] -> [a] -> [a]
++ [PatElem Type]
pes1'
let body' = BodyDec GPU -> Stms GPU -> Result -> Body GPU
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms GPU
stms' Result
res'
let lmd' = Lambda GPU
lmd {lambdaBody = body', lambdaReturnType = rts'}
let e' = [WithAccInput GPU] -> Lambda GPU -> Exp GPU
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput GPU]
inputs' Lambda GPU
lmd'
let stm' = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type]
pes') (Stm GPU -> StmAux (ExpDec GPU)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm GPU
stm) Exp GPU
e'
foldM addRead (out |> stm') (zip pes pes')
Op Op GPU
op -> do
op' <- HostOp SOAC GPU -> ReduceM (HostOp SOAC GPU)
forall (op :: * -> *). HostOp op GPU -> ReduceM (HostOp op GPU)
optimizeHostOp Op GPU
HostOp SOAC GPU
op
pure (out |> stm {stmExp = Op op'})
where
addRead :: Stms GPU -> (PatElem Type, PatElem dec) -> ReduceM (Stms GPU)
addRead Stms GPU
stms (pe :: PatElem Type
pe@(PatElem VName
n Type
_), PatElem VName
dev dec
_)
| VName
n VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
dev = Stms GPU -> ReduceM (Stms GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms GPU
stms
| Bool
otherwise = PatElem Type
pe PatElem Type -> (VName, Stms GPU) -> ReduceM (Stms GPU)
`migratedTo` (VName
dev, Stms GPU
stms)
optimizeWithAccInput :: VName -> WithAccInput GPU -> ReduceM (WithAccInput GPU)
optimizeWithAccInput :: VName -> WithAccInput GPU -> ReduceM (WithAccInput GPU)
optimizeWithAccInput VName
_ (ShapeBase SubExp
shape, [VName]
arrs, Maybe (Lambda GPU, [SubExp])
Nothing) = WithAccInput GPU -> ReduceM (WithAccInput GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ShapeBase SubExp
shape, [VName]
arrs, Maybe (Lambda GPU, [SubExp])
forall a. Maybe a
Nothing)
optimizeWithAccInput VName
acc (ShapeBase SubExp
shape, [VName]
arrs, Just (Lambda GPU
op, [SubExp]
nes)) = do
device_only <- (MigrationTable -> Bool) -> ReduceM Bool
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (VName -> MigrationTable -> Bool
shouldMove VName
acc)
if device_only
then do
op' <- addReadsToLambda op
pure (shape, arrs, Just (op', nes))
else do
let body = Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
op
stms' <- noGPUBody $ optimizeStms (bodyStms body)
let op' = Lambda GPU
op {lambdaBody = body {bodyStms = stms'}}
pure (shape, arrs, Just (op', nes))
optimizeHostOp :: HostOp op GPU -> ReduceM (HostOp op GPU)
optimizeHostOp :: forall (op :: * -> *). HostOp op GPU -> ReduceM (HostOp op GPU)
optimizeHostOp (SegOp (SegMap SegLevel
lvl SegSpace
space [Type]
types KernelBody GPU
kbody)) =
SegOp SegLevel GPU -> HostOp op GPU
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegOp SegLevel GPU -> HostOp op GPU)
-> (KernelBody GPU -> SegOp SegLevel GPU)
-> KernelBody GPU
-> HostOp op GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegLevel
-> SegSpace -> [Type] -> KernelBody GPU -> SegOp SegLevel GPU
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
space [Type]
types (KernelBody GPU -> HostOp op GPU)
-> ReduceM (KernelBody GPU) -> ReduceM (HostOp op GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelBody GPU -> ReduceM (KernelBody GPU)
addReadsToKernelBody KernelBody GPU
kbody
optimizeHostOp (SegOp (SegRed SegLevel
lvl SegSpace
space [Type]
types KernelBody GPU
kbody [SegBinOp GPU]
ops)) = do
ops' <- (SegBinOp GPU -> ReduceM (SegBinOp GPU))
-> [SegBinOp GPU] -> ReduceM [SegBinOp GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SegBinOp GPU -> ReduceM (SegBinOp GPU)
addReadsToSegBinOp [SegBinOp GPU]
ops
kbody' <- addReadsToKernelBody kbody
pure . SegOp $ SegRed lvl space types kbody' ops'
optimizeHostOp (SegOp (SegScan SegLevel
lvl SegSpace
space [Type]
types KernelBody GPU
kbody [SegBinOp GPU]
ops)) = do
ops' <- (SegBinOp GPU -> ReduceM (SegBinOp GPU))
-> [SegBinOp GPU] -> ReduceM [SegBinOp GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SegBinOp GPU -> ReduceM (SegBinOp GPU)
addReadsToSegBinOp [SegBinOp GPU]
ops
kbody' <- addReadsToKernelBody kbody
pure . SegOp $ SegScan lvl space types kbody' ops'
optimizeHostOp (SegOp (SegHist SegLevel
lvl SegSpace
space [Type]
types KernelBody GPU
kbody [HistOp GPU]
ops)) = do
ops' <- (HistOp GPU -> ReduceM (HistOp GPU))
-> [HistOp GPU] -> ReduceM [HistOp GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM HistOp GPU -> ReduceM (HistOp GPU)
addReadsToHistOp [HistOp GPU]
ops
kbody' <- addReadsToKernelBody kbody
pure . SegOp $ SegHist lvl space types kbody' ops'
optimizeHostOp (SizeOp SizeOp
op) =
HostOp op GPU -> ReduceM (HostOp op GPU)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SizeOp -> HostOp op GPU
forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp SizeOp
op)
optimizeHostOp OtherOp {} =
String -> ReduceM (HostOp op GPU)
forall a. String -> a
compilerBugS String
"optimizeHostOp: unhandled OtherOp"
optimizeHostOp (GPUBody [Type]
types Body GPU
body) =
[Type] -> Body GPU -> HostOp op GPU
forall (op :: * -> *) rep. [Type] -> Body rep -> HostOp op rep
GPUBody [Type]
types (Body GPU -> HostOp op GPU)
-> ReduceM (Body GPU) -> ReduceM (HostOp op GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body GPU -> ReduceM (Body GPU)
addReadsToBody Body GPU
body
withSuffix :: Name -> String -> Name
withSuffix :: Name -> String -> Name
withSuffix Name
name String
sfx = Text -> Name
nameFromText (Text -> Name) -> Text -> Name
forall a b. (a -> b) -> a -> b
$ Text -> Text -> Text
T.append (Name -> Text
nameToText Name
name) (String -> Text
T.pack String
sfx)
newtype ReduceM a = ReduceM (StateT State (Reader MigrationTable) a)
deriving
( (forall a b. (a -> b) -> ReduceM a -> ReduceM b)
-> (forall a b. a -> ReduceM b -> ReduceM a) -> Functor ReduceM
forall a b. a -> ReduceM b -> ReduceM a
forall a b. (a -> b) -> ReduceM a -> ReduceM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> ReduceM a -> ReduceM b
fmap :: forall a b. (a -> b) -> ReduceM a -> ReduceM b
$c<$ :: forall a b. a -> ReduceM b -> ReduceM a
<$ :: forall a b. a -> ReduceM b -> ReduceM a
Functor,
Functor ReduceM
Functor ReduceM =>
(forall a. a -> ReduceM a)
-> (forall a b. ReduceM (a -> b) -> ReduceM a -> ReduceM b)
-> (forall a b c.
(a -> b -> c) -> ReduceM a -> ReduceM b -> ReduceM c)
-> (forall a b. ReduceM a -> ReduceM b -> ReduceM b)
-> (forall a b. ReduceM a -> ReduceM b -> ReduceM a)
-> Applicative ReduceM
forall a. a -> ReduceM a
forall a b. ReduceM a -> ReduceM b -> ReduceM a
forall a b. ReduceM a -> ReduceM b -> ReduceM b
forall a b. ReduceM (a -> b) -> ReduceM a -> ReduceM b
forall a b c. (a -> b -> c) -> ReduceM a -> ReduceM b -> ReduceM c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall a. a -> ReduceM a
pure :: forall a. a -> ReduceM a
$c<*> :: forall a b. ReduceM (a -> b) -> ReduceM a -> ReduceM b
<*> :: forall a b. ReduceM (a -> b) -> ReduceM a -> ReduceM b
$cliftA2 :: forall a b c. (a -> b -> c) -> ReduceM a -> ReduceM b -> ReduceM c
liftA2 :: forall a b c. (a -> b -> c) -> ReduceM a -> ReduceM b -> ReduceM c
$c*> :: forall a b. ReduceM a -> ReduceM b -> ReduceM b
*> :: forall a b. ReduceM a -> ReduceM b -> ReduceM b
$c<* :: forall a b. ReduceM a -> ReduceM b -> ReduceM a
<* :: forall a b. ReduceM a -> ReduceM b -> ReduceM a
Applicative,
Applicative ReduceM
Applicative ReduceM =>
(forall a b. ReduceM a -> (a -> ReduceM b) -> ReduceM b)
-> (forall a b. ReduceM a -> ReduceM b -> ReduceM b)
-> (forall a. a -> ReduceM a)
-> Monad ReduceM
forall a. a -> ReduceM a
forall a b. ReduceM a -> ReduceM b -> ReduceM b
forall a b. ReduceM a -> (a -> ReduceM b) -> ReduceM b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall a b. ReduceM a -> (a -> ReduceM b) -> ReduceM b
>>= :: forall a b. ReduceM a -> (a -> ReduceM b) -> ReduceM b
$c>> :: forall a b. ReduceM a -> ReduceM b -> ReduceM b
>> :: forall a b. ReduceM a -> ReduceM b -> ReduceM b
$creturn :: forall a. a -> ReduceM a
return :: forall a. a -> ReduceM a
Monad,
MonadState State,
MonadReader MigrationTable
)
runReduceM :: (MonadFreshNames m) => MigrationTable -> ReduceM a -> m a
runReduceM :: forall (m :: * -> *) a.
MonadFreshNames m =>
MigrationTable -> ReduceM a -> m a
runReduceM MigrationTable
mt (ReduceM StateT State (Reader MigrationTable) a
m) = (VNameSource -> (a, VNameSource)) -> m a
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (a, VNameSource)) -> m a)
-> (VNameSource -> (a, VNameSource)) -> m a
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
(State -> VNameSource) -> (a, State) -> (a, VNameSource)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second State -> VNameSource
stateNameSource (Reader MigrationTable (a, State) -> MigrationTable -> (a, State)
forall r a. Reader r a -> r -> a
runReader (StateT State (Reader MigrationTable) a
-> State -> Reader MigrationTable (a, State)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT StateT State (Reader MigrationTable) a
m (VNameSource -> State
initialState VNameSource
src)) MigrationTable
mt)
instance MonadFreshNames ReduceM where
getNameSource :: ReduceM VNameSource
getNameSource = (State -> VNameSource) -> ReduceM VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets State -> VNameSource
stateNameSource
putNameSource :: VNameSource -> ReduceM ()
putNameSource VNameSource
src = (State -> State) -> ReduceM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((State -> State) -> ReduceM ()) -> (State -> State) -> ReduceM ()
forall a b. (a -> b) -> a -> b
$ \State
s -> State
s {stateNameSource = src}
data State = State
{
State -> VNameSource
stateNameSource :: VNameSource,
State -> IntMap (Name, Type, VName, Bool)
stateMigrated :: IM.IntMap (Name, Type, VName, Bool),
State -> Bool
stateGPUBodyOk :: Bool
}
initialState :: VNameSource -> State
initialState :: VNameSource -> State
initialState VNameSource
ns =
State
{ stateNameSource :: VNameSource
stateNameSource = VNameSource
ns,
stateMigrated :: IntMap (Name, Type, VName, Bool)
stateMigrated = IntMap (Name, Type, VName, Bool)
forall a. Monoid a => a
mempty,
stateGPUBodyOk :: Bool
stateGPUBodyOk = Bool
True
}
noGPUBody :: ReduceM a -> ReduceM a
noGPUBody :: forall a. ReduceM a -> ReduceM a
noGPUBody ReduceM a
m = do
prev <- (State -> Bool) -> ReduceM Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets State -> Bool
stateGPUBodyOk
modify $ \State
st -> State
st {stateGPUBodyOk = False}
res <- m
modify $ \State
st -> State
st {stateGPUBodyOk = prev}
pure res
arrayizePatElem :: PatElem Type -> ReduceM (PatElem Type)
arrayizePatElem :: PatElem Type -> ReduceM (PatElem Type)
arrayizePatElem (PatElem VName
n Type
t) = do
let name :: Name
name = VName -> Name
baseName VName
n Name -> String -> Name
`withSuffix` String
"_dev"
dev <- VName -> ReduceM VName
forall (m :: * -> *). MonadFreshNames m => VName -> m VName
newName (Name -> Int -> VName
VName Name
name Int
0)
let dev_t = Type
t Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1
pure (PatElem dev dev_t)
movedTo :: Ident -> VName -> ReduceM ()
movedTo :: Ident -> VName -> ReduceM ()
movedTo = Bool -> Ident -> VName -> ReduceM ()
recordMigration Bool
False
aliasedBy :: Ident -> VName -> ReduceM ()
aliasedBy :: Ident -> VName -> ReduceM ()
aliasedBy = Bool -> Ident -> VName -> ReduceM ()
recordMigration Bool
True
recordMigration :: Bool -> Ident -> VName -> ReduceM ()
recordMigration :: Bool -> Ident -> VName -> ReduceM ()
recordMigration Bool
host (Ident VName
x Type
t) VName
arr =
(State -> State) -> ReduceM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((State -> State) -> ReduceM ()) -> (State -> State) -> ReduceM ()
forall a b. (a -> b) -> a -> b
$ \State
st ->
let migrated :: IntMap (Name, Type, VName, Bool)
migrated = State -> IntMap (Name, Type, VName, Bool)
stateMigrated State
st
entry :: (Name, Type, VName, Bool)
entry = (VName -> Name
baseName VName
x, Type
t, VName
arr, Bool
host)
migrated' :: IntMap (Name, Type, VName, Bool)
migrated' = Int
-> (Name, Type, VName, Bool)
-> IntMap (Name, Type, VName, Bool)
-> IntMap (Name, Type, VName, Bool)
forall a. Int -> a -> IntMap a -> IntMap a
IM.insert (VName -> Int
baseTag VName
x) (Name, Type, VName, Bool)
entry IntMap (Name, Type, VName, Bool)
migrated
in State
st {stateMigrated = migrated'}
migratedTo :: PatElem Type -> (VName, Stms GPU) -> ReduceM (Stms GPU)
migratedTo :: PatElem Type -> (VName, Stms GPU) -> ReduceM (Stms GPU)
migratedTo PatElem Type
pe (VName
dev, Stms GPU
stms) = do
used <- (MigrationTable -> Bool) -> ReduceM Bool
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (VName -> MigrationTable -> Bool
usedOnHost (VName -> MigrationTable -> Bool)
-> VName -> MigrationTable -> Bool
forall a b. (a -> b) -> a -> b
$ PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName PatElem Type
pe)
if used
then patElemIdent pe `aliasedBy` dev >> pure (stms |> bind pe (eIndex dev))
else patElemIdent pe `movedTo` dev >> pure stms
useScalar :: Stms GPU -> VName -> ReduceM (Stms GPU, VName)
useScalar :: Stms GPU -> VName -> ReduceM (Stms GPU, VName)
useScalar Stms GPU
stms VName
n = do
entry <- (State -> Maybe (Name, Type, VName, Bool))
-> ReduceM (Maybe (Name, Type, VName, Bool))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((State -> Maybe (Name, Type, VName, Bool))
-> ReduceM (Maybe (Name, Type, VName, Bool)))
-> (State -> Maybe (Name, Type, VName, Bool))
-> ReduceM (Maybe (Name, Type, VName, Bool))
forall a b. (a -> b) -> a -> b
$ Int
-> IntMap (Name, Type, VName, Bool)
-> Maybe (Name, Type, VName, Bool)
forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
baseTag VName
n) (IntMap (Name, Type, VName, Bool)
-> Maybe (Name, Type, VName, Bool))
-> (State -> IntMap (Name, Type, VName, Bool))
-> State
-> Maybe (Name, Type, VName, Bool)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. State -> IntMap (Name, Type, VName, Bool)
stateMigrated
case entry of
Maybe (Name, Type, VName, Bool)
Nothing ->
(Stms GPU, VName) -> ReduceM (Stms GPU, VName)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms, VName
n)
Just (Name
_, Type
_, VName
_, Bool
True) ->
(Stms GPU, VName) -> ReduceM (Stms GPU, VName)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms, VName
n)
Just (Name
name, Type
t, VName
arr, Bool
_) ->
do
n' <- VName -> ReduceM VName
forall (m :: * -> *). MonadFreshNames m => VName -> m VName
newName (Name -> Int -> VName
VName Name
name Int
0)
let stm = PatElem Type -> Exp GPU -> Stm GPU
bind (VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem VName
n' Type
t) (VName -> Exp GPU
eIndex VName
arr)
pure (stms |> stm, n')
eIndex :: VName -> Exp GPU
eIndex :: VName -> Exp GPU
eIndex VName
arr = BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr ([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0])
bind :: PatElem Type -> Exp GPU -> Stm GPU
bind :: PatElem Type -> Exp GPU -> Stm GPU
bind PatElem Type
pe = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type
pe]) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ())
storedScalar :: SubExp -> ReduceM (Maybe VName)
storedScalar :: SubExp -> ReduceM (Maybe VName)
storedScalar (Constant PrimValue
_) = Maybe VName -> ReduceM (Maybe VName)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe VName
forall a. Maybe a
Nothing
storedScalar (Var VName
n) = do
entry <- (State -> Maybe (Name, Type, VName, Bool))
-> ReduceM (Maybe (Name, Type, VName, Bool))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((State -> Maybe (Name, Type, VName, Bool))
-> ReduceM (Maybe (Name, Type, VName, Bool)))
-> (State -> Maybe (Name, Type, VName, Bool))
-> ReduceM (Maybe (Name, Type, VName, Bool))
forall a b. (a -> b) -> a -> b
$ Int
-> IntMap (Name, Type, VName, Bool)
-> Maybe (Name, Type, VName, Bool)
forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
baseTag VName
n) (IntMap (Name, Type, VName, Bool)
-> Maybe (Name, Type, VName, Bool))
-> (State -> IntMap (Name, Type, VName, Bool))
-> State
-> Maybe (Name, Type, VName, Bool)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. State -> IntMap (Name, Type, VName, Bool)
stateMigrated
pure $ fmap (\(Name
_, Type
_, VName
arr, Bool
_) -> VName
arr) entry
storeScalar :: Stms GPU -> SubExp -> Type -> ReduceM (Stms GPU, VName)
storeScalar :: Stms GPU -> SubExp -> Type -> ReduceM (Stms GPU, VName)
storeScalar Stms GPU
stms SubExp
se Type
t = do
entry <- case SubExp
se of
Var VName
n -> (State -> Maybe (Name, Type, VName, Bool))
-> ReduceM (Maybe (Name, Type, VName, Bool))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((State -> Maybe (Name, Type, VName, Bool))
-> ReduceM (Maybe (Name, Type, VName, Bool)))
-> (State -> Maybe (Name, Type, VName, Bool))
-> ReduceM (Maybe (Name, Type, VName, Bool))
forall a b. (a -> b) -> a -> b
$ Int
-> IntMap (Name, Type, VName, Bool)
-> Maybe (Name, Type, VName, Bool)
forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
baseTag VName
n) (IntMap (Name, Type, VName, Bool)
-> Maybe (Name, Type, VName, Bool))
-> (State -> IntMap (Name, Type, VName, Bool))
-> State
-> Maybe (Name, Type, VName, Bool)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. State -> IntMap (Name, Type, VName, Bool)
stateMigrated
SubExp
_ -> Maybe (Name, Type, VName, Bool)
-> ReduceM (Maybe (Name, Type, VName, Bool))
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Name, Type, VName, Bool)
forall a. Maybe a
Nothing
case entry of
Just (Name
_, Type
_, VName
arr, Bool
_) -> (Stms GPU, VName) -> ReduceM (Stms GPU, VName)
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms, VName
arr)
Maybe (Name, Type, VName, Bool)
Nothing -> do
gpubody_ok <- (State -> Bool) -> ReduceM Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets State -> Bool
stateGPUBodyOk
case se of
Var VName
n | Bool
gpubody_ok -> do
n' <- VName -> ReduceM VName
forall (m :: * -> *). MonadFreshNames m => VName -> m VName
newName VName
n
let stm = PatElem Type -> Exp GPU -> Stm GPU
bind (VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem VName
n' Type
t) (BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se)
gpubody <- inGPUBody (pure stm)
let dev = PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName (PatElem Type -> VName) -> PatElem Type -> VName
forall a b. (a -> b) -> a -> b
$ [PatElem Type] -> PatElem Type
forall a. HasCallStack => [a] -> a
head ([PatElem Type] -> PatElem Type) -> [PatElem Type] -> PatElem Type
forall a b. (a -> b) -> a -> b
$ Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems (Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
gpubody)
pure (stms |> gpubody, dev)
Var VName
n -> do
pe <- PatElem Type -> ReduceM (PatElem Type)
arrayizePatElem (VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem VName
n Type
t)
let shape = [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1]
let stm = PatElem Type -> Exp GPU -> Stm GPU
bind PatElem Type
pe (BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate ShapeBase SubExp
shape SubExp
se)
pure (stms |> stm, patElemName pe)
SubExp
_ -> do
let n :: VName
n = Name -> Int -> VName
VName (String -> Name
nameFromString String
"const") Int
0
pe <- PatElem Type -> ReduceM (PatElem Type)
arrayizePatElem (VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem VName
n Type
t)
let stm = PatElem Type -> Exp GPU -> Stm GPU
bind PatElem Type
pe (BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Type -> BasicOp
ArrayLit [SubExp
se] Type
t)
pure (stms |> stm, patElemName pe)
resolveName :: VName -> ReduceM VName
resolveName :: VName -> ReduceM VName
resolveName VName
n = do
entry <- (State -> Maybe (Name, Type, VName, Bool))
-> ReduceM (Maybe (Name, Type, VName, Bool))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((State -> Maybe (Name, Type, VName, Bool))
-> ReduceM (Maybe (Name, Type, VName, Bool)))
-> (State -> Maybe (Name, Type, VName, Bool))
-> ReduceM (Maybe (Name, Type, VName, Bool))
forall a b. (a -> b) -> a -> b
$ Int
-> IntMap (Name, Type, VName, Bool)
-> Maybe (Name, Type, VName, Bool)
forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
baseTag VName
n) (IntMap (Name, Type, VName, Bool)
-> Maybe (Name, Type, VName, Bool))
-> (State -> IntMap (Name, Type, VName, Bool))
-> State
-> Maybe (Name, Type, VName, Bool)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. State -> IntMap (Name, Type, VName, Bool)
stateMigrated
case entry of
Maybe (Name, Type, VName, Bool)
Nothing -> VName -> ReduceM VName
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
n
Just (Name
_, Type
_, VName
_, Bool
True) -> VName -> ReduceM VName
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
n
Just (Name
_, Type
_, VName
arr, Bool
_) -> VName -> ReduceM VName
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
arr
resolveSubExp :: SubExp -> ReduceM SubExp
resolveSubExp :: SubExp -> ReduceM SubExp
resolveSubExp (Var VName
n) = VName -> SubExp
Var (VName -> SubExp) -> ReduceM VName -> ReduceM SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ReduceM VName
resolveName VName
n
resolveSubExp SubExp
cnst = SubExp -> ReduceM SubExp
forall a. a -> ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
cnst
resolveSubExpRes :: SubExpRes -> ReduceM SubExpRes
resolveSubExpRes :: SubExpRes -> ReduceM SubExpRes
resolveSubExpRes (SubExpRes Certs
certs SubExp
se) =
Certs -> SubExp -> SubExpRes
SubExpRes Certs
certs (SubExp -> SubExpRes) -> ReduceM SubExp -> ReduceM SubExpRes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> ReduceM SubExp
resolveSubExp SubExp
se
resolveResult :: Result -> ReduceM Result
resolveResult :: Result -> ReduceM Result
resolveResult = (SubExpRes -> ReduceM SubExpRes) -> Result -> ReduceM Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExpRes -> ReduceM SubExpRes
resolveSubExpRes
moveStm :: Stms GPU -> Stm GPU -> ReduceM (Stms GPU)
moveStm :: Stms GPU -> Stm GPU -> ReduceM (Stms GPU)
moveStm Stms GPU
out (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (BasicOp (ArrayLit [SubExp
se] Type
t')))
| Pat [PatElem VName
n LetDec GPU
_] <- Pat (LetDec GPU)
pat =
do
let n' :: VName
n' = Name -> Int -> VName
VName (VName -> Name
baseName VName
n Name -> String -> Name
`withSuffix` String
"_inner") Int
0
let pat' :: Pat Type
pat' = [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem VName
n' Type
t']
let e' :: Exp rep
e' = BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (SubExp -> BasicOp
SubExp SubExp
se)
let stm' :: Stm GPU
stm' = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec GPU)
pat' StmAux (ExpDec GPU)
aux Exp GPU
forall {rep}. Exp rep
e'
gpubody <- RewriteM (Stm GPU) -> ReduceM (Stm GPU)
inGPUBody (Stm GPU -> RewriteM (Stm GPU)
rewriteStm Stm GPU
stm')
pure (out |> gpubody {stmPat = pat})
moveStm Stms GPU
out Stm GPU
stm = do
gpubody <- RewriteM (Stm GPU) -> ReduceM (Stm GPU)
inGPUBody (Stm GPU -> RewriteM (Stm GPU)
rewriteStm Stm GPU
stm)
let arrs = [PatElem Type] -> [PatElem Type] -> [(PatElem Type, PatElem Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems (Pat Type -> [PatElem Type]) -> Pat Type -> [PatElem Type]
forall a b. (a -> b) -> a -> b
$ Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm) (Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems (Pat Type -> [PatElem Type]) -> Pat Type -> [PatElem Type]
forall a b. (a -> b) -> a -> b
$ Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
gpubody)
foldM addRead (out |> gpubody) arrs
where
addRead :: Stms GPU -> (PatElem Type, PatElem Type) -> ReduceM (Stms GPU)
addRead Stms GPU
stms (pe :: PatElem Type
pe@(PatElem VName
_ Type
t), PatElem VName
dev Type
dev_t) =
let add' :: Exp GPU -> f (Stms GPU)
add' Exp GPU
e = Stms GPU -> f (Stms GPU)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU -> f (Stms GPU)) -> Stms GPU -> f (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stms GPU
stms Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> PatElem Type -> Exp GPU -> Stm GPU
bind PatElem Type
pe Exp GPU
e
add :: BasicOp -> ReduceM (Stms GPU)
add = Exp GPU -> ReduceM (Stms GPU)
forall {f :: * -> *}. Applicative f => Exp GPU -> f (Stms GPU)
add' (Exp GPU -> ReduceM (Stms GPU))
-> (BasicOp -> Exp GPU) -> BasicOp -> ReduceM (Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp
in case Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
dev_t of
Int
0 -> BasicOp -> ReduceM (Stms GPU)
add (BasicOp -> ReduceM (Stms GPU)) -> BasicOp -> ReduceM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (VName -> SubExp
Var VName
dev)
Int
1 | Type
t Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit -> Exp GPU -> ReduceM (Stms GPU)
forall {f :: * -> *}. Applicative f => Exp GPU -> f (Stms GPU)
add' (VName -> Exp GPU
eIndex VName
dev)
Int
1 -> PatElem Type
pe PatElem Type -> (VName, Stms GPU) -> ReduceM (Stms GPU)
`migratedTo` (VName
dev, Stms GPU
stms)
Int
_ -> BasicOp -> ReduceM (Stms GPU)
add (BasicOp -> ReduceM (Stms GPU)) -> BasicOp -> ReduceM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
dev (Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
dev_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0])
inGPUBody :: RewriteM (Stm GPU) -> ReduceM (Stm GPU)
inGPUBody :: RewriteM (Stm GPU) -> ReduceM (Stm GPU)
inGPUBody RewriteM (Stm GPU)
m = do
(stm, st) <- RewriteM (Stm GPU) -> RState -> ReduceM (Stm GPU, RState)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT RewriteM (Stm GPU)
m RState
initialRState
let prologue = RState -> Stms GPU
rewritePrologue RState
st
let pes = Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems (Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm)
pat <- Pat <$> mapM arrayizePatElem pes
let aux = () -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()
let types = (PatElem Type -> Type) -> [PatElem Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map PatElem Type -> Type
forall dec. Typed dec => PatElem dec -> Type
patElemType [PatElem Type]
pes
let res = (PatElem Type -> SubExpRes) -> [PatElem Type] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (Certs -> SubExp -> SubExpRes
SubExpRes Certs
forall a. Monoid a => a
mempty (SubExp -> SubExpRes)
-> (PatElem Type -> SubExp) -> PatElem Type -> SubExpRes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp)
-> (PatElem Type -> VName) -> PatElem Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName) [PatElem Type]
pes
let body = BodyDec GPU -> Stms GPU -> Result -> Body GPU
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () (Stms GPU
prologue Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm) Result
res
let e = Op GPU -> Exp GPU
forall rep. Op rep -> Exp rep
Op ([Type] -> Body GPU -> HostOp SOAC GPU
forall (op :: * -> *) rep. [Type] -> Body rep -> HostOp op rep
GPUBody [Type]
types Body GPU
body)
pure (Let pat aux e)
type RewriteM = StateT RState ReduceM
data RState = RState
{
RState -> IntMap VName
rewriteRenames :: IM.IntMap VName,
RState -> Stms GPU
rewritePrologue :: Stms GPU
}
initialRState :: RState
initialRState :: RState
initialRState =
RState
{ rewriteRenames :: IntMap VName
rewriteRenames = IntMap VName
forall a. Monoid a => a
mempty,
rewritePrologue :: Stms GPU
rewritePrologue = Stms GPU
forall a. Monoid a => a
mempty
}
addReadsToSegBinOp :: SegBinOp GPU -> ReduceM (SegBinOp GPU)
addReadsToSegBinOp :: SegBinOp GPU -> ReduceM (SegBinOp GPU)
addReadsToSegBinOp SegBinOp GPU
op = do
f' <- Lambda GPU -> ReduceM (Lambda GPU)
addReadsToLambda (SegBinOp GPU -> Lambda GPU
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPU
op)
pure (op {segBinOpLambda = f'})
addReadsToHistOp :: HistOp GPU -> ReduceM (HistOp GPU)
addReadsToHistOp :: HistOp GPU -> ReduceM (HistOp GPU)
addReadsToHistOp HistOp GPU
op = do
f' <- Lambda GPU -> ReduceM (Lambda GPU)
addReadsToLambda (HistOp GPU -> Lambda GPU
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPU
op)
pure (op {histOp = f'})
addReadsToLambda :: Lambda GPU -> ReduceM (Lambda GPU)
addReadsToLambda :: Lambda GPU -> ReduceM (Lambda GPU)
addReadsToLambda Lambda GPU
f = do
body' <- Body GPU -> ReduceM (Body GPU)
addReadsToBody (Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
f)
pure (f {lambdaBody = body'})
addReadsToBody :: Body GPU -> ReduceM (Body GPU)
addReadsToBody :: Body GPU -> ReduceM (Body GPU)
addReadsToBody Body GPU
body = do
(body', prologue) <- Body GPU -> ReduceM (Body GPU, Stms GPU)
forall a. (FreeIn a, Substitute a) => a -> ReduceM (a, Stms GPU)
addReadsHelper Body GPU
body
pure body' {bodyStms = prologue >< bodyStms body'}
addReadsToKernelBody :: KernelBody GPU -> ReduceM (KernelBody GPU)
addReadsToKernelBody :: KernelBody GPU -> ReduceM (KernelBody GPU)
addReadsToKernelBody KernelBody GPU
kbody = do
(kbody', prologue) <- KernelBody GPU -> ReduceM (KernelBody GPU, Stms GPU)
forall a. (FreeIn a, Substitute a) => a -> ReduceM (a, Stms GPU)
addReadsHelper KernelBody GPU
kbody
pure kbody' {kernelBodyStms = prologue >< kernelBodyStms kbody'}
addReadsHelper :: (FreeIn a, Substitute a) => a -> ReduceM (a, Stms GPU)
addReadsHelper :: forall a. (FreeIn a, Substitute a) => a -> ReduceM (a, Stms GPU)
addReadsHelper a
x = do
let from :: [VName]
from = Names -> [VName]
namesToList (a -> Names
forall a. FreeIn a => a -> Names
freeIn a
x)
(to, st) <- StateT RState ReduceM [VName]
-> RState -> ReduceM ([VName], RState)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT ((VName -> StateT RState ReduceM VName)
-> [VName] -> StateT RState ReduceM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> StateT RState ReduceM VName
rename [VName]
from) RState
initialRState
let rename_map = [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
from [VName]
to)
pure (substituteNames rename_map x, rewritePrologue st)
rewriteName :: VName -> RewriteM VName
rewriteName :: VName -> StateT RState ReduceM VName
rewriteName VName
n = do
n' <- ReduceM VName -> StateT RState ReduceM VName
forall (m :: * -> *) a. Monad m => m a -> StateT RState m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (VName -> ReduceM VName
forall (m :: * -> *). MonadFreshNames m => VName -> m VName
newName VName
n)
modify $ \RState
st -> RState
st {rewriteRenames = IM.insert (baseTag n) n' (rewriteRenames st)}
pure n'
rewriteBody :: Body GPU -> RewriteM (Body GPU)
rewriteBody :: Body GPU -> RewriteM (Body GPU)
rewriteBody (Body BodyDec GPU
_ Stms GPU
stms Result
res) = do
stms' <- Stms GPU -> RewriteM (Stms GPU)
rewriteStms Stms GPU
stms
res' <- renameResult res
pure (Body () stms' res')
rewriteStms :: Stms GPU -> RewriteM (Stms GPU)
rewriteStms :: Stms GPU -> RewriteM (Stms GPU)
rewriteStms = (Stms GPU -> Stm GPU -> RewriteM (Stms GPU))
-> Stms GPU -> Stms GPU -> RewriteM (Stms GPU)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Stms GPU -> Stm GPU -> RewriteM (Stms GPU)
rewriteTo Stms GPU
forall a. Monoid a => a
mempty
where
rewriteTo :: Stms GPU -> Stm GPU -> RewriteM (Stms GPU)
rewriteTo Stms GPU
out Stm GPU
stm = do
stm' <- Stm GPU -> RewriteM (Stm GPU)
rewriteStm Stm GPU
stm
pure $ case stmExp stm' of
Op (GPUBody [Type]
_ (Body BodyDec GPU
_ Stms GPU
stms Result
res)) ->
let pes :: [PatElem Type]
pes = Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems (Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm')
in (Stms GPU -> (PatElem Type, SubExpRes) -> Stms GPU)
-> Stms GPU -> [(PatElem Type, SubExpRes)] -> Stms GPU
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Stms GPU -> (PatElem Type, SubExpRes) -> Stms GPU
bnd (Stms GPU
out Stms GPU -> Stms GPU -> Stms GPU
forall a. Seq a -> Seq a -> Seq a
>< Stms GPU
stms) ([PatElem Type] -> Result -> [(PatElem Type, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem Type]
pes Result
res)
Exp GPU
_ -> Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm'
bnd :: Stms GPU -> (PatElem Type, SubExpRes) -> Stms GPU
bnd :: Stms GPU -> (PatElem Type, SubExpRes) -> Stms GPU
bnd Stms GPU
out (PatElem Type
pe, SubExpRes Certs
cs SubExp
se)
| Just Type
t' <- Int -> Type -> Maybe Type
forall u.
Int
-> TypeBase (ShapeBase SubExp) u
-> Maybe (TypeBase (ShapeBase SubExp) u)
peelArray Int
1 (PatElem Type -> Type
forall t. Typed t => t -> Type
typeOf PatElem Type
pe) =
Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type
pe]) (Certs -> Attrs -> Provenance -> () -> StmAux ()
forall dec. Certs -> Attrs -> Provenance -> dec -> StmAux dec
StmAux Certs
cs Attrs
forall a. Monoid a => a
mempty Provenance
forall a. Monoid a => a
mempty ()) (BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Type -> BasicOp
ArrayLit [SubExp
se] Type
t')
| Bool
otherwise =
Stms GPU
out Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type
pe]) (Certs -> Attrs -> Provenance -> () -> StmAux ()
forall dec. Certs -> Attrs -> Provenance -> dec -> StmAux dec
StmAux Certs
cs Attrs
forall a. Monoid a => a
mempty Provenance
forall a. Monoid a => a
mempty ()) (BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se)
rewriteStm :: Stm GPU -> RewriteM (Stm GPU)
rewriteStm :: Stm GPU -> RewriteM (Stm GPU)
rewriteStm (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux Exp GPU
e) = do
e' <- Exp GPU -> RewriteM (Exp GPU)
rewriteExp Exp GPU
e
pat' <- rewritePat pat
aux' <- rewriteStmAux aux
pure (Let pat' aux' e')
rewritePat :: Pat Type -> RewriteM (Pat Type)
rewritePat :: Pat Type -> RewriteM (Pat Type)
rewritePat Pat Type
pat = [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem Type] -> Pat Type)
-> StateT RState ReduceM [PatElem Type] -> RewriteM (Pat Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PatElem Type -> StateT RState ReduceM (PatElem Type))
-> [PatElem Type] -> StateT RState ReduceM [PatElem Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM PatElem Type -> StateT RState ReduceM (PatElem Type)
rewritePatElem (Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat)
rewritePatElem :: PatElem Type -> RewriteM (PatElem Type)
rewritePatElem :: PatElem Type -> StateT RState ReduceM (PatElem Type)
rewritePatElem (PatElem VName
n Type
t) = do
n' <- VName -> StateT RState ReduceM VName
rewriteName VName
n
t' <- renameType t
pure (PatElem n' t')
rewriteStmAux :: StmAux () -> RewriteM (StmAux ())
rewriteStmAux :: StmAux () -> RewriteM (StmAux ())
rewriteStmAux StmAux ()
aux = do
certs' <- Certs -> RewriteM Certs
renameCerts (Certs -> RewriteM Certs) -> Certs -> RewriteM Certs
forall a b. (a -> b) -> a -> b
$ StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
aux
pure $ aux {stmAuxCerts = certs'}
rewriteExp :: Exp GPU -> RewriteM (Exp GPU)
rewriteExp :: Exp GPU -> RewriteM (Exp GPU)
rewriteExp =
Mapper GPU GPU (StateT RState ReduceM)
-> Exp GPU -> RewriteM (Exp GPU)
forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM (Mapper GPU GPU (StateT RState ReduceM)
-> Exp GPU -> RewriteM (Exp GPU))
-> Mapper GPU GPU (StateT RState ReduceM)
-> Exp GPU
-> RewriteM (Exp GPU)
forall a b. (a -> b) -> a -> b
$
Mapper
{ mapOnSubExp :: SubExp -> StateT RState ReduceM SubExp
mapOnSubExp = SubExp -> StateT RState ReduceM SubExp
renameSubExp,
mapOnBody :: Scope GPU -> Body GPU -> RewriteM (Body GPU)
mapOnBody = (Body GPU -> RewriteM (Body GPU))
-> Scope GPU -> Body GPU -> RewriteM (Body GPU)
forall a b. a -> b -> a
const Body GPU -> RewriteM (Body GPU)
rewriteBody,
mapOnVName :: VName -> StateT RState ReduceM VName
mapOnVName = VName -> StateT RState ReduceM VName
rename,
mapOnRetType :: RetType GPU -> StateT RState ReduceM (RetType GPU)
mapOnRetType = DeclExtType -> RewriteM DeclExtType
RetType GPU -> StateT RState ReduceM (RetType GPU)
forall u.
TypeBase (ShapeBase ExtSize) u
-> RewriteM (TypeBase (ShapeBase ExtSize) u)
renameExtType,
mapOnBranchType :: BranchType GPU -> StateT RState ReduceM (BranchType GPU)
mapOnBranchType = ExtType -> RewriteM ExtType
BranchType GPU -> StateT RState ReduceM (BranchType GPU)
forall u.
TypeBase (ShapeBase ExtSize) u
-> RewriteM (TypeBase (ShapeBase ExtSize) u)
renameExtType,
mapOnFParam :: FParam GPU -> StateT RState ReduceM (FParam GPU)
mapOnFParam = Param DeclType -> RewriteM (Param DeclType)
FParam GPU -> StateT RState ReduceM (FParam GPU)
forall u.
Param (TypeBase (ShapeBase SubExp) u)
-> RewriteM (Param (TypeBase (ShapeBase SubExp) u))
rewriteParam,
mapOnLParam :: LParam GPU -> StateT RState ReduceM (LParam GPU)
mapOnLParam = Param Type -> RewriteM (Param Type)
LParam GPU -> StateT RState ReduceM (LParam GPU)
forall u.
Param (TypeBase (ShapeBase SubExp) u)
-> RewriteM (Param (TypeBase (ShapeBase SubExp) u))
rewriteParam,
mapOnOp :: Op GPU -> StateT RState ReduceM (Op GPU)
mapOnOp = StateT RState ReduceM (HostOp SOAC GPU)
-> HostOp SOAC GPU -> StateT RState ReduceM (HostOp SOAC GPU)
forall a b. a -> b -> a
const StateT RState ReduceM (HostOp SOAC GPU)
forall {a}. a
opError
}
where
opError :: a
opError = String -> a
forall a. String -> a
compilerBugS String
"Cannot migrate a host-only operation to device."
rewriteParam :: Param (TypeBase Shape u) -> RewriteM (Param (TypeBase Shape u))
rewriteParam :: forall u.
Param (TypeBase (ShapeBase SubExp) u)
-> RewriteM (Param (TypeBase (ShapeBase SubExp) u))
rewriteParam (Param Attrs
attrs VName
n TypeBase (ShapeBase SubExp) u
t) = do
n' <- VName -> StateT RState ReduceM VName
rewriteName VName
n
t' <- renameType t
pure (Param attrs n' t')
rename :: VName -> RewriteM VName
rename :: VName -> StateT RState ReduceM VName
rename VName
n = do
st <- StateT RState ReduceM RState
forall s (m :: * -> *). MonadState s m => m s
get
let renames = RState -> IntMap VName
rewriteRenames RState
st
let idx = VName -> Int
baseTag VName
n
case IM.lookup idx renames of
Just VName
n' -> VName -> StateT RState ReduceM VName
forall a. a -> StateT RState ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
n'
Maybe VName
_ ->
do
let stms :: Stms GPU
stms = RState -> Stms GPU
rewritePrologue RState
st
(stms', n') <- ReduceM (Stms GPU, VName)
-> StateT RState ReduceM (Stms GPU, VName)
forall (m :: * -> *) a. Monad m => m a -> StateT RState m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ReduceM (Stms GPU, VName)
-> StateT RState ReduceM (Stms GPU, VName))
-> ReduceM (Stms GPU, VName)
-> StateT RState ReduceM (Stms GPU, VName)
forall a b. (a -> b) -> a -> b
$ Stms GPU -> VName -> ReduceM (Stms GPU, VName)
useScalar Stms GPU
stms VName
n
modify $ \RState
st' ->
RState
st'
{ rewriteRenames = IM.insert idx n' renames,
rewritePrologue = stms'
}
pure n'
renameResult :: Result -> RewriteM Result
renameResult :: Result -> RewriteM Result
renameResult = (SubExpRes -> StateT RState ReduceM SubExpRes)
-> Result -> RewriteM Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExpRes -> StateT RState ReduceM SubExpRes
renameSubExpRes
renameSubExpRes :: SubExpRes -> RewriteM SubExpRes
renameSubExpRes :: SubExpRes -> StateT RState ReduceM SubExpRes
renameSubExpRes (SubExpRes Certs
certs SubExp
se) = do
certs' <- Certs -> RewriteM Certs
renameCerts Certs
certs
se' <- renameSubExp se
pure (SubExpRes certs' se')
renameCerts :: Certs -> RewriteM Certs
renameCerts :: Certs -> RewriteM Certs
renameCerts Certs
cs = [VName] -> Certs
Certs ([VName] -> Certs)
-> StateT RState ReduceM [VName] -> RewriteM Certs
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> StateT RState ReduceM VName)
-> [VName] -> StateT RState ReduceM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> StateT RState ReduceM VName
rename (Certs -> [VName]
unCerts Certs
cs)
renameSubExp :: SubExp -> RewriteM SubExp
renameSubExp :: SubExp -> StateT RState ReduceM SubExp
renameSubExp (Var VName
n) = VName -> SubExp
Var (VName -> SubExp)
-> StateT RState ReduceM VName -> StateT RState ReduceM SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> StateT RState ReduceM VName
rename VName
n
renameSubExp SubExp
se = SubExp -> StateT RState ReduceM SubExp
forall a. a -> StateT RState ReduceM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
se
renameType :: TypeBase Shape u -> RewriteM (TypeBase Shape u)
renameType :: forall u.
TypeBase (ShapeBase SubExp) u
-> RewriteM (TypeBase (ShapeBase SubExp) u)
renameType = (SubExp -> StateT RState ReduceM SubExp)
-> TypeBase (ShapeBase SubExp) u
-> StateT RState ReduceM (TypeBase (ShapeBase SubExp) u)
forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase SubExp) u
-> m (TypeBase (ShapeBase SubExp) u)
mapOnType SubExp -> StateT RState ReduceM SubExp
renameSubExp
renameExtType :: TypeBase ExtShape u -> RewriteM (TypeBase ExtShape u)
renameExtType :: forall u.
TypeBase (ShapeBase ExtSize) u
-> RewriteM (TypeBase (ShapeBase ExtSize) u)
renameExtType = (SubExp -> StateT RState ReduceM SubExp)
-> TypeBase (ShapeBase ExtSize) u
-> StateT RState ReduceM (TypeBase (ShapeBase ExtSize) u)
forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase ExtSize) u
-> m (TypeBase (ShapeBase ExtSize) u)
mapOnExtType SubExp -> StateT RState ReduceM SubExp
renameSubExp