{-# LANGUAGE TypeFamilies #-}
module Futhark.Pass.ExtractKernels.ISRWIM
( iswim,
irwim,
rwimPossible,
)
where
import Control.Arrow (first)
import Control.Monad
import Futhark.IR.SOACS
import Futhark.MonadFreshNames
import Futhark.Tools
iswim ::
(MonadBuilder m, Rep m ~ SOACS) =>
Pat Type ->
SubExp ->
Lambda SOACS ->
[(SubExp, VName)] ->
Maybe (m ())
iswim :: forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
Pat Type
-> SubExp -> Lambda SOACS -> [(SubExp, VName)] -> Maybe (m ())
iswim Pat Type
res_pat SubExp
w Lambda SOACS
scan_fun [(SubExp, VName)]
scan_input
| Just (Pat Type
map_pat, StmAux ()
map_aux, SubExp
map_w, Lambda SOACS
map_fun) <- Lambda SOACS -> Maybe (Pat Type, StmAux (), SubExp, Lambda SOACS)
rwimPossible Lambda SOACS
scan_fun = m () -> Maybe (m ())
forall a. a -> Maybe a
Just (m () -> Maybe (m ())) -> m () -> Maybe (m ())
forall a b. (a -> b) -> a -> b
$ do
let ([SubExp]
accs, [VName]
arrs) = [(SubExp, VName)] -> ([SubExp], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(SubExp, VName)]
scan_input
arrs' <- [VName] -> m [VName]
forall (m :: * -> *). MonadBuilder m => [VName] -> m [VName]
transposedArrays [VName]
arrs
accs' <- mapM (letExp "acc" . BasicOp . SubExp) accs
let map_arrs' = [VName]
accs' [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
arrs'
(scan_acc_params, scan_elem_params) =
splitAt (length arrs) $ lambdaParams scan_fun
map_params =
(Param Type -> Param Type) -> [Param Type] -> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Param Type
LParam SOACS -> LParam SOACS
removeParamOuterDim [Param Type]
scan_acc_params
[Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ (Param Type -> Param Type) -> [Param Type] -> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> LParam SOACS -> LParam SOACS
setParamOuterDimTo SubExp
w) [Param Type]
scan_elem_params
map_rettype = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> Type -> Type
setOuterDimTo SubExp
w) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
scan_fun
scan_params = Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_fun
scan_body = Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
map_fun
scan_rettype = Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
map_fun
scan_fun' = [LParam SOACS] -> [Type] -> Body SOACS -> Lambda SOACS
forall rep. [LParam rep] -> [Type] -> Body rep -> Lambda rep
Lambda [LParam SOACS]
scan_params [Type]
scan_rettype Body SOACS
scan_body
scan_input' =
((VName, VName) -> (SubExp, VName))
-> [(VName, VName)] -> [(SubExp, VName)]
forall a b. (a -> b) -> [a] -> [b]
map ((VName -> SubExp) -> (VName, VName) -> (SubExp, VName)
forall b c d. (b -> c) -> (b, d) -> (c, d)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first VName -> SubExp
Var) ([(VName, VName)] -> [(SubExp, VName)])
-> [(VName, VName)] -> [(SubExp, VName)]
forall a b. (a -> b) -> a -> b
$
([VName] -> [VName] -> [(VName, VName)])
-> ([VName], [VName]) -> [(VName, VName)]
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry [VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (([VName], [VName]) -> [(VName, VName)])
-> ([VName], [VName]) -> [(VName, VName)]
forall a b. (a -> b) -> a -> b
$
Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs') ([VName] -> ([VName], [VName])) -> [VName] -> ([VName], [VName])
forall a b. (a -> b) -> a -> b
$
(Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
map_params
(nes', scan_arrs) = unzip scan_input'
scan_soac <- scanSOAC [Scan scan_fun' nes']
let map_body =
Stms SOACS -> Result -> Body SOACS
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody
( Stm SOACS -> Stms SOACS
forall rep. Stm rep -> Stms rep
oneStm (Stm SOACS -> Stms SOACS) -> Stm SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$
Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (SubExp -> Pat Type -> Pat Type
setPatOuterDimTo SubExp
w Pat Type
map_pat) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$
Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$
SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
scan_arrs ScremaForm SOACS
scan_soac
)
(Result -> Body SOACS) -> Result -> Body SOACS
forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes
([VName] -> Result) -> [VName] -> Result
forall a b. (a -> b) -> a -> b
$ Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
map_pat
map_fun' = [LParam SOACS] -> [Type] -> Body SOACS -> Lambda SOACS
forall rep. [LParam rep] -> [Type] -> Body rep -> Lambda rep
Lambda [Param Type]
[LParam SOACS]
map_params [Type]
map_rettype Body SOACS
map_body
res_pat' <-
fmap basicPat $
mapM (newIdent' (<> "_transposed") . transposeIdentType) $
patIdents res_pat
addStm . Let res_pat' map_aux . Op $
Screma map_w map_arrs' (mapSOAC map_fun')
forM_ (zip (patIdents res_pat) (patIdents res_pat')) $ \(Ident
to, Ident
from) -> do
let perm :: [Int]
perm = [Int
1, Int
0] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
2 .. Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Ident -> Type
identType Ident
from) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
Stm (Rep m) -> m ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep m) -> m ()) -> Stm (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$
Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([Ident] -> Pat Type
basicPat [Ident
to]) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp SOACS -> Stm (Rep m))
-> (BasicOp -> Exp SOACS) -> BasicOp -> Stm (Rep m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Stm (Rep m)) -> BasicOp -> Stm (Rep m)
forall a b. (a -> b) -> a -> b
$
VName -> [Int] -> BasicOp
Rearrange (Ident -> VName
identName Ident
from) [Int]
perm
| Bool
otherwise = Maybe (m ())
forall a. Maybe a
Nothing
irwim ::
(MonadBuilder m, Rep m ~ SOACS) =>
Pat Type ->
SubExp ->
Commutativity ->
Lambda SOACS ->
[(SubExp, VName)] ->
Maybe (m ())
irwim :: forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
Pat Type
-> SubExp
-> Commutativity
-> Lambda SOACS
-> [(SubExp, VName)]
-> Maybe (m ())
irwim Pat Type
res_pat SubExp
w Commutativity
comm Lambda SOACS
red_fun [(SubExp, VName)]
red_input
| Just (Pat Type
map_pat, StmAux ()
map_aux, SubExp
map_w, Lambda SOACS
map_fun) <- Lambda SOACS -> Maybe (Pat Type, StmAux (), SubExp, Lambda SOACS)
rwimPossible Lambda SOACS
red_fun = m () -> Maybe (m ())
forall a. a -> Maybe a
Just (m () -> Maybe (m ())) -> m () -> Maybe (m ())
forall a b. (a -> b) -> a -> b
$ do
let ([SubExp]
accs, [VName]
arrs) = [(SubExp, VName)] -> ([SubExp], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(SubExp, VName)]
red_input
arrs' <- [VName] -> m [VName]
forall (m :: * -> *). MonadBuilder m => [VName] -> m [VName]
transposedArrays [VName]
arrs
let indexAcc (Var VName
v) = do
v_t <- VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
letSubExp "acc" $
BasicOp $
Index v $
fullSlice v_t [DimFix $ intConst Int64 0]
indexAcc Constant {} =
String -> m SubExp
forall a. HasCallStack => String -> a
error String
"irwim: array accumulator is a constant."
accs' <- mapM indexAcc accs
let (_red_acc_params, red_elem_params) =
splitAt (length arrs) $ lambdaParams red_fun
map_rettype = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall u.
TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase SubExp) u
rowType ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
red_fun
map_params = (Param Type -> Param Type) -> [Param Type] -> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> LParam SOACS -> LParam SOACS
setParamOuterDimTo SubExp
w) [Param Type]
red_elem_params
red_params = Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_fun
red_body = Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
map_fun
red_rettype = Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
map_fun
red_fun' = [LParam SOACS] -> [Type] -> Body SOACS -> Lambda SOACS
forall rep. [LParam rep] -> [Type] -> Body rep -> Lambda rep
Lambda [LParam SOACS]
red_params [Type]
red_rettype Body SOACS
red_body
red_input' = [SubExp] -> [VName] -> [(SubExp, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
accs' ([VName] -> [(SubExp, VName)]) -> [VName] -> [(SubExp, VName)]
forall a b. (a -> b) -> a -> b
$ (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
map_params
red_pat = Pat Type -> Pat Type
stripPatOuterDim Pat Type
map_pat
map_body <-
case irwim red_pat w comm red_fun' red_input' of
Maybe (m ())
Nothing -> do
reduce_soac <- [Reduce SOACS] -> m (ScremaForm SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Reduce rep] -> m (ScremaForm rep)
reduceSOAC [Commutativity -> Lambda SOACS -> [SubExp] -> Reduce SOACS
forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
comm Lambda SOACS
red_fun' ([SubExp] -> Reduce SOACS) -> [SubExp] -> Reduce SOACS
forall a b. (a -> b) -> a -> b
$ ((SubExp, VName) -> SubExp) -> [(SubExp, VName)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, VName) -> SubExp
forall a b. (a, b) -> a
fst [(SubExp, VName)]
red_input']
pure
$ mkBody
( oneStm $
Let red_pat (defAux ()) $
Op $
Screma w (map snd red_input') reduce_soac
)
$ varsRes
$ patNames map_pat
Just m ()
m -> Scope SOACS -> m (Body SOACS) -> m (Body SOACS)
forall a. Scope SOACS -> m a -> m a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param Type] -> Scope SOACS
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [Param Type]
map_params) (m (Body SOACS) -> m (Body SOACS))
-> m (Body SOACS) -> m (Body SOACS)
forall a b. (a -> b) -> a -> b
$ do
map_body_stms <- m () -> m (Stms (Rep m))
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ m ()
m
pure $ mkBody map_body_stms $ varsRes $ patNames map_pat
let map_fun' = [LParam SOACS] -> [Type] -> Body SOACS -> Lambda SOACS
forall rep. [LParam rep] -> [Type] -> Body rep -> Lambda rep
Lambda [Param Type]
[LParam SOACS]
map_params [Type]
map_rettype Body SOACS
map_body
addStm . Let res_pat map_aux . Op . Screma map_w arrs' $
mapSOAC map_fun'
| Bool
otherwise = Maybe (m ())
forall a. Maybe a
Nothing
rwimPossible ::
Lambda SOACS ->
Maybe (Pat Type, StmAux (), SubExp, Lambda SOACS)
rwimPossible :: Lambda SOACS -> Maybe (Pat Type, StmAux (), SubExp, Lambda SOACS)
rwimPossible Lambda SOACS
fun
| Body BodyDec SOACS
_ Stms SOACS
stms Result
res <- Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
fun,
[Stm SOACS
stm] <- Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms SOACS
stms,
Pat (LetDec SOACS)
map_pat <- Stm SOACS -> Pat (LetDec SOACS)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm SOACS
stm,
(VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
Pat (LetDec SOACS)
map_pat) [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
res,
Op (Screma SubExp
map_w [VName]
map_arrs ScremaForm SOACS
form) <- Stm SOACS -> Exp SOACS
forall rep. Stm rep -> Exp rep
stmExp Stm SOACS
stm,
Just Lambda SOACS
map_fun <- ScremaForm SOACS -> Maybe (Lambda SOACS)
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm SOACS
form,
(Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
fun) [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
== [VName]
map_arrs =
(Pat Type, StmAux (), SubExp, Lambda SOACS)
-> Maybe (Pat Type, StmAux (), SubExp, Lambda SOACS)
forall a. a -> Maybe a
Just (Pat Type
Pat (LetDec SOACS)
map_pat, Stm SOACS -> StmAux (ExpDec SOACS)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm SOACS
stm, SubExp
map_w, Lambda SOACS
map_fun)
| Bool
otherwise =
Maybe (Pat Type, StmAux (), SubExp, Lambda SOACS)
forall a. Maybe a
Nothing
transposedArrays :: (MonadBuilder m) => [VName] -> m [VName]
transposedArrays :: forall (m :: * -> *). MonadBuilder m => [VName] -> m [VName]
transposedArrays [VName]
arrs = [VName] -> (VName -> m VName) -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
arrs ((VName -> m VName) -> m [VName])
-> (VName -> m VName) -> m [VName]
forall a b. (a -> b) -> a -> b
$ \VName
arr -> do
t <- VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
let perm = [Int
1, Int
0] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
2 .. Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
letExp (baseString arr) $ BasicOp $ Rearrange arr perm
removeParamOuterDim :: LParam SOACS -> LParam SOACS
removeParamOuterDim :: LParam SOACS -> LParam SOACS
removeParamOuterDim LParam SOACS
param =
let t :: Type
t = Type -> Type
forall u.
TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase SubExp) u
rowType (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
LParam SOACS
param
in LParam SOACS
param {paramDec = t}
setParamOuterDimTo :: SubExp -> LParam SOACS -> LParam SOACS
setParamOuterDimTo :: SubExp -> LParam SOACS -> LParam SOACS
setParamOuterDimTo SubExp
w LParam SOACS
param =
let t :: Type
t = SubExp -> Type -> Type
setOuterDimTo SubExp
w (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
LParam SOACS
param
in LParam SOACS
param {paramDec = t}
setIdentOuterDimTo :: SubExp -> Ident -> Ident
setIdentOuterDimTo :: SubExp -> Ident -> Ident
setIdentOuterDimTo SubExp
w Ident
ident =
let t :: Type
t = SubExp -> Type -> Type
setOuterDimTo SubExp
w (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Ident -> Type
identType Ident
ident
in Ident
ident {identType = t}
setOuterDimTo :: SubExp -> Type -> Type
setOuterDimTo :: SubExp -> Type -> Type
setOuterDimTo SubExp
w Type
t =
Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
arrayOfRow (Type -> Type
forall u.
TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase SubExp) u
rowType Type
t) SubExp
w
setPatOuterDimTo :: SubExp -> Pat Type -> Pat Type
setPatOuterDimTo :: SubExp -> Pat Type -> Pat Type
setPatOuterDimTo SubExp
w Pat Type
pat =
[Ident] -> Pat Type
basicPat ([Ident] -> Pat Type) -> [Ident] -> Pat Type
forall a b. (a -> b) -> a -> b
$ (Ident -> Ident) -> [Ident] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> Ident -> Ident
setIdentOuterDimTo SubExp
w) ([Ident] -> [Ident]) -> [Ident] -> [Ident]
forall a b. (a -> b) -> a -> b
$ Pat Type -> [Ident]
forall dec. Typed dec => Pat dec -> [Ident]
patIdents Pat Type
pat
transposeIdentType :: Ident -> Ident
transposeIdentType :: Ident -> Ident
transposeIdentType Ident
ident =
Ident
ident {identType = transposeType $ identType ident}
stripIdentOuterDim :: Ident -> Ident
stripIdentOuterDim :: Ident -> Ident
stripIdentOuterDim Ident
ident =
Ident
ident {identType = rowType $ identType ident}
stripPatOuterDim :: Pat Type -> Pat Type
stripPatOuterDim :: Pat Type -> Pat Type
stripPatOuterDim Pat Type
pat =
[Ident] -> Pat Type
basicPat ([Ident] -> Pat Type) -> [Ident] -> Pat Type
forall a b. (a -> b) -> a -> b
$ (Ident -> Ident) -> [Ident] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> Ident
stripIdentOuterDim ([Ident] -> [Ident]) -> [Ident] -> [Ident]
forall a b. (a -> b) -> a -> b
$ Pat Type -> [Ident]
forall dec. Typed dec => Pat dec -> [Ident]
patIdents Pat Type
pat