{-# LANGUAGE TypeFamilies #-}
module Futhark.AD.Rev.Reduce
( diffReduce,
diffMinMaxReduce,
diffVecReduce,
diffMulReduce,
)
where
import Control.Monad
import Futhark.AD.Rev.Monad
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Transform.Rename
eReverse :: (MonadBuilder m) => VName -> m VName
eReverse :: forall (m :: * -> *). MonadBuilder m => VName -> m VName
eReverse VName
arr = do
arr_t <- VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
let w = Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
arr_t
start <-
letSubExp "rev_start" . BasicOp $
BinOp (Sub Int64 OverflowUndef) w (intConst Int64 1)
let stride = IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1)
slice = Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
start SubExp
w SubExp
stride]
letExp (baseString arr <> "_rev") $ BasicOp $ Index arr slice
scanExc ::
(MonadBuilder m, Rep m ~ SOACS) =>
String ->
Scan SOACS ->
[VName] ->
m [VName]
scanExc :: forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
String -> Scan SOACS -> [VName] -> m [VName]
scanExc String
desc Scan SOACS
scan [VName]
arrs = do
w <- Int -> [Type] -> SubExp
forall u. Int -> [TypeBase Shape u] -> SubExp
arraysSize Int
0 ([Type] -> SubExp) -> m [Type] -> m SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> m Type) -> [VName] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
arrs
form <- scanSOAC [scan]
res_incl <- letTupExp (desc <> "_incl") $ Op $ Screma w arrs form
iota <-
letExp "iota" . BasicOp $
Iota w (intConst Int64 0) (intConst Int64 1) Int64
iparam <- newParam "iota_param" $ Prim int64
lam <- mkLambda [iparam] $ do
let first_elem =
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp
(PrimType -> CmpOp
CmpEq PrimType
int64)
(SubExp -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (VName -> SubExp
Var (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
iparam)))
(SubExp -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0))
prev = TPrimExp Int64 VName -> m (Exp (Rep m))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (TPrimExp Int64 VName -> m (Exp (Rep m)))
-> TPrimExp Int64 VName -> m (Exp (Rep m))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
iparam) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1
fmap subExpsRes . letTupExp' "scan_ex_res"
=<< eIf
first_elem
(resultBodyM $ scanNeutral scan)
(eBody $ map (`eIndex` [prev]) res_incl)
letTupExp desc $ Op $ Screma w [iota] (mapSOAC lam)
mkF :: Lambda SOACS -> ADM ([VName], Lambda SOACS)
mkF :: Lambda SOACS -> ADM ([VName], Lambda SOACS)
mkF Lambda SOACS
lam = do
lam_l <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam
lam_r <- renameLambda lam
let q = [Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Type] -> Int) -> [Type] -> Int
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam
(lps, aps) = splitAt q $ lambdaParams lam_l
(ips, rps) = splitAt q $ lambdaParams lam_r
lam' <- mkLambda (lps <> aps <> rps) $ do
lam_l_res <- bodyBind $ lambdaBody lam_l
forM_ (zip ips lam_l_res) $ \(Param Type
ip, SubExpRes Certs
cs SubExp
se) ->
Certs -> ADM () -> ADM ()
forall a. Certs -> ADM a -> ADM a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
ip] (Exp (Rep ADM) -> ADM ()) -> Exp (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
bodyBind $ lambdaBody lam_r
pure (map paramName aps, lam')
diffReduce :: VjpOps -> [VName] -> SubExp -> [VName] -> Reduce SOACS -> ADM ()
diffReduce :: VjpOps -> [VName] -> SubExp -> [VName] -> Reduce SOACS -> ADM ()
diffReduce VjpOps
_ops [VName
adj] SubExp
w [VName
a] Reduce SOACS
red
| Just [(BinOp
op, PrimType
_, VName
_, VName
_)] <- Lambda SOACS -> Maybe [(BinOp, PrimType, VName, VName)]
forall rep.
ASTRep rep =>
Lambda rep -> Maybe [(BinOp, PrimType, VName, VName)]
lamIsBinOp (Lambda SOACS -> Maybe [(BinOp, PrimType, VName, VName)])
-> Lambda SOACS -> Maybe [(BinOp, PrimType, VName, VName)]
forall a b. (a -> b) -> a -> b
$ Reduce SOACS -> Lambda SOACS
forall rep. Reduce rep -> Lambda rep
redLambda Reduce SOACS
red,
BinOp -> Bool
isAdd BinOp
op = do
adj_rep <-
String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
adj String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_rep") (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$
Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
VName -> SubExp
Var VName
adj
void $ updateAdj a adj_rep
where
isAdd :: BinOp -> Bool
isAdd FAdd {} = Bool
True
isAdd Add {} = Bool
True
isAdd BinOp
_ = Bool
False
diffReduce VjpOps
ops [VName]
pat_adj SubExp
w [VName]
as Reduce SOACS
red = do
red' <- Reduce SOACS -> ADM (Reduce SOACS)
forall {f :: * -> *} {rep}.
(Rename (OpC rep rep), Rename (LetDec rep), Rename (ExpDec rep),
Rename (BodyDec rep), Rename (FParamInfo rep),
Rename (LParamInfo rep), Rename (RetType rep),
Rename (BranchType rep), MonadFreshNames f) =>
Reduce rep -> f (Reduce rep)
renameRed Reduce SOACS
red
flip_red <- renameRed =<< flipReduce red
ls <- scanExc "ls" (redToScan red') as
rs <-
mapM eReverse
=<< scanExc "ls" (redToScan flip_red)
=<< mapM eReverse as
(as_params, f) <- mkF $ redLambda red
f_adj <- vjpLambda ops (map adjFromVar pat_adj) as_params f
as_adj <- letTupExp "adjs" $ Op $ Screma w (ls ++ as ++ rs) (mapSOAC f_adj)
zipWithM_ updateAdj as as_adj
where
renameRed :: Reduce rep -> f (Reduce rep)
renameRed (Reduce Commutativity
comm Lambda rep
lam [SubExp]
nes) =
Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
comm (Lambda rep -> [SubExp] -> Reduce rep)
-> f (Lambda rep) -> f ([SubExp] -> Reduce rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Lambda rep -> f (Lambda rep)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda rep
lam f ([SubExp] -> Reduce rep) -> f [SubExp] -> f (Reduce rep)
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [SubExp] -> f [SubExp]
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp]
nes
redToScan :: Reduce SOACS -> Scan SOACS
redToScan :: Reduce SOACS -> Scan SOACS
redToScan (Reduce Commutativity
_ Lambda SOACS
lam [SubExp]
nes) = Lambda SOACS -> [SubExp] -> Scan SOACS
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda SOACS
lam [SubExp]
nes
flipReduce :: Reduce rep -> m (Reduce rep)
flipReduce (Reduce Commutativity
comm Lambda rep
lam [SubExp]
nes) = do
lam' <- Lambda rep -> m (Lambda rep)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda rep
lam {lambdaParams = flipParams $ lambdaParams lam}
pure $ Reduce comm lam' nes
flipParams :: [a] -> [a]
flipParams [a]
ps = ([a] -> [a] -> [a]) -> ([a], [a]) -> [a]
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (([a] -> [a] -> [a]) -> [a] -> [a] -> [a]
forall a b c. (a -> b -> c) -> b -> a -> c
flip [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
(++)) (([a], [a]) -> [a]) -> ([a], [a]) -> [a]
forall a b. (a -> b) -> a -> b
$ Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt ([a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
ps Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) [a]
ps
diffMinMaxReduce ::
VjpOps -> VName -> StmAux () -> SubExp -> BinOp -> SubExp -> VName -> ADM () -> ADM ()
diffMinMaxReduce :: VjpOps
-> VName
-> StmAux ()
-> SubExp
-> BinOp
-> SubExp
-> VName
-> ADM ()
-> ADM ()
diffMinMaxReduce VjpOps
_ops VName
x StmAux ()
aux SubExp
w BinOp
minmax SubExp
ne VName
as ADM ()
m = do
let t :: PrimType
t = BinOp -> PrimType
binOpType BinOp
minmax
acc_v_p <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"acc_v" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
acc_i_p <- newParam "acc_i" $ Prim int64
v_p <- newParam "v" $ Prim t
i_p <- newParam "i" $ Prim int64
red_lam <-
mkLambda [acc_v_p, acc_i_p, v_p, i_p] $
fmap varsRes . letTupExp "idx_res"
=<< eIf
(eCmpOp (CmpEq t) (eParam acc_v_p) (eParam v_p))
( eBody
[ eParam acc_v_p,
eBinOp (SMin Int64) (eParam acc_i_p) (eParam i_p)
]
)
( eBody
[ eIf
( eCmpOp
(CmpEq t)
(eParam acc_v_p)
(eBinOp minmax (eParam acc_v_p) (eParam v_p))
)
(eBody [eParam acc_v_p, eParam acc_i_p])
(eBody [eParam v_p, eParam i_p])
]
)
red_iota <-
letExp "red_iota" $
BasicOp $
Iota w (intConst Int64 0) (intConst Int64 1) Int64
form <- reduceSOAC [Reduce Commutative red_lam [ne, intConst Int64 (-1)]]
x_ind <- newVName (baseString x <> "_ind")
auxing aux $ letBindNames [x, x_ind] $ Op $ Screma w [as, red_iota] form
m
x_adj <- lookupAdjVal x
in_bounds <-
letSubExp "minmax_in_bounds" . BasicOp $
CmpOp (CmpSlt Int64) (intConst Int64 0) w
updateAdjIndex as (CheckBounds (Just in_bounds), Var x_ind) (Var x_adj)
diffVecReduce ::
VjpOps -> Pat Type -> StmAux () -> SubExp -> Commutativity -> Lambda SOACS -> VName -> VName -> ADM () -> ADM ()
diffVecReduce :: VjpOps
-> Pat Type
-> StmAux ()
-> SubExp
-> Commutativity
-> Lambda SOACS
-> VName
-> VName
-> ADM ()
-> ADM ()
diffVecReduce VjpOps
ops Pat Type
x StmAux ()
aux SubExp
w Commutativity
iscomm Lambda SOACS
lam VName
ne VName
as ADM ()
m = do
stms <- ADM () -> ADM (Stms (Rep ADM))
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (ADM () -> ADM (Stms (Rep ADM))) -> ADM () -> ADM (Stms (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
rank <- Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Type -> Int) -> ADM Type -> ADM Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
as
let rear = [Int
1, Int
0] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
drop Int
2 [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
tran_as <- letExp "tran_as" $ BasicOp $ Rearrange as rear
ts <- lookupType tran_as
t_ne <- lookupType ne
as_param <- newParam "as_param" $ rowType ts
ne_param <- newParam "ne_param" $ rowType t_ne
reduce_form <- reduceSOAC [Reduce iscomm lam [Var $ paramName ne_param]]
map_lam <-
mkLambda [as_param, ne_param] $
fmap varsRes . letTupExp "idx_res" $
Op $
Screma w [paramName as_param] reduce_form
addStm $ Let x aux $ Op $ Screma (arraySize 0 ts) [tran_as, ne] $ mapSOAC map_lam
foldr (vjpStm ops) m stms
diffMulReduce ::
VjpOps -> VName -> StmAux () -> SubExp -> BinOp -> SubExp -> VName -> ADM () -> ADM ()
diffMulReduce :: VjpOps
-> VName
-> StmAux ()
-> SubExp
-> BinOp
-> SubExp
-> VName
-> ADM ()
-> ADM ()
diffMulReduce VjpOps
_ops VName
x StmAux ()
aux SubExp
w BinOp
mul SubExp
ne VName
as ADM ()
m = do
let t :: PrimType
t = BinOp -> PrimType
binOpType BinOp
mul
let const_zero :: ADM (Exp (Rep ADM))
const_zero = SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t
a_param <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"a" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
map_lam <-
mkLambda [a_param] $
fmap varsRes . letTupExp "map_res"
=<< eIf
(eCmpOp (CmpEq t) (eParam a_param) const_zero)
(eBody $ fmap eSubExp [Constant $ onePrimValue t, intConst Int64 1])
(eBody [eParam a_param, eSubExp $ intConst Int64 0])
ps <- newVName "ps"
zs <- newVName "zs"
auxing aux $
letBindNames [ps, zs] $
Op $
Screma w [as] $
mapSOAC map_lam
red_lam_mul <- binOpLambda mul t
red_lam_add <- binOpLambda (Add Int64 OverflowUndef) int64
red_form_mul <- reduceSOAC $ pure $ Reduce Commutative red_lam_mul $ pure ne
red_form_add <- reduceSOAC $ pure $ Reduce Commutative red_lam_add $ pure $ intConst Int64 0
nz_prods <- newVName "non_zero_prod"
zr_count <- newVName "zero_count"
auxing aux $ letBindNames [nz_prods] $ Op $ Screma w [ps] red_form_mul
auxing aux $ letBindNames [zr_count] $ Op $ Screma w [zs] red_form_add
auxing aux $
letBindNames [x]
=<< eIf
(toExp $ 0 .==. le64 zr_count)
(eBody $ pure $ eSubExp $ Var nz_prods)
(eBody $ pure const_zero)
m
x_adj <- lookupAdjVal x
a_param_rev <- newParam "a" $ Prim t
map_lam_rev <-
mkLambda [a_param_rev] $
fmap varsRes . letTupExp "adj_res"
=<< eIf
(toExp $ 0 .==. le64 zr_count)
( eBody $
pure $
eBinOp mul (eSubExp $ Var x_adj) $
eBinOp (getDiv t) (eSubExp $ Var nz_prods) $
eParam a_param_rev
)
( eBody $
pure $
eIf
(toExp $ 1 .==. le64 zr_count)
( eBody $
pure $
eIf
(eCmpOp (CmpEq t) (eParam a_param_rev) const_zero)
( eBody $
pure $
eBinOp mul (eSubExp $ Var x_adj) $
eSubExp $
Var nz_prods
)
(eBody $ pure const_zero)
)
(eBody $ pure const_zero)
)
as_adjup <- letExp "adjs" $ Op $ Screma w [as] $ mapSOAC map_lam_rev
updateAdj as as_adjup
where
getDiv :: PrimType -> BinOp
getDiv :: PrimType -> BinOp
getDiv (IntType IntType
t) = IntType -> Safety -> BinOp
SDiv IntType
t Safety
Unsafe
getDiv (FloatType FloatType
t) = FloatType -> BinOp
FDiv FloatType
t
getDiv PrimType
_ = String -> BinOp
forall a. HasCallStack => String -> a
error String
"In getDiv, Reduce.hs: input not supported"