{-# LANGUAGE Strict #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

-- |
--
-- Perform general rule-based simplification based on data dependency
-- information.  This module will:
--
--    * Perform common-subexpression elimination (CSE).
--
--    * Hoist expressions out of loops (including lambdas) and
--    branches.  This is done as aggressively as possible.
--
--    * Apply simplification rules (see
--    "Futhark.Optimise.Simplification.Rules").
--
-- If you just want to run the simplifier as simply as possible, you
-- may prefer to use the "Futhark.Optimise.Simplify" module.
module Futhark.Optimise.Simplify.Engine
  ( -- * Monadic interface
    SimpleM,
    runSimpleM,
    SimpleOps (..),
    SimplifyOp,
    bindableSimpleOps,
    Env (envHoistBlockers, envRules),
    emptyEnv,
    HoistBlockers (..),
    neverBlocks,
    noExtraHoistBlockers,
    neverHoist,
    BlockPred,
    orIf,
    hasFree,
    isConsumed,
    isConsuming,
    isFalse,
    isOp,
    isNotSafe,
    isDeviceMigrated,
    asksEngineEnv,
    askVtable,
    localVtable,

    -- * Building blocks
    SimplifiableRep,
    Simplifiable (..),
    simplifyFun,
    simplifyStms,
    simplifyStmsWithUsage,
    simplifyLambda,
    simplifyLambdaNoHoisting,
    bindLParams,
    simplifyBody,
    ST.SymbolTable,
    hoistStms,
    blockIf,
    blockMigrated,
    enterLoop,
    constructBody,
    module Futhark.Optimise.Simplify.Rep,
  )
where

import Control.Monad
import Control.Monad.Reader
import Control.Monad.State.Strict
import Data.Bitraversable
import Data.Either
import Data.List (find, foldl', inits, mapAccumL)
import Data.Map qualified as M
import Data.Maybe
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Analysis.UsageTable qualified as UT
import Futhark.Construct
import Futhark.IR
import Futhark.IR.Prop.Aliases
import Futhark.Optimise.Simplify.Rep
import Futhark.Optimise.Simplify.Rule
import Futhark.Util (nubOrd)

data HoistBlockers rep = HoistBlockers
  { -- | Blocker for hoisting out of parallel loops.
    forall {k} (rep :: k). HoistBlockers rep -> BlockPred (Wise rep)
blockHoistPar :: BlockPred (Wise rep),
    -- | Blocker for hoisting out of sequential loops.
    forall {k} (rep :: k). HoistBlockers rep -> BlockPred (Wise rep)
blockHoistSeq :: BlockPred (Wise rep),
    -- | Blocker for hoisting out of branches.
    forall {k} (rep :: k). HoistBlockers rep -> BlockPred (Wise rep)
blockHoistBranch :: BlockPred (Wise rep),
    forall {k} (rep :: k). HoistBlockers rep -> Stm (Wise rep) -> Bool
isAllocation :: Stm (Wise rep) -> Bool
  }

noExtraHoistBlockers :: HoistBlockers rep
noExtraHoistBlockers :: forall {k} (rep :: k). HoistBlockers rep
noExtraHoistBlockers =
  BlockPred (Wise rep)
-> BlockPred (Wise rep)
-> BlockPred (Wise rep)
-> (Stm (Wise rep) -> Bool)
-> HoistBlockers rep
forall {k} (rep :: k).
BlockPred (Wise rep)
-> BlockPred (Wise rep)
-> BlockPred (Wise rep)
-> (Stm (Wise rep) -> Bool)
-> HoistBlockers rep
HoistBlockers BlockPred (Wise rep)
forall rep. BlockPred rep
neverBlocks BlockPred (Wise rep)
forall rep. BlockPred rep
neverBlocks BlockPred (Wise rep)
forall rep. BlockPred rep
neverBlocks (Bool -> Stm (Wise rep) -> Bool
forall a b. a -> b -> a
const Bool
False)

neverHoist :: HoistBlockers rep
neverHoist :: forall {k} (rep :: k). HoistBlockers rep
neverHoist =
  BlockPred (Wise rep)
-> BlockPred (Wise rep)
-> BlockPred (Wise rep)
-> (Stm (Wise rep) -> Bool)
-> HoistBlockers rep
forall {k} (rep :: k).
BlockPred (Wise rep)
-> BlockPred (Wise rep)
-> BlockPred (Wise rep)
-> (Stm (Wise rep) -> Bool)
-> HoistBlockers rep
HoistBlockers BlockPred (Wise rep)
forall rep. BlockPred rep
alwaysBlocks BlockPred (Wise rep)
forall rep. BlockPred rep
alwaysBlocks BlockPred (Wise rep)
forall rep. BlockPred rep
alwaysBlocks (Bool -> Stm (Wise rep) -> Bool
forall a b. a -> b -> a
const Bool
False)

data Env rep = Env
  { forall {k} (rep :: k). Env rep -> RuleBook (Wise rep)
envRules :: RuleBook (Wise rep),
    forall {k} (rep :: k). Env rep -> HoistBlockers rep
envHoistBlockers :: HoistBlockers rep,
    forall {k} (rep :: k). Env rep -> SymbolTable (Wise rep)
envVtable :: ST.SymbolTable (Wise rep)
  }

emptyEnv :: RuleBook (Wise rep) -> HoistBlockers rep -> Env rep
emptyEnv :: forall {k} (rep :: k).
RuleBook (Wise rep) -> HoistBlockers rep -> Env rep
emptyEnv RuleBook (Wise rep)
rules HoistBlockers rep
blockers =
  Env
    { envRules :: RuleBook (Wise rep)
envRules = RuleBook (Wise rep)
rules,
      envHoistBlockers :: HoistBlockers rep
envHoistBlockers = HoistBlockers rep
blockers,
      envVtable :: SymbolTable (Wise rep)
envVtable = SymbolTable (Wise rep)
forall a. Monoid a => a
mempty
    }

-- | A function that protects a hoisted operation (if possible).  The
-- first operand is the condition of the 'Case' we have hoisted out of
-- (or equivalently, a boolean indicating whether a loop has nonzero
-- trip count).
type Protect m = SubExp -> Pat (LetDec (Rep m)) -> Op (Rep m) -> Maybe (m ())

type SimplifyOp rep op = op -> SimpleM rep (op, Stms (Wise rep))

data SimpleOps rep = SimpleOps
  { forall {k} (rep :: k).
SimpleOps rep
-> SymbolTable (Wise rep)
-> Pat (LetDec (Wise rep))
-> Exp (Wise rep)
-> SimpleM rep (ExpDec (Wise rep))
mkExpDecS ::
      ST.SymbolTable (Wise rep) ->
      Pat (LetDec (Wise rep)) ->
      Exp (Wise rep) ->
      SimpleM rep (ExpDec (Wise rep)),
    forall {k} (rep :: k).
SimpleOps rep
-> SymbolTable (Wise rep)
-> Stms (Wise rep)
-> Result
-> SimpleM rep (Body (Wise rep))
mkBodyS ::
      ST.SymbolTable (Wise rep) ->
      Stms (Wise rep) ->
      Result ->
      SimpleM rep (Body (Wise rep)),
    -- | Make a hoisted Op safe.  The SubExp is a boolean
    -- that is true when the value of the statement will
    -- actually be used.
    forall {k} (rep :: k).
SimpleOps rep -> Protect (Builder (Wise rep))
protectHoistedOpS :: Protect (Builder (Wise rep)),
    forall {k} (rep :: k). SimpleOps rep -> Op (Wise rep) -> UsageTable
opUsageS :: Op (Wise rep) -> UT.UsageTable,
    forall {k} (rep :: k).
SimpleOps rep -> SimplifyOp rep (Op (Wise rep))
simplifyOpS :: SimplifyOp rep (Op (Wise rep))
  }

bindableSimpleOps ::
  (SimplifiableRep rep, Buildable rep) =>
  SimplifyOp rep (Op (Wise rep)) ->
  SimpleOps rep
bindableSimpleOps :: forall rep.
(SimplifiableRep rep, Buildable rep) =>
SimplifyOp rep (Op (Wise rep)) -> SimpleOps rep
bindableSimpleOps =
  (SymbolTable (Wise rep)
 -> Pat (LetDec (Wise rep))
 -> Exp (Wise rep)
 -> SimpleM rep (ExpDec (Wise rep)))
-> (SymbolTable (Wise rep)
    -> Seq (Stm (Wise rep)) -> Result -> SimpleM rep (Body (Wise rep)))
-> Protect (Builder (Wise rep))
-> (Op (Wise rep) -> UsageTable)
-> (Op (Wise rep)
    -> SimpleM rep (Op (Wise rep), Seq (Stm (Wise rep))))
-> SimpleOps rep
forall {k} (rep :: k).
(SymbolTable (Wise rep)
 -> Pat (LetDec (Wise rep))
 -> Exp (Wise rep)
 -> SimpleM rep (ExpDec (Wise rep)))
-> (SymbolTable (Wise rep)
    -> Stms (Wise rep) -> Result -> SimpleM rep (Body (Wise rep)))
-> Protect (Builder (Wise rep))
-> (Op (Wise rep) -> UsageTable)
-> SimplifyOp rep (Op (Wise rep))
-> SimpleOps rep
SimpleOps SymbolTable (Wise rep)
-> Pat (LetDec (Wise rep))
-> Exp (Wise rep)
-> SimpleM rep (ExpDec (Wise rep))
forall {rep} {f :: * -> *} {p}.
(Applicative f, Buildable rep) =>
p -> Pat (LetDec rep) -> Exp rep -> f (ExpDec rep)
mkExpDecS' SymbolTable (Wise rep)
-> Seq (Stm (Wise rep)) -> Result -> SimpleM rep (Body (Wise rep))
forall {rep} {f :: * -> *} {p}.
(Applicative f, Buildable rep) =>
p -> Stms rep -> Result -> f (Body rep)
mkBodyS' SubExp
-> Pat (VarWisdom, LetDec rep)
-> OpC rep (Wise rep)
-> Maybe (Builder (Wise rep) ())
Protect (Builder (Wise rep))
forall {p} {p} {p} {a}. p -> p -> p -> Maybe a
protectHoistedOpS' (UsageTable -> OpC rep (Wise rep) -> UsageTable
forall a b. a -> b -> a
const UsageTable
forall a. Monoid a => a
mempty)
  where
    mkExpDecS' :: p -> Pat (LetDec rep) -> Exp rep -> f (ExpDec rep)
mkExpDecS' p
_ Pat (LetDec rep)
pat Exp rep
e = ExpDec rep -> f (ExpDec rep)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpDec rep -> f (ExpDec rep)) -> ExpDec rep -> f (ExpDec rep)
forall a b. (a -> b) -> a -> b
$ Pat (LetDec rep) -> Exp rep -> ExpDec rep
forall rep.
Buildable rep =>
Pat (LetDec rep) -> Exp rep -> ExpDec rep
mkExpDec Pat (LetDec rep)
pat Exp rep
e
    mkBodyS' :: p -> Stms rep -> Result -> f (Body rep)
mkBodyS' p
_ Stms rep
stms Result
res = Body rep -> f (Body rep)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body rep -> f (Body rep)) -> Body rep -> f (Body rep)
forall a b. (a -> b) -> a -> b
$ Stms rep -> Result -> Body rep
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms rep
stms Result
res
    protectHoistedOpS' :: p -> p -> p -> Maybe a
protectHoistedOpS' p
_ p
_ p
_ = Maybe a
forall a. Maybe a
Nothing

newtype SimpleM rep a
  = SimpleM
      ( ReaderT
          (SimpleOps rep, Env rep)
          (State (VNameSource, Bool, Certs))
          a
      )
  deriving
    ( Functor (SimpleM rep)
Functor (SimpleM rep) =>
(forall a. a -> SimpleM rep a)
-> (forall a b.
    SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b)
-> (forall a b c.
    (a -> b -> c) -> SimpleM rep a -> SimpleM rep b -> SimpleM rep c)
-> (forall a b. SimpleM rep a -> SimpleM rep b -> SimpleM rep b)
-> (forall a b. SimpleM rep a -> SimpleM rep b -> SimpleM rep a)
-> Applicative (SimpleM rep)
forall a. a -> SimpleM rep a
forall k (rep :: k). Functor (SimpleM rep)
forall k (rep :: k) a. a -> SimpleM rep a
forall k (rep :: k) a b.
SimpleM rep a -> SimpleM rep b -> SimpleM rep a
forall k (rep :: k) a b.
SimpleM rep a -> SimpleM rep b -> SimpleM rep b
forall k (rep :: k) a b.
SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall k (rep :: k) a b c.
(a -> b -> c) -> SimpleM rep a -> SimpleM rep b -> SimpleM rep c
forall a b. SimpleM rep a -> SimpleM rep b -> SimpleM rep a
forall a b. SimpleM rep a -> SimpleM rep b -> SimpleM rep b
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall a b c.
(a -> b -> c) -> SimpleM rep a -> SimpleM rep b -> SimpleM rep c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall k (rep :: k) a. a -> SimpleM rep a
pure :: forall a. a -> SimpleM rep a
$c<*> :: forall k (rep :: k) a b.
SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
<*> :: forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
$cliftA2 :: forall k (rep :: k) a b c.
(a -> b -> c) -> SimpleM rep a -> SimpleM rep b -> SimpleM rep c
liftA2 :: forall a b c.
(a -> b -> c) -> SimpleM rep a -> SimpleM rep b -> SimpleM rep c
$c*> :: forall k (rep :: k) a b.
SimpleM rep a -> SimpleM rep b -> SimpleM rep b
*> :: forall a b. SimpleM rep a -> SimpleM rep b -> SimpleM rep b
$c<* :: forall k (rep :: k) a b.
SimpleM rep a -> SimpleM rep b -> SimpleM rep a
<* :: forall a b. SimpleM rep a -> SimpleM rep b -> SimpleM rep a
Applicative,
      (forall a b. (a -> b) -> SimpleM rep a -> SimpleM rep b)
-> (forall a b. a -> SimpleM rep b -> SimpleM rep a)
-> Functor (SimpleM rep)
forall k (rep :: k) a b. a -> SimpleM rep b -> SimpleM rep a
forall k (rep :: k) a b. (a -> b) -> SimpleM rep a -> SimpleM rep b
forall a b. a -> SimpleM rep b -> SimpleM rep a
forall a b. (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall k (rep :: k) a b. (a -> b) -> SimpleM rep a -> SimpleM rep b
fmap :: forall a b. (a -> b) -> SimpleM rep a -> SimpleM rep b
$c<$ :: forall k (rep :: k) a b. a -> SimpleM rep b -> SimpleM rep a
<$ :: forall a b. a -> SimpleM rep b -> SimpleM rep a
Functor,
      Applicative (SimpleM rep)
Applicative (SimpleM rep) =>
(forall a b.
 SimpleM rep a -> (a -> SimpleM rep b) -> SimpleM rep b)
-> (forall a b. SimpleM rep a -> SimpleM rep b -> SimpleM rep b)
-> (forall a. a -> SimpleM rep a)
-> Monad (SimpleM rep)
forall a. a -> SimpleM rep a
forall k (rep :: k). Applicative (SimpleM rep)
forall k (rep :: k) a. a -> SimpleM rep a
forall k (rep :: k) a b.
SimpleM rep a -> SimpleM rep b -> SimpleM rep b
forall k (rep :: k) a b.
SimpleM rep a -> (a -> SimpleM rep b) -> SimpleM rep b
forall a b. SimpleM rep a -> SimpleM rep b -> SimpleM rep b
forall a b. SimpleM rep a -> (a -> SimpleM rep b) -> SimpleM rep b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall k (rep :: k) a b.
SimpleM rep a -> (a -> SimpleM rep b) -> SimpleM rep b
>>= :: forall a b. SimpleM rep a -> (a -> SimpleM rep b) -> SimpleM rep b
$c>> :: forall k (rep :: k) a b.
SimpleM rep a -> SimpleM rep b -> SimpleM rep b
>> :: forall a b. SimpleM rep a -> SimpleM rep b -> SimpleM rep b
$creturn :: forall k (rep :: k) a. a -> SimpleM rep a
return :: forall a. a -> SimpleM rep a
Monad,
      MonadReader (SimpleOps rep, Env rep),
      MonadState (VNameSource, Bool, Certs)
    )

instance MonadFreshNames (SimpleM rep) where
  putNameSource :: VNameSource -> SimpleM rep ()
putNameSource VNameSource
src = ((VNameSource, Bool, Certs) -> (VNameSource, Bool, Certs))
-> SimpleM rep ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (((VNameSource, Bool, Certs) -> (VNameSource, Bool, Certs))
 -> SimpleM rep ())
-> ((VNameSource, Bool, Certs) -> (VNameSource, Bool, Certs))
-> SimpleM rep ()
forall a b. (a -> b) -> a -> b
$ \(VNameSource
_, Bool
b, Certs
c) -> (VNameSource
src, Bool
b, Certs
c)
  getNameSource :: SimpleM rep VNameSource
getNameSource = ((VNameSource, Bool, Certs) -> VNameSource)
-> SimpleM rep VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (((VNameSource, Bool, Certs) -> VNameSource)
 -> SimpleM rep VNameSource)
-> ((VNameSource, Bool, Certs) -> VNameSource)
-> SimpleM rep VNameSource
forall a b. (a -> b) -> a -> b
$ \(VNameSource
a, Bool
_, Certs
_) -> VNameSource
a

instance (SimplifiableRep rep) => HasScope (Wise rep) (SimpleM rep) where
  askScope :: SimpleM rep (Scope (Wise rep))
askScope = SymbolTable (Wise rep) -> Scope (Wise rep)
forall rep. SymbolTable rep -> Scope rep
ST.toScope (SymbolTable (Wise rep) -> Scope (Wise rep))
-> SimpleM rep (SymbolTable (Wise rep))
-> SimpleM rep (Scope (Wise rep))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SimpleM rep (SymbolTable (Wise rep))
forall {k} (rep :: k). SimpleM rep (SymbolTable (Wise rep))
askVtable
  lookupType :: VName -> SimpleM rep Type
lookupType VName
name = do
    vtable <- SimpleM rep (SymbolTable (Wise rep))
forall {k} (rep :: k). SimpleM rep (SymbolTable (Wise rep))
askVtable
    case ST.lookupType name vtable of
      Just Type
t -> Type -> SimpleM rep Type
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
t
      Maybe Type
Nothing ->
        [Char] -> SimpleM rep Type
forall a. HasCallStack => [Char] -> a
error ([Char] -> SimpleM rep Type) -> [Char] -> SimpleM rep Type
forall a b. (a -> b) -> a -> b
$
          [Char]
"SimpleM.lookupType: cannot find variable "
            [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString VName
name
            [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" in symbol table."

instance
  (SimplifiableRep rep) =>
  LocalScope (Wise rep) (SimpleM rep)
  where
  localScope :: forall a. Scope (Wise rep) -> SimpleM rep a -> SimpleM rep a
localScope Scope (Wise rep)
types = (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
localVtable (SymbolTable (Wise rep)
-> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a. Semigroup a => a -> a -> a
<> Scope (Wise rep) -> SymbolTable (Wise rep)
forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope Scope (Wise rep)
types)

runSimpleM ::
  SimpleM rep a ->
  SimpleOps rep ->
  Env rep ->
  VNameSource ->
  ((a, Bool), VNameSource)
runSimpleM :: forall {k} (rep :: k) a.
SimpleM rep a
-> SimpleOps rep
-> Env rep
-> VNameSource
-> ((a, Bool), VNameSource)
runSimpleM (SimpleM ReaderT
  (SimpleOps rep, Env rep) (State (VNameSource, Bool, Certs)) a
m) SimpleOps rep
simpl Env rep
env VNameSource
src =
  let (a
x, (VNameSource
src', Bool
b, Certs
_)) = State (VNameSource, Bool, Certs) a
-> (VNameSource, Bool, Certs) -> (a, (VNameSource, Bool, Certs))
forall s a. State s a -> s -> (a, s)
runState (ReaderT
  (SimpleOps rep, Env rep) (State (VNameSource, Bool, Certs)) a
-> (SimpleOps rep, Env rep) -> State (VNameSource, Bool, Certs) a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT
  (SimpleOps rep, Env rep) (State (VNameSource, Bool, Certs)) a
m (SimpleOps rep
simpl, Env rep
env)) (VNameSource
src, Bool
False, Certs
forall a. Monoid a => a
mempty)
   in ((a
x, Bool
b), VNameSource
src')

askEngineEnv :: SimpleM rep (Env rep)
askEngineEnv :: forall {k} (rep :: k). SimpleM rep (Env rep)
askEngineEnv = ((SimpleOps rep, Env rep) -> Env rep) -> SimpleM rep (Env rep)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (SimpleOps rep, Env rep) -> Env rep
forall a b. (a, b) -> b
snd

asksEngineEnv :: (Env rep -> a) -> SimpleM rep a
asksEngineEnv :: forall {k} (rep :: k) a. (Env rep -> a) -> SimpleM rep a
asksEngineEnv Env rep -> a
f = Env rep -> a
f (Env rep -> a) -> SimpleM rep (Env rep) -> SimpleM rep a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SimpleM rep (Env rep)
forall {k} (rep :: k). SimpleM rep (Env rep)
askEngineEnv

askVtable :: SimpleM rep (ST.SymbolTable (Wise rep))
askVtable :: forall {k} (rep :: k). SimpleM rep (SymbolTable (Wise rep))
askVtable = (Env rep -> SymbolTable (Wise rep))
-> SimpleM rep (SymbolTable (Wise rep))
forall {k} (rep :: k) a. (Env rep -> a) -> SimpleM rep a
asksEngineEnv Env rep -> SymbolTable (Wise rep)
forall {k} (rep :: k). Env rep -> SymbolTable (Wise rep)
envVtable

localVtable ::
  (ST.SymbolTable (Wise rep) -> ST.SymbolTable (Wise rep)) ->
  SimpleM rep a ->
  SimpleM rep a
localVtable :: forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
localVtable SymbolTable (Wise rep) -> SymbolTable (Wise rep)
f = ((SimpleOps rep, Env rep) -> (SimpleOps rep, Env rep))
-> SimpleM rep a -> SimpleM rep a
forall a.
((SimpleOps rep, Env rep) -> (SimpleOps rep, Env rep))
-> SimpleM rep a -> SimpleM rep a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (((SimpleOps rep, Env rep) -> (SimpleOps rep, Env rep))
 -> SimpleM rep a -> SimpleM rep a)
-> ((SimpleOps rep, Env rep) -> (SimpleOps rep, Env rep))
-> SimpleM rep a
-> SimpleM rep a
forall a b. (a -> b) -> a -> b
$ \(SimpleOps rep
ops, Env rep
env) -> (SimpleOps rep
ops, Env rep
env {envVtable = f $ envVtable env})

collectCerts :: SimpleM rep a -> SimpleM rep (a, Certs)
collectCerts :: forall {k} (rep :: k) a. SimpleM rep a -> SimpleM rep (a, Certs)
collectCerts SimpleM rep a
m = do
  x <- SimpleM rep a
m
  (a, b, cs) <- get
  put (a, b, mempty)
  pure (x, cs)

-- | Mark that we have changed something and it would be a good idea
-- to re-run the simplifier.
changed :: SimpleM rep ()
changed :: forall {k} (rep :: k). SimpleM rep ()
changed = ((VNameSource, Bool, Certs) -> (VNameSource, Bool, Certs))
-> SimpleM rep ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (((VNameSource, Bool, Certs) -> (VNameSource, Bool, Certs))
 -> SimpleM rep ())
-> ((VNameSource, Bool, Certs) -> (VNameSource, Bool, Certs))
-> SimpleM rep ()
forall a b. (a -> b) -> a -> b
$ \(VNameSource
src, Bool
_, Certs
cs) -> (VNameSource
src, Bool
True, Certs
cs)

usedCerts :: Certs -> SimpleM rep ()
usedCerts :: forall {k} (rep :: k). Certs -> SimpleM rep ()
usedCerts Certs
cs = ((VNameSource, Bool, Certs) -> (VNameSource, Bool, Certs))
-> SimpleM rep ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (((VNameSource, Bool, Certs) -> (VNameSource, Bool, Certs))
 -> SimpleM rep ())
-> ((VNameSource, Bool, Certs) -> (VNameSource, Bool, Certs))
-> SimpleM rep ()
forall a b. (a -> b) -> a -> b
$ \(VNameSource
a, Bool
b, Certs
c) -> (VNameSource
a, Bool
b, Certs
cs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
c)

-- | Indicate in the symbol table that we have descended into a loop.
enterLoop :: SimpleM rep a -> SimpleM rep a
enterLoop :: forall {k} (rep :: k) a. SimpleM rep a -> SimpleM rep a
enterLoop = (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
localVtable SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall rep. SymbolTable rep -> SymbolTable rep
ST.deepen

bindFParams :: (SimplifiableRep rep) => [FParam (Wise rep)] -> SimpleM rep a -> SimpleM rep a
bindFParams :: forall rep a.
SimplifiableRep rep =>
[FParam (Wise rep)] -> SimpleM rep a -> SimpleM rep a
bindFParams [FParam (Wise rep)]
params =
  (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
localVtable ((SymbolTable (Wise rep) -> SymbolTable (Wise rep))
 -> SimpleM rep a -> SimpleM rep a)
-> (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a
-> SimpleM rep a
forall a b. (a -> b) -> a -> b
$ [FParam (Wise rep)]
-> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall rep.
ASTRep rep =>
[FParam rep] -> SymbolTable rep -> SymbolTable rep
ST.insertFParams [FParam (Wise rep)]
params

bindLParams :: (SimplifiableRep rep) => [LParam (Wise rep)] -> SimpleM rep a -> SimpleM rep a
bindLParams :: forall rep a.
SimplifiableRep rep =>
[LParam (Wise rep)] -> SimpleM rep a -> SimpleM rep a
bindLParams [LParam (Wise rep)]
params =
  (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
localVtable ((SymbolTable (Wise rep) -> SymbolTable (Wise rep))
 -> SimpleM rep a -> SimpleM rep a)
-> (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a
-> SimpleM rep a
forall a b. (a -> b) -> a -> b
$ \SymbolTable (Wise rep)
vtable -> (Param (LParamInfo rep)
 -> SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SymbolTable (Wise rep)
-> [Param (LParamInfo rep)]
-> SymbolTable (Wise rep)
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Param (LParamInfo rep)
-> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
LParam (Wise rep)
-> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall rep.
ASTRep rep =>
LParam rep -> SymbolTable rep -> SymbolTable rep
ST.insertLParam SymbolTable (Wise rep)
vtable [Param (LParamInfo rep)]
[LParam (Wise rep)]
params

bindMerge ::
  (SimplifiableRep rep) =>
  [(FParam (Wise rep), SubExp, SubExpRes)] ->
  SimpleM rep a ->
  SimpleM rep a
bindMerge :: forall rep a.
SimplifiableRep rep =>
[(FParam (Wise rep), SubExp, SubExpRes)]
-> SimpleM rep a -> SimpleM rep a
bindMerge = (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
localVtable ((SymbolTable (Wise rep) -> SymbolTable (Wise rep))
 -> SimpleM rep a -> SimpleM rep a)
-> ([(Param (FParamInfo rep), SubExp, SubExpRes)]
    -> SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> [(Param (FParamInfo rep), SubExp, SubExpRes)]
-> SimpleM rep a
-> SimpleM rep a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Param (FParamInfo rep), SubExp, SubExpRes)]
-> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
[(Param (FParamInfo (Wise rep)), SubExp, SubExpRes)]
-> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall rep.
ASTRep rep =>
[(FParam rep, SubExp, SubExpRes)]
-> SymbolTable rep -> SymbolTable rep
ST.insertLoopMerge

bindLoopVar :: (SimplifiableRep rep) => VName -> IntType -> SubExp -> SimpleM rep a -> SimpleM rep a
bindLoopVar :: forall rep a.
SimplifiableRep rep =>
VName -> IntType -> SubExp -> SimpleM rep a -> SimpleM rep a
bindLoopVar VName
var IntType
it SubExp
bound =
  (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
localVtable ((SymbolTable (Wise rep) -> SymbolTable (Wise rep))
 -> SimpleM rep a -> SimpleM rep a)
-> (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a
-> SimpleM rep a
forall a b. (a -> b) -> a -> b
$ VName
-> IntType
-> SubExp
-> SymbolTable (Wise rep)
-> SymbolTable (Wise rep)
forall rep.
ASTRep rep =>
VName -> IntType -> SubExp -> SymbolTable rep -> SymbolTable rep
ST.insertLoopVar VName
var IntType
it SubExp
bound

makeSafe :: Exp rep -> Maybe (Exp rep)
makeSafe :: forall rep. Exp rep -> Maybe (Exp rep)
makeSafe (BasicOp (BinOp (SDiv IntType
t Safety
_) SubExp
x SubExp
y)) =
  Exp rep -> Maybe (Exp rep)
forall a. a -> Maybe a
Just (Exp rep -> Maybe (Exp rep)) -> Exp rep -> Maybe (Exp rep)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SDiv IntType
t Safety
Safe) SubExp
x SubExp
y)
makeSafe (BasicOp (BinOp (SDivUp IntType
t Safety
_) SubExp
x SubExp
y)) =
  Exp rep -> Maybe (Exp rep)
forall a. a -> Maybe a
Just (Exp rep -> Maybe (Exp rep)) -> Exp rep -> Maybe (Exp rep)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SDivUp IntType
t Safety
Safe) SubExp
x SubExp
y)
makeSafe (BasicOp (BinOp (SQuot IntType
t Safety
_) SubExp
x SubExp
y)) =
  Exp rep -> Maybe (Exp rep)
forall a. a -> Maybe a
Just (Exp rep -> Maybe (Exp rep)) -> Exp rep -> Maybe (Exp rep)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SQuot IntType
t Safety
Safe) SubExp
x SubExp
y)
makeSafe (BasicOp (BinOp (UDiv IntType
t Safety
_) SubExp
x SubExp
y)) =
  Exp rep -> Maybe (Exp rep)
forall a. a -> Maybe a
Just (Exp rep -> Maybe (Exp rep)) -> Exp rep -> Maybe (Exp rep)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
UDiv IntType
t Safety
Safe) SubExp
x SubExp
y)
makeSafe (BasicOp (BinOp (UDivUp IntType
t Safety
_) SubExp
x SubExp
y)) =
  Exp rep -> Maybe (Exp rep)
forall a. a -> Maybe a
Just (Exp rep -> Maybe (Exp rep)) -> Exp rep -> Maybe (Exp rep)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
UDivUp IntType
t Safety
Safe) SubExp
x SubExp
y)
makeSafe (BasicOp (BinOp (SMod IntType
t Safety
_) SubExp
x SubExp
y)) =
  Exp rep -> Maybe (Exp rep)
forall a. a -> Maybe a
Just (Exp rep -> Maybe (Exp rep)) -> Exp rep -> Maybe (Exp rep)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SMod IntType
t Safety
Safe) SubExp
x SubExp
y)
makeSafe (BasicOp (BinOp (SRem IntType
t Safety
_) SubExp
x SubExp
y)) =
  Exp rep -> Maybe (Exp rep)
forall a. a -> Maybe a
Just (Exp rep -> Maybe (Exp rep)) -> Exp rep -> Maybe (Exp rep)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SRem IntType
t Safety
Safe) SubExp
x SubExp
y)
makeSafe (BasicOp (BinOp (UMod IntType
t Safety
_) SubExp
x SubExp
y)) =
  Exp rep -> Maybe (Exp rep)
forall a. a -> Maybe a
Just (Exp rep -> Maybe (Exp rep)) -> Exp rep -> Maybe (Exp rep)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
UMod IntType
t Safety
Safe) SubExp
x SubExp
y)
makeSafe Exp rep
_ =
  Maybe (Exp rep)
forall a. Maybe a
Nothing

emptyOfType :: (MonadBuilder m) => [VName] -> Type -> m (Exp (Rep m))
emptyOfType :: forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Type -> m (Exp (Rep m))
emptyOfType [VName]
_ Mem {} =
  [Char] -> m (Exp (Rep m))
forall a. HasCallStack => [Char] -> a
error [Char]
"emptyOfType: Cannot hoist non-existential memory."
emptyOfType [VName]
_ Acc {} =
  [Char] -> m (Exp (Rep m))
forall a. HasCallStack => [Char] -> a
error [Char]
"emptyOfType: Cannot hoist accumulator."
emptyOfType [VName]
_ (Prim PrimType
pt) =
  Exp (Rep m) -> m (Exp (Rep m))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Rep m) -> m (Exp (Rep m))) -> Exp (Rep m) -> m (Exp (Rep m))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
pt
emptyOfType [VName]
ctx_names (Array PrimType
et ShapeBase SubExp
shape NoUniqueness
_) = do
  let dims :: [SubExp]
dims = (SubExp -> SubExp) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> SubExp
zeroIfContext ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape
  Exp (Rep m) -> m (Exp (Rep m))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Rep m) -> m (Exp (Rep m))) -> Exp (Rep m) -> m (Exp (Rep m))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ PrimType -> [SubExp] -> BasicOp
Scratch PrimType
et [SubExp]
dims
  where
    zeroIfContext :: SubExp -> SubExp
zeroIfContext (Var VName
v) | VName
v VName -> [VName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
ctx_names = IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0
    zeroIfContext SubExp
se = SubExp
se

protectIf ::
  (MonadBuilder m) =>
  Protect m ->
  (Exp (Rep m) -> Bool) ->
  SubExp ->
  Stm (Rep m) ->
  m ()
protectIf :: forall (m :: * -> *).
MonadBuilder m =>
Protect m -> (Exp (Rep m) -> Bool) -> SubExp -> Stm (Rep m) -> m ()
protectIf Protect m
_ Exp (Rep m) -> Bool
_ SubExp
taken (Let Pat (LetDec (Rep m))
pat StmAux (ExpDec (Rep m))
aux (Match [SubExp
cond] [Case [Just (BoolValue Bool
True)] Body (Rep m)
taken_body] Body (Rep m)
untaken_body (MatchDec [BranchType (Rep m)]
if_ts MatchSort
MatchFallback))) = do
  cond' <- [Char] -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"protect_cond_conj" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
taken SubExp
cond
  auxing aux . letBind pat $
    Match [cond'] [Case [Just (BoolValue True)] taken_body] untaken_body $
      MatchDec if_ts MatchFallback
protectIf Protect m
_ Exp (Rep m) -> Bool
_ SubExp
taken (Let Pat (LetDec (Rep m))
pat StmAux (ExpDec (Rep m))
aux (BasicOp (Assert SubExp
cond ErrorMsg SubExp
msg))) = do
  not_taken <- [Char] -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"loop_not_taken" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp (PrimType -> UnOp
Neg PrimType
Bool) SubExp
taken
  cond' <- letSubExp "protect_assert_disj" $ BasicOp $ BinOp LogOr not_taken cond
  auxing aux $ letBind pat $ BasicOp $ Assert cond' msg
protectIf Protect m
protect Exp (Rep m) -> Bool
_ SubExp
taken (Let Pat (LetDec (Rep m))
pat StmAux (ExpDec (Rep m))
aux (Op Op (Rep m)
op))
  | Just m ()
m <- Protect m
protect SubExp
taken Pat (LetDec (Rep m))
pat Op (Rep m)
op =
      StmAux (ExpDec (Rep m)) -> m () -> m ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec (Rep m))
aux m ()
m
protectIf Protect m
_ Exp (Rep m) -> Bool
f SubExp
taken (Let Pat (LetDec (Rep m))
pat StmAux (ExpDec (Rep m))
aux Exp (Rep m)
e)
  | Exp (Rep m) -> Bool
f Exp (Rep m)
e =
      case Exp (Rep m) -> Maybe (Exp (Rep m))
forall rep. Exp rep -> Maybe (Exp rep)
makeSafe Exp (Rep m)
e of
        Just Exp (Rep m)
e' ->
          StmAux (ExpDec (Rep m)) -> m () -> m ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec (Rep m))
aux (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Rep m))
pat Exp (Rep m)
e'
        Maybe (Exp (Rep m))
Nothing -> do
          taken_body <- [m (Exp (Rep m))] -> m (Body (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [Exp (Rep m) -> m (Exp (Rep m))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp (Rep m)
e]
          untaken_body <-
            eBody $ map (emptyOfType $ patNames pat) (patTypes pat)
          if_ts <- expTypesFromPat pat
          auxing aux . letBind pat
            $ Match
              [taken]
              [Case [Just $ BoolValue True] taken_body]
              untaken_body
            $ MatchDec if_ts MatchFallback
protectIf Protect m
_ Exp (Rep m) -> Bool
_ SubExp
_ Stm (Rep m)
stm =
  Stm (Rep m) -> m ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm (Rep m)
stm

-- | We are willing to hoist potentially unsafe statements out of
-- loops, but they must be protected by adding a branch on top of
-- them.
protectLoopHoisted ::
  (SimplifiableRep rep) =>
  [(FParam (Wise rep), SubExp)] ->
  LoopForm ->
  SimpleM rep (a, b, Stms (Wise rep)) ->
  SimpleM rep (a, b, Stms (Wise rep))
protectLoopHoisted :: forall rep a b.
SimplifiableRep rep =>
[(FParam (Wise rep), SubExp)]
-> LoopForm
-> SimpleM rep (a, b, Stms (Wise rep))
-> SimpleM rep (a, b, Stms (Wise rep))
protectLoopHoisted [(FParam (Wise rep), SubExp)]
merge LoopForm
form SimpleM rep (a, b, Stms (Wise rep))
m = do
  (x, y, stms) <- SimpleM rep (a, b, Stms (Wise rep))
m
  ops <- asks $ protectHoistedOpS . fst
  stms' <- runBuilder_ $ do
    if not $ all (safeExp . stmExp) stms
      then do
        is_nonempty <- checkIfNonEmpty
        mapM_ (protectIf ops (not . safeExp) is_nonempty) stms
      else addStms stms
  pure (x, y, stms')
  where
    checkIfNonEmpty :: BuilderT (Wise rep) (State VNameSource) SubExp
checkIfNonEmpty =
      case LoopForm
form of
        WhileLoop VName
cond
          | Just (Param (FParamInfo rep)
_, SubExp
cond_init) <-
              ((Param (FParamInfo rep), SubExp) -> Bool)
-> [(Param (FParamInfo rep), SubExp)]
-> Maybe (Param (FParamInfo rep), SubExp)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
cond) (VName -> Bool)
-> ((Param (FParamInfo rep), SubExp) -> VName)
-> (Param (FParamInfo rep), SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (FParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName (Param (FParamInfo rep) -> VName)
-> ((Param (FParamInfo rep), SubExp) -> Param (FParamInfo rep))
-> (Param (FParamInfo rep), SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param (FParamInfo rep), SubExp) -> Param (FParamInfo rep)
forall a b. (a, b) -> a
fst) [(Param (FParamInfo rep), SubExp)]
[(FParam (Wise rep), SubExp)]
merge ->
              SubExp -> BuilderT (Wise rep) (State VNameSource) SubExp
forall a. a -> BuilderT (Wise rep) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
cond_init
          | Bool
otherwise -> SubExp -> BuilderT (Wise rep) (State VNameSource) SubExp
forall a. a -> BuilderT (Wise rep) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> BuilderT (Wise rep) (State VNameSource) SubExp)
-> SubExp -> BuilderT (Wise rep) (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ Bool -> SubExp
forall v. IsValue v => v -> SubExp
constant Bool
True -- infinite loop
        ForLoop VName
_ IntType
it SubExp
bound ->
          [Char]
-> Exp (Rep (BuilderT (Wise rep) (State VNameSource)))
-> BuilderT (Wise rep) (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"loop_nonempty" (Exp (Rep (BuilderT (Wise rep) (State VNameSource)))
 -> BuilderT (Wise rep) (State VNameSource) SubExp)
-> Exp (Rep (BuilderT (Wise rep) (State VNameSource)))
-> BuilderT (Wise rep) (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$
            BasicOp -> Exp (Rep (BuilderT (Wise rep) (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT (Wise rep) (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT (Wise rep) (State VNameSource)))
forall a b. (a -> b) -> a -> b
$
              CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (IntType -> CmpOp
CmpSlt IntType
it) (IntType -> Integer -> SubExp
intConst IntType
it Integer
0) SubExp
bound

-- Produces a true subexpression if the pattern (as in a 'Case')
-- matches the subexpression.
matching ::
  (BuilderOps rep) =>
  [(SubExp, Maybe PrimValue)] ->
  Builder rep SubExp
matching :: forall rep.
BuilderOps rep =>
[(SubExp, Maybe PrimValue)] -> Builder rep SubExp
matching = [Char]
-> Exp (Rep (BuilderT rep (State VNameSource)))
-> BuilderT rep (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"match" (Exp rep -> BuilderT rep (State VNameSource) SubExp)
-> ([(SubExp, Maybe PrimValue)]
    -> BuilderT rep (State VNameSource) (Exp rep))
-> [(SubExp, Maybe PrimValue)]
-> BuilderT rep (State VNameSource) SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< [SubExp] -> BuilderT rep (State VNameSource) (Exp rep)
[SubExp]
-> BuilderT
     rep
     (State VNameSource)
     (Exp (Rep (BuilderT rep (State VNameSource))))
forall (m :: * -> *). MonadBuilder m => [SubExp] -> m (Exp (Rep m))
eAll ([SubExp] -> BuilderT rep (State VNameSource) (Exp rep))
-> ([(SubExp, Maybe PrimValue)]
    -> BuilderT rep (State VNameSource) [SubExp])
-> [(SubExp, Maybe PrimValue)]
-> BuilderT rep (State VNameSource) (Exp rep)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< [BuilderT rep (State VNameSource) SubExp]
-> BuilderT rep (State VNameSource) [SubExp]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence ([BuilderT rep (State VNameSource) SubExp]
 -> BuilderT rep (State VNameSource) [SubExp])
-> ([(SubExp, Maybe PrimValue)]
    -> [BuilderT rep (State VNameSource) SubExp])
-> [(SubExp, Maybe PrimValue)]
-> BuilderT rep (State VNameSource) [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((SubExp, Maybe PrimValue)
 -> Maybe (BuilderT rep (State VNameSource) SubExp))
-> [(SubExp, Maybe PrimValue)]
-> [BuilderT rep (State VNameSource) SubExp]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (SubExp, Maybe PrimValue)
-> Maybe (BuilderT rep (State VNameSource) SubExp)
forall {f :: * -> *}.
MonadBuilder f =>
(SubExp, Maybe PrimValue) -> Maybe (f SubExp)
cmp
  where
    cmp :: (SubExp, Maybe PrimValue) -> Maybe (f SubExp)
cmp (SubExp
se, Just (BoolValue Bool
True)) =
      f SubExp -> Maybe (f SubExp)
forall a. a -> Maybe a
Just (f SubExp -> Maybe (f SubExp)) -> f SubExp -> Maybe (f SubExp)
forall a b. (a -> b) -> a -> b
$ SubExp -> f SubExp
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
se
    cmp (SubExp
se, Just PrimValue
v) =
      f SubExp -> Maybe (f SubExp)
forall a. a -> Maybe a
Just (f SubExp -> Maybe (f SubExp))
-> (BasicOp -> f SubExp) -> BasicOp -> Maybe (f SubExp)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> Exp (Rep f) -> f SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"match_val" (Exp (Rep f) -> f SubExp)
-> (BasicOp -> Exp (Rep f)) -> BasicOp -> f SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp (Rep f)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Maybe (f SubExp)) -> BasicOp -> Maybe (f SubExp)
forall a b. (a -> b) -> a -> b
$
        CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (PrimType -> CmpOp
CmpEq (PrimValue -> PrimType
primValueType PrimValue
v)) SubExp
se (PrimValue -> SubExp
Constant PrimValue
v)
    cmp (SubExp
_, Maybe PrimValue
Nothing) = Maybe (f SubExp)
forall a. Maybe a
Nothing

matchingExactlyThis ::
  (BuilderOps rep) =>
  [SubExp] ->
  [[Maybe PrimValue]] ->
  [Maybe PrimValue] ->
  Builder rep SubExp
matchingExactlyThis :: forall rep.
BuilderOps rep =>
[SubExp]
-> [[Maybe PrimValue]] -> [Maybe PrimValue] -> Builder rep SubExp
matchingExactlyThis [SubExp]
ses [[Maybe PrimValue]]
prior [Maybe PrimValue]
this = do
  prior_matches <- ([Maybe PrimValue] -> BuilderT rep (State VNameSource) SubExp)
-> [[Maybe PrimValue]] -> BuilderT rep (State VNameSource) [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ([(SubExp, Maybe PrimValue)]
-> BuilderT rep (State VNameSource) SubExp
forall rep.
BuilderOps rep =>
[(SubExp, Maybe PrimValue)] -> Builder rep SubExp
matching ([(SubExp, Maybe PrimValue)]
 -> BuilderT rep (State VNameSource) SubExp)
-> ([Maybe PrimValue] -> [(SubExp, Maybe PrimValue)])
-> [Maybe PrimValue]
-> BuilderT rep (State VNameSource) SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [SubExp] -> [Maybe PrimValue] -> [(SubExp, Maybe PrimValue)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
ses) [[Maybe PrimValue]]
prior
  letSubExp "matching_just_this"
    =<< eBinOp
      LogAnd
      (eUnOp (Neg Bool) (eAny prior_matches))
      (eSubExp =<< matching (zip ses this))

-- | We are willing to hoist potentially unsafe statements out of
-- matches, but they must be protected by adding a branch on top of
-- them.  (This means such hoisting is not worth it unless they are in
-- turn hoisted out of a loop somewhere.)
protectCaseHoisted ::
  (SimplifiableRep rep) =>
  -- | Scrutinee.
  [SubExp] ->
  -- | Pattern of previosu cases.
  [[Maybe PrimValue]] ->
  -- | Pattern of this case.
  [Maybe PrimValue] ->
  SimpleM rep (Stms (Wise rep), a) ->
  SimpleM rep (Stms (Wise rep), a)
protectCaseHoisted :: forall rep a.
SimplifiableRep rep =>
[SubExp]
-> [[Maybe PrimValue]]
-> [Maybe PrimValue]
-> SimpleM rep (Stms (Wise rep), a)
-> SimpleM rep (Stms (Wise rep), a)
protectCaseHoisted [SubExp]
ses [[Maybe PrimValue]]
prior [Maybe PrimValue]
vs SimpleM rep (Stms (Wise rep), a)
m = do
  (hoisted, x) <- SimpleM rep (Stms (Wise rep), a)
m
  ops <- asks $ protectHoistedOpS . fst
  hoisted' <- runBuilder_ $ do
    if not $ all (safeExp . stmExp) hoisted
      then do
        cond' <- matchingExactlyThis ses prior vs
        mapM_ (protectIf ops unsafeOrCostly cond') hoisted
      else addStms hoisted
  pure (hoisted', x)
  where
    unsafeOrCostly :: Exp rep -> Bool
unsafeOrCostly Exp rep
e = Bool -> Bool
not (Exp rep -> Bool
forall rep. ASTRep rep => Exp rep -> Bool
safeExp Exp rep
e) Bool -> Bool -> Bool
|| Bool -> Bool
not (Exp rep -> Bool
forall rep. ASTRep rep => Exp rep -> Bool
cheapExp Exp rep
e)

-- | Statements that are not worth hoisting out of loops, because they
-- are unsafe, and added safety (by 'protectLoopHoisted') may inhibit
-- further optimisation.
notWorthHoisting :: (ASTRep rep) => BlockPred rep
notWorthHoisting :: forall rep. ASTRep rep => BlockPred rep
notWorthHoisting SymbolTable rep
_ UsageTable
_ (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ Exp rep
e) =
  Bool -> Bool
not (Exp rep -> Bool
forall rep. ASTRep rep => Exp rep -> Bool
safeExp Exp rep
e) Bool -> Bool -> Bool
&& (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0) (Int -> Bool) -> (Type -> Int) -> Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank) (Pat (LetDec rep) -> [Type]
forall dec. Typed dec => Pat dec -> [Type]
patTypes Pat (LetDec rep)
pat)

-- Top-down simplify a statement (including copy propagation into the
-- pattern and such).  Does not recurse into any sub-Bodies or Ops.
nonrecSimplifyStm ::
  (SimplifiableRep rep) =>
  Stm (Wise rep) ->
  SimpleM rep (Stm (Wise rep))
nonrecSimplifyStm :: forall rep.
SimplifiableRep rep =>
Stm (Wise rep) -> SimpleM rep (Stm (Wise rep))
nonrecSimplifyStm (Let Pat (LetDec (Wise rep))
pat (StmAux Certs
cs Attrs
attrs Provenance
loc (ExpWisdom
_, ExpDec rep
dec)) Exp (Wise rep)
e) = do
  cs' <- Certs -> SimpleM rep Certs
forall rep. SimplifiableRep rep => Certs -> SimpleM rep Certs
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify Certs
cs
  e' <- simplifyExpBase e
  (pat', pat_cs) <- collectCerts $ traverse simplify $ removePatWisdom pat
  let aux' = Certs -> Attrs -> Provenance -> ExpDec rep -> StmAux (ExpDec rep)
forall dec. Certs -> Attrs -> Provenance -> dec -> StmAux dec
StmAux (Certs
cs' Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
pat_cs) Attrs
attrs Provenance
loc ExpDec rep
dec
  pure $ mkWiseStm pat' aux' e'

-- Bottom-up simplify a statement.  Recurses into sub-Bodies and Ops.
-- Does not copy-propagate into the pattern and similar, as it is
-- assumed 'nonrecSimplifyStm' has already touched it (and worst case,
-- it'll get it on the next round of the overall fixpoint iteration.)
recSimplifyStm ::
  (SimplifiableRep rep) =>
  Stm (Wise rep) ->
  UT.UsageTable ->
  SimpleM rep (Stms (Wise rep), Stm (Wise rep))
recSimplifyStm :: forall rep.
SimplifiableRep rep =>
Stm (Wise rep)
-> UsageTable -> SimpleM rep (Stms (Wise rep), Stm (Wise rep))
recSimplifyStm (Let Pat (LetDec (Wise rep))
pat (StmAux Certs
cs Attrs
attrs Provenance
loc (ExpWisdom
_, ExpDec rep
dec)) Exp (Wise rep)
e) UsageTable
usage = do
  ((e', e_hoisted), e_cs) <- SimpleM rep (Exp (Wise rep), Stms (Wise rep))
-> SimpleM rep ((Exp (Wise rep), Stms (Wise rep)), Certs)
forall {k} (rep :: k) a. SimpleM rep a -> SimpleM rep (a, Certs)
collectCerts (SimpleM rep (Exp (Wise rep), Stms (Wise rep))
 -> SimpleM rep ((Exp (Wise rep), Stms (Wise rep)), Certs))
-> SimpleM rep (Exp (Wise rep), Stms (Wise rep))
-> SimpleM rep ((Exp (Wise rep), Stms (Wise rep)), Certs)
forall a b. (a -> b) -> a -> b
$ UsageTable
-> Pat (LetDec (Wise rep))
-> Exp (Wise rep)
-> SimpleM rep (Exp (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
UsageTable
-> Pat (LetDec (Wise rep))
-> Exp (Wise rep)
-> SimpleM rep (Exp (Wise rep), Stms (Wise rep))
simplifyExp (UsageTable
usage UsageTable -> UsageTable -> UsageTable
forall a. Semigroup a => a -> a -> a
<> Pat (VarWisdom, LetDec rep) -> UsageTable
forall t. FreeIn t => Pat t -> UsageTable
UT.usageInPat Pat (VarWisdom, LetDec rep)
Pat (LetDec (Wise rep))
pat) Pat (LetDec (Wise rep))
pat Exp (Wise rep)
e
  let aux' = Certs -> Attrs -> Provenance -> ExpDec rep -> StmAux (ExpDec rep)
forall dec. Certs -> Attrs -> Provenance -> dec -> StmAux dec
StmAux (Certs
cs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
e_cs) Attrs
attrs Provenance
loc ExpDec rep
dec
  pure (e_hoisted, mkWiseStm (removePatWisdom pat) aux' e')

hoistStms ::
  (SimplifiableRep rep) =>
  RuleBook (Wise rep) ->
  BlockPred (Wise rep) ->
  Stms (Wise rep) ->
  SimpleM rep (a, UT.UsageTable) ->
  SimpleM rep (a, Stms (Wise rep), Stms (Wise rep))
hoistStms :: forall rep a.
SimplifiableRep rep =>
RuleBook (Wise rep)
-> BlockPred (Wise rep)
-> Stms (Wise rep)
-> SimpleM rep (a, UsageTable)
-> SimpleM rep (a, Stms (Wise rep), Stms (Wise rep))
hoistStms RuleBook (Wise rep)
rules BlockPred (Wise rep)
block Stms (Wise rep)
orig_stms SimpleM rep (a, UsageTable)
final = do
  (a, blocked, hoisted) <- Stms (Wise rep)
-> SimpleM rep (a, [Stm (Wise rep)], [Stm (Wise rep)])
simplifyStmsBottomUp Stms (Wise rep)
orig_stms
  unless (null hoisted) changed
  pure (a, stmsFromList blocked, stmsFromList hoisted)
  where
    simplifyStmsBottomUp :: Stms (Wise rep)
-> SimpleM rep (a, [Stm (Wise rep)], [Stm (Wise rep)])
simplifyStmsBottomUp Stms (Wise rep)
stms = do
      opUsage <- ((SimpleOps rep, Env rep) -> Op (Wise rep) -> UsageTable)
-> SimpleM rep (Op (Wise rep) -> UsageTable)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (((SimpleOps rep, Env rep) -> Op (Wise rep) -> UsageTable)
 -> SimpleM rep (Op (Wise rep) -> UsageTable))
-> ((SimpleOps rep, Env rep) -> Op (Wise rep) -> UsageTable)
-> SimpleM rep (Op (Wise rep) -> UsageTable)
forall a b. (a -> b) -> a -> b
$ SimpleOps rep -> Op (Wise rep) -> UsageTable
forall {k} (rep :: k). SimpleOps rep -> Op (Wise rep) -> UsageTable
opUsageS (SimpleOps rep -> Op (Wise rep) -> UsageTable)
-> ((SimpleOps rep, Env rep) -> SimpleOps rep)
-> (SimpleOps rep, Env rep)
-> Op (Wise rep)
-> UsageTable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SimpleOps rep, Env rep) -> SimpleOps rep
forall a b. (a, b) -> a
fst
      let usageInStm Stm (Wise rep)
stm =
            Stm (Wise rep) -> UsageTable
forall rep. Aliased rep => Stm rep -> UsageTable
UT.usageInStm Stm (Wise rep)
stm
              UsageTable -> UsageTable -> UsageTable
forall a. Semigroup a => a -> a -> a
<> case Stm (Wise rep) -> Exp (Wise rep)
forall rep. Stm rep -> Exp rep
stmExp Stm (Wise rep)
stm of
                Op Op (Wise rep)
op -> Op (Wise rep) -> UsageTable
opUsage Op (Wise rep)
op
                Exp (Wise rep)
_ -> UsageTable
forall a. Monoid a => a
mempty
      (x, _, stms') <- hoistableStms usageInStm stms
      -- We need to do a final pass to ensure that nothing is
      -- hoisted past something that it depends on.
      let (blocked, hoisted) = partitionEithers $ blockUnhoistedDeps stms'
      pure (x, blocked, hoisted)

    descend :: (Stm (Wise rep) -> UsageTable)
-> Stms (Wise rep)
-> SimpleM
     rep (a, UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
-> SimpleM
     rep (a, UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
descend Stm (Wise rep) -> UsageTable
usageInStm Stms (Wise rep)
stms SimpleM
  rep (a, UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
m =
      case Stms (Wise rep) -> Maybe (Stm (Wise rep), Stms (Wise rep))
forall rep. Stms rep -> Maybe (Stm rep, Stms rep)
stmsHead Stms (Wise rep)
stms of
        Maybe (Stm (Wise rep), Stms (Wise rep))
Nothing -> SimpleM
  rep (a, UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
m
        Just (Stm (Wise rep)
stms_h, Stms (Wise rep)
stms_t) -> (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM
     rep (a, UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
-> SimpleM
     rep (a, UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
localVtable (Stm (Wise rep) -> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall rep.
(IndexOp (Op rep), Aliased rep) =>
Stm rep -> SymbolTable rep -> SymbolTable rep
ST.insertStm Stm (Wise rep)
stms_h) (SimpleM
   rep (a, UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
 -> SimpleM
      rep (a, UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))]))
-> SimpleM
     rep (a, UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
-> SimpleM
     rep (a, UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
forall a b. (a -> b) -> a -> b
$ do
          (x, usage, stms_t') <- (Stm (Wise rep) -> UsageTable)
-> Stms (Wise rep)
-> SimpleM
     rep (a, UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
-> SimpleM
     rep (a, UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
descend Stm (Wise rep) -> UsageTable
usageInStm Stms (Wise rep)
stms_t SimpleM
  rep (a, UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
m
          process usageInStm stms_h stms_t' usage x

    process :: (Stm (Wise rep) -> UsageTable)
-> Stm (Wise rep)
-> [Either (Stm (Wise rep)) (Stm (Wise rep))]
-> UsageTable
-> a
-> SimpleM
     rep (a, UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
process Stm (Wise rep) -> UsageTable
usageInStm Stm (Wise rep)
stm [Either (Stm (Wise rep)) (Stm (Wise rep))]
stms UsageTable
usage a
x = do
      vtable <- SimpleM rep (SymbolTable (Wise rep))
forall {k} (rep :: k). SimpleM rep (SymbolTable (Wise rep))
askVtable
      res <- bottomUpSimplifyStm rules (vtable, usage) stm
      case res of
        Maybe (Stms (Wise rep))
Nothing -- Nothing to optimise - see if hoistable.
          | BlockPred (Wise rep)
block SymbolTable (Wise rep)
vtable UsageTable
usage Stm (Wise rep)
stm ->
              -- No, not hoistable.
              (a, UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
-> SimpleM
     rep (a, UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
                ( a
x,
                  (Stm (Wise rep) -> UsageTable)
-> SymbolTable (Wise rep)
-> UsageTable
-> Stm (Wise rep)
-> UsageTable
forall rep.
Aliased rep =>
(Stm rep -> UsageTable)
-> SymbolTable rep -> UsageTable -> Stm rep -> UsageTable
expandUsage Stm (Wise rep) -> UsageTable
usageInStm SymbolTable (Wise rep)
vtable UsageTable
usage Stm (Wise rep)
stm
                    UsageTable -> [VName] -> UsageTable
`UT.without` Stm (Wise rep) -> [VName]
forall rep. Stm rep -> [VName]
provides Stm (Wise rep)
stm,
                  Stm (Wise rep) -> Either (Stm (Wise rep)) (Stm (Wise rep))
forall a b. a -> Either a b
Left Stm (Wise rep)
stm Either (Stm (Wise rep)) (Stm (Wise rep))
-> [Either (Stm (Wise rep)) (Stm (Wise rep))]
-> [Either (Stm (Wise rep)) (Stm (Wise rep))]
forall a. a -> [a] -> [a]
: [Either (Stm (Wise rep)) (Stm (Wise rep))]
stms
                )
          | Bool
otherwise ->
              -- Yes, hoistable.
              (a, UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
-> SimpleM
     rep (a, UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
                ( a
x,
                  (Stm (Wise rep) -> UsageTable)
-> SymbolTable (Wise rep)
-> UsageTable
-> Stm (Wise rep)
-> UsageTable
forall rep.
Aliased rep =>
(Stm rep -> UsageTable)
-> SymbolTable rep -> UsageTable -> Stm rep -> UsageTable
expandUsage Stm (Wise rep) -> UsageTable
usageInStm SymbolTable (Wise rep)
vtable UsageTable
usage Stm (Wise rep)
stm,
                  Stm (Wise rep) -> Either (Stm (Wise rep)) (Stm (Wise rep))
forall a b. b -> Either a b
Right Stm (Wise rep)
stm Either (Stm (Wise rep)) (Stm (Wise rep))
-> [Either (Stm (Wise rep)) (Stm (Wise rep))]
-> [Either (Stm (Wise rep)) (Stm (Wise rep))]
forall a. a -> [a] -> [a]
: [Either (Stm (Wise rep)) (Stm (Wise rep))]
stms
                )
        Just Stms (Wise rep)
optimstms -> do
          SimpleM rep ()
forall {k} (rep :: k). SimpleM rep ()
changed
          (Stm (Wise rep) -> UsageTable)
-> Stms (Wise rep)
-> SimpleM
     rep (a, UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
-> SimpleM
     rep (a, UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
descend Stm (Wise rep) -> UsageTable
usageInStm Stms (Wise rep)
optimstms (SimpleM
   rep (a, UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
 -> SimpleM
      rep (a, UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))]))
-> SimpleM
     rep (a, UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
-> SimpleM
     rep (a, UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
forall a b. (a -> b) -> a -> b
$ (a, UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
-> SimpleM
     rep (a, UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
x, UsageTable
usage, [Either (Stm (Wise rep)) (Stm (Wise rep))]
stms)

    hoistableStms :: (Stm (Wise rep) -> UsageTable)
-> Stms (Wise rep)
-> SimpleM
     rep (a, UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
hoistableStms Stm (Wise rep) -> UsageTable
usageInStm Stms (Wise rep)
stms =
      case Stms (Wise rep) -> Maybe (Stm (Wise rep), Stms (Wise rep))
forall rep. Stms rep -> Maybe (Stm rep, Stms rep)
stmsHead Stms (Wise rep)
stms of
        Maybe (Stm (Wise rep), Stms (Wise rep))
Nothing -> do
          (x, usage) <- SimpleM rep (a, UsageTable)
final
          pure (x, usage, mempty)
        Just (Stm (Wise rep)
stms_h, Stms (Wise rep)
stms_t) -> do
          stms_h' <- Stm (Wise rep) -> SimpleM rep (Stm (Wise rep))
forall rep.
SimplifiableRep rep =>
Stm (Wise rep) -> SimpleM rep (Stm (Wise rep))
nonrecSimplifyStm Stm (Wise rep)
stms_h

          vtable <- askVtable
          simplified <- topDownSimplifyStm rules vtable stms_h'

          case simplified of
            Just Stms (Wise rep)
newstms -> do
              SimpleM rep ()
forall {k} (rep :: k). SimpleM rep ()
changed
              (Stm (Wise rep) -> UsageTable)
-> Stms (Wise rep)
-> SimpleM
     rep (a, UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
hoistableStms Stm (Wise rep) -> UsageTable
usageInStm (Stms (Wise rep)
newstms Stms (Wise rep) -> Stms (Wise rep) -> Stms (Wise rep)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
stms_t)
            Maybe (Stms (Wise rep))
Nothing -> do
              (x, usage, stms_t') <-
                (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM
     rep (a, UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
-> SimpleM
     rep (a, UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
localVtable (Stm (Wise rep) -> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall rep.
(IndexOp (Op rep), Aliased rep) =>
Stm rep -> SymbolTable rep -> SymbolTable rep
ST.insertStm Stm (Wise rep)
stms_h') (SimpleM
   rep (a, UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
 -> SimpleM
      rep (a, UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))]))
-> SimpleM
     rep (a, UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
-> SimpleM
     rep (a, UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
forall a b. (a -> b) -> a -> b
$
                  (Stm (Wise rep) -> UsageTable)
-> Stms (Wise rep)
-> SimpleM
     rep (a, UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
hoistableStms Stm (Wise rep) -> UsageTable
usageInStm Stms (Wise rep)
stms_t
              if not $ any (`UT.isUsedDirectly` usage) $ provides stms_h'
                then -- Dead statement.
                  pure (x, usage, stms_t')
                else do
                  (stms_h_stms, stms_h'') <- recSimplifyStm stms_h' usage
                  descend usageInStm stms_h_stms $
                    process usageInStm stms_h'' stms_t' usage x

blockUnhoistedDeps ::
  (ASTRep rep) =>
  [Either (Stm rep) (Stm rep)] ->
  [Either (Stm rep) (Stm rep)]
blockUnhoistedDeps :: forall rep.
ASTRep rep =>
[Either (Stm rep) (Stm rep)] -> [Either (Stm rep) (Stm rep)]
blockUnhoistedDeps = (Names, [Either (Stm rep) (Stm rep)])
-> [Either (Stm rep) (Stm rep)]
forall a b. (a, b) -> b
snd ((Names, [Either (Stm rep) (Stm rep)])
 -> [Either (Stm rep) (Stm rep)])
-> ([Either (Stm rep) (Stm rep)]
    -> (Names, [Either (Stm rep) (Stm rep)]))
-> [Either (Stm rep) (Stm rep)]
-> [Either (Stm rep) (Stm rep)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Names
 -> Either (Stm rep) (Stm rep)
 -> (Names, Either (Stm rep) (Stm rep)))
-> Names
-> [Either (Stm rep) (Stm rep)]
-> (Names, [Either (Stm rep) (Stm rep)])
forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumL Names
-> Either (Stm rep) (Stm rep)
-> (Names, Either (Stm rep) (Stm rep))
forall {rep}.
(FreeDec (ExpDec rep), FreeDec (BodyDec rep), FreeIn (OpC rep rep),
 FreeIn (FParamInfo rep), FreeIn (LParamInfo rep),
 FreeIn (LetDec rep), FreeIn (RetType rep),
 FreeIn (BranchType rep)) =>
Names
-> Either (Stm rep) (Stm rep)
-> (Names, Either (Stm rep) (Stm rep))
block Names
forall a. Monoid a => a
mempty
  where
    block :: Names
-> Either (Stm rep) (Stm rep)
-> (Names, Either (Stm rep) (Stm rep))
block Names
blocked (Left Stm rep
need) =
      (Names
blocked Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [VName] -> Names
namesFromList (Stm rep -> [VName]
forall rep. Stm rep -> [VName]
provides Stm rep
need), Stm rep -> Either (Stm rep) (Stm rep)
forall a b. a -> Either a b
Left Stm rep
need)
    block Names
blocked (Right Stm rep
need)
      | Names
blocked Names -> Names -> Bool
`namesIntersect` Stm rep -> Names
forall a. FreeIn a => a -> Names
freeIn Stm rep
need =
          (Names
blocked Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [VName] -> Names
namesFromList (Stm rep -> [VName]
forall rep. Stm rep -> [VName]
provides Stm rep
need), Stm rep -> Either (Stm rep) (Stm rep)
forall a b. a -> Either a b
Left Stm rep
need)
      | Bool
otherwise =
          (Names
blocked, Stm rep -> Either (Stm rep) (Stm rep)
forall a b. b -> Either a b
Right Stm rep
need)

provides :: Stm rep -> [VName]
provides :: forall rep. Stm rep -> [VName]
provides = Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat (LetDec rep) -> [VName])
-> (Stm rep -> Pat (LetDec rep)) -> Stm rep -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Pat (LetDec rep)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat

expandUsage ::
  (Aliased rep) =>
  (Stm rep -> UT.UsageTable) ->
  ST.SymbolTable rep ->
  UT.UsageTable ->
  Stm rep ->
  UT.UsageTable
expandUsage :: forall rep.
Aliased rep =>
(Stm rep -> UsageTable)
-> SymbolTable rep -> UsageTable -> Stm rep -> UsageTable
expandUsage Stm rep -> UsageTable
usageInStm SymbolTable rep
vtable UsageTable
utable stm :: Stm rep
stm@(Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux Exp rep
e) =
  UsageTable
stmUsages UsageTable -> UsageTable -> UsageTable
forall a. Semigroup a => a -> a -> a
<> UsageTable
utable
  where
    stmUsages :: UsageTable
stmUsages =
      (VName -> Names) -> UsageTable -> UsageTable
UT.expand (VName -> SymbolTable rep -> Names
forall rep. VName -> SymbolTable rep -> Names
`ST.lookupAliases` SymbolTable rep
vtable) (Stm rep -> UsageTable
usageInStm Stm rep
stm UsageTable -> UsageTable -> UsageTable
forall a. Semigroup a => a -> a -> a
<> UsageTable
usageThroughAliases)
        UsageTable -> UsageTable -> UsageTable
forall a. Semigroup a => a -> a -> a
<> ( if (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> UsageTable -> Bool
`UT.isSize` UsageTable
utable) (Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat)
               then Names -> UsageTable
UT.sizeUsages (Certs -> Names
forall a. FreeIn a => a -> Names
freeIn (StmAux (ExpDec rep) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec rep)
aux) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Exp rep -> Names
forall a. FreeIn a => a -> Names
freeIn Exp rep
e)
               else UsageTable
forall a. Monoid a => a
mempty
           )
    usageThroughAliases :: UsageTable
usageThroughAliases =
      [UsageTable] -> UsageTable
forall a. Monoid a => [a] -> a
mconcat ([UsageTable] -> UsageTable)
-> ([(VName, Names)] -> [UsageTable])
-> [(VName, Names)]
-> UsageTable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((VName, Names) -> Maybe UsageTable)
-> [(VName, Names)] -> [UsageTable]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (VName, Names) -> Maybe UsageTable
usageThroughBindeeAliases ([(VName, Names)] -> UsageTable) -> [(VName, Names)] -> UsageTable
forall a b. (a -> b) -> a -> b
$
        [VName] -> [Names] -> [(VName, Names)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat) (Pat (LetDec rep) -> [Names]
forall dec. AliasesOf dec => Pat dec -> [Names]
patAliases Pat (LetDec rep)
pat)
    usageThroughBindeeAliases :: (VName, Names) -> Maybe UsageTable
usageThroughBindeeAliases (VName
name, Names
aliases) = do
      uses <- VName -> UsageTable -> Maybe Usages
UT.lookup VName
name UsageTable
utable
      pure . mconcat $
        map (`UT.usage` (uses `UT.withoutU` UT.presentU)) $
          namesToList aliases

type BlockPred rep = ST.SymbolTable rep -> UT.UsageTable -> Stm rep -> Bool

neverBlocks :: BlockPred rep
neverBlocks :: forall rep. BlockPred rep
neverBlocks SymbolTable rep
_ UsageTable
_ Stm rep
_ = Bool
False

alwaysBlocks :: BlockPred rep
alwaysBlocks :: forall rep. BlockPred rep
alwaysBlocks SymbolTable rep
_ UsageTable
_ Stm rep
_ = Bool
True

isFalse :: Bool -> BlockPred rep
isFalse :: forall rep. Bool -> BlockPred rep
isFalse Bool
b SymbolTable rep
_ UsageTable
_ Stm rep
_ = Bool -> Bool
not Bool
b

orIf :: BlockPred rep -> BlockPred rep -> BlockPred rep
orIf :: forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
orIf BlockPred rep
p1 BlockPred rep
p2 SymbolTable rep
body UsageTable
vtable Stm rep
need = BlockPred rep
p1 SymbolTable rep
body UsageTable
vtable Stm rep
need Bool -> Bool -> Bool
|| BlockPred rep
p2 SymbolTable rep
body UsageTable
vtable Stm rep
need

andAlso :: BlockPred rep -> BlockPred rep -> BlockPred rep
andAlso :: forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
andAlso BlockPred rep
p1 BlockPred rep
p2 SymbolTable rep
body UsageTable
vtable Stm rep
need = BlockPred rep
p1 SymbolTable rep
body UsageTable
vtable Stm rep
need Bool -> Bool -> Bool
&& BlockPred rep
p2 SymbolTable rep
body UsageTable
vtable Stm rep
need

isConsumed :: BlockPred rep
isConsumed :: forall rep. BlockPred rep
isConsumed SymbolTable rep
_ UsageTable
utable = (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
utable) ([VName] -> Bool) -> (Stm rep -> [VName]) -> Stm rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat (LetDec rep) -> [VName])
-> (Stm rep -> Pat (LetDec rep)) -> Stm rep -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Pat (LetDec rep)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat

-- The main purpose of this rule is to avoid hoisting 'inblock' SegOps
-- out of their enclosing SegOp, *including* when those are present in
-- nested Bodies.
isOp :: BlockPred rep
isOp :: forall rep. BlockPred rep
isOp SymbolTable rep
_ UsageTable
_ (Let Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ Op {}) = Bool
True
isOp SymbolTable rep
vtable UsageTable
utable (Let Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ (Match [SubExp]
_ [Case (Body rep)]
cs Body rep
def_body MatchDec (BranchType rep)
_)) =
  (Body rep -> Bool) -> [Body rep] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((Stm rep -> Bool) -> Seq (Stm rep) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (BlockPred rep
forall rep. BlockPred rep
isOp SymbolTable rep
vtable UsageTable
utable) (Seq (Stm rep) -> Bool)
-> (Body rep -> Seq (Stm rep)) -> Body rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body rep -> Seq (Stm rep)
forall rep. Body rep -> Stms rep
bodyStms) ([Body rep] -> Bool) -> [Body rep] -> Bool
forall a b. (a -> b) -> a -> b
$ Body rep
def_body Body rep -> [Body rep] -> [Body rep]
forall a. a -> [a] -> [a]
: (Case (Body rep) -> Body rep) -> [Case (Body rep)] -> [Body rep]
forall a b. (a -> b) -> [a] -> [b]
map Case (Body rep) -> Body rep
forall body. Case body -> body
caseBody [Case (Body rep)]
cs
isOp SymbolTable rep
vtable UsageTable
utable (Let Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ (Loop [(FParam rep, SubExp)]
_ LoopForm
_ Body rep
body)) =
  (Stm rep -> Bool) -> Seq (Stm rep) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (BlockPred rep
forall rep. BlockPred rep
isOp SymbolTable rep
vtable UsageTable
utable) (Seq (Stm rep) -> Bool) -> Seq (Stm rep) -> Bool
forall a b. (a -> b) -> a -> b
$ Body rep -> Seq (Stm rep)
forall rep. Body rep -> Stms rep
bodyStms Body rep
body
isOp SymbolTable rep
_ UsageTable
_ Stm rep
_ = Bool
False

constructBody ::
  (SimplifiableRep rep) =>
  Stms (Wise rep) ->
  Result ->
  SimpleM rep (Body (Wise rep))
constructBody :: forall rep.
SimplifiableRep rep =>
Stms (Wise rep) -> Result -> SimpleM rep (Body (Wise rep))
constructBody Stms (Wise rep)
stms Result
res =
  ((Body (Wise rep), Stms (Wise rep)) -> Body (Wise rep))
-> SimpleM rep (Body (Wise rep), Stms (Wise rep))
-> SimpleM rep (Body (Wise rep))
forall a b. (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Body (Wise rep), Stms (Wise rep)) -> Body (Wise rep)
forall a b. (a, b) -> a
fst (SimpleM rep (Body (Wise rep), Stms (Wise rep))
 -> SimpleM rep (Body (Wise rep)))
-> (BuilderT (Wise rep) (State VNameSource) Result
    -> SimpleM rep (Body (Wise rep), Stms (Wise rep)))
-> BuilderT (Wise rep) (State VNameSource) Result
-> SimpleM rep (Body (Wise rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder (Wise rep) (Body (Wise rep))
-> SimpleM rep (Body (Wise rep), Stms (Wise rep))
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder (Wise rep) (Body (Wise rep))
 -> SimpleM rep (Body (Wise rep), Stms (Wise rep)))
-> (BuilderT (Wise rep) (State VNameSource) Result
    -> Builder (Wise rep) (Body (Wise rep)))
-> BuilderT (Wise rep) (State VNameSource) Result
-> SimpleM rep (Body (Wise rep), Stms (Wise rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BuilderT (Wise rep) (State VNameSource) Result
-> BuilderT
     (Wise rep)
     (State VNameSource)
     (Body (Rep (BuilderT (Wise rep) (State VNameSource))))
BuilderT (Wise rep) (State VNameSource) Result
-> Builder (Wise rep) (Body (Wise rep))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (BuilderT (Wise rep) (State VNameSource) Result
 -> SimpleM rep (Body (Wise rep)))
-> BuilderT (Wise rep) (State VNameSource) Result
-> SimpleM rep (Body (Wise rep))
forall a b. (a -> b) -> a -> b
$ do
    Stms (Rep (BuilderT (Wise rep) (State VNameSource)))
-> BuilderT (Wise rep) (State VNameSource) ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms (Rep (BuilderT (Wise rep) (State VNameSource)))
Stms (Wise rep)
stms
    Result -> BuilderT (Wise rep) (State VNameSource) Result
forall a. a -> BuilderT (Wise rep) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

blockIf ::
  (SimplifiableRep rep) =>
  BlockPred (Wise rep) ->
  Stms (Wise rep) ->
  SimpleM rep (a, UT.UsageTable) ->
  SimpleM rep (a, Stms (Wise rep), Stms (Wise rep))
blockIf :: forall rep a.
SimplifiableRep rep =>
BlockPred (Wise rep)
-> Stms (Wise rep)
-> SimpleM rep (a, UsageTable)
-> SimpleM rep (a, Stms (Wise rep), Stms (Wise rep))
blockIf BlockPred (Wise rep)
block Stms (Wise rep)
stms SimpleM rep (a, UsageTable)
m = do
  rules <- (Env rep -> RuleBook (Wise rep))
-> SimpleM rep (RuleBook (Wise rep))
forall {k} (rep :: k) a. (Env rep -> a) -> SimpleM rep a
asksEngineEnv Env rep -> RuleBook (Wise rep)
forall {k} (rep :: k). Env rep -> RuleBook (Wise rep)
envRules
  hoistStms rules block stms m

hasFree :: (ASTRep rep) => Names -> BlockPred rep
hasFree :: forall rep. ASTRep rep => Names -> BlockPred rep
hasFree Names
ks SymbolTable rep
_ UsageTable
_ Stm rep
need = Names
ks Names -> Names -> Bool
`namesIntersect` Stm rep -> Names
forall a. FreeIn a => a -> Names
freeIn Stm rep
need

isNotSafe :: (ASTRep rep) => BlockPred rep
isNotSafe :: forall rep. ASTRep rep => BlockPred rep
isNotSafe SymbolTable rep
_ UsageTable
_ = Bool -> Bool
not (Bool -> Bool) -> (Stm rep -> Bool) -> Stm rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp rep -> Bool
forall rep. ASTRep rep => Exp rep -> Bool
safeExp (Exp rep -> Bool) -> (Stm rep -> Exp rep) -> Stm rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp

isConsuming :: (Aliased rep) => BlockPred rep
isConsuming :: forall rep. Aliased rep => BlockPred rep
isConsuming SymbolTable rep
_ UsageTable
_ = Exp rep -> Bool
forall {rep}. Aliased rep => Exp rep -> Bool
isUpdate (Exp rep -> Bool) -> (Stm rep -> Exp rep) -> Stm rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp
  where
    isUpdate :: Exp rep -> Bool
isUpdate Exp rep
e = Exp rep -> Names
forall rep. Aliased rep => Exp rep -> Names
consumedInExp Exp rep
e Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
/= Names
forall a. Monoid a => a
mempty

isNotCheap :: (ASTRep rep) => BlockPred rep
isNotCheap :: forall rep. ASTRep rep => BlockPred rep
isNotCheap SymbolTable rep
_ UsageTable
_ = Bool -> Bool
not (Bool -> Bool) -> (Stm rep -> Bool) -> Stm rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Bool
forall rep. ASTRep rep => Stm rep -> Bool
cheapStm

cheapStm :: (ASTRep rep) => Stm rep -> Bool
cheapStm :: forall rep. ASTRep rep => Stm rep -> Bool
cheapStm = Exp rep -> Bool
forall rep. ASTRep rep => Exp rep -> Bool
cheapExp (Exp rep -> Bool) -> (Stm rep -> Exp rep) -> Stm rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp

cheapExp :: (ASTRep rep) => Exp rep -> Bool
cheapExp :: forall rep. ASTRep rep => Exp rep -> Bool
cheapExp (BasicOp BinOp {}) = Bool
True
cheapExp (BasicOp SubExp {}) = Bool
True
cheapExp (BasicOp UnOp {}) = Bool
True
cheapExp (BasicOp CmpOp {}) = Bool
True
cheapExp (BasicOp ConvOp {}) = Bool
True
cheapExp (BasicOp Assert {}) = Bool
True
cheapExp (BasicOp Replicate {}) = Bool
False
cheapExp (BasicOp Concat {}) = Bool
False
cheapExp (BasicOp Manifest {}) = Bool
False
cheapExp Loop {} = Bool
False
cheapExp (Match [SubExp]
_ [Case (Body rep)]
cases Body rep
defbranch MatchDec (BranchType rep)
_) =
  (Case (Body rep) -> Bool) -> [Case (Body rep)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((Stm rep -> Bool) -> Seq (Stm rep) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Stm rep -> Bool
forall rep. ASTRep rep => Stm rep -> Bool
cheapStm (Seq (Stm rep) -> Bool)
-> (Case (Body rep) -> Seq (Stm rep)) -> Case (Body rep) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body rep -> Seq (Stm rep)
forall rep. Body rep -> Stms rep
bodyStms (Body rep -> Seq (Stm rep))
-> (Case (Body rep) -> Body rep)
-> Case (Body rep)
-> Seq (Stm rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Case (Body rep) -> Body rep
forall body. Case body -> body
caseBody) [Case (Body rep)]
cases
    Bool -> Bool -> Bool
&& (Stm rep -> Bool) -> Seq (Stm rep) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Stm rep -> Bool
forall rep. ASTRep rep => Stm rep -> Bool
cheapStm (Body rep -> Seq (Stm rep)
forall rep. Body rep -> Stms rep
bodyStms Body rep
defbranch)
cheapExp (Op Op rep
op) = Op rep -> Bool
forall rep. ASTRep rep => OpC rep rep -> Bool
forall (op :: * -> *) rep. (IsOp op, ASTRep rep) => op rep -> Bool
cheapOp Op rep
op
cheapExp Exp rep
_ = Bool
True -- Used to be False, but
-- let's try it out.

loopInvariantStm :: (ASTRep rep) => ST.SymbolTable rep -> Stm rep -> Bool
loopInvariantStm :: forall rep. ASTRep rep => SymbolTable rep -> Stm rep -> Bool
loopInvariantStm SymbolTable rep
vtable =
  (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> Names -> Bool
`nameIn` SymbolTable rep -> Names
forall rep. SymbolTable rep -> Names
ST.availableAtClosestLoop SymbolTable rep
vtable) ([VName] -> Bool) -> (Stm rep -> [VName]) -> Stm rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList (Names -> [VName]) -> (Stm rep -> Names) -> Stm rep -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Names
forall a. FreeIn a => a -> Names
freeIn

matchBlocker ::
  (SimplifiableRep rep) =>
  [SubExp] ->
  MatchDec rt ->
  SimpleM rep (BlockPred (Wise rep))
matchBlocker :: forall rep rt.
SimplifiableRep rep =>
[SubExp] -> MatchDec rt -> SimpleM rep (BlockPred (Wise rep))
matchBlocker [SubExp]
cond (MatchDec [rt]
_ MatchSort
ifsort) = do
  is_alloc_fun <- (Env rep -> Stm (Wise rep) -> Bool)
-> SimpleM rep (Stm (Wise rep) -> Bool)
forall {k} (rep :: k) a. (Env rep -> a) -> SimpleM rep a
asksEngineEnv ((Env rep -> Stm (Wise rep) -> Bool)
 -> SimpleM rep (Stm (Wise rep) -> Bool))
-> (Env rep -> Stm (Wise rep) -> Bool)
-> SimpleM rep (Stm (Wise rep) -> Bool)
forall a b. (a -> b) -> a -> b
$ HoistBlockers rep -> Stm (Wise rep) -> Bool
forall {k} (rep :: k). HoistBlockers rep -> Stm (Wise rep) -> Bool
isAllocation (HoistBlockers rep -> Stm (Wise rep) -> Bool)
-> (Env rep -> HoistBlockers rep)
-> Env rep
-> Stm (Wise rep)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env rep -> HoistBlockers rep
forall {k} (rep :: k). Env rep -> HoistBlockers rep
envHoistBlockers
  branch_blocker <- asksEngineEnv $ blockHoistBranch . envHoistBlockers
  vtable <- askVtable
  let -- We are unwilling to hoist things that are unsafe or costly,
      -- except if they are invariant to the most enclosing loop,
      -- because in that case they will also be hoisted past that
      -- loop.
      --
      -- We also try very hard to hoist allocations or anything that
      -- contributes to memory or array size, because that will allow
      -- allocations to be hoisted.
      cond_loop_invariant =
        (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> Names -> Bool
`nameIn` SymbolTable (Wise rep) -> Names
forall rep. SymbolTable rep -> Names
ST.availableAtClosestLoop SymbolTable (Wise rep)
vtable) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Names
forall a. FreeIn a => a -> Names
freeIn [SubExp]
cond

      desirableToHoist UsageTable
usage Stm (Wise rep)
stm =
        Stm (Wise rep) -> Bool
is_alloc_fun Stm (Wise rep)
stm
          Bool -> Bool -> Bool
|| ( SymbolTable (Wise rep) -> Int
forall rep. SymbolTable rep -> Int
ST.loopDepth SymbolTable (Wise rep)
vtable Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0
                 Bool -> Bool -> Bool
&& Bool
cond_loop_invariant
                 Bool -> Bool -> Bool
&& MatchSort
ifsort MatchSort -> MatchSort -> Bool
forall a. Eq a => a -> a -> Bool
/= MatchSort
MatchFallback
                 Bool -> Bool -> Bool
&& SymbolTable (Wise rep) -> Stm (Wise rep) -> Bool
forall rep. ASTRep rep => SymbolTable rep -> Stm rep -> Bool
loopInvariantStm SymbolTable (Wise rep)
vtable Stm (Wise rep)
stm
                 -- Avoid hoisting out something that might change the
                 -- asymptotics of the program.
                 Bool -> Bool -> Bool
&& ( (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Pat (VarWisdom, LetDec rep) -> [Type]
forall dec. Typed dec => Pat dec -> [Type]
patTypes (Stm (Wise rep) -> Pat (LetDec (Wise rep))
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm (Wise rep)
stm))
                        Bool -> Bool -> Bool
|| (MatchSort
ifsort MatchSort -> MatchSort -> Bool
forall a. Eq a => a -> a -> Bool
== MatchSort
MatchEquiv Bool -> Bool -> Bool
&& Exp (Wise rep) -> Bool
forall {rep}. Exp rep -> Bool
isManifest (Stm (Wise rep) -> Exp (Wise rep)
forall rep. Stm rep -> Exp rep
stmExp Stm (Wise rep)
stm))
                    )
             )
          Bool -> Bool -> Bool
|| ( MatchSort
ifsort MatchSort -> MatchSort -> Bool
forall a. Eq a => a -> a -> Bool
/= MatchSort
MatchFallback
                 Bool -> Bool -> Bool
&& (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> UsageTable -> Bool
`UT.isSize` UsageTable
usage) (Pat (VarWisdom, LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames (Stm (Wise rep) -> Pat (LetDec (Wise rep))
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm (Wise rep)
stm))
                 Bool -> Bool -> Bool
&& (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Pat (VarWisdom, LetDec rep) -> [Type]
forall dec. Typed dec => Pat dec -> [Type]
patTypes (Stm (Wise rep) -> Pat (LetDec (Wise rep))
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm (Wise rep)
stm))
             )
      notDesirableToHoist SymbolTable (Wise rep)
_ UsageTable
usage Stm (Wise rep)
stm = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ UsageTable -> Stm (Wise rep) -> Bool
desirableToHoist UsageTable
usage Stm (Wise rep)
stm

      -- No matter what, we always want to hoist constants as much as
      -- possible.
      isNotHoistableBnd SymbolTable (Wise rep)
_ UsageTable
_ (Let Pat (LetDec (Wise rep))
_ StmAux (ExpDec (Wise rep))
_ (BasicOp ArrayLit {})) = Bool
False
      isNotHoistableBnd SymbolTable (Wise rep)
_ UsageTable
_ (Let Pat (LetDec (Wise rep))
_ StmAux (ExpDec (Wise rep))
_ (BasicOp SubExp {})) = Bool
False
      -- Hoist things that are free.
      isNotHoistableBnd SymbolTable (Wise rep)
_ UsageTable
_ (Let Pat (LetDec (Wise rep))
_ StmAux (ExpDec (Wise rep))
_ (BasicOp Reshape {})) = Bool
False
      isNotHoistableBnd SymbolTable (Wise rep)
_ UsageTable
_ (Let Pat (LetDec (Wise rep))
_ StmAux (ExpDec (Wise rep))
_ (BasicOp Rearrange {})) = Bool
False
      isNotHoistableBnd SymbolTable (Wise rep)
_ UsageTable
_ (Let Pat (LetDec (Wise rep))
_ StmAux (ExpDec (Wise rep))
_ (BasicOp (Index VName
_ Slice SubExp
slice))) =
        [SubExp] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([SubExp] -> Bool) -> [SubExp] -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice
      --
      isNotHoistableBnd SymbolTable (Wise rep)
_ UsageTable
_ Stm (Wise rep)
stm
        | Stm (Wise rep) -> Bool
is_alloc_fun Stm (Wise rep)
stm = Bool
False
      isNotHoistableBnd SymbolTable (Wise rep)
_ UsageTable
_ Stm (Wise rep)
_ =
        -- Hoist aggressively out of versioning branches.
        MatchSort
ifsort MatchSort -> MatchSort -> Bool
forall a. Eq a => a -> a -> Bool
/= MatchSort
MatchEquiv

      isManifest (BasicOp Manifest {}) = Bool
True
      isManifest Exp rep
_ = Bool
False

      block =
        BlockPred (Wise rep)
branch_blocker
          BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`orIf` ( (BlockPred (Wise rep)
forall rep. ASTRep rep => BlockPred rep
isNotSafe BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`orIf` BlockPred (Wise rep)
forall rep. ASTRep rep => BlockPred rep
isNotCheap BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`orIf` BlockPred (Wise rep)
isNotHoistableBnd)
                     BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`andAlso` BlockPred (Wise rep)
notDesirableToHoist
                 )
          BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`orIf` BlockPred (Wise rep)
forall rep. Aliased rep => BlockPred rep
isConsuming
  pure block

-- | Simplify a single body.
simplifyBody ::
  (SimplifiableRep rep) =>
  BlockPred (Wise rep) ->
  UT.UsageTable ->
  [UT.Usages] ->
  Body (Wise rep) ->
  SimpleM rep (Stms (Wise rep), Body (Wise rep))
simplifyBody :: forall rep.
SimplifiableRep rep =>
BlockPred (Wise rep)
-> UsageTable
-> [Usages]
-> Body (Wise rep)
-> SimpleM rep (Stms (Wise rep), Body (Wise rep))
simplifyBody BlockPred (Wise rep)
blocker UsageTable
usage [Usages]
res_usages (Body BodyDec (Wise rep)
_ Stms (Wise rep)
stms Result
res) = do
  (res', stms', hoisted) <-
    BlockPred (Wise rep)
-> Stms (Wise rep)
-> SimpleM rep (Result, UsageTable)
-> SimpleM rep (Result, Stms (Wise rep), Stms (Wise rep))
forall rep a.
SimplifiableRep rep =>
BlockPred (Wise rep)
-> Stms (Wise rep)
-> SimpleM rep (a, UsageTable)
-> SimpleM rep (a, Stms (Wise rep), Stms (Wise rep))
blockIf BlockPred (Wise rep)
blocker Stms (Wise rep)
stms (SimpleM rep (Result, UsageTable)
 -> SimpleM rep (Result, Stms (Wise rep), Stms (Wise rep)))
-> SimpleM rep (Result, UsageTable)
-> SimpleM rep (Result, Stms (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$ do
      (res', res_usage) <- [Usages] -> Result -> SimpleM rep (Result, UsageTable)
forall rep.
SimplifiableRep rep =>
[Usages] -> Result -> SimpleM rep (Result, UsageTable)
simplifyResult [Usages]
res_usages Result
res
      pure (res', res_usage <> usage)
  body' <- constructBody stms' res'
  pure (hoisted, body')

-- | Simplify a single body.
simplifyBodyNoHoisting ::
  (SimplifiableRep rep) =>
  UT.UsageTable ->
  [UT.Usages] ->
  Body (Wise rep) ->
  SimpleM rep (Body (Wise rep))
simplifyBodyNoHoisting :: forall rep.
SimplifiableRep rep =>
UsageTable
-> [Usages] -> Body (Wise rep) -> SimpleM rep (Body (Wise rep))
simplifyBodyNoHoisting UsageTable
usage [Usages]
res_usages Body (Wise rep)
body =
  (Stms (Wise rep), Body (Wise rep)) -> Body (Wise rep)
forall a b. (a, b) -> b
snd ((Stms (Wise rep), Body (Wise rep)) -> Body (Wise rep))
-> SimpleM rep (Stms (Wise rep), Body (Wise rep))
-> SimpleM rep (Body (Wise rep))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BlockPred (Wise rep)
-> UsageTable
-> [Usages]
-> Body (Wise rep)
-> SimpleM rep (Stms (Wise rep), Body (Wise rep))
forall rep.
SimplifiableRep rep =>
BlockPred (Wise rep)
-> UsageTable
-> [Usages]
-> Body (Wise rep)
-> SimpleM rep (Stms (Wise rep), Body (Wise rep))
simplifyBody (Bool -> BlockPred (Wise rep)
forall rep. Bool -> BlockPred rep
isFalse Bool
False) UsageTable
usage [Usages]
res_usages Body (Wise rep)
body

usageFromDiet :: Diet -> UT.Usages
usageFromDiet :: Diet -> Usages
usageFromDiet Diet
Consume = Usages
UT.consumedU
usageFromDiet Diet
_ = Usages
forall a. Monoid a => a
mempty

-- | Simplify a single 'Result'.
simplifyResult ::
  (SimplifiableRep rep) => [UT.Usages] -> Result -> SimpleM rep (Result, UT.UsageTable)
simplifyResult :: forall rep.
SimplifiableRep rep =>
[Usages] -> Result -> SimpleM rep (Result, UsageTable)
simplifyResult [Usages]
usages Result
res = do
  res' <- (SubExpRes -> SimpleM rep SubExpRes)
-> Result -> SimpleM rep Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExpRes -> SimpleM rep SubExpRes
forall rep.
SimplifiableRep rep =>
SubExpRes -> SimpleM rep SubExpRes
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify Result
res
  vtable <- askVtable
  let more_usages = [UsageTable] -> UsageTable
forall a. Monoid a => [a] -> a
mconcat ([UsageTable] -> UsageTable) -> [UsageTable] -> UsageTable
forall a b. (a -> b) -> a -> b
$ do
        (u, Var v) <- [Usages] -> [SubExp] -> [(Usages, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Usages]
usages ([SubExp] -> [(Usages, SubExp)]) -> [SubExp] -> [(Usages, SubExp)]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
res
        let als_usages =
              (VName -> UsageTable) -> [VName] -> [UsageTable]
forall a b. (a -> b) -> [a] -> [b]
map
                (VName -> Usages -> UsageTable
`UT.usage` (Usages
u Usages -> Usages -> Usages
`UT.withoutU` Usages
UT.presentU))
                (Names -> [VName]
namesToList (VName -> SymbolTable (Wise rep) -> Names
forall rep. VName -> SymbolTable rep -> Names
ST.lookupAliases VName
v SymbolTable (Wise rep)
vtable))
        UT.usage v u : als_usages
  pure
    ( res',
      UT.usages (freeIn res')
        <> foldMap UT.inResultUsage (namesToList (freeIn res'))
        <> more_usages
    )

isLoopResult :: Result -> UT.UsageTable
isLoopResult :: Result -> UsageTable
isLoopResult = [UsageTable] -> UsageTable
forall a. Monoid a => [a] -> a
mconcat ([UsageTable] -> UsageTable)
-> (Result -> [UsageTable]) -> Result -> UsageTable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExpRes -> UsageTable) -> Result -> [UsageTable]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> UsageTable
checkForVar
  where
    checkForVar :: SubExpRes -> UsageTable
checkForVar (SubExpRes Certs
_ (Var VName
ident)) = VName -> UsageTable
UT.inResultUsage VName
ident
    checkForVar SubExpRes
_ = UsageTable
forall a. Monoid a => a
mempty

simplifyStms ::
  (SimplifiableRep rep) =>
  Stms (Wise rep) ->
  SimpleM rep (Stms (Wise rep))
simplifyStms :: forall rep.
SimplifiableRep rep =>
Stms (Wise rep) -> SimpleM rep (Stms (Wise rep))
simplifyStms Stms (Wise rep)
stms = UsageTable -> Stms (Wise rep) -> SimpleM rep (Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
UsageTable -> Stms (Wise rep) -> SimpleM rep (Stms (Wise rep))
simplifyStmsWithUsage UsageTable
usage Stms (Wise rep)
stms
  where
    -- XXX: treat everything as consumed, because when these are
    -- constants it is otherwise complicated to ensure we do not
    -- introduce more aliasing than specified by the return types.
    -- CSE has the same problem.
    all_bound :: [VName]
all_bound = Map VName (NameInfo (Wise rep)) -> [VName]
forall k a. Map k a -> [k]
M.keys (Stms (Wise rep) -> Map VName (NameInfo (Wise rep))
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms (Wise rep)
stms)
    usage :: UsageTable
usage =
      Names -> UsageTable
UT.usages ([VName] -> Names
namesFromList [VName]
all_bound)
        UsageTable -> UsageTable -> UsageTable
forall a. Semigroup a => a -> a -> a
<> (VName -> UsageTable) -> [VName] -> UsageTable
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap VName -> UsageTable
UT.consumedUsage [VName]
all_bound

simplifyStmsWithUsage ::
  (SimplifiableRep rep) =>
  UT.UsageTable ->
  Stms (Wise rep) ->
  SimpleM rep (Stms (Wise rep))
simplifyStmsWithUsage :: forall rep.
SimplifiableRep rep =>
UsageTable -> Stms (Wise rep) -> SimpleM rep (Stms (Wise rep))
simplifyStmsWithUsage UsageTable
usage Stms (Wise rep)
stms = do
  ((), stms', _) <- BlockPred (Wise rep)
-> Stms (Wise rep)
-> SimpleM rep ((), UsageTable)
-> SimpleM rep ((), Stms (Wise rep), Stms (Wise rep))
forall rep a.
SimplifiableRep rep =>
BlockPred (Wise rep)
-> Stms (Wise rep)
-> SimpleM rep (a, UsageTable)
-> SimpleM rep (a, Stms (Wise rep), Stms (Wise rep))
blockIf (Bool -> BlockPred (Wise rep)
forall rep. Bool -> BlockPred rep
isFalse Bool
False) Stms (Wise rep)
stms (SimpleM rep ((), UsageTable)
 -> SimpleM rep ((), Stms (Wise rep), Stms (Wise rep)))
-> SimpleM rep ((), UsageTable)
-> SimpleM rep ((), Stms (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$ ((), UsageTable) -> SimpleM rep ((), UsageTable)
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((), UsageTable
usage)
  pure stms'

simplifyOp :: Op (Wise rep) -> SimpleM rep (Op (Wise rep), Stms (Wise rep))
simplifyOp :: forall {k} (rep :: k).
Op (Wise rep) -> SimpleM rep (Op (Wise rep), Stms (Wise rep))
simplifyOp Op (Wise rep)
op = do
  f <- ((SimpleOps rep, Env rep)
 -> Op (Wise rep) -> SimpleM rep (Op (Wise rep), Stms (Wise rep)))
-> SimpleM
     rep (Op (Wise rep) -> SimpleM rep (Op (Wise rep), Stms (Wise rep)))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (((SimpleOps rep, Env rep)
  -> Op (Wise rep) -> SimpleM rep (Op (Wise rep), Stms (Wise rep)))
 -> SimpleM
      rep
      (Op (Wise rep) -> SimpleM rep (Op (Wise rep), Stms (Wise rep))))
-> ((SimpleOps rep, Env rep)
    -> Op (Wise rep) -> SimpleM rep (Op (Wise rep), Stms (Wise rep)))
-> SimpleM
     rep (Op (Wise rep) -> SimpleM rep (Op (Wise rep), Stms (Wise rep)))
forall a b. (a -> b) -> a -> b
$ SimpleOps rep
-> Op (Wise rep) -> SimpleM rep (Op (Wise rep), Stms (Wise rep))
forall {k} (rep :: k).
SimpleOps rep -> SimplifyOp rep (Op (Wise rep))
simplifyOpS (SimpleOps rep
 -> Op (Wise rep) -> SimpleM rep (Op (Wise rep), Stms (Wise rep)))
-> ((SimpleOps rep, Env rep) -> SimpleOps rep)
-> (SimpleOps rep, Env rep)
-> Op (Wise rep)
-> SimpleM rep (Op (Wise rep), Stms (Wise rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SimpleOps rep, Env rep) -> SimpleOps rep
forall a b. (a, b) -> a
fst
  f op

simplifyExp ::
  (SimplifiableRep rep) =>
  UT.UsageTable ->
  Pat (LetDec (Wise rep)) ->
  Exp (Wise rep) ->
  SimpleM rep (Exp (Wise rep), Stms (Wise rep))
simplifyExp :: forall rep.
SimplifiableRep rep =>
UsageTable
-> Pat (LetDec (Wise rep))
-> Exp (Wise rep)
-> SimpleM rep (Exp (Wise rep), Stms (Wise rep))
simplifyExp UsageTable
usage (Pat [PatElem (LetDec (Wise rep))]
pes) (Match [SubExp]
ses [Case (Body (Wise rep))]
cases Body (Wise rep)
defbody ifdec :: MatchDec (BranchType (Wise rep))
ifdec@(MatchDec [BranchType (Wise rep)]
ts MatchSort
ifsort)) = do
  let pes_usages :: [Usages]
pes_usages = (PatElem (VarWisdom, LetDec rep) -> Usages)
-> [PatElem (VarWisdom, LetDec rep)] -> [Usages]
forall a b. (a -> b) -> [a] -> [b]
map (Usages -> Maybe Usages -> Usages
forall a. a -> Maybe a -> a
fromMaybe Usages
forall a. Monoid a => a
mempty (Maybe Usages -> Usages)
-> (PatElem (VarWisdom, LetDec rep) -> Maybe Usages)
-> PatElem (VarWisdom, LetDec rep)
-> Usages
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> UsageTable -> Maybe Usages
`UT.lookup` UsageTable
usage) (VName -> Maybe Usages)
-> (PatElem (VarWisdom, LetDec rep) -> VName)
-> PatElem (VarWisdom, LetDec rep)
-> Maybe Usages
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem (VarWisdom, LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName) [PatElem (VarWisdom, LetDec rep)]
[PatElem (LetDec (Wise rep))]
pes
  ses' <- (SubExp -> SimpleM rep SubExp) -> [SubExp] -> SimpleM rep [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify [SubExp]
ses
  ts' <- mapM simplify ts
  let pats = (Case (Body (Wise rep)) -> [Maybe PrimValue])
-> [Case (Body (Wise rep))] -> [[Maybe PrimValue]]
forall a b. (a -> b) -> [a] -> [b]
map Case (Body (Wise rep)) -> [Maybe PrimValue]
forall body. Case body -> [Maybe PrimValue]
casePat [Case (Body (Wise rep))]
cases
  block <- matchBlocker ses ifdec
  (cases_hoisted, cases') <-
    unzip <$> zipWithM (simplifyCase block ses' pes_usages) (inits pats) cases
  (defbody_hoisted, defbody') <-
    protectCaseHoisted ses' pats [] $
      simplifyBody block usage pes_usages defbody
  pure
    ( Match ses' cases' defbody' $ MatchDec ts' ifsort,
      mconcat $ defbody_hoisted : cases_hoisted
    )
  where
    simplifyCase :: BlockPred (Wise rep)
-> [SubExp]
-> [Usages]
-> [[Maybe PrimValue]]
-> Case (Body (Wise rep))
-> SimpleM rep (Stms (Wise rep), Case (Body (Wise rep)))
simplifyCase BlockPred (Wise rep)
block [SubExp]
ses' [Usages]
pes_usages [[Maybe PrimValue]]
prior (Case [Maybe PrimValue]
vs Body (Wise rep)
body) = do
      (hoisted, body') <-
        [SubExp]
-> [[Maybe PrimValue]]
-> [Maybe PrimValue]
-> SimpleM rep (Stms (Wise rep), Body (Wise rep))
-> SimpleM rep (Stms (Wise rep), Body (Wise rep))
forall rep a.
SimplifiableRep rep =>
[SubExp]
-> [[Maybe PrimValue]]
-> [Maybe PrimValue]
-> SimpleM rep (Stms (Wise rep), a)
-> SimpleM rep (Stms (Wise rep), a)
protectCaseHoisted [SubExp]
ses' [[Maybe PrimValue]]
prior [Maybe PrimValue]
vs (SimpleM rep (Stms (Wise rep), Body (Wise rep))
 -> SimpleM rep (Stms (Wise rep), Body (Wise rep)))
-> SimpleM rep (Stms (Wise rep), Body (Wise rep))
-> SimpleM rep (Stms (Wise rep), Body (Wise rep))
forall a b. (a -> b) -> a -> b
$
          BlockPred (Wise rep)
-> UsageTable
-> [Usages]
-> Body (Wise rep)
-> SimpleM rep (Stms (Wise rep), Body (Wise rep))
forall rep.
SimplifiableRep rep =>
BlockPred (Wise rep)
-> UsageTable
-> [Usages]
-> Body (Wise rep)
-> SimpleM rep (Stms (Wise rep), Body (Wise rep))
simplifyBody BlockPred (Wise rep)
block UsageTable
usage [Usages]
pes_usages Body (Wise rep)
body
      pure (hoisted, Case vs body')
simplifyExp UsageTable
_ Pat (LetDec (Wise rep))
_ (Loop [(FParam (Wise rep), SubExp)]
merge LoopForm
form Body (Wise rep)
loopbody) = do
  let ([Param (FParamInfo rep)]
params, [SubExp]
args) = [(Param (FParamInfo rep), SubExp)]
-> ([Param (FParamInfo rep)], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param (FParamInfo rep), SubExp)]
[(FParam (Wise rep), SubExp)]
merge
  params' <- (Param (FParamInfo rep) -> SimpleM rep (Param (FParamInfo rep)))
-> [Param (FParamInfo rep)] -> SimpleM rep [Param (FParamInfo rep)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((FParamInfo rep -> SimpleM rep (FParamInfo rep))
-> Param (FParamInfo rep) -> SimpleM rep (Param (FParamInfo rep))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Param a -> f (Param b)
traverse FParamInfo rep -> SimpleM rep (FParamInfo rep)
forall rep.
SimplifiableRep rep =>
FParamInfo rep -> SimpleM rep (FParamInfo rep)
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify) [Param (FParamInfo rep)]
params
  args' <- mapM simplify args
  let merge' = [Param (FParamInfo rep)]
-> [SubExp] -> [(Param (FParamInfo rep), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (FParamInfo rep)]
params' [SubExp]
args'
  (form', boundnames, wrapbody) <- case form of
    ForLoop VName
loopvar IntType
it SubExp
boundexp -> do
      boundexp' <- SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify SubExp
boundexp
      let form' = VName -> IntType -> SubExp -> LoopForm
ForLoop VName
loopvar IntType
it SubExp
boundexp'
      pure
        ( form',
          oneName loopvar <> fparamnames,
          bindLoopVar loopvar it boundexp' . protectLoopHoisted merge' form'
        )
    WhileLoop VName
cond -> do
      cond' <- VName -> SimpleM rep VName
forall rep. SimplifiableRep rep => VName -> SimpleM rep VName
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify VName
cond
      pure
        ( WhileLoop cond',
          fparamnames,
          protectLoopHoisted merge' (WhileLoop cond')
        )
  seq_blocker <- asksEngineEnv $ blockHoistSeq . envHoistBlockers
  (loopres, loopstms, hoisted) <-
    enterLoop . consumeMerge
      $ bindMerge (zipWith withRes merge' (bodyResult loopbody)) . wrapbody
      $ blockIf
        ( hasFree boundnames
            `orIf` isConsumed
            `orIf` seq_blocker
            `orIf` notWorthHoisting
        )
        (bodyStms loopbody)
      $ do
        let params_usages =
              (Param (FParamInfo rep) -> Usages)
-> [Param (FParamInfo rep)] -> [Usages]
forall a b. (a -> b) -> [a] -> [b]
map
                (\Param (FParamInfo rep)
p -> if DeclType -> Bool
forall shape. TypeBase shape Uniqueness -> Bool
unique (Param (FParamInfo rep) -> DeclType
forall dec. DeclTyped dec => Param dec -> DeclType
paramDeclType Param (FParamInfo rep)
p) then Usages
UT.consumedU else Usages
forall a. Monoid a => a
mempty)
                [Param (FParamInfo rep)]
params'
        (res, uses) <- simplifyResult params_usages $ bodyResult loopbody
        pure (res, uses <> isLoopResult res)
  loopbody' <- constructBody loopstms loopres
  pure (Loop merge' form' loopbody', hoisted)
  where
    fparamnames :: Names
fparamnames =
      [VName] -> Names
namesFromList (((Param (FParamInfo rep), SubExp) -> VName)
-> [(Param (FParamInfo rep), SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param (FParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName (Param (FParamInfo rep) -> VName)
-> ((Param (FParamInfo rep), SubExp) -> Param (FParamInfo rep))
-> (Param (FParamInfo rep), SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param (FParamInfo rep), SubExp) -> Param (FParamInfo rep)
forall a b. (a, b) -> a
fst) [(Param (FParamInfo rep), SubExp)]
[(FParam (Wise rep), SubExp)]
merge)
    consumeMerge :: SimpleM rep (Result, Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep (Result, Stms (Wise rep), Stms (Wise rep))
consumeMerge =
      (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep (Result, Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep (Result, Stms (Wise rep), Stms (Wise rep))
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
localVtable ((SymbolTable (Wise rep) -> SymbolTable (Wise rep))
 -> SimpleM rep (Result, Stms (Wise rep), Stms (Wise rep))
 -> SimpleM rep (Result, Stms (Wise rep), Stms (Wise rep)))
-> (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep (Result, Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep (Result, Stms (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$ (SymbolTable (Wise rep) -> [VName] -> SymbolTable (Wise rep))
-> [VName] -> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((SymbolTable (Wise rep) -> VName -> SymbolTable (Wise rep))
-> SymbolTable (Wise rep) -> [VName] -> SymbolTable (Wise rep)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((VName -> SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SymbolTable (Wise rep) -> VName -> SymbolTable (Wise rep)
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall rep. VName -> SymbolTable rep -> SymbolTable rep
ST.consume)) ([VName] -> SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> [VName] -> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
consumed_by_merge
    consumed_by_merge :: Names
consumed_by_merge =
      [SubExp] -> Names
forall a. FreeIn a => a -> Names
freeIn ([SubExp] -> Names) -> [SubExp] -> Names
forall a b. (a -> b) -> a -> b
$ ((Param (FParamInfo rep), SubExp) -> SubExp)
-> [(Param (FParamInfo rep), SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (Param (FParamInfo rep), SubExp) -> SubExp
forall a b. (a, b) -> b
snd ([(Param (FParamInfo rep), SubExp)] -> [SubExp])
-> [(Param (FParamInfo rep), SubExp)] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ ((Param (FParamInfo rep), SubExp) -> Bool)
-> [(Param (FParamInfo rep), SubExp)]
-> [(Param (FParamInfo rep), SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (DeclType -> Bool
forall shape. TypeBase shape Uniqueness -> Bool
unique (DeclType -> Bool)
-> ((Param (FParamInfo rep), SubExp) -> DeclType)
-> (Param (FParamInfo rep), SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (FParamInfo rep) -> DeclType
forall dec. DeclTyped dec => Param dec -> DeclType
paramDeclType (Param (FParamInfo rep) -> DeclType)
-> ((Param (FParamInfo rep), SubExp) -> Param (FParamInfo rep))
-> (Param (FParamInfo rep), SubExp)
-> DeclType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param (FParamInfo rep), SubExp) -> Param (FParamInfo rep)
forall a b. (a, b) -> a
fst) [(Param (FParamInfo rep), SubExp)]
[(FParam (Wise rep), SubExp)]
merge
    withRes :: (a, b) -> c -> (a, b, c)
withRes (a
p, b
x) c
y = (a
p, b
x, c
y)
simplifyExp UsageTable
_ Pat (LetDec (Wise rep))
_ (Op Op (Wise rep)
op) = do
  (op', stms) <- Op (Wise rep) -> SimpleM rep (Op (Wise rep), Stms (Wise rep))
forall {k} (rep :: k).
Op (Wise rep) -> SimpleM rep (Op (Wise rep), Stms (Wise rep))
simplifyOp Op (Wise rep)
op
  pure (Op op', stms)
simplifyExp UsageTable
usage Pat (LetDec (Wise rep))
_ (WithAcc [WithAccInput (Wise rep)]
inputs Lambda (Wise rep)
lam) = do
  (inputs', inputs_stms) <- ([(WithAccInput (Wise rep), Stms (Wise rep))]
 -> ([WithAccInput (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep [(WithAccInput (Wise rep), Stms (Wise rep))]
-> SimpleM rep ([WithAccInput (Wise rep)], [Stms (Wise rep)])
forall a b. (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(WithAccInput (Wise rep), Stms (Wise rep))]
-> ([WithAccInput (Wise rep)], [Stms (Wise rep)])
forall a b. [(a, b)] -> ([a], [b])
unzip (SimpleM rep [(WithAccInput (Wise rep), Stms (Wise rep))]
 -> SimpleM rep ([WithAccInput (Wise rep)], [Stms (Wise rep)]))
-> ((WithAccInput (Wise rep)
     -> SimpleM rep (WithAccInput (Wise rep), Stms (Wise rep)))
    -> SimpleM rep [(WithAccInput (Wise rep), Stms (Wise rep))])
-> (WithAccInput (Wise rep)
    -> SimpleM rep (WithAccInput (Wise rep), Stms (Wise rep)))
-> SimpleM rep ([WithAccInput (Wise rep)], [Stms (Wise rep)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [WithAccInput (Wise rep)]
-> (WithAccInput (Wise rep)
    -> SimpleM rep (WithAccInput (Wise rep), Stms (Wise rep)))
-> SimpleM rep [(WithAccInput (Wise rep), Stms (Wise rep))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [WithAccInput (Wise rep)]
inputs ((WithAccInput (Wise rep)
  -> SimpleM rep (WithAccInput (Wise rep), Stms (Wise rep)))
 -> SimpleM rep ([WithAccInput (Wise rep)], [Stms (Wise rep)]))
-> (WithAccInput (Wise rep)
    -> SimpleM rep (WithAccInput (Wise rep), Stms (Wise rep)))
-> SimpleM rep ([WithAccInput (Wise rep)], [Stms (Wise rep)])
forall a b. (a -> b) -> a -> b
$ \(ShapeBase SubExp
shape, [VName]
arrs, Maybe (Lambda (Wise rep), [SubExp])
op) -> do
    (op', op_stms) <- case Maybe (Lambda (Wise rep), [SubExp])
op of
      Maybe (Lambda (Wise rep), [SubExp])
Nothing ->
        (Maybe (Lambda (Wise rep), [SubExp]), Stms (Wise rep))
-> SimpleM
     rep (Maybe (Lambda (Wise rep), [SubExp]), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Lambda (Wise rep), [SubExp])
forall a. Maybe a
Nothing, Stms (Wise rep)
forall a. Monoid a => a
mempty)
      Just (Lambda (Wise rep)
op_lam, [SubExp]
nes) -> do
        (op_lam', op_lam_stms) <- SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
blockMigrated (Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambda Names
forall a. Monoid a => a
mempty Lambda (Wise rep)
op_lam)
        nes' <- simplify nes
        pure (Just (op_lam', nes'), op_lam_stms)
    (,op_stms) <$> ((,,op') <$> simplify shape <*> simplify arrs)
  let noteAcc = [(VName, WithAccInput (Wise rep))]
-> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall rep.
[(VName, WithAccInput rep)] -> SymbolTable rep -> SymbolTable rep
ST.noteAccTokens ([VName]
-> [WithAccInput (Wise rep)] -> [(VName, WithAccInput (Wise rep))]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param (LParamInfo rep) -> VName)
-> [Param (LParamInfo rep)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName (Lambda (Wise rep) -> [LParam (Wise rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda (Wise rep)
lam)) [WithAccInput (Wise rep)]
inputs')
  (lam', lam_stms) <-
    consumeInput inputs' $
      simplifyLambdaWith noteAcc (isFalse True) usage lam
  pure (WithAcc inputs' lam', mconcat inputs_stms <> lam_stms)
  where
    inputArrs :: (a, b, c) -> b
inputArrs (a
_, b
arrs, c
_) = b
arrs
    consumeInput :: [(a, [VName], c)] -> SimpleM rep a -> SimpleM rep a
consumeInput =
      (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
localVtable ((SymbolTable (Wise rep) -> SymbolTable (Wise rep))
 -> SimpleM rep a -> SimpleM rep a)
-> ([(a, [VName], c)]
    -> SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> [(a, [VName], c)]
-> SimpleM rep a
-> SimpleM rep a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SymbolTable (Wise rep) -> [VName] -> SymbolTable (Wise rep))
-> [VName] -> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((SymbolTable (Wise rep) -> VName -> SymbolTable (Wise rep))
-> SymbolTable (Wise rep) -> [VName] -> SymbolTable (Wise rep)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((VName -> SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SymbolTable (Wise rep) -> VName -> SymbolTable (Wise rep)
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall rep. VName -> SymbolTable rep -> SymbolTable rep
ST.consume)) ([VName] -> SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> ([(a, [VName], c)] -> [VName])
-> [(a, [VName], c)]
-> SymbolTable (Wise rep)
-> SymbolTable (Wise rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((a, [VName], c) -> [VName]) -> [(a, [VName], c)] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (a, [VName], c) -> [VName]
forall {a} {b} {c}. (a, b, c) -> b
inputArrs
simplifyExp UsageTable
_ Pat (LetDec (Wise rep))
_ Exp (Wise rep)
e = do
  e' <- Exp (Wise rep) -> SimpleM rep (Exp (Wise rep))
forall rep.
SimplifiableRep rep =>
Exp (Wise rep) -> SimpleM rep (Exp (Wise rep))
simplifyExpBase Exp (Wise rep)
e
  pure (e', mempty)

-- | Block hoisting of 'Index' statements introduced by migration.
blockMigrated ::
  (SimplifiableRep rep) =>
  SimpleM rep (Lambda (Wise rep), Stms (Wise rep)) ->
  SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
blockMigrated :: forall rep.
SimplifiableRep rep =>
SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
blockMigrated = ((SimpleOps rep, Env rep) -> (SimpleOps rep, Env rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall a.
((SimpleOps rep, Env rep) -> (SimpleOps rep, Env rep))
-> SimpleM rep a -> SimpleM rep a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (SimpleOps rep, Env rep) -> (SimpleOps rep, Env rep)
forall {rep} {a}.
(ASTRep rep, Simplifiable (LetDec rep),
 Simplifiable (FParamInfo rep), Simplifiable (LParamInfo rep),
 Simplifiable (RetType rep), Simplifiable (BranchType rep),
 TraverseOpStms (Wise rep), CanBeWise (OpC rep),
 IndexOp (OpC rep (Wise rep)), AliasedOp (OpC rep),
 BuilderOps (Wise rep), Ord (OpC rep (Wise rep)),
 Show (OpC rep (Wise rep)), Rename (OpC rep (Wise rep)),
 Substitute (OpC rep (Wise rep)), FreeIn (OpC rep (Wise rep)),
 Pretty (OpC rep (Wise rep))) =>
(a, Env rep) -> (a, Env rep)
withMigrationBlocker
  where
    withMigrationBlocker :: (a, Env rep) -> (a, Env rep)
withMigrationBlocker (a
ops, Env rep
env) =
      let blockers :: HoistBlockers rep
blockers = Env rep -> HoistBlockers rep
forall {k} (rep :: k). Env rep -> HoistBlockers rep
envHoistBlockers Env rep
env
          par_blocker :: BlockPred (Wise rep)
par_blocker = HoistBlockers rep -> BlockPred (Wise rep)
forall {k} (rep :: k). HoistBlockers rep -> BlockPred (Wise rep)
blockHoistPar HoistBlockers rep
blockers

          blocker :: BlockPred (Wise rep)
blocker = BlockPred (Wise rep)
par_blocker BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`orIf` BlockPred (Wise rep)
forall rep. SimplifiableRep rep => BlockPred (Wise rep)
isDeviceMigrated

          blockers' :: HoistBlockers rep
blockers' = HoistBlockers rep
blockers {blockHoistPar = blocker}
          env' :: Env rep
env' = Env rep
env {envHoistBlockers = blockers'}
       in (a
ops, Env rep
env')

-- | Statement is a scalar read from a single element array of rank one.
isDeviceMigrated :: (SimplifiableRep rep) => BlockPred (Wise rep)
isDeviceMigrated :: forall rep. SimplifiableRep rep => BlockPred (Wise rep)
isDeviceMigrated SymbolTable (Wise rep)
vtable UsageTable
_ Stm (Wise rep)
stm
  | BasicOp (Index VName
arr Slice SubExp
slice) <- Stm (Wise rep) -> Exp (Wise rep)
forall rep. Stm rep -> Exp rep
stmExp Stm (Wise rep)
stm,
    [DimFix SubExp
idx] <- Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice,
    SubExp
idx SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0,
    Just Type
arr_t <- VName -> SymbolTable (Wise rep) -> Maybe Type
forall rep. ASTRep rep => VName -> SymbolTable rep -> Maybe Type
ST.lookupType VName
arr SymbolTable (Wise rep)
vtable,
    [SubExp
size] <- Type -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims Type
arr_t,
    SubExp
size SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1 =
      Bool
True
  | Bool
otherwise =
      Bool
False

-- The simple nonrecursive case that we can perform without bottom-up
-- information.
simplifyExpBase :: (SimplifiableRep rep) => Exp (Wise rep) -> SimpleM rep (Exp (Wise rep))
-- Special case for simplification of commutative BinOps where we
-- arrange the operands in sorted order.  This can make expressions
-- more identical, which helps CSE.
simplifyExpBase :: forall rep.
SimplifiableRep rep =>
Exp (Wise rep) -> SimpleM rep (Exp (Wise rep))
simplifyExpBase (BasicOp (BinOp BinOp
op SubExp
x SubExp
y))
  | BinOp -> Bool
commutativeBinOp BinOp
op = do
      x' <- SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify SubExp
x
      y' <- simplify y
      pure $ BasicOp $ BinOp op (min x' y') (max x' y')
simplifyExpBase Exp (Wise rep)
e = Mapper (Wise rep) (Wise rep) (SimpleM rep)
-> Exp (Wise rep) -> SimpleM rep (Exp (Wise rep))
forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper (Wise rep) (Wise rep) (SimpleM rep)
hoist Exp (Wise rep)
e
  where
    hoist :: Mapper (Wise rep) (Wise rep) (SimpleM rep)
hoist =
      Mapper (Wise rep) (Wise rep) (SimpleM rep)
forall rep (m :: * -> *). Monad m => Mapper rep rep m
identityMapper
        { mapOnSubExp = simplify,
          mapOnVName = simplify,
          mapOnRetType = simplify,
          mapOnBranchType = simplify
        }

type SimplifiableRep rep =
  ( ASTRep rep,
    Simplifiable (LetDec rep),
    Simplifiable (FParamInfo rep),
    Simplifiable (LParamInfo rep),
    Simplifiable (RetType rep),
    Simplifiable (BranchType rep),
    TraverseOpStms (Wise rep),
    CanBeWise (OpC rep),
    ST.IndexOp (Op (Wise rep)),
    IsOp (OpC rep),
    ASTConstraints (OpC rep (Wise rep)),
    AliasedOp (OpC (Wise rep)),
    RephraseOp (OpC rep),
    BuilderOps (Wise rep),
    IsOp (OpC rep)
  )

class Simplifiable e where
  simplify :: (SimplifiableRep rep) => e -> SimpleM rep e

instance (Simplifiable a, Simplifiable b) => Simplifiable (a, b) where
  simplify :: forall rep. SimplifiableRep rep => (a, b) -> SimpleM rep (a, b)
simplify (a
x, b
y) = (,) (a -> b -> (a, b)) -> SimpleM rep a -> SimpleM rep (b -> (a, b))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> SimpleM rep a
forall rep. SimplifiableRep rep => a -> SimpleM rep a
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify a
x SimpleM rep (b -> (a, b)) -> SimpleM rep b -> SimpleM rep (a, b)
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> b -> SimpleM rep b
forall rep. SimplifiableRep rep => b -> SimpleM rep b
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify b
y

instance
  (Simplifiable a, Simplifiable b, Simplifiable c) =>
  Simplifiable (a, b, c)
  where
  simplify :: forall rep.
SimplifiableRep rep =>
(a, b, c) -> SimpleM rep (a, b, c)
simplify (a
x, b
y, c
z) = (,,) (a -> b -> c -> (a, b, c))
-> SimpleM rep a -> SimpleM rep (b -> c -> (a, b, c))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> SimpleM rep a
forall rep. SimplifiableRep rep => a -> SimpleM rep a
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify a
x SimpleM rep (b -> c -> (a, b, c))
-> SimpleM rep b -> SimpleM rep (c -> (a, b, c))
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> b -> SimpleM rep b
forall rep. SimplifiableRep rep => b -> SimpleM rep b
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify b
y SimpleM rep (c -> (a, b, c))
-> SimpleM rep c -> SimpleM rep (a, b, c)
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> c -> SimpleM rep c
forall rep. SimplifiableRep rep => c -> SimpleM rep c
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify c
z

-- Convenient for Scatter.
instance Simplifiable Int where
  simplify :: forall rep. SimplifiableRep rep => Int -> SimpleM rep Int
simplify = Int -> SimpleM rep Int
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure

instance (Simplifiable a) => Simplifiable (Maybe a) where
  simplify :: forall rep. SimplifiableRep rep => Maybe a -> SimpleM rep (Maybe a)
simplify Maybe a
Nothing = Maybe a -> SimpleM rep (Maybe a)
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe a
forall a. Maybe a
Nothing
  simplify (Just a
x) = a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> SimpleM rep a -> SimpleM rep (Maybe a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> SimpleM rep a
forall rep. SimplifiableRep rep => a -> SimpleM rep a
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify a
x

instance (Simplifiable a) => Simplifiable [a] where
  simplify :: forall rep. SimplifiableRep rep => [a] -> SimpleM rep [a]
simplify = (a -> SimpleM rep a) -> [a] -> SimpleM rep [a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM a -> SimpleM rep a
forall rep. SimplifiableRep rep => a -> SimpleM rep a
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify

instance Simplifiable SubExp where
  simplify :: forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
simplify (Var VName
name) = do
    stm <- VName -> SymbolTable (Wise rep) -> Maybe (SubExp, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (SubExp, Certs)
ST.lookupSubExp VName
name (SymbolTable (Wise rep) -> Maybe (SubExp, Certs))
-> SimpleM rep (SymbolTable (Wise rep))
-> SimpleM rep (Maybe (SubExp, Certs))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SimpleM rep (SymbolTable (Wise rep))
forall {k} (rep :: k). SimpleM rep (SymbolTable (Wise rep))
askVtable
    case stm of
      Just (Constant PrimValue
v, Certs
cs) -> do
        SimpleM rep ()
forall {k} (rep :: k). SimpleM rep ()
changed
        Certs -> SimpleM rep ()
forall {k} (rep :: k). Certs -> SimpleM rep ()
usedCerts Certs
cs
        SubExp -> SimpleM rep SubExp
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> SimpleM rep SubExp) -> SubExp -> SimpleM rep SubExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
v
      Just (Var VName
id', Certs
cs) -> do
        SimpleM rep ()
forall {k} (rep :: k). SimpleM rep ()
changed
        Certs -> SimpleM rep ()
forall {k} (rep :: k). Certs -> SimpleM rep ()
usedCerts Certs
cs
        SubExp -> SimpleM rep SubExp
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> SimpleM rep SubExp) -> SubExp -> SimpleM rep SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
id'
      Maybe (SubExp, Certs)
_ -> SubExp -> SimpleM rep SubExp
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> SimpleM rep SubExp) -> SubExp -> SimpleM rep SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
name
  simplify (Constant PrimValue
v) =
    SubExp -> SimpleM rep SubExp
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> SimpleM rep SubExp) -> SubExp -> SimpleM rep SubExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
v

instance Simplifiable SubExpRes where
  simplify :: forall rep.
SimplifiableRep rep =>
SubExpRes -> SimpleM rep SubExpRes
simplify (SubExpRes Certs
cs SubExp
se) = do
    cs' <- Certs -> SimpleM rep Certs
forall rep. SimplifiableRep rep => Certs -> SimpleM rep Certs
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify Certs
cs
    (se', se_cs) <- collectCerts $ simplify se
    pure $ SubExpRes (se_cs <> cs') se'

instance Simplifiable () where
  simplify :: forall rep. SimplifiableRep rep => () -> SimpleM rep ()
simplify = () -> SimpleM rep ()
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure

instance Simplifiable VName where
  simplify :: forall rep. SimplifiableRep rep => VName -> SimpleM rep VName
simplify VName
v = do
    se <- VName -> SymbolTable (Wise rep) -> Maybe (SubExp, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (SubExp, Certs)
ST.lookupSubExp VName
v (SymbolTable (Wise rep) -> Maybe (SubExp, Certs))
-> SimpleM rep (SymbolTable (Wise rep))
-> SimpleM rep (Maybe (SubExp, Certs))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SimpleM rep (SymbolTable (Wise rep))
forall {k} (rep :: k). SimpleM rep (SymbolTable (Wise rep))
askVtable
    case se of
      Just (Var VName
v', Certs
cs) -> do
        SimpleM rep ()
forall {k} (rep :: k). SimpleM rep ()
changed
        Certs -> SimpleM rep ()
forall {k} (rep :: k). Certs -> SimpleM rep ()
usedCerts Certs
cs
        VName -> SimpleM rep VName
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v'
      Maybe (SubExp, Certs)
_ -> VName -> SimpleM rep VName
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v

instance (Simplifiable d) => Simplifiable (ShapeBase d) where
  simplify :: forall rep.
SimplifiableRep rep =>
ShapeBase d -> SimpleM rep (ShapeBase d)
simplify = (d -> SimpleM rep d) -> ShapeBase d -> SimpleM rep (ShapeBase d)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> ShapeBase a -> f (ShapeBase b)
traverse d -> SimpleM rep d
forall rep. SimplifiableRep rep => d -> SimpleM rep d
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify

instance Simplifiable ExtSize where
  simplify :: forall rep. SimplifiableRep rep => ExtSize -> SimpleM rep ExtSize
simplify (Free SubExp
se) = SubExp -> ExtSize
forall a. a -> Ext a
Free (SubExp -> ExtSize) -> SimpleM rep SubExp -> SimpleM rep ExtSize
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify SubExp
se
  simplify (Ext Int
x) = ExtSize -> SimpleM rep ExtSize
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExtSize -> SimpleM rep ExtSize) -> ExtSize -> SimpleM rep ExtSize
forall a b. (a -> b) -> a -> b
$ Int -> ExtSize
forall a. Int -> Ext a
Ext Int
x

instance Simplifiable Space where
  simplify :: forall rep. SimplifiableRep rep => Space -> SimpleM rep Space
simplify (ScalarSpace [SubExp]
ds PrimType
t) = [SubExp] -> PrimType -> Space
ScalarSpace ([SubExp] -> PrimType -> Space)
-> SimpleM rep [SubExp] -> SimpleM rep (PrimType -> Space)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [SubExp] -> SimpleM rep [SubExp]
forall rep. SimplifiableRep rep => [SubExp] -> SimpleM rep [SubExp]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify [SubExp]
ds SimpleM rep (PrimType -> Space)
-> SimpleM rep PrimType -> SimpleM rep Space
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> PrimType -> SimpleM rep PrimType
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure PrimType
t
  simplify Space
s = Space -> SimpleM rep Space
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Space
s

instance Simplifiable PrimType where
  simplify :: forall rep. SimplifiableRep rep => PrimType -> SimpleM rep PrimType
simplify = PrimType -> SimpleM rep PrimType
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure

instance (Simplifiable shape) => Simplifiable (TypeBase shape u) where
  simplify :: forall rep.
SimplifiableRep rep =>
TypeBase shape u -> SimpleM rep (TypeBase shape u)
simplify (Array PrimType
et shape
shape u
u) =
    PrimType -> shape -> u -> TypeBase shape u
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array (PrimType -> shape -> u -> TypeBase shape u)
-> SimpleM rep PrimType
-> SimpleM rep (shape -> u -> TypeBase shape u)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PrimType -> SimpleM rep PrimType
forall rep. SimplifiableRep rep => PrimType -> SimpleM rep PrimType
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify PrimType
et SimpleM rep (shape -> u -> TypeBase shape u)
-> SimpleM rep shape -> SimpleM rep (u -> TypeBase shape u)
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> shape -> SimpleM rep shape
forall rep. SimplifiableRep rep => shape -> SimpleM rep shape
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify shape
shape SimpleM rep (u -> TypeBase shape u)
-> SimpleM rep u -> SimpleM rep (TypeBase shape u)
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> u -> SimpleM rep u
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure u
u
  simplify (Acc VName
acc ShapeBase SubExp
ispace [Type]
ts u
u) =
    VName -> ShapeBase SubExp -> [Type] -> u -> TypeBase shape u
forall shape u.
VName -> ShapeBase SubExp -> [Type] -> u -> TypeBase shape u
Acc (VName -> ShapeBase SubExp -> [Type] -> u -> TypeBase shape u)
-> SimpleM rep VName
-> SimpleM
     rep (ShapeBase SubExp -> [Type] -> u -> TypeBase shape u)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> SimpleM rep VName
forall rep. SimplifiableRep rep => VName -> SimpleM rep VName
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify VName
acc SimpleM rep (ShapeBase SubExp -> [Type] -> u -> TypeBase shape u)
-> SimpleM rep (ShapeBase SubExp)
-> SimpleM rep ([Type] -> u -> TypeBase shape u)
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ShapeBase SubExp -> SimpleM rep (ShapeBase SubExp)
forall rep.
SimplifiableRep rep =>
ShapeBase SubExp -> SimpleM rep (ShapeBase SubExp)
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify ShapeBase SubExp
ispace SimpleM rep ([Type] -> u -> TypeBase shape u)
-> SimpleM rep [Type] -> SimpleM rep (u -> TypeBase shape u)
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Type] -> SimpleM rep [Type]
forall rep. SimplifiableRep rep => [Type] -> SimpleM rep [Type]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify [Type]
ts SimpleM rep (u -> TypeBase shape u)
-> SimpleM rep u -> SimpleM rep (TypeBase shape u)
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> u -> SimpleM rep u
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure u
u
  simplify (Mem Space
space) =
    Space -> TypeBase shape u
forall shape u. Space -> TypeBase shape u
Mem (Space -> TypeBase shape u)
-> SimpleM rep Space -> SimpleM rep (TypeBase shape u)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Space -> SimpleM rep Space
forall rep. SimplifiableRep rep => Space -> SimpleM rep Space
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify Space
space
  simplify (Prim PrimType
bt) =
    TypeBase shape u -> SimpleM rep (TypeBase shape u)
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TypeBase shape u -> SimpleM rep (TypeBase shape u))
-> TypeBase shape u -> SimpleM rep (TypeBase shape u)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase shape u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
bt

instance (Simplifiable d) => Simplifiable (DimIndex d) where
  simplify :: forall rep.
SimplifiableRep rep =>
DimIndex d -> SimpleM rep (DimIndex d)
simplify (DimFix d
i) = d -> DimIndex d
forall d. d -> DimIndex d
DimFix (d -> DimIndex d) -> SimpleM rep d -> SimpleM rep (DimIndex d)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> d -> SimpleM rep d
forall rep. SimplifiableRep rep => d -> SimpleM rep d
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify d
i
  simplify (DimSlice d
i d
n d
s) = d -> d -> d -> DimIndex d
forall d. d -> d -> d -> DimIndex d
DimSlice (d -> d -> d -> DimIndex d)
-> SimpleM rep d -> SimpleM rep (d -> d -> DimIndex d)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> d -> SimpleM rep d
forall rep. SimplifiableRep rep => d -> SimpleM rep d
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify d
i SimpleM rep (d -> d -> DimIndex d)
-> SimpleM rep d -> SimpleM rep (d -> DimIndex d)
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> d -> SimpleM rep d
forall rep. SimplifiableRep rep => d -> SimpleM rep d
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify d
n SimpleM rep (d -> DimIndex d)
-> SimpleM rep d -> SimpleM rep (DimIndex d)
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> d -> SimpleM rep d
forall rep. SimplifiableRep rep => d -> SimpleM rep d
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify d
s

instance (Simplifiable d) => Simplifiable (Slice d) where
  simplify :: forall rep. SimplifiableRep rep => Slice d -> SimpleM rep (Slice d)
simplify = (d -> SimpleM rep d) -> Slice d -> SimpleM rep (Slice d)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Slice a -> f (Slice b)
traverse d -> SimpleM rep d
forall rep. SimplifiableRep rep => d -> SimpleM rep d
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify

simplifyLambda ::
  (SimplifiableRep rep) =>
  Names ->
  Lambda (Wise rep) ->
  SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambda :: forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambda Names
extra_bound Lambda (Wise rep)
lam = do
  par_blocker <- (Env rep -> BlockPred (Wise rep))
-> SimpleM rep (BlockPred (Wise rep))
forall {k} (rep :: k) a. (Env rep -> a) -> SimpleM rep a
asksEngineEnv ((Env rep -> BlockPred (Wise rep))
 -> SimpleM rep (BlockPred (Wise rep)))
-> (Env rep -> BlockPred (Wise rep))
-> SimpleM rep (BlockPred (Wise rep))
forall a b. (a -> b) -> a -> b
$ HoistBlockers rep -> BlockPred (Wise rep)
forall {k} (rep :: k). HoistBlockers rep -> BlockPred (Wise rep)
blockHoistPar (HoistBlockers rep -> BlockPred (Wise rep))
-> (Env rep -> HoistBlockers rep)
-> Env rep
-> BlockPred (Wise rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env rep -> HoistBlockers rep
forall {k} (rep :: k). Env rep -> HoistBlockers rep
envHoistBlockers
  simplifyLambdaMaybeHoist (par_blocker `orIf` hasFree extra_bound) mempty lam

simplifyLambdaNoHoisting ::
  (SimplifiableRep rep) =>
  Lambda (Wise rep) ->
  SimpleM rep (Lambda (Wise rep))
simplifyLambdaNoHoisting :: forall rep.
SimplifiableRep rep =>
Lambda (Wise rep) -> SimpleM rep (Lambda (Wise rep))
simplifyLambdaNoHoisting Lambda (Wise rep)
lam =
  (Lambda (Wise rep), Stms (Wise rep)) -> Lambda (Wise rep)
forall a b. (a, b) -> a
fst ((Lambda (Wise rep), Stms (Wise rep)) -> Lambda (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BlockPred (Wise rep)
-> UsageTable
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
BlockPred (Wise rep)
-> UsageTable
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambdaMaybeHoist (Bool -> BlockPred (Wise rep)
forall rep. Bool -> BlockPred rep
isFalse Bool
False) UsageTable
forall a. Monoid a => a
mempty Lambda (Wise rep)
lam

simplifyLambdaMaybeHoist ::
  (SimplifiableRep rep) =>
  BlockPred (Wise rep) ->
  UT.UsageTable ->
  Lambda (Wise rep) ->
  SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambdaMaybeHoist :: forall rep.
SimplifiableRep rep =>
BlockPred (Wise rep)
-> UsageTable
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambdaMaybeHoist = (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> BlockPred (Wise rep)
-> UsageTable
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> BlockPred (Wise rep)
-> UsageTable
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambdaWith SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a. a -> a
id

simplifyLambdaWith ::
  (SimplifiableRep rep) =>
  (ST.SymbolTable (Wise rep) -> ST.SymbolTable (Wise rep)) ->
  BlockPred (Wise rep) ->
  UT.UsageTable ->
  Lambda (Wise rep) ->
  SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambdaWith :: forall rep.
SimplifiableRep rep =>
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> BlockPred (Wise rep)
-> UsageTable
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambdaWith SymbolTable (Wise rep) -> SymbolTable (Wise rep)
f BlockPred (Wise rep)
blocked UsageTable
usage lam :: Lambda (Wise rep)
lam@(Lambda [LParam (Wise rep)]
params [Type]
rettype Body (Wise rep)
body) = do
  params' <- (Param (LParamInfo rep) -> SimpleM rep (Param (LParamInfo rep)))
-> [Param (LParamInfo rep)] -> SimpleM rep [Param (LParamInfo rep)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((LParamInfo rep -> SimpleM rep (LParamInfo rep))
-> Param (LParamInfo rep) -> SimpleM rep (Param (LParamInfo rep))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Param a -> f (Param b)
traverse LParamInfo rep -> SimpleM rep (LParamInfo rep)
forall rep.
SimplifiableRep rep =>
LParamInfo rep -> SimpleM rep (LParamInfo rep)
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify) [Param (LParamInfo rep)]
[LParam (Wise rep)]
params
  let paramnames = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Lambda (Wise rep) -> [VName]
forall rep. Lambda rep -> [VName]
boundByLambda Lambda (Wise rep)
lam
  (hoisted, body') <-
    bindLParams params' . localVtable f $
      simplifyBody
        (blocked `orIf` hasFree paramnames `orIf` isConsumed)
        usage
        (map (const mempty) rettype)
        body
  rettype' <- simplify rettype
  pure (Lambda params' rettype' body', hoisted)

instance Simplifiable Certs where
  simplify :: forall rep. SimplifiableRep rep => Certs -> SimpleM rep Certs
simplify (Certs [VName]
ocs) = [VName] -> Certs
Certs ([VName] -> Certs) -> ([[VName]] -> [VName]) -> [[VName]] -> Certs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [VName] -> [VName]
forall a. Ord a => [a] -> [a]
nubOrd ([VName] -> [VName])
-> ([[VName]] -> [VName]) -> [[VName]] -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[VName]] -> Certs) -> SimpleM rep [[VName]] -> SimpleM rep Certs
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> SimpleM rep [VName]) -> [VName] -> SimpleM rep [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> SimpleM rep [VName]
forall {k} {rep :: k}. VName -> SimpleM rep [VName]
check [VName]
ocs
    where
      check :: VName -> SimpleM rep [VName]
check VName
idd = do
        vv <- VName -> SymbolTable (Wise rep) -> Maybe (SubExp, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (SubExp, Certs)
ST.lookupSubExp VName
idd (SymbolTable (Wise rep) -> Maybe (SubExp, Certs))
-> SimpleM rep (SymbolTable (Wise rep))
-> SimpleM rep (Maybe (SubExp, Certs))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SimpleM rep (SymbolTable (Wise rep))
forall {k} (rep :: k). SimpleM rep (SymbolTable (Wise rep))
askVtable
        case vv of
          Just (Constant PrimValue
_, Certs [VName]
cs) -> [VName] -> SimpleM rep [VName]
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName]
cs
          Just (Var VName
idd', Certs
_) -> [VName] -> SimpleM rep [VName]
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName
idd']
          Maybe (SubExp, Certs)
_ -> [VName] -> SimpleM rep [VName]
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName
idd]

simplifyFun ::
  (SimplifiableRep rep) =>
  FunDef (Wise rep) ->
  SimpleM rep (FunDef (Wise rep))
simplifyFun :: forall rep.
SimplifiableRep rep =>
FunDef (Wise rep) -> SimpleM rep (FunDef (Wise rep))
simplifyFun (FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [(RetType (Wise rep), RetAls)]
rettype [FParam (Wise rep)]
params Body (Wise rep)
body) = do
  rettype' <- ((RetType rep, RetAls) -> SimpleM rep (RetType rep, RetAls))
-> [(RetType rep, RetAls)] -> SimpleM rep [(RetType rep, RetAls)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((RetType rep -> SimpleM rep (RetType rep))
-> (RetAls -> SimpleM rep RetAls)
-> (RetType rep, RetAls)
-> SimpleM rep (RetType rep, RetAls)
forall (f :: * -> *) a c b d.
Applicative f =>
(a -> f c) -> (b -> f d) -> (a, b) -> f (c, d)
forall (t :: * -> * -> *) (f :: * -> *) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse RetType rep -> SimpleM rep (RetType rep)
forall rep.
SimplifiableRep rep =>
RetType rep -> SimpleM rep (RetType rep)
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify RetAls -> SimpleM rep RetAls
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure) [(RetType rep, RetAls)]
[(RetType (Wise rep), RetAls)]
rettype
  params' <- mapM (traverse simplify) params
  let usages = ((RetType rep, RetAls) -> Usages)
-> [(RetType rep, RetAls)] -> [Usages]
forall a b. (a -> b) -> [a] -> [b]
map (RetType rep, RetAls) -> Usages
usageFromRet [(RetType rep, RetAls)]
rettype'
  body' <- bindFParams params $ simplifyBodyNoHoisting mempty usages body
  pure $ FunDef entry attrs fname rettype' params' body'
  where
    aliasable :: TypeBase shape u -> Bool
aliasable Array {} = Bool
True
    aliasable TypeBase shape u
_ = Bool
False
    aliasable_params :: [Int]
aliasable_params =
      ((Param (FParamInfo rep), Int) -> Int)
-> [(Param (FParamInfo rep), Int)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Param (FParamInfo rep), Int) -> Int
forall a b. (a, b) -> b
snd ([(Param (FParamInfo rep), Int)] -> [Int])
-> [(Param (FParamInfo rep), Int)] -> [Int]
forall a b. (a -> b) -> a -> b
$ ((Param (FParamInfo rep), Int) -> Bool)
-> [(Param (FParamInfo rep), Int)]
-> [(Param (FParamInfo rep), Int)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Type -> Bool
forall shape u. TypeBase shape u -> Bool
aliasable (Type -> Bool)
-> ((Param (FParamInfo rep), Int) -> Type)
-> (Param (FParamInfo rep), Int)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (FParamInfo rep) -> Type
forall dec. Typed dec => Param dec -> Type
paramType (Param (FParamInfo rep) -> Type)
-> ((Param (FParamInfo rep), Int) -> Param (FParamInfo rep))
-> (Param (FParamInfo rep), Int)
-> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param (FParamInfo rep), Int) -> Param (FParamInfo rep)
forall a b. (a, b) -> a
fst) ([(Param (FParamInfo rep), Int)]
 -> [(Param (FParamInfo rep), Int)])
-> [(Param (FParamInfo rep), Int)]
-> [(Param (FParamInfo rep), Int)]
forall a b. (a -> b) -> a -> b
$ [Param (FParamInfo rep)]
-> [Int] -> [(Param (FParamInfo rep), Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (FParamInfo rep)]
[FParam (Wise rep)]
params [Int
0 ..]
    aliasable_rets :: [Int]
aliasable_rets =
      (((RetType rep, RetAls), Int) -> Int)
-> [((RetType rep, RetAls), Int)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ((RetType rep, RetAls), Int) -> Int
forall a b. (a, b) -> b
snd ([((RetType rep, RetAls), Int)] -> [Int])
-> [((RetType rep, RetAls), Int)] -> [Int]
forall a b. (a -> b) -> a -> b
$ (((RetType rep, RetAls), Int) -> Bool)
-> [((RetType rep, RetAls), Int)] -> [((RetType rep, RetAls), Int)]
forall a. (a -> Bool) -> [a] -> [a]
filter (ExtType -> Bool
forall shape u. TypeBase shape u -> Bool
aliasable (ExtType -> Bool)
-> (((RetType rep, RetAls), Int) -> ExtType)
-> ((RetType rep, RetAls), Int)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RetType rep -> ExtType
forall t. ExtTyped t => t -> ExtType
extTypeOf (RetType rep -> ExtType)
-> (((RetType rep, RetAls), Int) -> RetType rep)
-> ((RetType rep, RetAls), Int)
-> ExtType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (RetType rep, RetAls) -> RetType rep
forall a b. (a, b) -> a
fst ((RetType rep, RetAls) -> RetType rep)
-> (((RetType rep, RetAls), Int) -> (RetType rep, RetAls))
-> ((RetType rep, RetAls), Int)
-> RetType rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((RetType rep, RetAls), Int) -> (RetType rep, RetAls)
forall a b. (a, b) -> a
fst) ([((RetType rep, RetAls), Int)] -> [((RetType rep, RetAls), Int)])
-> [((RetType rep, RetAls), Int)] -> [((RetType rep, RetAls), Int)]
forall a b. (a -> b) -> a -> b
$ [(RetType rep, RetAls)] -> [Int] -> [((RetType rep, RetAls), Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(RetType rep, RetAls)]
[(RetType (Wise rep), RetAls)]
rettype [Int
0 ..]
    restricted :: t a -> t a -> Bool
restricted t a
als = (a -> Bool) -> t a -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (a -> t a -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` t a
als)
    usageFromRet :: (RetType rep, RetAls) -> Usages
usageFromRet (RetType rep
t, RetAls [Int]
pals [Int]
rals) =
      Diet -> Usages
usageFromDiet (DeclExtType -> Diet
forall shape. TypeBase shape Uniqueness -> Diet
diet (DeclExtType -> Diet) -> DeclExtType -> Diet
forall a b. (a -> b) -> a -> b
$ RetType rep -> DeclExtType
forall t. DeclExtTyped t => t -> DeclExtType
declExtTypeOf RetType rep
t)
        Usages -> Usages -> Usages
forall a. Semigroup a => a -> a -> a
<> if [Int] -> [Int] -> Bool
forall {t :: * -> *} {t :: * -> *} {a}.
(Foldable t, Foldable t, Eq a) =>
t a -> t a -> Bool
restricted [Int]
pals [Int]
aliasable_params
          Bool -> Bool -> Bool
|| [Int] -> [Int] -> Bool
forall {t :: * -> *} {t :: * -> *} {a}.
(Foldable t, Foldable t, Eq a) =>
t a -> t a -> Bool
restricted [Int]
rals [Int]
aliasable_rets
          then Usages
UT.consumedU
          else Usages
forall a. Monoid a => a
mempty