{-# LANGUAGE TypeFamilies #-}
module Futhark.CodeGen.ImpGen.GPU.SegScan.SinglePass (compileSegScan) where
import Control.Monad
import Data.List (zip4, zip7)
import Data.Map qualified as M
import Data.Maybe
import Futhark.CodeGen.ImpCode.GPU qualified as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.IR.GPUMem
import Futhark.IR.Mem.LMAD qualified as LMAD
import Futhark.Transform.Rename
import Futhark.Util (mapAccumLM, takeLast)
import Futhark.Util.IntegralExp (IntegralExp (mod, rem), divUp, nextMul, quot)
import Prelude hiding (mod, quot, rem)
xParams, yParams :: SegBinOp GPUMem -> [LParam GPUMem]
xParams :: SegBinOp GPUMem -> [LParam GPUMem]
xParams SegBinOp GPUMem
scan =
Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan)) (Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams (SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan))
yParams :: SegBinOp GPUMem -> [LParam GPUMem]
yParams SegBinOp GPUMem
scan =
Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan)) (Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams (SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan))
createLocalArrays ::
Count BlockSize SubExp ->
SubExp ->
[PrimType] ->
InKernelGen (VName, [VName], [VName], VName, [VName])
createLocalArrays :: Count BlockSize SubExp
-> SubExp
-> [PrimType]
-> InKernelGen (VName, [VName], [VName], VName, [VName])
createLocalArrays (Count SubExp
block_size) SubExp
chunk [PrimType]
types = do
let block_sizeE :: TExp Int64
block_sizeE = SubExp -> TExp Int64
pe64 SubExp
block_size
workSize :: TExp Int64
workSize = SubExp -> TExp Int64
pe64 SubExp
chunk TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
block_sizeE
prefixArraysSize :: TExp Int64
prefixArraysSize =
(TExp Int64 -> TExp Int64 -> TExp Int64)
-> TExp Int64 -> [TExp Int64] -> TExp Int64
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\TExp Int64
acc TExp Int64
tySize -> TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
nextMul TExp Int64
acc TExp Int64
tySize TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
tySize TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
block_sizeE) TExp Int64
0 ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$
(PrimType -> TExp Int64) -> [PrimType] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize [PrimType]
types
maxTransposedArraySize :: TExp Int64
maxTransposedArraySize =
(TExp Int64 -> TExp Int64 -> TExp Int64)
-> [TExp Int64] -> TExp Int64
forall a. (a -> a -> a) -> [a] -> a
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMax64 ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ (PrimType -> TExp Int64) -> [PrimType] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (\PrimType
ty -> TExp Int64
workSize TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize PrimType
ty) [PrimType]
types
warp_size :: (Num a) => a
warp_size :: forall a. Num a => a
warp_size = a
32
maxWarpExchangeSize :: TExp Int64
maxWarpExchangeSize =
(TExp Int64 -> TExp Int64 -> TExp Int64)
-> TExp Int64 -> [TExp Int64] -> TExp Int64
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\TExp Int64
acc TExp Int64
tySize -> TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
nextMul TExp Int64
acc TExp Int64
tySize TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
tySize TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* Integer -> TExp Int64
forall a. Num a => Integer -> a
fromInteger Integer
forall a. Num a => a
warp_size) TExp Int64
0 ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$
(PrimType -> TExp Int64) -> [PrimType] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize [PrimType]
types
maxLookbackSize :: TExp Int64
maxLookbackSize = TExp Int64
maxWarpExchangeSize TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
forall a. Num a => a
warp_size
size :: Count Bytes (TExp Int64)
size = TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
Imp.bytes (TExp Int64 -> Count Bytes (TExp Int64))
-> TExp Int64 -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64
maxLookbackSize TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
`sMax64` TExp Int64
prefixArraysSize TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
`sMax64` TExp Int64
maxTransposedArraySize
(_, byteOffsets) <-
(TExp Int64
-> TExp Int64
-> ImpM GPUMem KernelEnv KernelOp (TExp Int64, TExp Int64))
-> TExp Int64
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp (TExp Int64, [TExp Int64])
forall (m :: * -> *) (t :: * -> *) acc x y.
(Monad m, Traversable t) =>
(acc -> x -> m (acc, y)) -> acc -> t x -> m (acc, t y)
mapAccumLM
( \TExp Int64
off TExp Int64
tySize -> do
off' <- String -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"byte_offsets" (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
nextMul TExp Int64
off TExp Int64
tySize TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ SubExp -> TExp Int64
pe64 SubExp
block_size TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
tySize
pure (off', off)
)
TExp Int64
0
([TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp (TExp Int64, [TExp Int64]))
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp (TExp Int64, [TExp Int64])
forall a b. (a -> b) -> a -> b
$ (PrimType -> TExp Int64) -> [PrimType] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize [PrimType]
types
(_, warpByteOffsets) <-
mapAccumLM
( \TExp Int64
off TExp Int64
tySize -> do
off' <- String -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"warp_byte_offset" (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
nextMul TExp Int64
off TExp Int64
tySize TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
forall a. Num a => a
warp_size TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
tySize
pure (off', off)
)
warp_size
$ map primByteSize types
sComment "Allocate reusable shared memory" $ pure ()
localMem <- sAlloc "local_mem" size (Space "shared")
transposeArrayLength <- dPrimV "trans_arr_len" workSize
sharedId <- sArrayInMem "shared_id" int32 (Shape [constant (1 :: Int32)]) localMem
transposedArrays <-
forM types $ \PrimType
ty ->
String
-> PrimType
-> ShapeBase SubExp
-> VName
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> VName -> ImpM rep r op VName
sArrayInMem
String
"local_transpose_arr"
PrimType
ty
([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [TV Int64 -> SubExp
forall {k} (t :: k). TV t -> SubExp
tvSize TV Int64
transposeArrayLength])
VName
localMem
prefixArrays <-
forM (zip byteOffsets types) $ \(TExp Int64
off, PrimType
ty) -> do
let off' :: TExp Int64
off' = TExp Int64
off TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`quot` PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize PrimType
ty
String
-> PrimType
-> ShapeBase SubExp
-> VName
-> LMAD
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String
-> PrimType
-> ShapeBase SubExp
-> VName
-> LMAD
-> ImpM rep r op VName
sArray
String
"local_prefix_arr"
PrimType
ty
([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
block_size])
VName
localMem
(LMAD -> ImpM GPUMem KernelEnv KernelOp VName)
-> LMAD -> ImpM GPUMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> [TExp Int64] -> LMAD
forall num. IntegralExp num => num -> [num] -> LMAD num
LMAD.iota TExp Int64
off' [SubExp -> TExp Int64
pe64 SubExp
block_size]
warpscan <- sArrayInMem "warpscan" int8 (Shape [constant (warp_size :: Int64)]) localMem
warpExchanges <-
forM (zip warpByteOffsets types) $ \(TExp Int64
off, PrimType
ty) -> do
let off' :: TExp Int64
off' = TExp Int64
off TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`quot` PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize PrimType
ty
String
-> PrimType
-> ShapeBase SubExp
-> VName
-> LMAD
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String
-> PrimType
-> ShapeBase SubExp
-> VName
-> LMAD
-> ImpM rep r op VName
sArray
String
"warp_exchange"
PrimType
ty
([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
forall a. Num a => a
warp_size :: Int64)])
VName
localMem
(LMAD -> ImpM GPUMem KernelEnv KernelOp VName)
-> LMAD -> ImpM GPUMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> [TExp Int64] -> LMAD
forall num. IntegralExp num => num -> [num] -> LMAD num
LMAD.iota TExp Int64
off' [TExp Int64
forall a. Num a => a
warp_size]
pure (sharedId, transposedArrays, prefixArrays, warpscan, warpExchanges)
statusX, statusA, statusP :: (Num a) => a
statusX :: forall a. Num a => a
statusX = a
0
statusA :: forall a. Num a => a
statusA = a
1
statusP :: forall a. Num a => a
statusP = a
2
inBlockScanLookback ::
KernelConstants ->
Imp.TExp Int64 ->
VName ->
[VName] ->
Lambda GPUMem ->
InKernelGen ()
inBlockScanLookback :: KernelConstants
-> TExp Int64
-> VName
-> [VName]
-> Lambda GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
inBlockScanLookback KernelConstants
constants TExp Int64
arrs_full_size VName
flag_arr [VName]
arrs Lambda GPUMem
scan_lam = ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
flg_x :: TV Int8 <- String -> ImpM GPUMem KernelEnv KernelOp (TV Int8)
forall {k} (t :: k) rep r op.
MkTV t =>
String -> ImpM rep r op (TV t)
dPrim String
"flg_x"
flg_y :: TV Int8 <- dPrim "flg_y"
let flg_param_x = Attrs
-> VName
-> MemInfo SubExp NoUniqueness MemBind
-> Param (MemInfo SubExp NoUniqueness MemBind)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty (TV Int8 -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV Int8
flg_x) (PrimType -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
p_int8)
flg_param_y = Attrs
-> VName
-> MemInfo SubExp NoUniqueness MemBind
-> Param (MemInfo SubExp NoUniqueness MemBind)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty (TV Int8 -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV Int8
flg_y) (PrimType -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
p_int8)
flg_y_exp = TV Int8 -> TPrimExp Int8 VName
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int8
flg_y
statusP_e = TPrimExp Int8 VName
forall a. Num a => a
statusP :: Imp.TExp Int8
statusX_e = TPrimExp Int8 VName
forall a. Num a => a
statusX :: Imp.TExp Int8
dLParams (lambdaParams scan_lam)
skip_threads <- dPrim "skip_threads"
let in_block_thread_active =
TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
skip_threads TExp Int32 -> TExp Int32 -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int32
in_block_id
actual_params = Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
scan_lam
(x_params, y_params) =
splitAt (length actual_params `div` 2) actual_params
y_to_x =
[(Param (MemInfo SubExp NoUniqueness MemBind),
Param (MemInfo SubExp NoUniqueness MemBind))]
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
Param (MemInfo SubExp NoUniqueness MemBind))
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [(Param (MemInfo SubExp NoUniqueness MemBind),
Param (MemInfo SubExp NoUniqueness MemBind))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params [Param (MemInfo SubExp NoUniqueness MemBind)]
y_params) (((Param (MemInfo SubExp NoUniqueness MemBind),
Param (MemInfo SubExp NoUniqueness MemBind))
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
Param (MemInfo SubExp NoUniqueness MemBind))
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
x, Param (MemInfo SubExp NoUniqueness MemBind)
y) ->
Bool
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Param (MemInfo SubExp NoUniqueness MemBind)
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
Param dec -> TypeBase (ShapeBase SubExp) NoUniqueness
paramType Param (MemInfo SubExp NoUniqueness MemBind)
x)) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
x) [] (VName -> SubExp
Var (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
y)) []
y_to_x_flg =
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (TV Int8 -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV Int8
flg_x) [] (VName -> SubExp
Var (TV Int8 -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV Int8
flg_y)) []
sComment "read input for in-block scan" $ do
zipWithM_ readInitial (flg_param_y : y_params) (flag_arr : arrs)
sWhen (in_block_id .==. 0) $ do
y_to_x
y_to_x_flg
when array_scan barrier
let op_to_x = do
TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
(TPrimExp Int8 VName
flg_y_exp TPrimExp Int8 VName -> TPrimExp Int8 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int8 VName
statusP_e TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. TPrimExp Int8 VName
flg_y_exp TPrimExp Int8 VName -> TPrimExp Int8 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int8 VName
statusX_e)
( do
ImpM GPUMem KernelEnv KernelOp ()
y_to_x_flg
ImpM GPUMem KernelEnv KernelOp ()
y_to_x
)
([Param (MemInfo SubExp NoUniqueness MemBind)]
-> Body GPUMem -> ImpM GPUMem KernelEnv KernelOp ()
forall dec rep r op. [Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params (Body GPUMem -> ImpM GPUMem KernelEnv KernelOp ())
-> Body GPUMem -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scan_lam)
sComment "in-block scan (hopefully no barriers needed)" $ do
skip_threads <-- 1
sWhile (tvExp skip_threads .<. block_size) $ do
sWhen in_block_thread_active $ do
sComment "read operands" $
zipWithM_
(readParam (sExt64 $ tvExp skip_threads))
(flg_param_x : x_params)
(flag_arr : arrs)
sComment "perform operation" op_to_x
sComment "write result" $
sequence_ $
zipWith3
writeResult
(flg_param_x : x_params)
(flg_param_y : y_params)
(flag_arr : arrs)
skip_threads <-- tvExp skip_threads * 2
where
p_int8 :: PrimType
p_int8 = IntType -> PrimType
IntType IntType
Int8
block_size :: TExp Int32
block_size = TExp Int32
32
block_id :: TExp Int32
block_id = TExp Int32
ltid32 TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int32
block_size
in_block_id :: TExp Int32
in_block_id = TExp Int32
ltid32 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
block_id TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
block_size
ltid32 :: TExp Int32
ltid32 = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
ltid :: TExp Int64
ltid = TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
ltid32
gtid :: TExp Int64
gtid = TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelGlobalThreadId KernelConstants
constants
array_scan :: Bool
array_scan = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (TypeBase (ShapeBase SubExp) NoUniqueness -> Bool)
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([TypeBase (ShapeBase SubExp) NoUniqueness] -> Bool)
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall rep.
Lambda rep -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda GPUMem
scan_lam
barrier :: ImpM GPUMem KernelEnv KernelOp ()
barrier
| Bool
array_scan =
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> ImpM GPUMem KernelEnv KernelOp ())
-> KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceGlobal
| Bool
otherwise =
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> ImpM GPUMem KernelEnv KernelOp ())
-> KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
readInitial :: Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> ImpM GPUMem KernelEnv KernelOp ()
readInitial Param (MemInfo SubExp NoUniqueness MemBind)
p VName
arr
| TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase (ShapeBase SubExp) NoUniqueness -> Bool)
-> TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind)
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
Param dec -> TypeBase (ShapeBase SubExp) NoUniqueness
paramType Param (MemInfo SubExp NoUniqueness MemBind)
p =
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var VName
arr) [TExp Int64
ltid]
| Bool
otherwise =
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var VName
arr) [TExp Int64
gtid]
readParam :: TExp Int64
-> Param (MemInfo SubExp NoUniqueness MemBind)
-> VName
-> ImpM GPUMem KernelEnv KernelOp ()
readParam TExp Int64
behind Param (MemInfo SubExp NoUniqueness MemBind)
p VName
arr
| TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase (ShapeBase SubExp) NoUniqueness -> Bool)
-> TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind)
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
Param dec -> TypeBase (ShapeBase SubExp) NoUniqueness
paramType Param (MemInfo SubExp NoUniqueness MemBind)
p =
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var VName
arr) [TExp Int64
ltid TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
behind]
| Bool
otherwise =
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var VName
arr) [TExp Int64
gtid TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
behind TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
arrs_full_size]
writeResult :: Param (MemInfo SubExp NoUniqueness MemBind)
-> Param (MemInfo SubExp NoUniqueness MemBind)
-> VName
-> ImpM GPUMem KernelEnv KernelOp ()
writeResult Param (MemInfo SubExp NoUniqueness MemBind)
x Param (MemInfo SubExp NoUniqueness MemBind)
y VName
arr = do
Bool
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Param (MemInfo SubExp NoUniqueness MemBind) -> Bool
forall p. Typed p => Param p -> Bool
isPrimParam Param (MemInfo SubExp NoUniqueness MemBind)
x) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
arr [TExp Int64
ltid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
x) []
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
y) [] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
x) []
compileSegScan ::
Pat LetDecMem ->
SegLevel ->
SegSpace ->
SegBinOp GPUMem ->
KernelBody GPUMem ->
CallKernelGen ()
compileSegScan :: Pat (MemInfo SubExp NoUniqueness MemBind)
-> SegLevel
-> SegSpace
-> SegBinOp GPUMem
-> KernelBody GPUMem
-> CallKernelGen ()
compileSegScan Pat (MemInfo SubExp NoUniqueness MemBind)
pat SegLevel
lvl SegSpace
space SegBinOp GPUMem
scan_op KernelBody GPUMem
map_kbody = do
attrs <- SegLevel -> CallKernelGen KernelAttrs
lvlKernelAttrs SegLevel
lvl
let Pat all_pes = pat
scanop_nes = SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan_op
n = [TExp Int64] -> TExp Int64
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
tys' = Lambda GPUMem -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall rep.
Lambda rep -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType (Lambda GPUMem -> [TypeBase (ShapeBase SubExp) NoUniqueness])
-> Lambda GPUMem -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a b. (a -> b) -> a -> b
$ SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan_op
tys = (TypeBase (ShapeBase SubExp) NoUniqueness -> PrimType)
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> [PrimType]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase (ShapeBase SubExp) NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType [TypeBase (ShapeBase SubExp) NoUniqueness]
tys'
tblock_size_e = SubExp -> TExp Int64
pe64 (SubExp -> TExp Int64) -> SubExp -> TExp Int64
forall a b. (a -> b) -> a -> b
$ Count BlockSize SubExp -> SubExp
forall {k} (u :: k) e. Count u e -> e
unCount (Count BlockSize SubExp -> SubExp)
-> Count BlockSize SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ KernelAttrs -> Count BlockSize SubExp
kAttrBlockSize KernelAttrs
attrs
num_phys_blocks_e = SubExp -> TExp Int64
pe64 (SubExp -> TExp Int64) -> SubExp -> TExp Int64
forall a b. (a -> b) -> a -> b
$ Count NumBlocks SubExp -> SubExp
forall {k} (u :: k) e. Count u e -> e
unCount (Count NumBlocks SubExp -> SubExp)
-> Count NumBlocks SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ KernelAttrs -> Count NumBlocks SubExp
kAttrNumBlocks KernelAttrs
attrs
let chunk_const = [TypeBase (ShapeBase SubExp) NoUniqueness] -> KernelConstExp
getChunkSize [TypeBase (ShapeBase SubExp) NoUniqueness]
tys'
chunk_v <- dPrimV "chunk_size" . isInt64 =<< kernelConstToExp chunk_const
let chunk = TV Int64 -> TExp Int64
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
chunk_v
num_virt_blocks <-
tvSize <$> dPrimV "num_virt_blocks" (n `divUp` (tblock_size_e * chunk))
let num_virt_blocks_e = SubExp -> TExp Int64
pe64 SubExp
num_virt_blocks
num_virt_threads <-
dPrimVE "num_virt_threads" $ num_virt_blocks_e * tblock_size_e
let (gtids, dims) = unzip $ unSegSpace space
dims' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
dims
segmented = [TExp Int64] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TExp Int64]
dims' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1
not_segmented_e = Bool -> TPrimExp Bool VName
forall v. Bool -> TPrimExp Bool v
fromBool (Bool -> TPrimExp Bool VName) -> Bool -> TPrimExp Bool VName
forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not Bool
segmented
segment_size = [TExp Int64] -> TExp Int64
forall a. HasCallStack => [a] -> a
last [TExp Int64]
dims'
emit $ Imp.DebugPrint "Sequential elements per thread (chunk)" $ Just $ untyped chunk
statusFlags <- sAllocArray "status_flags" int8 (Shape [num_virt_blocks]) (Space "device")
sReplicate statusFlags $ intConst Int8 statusX
(aggregateArrays, incprefixArrays) <-
fmap unzip $
forM tys $ \PrimType
ty ->
(,)
(VName -> VName -> (VName, VName))
-> ImpM GPUMem HostEnv HostOp VName
-> ImpM GPUMem HostEnv HostOp (VName -> (VName, VName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String
-> PrimType
-> ShapeBase SubExp
-> Space
-> ImpM GPUMem HostEnv HostOp VName
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray String
"aggregates" PrimType
ty ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
num_virt_blocks]) (String -> Space
Space String
"device")
ImpM GPUMem HostEnv HostOp (VName -> (VName, VName))
-> ImpM GPUMem HostEnv HostOp VName
-> ImpM GPUMem HostEnv HostOp (VName, VName)
forall a b.
ImpM GPUMem HostEnv HostOp (a -> b)
-> ImpM GPUMem HostEnv HostOp a -> ImpM GPUMem HostEnv HostOp b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> String
-> PrimType
-> ShapeBase SubExp
-> Space
-> ImpM GPUMem HostEnv HostOp VName
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray String
"incprefixes" PrimType
ty ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
num_virt_blocks]) (String -> Space
Space String
"device")
global_id <- genZeroes "global_dynid" 1
let attrs' = KernelAttrs
attrs {kAttrConstExps = M.singleton (tvVar chunk_v) chunk_const}
sKernelThread "segscan" (segFlat space) attrs' $ do
chunk32 <- dPrimVE "chunk_size_32b" $ sExt32 $ tvExp chunk_v
constants <- kernelConstants <$> askEnv
let ltid32 = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
ltid = TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
ltid32
(sharedId, transposedArrays, prefixArrays, warpscan, exchanges) <-
createLocalArrays (kAttrBlockSize attrs) (tvSize chunk_v) tys
phys_block_id <- dPrim "phys_block_id"
sOp $ Imp.GetBlockId (tvVar phys_block_id) 0
iters <-
dPrimVE "virtloop_bound" $
(num_virt_blocks_e - tvExp phys_block_id)
`divUp` num_phys_blocks_e
sFor "virtloop_i" iters $ const $ do
dyn_id <- dPrim "dynamic_id"
sComment "First thread in block fetches this block's dynamic_id" $ do
sWhen (ltid32 .==. 0) $ do
(globalIdMem, _, globalIdOff) <- fullyIndexArray global_id [0]
sOp $
Imp.Atomic DefaultSpace $
Imp.AtomicAdd
Int32
(tvVar dyn_id)
globalIdMem
(Count $ unCount globalIdOff)
(untyped (1 :: Imp.TExp Int32))
sComment "Set dynamic id for this block" $ do
copyDWIMFix sharedId [0] (tvSize dyn_id) []
sComment "First thread in last (virtual) block resets global dynamic_id" $ do
sWhen (tvExp dyn_id .==. num_virt_blocks_e - 1) $
copyDWIMFix global_id [0] (intConst Int32 0) []
let local_barrier = Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
local_fence = Fence -> KernelOp
Imp.MemFence Fence
Imp.FenceLocal
global_fence = Fence -> KernelOp
Imp.MemFence Fence
Imp.FenceGlobal
sOp local_barrier
copyDWIMFix (tvVar dyn_id) [] (Var sharedId) [0]
sOp local_barrier
block_offset <-
dPrimVE "block_offset" $
sExt64 (tvExp dyn_id) * chunk * tblock_size_e
sgm_idx <- dPrimVE "sgm_idx" $ block_offset `mod` segment_size
boundary <-
dPrimVE "boundary" $
sExt32 $
sMin64 (chunk * tblock_size_e) (segment_size - sgm_idx)
segsize_compact <-
dPrimVE "segsize_compact" $
sExt32 $
sMin64 (chunk * tblock_size_e) segment_size
private_chunks <-
forM tys $ \PrimType
ty ->
String
-> PrimType
-> ShapeBase SubExp
-> Space
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray
String
"private"
PrimType
ty
([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [TV Int64 -> SubExp
forall {k} (t :: k). TV t -> SubExp
tvSize TV Int64
chunk_v])
([SubExp] -> PrimType -> Space
ScalarSpace [TV Int64 -> SubExp
forall {k} (t :: k). TV t -> SubExp
tvSize TV Int64
chunk_v] PrimType
ty)
thd_offset <- dPrimVE "thd_offset" $ block_offset + ltid
sComment "Load and map" $
sFor "i" chunk $ \TExp Int64
i -> do
virt_tid <- String -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"virt_tid" (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64
thd_offset TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
tblock_size_e
dIndexSpace (zip gtids dims') virt_tid
let in_bounds =
Names
-> Stms GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
map_kbody) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
let ([KernelResult]
all_scan_res, [KernelResult]
map_res) =
Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp GPUMem] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp GPUMem
scan_op]) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
map_kbody
[(PatElem (MemInfo SubExp NoUniqueness MemBind), KernelResult)]
-> ((PatElem (MemInfo SubExp NoUniqueness MemBind), KernelResult)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> [KernelResult]
-> [(PatElem (MemInfo SubExp NoUniqueness MemBind), KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Int
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
forall a. Int -> [a] -> [a]
takeLast ([KernelResult] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
map_res) [PatElem (MemInfo SubExp NoUniqueness MemBind)]
all_pes) [KernelResult]
map_res) (((PatElem (MemInfo SubExp NoUniqueness MemBind), KernelResult)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((PatElem (MemInfo SubExp NoUniqueness MemBind), KernelResult)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(PatElem (MemInfo SubExp NoUniqueness MemBind)
dest, KernelResult
src) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (PatElem (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (MemInfo SubExp NoUniqueness MemBind)
dest) ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gtids) (KernelResult -> SubExp
kernelResultSubExp KernelResult
src) []
[(VName, SubExp)]
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
private_chunks ([SubExp] -> [(VName, SubExp)]) -> [SubExp] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
all_scan_res) (((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
dest, SubExp
src) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
dest [TExp Int64
i] SubExp
src []
out_of_bounds =
[(VName, SubExp)]
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
private_chunks [SubExp]
scanop_nes) (((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
dest, SubExp
ne) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
dest [TExp Int64
i] SubExp
ne []
sIf (virt_tid .<. n) in_bounds out_of_bounds
sOp $ Imp.ErrorSync Imp.FenceLocal
sComment "Transpose scan inputs" $ do
forM_ (zip transposedArrays private_chunks) $ \(VName
trans, VName
priv) -> do
String
-> TExp Int64
-> (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall {k} (t :: k) rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TExp Int64
chunk ((TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
sharedIdx <- String -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"sharedIdx" (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64
ltid TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
tblock_size_e
copyDWIMFix trans [sharedIdx] (Var priv) [i]
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
local_barrier
String
-> TExp Int64
-> (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall {k} (t :: k) rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TExp Int64
chunk ((TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
sharedIdx <- String -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"sharedIdx" (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64
ltid TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
chunk TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
i
copyDWIMFix priv [sExt64 i] (Var trans) [sExt64 $ tvExp sharedIdx]
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
local_barrier
sComment "Per thread scan" $ do
sFor "i" (chunk - 1) $ \TExp Int64
i -> do
let xs :: [VName]
xs = (LParam GPUMem -> VName) -> [LParam GPUMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map LParam GPUMem -> VName
Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName ([LParam GPUMem] -> [VName]) -> [LParam GPUMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegBinOp GPUMem -> [LParam GPUMem]
xParams SegBinOp GPUMem
scan_op
ys :: [VName]
ys = (LParam GPUMem -> VName) -> [LParam GPUMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map LParam GPUMem -> VName
Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName ([LParam GPUMem] -> [VName]) -> [LParam GPUMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegBinOp GPUMem -> [LParam GPUMem]
yParams SegBinOp GPUMem
scan_op
new_sgm <-
if Bool
segmented
then do
gidx <- String -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"gidx" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ (TExp Int32
ltid32 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
chunk32) TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
1
dPrimVE "new_sgm" $ (gidx + sExt32 i - boundary) `mod` segsize_compact .==. 0
else TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Bool VName)
forall a. a -> ImpM GPUMem KernelEnv KernelOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TPrimExp Bool VName
forall v. TPrimExp Bool v
false
sUnless new_sgm $ do
forM_ (zip4 private_chunks xs ys tys) $ \(VName
src, VName
x, VName
y, PrimType
ty) -> do
VName -> PrimType -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
x PrimType
ty
VName -> PrimType -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
y PrimType
ty
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
x [] (VName -> SubExp
Var VName
src) [TExp Int64
i]
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
y [] (VName -> SubExp
Var VName
src) [TExp Int64
i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
1]
compileStms mempty (bodyStms $ lambdaBody $ segBinOpLambda scan_op) $
forM_ (zip private_chunks $ map resSubExp $ bodyResult $ lambdaBody $ segBinOpLambda scan_op) $ \(VName
dest, SubExp
res) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
dest [TExp Int64
i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
1] SubExp
res []
sComment "Publish results in shared memory" $ do
forM_ (zip prefixArrays private_chunks) $ \(VName
dest, VName
src) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
dest [TExp Int64
ltid] (VName -> SubExp
Var VName
src) [TExp Int64
chunk TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1]
sOp local_barrier
let crossesSegment = do
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard Bool
segmented
(TExp Int32 -> TExp Int32 -> TPrimExp Bool VName)
-> Maybe (TExp Int32 -> TExp Int32 -> TPrimExp Bool VName)
forall a. a -> Maybe a
Just ((TExp Int32 -> TExp Int32 -> TPrimExp Bool VName)
-> Maybe (TExp Int32 -> TExp Int32 -> TPrimExp Bool VName))
-> (TExp Int32 -> TExp Int32 -> TPrimExp Bool VName)
-> Maybe (TExp Int32 -> TExp Int32 -> TPrimExp Bool VName)
forall a b. (a -> b) -> a -> b
$ \TExp Int32
from TExp Int32
to ->
let from' :: TExp Int32
from' = (TExp Int32
from TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
1) TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
chunk32 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1
to' :: TExp Int32
to' = (TExp Int32
to TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
1) TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
chunk32 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1
in (TExp Int32
to' TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
from') TExp Int32 -> TExp Int32 -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (TExp Int32
to' TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
segsize_compact TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
boundary) TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`mod` TExp Int32
segsize_compact
scan_op1 <- renameLambda $ segBinOpLambda scan_op
accs <- mapM (dPrimSV "acc") tys
sComment "Scan results (with warp scan)" $ do
blockScan
crossesSegment
tblock_size_e
num_virt_threads
scan_op1
prefixArrays
sOp $ Imp.ErrorSync Imp.FenceLocal
let firstThread TV (ZonkAny 1)
acc VName
prefixes =
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (TV (ZonkAny 1) -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV (ZonkAny 1)
acc) [] (VName -> SubExp
Var VName
prefixes) [TExp Int64 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
tblock_size_e TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1]
notFirstThread TV (ZonkAny 1)
acc VName
prefixes =
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (TV (ZonkAny 1) -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV (ZonkAny 1)
acc) [] (VName -> SubExp
Var VName
prefixes) [TExp Int64
ltid TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1]
sIf
(ltid32 .==. 0)
(zipWithM_ firstThread accs prefixArrays)
(zipWithM_ notFirstThread accs prefixArrays)
sOp local_barrier
prefixes <- forM (zip scanop_nes tys) $ \(SubExp
ne, PrimType
ty) ->
String
-> TExp (ZonkAny 3)
-> ImpM GPUMem KernelEnv KernelOp (TV (ZonkAny 3))
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"prefix" (TExp (ZonkAny 3)
-> ImpM GPUMem KernelEnv KernelOp (TV (ZonkAny 3)))
-> TExp (ZonkAny 3)
-> ImpM GPUMem KernelEnv KernelOp (TV (ZonkAny 3))
forall a b. (a -> b) -> a -> b
$ PrimExp VName -> TExp (ZonkAny 3)
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp VName -> TExp (ZonkAny 3))
-> PrimExp VName -> TExp (ZonkAny 3)
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> PrimExp VName
forall a. ToExp a => PrimType -> a -> PrimExp VName
toExp' PrimType
ty SubExp
ne
blockNewSgm <- dPrimVE "block_new_sgm" $ sgm_idx .==. 0
sComment "Perform lookback" $ do
sWhen (blockNewSgm .&&. ltid32 .==. 0) $ do
everythingVolatile $
forM_ (zip accs incprefixArrays) $ \(TV (ZonkAny 1)
acc, VName
incprefixArray) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
incprefixArray [TV Int64 -> TExp Int64
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dyn_id] (TV (ZonkAny 1) -> SubExp
forall {k} (t :: k). TV t -> SubExp
tvSize TV (ZonkAny 1)
acc) []
sOp global_fence
everythingVolatile $
copyDWIMFix statusFlags [tvExp dyn_id] (intConst Int8 statusP) []
forM_ (zip scanop_nes accs) $ \(SubExp
ne, TV (ZonkAny 1)
acc) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (TV (ZonkAny 1) -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV (ZonkAny 1)
acc) [] SubExp
ne []
let warp_size = KernelConstants -> TExp Int32
kernelWaveSize KernelConstants
constants
sWhen (bNot blockNewSgm .&&. ltid32 .<. warp_size) $ do
sWhen (ltid32 .==. 0) $ do
sIf
(not_segmented_e .||. boundary .==. sExt32 (tblock_size_e * chunk))
( do
everythingVolatile $
forM_ (zip aggregateArrays accs) $ \(VName
aggregateArray, TV (ZonkAny 1)
acc) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
aggregateArray [TV Int64 -> TExp Int64
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dyn_id] (TV (ZonkAny 1) -> SubExp
forall {k} (t :: k). TV t -> SubExp
tvSize TV (ZonkAny 1)
acc) []
sOp global_fence
everythingVolatile $
copyDWIMFix statusFlags [tvExp dyn_id] (intConst Int8 statusA) []
)
( do
everythingVolatile $
forM_ (zip incprefixArrays accs) $ \(VName
incprefixArray, TV (ZonkAny 1)
acc) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
incprefixArray [TV Int64 -> TExp Int64
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dyn_id] (TV (ZonkAny 1) -> SubExp
forall {k} (t :: k). TV t -> SubExp
tvSize TV (ZonkAny 1)
acc) []
sOp global_fence
everythingVolatile $
copyDWIMFix statusFlags [tvExp dyn_id] (intConst Int8 statusP) []
)
everythingVolatile $
copyDWIMFix warpscan [0] (Var statusFlags) [tvExp dyn_id - 1]
sOp local_fence
status :: TV Int8 <- dPrim "status"
copyDWIMFix (tvVar status) [] (Var warpscan) [0]
sIf
(tvExp status .==. statusP)
( sWhen (ltid32 .==. 0) $
everythingVolatile $
forM_ (zip prefixes incprefixArrays) $ \(TV (ZonkAny 3)
prefix, VName
incprefixArray) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (TV (ZonkAny 3) -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV (ZonkAny 3)
prefix) [] (VName -> SubExp
Var VName
incprefixArray) [TV Int64 -> TExp Int64
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dyn_id TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1]
)
( do
readOffset <-
dPrimV "readOffset" $
sExt32 $
tvExp dyn_id - sExt64 (kernelWaveSize constants)
let loopStop = TExp Int32
warp_size TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* (-TExp Int32
1)
sameSegment TV Int32
readIdx
| Bool
segmented =
let startIdx :: TExp Int64
startIdx = TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
readIdx TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
1) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
tblock_size_e TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
chunk TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1
in TExp Int64
block_offset TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
startIdx TExp Int64 -> TExp Int64 -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int64
sgm_idx
| Bool
otherwise = TPrimExp Bool VName
forall v. TPrimExp Bool v
true
sWhile (tvExp readOffset .>. loopStop) $ do
readI <- dPrimV "read_i" $ tvExp readOffset + ltid32
aggrs <- forM (zip scanop_nes tys) $ \(SubExp
ne, PrimType
ty) ->
String
-> TExp (ZonkAny 5)
-> ImpM GPUMem KernelEnv KernelOp (TV (ZonkAny 5))
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"aggr" (TExp (ZonkAny 5)
-> ImpM GPUMem KernelEnv KernelOp (TV (ZonkAny 5)))
-> TExp (ZonkAny 5)
-> ImpM GPUMem KernelEnv KernelOp (TV (ZonkAny 5))
forall a b. (a -> b) -> a -> b
$ PrimExp VName -> TExp (ZonkAny 5)
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp VName -> TExp (ZonkAny 5))
-> PrimExp VName -> TExp (ZonkAny 5)
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> PrimExp VName
forall a. ToExp a => PrimType -> a -> PrimExp VName
toExp' PrimType
ty SubExp
ne
flag <- dPrimV "flag" (statusX :: Imp.TExp Int8)
everythingVolatile . sWhen (tvExp readI .>=. 0) $ do
sIf
(sameSegment readI)
( do
copyDWIMFix (tvVar flag) [] (Var statusFlags) [sExt64 $ tvExp readI]
sIf
(tvExp flag .==. statusP)
( forM_ (zip incprefixArrays aggrs) $ \(VName
incprefix, TV (ZonkAny 5)
aggr) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (TV (ZonkAny 5) -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV (ZonkAny 5)
aggr) [] (VName -> SubExp
Var VName
incprefix) [TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
readI]
)
( sWhen (tvExp flag .==. statusA) $ do
forM_ (zip aggrs aggregateArrays) $ \(TV (ZonkAny 5)
aggr, VName
aggregate) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (TV (ZonkAny 5) -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV (ZonkAny 5)
aggr) [] (VName -> SubExp
Var VName
aggregate) [TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
readI]
)
)
(copyDWIMFix (tvVar flag) [] (intConst Int8 statusP) [])
forM_ (zip exchanges aggrs) $ \(VName
exchange, TV (ZonkAny 5)
aggr) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
exchange [TExp Int64
ltid] (TV (ZonkAny 5) -> SubExp
forall {k} (t :: k). TV t -> SubExp
tvSize TV (ZonkAny 5)
aggr) []
copyDWIMFix warpscan [ltid] (tvSize flag) []
copyDWIMFix (tvVar flag) [] (Var warpscan) [sExt64 warp_size - 1]
sWhen (tvExp flag .<. statusP) $ do
lam' <- renameLambda scan_op1
inBlockScanLookback
constants
num_virt_threads
warpscan
exchanges
lam'
copyDWIMFix (tvVar flag) [] (Var warpscan) [sExt64 warp_size - 1]
forM_ (zip aggrs exchanges) $ \(TV (ZonkAny 5)
aggr, VName
exchange) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (TV (ZonkAny 5) -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV (ZonkAny 5)
aggr) [] (VName -> SubExp
Var VName
exchange) [TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
warp_size TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1]
sIf
(tvExp flag .==. statusP)
(readOffset <-- loopStop)
( sWhen (tvExp flag .==. statusA) $ do
readOffset <-- tvExp readOffset - zExt32 warp_size
)
sWhen (tvExp flag .>. statusX) $ do
lam <- renameLambda scan_op1
let (xs, ys) = splitAt (length tys) $ map paramName $ lambdaParams lam
forM_ (zip xs aggrs) $ \(VName
x, TV (ZonkAny 5)
aggr) -> VName -> TExp (ZonkAny 5) -> ImpM GPUMem KernelEnv KernelOp ()
forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
x (TV (ZonkAny 5) -> TExp (ZonkAny 5)
forall {k} (t :: k). TV t -> TExp t
tvExp TV (ZonkAny 5)
aggr)
forM_ (zip ys prefixes) $ \(VName
y, TV (ZonkAny 3)
prefix) -> VName -> TExp (ZonkAny 3) -> ImpM GPUMem KernelEnv KernelOp ()
forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
y (TV (ZonkAny 3) -> TExp (ZonkAny 3)
forall {k} (t :: k). TV t -> TExp t
tvExp TV (ZonkAny 3)
prefix)
compileStms mempty (bodyStms $ lambdaBody lam) $
forM_ (zip3 prefixes tys $ map resSubExp $ bodyResult $ lambdaBody lam) $
\(TV (ZonkAny 3)
prefix, PrimType
ty, SubExp
res) -> TV (ZonkAny 3)
prefix TV (ZonkAny 3)
-> TExp (ZonkAny 3) -> ImpM GPUMem KernelEnv KernelOp ()
forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- PrimExp VName -> TExp (ZonkAny 3)
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (PrimType -> SubExp -> PrimExp VName
forall a. ToExp a => PrimType -> a -> PrimExp VName
toExp' PrimType
ty SubExp
res)
sOp local_fence
)
sWhen (ltid32 .==. 0) $ do
scan_op2 <- renameLambda scan_op1
let xs = (Param (MemInfo SubExp NoUniqueness MemBind) -> VName)
-> [Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName ([Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName])
-> [Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a. Int -> [a] -> [a]
take ([PrimType] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimType]
tys) ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)])
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
scan_op2
ys = (Param (MemInfo SubExp NoUniqueness MemBind) -> VName)
-> [Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName ([Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName])
-> [Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a. Int -> [a] -> [a]
drop ([PrimType] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimType]
tys) ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)])
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
scan_op2
sWhen (boundary .==. sExt32 (tblock_size_e * chunk)) $ do
forM_ (zip xs prefixes) $ \(VName
x, TV (ZonkAny 3)
prefix) -> VName -> TExp (ZonkAny 3) -> ImpM GPUMem KernelEnv KernelOp ()
forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
x (TExp (ZonkAny 3) -> ImpM GPUMem KernelEnv KernelOp ())
-> TExp (ZonkAny 3) -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ TV (ZonkAny 3) -> TExp (ZonkAny 3)
forall {k} (t :: k). TV t -> TExp t
tvExp TV (ZonkAny 3)
prefix
forM_ (zip ys accs) $ \(VName
y, TV (ZonkAny 1)
acc) -> VName -> TExp (ZonkAny 1) -> ImpM GPUMem KernelEnv KernelOp ()
forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
y (TExp (ZonkAny 1) -> ImpM GPUMem KernelEnv KernelOp ())
-> TExp (ZonkAny 1) -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ TV (ZonkAny 1) -> TExp (ZonkAny 1)
forall {k} (t :: k). TV t -> TExp t
tvExp TV (ZonkAny 1)
acc
compileStms mempty (bodyStms $ lambdaBody scan_op2) $
everythingVolatile $
forM_ (zip incprefixArrays $ map resSubExp $ bodyResult $ lambdaBody scan_op2) $
\(VName
incprefixArray, SubExp
res) -> VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
incprefixArray [TV Int64 -> TExp Int64
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dyn_id] SubExp
res []
sOp global_fence
everythingVolatile $ copyDWIMFix statusFlags [tvExp dyn_id] (intConst Int8 statusP) []
forM_ (zip exchanges prefixes) $ \(VName
exchange, TV (ZonkAny 3)
prefix) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
exchange [TExp Int64
0] (TV (ZonkAny 3) -> SubExp
forall {k} (t :: k). TV t -> SubExp
tvSize TV (ZonkAny 3)
prefix) []
forM_ (zip3 accs tys scanop_nes) $ \(TV (ZonkAny 1)
acc, PrimType
ty, SubExp
ne) ->
TV (ZonkAny 1) -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV (ZonkAny 1)
acc VName -> PrimExp VName -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. VName -> PrimExp VName -> ImpM rep r op ()
<~~ PrimType -> SubExp -> PrimExp VName
forall a. ToExp a => PrimType -> a -> PrimExp VName
toExp' PrimType
ty SubExp
ne
sWhen (bNot $ tvExp dyn_id .==. 0) $ do
sOp local_barrier
forM_ (zip exchanges prefixes) $ \(VName
exchange, TV (ZonkAny 3)
prefix) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (TV (ZonkAny 3) -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV (ZonkAny 3)
prefix) [] (VName -> SubExp
Var VName
exchange) [TExp Int64
0]
sOp local_barrier
scan_op3 <- renameLambda scan_op1
scan_op4 <- renameLambda scan_op1
sComment "Distribute results" $ do
let (xs, ys) = splitAt (length tys) $ map paramName $ lambdaParams scan_op3
(xs', ys') = splitAt (length tys) $ map paramName $ lambdaParams scan_op4
forM_ (zip7 prefixes accs xs xs' ys ys' tys) $
\(TV (ZonkAny 3)
prefix, TV (ZonkAny 1)
acc, VName
x, VName
x', VName
y, VName
y', PrimType
ty) -> do
VName -> PrimType -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
x PrimType
ty
VName -> PrimType -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
y PrimType
ty
VName -> TExp (ZonkAny 3) -> ImpM GPUMem KernelEnv KernelOp ()
forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
x' (TExp (ZonkAny 3) -> ImpM GPUMem KernelEnv KernelOp ())
-> TExp (ZonkAny 3) -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ TV (ZonkAny 3) -> TExp (ZonkAny 3)
forall {k} (t :: k). TV t -> TExp t
tvExp TV (ZonkAny 3)
prefix
VName -> TExp (ZonkAny 1) -> ImpM GPUMem KernelEnv KernelOp ()
forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
y' (TExp (ZonkAny 1) -> ImpM GPUMem KernelEnv KernelOp ())
-> TExp (ZonkAny 1) -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ TV (ZonkAny 1) -> TExp (ZonkAny 1)
forall {k} (t :: k). TV t -> TExp t
tvExp TV (ZonkAny 1)
acc
sIf
(ltid32 * chunk32 .<. boundary .&&. bNot blockNewSgm)
( compileStms mempty (bodyStms $ lambdaBody scan_op4) $
forM_ (zip3 xs tys $ map resSubExp $ bodyResult $ lambdaBody scan_op4) $
\(VName
x, PrimType
ty, SubExp
res) -> VName
x VName -> PrimExp VName -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. VName -> PrimExp VName -> ImpM rep r op ()
<~~ PrimType -> SubExp -> PrimExp VName
forall a. ToExp a => PrimType -> a -> PrimExp VName
toExp' PrimType
ty SubExp
res
)
(forM_ (zip xs accs) $ \(VName
x, TV (ZonkAny 1)
acc) -> VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
x [] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ TV (ZonkAny 1) -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV (ZonkAny 1)
acc) [])
stop <-
dPrimVE "stopping_point" $
segsize_compact - (ltid32 * chunk32 - 1 + segsize_compact - boundary) `rem` segsize_compact
sFor "i" chunk $ \TExp Int64
i -> do
TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int64 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
i TExp Int32 -> TExp Int32 -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
stop TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
[(VName, VName)]
-> ((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
private_chunks [VName]
ys) (((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
src, VName
y) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
y [] (VName -> SubExp
Var VName
src) [TExp Int64
i]
Names
-> Stms GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body GPUMem -> Stms GPUMem
forall rep. Body rep -> Stms rep
bodyStms (Body GPUMem -> Stms GPUMem) -> Body GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scan_op3) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
[(VName, SubExp)]
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
private_chunks ([SubExp] -> [(VName, SubExp)]) -> [SubExp] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp]) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Body GPUMem -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult (Body GPUMem -> [SubExpRes]) -> Body GPUMem -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scan_op3) (((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
\(VName
dest, SubExp
res) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
dest [TExp Int64
i] SubExp
res []
sComment "Transpose scan output and Write it to global memory in coalesced fashion" $ do
forM_ (zip3 transposedArrays private_chunks $ map patElemName all_pes) $ \(VName
locmem, VName
priv, VName
dest) -> do
String
-> TExp Int64
-> (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall {k} (t :: k) rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TExp Int64
chunk ((TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
sharedIdx <-
String -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"sharedIdx" (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
TExp Int64 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int64
ltid TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
chunk) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
i
copyDWIMFix locmem [tvExp sharedIdx] (Var priv) [i]
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
local_barrier
String
-> TExp Int64
-> (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall {k} (t :: k) rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TExp Int64
chunk ((TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
flat_idx <- String -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"flat_idx" (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64
thd_offset TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
tblock_size_e
dIndexSpace (zip gtids dims') flat_idx
sWhen (flat_idx .<. n) $ do
copyDWIMFix
dest
(map Imp.le64 gtids)
(Var locmem)
[sExt64 $ flat_idx - block_offset]
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
local_barrier
{-# NOINLINE compileSegScan #-}