{-# LANGUAGE TypeFamilies #-}

module Futhark.AD.Rev.Loop (diffLoop, stripmineStms) where

import Control.Monad
import Data.Foldable (toList)
import Data.List ((\\))
import Data.Map qualified as M
import Data.Maybe
import Futhark.AD.Rev.Monad
import Futhark.Analysis.Alias qualified as Alias
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.Aliases (consumedInStms)
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util (nubOrd, traverseFold)

-- | A convenience function to bring the components of a for-loop into
-- scope and throw an error if the passed 'Exp' is not a for-loop.
bindForLoop ::
  (PrettyRep rep) =>
  Exp rep ->
  ( [(Param (FParamInfo rep), SubExp)] ->
    LoopForm ->
    VName ->
    IntType ->
    SubExp ->
    Body rep ->
    a
  ) ->
  a
bindForLoop :: forall rep a.
PrettyRep rep =>
Exp rep
-> ([(Param (FParamInfo rep), SubExp)]
    -> LoopForm -> VName -> IntType -> SubExp -> Body rep -> a)
-> a
bindForLoop (Loop [(FParam rep, SubExp)]
val_pats form :: LoopForm
form@(ForLoop VName
i IntType
it SubExp
bound) Body rep
body) [(FParam rep, SubExp)]
-> LoopForm -> VName -> IntType -> SubExp -> Body rep -> a
f =
  [(FParam rep, SubExp)]
-> LoopForm -> VName -> IntType -> SubExp -> Body rep -> a
f [(FParam rep, SubExp)]
val_pats LoopForm
form VName
i IntType
it SubExp
bound Body rep
body
bindForLoop Exp rep
e [(FParam rep, SubExp)]
-> LoopForm -> VName -> IntType -> SubExp -> Body rep -> a
_ = [Char] -> a
forall a. HasCallStack => [Char] -> a
error ([Char] -> a) -> [Char] -> a
forall a b. (a -> b) -> a -> b
$ [Char]
"bindForLoop: not a for-loop:\n" [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Exp rep -> [Char]
forall a. Pretty a => a -> [Char]
prettyString Exp rep
e

-- | A convenience function to rename a for-loop and then bind the
-- renamed components.
renameForLoop ::
  (MonadFreshNames m, Renameable rep, PrettyRep rep) =>
  Exp rep ->
  ( Exp rep ->
    [(Param (FParamInfo rep), SubExp)] ->
    LoopForm ->
    VName ->
    IntType ->
    SubExp ->
    Body rep ->
    m a
  ) ->
  m a
renameForLoop :: forall (m :: * -> *) rep a.
(MonadFreshNames m, Renameable rep, PrettyRep rep) =>
Exp rep
-> (Exp rep
    -> [(Param (FParamInfo rep), SubExp)]
    -> LoopForm
    -> VName
    -> IntType
    -> SubExp
    -> Body rep
    -> m a)
-> m a
renameForLoop Exp rep
loop Exp rep
-> [(Param (FParamInfo rep), SubExp)]
-> LoopForm
-> VName
-> IntType
-> SubExp
-> Body rep
-> m a
f = Exp rep -> m (Exp rep)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Exp rep -> m (Exp rep)
renameExp Exp rep
loop m (Exp rep) -> (Exp rep -> m a) -> m a
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Exp rep
loop' -> Exp rep
-> ([(Param (FParamInfo rep), SubExp)]
    -> LoopForm -> VName -> IntType -> SubExp -> Body rep -> m a)
-> m a
forall rep a.
PrettyRep rep =>
Exp rep
-> ([(Param (FParamInfo rep), SubExp)]
    -> LoopForm -> VName -> IntType -> SubExp -> Body rep -> a)
-> a
bindForLoop Exp rep
loop' (Exp rep
-> [(Param (FParamInfo rep), SubExp)]
-> LoopForm
-> VName
-> IntType
-> SubExp
-> Body rep
-> m a
f Exp rep
loop')

-- | Is the loop a while-loop?
isWhileLoop :: Exp rep -> Bool
isWhileLoop :: forall rep. Exp rep -> Bool
isWhileLoop (Loop [(FParam rep, SubExp)]
_ WhileLoop {} Body rep
_) = Bool
True
isWhileLoop Exp rep
_ = Bool
False

-- | Augments a while-loop to also compute the number of iterations.
computeWhileIters :: Exp SOACS -> ADM SubExp
computeWhileIters :: Exp SOACS -> ADM SubExp
computeWhileIters (Loop [(FParam SOACS, SubExp)]
val_pats (WhileLoop VName
b) Body SOACS
body) = do
  bound_v <- [Char] -> ADM VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"bound"
  let t = PrimType -> TypeBase shape u
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> TypeBase shape u) -> PrimType -> TypeBase shape u
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64
      bound_param = Attrs
-> VName
-> TypeBase Shape Uniqueness
-> Param (TypeBase Shape Uniqueness)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty VName
bound_v TypeBase Shape Uniqueness
forall {shape} {u}. TypeBase shape u
t
  bound_init <- letSubExp "bound_init" $ zeroExp t
  body' <- localScope (scopeOfFParams [bound_param]) $
    buildBody_ $ do
      bound_plus_one <-
        let one = PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
Int64 (Int
1 :: Int)
         in letSubExp "bound+1" $ BasicOp $ BinOp (Add Int64 OverflowUndef) (Var bound_v) one
      addStms $ bodyStms body
      pure (pure (subExpRes bound_plus_one) <> bodyResult body)
  res <- letTupExp' "loop" $ Loop ((bound_param, bound_init) : val_pats) (WhileLoop b) body'
  pure $ head res
computeWhileIters Exp SOACS
e = [Char] -> ADM SubExp
forall a. HasCallStack => [Char] -> a
error ([Char] -> ADM SubExp) -> [Char] -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ [Char]
"convertWhileIters: not a while-loop:\n" [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Exp SOACS -> [Char]
forall a. Pretty a => a -> [Char]
prettyString Exp SOACS
e

-- | Converts a 'WhileLoop' into a 'ForLoop'. Requires that the
-- surrounding 'Loop' is annotated with a @#[bound(n)]@ attribute,
-- where @n@ is an upper bound on the number of iterations of the
-- while-loop. The resulting for-loop will execute for @n@ iterations on
-- all inputs, so the tighter the bound the better.
convertWhileLoop :: SubExp -> Exp SOACS -> ADM (Exp SOACS)
convertWhileLoop :: SubExp -> Exp SOACS -> ADM (Exp SOACS)
convertWhileLoop SubExp
bound_se (Loop [(FParam SOACS, SubExp)]
val_pats (WhileLoop VName
cond) Body SOACS
body) =
  Scope SOACS -> ADM (Exp SOACS) -> ADM (Exp SOACS)
forall a. Scope SOACS -> ADM a -> ADM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([FParam SOACS] -> Scope SOACS
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams ([FParam SOACS] -> Scope SOACS) -> [FParam SOACS] -> Scope SOACS
forall a b. (a -> b) -> a -> b
$ ((FParam SOACS, SubExp) -> FParam SOACS)
-> [(FParam SOACS, SubExp)] -> [FParam SOACS]
forall a b. (a -> b) -> [a] -> [b]
map (FParam SOACS, SubExp) -> FParam SOACS
forall a b. (a, b) -> a
fst [(FParam SOACS, SubExp)]
val_pats) (ADM (Exp SOACS) -> ADM (Exp SOACS))
-> ADM (Exp SOACS) -> ADM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$ do
    i <- [Char] -> ADM VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"i"
    body' <-
      eBody
        [ eIf
            (pure $ BasicOp $ SubExp $ Var cond)
            (pure body)
            (resultBodyM $ map (Var . paramName . fst) val_pats)
        ]
    pure $ Loop val_pats (ForLoop i Int64 bound_se) body'
convertWhileLoop SubExp
_ Exp SOACS
e = [Char] -> ADM (Exp SOACS)
forall a. HasCallStack => [Char] -> a
error ([Char] -> ADM (Exp SOACS)) -> [Char] -> ADM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$ [Char]
"convertWhileLoopBound: not a while-loop:\n" [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Exp SOACS -> [Char]
forall a. Pretty a => a -> [Char]
prettyString Exp SOACS
e

-- | @nestifyLoop n bound loop@ transforms a loop into a depth-@n@ loop nest
-- of @bound@-iteration loops. This transformation does not preserve
-- the original semantics of the loop: @n@ and @bound@ may be arbitrary and have
-- no relation to the number of iterations of @loop@.
nestifyLoop ::
  SubExp ->
  Integer ->
  Exp SOACS ->
  ADM (Exp SOACS)
nestifyLoop :: SubExp -> Integer -> Exp SOACS -> ADM (Exp SOACS)
nestifyLoop SubExp
bound_se = SubExp -> Integer -> Exp SOACS -> ADM (Exp SOACS)
nestifyLoop' SubExp
bound_se
  where
    nestifyLoop' :: SubExp -> Integer -> Exp SOACS -> ADM (Exp SOACS)
nestifyLoop' SubExp
offset Integer
n Exp SOACS
loop = Exp SOACS
-> ([(FParam SOACS, SubExp)]
    -> LoopForm
    -> VName
    -> IntType
    -> SubExp
    -> Body SOACS
    -> ADM (Exp SOACS))
-> ADM (Exp SOACS)
forall rep a.
PrettyRep rep =>
Exp rep
-> ([(Param (FParamInfo rep), SubExp)]
    -> LoopForm -> VName -> IntType -> SubExp -> Body rep -> a)
-> a
bindForLoop Exp SOACS
loop [(Param (TypeBase Shape Uniqueness), SubExp)]
-> LoopForm
-> VName
-> IntType
-> SubExp
-> Body SOACS
-> ADM (Exp SOACS)
[(FParam SOACS, SubExp)]
-> LoopForm
-> VName
-> IntType
-> SubExp
-> Body SOACS
-> ADM (Exp SOACS)
nestify
      where
        nestify :: [(Param (TypeBase Shape Uniqueness), SubExp)]
-> LoopForm
-> VName
-> IntType
-> SubExp
-> Body SOACS
-> ADM (Exp SOACS)
nestify [(Param (TypeBase Shape Uniqueness), SubExp)]
val_pats LoopForm
_form VName
i IntType
it SubExp
_bound Body SOACS
body
          | Integer
n Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> Integer
1 = do
              Exp SOACS
-> (Exp SOACS
    -> [(FParam SOACS, SubExp)]
    -> LoopForm
    -> VName
    -> IntType
    -> SubExp
    -> Body SOACS
    -> ADM (Exp SOACS))
-> ADM (Exp SOACS)
forall (m :: * -> *) rep a.
(MonadFreshNames m, Renameable rep, PrettyRep rep) =>
Exp rep
-> (Exp rep
    -> [(Param (FParamInfo rep), SubExp)]
    -> LoopForm
    -> VName
    -> IntType
    -> SubExp
    -> Body rep
    -> m a)
-> m a
renameForLoop Exp SOACS
loop ((Exp SOACS
  -> [(FParam SOACS, SubExp)]
  -> LoopForm
  -> VName
  -> IntType
  -> SubExp
  -> Body SOACS
  -> ADM (Exp SOACS))
 -> ADM (Exp SOACS))
-> (Exp SOACS
    -> [(FParam SOACS, SubExp)]
    -> LoopForm
    -> VName
    -> IntType
    -> SubExp
    -> Body SOACS
    -> ADM (Exp SOACS))
-> ADM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$ \Exp SOACS
_loop' [(FParam SOACS, SubExp)]
val_pats' LoopForm
_form' VName
i' IntType
it' SubExp
_bound' Body SOACS
body' -> do
                let loop_params :: [Param (TypeBase Shape Uniqueness)]
loop_params = ((Param (TypeBase Shape Uniqueness), SubExp)
 -> Param (TypeBase Shape Uniqueness))
-> [(Param (TypeBase Shape Uniqueness), SubExp)]
-> [Param (TypeBase Shape Uniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (TypeBase Shape Uniqueness), SubExp)
-> Param (TypeBase Shape Uniqueness)
forall a b. (a, b) -> a
fst [(Param (TypeBase Shape Uniqueness), SubExp)]
val_pats
                    loop_params' :: [Param (TypeBase Shape Uniqueness)]
loop_params' = ((Param (TypeBase Shape Uniqueness), SubExp)
 -> Param (TypeBase Shape Uniqueness))
-> [(Param (TypeBase Shape Uniqueness), SubExp)]
-> [Param (TypeBase Shape Uniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (TypeBase Shape Uniqueness), SubExp)
-> Param (TypeBase Shape Uniqueness)
forall a b. (a, b) -> a
fst [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam SOACS, SubExp)]
val_pats'
                    loop_inits' :: [SubExp]
loop_inits' = (Param (TypeBase Shape Uniqueness) -> SubExp)
-> [Param (TypeBase Shape Uniqueness)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (Param (TypeBase Shape Uniqueness) -> VName)
-> Param (TypeBase Shape Uniqueness)
-> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName) [Param (TypeBase Shape Uniqueness)]
loop_params
                    val_pats'' :: [(Param (TypeBase Shape Uniqueness), SubExp)]
val_pats'' = [Param (TypeBase Shape Uniqueness)]
-> [SubExp] -> [(Param (TypeBase Shape Uniqueness), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape Uniqueness)]
loop_params' [SubExp]
loop_inits'
                outer_body <-
                  ADM Result -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (ADM Result -> ADM (Body (Rep ADM)))
-> ADM Result -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
                    offset' <-
                      [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"offset" (Exp SOACS -> ADM SubExp)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM SubExp) -> BasicOp -> ADM SubExp
forall a b. (a -> b) -> a -> b
$
                        BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
it Overflow
OverflowUndef) SubExp
offset (VName -> SubExp
Var VName
i)

                    inner_body <- insertStmsM $ do
                      i_inner <-
                        letExp "i_inner" . BasicOp $
                          BinOp (Add it OverflowUndef) offset' (Var i')
                      pure $ substituteNames (M.singleton i' i_inner) body'

                    inner_loop <-
                      letTupExp "inner_loop"
                        =<< nestifyLoop'
                          offset'
                          (n - 1)
                          (Loop val_pats'' (ForLoop i' it' bound_se) inner_body)
                    pure $ varsRes inner_loop
                pure $ Loop val_pats (ForLoop i it bound_se) outer_body
          | Integer
n Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
1 =
              Exp SOACS -> ADM (Exp SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp SOACS -> ADM (Exp SOACS)) -> Exp SOACS -> ADM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$ [(FParam SOACS, SubExp)] -> LoopForm -> Body SOACS -> Exp SOACS
forall rep.
[(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
Loop [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam SOACS, SubExp)]
val_pats (VName -> IntType -> SubExp -> LoopForm
ForLoop VName
i IntType
it SubExp
bound_se) Body SOACS
body
          | Bool
otherwise = Exp SOACS -> ADM (Exp SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp SOACS
loop

-- | @stripmine n pat loop@ stripmines a loop into a depth-@n@ loop nest.
-- An additional @bound - (floor(bound^(1/n)))^n@-iteration remainder loop is
-- inserted after the stripmined loop which executes the remaining iterations
-- so that the stripmined loop is semantically equivalent to the original loop.
stripmine :: Integer -> Pat Type -> Exp SOACS -> ADM (Stms SOACS)
stripmine :: Integer -> Pat Type -> Exp SOACS -> ADM (Stms SOACS)
stripmine Integer
n Pat Type
pat Exp SOACS
loop = do
  Exp SOACS
-> ([(FParam SOACS, SubExp)]
    -> LoopForm
    -> VName
    -> IntType
    -> SubExp
    -> Body SOACS
    -> ADM (Stms SOACS))
-> ADM (Stms SOACS)
forall rep a.
PrettyRep rep =>
Exp rep
-> ([(Param (FParamInfo rep), SubExp)]
    -> LoopForm -> VName -> IntType -> SubExp -> Body rep -> a)
-> a
bindForLoop Exp SOACS
loop (([(FParam SOACS, SubExp)]
  -> LoopForm
  -> VName
  -> IntType
  -> SubExp
  -> Body SOACS
  -> ADM (Stms SOACS))
 -> ADM (Stms SOACS))
-> ([(FParam SOACS, SubExp)]
    -> LoopForm
    -> VName
    -> IntType
    -> SubExp
    -> Body SOACS
    -> ADM (Stms SOACS))
-> ADM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ \[(FParam SOACS, SubExp)]
_val_pats LoopForm
_form VName
_i IntType
it SubExp
bound Body SOACS
_body -> do
    let n_root :: SubExp
n_root = PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ FloatValue -> PrimValue
FloatValue (FloatValue -> PrimValue) -> FloatValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ FloatType -> Double -> FloatValue
forall num. Real num => FloatType -> num -> FloatValue
floatValue FloatType
Float64 (Double
1 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Integer -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
n :: Double)
    bound_float <- [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"bound_f64" (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> FloatType -> ConvOp
UIToFP IntType
it FloatType
Float64) SubExp
bound
    bound' <- letSubExp "bound" $ BasicOp $ BinOp (FPow Float64) bound_float n_root
    bound_int <- letSubExp "bound_int" $ BasicOp $ ConvOp (FPToUI Float64 it) bound'
    total_iters <-
      letSubExp "total_iters" . BasicOp $
        BinOp (Pow it) bound_int (Constant $ IntValue $ intValue it n)
    remain_iters <-
      letSubExp "remain_iters" $ BasicOp $ BinOp (Sub it OverflowUndef) bound total_iters
    mined_loop <- nestifyLoop bound_int n loop
    pat' <- renamePat pat
    renameForLoop loop $ \Exp SOACS
_loop [(FParam SOACS, SubExp)]
val_pats' LoopForm
_form' VName
i' IntType
it' SubExp
_bound' Body SOACS
body' -> do
      remain_body <- ADM (Body (Rep ADM)) -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m (Body (Rep m)) -> m (Body (Rep m))
insertStmsM (ADM (Body (Rep ADM)) -> ADM (Body (Rep ADM)))
-> ADM (Body (Rep ADM)) -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
        i_remain <-
          [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"i_remain" (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
            BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
it Overflow
OverflowUndef) SubExp
total_iters (VName -> SubExp
Var VName
i')
        pure $ substituteNames (M.singleton i' i_remain) body'
      let loop_params_rem = ((Param (TypeBase Shape Uniqueness), SubExp)
 -> Param (TypeBase Shape Uniqueness))
-> [(Param (TypeBase Shape Uniqueness), SubExp)]
-> [Param (TypeBase Shape Uniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (TypeBase Shape Uniqueness), SubExp)
-> Param (TypeBase Shape Uniqueness)
forall a b. (a, b) -> a
fst [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam SOACS, SubExp)]
val_pats'
          loop_inits_rem = (PatElem Type -> SubExp) -> [PatElem Type] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (PatElem Type -> VName) -> PatElem Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName) ([PatElem Type] -> [SubExp]) -> [PatElem Type] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat'
          val_pats_rem = [Param (TypeBase Shape Uniqueness)]
-> [SubExp] -> [(Param (TypeBase Shape Uniqueness), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape Uniqueness)]
loop_params_rem [SubExp]
loop_inits_rem
          remain_loop = [(FParam SOACS, SubExp)] -> LoopForm -> Body SOACS -> Exp SOACS
forall rep.
[(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
Loop [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam SOACS, SubExp)]
val_pats_rem (VName -> IntType -> SubExp -> LoopForm
ForLoop VName
i' IntType
it' SubExp
remain_iters) Body SOACS
remain_body
      collectStms_ $ do
        letBind pat' mined_loop
        letBind pat remain_loop

-- | Stripmines a statement. Only has an effect when the statement's
-- expression is a for-loop with a @#[stripmine(n)]@ attribute, where
-- @n@ is the nesting depth.
stripmineStm :: Stm SOACS -> ADM (Stms SOACS)
stripmineStm :: Stm SOACS -> ADM (Stms SOACS)
stripmineStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux loop :: Exp SOACS
loop@(Loop [(FParam SOACS, SubExp)]
_ ForLoop {} Body SOACS
_)) =
  case [Integer]
nums of
    (Integer
n : [Integer]
_) -> Integer -> Pat Type -> Exp SOACS -> ADM (Stms SOACS)
stripmine Integer
n Pat Type
Pat (LetDec SOACS)
pat Exp SOACS
loop
    [Integer]
_ -> Stms SOACS -> ADM (Stms SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms SOACS -> ADM (Stms SOACS)) -> Stms SOACS -> ADM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Stms SOACS
forall rep. Stm rep -> Stms rep
oneStm Stm SOACS
stm
  where
    extractNum :: Attr -> Maybe Integer
extractNum (AttrComp Name
"stripmine" [AttrInt Integer
n]) = Integer -> Maybe Integer
forall a. a -> Maybe a
Just Integer
n
    extractNum Attr
_ = Maybe Integer
forall a. Maybe a
Nothing
    nums :: [Integer]
nums = [Maybe Integer] -> [Integer]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe Integer] -> [Integer]) -> [Maybe Integer] -> [Integer]
forall a b. (a -> b) -> a -> b
$ (Attr -> Maybe Integer) -> Attrs -> [Maybe Integer]
forall a. (Attr -> a) -> Attrs -> [a]
mapAttrs Attr -> Maybe Integer
extractNum (Attrs -> [Maybe Integer]) -> Attrs -> [Maybe Integer]
forall a b. (a -> b) -> a -> b
$ StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
StmAux (ExpDec SOACS)
aux
stripmineStm Stm SOACS
stm = Stms SOACS -> ADM (Stms SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms SOACS -> ADM (Stms SOACS)) -> Stms SOACS -> ADM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Stms SOACS
forall rep. Stm rep -> Stms rep
oneStm Stm SOACS
stm

stripmineStms :: Stms SOACS -> ADM (Stms SOACS)
stripmineStms :: Stms SOACS -> ADM (Stms SOACS)
stripmineStms = (Stm SOACS -> ADM (Stms SOACS)) -> Stms SOACS -> ADM (Stms SOACS)
forall m (t :: * -> *) (f :: * -> *) a.
(Monoid m, Traversable t, Applicative f) =>
(a -> f m) -> t a -> f m
traverseFold Stm SOACS -> ADM (Stms SOACS)
stripmineStm

-- | Forward pass transformation of a loop. This includes modifying the loop
-- to save the loop values at each iteration onto a tape as well as copying
-- any consumed arrays in the loop's body and consuming said copies in lieu of
-- the originals (which will be consumed later in the reverse pass).
fwdLoop :: Pat Type -> StmAux () -> Exp SOACS -> ADM ()
fwdLoop :: Pat Type -> StmAux () -> Exp SOACS -> ADM ()
fwdLoop Pat Type
pat StmAux ()
aux Exp SOACS
loop =
  Exp SOACS
-> ([(FParam SOACS, SubExp)]
    -> LoopForm -> VName -> IntType -> SubExp -> Body SOACS -> ADM ())
-> ADM ()
forall rep a.
PrettyRep rep =>
Exp rep
-> ([(Param (FParamInfo rep), SubExp)]
    -> LoopForm -> VName -> IntType -> SubExp -> Body rep -> a)
-> a
bindForLoop Exp SOACS
loop (([(FParam SOACS, SubExp)]
  -> LoopForm -> VName -> IntType -> SubExp -> Body SOACS -> ADM ())
 -> ADM ())
-> ([(FParam SOACS, SubExp)]
    -> LoopForm -> VName -> IntType -> SubExp -> Body SOACS -> ADM ())
-> ADM ()
forall a b. (a -> b) -> a -> b
$ \[(FParam SOACS, SubExp)]
val_pats LoopForm
form VName
i IntType
_it SubExp
bound Body SOACS
body -> do
    bound64 <- IntType -> SubExp -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
Int64 SubExp
bound
    let loop_params = ((Param (TypeBase Shape Uniqueness), SubExp)
 -> Param (TypeBase Shape Uniqueness))
-> [(Param (TypeBase Shape Uniqueness), SubExp)]
-> [Param (TypeBase Shape Uniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (TypeBase Shape Uniqueness), SubExp)
-> Param (TypeBase Shape Uniqueness)
forall a b. (a, b) -> a
fst [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam SOACS, SubExp)]
val_pats
        is_true_dep = Attr -> Attrs -> Bool
inAttrs (Name -> Attr
AttrName Name
"true_dep") (Attrs -> Bool) -> (Param dec -> Attrs) -> Param dec -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> Attrs
forall dec. Param dec -> Attrs
paramAttrs
        dont_copy_params = (Param (TypeBase Shape Uniqueness) -> Bool)
-> [Param (TypeBase Shape Uniqueness)]
-> [Param (TypeBase Shape Uniqueness)]
forall a. (a -> Bool) -> [a] -> [a]
filter Param (TypeBase Shape Uniqueness) -> Bool
forall {dec}. Param dec -> Bool
is_true_dep [Param (TypeBase Shape Uniqueness)]
loop_params
        dont_copy = (Param (TypeBase Shape Uniqueness) -> VName)
-> [Param (TypeBase Shape Uniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName [Param (TypeBase Shape Uniqueness)]
dont_copy_params
        loop_params_to_copy = [Param (TypeBase Shape Uniqueness)]
loop_params [Param (TypeBase Shape Uniqueness)]
-> [Param (TypeBase Shape Uniqueness)]
-> [Param (TypeBase Shape Uniqueness)]
forall a. Eq a => [a] -> [a] -> [a]
\\ [Param (TypeBase Shape Uniqueness)]
dont_copy_params

    empty_saved_array <-
      forM loop_params_to_copy $ \Param (TypeBase Shape Uniqueness)
p ->
        [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp (VName -> [Char]
baseString (Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape Uniqueness)
p) [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_empty_saved")
          (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => Type -> m (Exp (Rep m))
eBlank (TypeBase Shape Uniqueness -> Shape -> NoUniqueness -> Type
forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf (Param (TypeBase Shape Uniqueness) -> TypeBase Shape Uniqueness
forall dec. Param dec -> dec
paramDec Param (TypeBase Shape Uniqueness)
p) ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
bound64]) NoUniqueness
NoUniqueness)

    (body', (saved_pats, saved_params)) <- buildBody $
      localScope (scopeOfFParams loop_params) $
        localScope (scopeOfLoopForm form) $ do
          copy_substs <- copyConsumedArrsInBody dont_copy body
          addStms $ bodyStms body
          i_i64 <- asIntS Int64 $ Var i
          (saved_updates, saved_pats_params) <- fmap unzip $
            forM loop_params_to_copy $ \Param (TypeBase Shape Uniqueness)
p -> do
              let v :: VName
v = Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape Uniqueness)
p
                  t :: TypeBase Shape Uniqueness
t = Param (TypeBase Shape Uniqueness) -> TypeBase Shape Uniqueness
forall dec. Param dec -> dec
paramDec Param (TypeBase Shape Uniqueness)
p
              saved_param_v <- [Char] -> ADM VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName ([Char] -> ADM VName) -> [Char] -> ADM VName
forall a b. (a -> b) -> a -> b
$ VName -> [Char]
baseString VName
v [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_saved"
              saved_pat_v <- newVName $ baseString v <> "_saved"
              setLoopTape v saved_pat_v
              let saved_param = Attrs
-> VName
-> TypeBase Shape Uniqueness
-> Param (TypeBase Shape Uniqueness)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty VName
saved_param_v (TypeBase Shape Uniqueness -> Param (TypeBase Shape Uniqueness))
-> TypeBase Shape Uniqueness -> Param (TypeBase Shape Uniqueness)
forall a b. (a -> b) -> a -> b
$ TypeBase Shape Uniqueness
-> Shape -> Uniqueness -> TypeBase Shape Uniqueness
forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf TypeBase Shape Uniqueness
t ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
bound64]) Uniqueness
Unique
                  saved_pat = VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem VName
saved_pat_v (Type -> PatElem Type) -> Type -> PatElem Type
forall a b. (a -> b) -> a -> b
$ TypeBase Shape Uniqueness -> Shape -> NoUniqueness -> Type
forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf TypeBase Shape Uniqueness
t ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
bound64]) NoUniqueness
NoUniqueness
              saved_update <-
                localScope (scopeOfFParams [saved_param])
                  $ letInPlace
                    (baseString v <> "_saved_update")
                    saved_param_v
                    (fullSlice (fromDecl $ paramDec saved_param) [DimFix i_i64])
                  $ substituteNames copy_substs
                  $ BasicOp
                  $ SubExp
                  $ Var v
              pure (saved_update, (saved_pat, saved_param))
          pure (bodyResult body <> varsRes saved_updates, unzip saved_pats_params)

    let pat' = Pat Type
pat Pat Type -> Pat Type -> Pat Type
forall a. Semigroup a => a -> a -> a
<> [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type]
saved_pats
        val_pats' = [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam SOACS, SubExp)]
val_pats [(Param (TypeBase Shape Uniqueness), SubExp)]
-> [(Param (TypeBase Shape Uniqueness), SubExp)]
-> [(Param (TypeBase Shape Uniqueness), SubExp)]
forall a. Semigroup a => a -> a -> a
<> [Param (TypeBase Shape Uniqueness)]
-> [SubExp] -> [(Param (TypeBase Shape Uniqueness), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape Uniqueness)]
saved_params [SubExp]
empty_saved_array
    addStm $ Let pat' aux $ Loop val_pats' form body'

-- | Construct a loop value-pattern for the adjoint of the
-- given variable.
valPatAdj :: VName -> ADM (Param DeclType, SubExp)
valPatAdj :: VName -> ADM (Param (TypeBase Shape Uniqueness), SubExp)
valPatAdj VName
v = do
  v_adj <- VName -> ADM VName
adjVName VName
v
  init_adj <- lookupAdjVal v
  t <- lookupType init_adj
  pure (Param mempty v_adj (toDecl t Unique), Var init_adj)

valPatAdjs :: LoopInfo [VName] -> ADM (LoopInfo [(Param DeclType, SubExp)])
valPatAdjs :: LoopInfo [VName]
-> ADM (LoopInfo [(Param (TypeBase Shape Uniqueness), SubExp)])
valPatAdjs = (([VName] -> ADM [(Param (TypeBase Shape Uniqueness), SubExp)])
-> LoopInfo [VName]
-> ADM (LoopInfo [(Param (TypeBase Shape Uniqueness), SubExp)])
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> LoopInfo a -> m (LoopInfo b)
mapM (([VName] -> ADM [(Param (TypeBase Shape Uniqueness), SubExp)])
 -> LoopInfo [VName]
 -> ADM (LoopInfo [(Param (TypeBase Shape Uniqueness), SubExp)]))
-> ((VName -> ADM (Param (TypeBase Shape Uniqueness), SubExp))
    -> [VName] -> ADM [(Param (TypeBase Shape Uniqueness), SubExp)])
-> (VName -> ADM (Param (TypeBase Shape Uniqueness), SubExp))
-> LoopInfo [VName]
-> ADM (LoopInfo [(Param (TypeBase Shape Uniqueness), SubExp)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> ADM (Param (TypeBase Shape Uniqueness), SubExp))
-> [VName] -> ADM [(Param (TypeBase Shape Uniqueness), SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM) VName -> ADM (Param (TypeBase Shape Uniqueness), SubExp)
valPatAdj

-- | Reverses a loop by substituting the loop index.
reverseIndices :: Exp SOACS -> ADM (Substitutions, Stms SOACS)
reverseIndices :: Exp SOACS -> ADM (Map VName VName, Stms SOACS)
reverseIndices Exp SOACS
loop = do
  Exp SOACS
-> ([(FParam SOACS, SubExp)]
    -> LoopForm
    -> VName
    -> IntType
    -> SubExp
    -> Body SOACS
    -> ADM (Map VName VName, Stms SOACS))
-> ADM (Map VName VName, Stms SOACS)
forall rep a.
PrettyRep rep =>
Exp rep
-> ([(Param (FParamInfo rep), SubExp)]
    -> LoopForm -> VName -> IntType -> SubExp -> Body rep -> a)
-> a
bindForLoop Exp SOACS
loop (([(FParam SOACS, SubExp)]
  -> LoopForm
  -> VName
  -> IntType
  -> SubExp
  -> Body SOACS
  -> ADM (Map VName VName, Stms SOACS))
 -> ADM (Map VName VName, Stms SOACS))
-> ([(FParam SOACS, SubExp)]
    -> LoopForm
    -> VName
    -> IntType
    -> SubExp
    -> Body SOACS
    -> ADM (Map VName VName, Stms SOACS))
-> ADM (Map VName VName, Stms SOACS)
forall a b. (a -> b) -> a -> b
$ \[(FParam SOACS, SubExp)]
_val_pats LoopForm
form VName
i IntType
it SubExp
bound Body SOACS
_body -> do
    bound_minus_one <-
      Scope SOACS -> ADM SubExp -> ADM SubExp
forall a. Scope SOACS -> ADM a -> ADM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (LoopForm -> Scope SOACS
forall rep. LoopForm -> Scope rep
scopeOfLoopForm LoopForm
form) (ADM SubExp -> ADM SubExp) -> ADM SubExp -> ADM SubExp
forall a b. (a -> b) -> a -> b
$
        let one :: SubExp
one = PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
it (Int
1 :: Int)
         in [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"bound-1" (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Sub IntType
it Overflow
OverflowUndef) SubExp
bound SubExp
one

    (i_rev, i_stms) <- collectStms $
      localScope (scopeOfLoopForm form) $ do
        letExp (baseString i <> "_rev") $
          BasicOp $
            BinOp (Sub it OverflowWrap) bound_minus_one (Var i)

    pure (M.singleton i i_rev, i_stms)

-- | Pures a substitution which substitutes values in the reverse
-- loop body with values from the tape.
restore :: Stms SOACS -> [Param DeclType] -> VName -> ADM Substitutions
restore :: Stms SOACS
-> [Param (TypeBase Shape Uniqueness)]
-> VName
-> ADM (Map VName VName)
restore Stms SOACS
stms_adj [Param (TypeBase Shape Uniqueness)]
loop_params' VName
i' =
  [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, VName)] -> Map VName VName)
-> ([Maybe (VName, VName)] -> [(VName, VName)])
-> [Maybe (VName, VName)]
-> Map VName VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe (VName, VName)] -> [(VName, VName)]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe (VName, VName)] -> Map VName VName)
-> ADM [Maybe (VName, VName)] -> ADM (Map VName VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Param (TypeBase Shape Uniqueness) -> ADM (Maybe (VName, VName)))
-> [Param (TypeBase Shape Uniqueness)]
-> ADM [Maybe (VName, VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Param (TypeBase Shape Uniqueness) -> ADM (Maybe (VName, VName))
f [Param (TypeBase Shape Uniqueness)]
loop_params'
  where
    dont_copy :: [VName]
dont_copy =
      (Param (TypeBase Shape Uniqueness) -> VName)
-> [Param (TypeBase Shape Uniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName ([Param (TypeBase Shape Uniqueness)] -> [VName])
-> [Param (TypeBase Shape Uniqueness)] -> [VName]
forall a b. (a -> b) -> a -> b
$ (Param (TypeBase Shape Uniqueness) -> Bool)
-> [Param (TypeBase Shape Uniqueness)]
-> [Param (TypeBase Shape Uniqueness)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Attr -> Attrs -> Bool
inAttrs (Name -> Attr
AttrName Name
"true_dep") (Attrs -> Bool)
-> (Param (TypeBase Shape Uniqueness) -> Attrs)
-> Param (TypeBase Shape Uniqueness)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (TypeBase Shape Uniqueness) -> Attrs
forall dec. Param dec -> Attrs
paramAttrs) [Param (TypeBase Shape Uniqueness)]
loop_params'
    f :: Param (TypeBase Shape Uniqueness) -> ADM (Maybe (VName, VName))
f Param (TypeBase Shape Uniqueness)
p
      | VName
v VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
dont_copy = do
          m_vs <- VName -> ADM (Maybe VName)
lookupLoopTape VName
v
          case m_vs of
            Maybe VName
Nothing -> Maybe (VName, VName) -> ADM (Maybe (VName, VName))
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (VName, VName)
forall a. Maybe a
Nothing
            Just VName
vs -> do
              vs_t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
vs
              i_i64' <- asIntS Int64 $ Var i'
              v' <- letExp "restore" $ BasicOp $ Index vs $ fullSlice vs_t [DimFix i_i64']
              t <- lookupType v
              v'' <- case (t, v `elem` consumed) of
                (Array {}, Bool
True) ->
                  [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"restore_copy" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v'
                (Type, Bool)
_ -> VName -> ADM VName
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v'
              pure $ Just (v, v'')
      | Bool
otherwise = Maybe (VName, VName) -> ADM (Maybe (VName, VName))
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (VName, VName)
forall a. Maybe a
Nothing
      where
        v :: VName
v = Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape Uniqueness)
p
        consumed :: [VName]
consumed = Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Stms (Aliases SOACS) -> Names
forall rep. Aliased rep => Stms rep -> Names
consumedInStms (Stms (Aliases SOACS) -> Names) -> Stms (Aliases SOACS) -> Names
forall a b. (a -> b) -> a -> b
$ (Stms (Aliases SOACS), AliasesAndConsumed) -> Stms (Aliases SOACS)
forall a b. (a, b) -> a
fst ((Stms (Aliases SOACS), AliasesAndConsumed)
 -> Stms (Aliases SOACS))
-> (Stms (Aliases SOACS), AliasesAndConsumed)
-> Stms (Aliases SOACS)
forall a b. (a -> b) -> a -> b
$ AliasTable
-> Stms SOACS -> (Stms (Aliases SOACS), AliasesAndConsumed)
forall rep.
AliasableRep rep =>
AliasTable -> Stms rep -> (Stms (Aliases rep), AliasesAndConsumed)
Alias.analyseStms AliasTable
forall a. Monoid a => a
mempty Stms SOACS
stms_adj

-- | A type to keep track of and seperate values corresponding to different
-- parts of the loop.
data LoopInfo a = LoopInfo
  { forall a. LoopInfo a -> a
loopRes :: a,
    forall a. LoopInfo a -> a
loopFree :: a,
    forall a. LoopInfo a -> a
loopVals :: a
  }
  deriving ((forall a b. (a -> b) -> LoopInfo a -> LoopInfo b)
-> (forall a b. a -> LoopInfo b -> LoopInfo a) -> Functor LoopInfo
forall a b. a -> LoopInfo b -> LoopInfo a
forall a b. (a -> b) -> LoopInfo a -> LoopInfo b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> LoopInfo a -> LoopInfo b
fmap :: forall a b. (a -> b) -> LoopInfo a -> LoopInfo b
$c<$ :: forall a b. a -> LoopInfo b -> LoopInfo a
<$ :: forall a b. a -> LoopInfo b -> LoopInfo a
Functor, (forall m. Monoid m => LoopInfo m -> m)
-> (forall m a. Monoid m => (a -> m) -> LoopInfo a -> m)
-> (forall m a. Monoid m => (a -> m) -> LoopInfo a -> m)
-> (forall a b. (a -> b -> b) -> b -> LoopInfo a -> b)
-> (forall a b. (a -> b -> b) -> b -> LoopInfo a -> b)
-> (forall b a. (b -> a -> b) -> b -> LoopInfo a -> b)
-> (forall b a. (b -> a -> b) -> b -> LoopInfo a -> b)
-> (forall a. (a -> a -> a) -> LoopInfo a -> a)
-> (forall a. (a -> a -> a) -> LoopInfo a -> a)
-> (forall a. LoopInfo a -> [a])
-> (forall a. LoopInfo a -> Bool)
-> (forall a. LoopInfo a -> Int)
-> (forall a. Eq a => a -> LoopInfo a -> Bool)
-> (forall a. Ord a => LoopInfo a -> a)
-> (forall a. Ord a => LoopInfo a -> a)
-> (forall a. Num a => LoopInfo a -> a)
-> (forall a. Num a => LoopInfo a -> a)
-> Foldable LoopInfo
forall a. Eq a => a -> LoopInfo a -> Bool
forall a. Num a => LoopInfo a -> a
forall a. Ord a => LoopInfo a -> a
forall m. Monoid m => LoopInfo m -> m
forall a. LoopInfo a -> Bool
forall a. LoopInfo a -> Int
forall a. LoopInfo a -> [a]
forall a. (a -> a -> a) -> LoopInfo a -> a
forall m a. Monoid m => (a -> m) -> LoopInfo a -> m
forall b a. (b -> a -> b) -> b -> LoopInfo a -> b
forall a b. (a -> b -> b) -> b -> LoopInfo a -> b
forall (t :: * -> *).
(forall m. Monoid m => t m -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. t a -> [a])
-> (forall a. t a -> Bool)
-> (forall a. t a -> Int)
-> (forall a. Eq a => a -> t a -> Bool)
-> (forall a. Ord a => t a -> a)
-> (forall a. Ord a => t a -> a)
-> (forall a. Num a => t a -> a)
-> (forall a. Num a => t a -> a)
-> Foldable t
$cfold :: forall m. Monoid m => LoopInfo m -> m
fold :: forall m. Monoid m => LoopInfo m -> m
$cfoldMap :: forall m a. Monoid m => (a -> m) -> LoopInfo a -> m
foldMap :: forall m a. Monoid m => (a -> m) -> LoopInfo a -> m
$cfoldMap' :: forall m a. Monoid m => (a -> m) -> LoopInfo a -> m
foldMap' :: forall m a. Monoid m => (a -> m) -> LoopInfo a -> m
$cfoldr :: forall a b. (a -> b -> b) -> b -> LoopInfo a -> b
foldr :: forall a b. (a -> b -> b) -> b -> LoopInfo a -> b
$cfoldr' :: forall a b. (a -> b -> b) -> b -> LoopInfo a -> b
foldr' :: forall a b. (a -> b -> b) -> b -> LoopInfo a -> b
$cfoldl :: forall b a. (b -> a -> b) -> b -> LoopInfo a -> b
foldl :: forall b a. (b -> a -> b) -> b -> LoopInfo a -> b
$cfoldl' :: forall b a. (b -> a -> b) -> b -> LoopInfo a -> b
foldl' :: forall b a. (b -> a -> b) -> b -> LoopInfo a -> b
$cfoldr1 :: forall a. (a -> a -> a) -> LoopInfo a -> a
foldr1 :: forall a. (a -> a -> a) -> LoopInfo a -> a
$cfoldl1 :: forall a. (a -> a -> a) -> LoopInfo a -> a
foldl1 :: forall a. (a -> a -> a) -> LoopInfo a -> a
$ctoList :: forall a. LoopInfo a -> [a]
toList :: forall a. LoopInfo a -> [a]
$cnull :: forall a. LoopInfo a -> Bool
null :: forall a. LoopInfo a -> Bool
$clength :: forall a. LoopInfo a -> Int
length :: forall a. LoopInfo a -> Int
$celem :: forall a. Eq a => a -> LoopInfo a -> Bool
elem :: forall a. Eq a => a -> LoopInfo a -> Bool
$cmaximum :: forall a. Ord a => LoopInfo a -> a
maximum :: forall a. Ord a => LoopInfo a -> a
$cminimum :: forall a. Ord a => LoopInfo a -> a
minimum :: forall a. Ord a => LoopInfo a -> a
$csum :: forall a. Num a => LoopInfo a -> a
sum :: forall a. Num a => LoopInfo a -> a
$cproduct :: forall a. Num a => LoopInfo a -> a
product :: forall a. Num a => LoopInfo a -> a
Foldable, Functor LoopInfo
Foldable LoopInfo
(Functor LoopInfo, Foldable LoopInfo) =>
(forall (f :: * -> *) a b.
 Applicative f =>
 (a -> f b) -> LoopInfo a -> f (LoopInfo b))
-> (forall (f :: * -> *) a.
    Applicative f =>
    LoopInfo (f a) -> f (LoopInfo a))
-> (forall (m :: * -> *) a b.
    Monad m =>
    (a -> m b) -> LoopInfo a -> m (LoopInfo b))
-> (forall (m :: * -> *) a.
    Monad m =>
    LoopInfo (m a) -> m (LoopInfo a))
-> Traversable LoopInfo
forall (t :: * -> *).
(Functor t, Foldable t) =>
(forall (f :: * -> *) a b.
 Applicative f =>
 (a -> f b) -> t a -> f (t b))
-> (forall (f :: * -> *) a. Applicative f => t (f a) -> f (t a))
-> (forall (m :: * -> *) a b.
    Monad m =>
    (a -> m b) -> t a -> m (t b))
-> (forall (m :: * -> *) a. Monad m => t (m a) -> m (t a))
-> Traversable t
forall (m :: * -> *) a. Monad m => LoopInfo (m a) -> m (LoopInfo a)
forall (f :: * -> *) a.
Applicative f =>
LoopInfo (f a) -> f (LoopInfo a)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> LoopInfo a -> m (LoopInfo b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> LoopInfo a -> f (LoopInfo b)
$ctraverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> LoopInfo a -> f (LoopInfo b)
traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> LoopInfo a -> f (LoopInfo b)
$csequenceA :: forall (f :: * -> *) a.
Applicative f =>
LoopInfo (f a) -> f (LoopInfo a)
sequenceA :: forall (f :: * -> *) a.
Applicative f =>
LoopInfo (f a) -> f (LoopInfo a)
$cmapM :: forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> LoopInfo a -> m (LoopInfo b)
mapM :: forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> LoopInfo a -> m (LoopInfo b)
$csequence :: forall (m :: * -> *) a. Monad m => LoopInfo (m a) -> m (LoopInfo a)
sequence :: forall (m :: * -> *) a. Monad m => LoopInfo (m a) -> m (LoopInfo a)
Traversable, Int -> LoopInfo a -> [Char] -> [Char]
[LoopInfo a] -> [Char] -> [Char]
LoopInfo a -> [Char]
(Int -> LoopInfo a -> [Char] -> [Char])
-> (LoopInfo a -> [Char])
-> ([LoopInfo a] -> [Char] -> [Char])
-> Show (LoopInfo a)
forall a. Show a => Int -> LoopInfo a -> [Char] -> [Char]
forall a. Show a => [LoopInfo a] -> [Char] -> [Char]
forall a. Show a => LoopInfo a -> [Char]
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
$cshowsPrec :: forall a. Show a => Int -> LoopInfo a -> [Char] -> [Char]
showsPrec :: Int -> LoopInfo a -> [Char] -> [Char]
$cshow :: forall a. Show a => LoopInfo a -> [Char]
show :: LoopInfo a -> [Char]
$cshowList :: forall a. Show a => [LoopInfo a] -> [Char] -> [Char]
showList :: [LoopInfo a] -> [Char] -> [Char]
Show)

-- | Transforms a for-loop into its reverse-mode derivative.
revLoop :: (Stms SOACS -> ADM ()) -> Pat Type -> Exp SOACS -> ADM ()
revLoop :: (Stms SOACS -> ADM ()) -> Pat Type -> Exp SOACS -> ADM ()
revLoop Stms SOACS -> ADM ()
diffStms Pat Type
pat Exp SOACS
loop =
  Exp SOACS
-> ([(FParam SOACS, SubExp)]
    -> LoopForm -> VName -> IntType -> SubExp -> Body SOACS -> ADM ())
-> ADM ()
forall rep a.
PrettyRep rep =>
Exp rep
-> ([(Param (FParamInfo rep), SubExp)]
    -> LoopForm -> VName -> IntType -> SubExp -> Body rep -> a)
-> a
bindForLoop Exp SOACS
loop (([(FParam SOACS, SubExp)]
  -> LoopForm -> VName -> IntType -> SubExp -> Body SOACS -> ADM ())
 -> ADM ())
-> ([(FParam SOACS, SubExp)]
    -> LoopForm -> VName -> IntType -> SubExp -> Body SOACS -> ADM ())
-> ADM ()
forall a b. (a -> b) -> a -> b
$ \[(FParam SOACS, SubExp)]
val_pats LoopForm
_form VName
_i IntType
_it SubExp
_bound Body SOACS
_body ->
    Exp SOACS
-> (Exp SOACS
    -> [(FParam SOACS, SubExp)]
    -> LoopForm
    -> VName
    -> IntType
    -> SubExp
    -> Body SOACS
    -> ADM ())
-> ADM ()
forall (m :: * -> *) rep a.
(MonadFreshNames m, Renameable rep, PrettyRep rep) =>
Exp rep
-> (Exp rep
    -> [(Param (FParamInfo rep), SubExp)]
    -> LoopForm
    -> VName
    -> IntType
    -> SubExp
    -> Body rep
    -> m a)
-> m a
renameForLoop Exp SOACS
loop ((Exp SOACS
  -> [(FParam SOACS, SubExp)]
  -> LoopForm
  -> VName
  -> IntType
  -> SubExp
  -> Body SOACS
  -> ADM ())
 -> ADM ())
-> (Exp SOACS
    -> [(FParam SOACS, SubExp)]
    -> LoopForm
    -> VName
    -> IntType
    -> SubExp
    -> Body SOACS
    -> ADM ())
-> ADM ()
forall a b. (a -> b) -> a -> b
$
      \Exp SOACS
loop' [(FParam SOACS, SubExp)]
val_pats' LoopForm
form' VName
i' IntType
_it' SubExp
_bound' Body SOACS
body' -> do
        let loop_params :: [Param (TypeBase Shape Uniqueness)]
loop_params = ((Param (TypeBase Shape Uniqueness), SubExp)
 -> Param (TypeBase Shape Uniqueness))
-> [(Param (TypeBase Shape Uniqueness), SubExp)]
-> [Param (TypeBase Shape Uniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (TypeBase Shape Uniqueness), SubExp)
-> Param (TypeBase Shape Uniqueness)
forall a b. (a, b) -> a
fst [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam SOACS, SubExp)]
val_pats
            ([Param (TypeBase Shape Uniqueness)]
loop_params', [SubExp]
loop_vals') = [(Param (TypeBase Shape Uniqueness), SubExp)]
-> ([Param (TypeBase Shape Uniqueness)], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam SOACS, SubExp)]
val_pats'
            getVName :: SubExp -> Maybe VName
getVName Constant {} = Maybe VName
forall a. Maybe a
Nothing
            getVName (Var VName
v) = VName -> Maybe VName
forall a. a -> Maybe a
Just VName
v
            loop_vnames :: LoopInfo [VName]
loop_vnames =
              LoopInfo
                { loopRes :: [VName]
loopRes = (SubExpRes -> Maybe VName) -> Result -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe SubExpRes -> Maybe VName
subExpResVName (Result -> [VName]) -> Result -> [VName]
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult Body SOACS
body',
                  loopFree :: [VName]
loopFree =
                    Names -> [VName]
namesToList (Exp SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn Exp SOACS
loop') [VName] -> [VName] -> [VName]
forall a. Eq a => [a] -> [a] -> [a]
\\ (SubExp -> Maybe VName) -> [SubExp] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe SubExp -> Maybe VName
getVName [SubExp]
loop_vals',
                  loopVals :: [VName]
loopVals = [VName] -> [VName]
forall a. Ord a => [a] -> [a]
nubOrd ([VName] -> [VName]) -> [VName] -> [VName]
forall a b. (a -> b) -> a -> b
$ (SubExp -> Maybe VName) -> [SubExp] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe SubExp -> Maybe VName
getVName [SubExp]
loop_vals'
                }

        Map VName VName -> ADM ()
renameLoopTape (Map VName VName -> ADM ()) -> Map VName VName -> ADM ()
forall a b. (a -> b) -> a -> b
$ [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, VName)] -> Map VName VName)
-> [(VName, VName)] -> Map VName VName
forall a b. (a -> b) -> a -> b
$ [VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param (TypeBase Shape Uniqueness) -> VName)
-> [Param (TypeBase Shape Uniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName [Param (TypeBase Shape Uniqueness)]
loop_params) ((Param (TypeBase Shape Uniqueness) -> VName)
-> [Param (TypeBase Shape Uniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName [Param (TypeBase Shape Uniqueness)]
loop_params')

        [(SubExpRes, PatElem Type)]
-> ((SubExpRes, PatElem Type) -> ADM ()) -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Result -> [PatElem Type] -> [(SubExpRes, PatElem Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult Body SOACS
body') ([PatElem Type] -> [(SubExpRes, PatElem Type)])
-> [PatElem Type] -> [(SubExpRes, PatElem Type)]
forall a b. (a -> b) -> a -> b
$ Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat) (((SubExpRes, PatElem Type) -> ADM ()) -> ADM ())
-> ((SubExpRes, PatElem Type) -> ADM ()) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \(SubExpRes
se_res, PatElem Type
pe) ->
          case SubExpRes -> Maybe VName
subExpResVName SubExpRes
se_res of
            Just VName
v -> VName -> Adj -> ADM ()
setAdj VName
v (Adj -> ADM ()) -> ADM Adj -> ADM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> ADM Adj
lookupAdj (PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName PatElem Type
pe)
            Maybe VName
Nothing -> () -> ADM ()
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

        (i_subst, i_stms) <- Exp SOACS -> ADM (Map VName VName, Stms SOACS)
reverseIndices Exp SOACS
loop'

        val_pat_adjs <- valPatAdjs loop_vnames
        let val_pat_adjs_list = [[(Param (TypeBase Shape Uniqueness), SubExp)]]
-> [(Param (TypeBase Shape Uniqueness), SubExp)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[(Param (TypeBase Shape Uniqueness), SubExp)]]
 -> [(Param (TypeBase Shape Uniqueness), SubExp)])
-> [[(Param (TypeBase Shape Uniqueness), SubExp)]]
-> [(Param (TypeBase Shape Uniqueness), SubExp)]
forall a b. (a -> b) -> a -> b
$ LoopInfo [(Param (TypeBase Shape Uniqueness), SubExp)]
-> [[(Param (TypeBase Shape Uniqueness), SubExp)]]
forall a. LoopInfo a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList LoopInfo [(Param (TypeBase Shape Uniqueness), SubExp)]
val_pat_adjs

        (loop_adjs, stms_adj) <- collectStms $
          localScope (scopeOfLoopForm form' <> scopeOfFParams (map fst val_pat_adjs_list <> loop_params')) $ do
            addStms i_stms
            (loop_adjs, stms_adj) <- collectStms $
              subAD $ do
                zipWithM_
                  (\(Param (TypeBase Shape Uniqueness), SubExp)
val_pat VName
v -> VName -> VName -> ADM ()
insAdj VName
v (Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName (Param (TypeBase Shape Uniqueness) -> VName)
-> Param (TypeBase Shape Uniqueness) -> VName
forall a b. (a -> b) -> a -> b
$ (Param (TypeBase Shape Uniqueness), SubExp)
-> Param (TypeBase Shape Uniqueness)
forall a b. (a, b) -> a
fst (Param (TypeBase Shape Uniqueness), SubExp)
val_pat))
                  val_pat_adjs_list
                  (concat $ toList loop_vnames)
                diffStms $ bodyStms body'

                loop_res_adjs <- mapM (lookupAdjVal . paramName) loop_params'
                loop_free_adjs <- mapM lookupAdjVal $ loopFree loop_vnames
                loop_vals_adjs <- mapM lookupAdjVal $ loopVals loop_vnames

                pure $
                  LoopInfo
                    { loopRes = loop_res_adjs,
                      loopFree = loop_free_adjs,
                      loopVals = loop_vals_adjs
                    }
            (substs, restore_stms) <-
              collectStms $ restore stms_adj loop_params' i'
            addStms $ substituteNames i_subst restore_stms
            addStms $ substituteNames i_subst $ substituteNames substs stms_adj
            pure loop_adjs

        inScopeOf stms_adj $
          localScope (scopeOfFParams $ map fst val_pat_adjs_list) $ do
            let body_adj = Stms SOACS -> Result -> Body SOACS
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms SOACS
stms_adj (Result -> Body SOACS) -> Result -> Body SOACS
forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes ([VName] -> Result) -> [VName] -> Result
forall a b. (a -> b) -> a -> b
$ [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[VName]] -> [VName]) -> [[VName]] -> [VName]
forall a b. (a -> b) -> a -> b
$ LoopInfo [VName] -> [[VName]]
forall a. LoopInfo a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList LoopInfo [VName]
loop_adjs
                restore_true_deps = [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, VName)] -> Map VName VName)
-> [(VName, VName)] -> Map VName VName
forall a b. (a -> b) -> a -> b
$
                  (((Param (TypeBase Shape Uniqueness), PatElem Type)
  -> Maybe (VName, VName))
 -> [(Param (TypeBase Shape Uniqueness), PatElem Type)]
 -> [(VName, VName)])
-> [(Param (TypeBase Shape Uniqueness), PatElem Type)]
-> ((Param (TypeBase Shape Uniqueness), PatElem Type)
    -> Maybe (VName, VName))
-> [(VName, VName)]
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Param (TypeBase Shape Uniqueness), PatElem Type)
 -> Maybe (VName, VName))
-> [(Param (TypeBase Shape Uniqueness), PatElem Type)]
-> [(VName, VName)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe ([Param (TypeBase Shape Uniqueness)]
-> [PatElem Type]
-> [(Param (TypeBase Shape Uniqueness), PatElem Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape Uniqueness)]
loop_params' ([PatElem Type]
 -> [(Param (TypeBase Shape Uniqueness), PatElem Type)])
-> [PatElem Type]
-> [(Param (TypeBase Shape Uniqueness), PatElem Type)]
forall a b. (a -> b) -> a -> b
$ Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat) (((Param (TypeBase Shape Uniqueness), PatElem Type)
  -> Maybe (VName, VName))
 -> [(VName, VName)])
-> ((Param (TypeBase Shape Uniqueness), PatElem Type)
    -> Maybe (VName, VName))
-> [(VName, VName)]
forall a b. (a -> b) -> a -> b
$ \(Param (TypeBase Shape Uniqueness)
p, PatElem Type
pe) ->
                    if Param (TypeBase Shape Uniqueness)
p Param (TypeBase Shape Uniqueness)
-> [Param (TypeBase Shape Uniqueness)] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` (Param (TypeBase Shape Uniqueness) -> Bool)
-> [Param (TypeBase Shape Uniqueness)]
-> [Param (TypeBase Shape Uniqueness)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Attr -> Attrs -> Bool
inAttrs (Name -> Attr
AttrName Name
"true_dep") (Attrs -> Bool)
-> (Param (TypeBase Shape Uniqueness) -> Attrs)
-> Param (TypeBase Shape Uniqueness)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (TypeBase Shape Uniqueness) -> Attrs
forall dec. Param dec -> Attrs
paramAttrs) [Param (TypeBase Shape Uniqueness)]
loop_params'
                      then (VName, VName) -> Maybe (VName, VName)
forall a. a -> Maybe a
Just (Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape Uniqueness)
p, PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName PatElem Type
pe)
                      else Maybe (VName, VName)
forall a. Maybe a
Nothing
            adjs' <-
              letTupExp "loop_adj" $
                substituteNames restore_true_deps $
                  Loop val_pat_adjs_list form' body_adj
            let (loop_res_adjs, loop_free_var_val_adjs) =
                  splitAt (length $ loopRes loop_adjs) adjs'
                (loop_free_adjs, loop_val_adjs) =
                  splitAt (length $ loopFree loop_adjs) loop_free_var_val_adjs
            returnSweepCode $ do
              zipWithM_ updateSubExpAdj loop_vals' loop_res_adjs
              zipWithM_ insAdj (loopFree loop_vnames) loop_free_adjs
              zipWithM_ updateAdj (loopVals loop_vnames) loop_val_adjs

-- | Transforms a loop into its reverse-mode derivative.
diffLoop :: (Stms SOACS -> ADM ()) -> Pat Type -> StmAux () -> Exp SOACS -> ADM () -> ADM ()
diffLoop :: (Stms SOACS -> ADM ())
-> Pat Type -> StmAux () -> Exp SOACS -> ADM () -> ADM ()
diffLoop Stms SOACS -> ADM ()
diffStms Pat Type
pat StmAux ()
aux Exp SOACS
loop ADM ()
m
  | Exp SOACS -> Bool
forall rep. Exp rep -> Bool
isWhileLoop Exp SOACS
loop =
      let getBound :: Attr -> Maybe Integer
getBound (AttrComp Name
"bound" [AttrInt Integer
b]) = Integer -> Maybe Integer
forall a. a -> Maybe a
Just Integer
b
          getBound Attr
_ = Maybe Integer
forall a. Maybe a
Nothing
          bounds :: [Integer]
bounds = [Maybe Integer] -> [Integer]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe Integer] -> [Integer]) -> [Maybe Integer] -> [Integer]
forall a b. (a -> b) -> a -> b
$ (Attr -> Maybe Integer) -> Attrs -> [Maybe Integer]
forall a. (Attr -> a) -> Attrs -> [a]
mapAttrs Attr -> Maybe Integer
getBound (Attrs -> [Maybe Integer]) -> Attrs -> [Maybe Integer]
forall a b. (a -> b) -> a -> b
$ StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
aux
       in case [Integer]
bounds of
            (Integer
bound : [Integer]
_) -> do
              let bound_se :: SubExp
bound_se = PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
Int64 Integer
bound
              for_loop <- SubExp -> Exp SOACS -> ADM (Exp SOACS)
convertWhileLoop SubExp
bound_se Exp SOACS
loop
              diffLoop diffStms pat aux for_loop m
            [Integer]
_ -> do
              bound <- Exp SOACS -> ADM SubExp
computeWhileIters Exp SOACS
loop
              for_loop <- convertWhileLoop bound =<< renameExp loop
              diffLoop diffStms pat aux for_loop m
  | Bool
otherwise = do
      Pat Type -> StmAux () -> Exp SOACS -> ADM ()
fwdLoop Pat Type
pat StmAux ()
aux Exp SOACS
loop
      ADM ()
m
      (Stms SOACS -> ADM ()) -> Pat Type -> Exp SOACS -> ADM ()
revLoop Stms SOACS -> ADM ()
diffStms Pat Type
pat Exp SOACS
loop