{-# LANGUAGE TypeFamilies #-}
module Futhark.CodeGen.ImpGen.GPU.Block
( sKernelBlock,
compileBlockResult,
blockOperations,
Precomputed,
precomputeConstants,
precomputedConstants,
atomicUpdateLocking,
)
where
import Control.Monad
import Data.Bifunctor
import Data.List (partition, zip4)
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Set qualified as S
import Futhark.CodeGen.ImpCode.GPU qualified as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.Construct (fullSliceNum)
import Futhark.Error
import Futhark.IR.GPUMem
import Futhark.IR.Mem.LMAD qualified as LMAD
import Futhark.Transform.Rename
import Futhark.Util (chunks, mapAccumLM, takeLast)
import Futhark.Util.IntegralExp (divUp, rem)
import Prelude hiding (quot, rem)
flattenArray :: Int -> TV Int64 -> VName -> ImpM rep r op VName
flattenArray :: forall rep r op. Int -> TV Int64 -> VName -> ImpM rep r op VName
flattenArray Int
k TV Int64
flat VName
arr = do
ArrayEntry arr_loc pt <- VName -> ImpM rep r op ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
arr
let flat_shape = [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> ShapeBase SubExp) -> [SubExp] -> ShapeBase SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (TV Int64 -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV Int64
flat) SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop Int
k (MemLoc -> [SubExp]
memLocShape MemLoc
arr_loc)
sArray (baseString arr ++ "_flat") pt flat_shape (memLocName arr_loc) $
fromMaybe (error "flattenArray") $
LMAD.reshape (memLocLMAD arr_loc) (map pe64 $ shapeDims flat_shape)
sliceArray :: Imp.TExp Int64 -> TV Int64 -> VName -> ImpM rep r op VName
sliceArray :: forall rep r op.
TPrimExp Int64 VName -> TV Int64 -> VName -> ImpM rep r op VName
sliceArray TPrimExp Int64 VName
start TV Int64
size VName
arr = do
MemLoc mem _ lmad <- ArrayEntry -> MemLoc
entryArrayLoc (ArrayEntry -> MemLoc)
-> ImpM rep r op ArrayEntry -> ImpM rep r op MemLoc
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM rep r op ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
arr
arr_t <- lookupType arr
let slice =
[TPrimExp Int64 VName]
-> [DimIndex (TPrimExp Int64 VName)]
-> Slice (TPrimExp Int64 VName)
forall d. Num d => [d] -> [DimIndex d] -> Slice d
fullSliceNum
((SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
Imp.pe64 (TypeBase (ShapeBase SubExp) NoUniqueness -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims TypeBase (ShapeBase SubExp) NoUniqueness
arr_t))
[TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> DimIndex (TPrimExp Int64 VName)
forall d. d -> d -> d -> DimIndex d
DimSlice TPrimExp Int64 VName
start (TV Int64 -> TPrimExp Int64 VName
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
size) TPrimExp Int64 VName
1]
sArray
(baseString arr ++ "_chunk")
(elemType arr_t)
(arrayShape arr_t `setOuterDim` Var (tvVar size))
mem
$ LMAD.slice lmad slice
applyLambda ::
(Mem rep inner) =>
Lambda rep ->
[(VName, [DimIndex (Imp.TExp Int64)])] ->
[(SubExp, [DimIndex (Imp.TExp Int64)])] ->
ImpM rep r op ()
applyLambda :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
Lambda rep
-> [(VName, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> ImpM rep r op ()
applyLambda Lambda rep
lam [(VName, [DimIndex (TPrimExp Int64 VName)])]
dests [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
args = do
[LParam rep] -> ImpM rep r op ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams ([LParam rep] -> ImpM rep r op ())
-> [LParam rep] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [LParam rep]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam
[(Param LetDecMem, (SubExp, [DimIndex (TPrimExp Int64 VName)]))]
-> ((Param LetDecMem, (SubExp, [DimIndex (TPrimExp Int64 VName)]))
-> ImpM rep r op ())
-> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> [(Param LetDecMem, (SubExp, [DimIndex (TPrimExp Int64 VName)]))]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda rep -> [LParam rep]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam) [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
args) (((Param LetDecMem, (SubExp, [DimIndex (TPrimExp Int64 VName)]))
-> ImpM rep r op ())
-> ImpM rep r op ())
-> ((Param LetDecMem, (SubExp, [DimIndex (TPrimExp Int64 VName)]))
-> ImpM rep r op ())
-> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, (SubExp
arg, [DimIndex (TPrimExp Int64 VName)]
arg_slice)) ->
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] SubExp
arg [DimIndex (TPrimExp Int64 VName)]
arg_slice
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms (Body rep -> Stms rep) -> Body rep -> Stms rep
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam) (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ do
let res :: [SubExp]
res = (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp]) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Body rep -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult (Body rep -> [SubExpRes]) -> Body rep -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
[((VName, [DimIndex (TPrimExp Int64 VName)]), SubExp)]
-> (((VName, [DimIndex (TPrimExp Int64 VName)]), SubExp)
-> ImpM rep r op ())
-> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(VName, [DimIndex (TPrimExp Int64 VName)])]
-> [SubExp]
-> [((VName, [DimIndex (TPrimExp Int64 VName)]), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(VName, [DimIndex (TPrimExp Int64 VName)])]
dests [SubExp]
res) ((((VName, [DimIndex (TPrimExp Int64 VName)]), SubExp)
-> ImpM rep r op ())
-> ImpM rep r op ())
-> (((VName, [DimIndex (TPrimExp Int64 VName)]), SubExp)
-> ImpM rep r op ())
-> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \((VName
dest, [DimIndex (TPrimExp Int64 VName)]
dest_slice), SubExp
se) ->
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM VName
dest [DimIndex (TPrimExp Int64 VName)]
dest_slice SubExp
se []
applyRenamedLambda ::
(Mem rep inner) =>
Lambda rep ->
[(VName, [DimIndex (Imp.TExp Int64)])] ->
[(SubExp, [DimIndex (Imp.TExp Int64)])] ->
ImpM rep r op ()
applyRenamedLambda :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
Lambda rep
-> [(VName, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> ImpM rep r op ()
applyRenamedLambda Lambda rep
lam [(VName, [DimIndex (TPrimExp Int64 VName)])]
dests [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
args = do
lam_renamed <- Lambda rep -> ImpM rep r op (Lambda rep)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda rep
lam
applyLambda lam_renamed dests args
blockChunkLoop ::
Imp.TExp Int32 ->
(Imp.TExp Int32 -> TV Int64 -> InKernelGen ()) ->
InKernelGen ()
blockChunkLoop :: TExp Int32
-> (TExp Int32 -> TV Int64 -> InKernelGen ()) -> InKernelGen ()
blockChunkLoop TExp Int32
w TExp Int32 -> TV Int64 -> InKernelGen ()
m = do
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
let max_chunk_size = TPrimExp Int64 VName -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int64 VName
kernelBlockSize KernelConstants
constants
num_chunks <- dPrimVE "num_chunks" $ w `divUp` max_chunk_size
sFor "chunk_i" num_chunks $ \TExp Int32
chunk_i -> do
chunk_start <-
SpaceId
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall {k} (t :: k) rep r op.
SpaceId -> TExp t -> ImpM rep r op (TExp t)
dPrimVE SpaceId
"chunk_start" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int32
chunk_i TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
max_chunk_size
chunk_end <-
dPrimVE "chunk_end" $ sMin32 w (chunk_start + max_chunk_size)
chunk_size <-
dPrimV "chunk_size" $ sExt64 $ chunk_end - chunk_start
m chunk_start chunk_size
virtualisedBlockScan ::
Maybe (Imp.TExp Int32 -> Imp.TExp Int32 -> Imp.TExp Bool) ->
Imp.TExp Int32 ->
Lambda GPUMem ->
[VName] ->
InKernelGen ()
virtualisedBlockScan :: Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen ()
virtualisedBlockScan Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag TExp Int32
w Lambda GPUMem
lam [VName]
arrs = do
TExp Int32
-> (TExp Int32 -> TV Int64 -> InKernelGen ()) -> InKernelGen ()
blockChunkLoop TExp Int32
w ((TExp Int32 -> TV Int64 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> TV Int64 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
chunk_start TV Int64
chunk_size -> do
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
let ltid = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
crosses_segment =
case Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag of
Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
Nothing -> TExp Bool
forall v. TPrimExp Bool v
false
Just TExp Int32 -> TExp Int32 -> TExp Bool
flag_true ->
TExp Int32 -> TExp Int32 -> TExp Bool
flag_true (TExp Int32 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int32
chunk_start TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1)) (TExp Int32 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int32
chunk_start)
sComment "possibly incorporate carry" $
sWhen (chunk_start .>. 0 .&&. ltid .==. 0 .&&. bNot crosses_segment) $ do
carry_idx <- dPrimVE "carry_idx" $ sExt64 chunk_start - 1
applyRenamedLambda
lam
(map (,[DimFix $ sExt64 chunk_start]) arrs)
( map ((,[DimFix carry_idx]) . Var) arrs
++ map ((,[DimFix $ sExt64 chunk_start]) . Var) arrs
)
arrs_chunks <- mapM (sliceArray (sExt64 chunk_start) chunk_size) arrs
sOp $ Imp.ErrorSync Imp.FenceLocal
blockScan seg_flag (sExt64 w) (tvExp chunk_size) lam arrs_chunks
copyInBlock :: CopyCompiler GPUMem KernelEnv Imp.KernelOp
copyInBlock :: CopyCompiler GPUMem KernelEnv KernelOp
copyInBlock PrimType
pt MemLoc
destloc MemLoc
srcloc = do
dest_space <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM GPUMem KernelEnv KernelOp MemEntry
-> ImpM GPUMem KernelEnv KernelOp Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM GPUMem KernelEnv KernelOp MemEntry
forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory (MemLoc -> VName
memLocName MemLoc
destloc)
src_space <- entryMemSpace <$> lookupMemory (memLocName srcloc)
let src_lmad = MemLoc -> LMAD
memLocLMAD MemLoc
srcloc
dims = LMAD -> [TPrimExp Int64 VName]
forall num. LMAD num -> Shape num
LMAD.shape LMAD
src_lmad
rank = [TPrimExp Int64 VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
dims
case (dest_space, src_space) of
(ScalarSpace [SubExp]
destds PrimType
_, ScalarSpace [SubExp]
srcds PrimType
_) -> do
let fullDim :: d -> DimIndex d
fullDim d
d = d -> d -> d -> DimIndex d
forall d. d -> d -> d -> DimIndex d
DimSlice d
0 d
d d
1
destslice' :: Slice (TPrimExp Int64 VName)
destslice' =
[DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName))
-> [DimIndex (TPrimExp Int64 VName)]
-> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
Int
-> DimIndex (TPrimExp Int64 VName)
-> [DimIndex (TPrimExp Int64 VName)]
forall a. Int -> a -> [a]
replicate (Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
- [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
destds) (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
0)
[DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
forall a. [a] -> [a] -> [a]
++ Int
-> [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
forall a. Int -> [a] -> [a]
takeLast ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
destds) ((TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall {d}. Num d => d -> DimIndex d
fullDim [TPrimExp Int64 VName]
dims)
srcslice' :: Slice (TPrimExp Int64 VName)
srcslice' =
[DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName))
-> [DimIndex (TPrimExp Int64 VName)]
-> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
Int
-> DimIndex (TPrimExp Int64 VName)
-> [DimIndex (TPrimExp Int64 VName)]
forall a. Int -> a -> [a]
replicate (Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
- [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
srcds) (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
0)
[DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
forall a. [a] -> [a] -> [a]
++ Int
-> [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
forall a. Int -> [a] -> [a]
takeLast ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
srcds) ((TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall {d}. Num d => d -> DimIndex d
fullDim [TPrimExp Int64 VName]
dims)
CopyCompiler GPUMem KernelEnv KernelOp
forall rep r op. CopyCompiler rep r op
lmadCopy
PrimType
pt
(MemLoc -> Slice (TPrimExp Int64 VName) -> MemLoc
sliceMemLoc MemLoc
destloc Slice (TPrimExp Int64 VName)
destslice')
(MemLoc -> Slice (TPrimExp Int64 VName) -> MemLoc
sliceMemLoc MemLoc
srcloc Slice (TPrimExp Int64 VName)
srcslice')
(Space, Space)
_ -> do
[TExp Int32] -> ([TExp Int32] -> InKernelGen ()) -> InKernelGen ()
forall {k} (t :: k).
IntExp t =>
[TExp t] -> ([TExp t] -> InKernelGen ()) -> InKernelGen ()
blockCoverSpace ((TPrimExp Int64 VName -> TExp Int32)
-> [TPrimExp Int64 VName] -> [TExp Int32]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 [TPrimExp Int64 VName]
dims) (([TExp Int32] -> InKernelGen ()) -> InKernelGen ())
-> ([TExp Int32] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int32]
is ->
CopyCompiler GPUMem KernelEnv KernelOp
forall rep r op. CopyCompiler rep r op
lmadCopy
PrimType
pt
(MemLoc -> Slice (TPrimExp Int64 VName) -> MemLoc
sliceMemLoc MemLoc
destloc ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName))
-> [DimIndex (TPrimExp Int64 VName)]
-> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (TExp Int32 -> DimIndex (TPrimExp Int64 VName))
-> [TExp Int32] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> (TExp Int32 -> TPrimExp Int64 VName)
-> TExp Int32
-> DimIndex (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64) [TExp Int32]
is))
(MemLoc -> Slice (TPrimExp Int64 VName) -> MemLoc
sliceMemLoc MemLoc
srcloc ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName))
-> [DimIndex (TPrimExp Int64 VName)]
-> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (TExp Int32 -> DimIndex (TPrimExp Int64 VName))
-> [TExp Int32] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> (TExp Int32 -> TPrimExp Int64 VName)
-> TExp Int32
-> DimIndex (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64) [TExp Int32]
is))
KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
localThreadIDs :: [SubExp] -> InKernelGen [Imp.TExp Int64]
localThreadIDs :: [SubExp] -> InKernelGen [TPrimExp Int64 VName]
localThreadIDs [SubExp]
dims = do
ltid <- TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TPrimExp Int64 VName)
-> (KernelEnv -> TExp Int32) -> KernelEnv -> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelConstants -> TExp Int32
kernelLocalThreadId (KernelConstants -> TExp Int32)
-> (KernelEnv -> KernelConstants) -> KernelEnv -> TExp Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
let dims' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims
maybe (dIndexSpace' "ltid" dims' ltid) (pure . map sExt64)
. M.lookup dims
. kernelLocalIdMap
. kernelConstants
=<< askEnv
partitionSeqDims :: SegSeqDims -> SegSpace -> ([(VName, SubExp)], [(VName, SubExp)])
partitionSeqDims :: SegSeqDims -> SegSpace -> ([(VName, SubExp)], [(VName, SubExp)])
partitionSeqDims (SegSeqDims [Int]
seq_is) SegSpace
space =
([((VName, SubExp), Int)] -> [(VName, SubExp)])
-> ([((VName, SubExp), Int)] -> [(VName, SubExp)])
-> ([((VName, SubExp), Int)], [((VName, SubExp), Int)])
-> ([(VName, SubExp)], [(VName, SubExp)])
forall a b c d. (a -> b) -> (c -> d) -> (a, c) -> (b, d)
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap ((((VName, SubExp), Int) -> (VName, SubExp))
-> [((VName, SubExp), Int)] -> [(VName, SubExp)]
forall a b. (a -> b) -> [a] -> [b]
map ((VName, SubExp), Int) -> (VName, SubExp)
forall a b. (a, b) -> a
fst) ((((VName, SubExp), Int) -> (VName, SubExp))
-> [((VName, SubExp), Int)] -> [(VName, SubExp)]
forall a b. (a -> b) -> [a] -> [b]
map ((VName, SubExp), Int) -> (VName, SubExp)
forall a b. (a, b) -> a
fst) (([((VName, SubExp), Int)], [((VName, SubExp), Int)])
-> ([(VName, SubExp)], [(VName, SubExp)]))
-> ([((VName, SubExp), Int)], [((VName, SubExp), Int)])
-> ([(VName, SubExp)], [(VName, SubExp)])
forall a b. (a -> b) -> a -> b
$
(((VName, SubExp), Int) -> Bool)
-> [((VName, SubExp), Int)]
-> ([((VName, SubExp), Int)], [((VName, SubExp), Int)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((Int -> [Int] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Int]
seq_is) (Int -> Bool)
-> (((VName, SubExp), Int) -> Int)
-> ((VName, SubExp), Int)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((VName, SubExp), Int) -> Int
forall a b. (a, b) -> b
snd) ([(VName, SubExp)] -> [Int] -> [((VName, SubExp), Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space) [Int
0 ..])
compileFlatId :: SegSpace -> InKernelGen ()
compileFlatId :: SegSpace -> InKernelGen ()
compileFlatId SegSpace
space = do
ltid <- KernelConstants -> TExp Int32
kernelLocalThreadId (KernelConstants -> TExp Int32)
-> (KernelEnv -> KernelConstants) -> KernelEnv -> TExp Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> TExp Int32)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
dPrimV_ (segFlat space) $ sExt64 ltid
prepareIntraBlockSegHist ::
Shape ->
Count BlockSize SubExp ->
[HistOp GPUMem] ->
InKernelGen [[Imp.TExp Int64] -> InKernelGen ()]
prepareIntraBlockSegHist :: ShapeBase SubExp
-> Count BlockSize SubExp
-> [HistOp GPUMem]
-> InKernelGen [[TPrimExp Int64 VName] -> InKernelGen ()]
prepareIntraBlockSegHist ShapeBase SubExp
segments Count BlockSize SubExp
tblock_size =
((Maybe Locking, [[TPrimExp Int64 VName] -> InKernelGen ()])
-> [[TPrimExp Int64 VName] -> InKernelGen ()])
-> ImpM
GPUMem
KernelEnv
KernelOp
(Maybe Locking, [[TPrimExp Int64 VName] -> InKernelGen ()])
-> InKernelGen [[TPrimExp Int64 VName] -> InKernelGen ()]
forall a b.
(a -> b)
-> ImpM GPUMem KernelEnv KernelOp a
-> ImpM GPUMem KernelEnv KernelOp b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Maybe Locking, [[TPrimExp Int64 VName] -> InKernelGen ()])
-> [[TPrimExp Int64 VName] -> InKernelGen ()]
forall a b. (a, b) -> b
snd (ImpM
GPUMem
KernelEnv
KernelOp
(Maybe Locking, [[TPrimExp Int64 VName] -> InKernelGen ()])
-> InKernelGen [[TPrimExp Int64 VName] -> InKernelGen ()])
-> ([HistOp GPUMem]
-> ImpM
GPUMem
KernelEnv
KernelOp
(Maybe Locking, [[TPrimExp Int64 VName] -> InKernelGen ()]))
-> [HistOp GPUMem]
-> InKernelGen [[TPrimExp Int64 VName] -> InKernelGen ()]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Maybe Locking
-> HistOp GPUMem
-> ImpM
GPUMem
KernelEnv
KernelOp
(Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ()))
-> Maybe Locking
-> [HistOp GPUMem]
-> ImpM
GPUMem
KernelEnv
KernelOp
(Maybe Locking, [[TPrimExp Int64 VName] -> InKernelGen ()])
forall (m :: * -> *) (t :: * -> *) acc x y.
(Monad m, Traversable t) =>
(acc -> x -> m (acc, y)) -> acc -> t x -> m (acc, t y)
mapAccumLM Maybe Locking
-> HistOp GPUMem
-> ImpM
GPUMem
KernelEnv
KernelOp
(Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
onOp Maybe Locking
forall a. Maybe a
Nothing
where
onOp :: Maybe Locking
-> HistOp GPUMem
-> ImpM
GPUMem
KernelEnv
KernelOp
(Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
onOp Maybe Locking
l HistOp GPUMem
op = do
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
atomicBinOp <- kernelAtomics <$> askEnv
let local_subhistos = HistOp GPUMem -> [VName]
forall rep. HistOp rep -> [VName]
histDest HistOp GPUMem
op
case (l, atomicUpdateLocking atomicBinOp $ histOp op) of
(Maybe Locking
_, AtomicPrim DoAtomicUpdate GPUMem KernelEnv
f) -> (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
-> ImpM
GPUMem
KernelEnv
KernelOp
(Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
forall a. a -> ImpM GPUMem KernelEnv KernelOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Locking
l, DoAtomicUpdate GPUMem KernelEnv
f (SpaceId -> Space
Space SpaceId
"shared") [VName]
local_subhistos)
(Maybe Locking
_, AtomicCAS DoAtomicUpdate GPUMem KernelEnv
f) -> (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
-> ImpM
GPUMem
KernelEnv
KernelOp
(Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
forall a. a -> ImpM GPUMem KernelEnv KernelOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Locking
l, DoAtomicUpdate GPUMem KernelEnv
f (SpaceId -> Space
Space SpaceId
"shared") [VName]
local_subhistos)
(Just Locking
l', AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
f) -> (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
-> ImpM
GPUMem
KernelEnv
KernelOp
(Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
forall a. a -> ImpM GPUMem KernelEnv KernelOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Locking
l, Locking -> DoAtomicUpdate GPUMem KernelEnv
f Locking
l' (SpaceId -> Space
Space SpaceId
"shared") [VName]
local_subhistos)
(Maybe Locking
Nothing, AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
f) -> do
locks <- SpaceId -> ImpM GPUMem KernelEnv KernelOp VName
forall (m :: * -> *). MonadFreshNames m => SpaceId -> m VName
newVName SpaceId
"locks"
let num_locks = SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName) -> SubExp -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ Count BlockSize SubExp -> SubExp
forall {k} (u :: k) e. Count u e -> e
unCount Count BlockSize SubExp
tblock_size
dims = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (ShapeBase SubExp
segments ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> HistOp GPUMem -> ShapeBase SubExp
forall rep. HistOp rep -> ShapeBase SubExp
histOpShape HistOp GPUMem
op ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> HistOp GPUMem -> ShapeBase SubExp
forall rep. HistOp rep -> ShapeBase SubExp
histShape HistOp GPUMem
op)
l' = VName
-> TExp Int32
-> TExp Int32
-> TExp Int32
-> ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName])
-> Locking
Locking VName
locks TExp Int32
0 TExp Int32
1 TExp Int32
0 (TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TPrimExp Int64 VName -> [TPrimExp Int64 VName])
-> ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp Int64 VName
num_locks) (TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [TPrimExp Int64 VName]
dims)
locks_t = PrimType
-> ShapeBase SubExp
-> NoUniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
int32 ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [Count BlockSize SubExp -> SubExp
forall {k} (u :: k) e. Count u e -> e
unCount Count BlockSize SubExp
tblock_size]) NoUniqueness
NoUniqueness
locks_mem <- sAlloc "locks_mem" (typeSize locks_t) $ Space "shared"
dArray locks int32 (arrayShape locks_t) locks_mem $
LMAD.iota 0 . map pe64 . arrayDims $
locks_t
sComment "All locks start out unlocked" $
blockCoverSpace [kernelBlockSize constants] $ \[TPrimExp Int64 VName]
is ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
locks [TPrimExp Int64 VName]
is (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0) []
pure (Just l', f l' (Space "shared") local_subhistos)
blockCoverSegSpace :: SegVirt -> SegSpace -> InKernelGen () -> InKernelGen ()
blockCoverSegSpace :: SegVirt -> SegSpace -> InKernelGen () -> InKernelGen ()
blockCoverSegSpace SegVirt
virt SegSpace
space InKernelGen ()
m = do
let ([VName]
ltids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
dims' :: [TPrimExp Int64 VName]
dims' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
let tblock_size = KernelConstants -> TPrimExp Int64 VName
kernelBlockSize KernelConstants
constants
let virt' = if [TPrimExp Int64 VName]
dims' [TPrimExp Int64 VName] -> [TPrimExp Int64 VName] -> Bool
forall a. Eq a => a -> a -> Bool
== [TPrimExp Int64 VName
tblock_size] then SegSeqDims -> SegVirt
SegNoVirtFull ([Int] -> SegSeqDims
SegSeqDims []) else SegVirt
virt
case virt' of
SegVirt
SegVirt -> do
iters <- [SubExp] -> Map [SubExp] (TExp Int32) -> Maybe (TExp Int32)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup [SubExp]
dims (Map [SubExp] (TExp Int32) -> Maybe (TExp Int32))
-> (KernelEnv -> Map [SubExp] (TExp Int32))
-> KernelEnv
-> Maybe (TExp Int32)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelConstants -> Map [SubExp] (TExp Int32)
kernelChunkItersMap (KernelConstants -> Map [SubExp] (TExp Int32))
-> (KernelEnv -> KernelConstants)
-> KernelEnv
-> Map [SubExp] (TExp Int32)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> Maybe (TExp Int32))
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (Maybe (TExp Int32))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
case iters of
Maybe (TExp Int32)
Nothing -> do
iterations <- SpaceId
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall {k} (t :: k) rep r op.
SpaceId -> TExp t -> ImpM rep r op (TExp t)
dPrimVE SpaceId
"iterations" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ [TExp Int32] -> TExp Int32
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TExp Int32] -> TExp Int32) -> [TExp Int32] -> TExp Int32
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName -> TExp Int32)
-> [TPrimExp Int64 VName] -> [TExp Int32]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 [TPrimExp Int64 VName]
dims'
blockLoop iterations $ \TExp Int32
i -> do
[(VName, TPrimExp Int64 VName)]
-> TPrimExp Int64 VName -> InKernelGen ()
forall rep r op.
[(VName, TPrimExp Int64 VName)]
-> TPrimExp Int64 VName -> ImpM rep r op ()
dIndexSpace ([VName]
-> [TPrimExp Int64 VName] -> [(VName, TPrimExp Int64 VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ltids [TPrimExp Int64 VName]
dims') (TPrimExp Int64 VName -> InKernelGen ())
-> TPrimExp Int64 VName -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
i
InKernelGen ()
m
Just TExp Int32
num_chunks -> Operations GPUMem KernelEnv KernelOp
-> InKernelGen () -> InKernelGen ()
forall rep r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let ltid :: TExp Int32
ltid = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
SpaceId
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall {k} (t :: k) rep r op.
SpaceId
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor SpaceId
"chunk_i" TExp Int32
num_chunks ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
chunk_i -> do
i <- SpaceId
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall {k} (t :: k) rep r op.
SpaceId -> TExp t -> ImpM rep r op (TExp t)
dPrimVE SpaceId
"i" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int32
chunk_i TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
tblock_size TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
ltid
dIndexSpace (zip ltids dims') $ sExt64 i
sWhen (inBounds (Slice (map (DimFix . le64) ltids)) dims') m
SegVirt
SegNoVirt -> Operations GPUMem KernelEnv KernelOp
-> InKernelGen () -> InKernelGen ()
forall rep r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
(VName -> TPrimExp Int64 VName -> InKernelGen ())
-> [VName] -> [TPrimExp Int64 VName] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TPrimExp Int64 VName -> InKernelGen ()
forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
ltids ([TPrimExp Int64 VName] -> InKernelGen ())
-> InKernelGen [TPrimExp Int64 VName] -> InKernelGen ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [SubExp] -> InKernelGen [TPrimExp Int64 VName]
localThreadIDs [SubExp]
dims
TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen ([(VName, SubExp)] -> TExp Bool
isActive ([(VName, SubExp)] -> TExp Bool) -> [(VName, SubExp)] -> TExp Bool
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ltids [SubExp]
dims) InKernelGen ()
m
SegNoVirtFull SegSeqDims
seq_dims -> do
let (([VName]
ltids_seq, [SubExp]
dims_seq), ([VName]
ltids_par, [SubExp]
dims_par)) =
([(VName, SubExp)] -> ([VName], [SubExp]))
-> ([(VName, SubExp)] -> ([VName], [SubExp]))
-> ([(VName, SubExp)], [(VName, SubExp)])
-> (([VName], [SubExp]), ([VName], [SubExp]))
forall a b c d. (a -> b) -> (c -> d) -> (a, c) -> (b, d)
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip (([(VName, SubExp)], [(VName, SubExp)])
-> (([VName], [SubExp]), ([VName], [SubExp])))
-> ([(VName, SubExp)], [(VName, SubExp)])
-> (([VName], [SubExp]), ([VName], [SubExp]))
forall a b. (a -> b) -> a -> b
$ SegSeqDims -> SegSpace -> ([(VName, SubExp)], [(VName, SubExp)])
partitionSeqDims SegSeqDims
seq_dims SegSpace
space
ShapeBase SubExp
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall rep r op.
ShapeBase SubExp
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
dims_seq) (([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ())
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
is_seq -> do
(VName -> TPrimExp Int64 VName -> InKernelGen ())
-> [VName] -> [TPrimExp Int64 VName] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TPrimExp Int64 VName -> InKernelGen ()
forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
ltids_seq [TPrimExp Int64 VName]
is_seq
Operations GPUMem KernelEnv KernelOp
-> InKernelGen () -> InKernelGen ()
forall rep r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
(VName -> TPrimExp Int64 VName -> InKernelGen ())
-> [VName] -> [TPrimExp Int64 VName] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TPrimExp Int64 VName -> InKernelGen ()
forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
ltids_par ([TPrimExp Int64 VName] -> InKernelGen ())
-> InKernelGen [TPrimExp Int64 VName] -> InKernelGen ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [SubExp] -> InKernelGen [TPrimExp Int64 VName]
localThreadIDs [SubExp]
dims_par
InKernelGen ()
m
compileBlockExp :: ExpCompiler GPUMem KernelEnv Imp.KernelOp
compileBlockExp :: ExpCompiler GPUMem KernelEnv KernelOp
compileBlockExp (Pat [PatElem (LetDec GPUMem)
pe]) (BasicOp (Opaque OpaqueOp
_ SubExp
se)) =
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
PatElem LetDecMem
pe) [] SubExp
se []
compileBlockExp (Pat [PatElem (LetDec GPUMem)
dest]) (BasicOp (ArrayVal [PrimValue]
vs PrimType
t)) =
ExpCompiler GPUMem KernelEnv KernelOp
compileBlockExp ([PatElem LetDecMem] -> Pat LetDecMem
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec GPUMem)
PatElem LetDecMem
dest]) (BasicOp -> Exp GPUMem
forall rep. BasicOp -> Exp rep
BasicOp ([SubExp] -> TypeBase (ShapeBase SubExp) NoUniqueness -> BasicOp
ArrayLit ((PrimValue -> SubExp) -> [PrimValue] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map PrimValue -> SubExp
Constant [PrimValue]
vs) (PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t)))
compileBlockExp (Pat [PatElem (LetDec GPUMem)
dest]) (BasicOp (ArrayLit [SubExp]
es TypeBase (ShapeBase SubExp) NoUniqueness
_)) =
[(Int64, SubExp)]
-> ((Int64, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Int64] -> [SubExp] -> [(Int64, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int64
0 ..] [SubExp]
es) (((Int64, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Int64, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Int64
i, SubExp
e) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
PatElem LetDecMem
dest) [Int64 -> TPrimExp Int64 VName
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64
i :: Int64)] SubExp
e []
compileBlockExp Pat (LetDec GPUMem)
_ (BasicOp (UpdateAcc Safety
safety VName
acc [SubExp]
is [SubExp]
vs)) = do
ltid <- KernelConstants -> TExp Int32
kernelLocalThreadId (KernelConstants -> TExp Int32)
-> (KernelEnv -> KernelConstants) -> KernelEnv -> TExp Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> TExp Int32)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
sWhen (ltid .==. 0) $ updateAcc safety acc is vs
sOp $ Imp.Barrier Imp.FenceLocal
compileBlockExp (Pat [PatElem (LetDec GPUMem)
dest]) (BasicOp (Replicate ShapeBase SubExp
ds SubExp
se)) | ShapeBase SubExp
ds ShapeBase SubExp -> ShapeBase SubExp -> Bool
forall a. Eq a => a -> a -> Bool
/= ShapeBase SubExp
forall a. Monoid a => a
mempty = do
flat <- SpaceId -> ImpM GPUMem KernelEnv KernelOp VName
forall (m :: * -> *). MonadFreshNames m => SpaceId -> m VName
newVName SpaceId
"rep_flat"
is <- replicateM (arrayRank dest_t) (newVName "rep_i")
let is' = (VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 [VName]
is
blockCoverSegSpace SegVirt (SegSpace flat $ zip is $ arrayDims dest_t) $
copyDWIMFix (patElemName dest) is' se (drop (shapeRank ds) is')
sOp $ Imp.Barrier Imp.FenceLocal
where
dest_t :: TypeBase (ShapeBase SubExp) NoUniqueness
dest_t = PatElem LetDecMem -> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
PatElem dec -> TypeBase (ShapeBase SubExp) NoUniqueness
patElemType PatElem (LetDec GPUMem)
PatElem LetDecMem
dest
compileBlockExp (Pat [PatElem (LetDec GPUMem)
dest]) (BasicOp (Iota SubExp
n SubExp
e SubExp
s IntType
it)) = do
n' <- SubExp -> ImpM GPUMem KernelEnv KernelOp Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
forall rep r op. SubExp -> ImpM rep r op Exp
toExp SubExp
n
e' <- toExp e
s' <- toExp s
blockLoop (TPrimExp n') $ \TPrimExp Int64 VName
i' -> do
x <-
SpaceId
-> TExp (ZonkAny 1)
-> ImpM GPUMem KernelEnv KernelOp (TV (ZonkAny 1))
forall {k} (t :: k) rep r op.
SpaceId -> TExp t -> ImpM rep r op (TV t)
dPrimV SpaceId
"x" (TExp (ZonkAny 1)
-> ImpM GPUMem KernelEnv KernelOp (TV (ZonkAny 1)))
-> TExp (ZonkAny 1)
-> ImpM GPUMem KernelEnv KernelOp (TV (ZonkAny 1))
forall a b. (a -> b) -> a -> b
$
Exp -> TExp (ZonkAny 1)
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp (ZonkAny 1)) -> Exp -> TExp (ZonkAny 1)
forall a b. (a -> b) -> a -> b
$
BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Add IntType
it Overflow
OverflowUndef) Exp
e' (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Mul IntType
it Overflow
OverflowUndef) (TPrimExp Int64 VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
i') Exp
s'
copyDWIMFix (patElemName dest) [i'] (Var (tvVar x)) []
sOp $ Imp.Barrier Imp.FenceLocal
compileBlockExp (Pat [PatElem (LetDec GPUMem)
pe]) (BasicOp (Update Safety
safety VName
_ Slice SubExp
slice SubExp
se))
| [SubExp] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([SubExp] -> Bool) -> [SubExp] -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice = do
KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
ltid <- KernelConstants -> TExp Int32
kernelLocalThreadId (KernelConstants -> TExp Int32)
-> (KernelEnv -> KernelConstants) -> KernelEnv -> TExp Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> TExp Int32)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
sWhen (ltid .==. 0) $
case safety of
Safety
Unsafe -> InKernelGen ()
write
Safety
Safe -> TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (Slice (TPrimExp Int64 VName) -> [TPrimExp Int64 VName] -> TExp Bool
inBounds Slice (TPrimExp Int64 VName)
slice' [TPrimExp Int64 VName]
dims) InKernelGen ()
write
sOp $ Imp.Barrier Imp.FenceLocal
where
slice' :: Slice (TPrimExp Int64 VName)
slice' = (SubExp -> TPrimExp Int64 VName)
-> Slice SubExp -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> Slice a -> Slice b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Slice SubExp
slice
dims :: [TPrimExp Int64 VName]
dims = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ TypeBase (ShapeBase SubExp) NoUniqueness -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims (TypeBase (ShapeBase SubExp) NoUniqueness -> [SubExp])
-> TypeBase (ShapeBase SubExp) NoUniqueness -> [SubExp]
forall a b. (a -> b) -> a -> b
$ PatElem LetDecMem -> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
PatElem dec -> TypeBase (ShapeBase SubExp) NoUniqueness
patElemType PatElem (LetDec GPUMem)
PatElem LetDecMem
pe
write :: InKernelGen ()
write = VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
PatElem LetDecMem
pe) (Slice (TPrimExp Int64 VName) -> [DimIndex (TPrimExp Int64 VName)]
forall d. Slice d -> [DimIndex d]
unSlice Slice (TPrimExp Int64 VName)
slice') SubExp
se []
compileBlockExp Pat (LetDec GPUMem)
dest Exp GPUMem
e = do
Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Exp GPUMem -> Bool
forall {rep}. Exp rep -> Bool
doSync Exp GPUMem
e) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
ExpCompiler GPUMem KernelEnv KernelOp
forall rep (inner :: * -> *) r op.
Mem rep inner =>
Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
defCompileExp Pat (LetDec GPUMem)
dest Exp GPUMem
e
where
doSync :: Exp rep -> Bool
doSync Loop {} = Bool
True
doSync Match {} = Bool
True
doSync Exp rep
_ = Bool
False
blockAlloc ::
Pat LetDecMem ->
SubExp ->
Space ->
InKernelGen ()
blockAlloc :: Pat LetDecMem -> SubExp -> Space -> InKernelGen ()
blockAlloc (Pat [PatElem LetDecMem
_]) SubExp
_ ScalarSpace {} =
() -> InKernelGen ()
forall a. a -> ImpM GPUMem KernelEnv KernelOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
blockAlloc (Pat [PatElem LetDecMem
mem]) SubExp
size (Space SpaceId
"shared") =
AllocCompiler GPUMem KernelEnv KernelOp
forall r. AllocCompiler GPUMem r KernelOp
allocLocal (PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
mem) (Count Bytes (TPrimExp Int64 VName) -> InKernelGen ())
-> Count Bytes (TPrimExp Int64 VName) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Count Bytes (TPrimExp Int64 VName)
forall a. a -> Count Bytes a
Imp.bytes (TPrimExp Int64 VName -> Count Bytes (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> Count Bytes (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
size
blockAlloc (Pat [PatElem LetDecMem
mem]) SubExp
_ Space
_ =
SpaceId -> InKernelGen ()
forall a. SpaceId -> a
compilerLimitationS (SpaceId -> InKernelGen ()) -> SpaceId -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ SpaceId
"Cannot allocate memory block " SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ PatElem LetDecMem -> SpaceId
forall a. Pretty a => a -> SpaceId
prettyString PatElem LetDecMem
mem SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SpaceId
" in kernel block."
blockAlloc Pat LetDecMem
dest SubExp
_ Space
_ =
SpaceId -> InKernelGen ()
forall a. HasCallStack => SpaceId -> a
error (SpaceId -> InKernelGen ()) -> SpaceId -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ SpaceId
"Invalid target for in-kernel allocation: " SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ Pat LetDecMem -> SpaceId
forall a. Show a => a -> SpaceId
show Pat LetDecMem
dest
compileBlockOp :: OpCompiler GPUMem KernelEnv Imp.KernelOp
compileBlockOp :: OpCompiler GPUMem KernelEnv KernelOp
compileBlockOp Pat (LetDec GPUMem)
pat (Alloc SubExp
size Space
space) =
Pat LetDecMem -> SubExp -> Space -> InKernelGen ()
blockAlloc Pat (LetDec GPUMem)
Pat LetDecMem
pat SubExp
size Space
space
compileBlockOp Pat (LetDec GPUMem)
pat (Inner (SegOp (SegMap SegLevel
lvl SegSpace
space [TypeBase (ShapeBase SubExp) NoUniqueness]
_ KernelBody GPUMem
body))) = do
SegSpace -> InKernelGen ()
compileFlatId SegSpace
space
SegVirt -> SegSpace -> InKernelGen () -> InKernelGen ()
blockCoverSegSpace (SegLevel -> SegVirt
segVirt SegLevel
lvl) SegSpace
space (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
(PatElem LetDecMem -> KernelResult -> InKernelGen ())
-> [PatElem LetDecMem] -> [KernelResult] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace -> PatElem LetDecMem -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
space) (Pat LetDecMem -> [PatElem LetDecMem]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec GPUMem)
Pat LetDecMem
pat) ([KernelResult] -> InKernelGen ())
-> [KernelResult] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
KernelBody GPUMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
body
KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
compileBlockOp Pat (LetDec GPUMem)
pat (Inner (SegOp (SegScan SegLevel
lvl SegSpace
space [TypeBase (ShapeBase SubExp) NoUniqueness]
_ KernelBody GPUMem
body [SegBinOp GPUMem]
scans))) = do
SegSpace -> InKernelGen ()
compileFlatId SegSpace
space
let ([VName]
ltids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
dims' :: [TPrimExp Int64 VName]
dims' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims
SegVirt -> SegSpace -> InKernelGen () -> InKernelGen ()
blockCoverSegSpace (SegLevel -> SegVirt
segVirt SegLevel
lvl) SegSpace
space (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(VName, KernelResult)]
-> ((VName, KernelResult) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [KernelResult] -> [(VName, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat LetDecMem -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec GPUMem)
Pat LetDecMem
pat) ([KernelResult] -> [(VName, KernelResult)])
-> [KernelResult] -> [(VName, KernelResult)]
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
body) (((VName, KernelResult) -> InKernelGen ()) -> InKernelGen ())
-> ((VName, KernelResult) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
dest, KernelResult
res) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
VName
dest
((VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
ltids)
(KernelResult -> SubExp
kernelResultSubExp KernelResult
res)
[]
fence <- [VName] -> InKernelGen Fence
fenceForArrays ([VName] -> InKernelGen Fence) -> [VName] -> InKernelGen Fence
forall a b. (a -> b) -> a -> b
$ Pat LetDecMem -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec GPUMem)
Pat LetDecMem
pat
sOp $ Imp.ErrorSync fence
let segment_size = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. HasCallStack => [a] -> a
last [TPrimExp Int64 VName]
dims'
crossesSegment TExp Int32
from TExp Int32
to =
(TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
from) TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp Int64 VName
segment_size)
dims_flat <- dPrimV "dims_flat" $ product dims'
let scan = [SegBinOp GPUMem] -> SegBinOp GPUMem
forall a. HasCallStack => [a] -> a
head [SegBinOp GPUMem]
scans
num_scan_results = [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int) -> [SubExp] -> Int
forall a b. (a -> b) -> a -> b
$ SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan
arrs_flat <-
mapM (flattenArray (length dims') dims_flat) $
take num_scan_results $
patNames pat
case segVirt lvl of
SegVirt
SegVirt ->
Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen ()
virtualisedBlockScan
((TExp Int32 -> TExp Int32 -> TExp Bool)
-> Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
forall a. a -> Maybe a
Just TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment)
(TPrimExp Int64 VName -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TPrimExp Int64 VName
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dims_flat)
(SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan)
[VName]
arrs_flat
SegVirt
_ ->
Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> Lambda GPUMem
-> [VName]
-> InKernelGen ()
blockScan
((TExp Int32 -> TExp Int32 -> TExp Bool)
-> Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
forall a. a -> Maybe a
Just TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment)
([TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims')
([TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims')
(SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan)
[VName]
arrs_flat
compileBlockOp Pat (LetDec GPUMem)
pat (Inner (SegOp (SegRed SegLevel
lvl SegSpace
space [TypeBase (ShapeBase SubExp) NoUniqueness]
_ KernelBody GPUMem
body [SegBinOp GPUMem]
ops))) = do
SegSpace -> InKernelGen ()
compileFlatId SegSpace
space
let dims' :: [TPrimExp Int64 VName]
dims' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims
mkTempArr :: TypeBase (ShapeBase SubExp) NoUniqueness
-> ImpM GPUMem KernelEnv KernelOp VName
mkTempArr TypeBase (ShapeBase SubExp) NoUniqueness
t =
SpaceId
-> PrimType
-> ShapeBase SubExp
-> Space
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
SpaceId
-> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray SpaceId
"red_arr" (TypeBase (ShapeBase SubExp) NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase (ShapeBase SubExp) NoUniqueness
t) ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
dims ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> TypeBase (ShapeBase SubExp) NoUniqueness -> ShapeBase SubExp
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase (ShapeBase SubExp) NoUniqueness
t) (Space -> ImpM GPUMem KernelEnv KernelOp VName)
-> Space -> ImpM GPUMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ SpaceId -> Space
Space SpaceId
"shared"
tmp_arrs <- (TypeBase (ShapeBase SubExp) NoUniqueness
-> ImpM GPUMem KernelEnv KernelOp VName)
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM TypeBase (ShapeBase SubExp) NoUniqueness
-> ImpM GPUMem KernelEnv KernelOp VName
mkTempArr ([TypeBase (ShapeBase SubExp) NoUniqueness]
-> ImpM GPUMem KernelEnv KernelOp [VName])
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall a b. (a -> b) -> a -> b
$ (SegBinOp GPUMem -> [TypeBase (ShapeBase SubExp) NoUniqueness])
-> [SegBinOp GPUMem] -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Lambda GPUMem -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall rep.
Lambda rep -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType (Lambda GPUMem -> [TypeBase (ShapeBase SubExp) NoUniqueness])
-> (SegBinOp GPUMem -> Lambda GPUMem)
-> SegBinOp GPUMem
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda) [SegBinOp GPUMem]
ops
blockCoverSegSpace (segVirt lvl) space $
compileStms mempty (kernelBodyStms body) $ do
let (red_res, map_res) =
splitAt (segBinOpResults ops) $ kernelBodyResult body
forM_ (zip tmp_arrs red_res) $ \(VName
dest, KernelResult
res) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
dest ((VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
ltids) (KernelResult -> SubExp
kernelResultSubExp KernelResult
res) []
zipWithM_ (compileThreadResult space) map_pes map_res
sOp $ Imp.ErrorSync Imp.FenceLocal
let tmps_for_ops = [Int] -> [VName] -> [[VName]]
forall a. [Int] -> [a] -> [[a]]
chunks ((SegBinOp GPUMem -> Int) -> [SegBinOp GPUMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp GPUMem -> [SubExp]) -> SegBinOp GPUMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral) [SegBinOp GPUMem]
ops) [VName]
tmp_arrs
case segVirt lvl of
SegVirt
SegVirt -> [TPrimExp Int64 VName] -> [[VName]] -> InKernelGen ()
virtCase [TPrimExp Int64 VName]
dims' [[VName]]
tmps_for_ops
SegVirt
_ -> [TPrimExp Int64 VName] -> [[VName]] -> InKernelGen ()
nonvirtCase [TPrimExp Int64 VName]
dims' [[VName]]
tmps_for_ops
where
([VName]
ltids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
([PatElem LetDecMem]
red_pes, [PatElem LetDecMem]
map_pes) = Int
-> [PatElem LetDecMem]
-> ([PatElem LetDecMem], [PatElem LetDecMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp GPUMem] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp GPUMem]
ops) ([PatElem LetDecMem] -> ([PatElem LetDecMem], [PatElem LetDecMem]))
-> [PatElem LetDecMem]
-> ([PatElem LetDecMem], [PatElem LetDecMem])
forall a b. (a -> b) -> a -> b
$ Pat LetDecMem -> [PatElem LetDecMem]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec GPUMem)
Pat LetDecMem
pat
virtCase :: [TPrimExp Int64 VName] -> [[VName]] -> InKernelGen ()
virtCase [TPrimExp Int64 VName
dim'] [[VName]]
tmps_for_ops = do
ltid <- KernelConstants -> TExp Int32
kernelLocalThreadId (KernelConstants -> TExp Int32)
-> (KernelEnv -> KernelConstants) -> KernelEnv -> TExp Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> TExp Int32)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
blockChunkLoop (sExt32 dim') $ \TExp Int32
chunk_start TV Int64
chunk_size -> do
Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"possibly incorporate carry" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
chunk_start TExp Int32 -> TExp Int32 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TExp Int32
0 TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int32
ltid TExp Int32 -> TExp Int32 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(SegBinOp GPUMem, [VName])]
-> ((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp GPUMem] -> [[VName]] -> [(SegBinOp GPUMem, [VName])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp GPUMem]
ops [[VName]]
tmps_for_ops) (((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ())
-> ((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOp GPUMem
op, [VName]
tmps) ->
Lambda GPUMem
-> [(VName, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> InKernelGen ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
Lambda rep
-> [(VName, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> ImpM rep r op ()
applyRenamedLambda
(SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
op)
((VName -> (VName, [DimIndex (TPrimExp Int64 VName)]))
-> [VName] -> [(VName, [DimIndex (TPrimExp Int64 VName)])]
forall a b. (a -> b) -> [a] -> [b]
map (,[TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start]) [VName]
tmps)
( (PatElem LetDecMem -> (SubExp, [DimIndex (TPrimExp Int64 VName)]))
-> [PatElem LetDecMem]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
forall a b. (a -> b) -> [a] -> [b]
map ((,[]) (SubExp -> (SubExp, [DimIndex (TPrimExp Int64 VName)]))
-> (PatElem LetDecMem -> SubExp)
-> PatElem LetDecMem
-> (SubExp, [DimIndex (TPrimExp Int64 VName)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp)
-> (PatElem LetDecMem -> VName) -> PatElem LetDecMem -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName) [PatElem LetDecMem]
red_pes
[(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
forall a. [a] -> [a] -> [a]
++ (VName -> (SubExp, [DimIndex (TPrimExp Int64 VName)]))
-> [VName] -> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
forall a b. (a -> b) -> [a] -> [b]
map ((,[TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start]) (SubExp -> (SubExp, [DimIndex (TPrimExp Int64 VName)]))
-> (VName -> SubExp)
-> VName
-> (SubExp, [DimIndex (TPrimExp Int64 VName)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
tmps
)
KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
[(SegBinOp GPUMem, [VName])]
-> ((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp GPUMem] -> [[VName]] -> [(SegBinOp GPUMem, [VName])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp GPUMem]
ops [[VName]]
tmps_for_ops) (((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ())
-> ((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOp GPUMem
op, [VName]
tmps) -> do
tmps_chunks <- (VName -> ImpM GPUMem KernelEnv KernelOp VName)
-> [VName] -> ImpM GPUMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (TPrimExp Int64 VName
-> TV Int64 -> VName -> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
TPrimExp Int64 VName -> TV Int64 -> VName -> ImpM rep r op VName
sliceArray (TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start) TV Int64
chunk_size) [VName]
tmps
blockReduce (sExt32 (tvExp chunk_size)) (segBinOpLambda op) tmps_chunks
KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Save result of reduction." (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(PatElem LetDecMem, VName)]
-> ((PatElem LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LetDecMem] -> [VName] -> [(PatElem LetDecMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LetDecMem]
red_pes ([VName] -> [(PatElem LetDecMem, VName)])
-> [VName] -> [(PatElem LetDecMem, VName)]
forall a b. (a -> b) -> a -> b
$ [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
tmps_for_ops) (((PatElem LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((PatElem LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElem LetDecMem
pe, VName
arr) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) [] (VName -> SubExp
Var VName
arr) [TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start]
virtCase [TPrimExp Int64 VName]
dims' [[VName]]
tmps_for_ops = do
dims_flat <- SpaceId
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall {k} (t :: k) rep r op.
SpaceId -> TExp t -> ImpM rep r op (TV t)
dPrimV SpaceId
"dims_flat" (TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims'
let segment_size = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. HasCallStack => [a] -> a
last [TPrimExp Int64 VName]
dims'
crossesSegment TExp Int32
from TExp Int32
to =
(TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
from) TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp Int64 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
segment_size)
forM_ (zip ops tmps_for_ops) $ \(SegBinOp GPUMem
op, [VName]
tmps) -> do
tmps_flat <- (VName -> ImpM GPUMem KernelEnv KernelOp VName)
-> [VName] -> ImpM GPUMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Int -> TV Int64 -> VName -> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op. Int -> TV Int64 -> VName -> ImpM rep r op VName
flattenArray ([TPrimExp Int64 VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
dims') TV Int64
dims_flat) [VName]
tmps
virtualisedBlockScan
(Just crossesSegment)
(sExt32 $ tvExp dims_flat)
(segBinOpLambda op)
tmps_flat
sOp $ Imp.ErrorSync Imp.FenceLocal
sComment "Save result of reduction." $
forM_ (zip red_pes $ concat tmps_for_ops) $ \(PatElem LetDecMem
pe, VName
arr) ->
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM
(PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe)
[]
(VName -> SubExp
Var VName
arr)
((TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map (TPrimExp Int64 VName
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. Num d => d -> d -> DimIndex d
unitSlice TPrimExp Int64 VName
0) ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. HasCallStack => [a] -> [a]
init [TPrimExp Int64 VName]
dims') [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. HasCallStack => [a] -> a
last [TPrimExp Int64 VName]
dims' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1])
sOp $ Imp.Barrier Imp.FenceLocal
nonvirtCase :: [TPrimExp Int64 VName] -> [[VName]] -> InKernelGen ()
nonvirtCase [TPrimExp Int64 VName
dim'] [[VName]]
tmps_for_ops = do
[(SegBinOp GPUMem, [VName])]
-> ((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp GPUMem] -> [[VName]] -> [(SegBinOp GPUMem, [VName])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp GPUMem]
ops [[VName]]
tmps_for_ops) (((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ())
-> ((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOp GPUMem
op, [VName]
tmps) ->
TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen ()
blockReduce (TPrimExp Int64 VName -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
dim') (SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
op) [VName]
tmps
KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Save result of reduction." (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(PatElem LetDecMem, VName)]
-> ((PatElem LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LetDecMem] -> [VName] -> [(PatElem LetDecMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LetDecMem]
red_pes ([VName] -> [(PatElem LetDecMem, VName)])
-> [VName] -> [(PatElem LetDecMem, VName)]
forall a b. (a -> b) -> a -> b
$ [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
tmps_for_ops) (((PatElem LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((PatElem LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElem LetDecMem
pe, VName
arr) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) [] (VName -> SubExp
Var VName
arr) [TPrimExp Int64 VName
0]
KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
nonvirtCase [TPrimExp Int64 VName]
dims' [[VName]]
tmps_for_ops = do
dims_flat <- SpaceId
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall {k} (t :: k) rep r op.
SpaceId -> TExp t -> ImpM rep r op (TV t)
dPrimV SpaceId
"dims_flat" (TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims'
let segment_size = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. HasCallStack => [a] -> a
last [TPrimExp Int64 VName]
dims'
crossesSegment TExp Int32
from TExp Int32
to =
(TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
from) TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp Int64 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
segment_size)
forM_ (zip ops tmps_for_ops) $ \(SegBinOp GPUMem
op, [VName]
tmps) -> do
tmps_flat <- (VName -> ImpM GPUMem KernelEnv KernelOp VName)
-> [VName] -> ImpM GPUMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Int -> TV Int64 -> VName -> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op. Int -> TV Int64 -> VName -> ImpM rep r op VName
flattenArray ([TPrimExp Int64 VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
dims') TV Int64
dims_flat) [VName]
tmps
blockScan
(Just crossesSegment)
(product dims')
(product dims')
(segBinOpLambda op)
tmps_flat
sOp $ Imp.ErrorSync Imp.FenceLocal
sComment "Save result of reduction." $
forM_ (zip red_pes $ concat tmps_for_ops) $ \(PatElem LetDecMem
pe, VName
arr) ->
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM
(PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe)
[]
(VName -> SubExp
Var VName
arr)
((TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map (TPrimExp Int64 VName
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. Num d => d -> d -> DimIndex d
unitSlice TPrimExp Int64 VName
0) ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. HasCallStack => [a] -> [a]
init [TPrimExp Int64 VName]
dims') [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. HasCallStack => [a] -> a
last [TPrimExp Int64 VName]
dims' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1])
sOp $ Imp.Barrier Imp.FenceLocal
compileBlockOp Pat (LetDec GPUMem)
pat (Inner (SegOp (SegHist SegLevel
lvl SegSpace
space [TypeBase (ShapeBase SubExp) NoUniqueness]
_ KernelBody GPUMem
kbody [HistOp GPUMem]
ops))) = do
SegSpace -> InKernelGen ()
compileFlatId SegSpace
space
let ([VName]
ltids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
let num_red_res :: Int
num_red_res = [HistOp GPUMem] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp GPUMem]
ops Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((HistOp GPUMem -> Int) -> [HistOp GPUMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (HistOp GPUMem -> [SubExp]) -> HistOp GPUMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp GPUMem -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral) [HistOp GPUMem]
ops)
([PatElem LetDecMem]
_red_pes, [PatElem LetDecMem]
map_pes) =
Int
-> [PatElem LetDecMem]
-> ([PatElem LetDecMem], [PatElem LetDecMem])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_red_res ([PatElem LetDecMem] -> ([PatElem LetDecMem], [PatElem LetDecMem]))
-> [PatElem LetDecMem]
-> ([PatElem LetDecMem], [PatElem LetDecMem])
forall a b. (a -> b) -> a -> b
$ Pat LetDecMem -> [PatElem LetDecMem]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec GPUMem)
Pat LetDecMem
pat
tblock_size <- KernelConstants -> Count BlockSize SubExp
kernelBlockSizeCount (KernelConstants -> Count BlockSize SubExp)
-> (KernelEnv -> KernelConstants)
-> KernelEnv
-> Count BlockSize SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> Count BlockSize SubExp)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (Count BlockSize SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
ops' <- prepareIntraBlockSegHist (Shape $ init dims) tblock_size ops
sOp $ Imp.Barrier Imp.FenceLocal
blockCoverSegSpace (segVirt lvl) space $
compileStms mempty (kernelBodyStms kbody) $ do
let (red_res, map_res) = splitAt num_red_res $ kernelBodyResult kbody
(red_is, red_vs) = splitAt (length ops) $ map kernelResultSubExp red_res
zipWithM_ (compileThreadResult space) map_pes map_res
let vs_per_op = [Int] -> [SubExp] -> [[SubExp]]
forall a. [Int] -> [a] -> [[a]]
chunks ((HistOp GPUMem -> Int) -> [HistOp GPUMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int)
-> (HistOp GPUMem -> [VName]) -> HistOp GPUMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp GPUMem -> [VName]
forall rep. HistOp rep -> [VName]
histDest) [HistOp GPUMem]
ops) [SubExp]
red_vs
forM_ (zip4 red_is vs_per_op ops' ops) $
\(SubExp
bin, [SubExp]
op_vs, [TPrimExp Int64 VName] -> InKernelGen ()
do_op, HistOp ShapeBase SubExp
dest_shape SubExp
_ [VName]
_ [SubExp]
_ ShapeBase SubExp
shape Lambda GPUMem
lam) -> do
let bin' :: TPrimExp Int64 VName
bin' = SubExp -> TPrimExp Int64 VName
pe64 SubExp
bin
dest_shape' :: [TPrimExp Int64 VName]
dest_shape' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
dest_shape
bin_in_bounds :: TExp Bool
bin_in_bounds = Slice (TPrimExp Int64 VName) -> [TPrimExp Int64 VName] -> TExp Bool
inBounds ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
bin']) [TPrimExp Int64 VName]
dest_shape'
bin_is :: [TPrimExp Int64 VName]
bin_is = (VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 ([VName] -> [VName]
forall a. HasCallStack => [a] -> [a]
init [VName]
ltids) [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName
bin']
vs_params :: [Param LetDecMem]
vs_params = Int -> [Param LetDecMem] -> [Param LetDecMem]
forall a. Int -> [a] -> [a]
takeLast ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
op_vs) ([Param LetDecMem] -> [Param LetDecMem])
-> [Param LetDecMem] -> [Param LetDecMem]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam
Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"perform atomic updates" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
bin_in_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
[LParam GPUMem] -> InKernelGen ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams ([LParam GPUMem] -> InKernelGen ())
-> [LParam GPUMem] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam
ShapeBase SubExp
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall rep r op.
ShapeBase SubExp
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest ShapeBase SubExp
shape (([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ())
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
is -> do
[(Param LetDecMem, SubExp)]
-> ((Param LetDecMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [SubExp] -> [(Param LetDecMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
vs_params [SubExp]
op_vs) (((Param LetDecMem, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LetDecMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, SubExp
v) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] SubExp
v [TPrimExp Int64 VName]
is
[TPrimExp Int64 VName] -> InKernelGen ()
do_op ([TPrimExp Int64 VName]
bin_is [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
is)
sOp $ Imp.ErrorSync Imp.FenceLocal
compileBlockOp Pat (LetDec GPUMem)
pat Op GPUMem
_ =
SpaceId -> InKernelGen ()
forall a. SpaceId -> a
compilerBugS (SpaceId -> InKernelGen ()) -> SpaceId -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ SpaceId
"compileBlockOp: cannot compile rhs of binding " SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ Pat LetDecMem -> SpaceId
forall a. Pretty a => a -> SpaceId
prettyString Pat (LetDec GPUMem)
Pat LetDecMem
pat
blockOperations :: Operations GPUMem KernelEnv Imp.KernelOp
blockOperations :: Operations GPUMem KernelEnv KernelOp
blockOperations =
(OpCompiler GPUMem KernelEnv KernelOp
-> Operations GPUMem KernelEnv KernelOp
forall rep (inner :: * -> *) op r.
(Mem rep inner, FreeIn op) =>
OpCompiler rep r op -> Operations rep r op
defaultOperations OpCompiler GPUMem KernelEnv KernelOp
compileBlockOp)
{ opsCopyCompiler = copyInBlock,
opsExpCompiler = compileBlockExp,
opsStmsCompiler = \Names
_ -> Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall rep (inner :: * -> *) op r.
(Mem rep inner, FreeIn op) =>
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
defCompileStms Names
forall a. Monoid a => a
mempty,
opsAllocCompilers =
M.fromList [(Space "shared", allocLocal)]
}
arrayInSharedMemory :: SubExp -> InKernelGen Bool
arrayInSharedMemory :: SubExp -> InKernelGen Bool
arrayInSharedMemory (Var VName
name) = do
res <- VName -> ImpM GPUMem KernelEnv KernelOp (VarEntry GPUMem)
forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
name
case res of
ArrayVar Maybe (Exp GPUMem)
_ ArrayEntry
entry ->
(SpaceId -> Space
Space SpaceId
"shared" Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
==) (Space -> Bool) -> (MemEntry -> Space) -> MemEntry -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemEntry -> Space
entryMemSpace
(MemEntry -> Bool)
-> ImpM GPUMem KernelEnv KernelOp MemEntry -> InKernelGen Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM GPUMem KernelEnv KernelOp MemEntry
forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory (MemLoc -> VName
memLocName (ArrayEntry -> MemLoc
entryArrayLoc ArrayEntry
entry))
VarEntry GPUMem
_ -> Bool -> InKernelGen Bool
forall a. a -> ImpM GPUMem KernelEnv KernelOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
arrayInSharedMemory Constant {} = Bool -> InKernelGen Bool
forall a. a -> ImpM GPUMem KernelEnv KernelOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
sKernelBlock ::
String ->
VName ->
KernelAttrs ->
InKernelGen () ->
CallKernelGen ()
sKernelBlock :: SpaceId
-> VName -> KernelAttrs -> InKernelGen () -> CallKernelGen ()
sKernelBlock = Operations GPUMem KernelEnv KernelOp
-> (KernelConstants -> TPrimExp Int64 VName)
-> SpaceId
-> VName
-> KernelAttrs
-> InKernelGen ()
-> CallKernelGen ()
sKernel Operations GPUMem KernelEnv KernelOp
blockOperations ((KernelConstants -> TPrimExp Int64 VName)
-> SpaceId
-> VName
-> KernelAttrs
-> InKernelGen ()
-> CallKernelGen ())
-> (KernelConstants -> TPrimExp Int64 VName)
-> SpaceId
-> VName
-> KernelAttrs
-> InKernelGen ()
-> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TPrimExp Int64 VName)
-> (KernelConstants -> TExp Int32)
-> KernelConstants
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelConstants -> TExp Int32
kernelBlockId
compileBlockResult ::
SegSpace ->
PatElem LetDecMem ->
KernelResult ->
InKernelGen ()
compileBlockResult :: SegSpace -> PatElem LetDecMem -> KernelResult -> InKernelGen ()
compileBlockResult SegSpace
_ PatElem LetDecMem
pe (TileReturns Certs
_ [(SubExp
w, SubExp
per_block_elems)] VName
what) = do
n <- SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName)
-> (TypeBase (ShapeBase SubExp) NoUniqueness -> SubExp)
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> TypeBase (ShapeBase SubExp) NoUniqueness -> SubExp
forall u. Int -> TypeBase (ShapeBase SubExp) u -> SubExp
arraySize Int
0 (TypeBase (ShapeBase SubExp) NoUniqueness -> TPrimExp Int64 VName)
-> ImpM
GPUMem
KernelEnv
KernelOp
(TypeBase (ShapeBase SubExp) NoUniqueness)
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> ImpM
GPUMem
KernelEnv
KernelOp
(TypeBase (ShapeBase SubExp) NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase (ShapeBase SubExp) NoUniqueness)
lookupType VName
what
constants <- kernelConstants <$> askEnv
let ltid = TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TPrimExp Int64 VName)
-> TExp Int32 -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
offset =
SubExp -> TPrimExp Int64 VName
pe64 SubExp
per_block_elems
TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelBlockId KernelConstants
constants)
localOps threadOperations $
if pe64 per_block_elems == kernelBlockSize constants
then
sWhen (ltid + offset .<. pe64 w) $
copyDWIMFix (patElemName pe) [ltid + offset] (Var what) [ltid]
else sFor "i" (n `divUp` kernelBlockSize constants) $ \TPrimExp Int64 VName
i -> do
j <- SpaceId
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall {k} (t :: k) rep r op.
SpaceId -> TExp t -> ImpM rep r op (TExp t)
dPrimVE SpaceId
"j" (TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int64 VName
kernelBlockSize KernelConstants
constants TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
ltid
sWhen (j + offset .<. pe64 w) $
copyDWIMFix (patElemName pe) [j + offset] (Var what) [j]
compileBlockResult SegSpace
space PatElem LetDecMem
pe (TileReturns Certs
_ [(SubExp, SubExp)]
dims VName
what) = do
let gids :: [VName]
gids = ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst ([(VName, SubExp)] -> [VName]) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
out_tile_sizes :: [TPrimExp Int64 VName]
out_tile_sizes = ((SubExp, SubExp) -> TPrimExp Int64 VName)
-> [(SubExp, SubExp)] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName)
-> ((SubExp, SubExp) -> SubExp)
-> (SubExp, SubExp)
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(SubExp, SubExp)]
dims
block_is :: [TPrimExp Int64 VName]
block_is = (TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
(*) ((VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gids) [TPrimExp Int64 VName]
out_tile_sizes
local_is <- [SubExp] -> InKernelGen [TPrimExp Int64 VName]
localThreadIDs ([SubExp] -> InKernelGen [TPrimExp Int64 VName])
-> [SubExp] -> InKernelGen [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ ((SubExp, SubExp) -> SubExp) -> [(SubExp, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(SubExp, SubExp)]
dims
is_for_thread <-
mapM (dPrimV "thread_out_index") $
zipWith (+) block_is local_is
localOps threadOperations $
sWhen (isActive $ zip (map tvVar is_for_thread) $ map fst dims) $
copyDWIMFix (patElemName pe) (map tvExp is_for_thread) (Var what) local_is
compileBlockResult SegSpace
space PatElem LetDecMem
pe (RegTileReturns Certs
_ [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
what) = do
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
let gids = ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst ([(VName, SubExp)] -> [VName]) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
(dims, block_tiles, reg_tiles) = unzip3 dims_n_tiles
block_tiles' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
block_tiles
reg_tiles' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
reg_tiles
let block_tile_is = (VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gids
reg_tile_is <-
dIndexSpace' "reg_tile_i" block_tiles' $ sExt64 $ kernelLocalThreadId constants
let regTileSliceDim (TExp t
block_tile, TExp t
block_tile_i) (TExp t
reg_tile, TExp t
reg_tile_i) = do
tile_dim_start <-
SpaceId -> TExp t -> ImpM rep r op (TExp t)
forall {k} (t :: k) rep r op.
SpaceId -> TExp t -> ImpM rep r op (TExp t)
dPrimVE SpaceId
"tile_dim_start" (TExp t -> ImpM rep r op (TExp t))
-> TExp t -> ImpM rep r op (TExp t)
forall a b. (a -> b) -> a -> b
$
TExp t
reg_tile TExp t -> TExp t -> TExp t
forall a. Num a => a -> a -> a
* (TExp t
block_tile TExp t -> TExp t -> TExp t
forall a. Num a => a -> a -> a
* TExp t
block_tile_i TExp t -> TExp t -> TExp t
forall a. Num a => a -> a -> a
+ TExp t
reg_tile_i)
pure $ DimSlice tile_dim_start reg_tile 1
reg_tile_slices <-
Slice
<$> zipWithM
regTileSliceDim
(zip block_tiles' block_tile_is)
(zip reg_tiles' reg_tile_is)
localOps threadOperations $
sLoopNest (Shape reg_tiles) $ \[TPrimExp Int64 VName]
is_in_reg_tile -> do
let dest_is :: [TPrimExp Int64 VName]
dest_is = Slice (TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall d. Num d => Slice d -> [d] -> [d]
fixSlice Slice (TPrimExp Int64 VName)
reg_tile_slices [TPrimExp Int64 VName]
is_in_reg_tile
src_is :: [TPrimExp Int64 VName]
src_is = [TPrimExp Int64 VName]
reg_tile_is [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
is_in_reg_tile
TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen ((TExp Bool -> TExp Bool -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall a. (a -> a -> a) -> [a] -> a
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) ([TExp Bool] -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool)
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName] -> [TExp Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
(.<.) [TPrimExp Int64 VName]
dest_is ([TPrimExp Int64 VName] -> [TExp Bool])
-> [TPrimExp Int64 VName] -> [TExp Bool]
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) [TPrimExp Int64 VName]
dest_is (VName -> SubExp
Var VName
what) [TPrimExp Int64 VName]
src_is
compileBlockResult SegSpace
space PatElem LetDecMem
pe (Returns ResultManifest
_ Certs
_ SubExp
what) = do
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
in_shared_memory <- arrayInSharedMemory what
let gids = ((VName, SubExp) -> TPrimExp Int64 VName)
-> [(VName, SubExp)] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 (VName -> TPrimExp Int64 VName)
-> ((VName, SubExp) -> VName)
-> (VName, SubExp)
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) ([(VName, SubExp)] -> [TPrimExp Int64 VName])
-> [(VName, SubExp)] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
if not in_shared_memory
then
localOps threadOperations $
sWhen (kernelLocalThreadId constants .==. 0) $
copyDWIMFix (patElemName pe) gids what []
else
copyDWIMFix (patElemName pe) gids what []
compileBlockResult SegSpace
_ PatElem LetDecMem
_ WriteReturns {} =
SpaceId -> InKernelGen ()
forall a. SpaceId -> a
compilerLimitationS SpaceId
"compileBlockResult: WriteReturns not handled yet."
type SegOpSizes = S.Set [SubExp]
data Precomputed = Precomputed
{ Precomputed -> SegOpSizes
pcSegOpSizes :: SegOpSizes,
Precomputed -> Map [SubExp] (TExp Int32)
pcChunkItersMap :: M.Map [SubExp] (Imp.TExp Int32)
}
segOpSizes :: Stms GPUMem -> SegOpSizes
segOpSizes :: Stms GPUMem -> SegOpSizes
segOpSizes = Stms GPUMem -> SegOpSizes
onStms
where
onStms :: Stms GPUMem -> SegOpSizes
onStms = (Stm GPUMem -> SegOpSizes) -> Stms GPUMem -> SegOpSizes
forall m a. Monoid m => (a -> m) -> Seq a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> SegOpSizes
onStm
onStm :: Stm GPUMem -> SegOpSizes
onStm (Let Pat (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (Op (Inner (SegOp SegOp SegLevel GPUMem
op)))) =
case SegLevel -> SegVirt
segVirt (SegLevel -> SegVirt) -> SegLevel -> SegVirt
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPUMem -> SegLevel
forall lvl rep. SegOp lvl rep -> lvl
segLevel SegOp SegLevel GPUMem
op of
SegNoVirtFull SegSeqDims
seq_dims ->
[SubExp] -> SegOpSizes
forall a. a -> Set a
S.singleton ([SubExp] -> SegOpSizes) -> [SubExp] -> SegOpSizes
forall a b. (a -> b) -> a -> b
$ ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd ([(VName, SubExp)] -> [SubExp]) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ ([(VName, SubExp)], [(VName, SubExp)]) -> [(VName, SubExp)]
forall a b. (a, b) -> b
snd (([(VName, SubExp)], [(VName, SubExp)]) -> [(VName, SubExp)])
-> ([(VName, SubExp)], [(VName, SubExp)]) -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSeqDims -> SegSpace -> ([(VName, SubExp)], [(VName, SubExp)])
partitionSeqDims SegSeqDims
seq_dims (SegSpace -> ([(VName, SubExp)], [(VName, SubExp)]))
-> SegSpace -> ([(VName, SubExp)], [(VName, SubExp)])
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPUMem -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp SegLevel GPUMem
op
SegVirt
_ -> [SubExp] -> SegOpSizes
forall a. a -> Set a
S.singleton ([SubExp] -> SegOpSizes) -> [SubExp] -> SegOpSizes
forall a b. (a -> b) -> a -> b
$ ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd ([(VName, SubExp)] -> [SubExp]) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace (SegSpace -> [(VName, SubExp)]) -> SegSpace -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPUMem -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp SegLevel GPUMem
op
onStm (Let (Pat [PatElem (LetDec GPUMem)
pe]) StmAux (ExpDec GPUMem)
_ (BasicOp (Replicate {}))) =
[SubExp] -> SegOpSizes
forall a. a -> Set a
S.singleton ([SubExp] -> SegOpSizes) -> [SubExp] -> SegOpSizes
forall a b. (a -> b) -> a -> b
$ TypeBase (ShapeBase SubExp) NoUniqueness -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims (TypeBase (ShapeBase SubExp) NoUniqueness -> [SubExp])
-> TypeBase (ShapeBase SubExp) NoUniqueness -> [SubExp]
forall a b. (a -> b) -> a -> b
$ PatElem (LetDec GPUMem) -> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
PatElem dec -> TypeBase (ShapeBase SubExp) NoUniqueness
patElemType PatElem (LetDec GPUMem)
pe
onStm (Let (Pat [PatElem (LetDec GPUMem)
pe]) StmAux (ExpDec GPUMem)
_ (BasicOp (Iota {}))) =
[SubExp] -> SegOpSizes
forall a. a -> Set a
S.singleton ([SubExp] -> SegOpSizes) -> [SubExp] -> SegOpSizes
forall a b. (a -> b) -> a -> b
$ TypeBase (ShapeBase SubExp) NoUniqueness -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims (TypeBase (ShapeBase SubExp) NoUniqueness -> [SubExp])
-> TypeBase (ShapeBase SubExp) NoUniqueness -> [SubExp]
forall a b. (a -> b) -> a -> b
$ PatElem (LetDec GPUMem) -> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
PatElem dec -> TypeBase (ShapeBase SubExp) NoUniqueness
patElemType PatElem (LetDec GPUMem)
pe
onStm (Let (Pat [PatElem (LetDec GPUMem)
pe]) StmAux (ExpDec GPUMem)
_ (BasicOp (Manifest {}))) =
[SubExp] -> SegOpSizes
forall a. a -> Set a
S.singleton ([SubExp] -> SegOpSizes) -> [SubExp] -> SegOpSizes
forall a b. (a -> b) -> a -> b
$ TypeBase (ShapeBase SubExp) NoUniqueness -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims (TypeBase (ShapeBase SubExp) NoUniqueness -> [SubExp])
-> TypeBase (ShapeBase SubExp) NoUniqueness -> [SubExp]
forall a b. (a -> b) -> a -> b
$ PatElem (LetDec GPUMem) -> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
PatElem dec -> TypeBase (ShapeBase SubExp) NoUniqueness
patElemType PatElem (LetDec GPUMem)
pe
onStm (Let Pat (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (Match [SubExp]
_ [Case (Body GPUMem)]
cases Body GPUMem
defbody MatchDec (BranchType GPUMem)
_)) =
(Case (Body GPUMem) -> SegOpSizes)
-> [Case (Body GPUMem)] -> SegOpSizes
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Stms GPUMem -> SegOpSizes
onStms (Stms GPUMem -> SegOpSizes)
-> (Case (Body GPUMem) -> Stms GPUMem)
-> Case (Body GPUMem)
-> SegOpSizes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body GPUMem -> Stms GPUMem
forall rep. Body rep -> Stms rep
bodyStms (Body GPUMem -> Stms GPUMem)
-> (Case (Body GPUMem) -> Body GPUMem)
-> Case (Body GPUMem)
-> Stms GPUMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Case (Body GPUMem) -> Body GPUMem
forall body. Case body -> body
caseBody) [Case (Body GPUMem)]
cases SegOpSizes -> SegOpSizes -> SegOpSizes
forall a. Semigroup a => a -> a -> a
<> Stms GPUMem -> SegOpSizes
onStms (Body GPUMem -> Stms GPUMem
forall rep. Body rep -> Stms rep
bodyStms Body GPUMem
defbody)
onStm (Let Pat (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (Loop [(FParam GPUMem, SubExp)]
_ LoopForm
_ Body GPUMem
body)) =
Stms GPUMem -> SegOpSizes
onStms (Body GPUMem -> Stms GPUMem
forall rep. Body rep -> Stms rep
bodyStms Body GPUMem
body)
onStm Stm GPUMem
_ = SegOpSizes
forall a. Monoid a => a
mempty
precomputeConstants :: Count BlockSize (Imp.TExp Int64) -> Stms GPUMem -> CallKernelGen Precomputed
precomputeConstants :: Count BlockSize (TPrimExp Int64 VName)
-> Stms GPUMem -> CallKernelGen Precomputed
precomputeConstants Count BlockSize (TPrimExp Int64 VName)
tblock_size Stms GPUMem
stms = do
let sizes :: SegOpSizes
sizes = Stms GPUMem -> SegOpSizes
segOpSizes Stms GPUMem
stms
iters_map <- [([SubExp], TExp Int32)] -> Map [SubExp] (TExp Int32)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([([SubExp], TExp Int32)] -> Map [SubExp] (TExp Int32))
-> ImpM GPUMem HostEnv HostOp [([SubExp], TExp Int32)]
-> ImpM GPUMem HostEnv HostOp (Map [SubExp] (TExp Int32))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([SubExp] -> ImpM GPUMem HostEnv HostOp ([SubExp], TExp Int32))
-> [[SubExp]]
-> ImpM GPUMem HostEnv HostOp [([SubExp], TExp Int32)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM [SubExp] -> ImpM GPUMem HostEnv HostOp ([SubExp], TExp Int32)
mkMap (SegOpSizes -> [[SubExp]]
forall a. Set a -> [a]
S.toList SegOpSizes
sizes)
pure $ Precomputed sizes iters_map
where
mkMap :: [SubExp] -> ImpM GPUMem HostEnv HostOp ([SubExp], TExp Int32)
mkMap [SubExp]
dims = do
let n :: TPrimExp Int64 VName
n = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
Imp.pe64 [SubExp]
dims
num_chunks <- SpaceId -> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall {k} (t :: k) rep r op.
SpaceId -> TExp t -> ImpM rep r op (TExp t)
dPrimVE SpaceId
"num_chunks" (TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
n TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` Count BlockSize (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall {k} (u :: k) e. Count u e -> e
unCount Count BlockSize (TPrimExp Int64 VName)
tblock_size
pure (dims, num_chunks)
precomputedConstants :: Precomputed -> InKernelGen a -> InKernelGen a
precomputedConstants :: forall a. Precomputed -> InKernelGen a -> InKernelGen a
precomputedConstants Precomputed
pre InKernelGen a
m = do
ltid <- KernelConstants -> TExp Int32
kernelLocalThreadId (KernelConstants -> TExp Int32)
-> (KernelEnv -> KernelConstants) -> KernelEnv -> TExp Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> TExp Int32)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
new_ids <- M.fromList <$> mapM (mkMap ltid) (S.toList (pcSegOpSizes pre))
let f KernelEnv
env =
KernelEnv
env
{ kernelConstants =
(kernelConstants env)
{ kernelLocalIdMap = new_ids,
kernelChunkItersMap = pcChunkItersMap pre
}
}
localEnv f m
where
mkMap :: TPrimExp t VName
-> [SubExp] -> ImpM rep r op ([SubExp], [TExp Int32])
mkMap TPrimExp t VName
ltid [SubExp]
dims = do
let dims' :: [TPrimExp Int64 VName]
dims' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims
ids' <- SpaceId
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> ImpM rep r op [TPrimExp Int64 VName]
forall rep r op.
SpaceId
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> ImpM rep r op [TPrimExp Int64 VName]
dIndexSpace' SpaceId
"ltid_pre" [TPrimExp Int64 VName]
dims' (TPrimExp t VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp t VName
ltid)
pure (dims, map sExt32 ids')