{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.AD.Rev.Hist
  ( diffMinMaxHist,
    diffMulHist,
    diffAddHist,
    diffVecHist,
    diffHist,
  )
where

import Control.Monad
import Futhark.AD.Rev.Monad
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Transform.Rename

getBinOpPlus :: PrimType -> BinOp
getBinOpPlus :: PrimType -> BinOp
getBinOpPlus (IntType IntType
x) = IntType -> Overflow -> BinOp
Add IntType
x Overflow
OverflowUndef
getBinOpPlus (FloatType FloatType
f) = FloatType -> BinOp
FAdd FloatType
f
getBinOpPlus PrimType
_ = String -> BinOp
forall a. HasCallStack => String -> a
error String
"In getBinOpMul, Hist.hs: input not supported"

getBinOpDiv :: PrimType -> BinOp
getBinOpDiv :: PrimType -> BinOp
getBinOpDiv (IntType IntType
t) = IntType -> Safety -> BinOp
SDiv IntType
t Safety
Unsafe
getBinOpDiv (FloatType FloatType
t) = FloatType -> BinOp
FDiv FloatType
t
getBinOpDiv PrimType
_ = String -> BinOp
forall a. HasCallStack => String -> a
error String
"In getBinOpDiv, Hist.hs: input not supported"

withinBounds :: [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds :: [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds [] = PrimExp VName -> TPrimExp Bool VName
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp VName -> TPrimExp Bool VName)
-> PrimExp VName -> TPrimExp Bool VName
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimExp VName
forall v. PrimValue -> PrimExp v
ValueExp (Bool -> PrimValue
BoolValue Bool
True)
withinBounds [(SubExp
q, VName
i)] = (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
q) TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. (SubExp -> TPrimExp Int64 VName
pe64 (IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1)) TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i)
withinBounds ((SubExp, VName)
qi : [(SubExp, VName)]
qis) = [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds [(SubExp, VName)
qi] TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds [(SubExp, VName)]
qis

elseIf ::
  (MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
  PrimType ->
  [(m (Exp (Rep m)), m (Exp (Rep m)))] ->
  [m (Body (Rep m))] ->
  m (Exp (Rep m))
elseIf :: forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
PrimType
-> [(m (Exp (Rep m)), m (Exp (Rep m)))]
-> [m (Body (Rep m))]
-> m (Exp (Rep m))
elseIf PrimType
t [(m (Exp (Rep m))
c1, m (Exp (Rep m))
c2)] [m (Body (Rep m))
bt, m (Body (Rep m))
bf] =
  m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
    (CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
t) m (Exp (Rep m))
c1 m (Exp (Rep m))
c2)
    m (Body (Rep m))
bt
    m (Body (Rep m))
bf
elseIf PrimType
t ((m (Exp (Rep m))
c1, m (Exp (Rep m))
c2) : [(m (Exp (Rep m)), m (Exp (Rep m)))]
cs) (m (Body (Rep m))
bt : [m (Body (Rep m))]
bs) =
  m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
    (CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
t) m (Exp (Rep m))
c1 m (Exp (Rep m))
c2)
    m (Body (Rep m))
bt
    (m (Body (Rep m)) -> m (Exp (Rep m)))
-> m (Body (Rep m)) -> m (Exp (Rep m))
forall a b. (a -> b) -> a -> b
$ [m (Exp (Rep m))] -> m (Body (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody
    ([m (Exp (Rep m))] -> m (Body (Rep m)))
-> [m (Exp (Rep m))] -> m (Body (Rep m))
forall a b. (a -> b) -> a -> b
$ m (Exp (Rep m)) -> [m (Exp (Rep m))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    (m (Exp (Rep m)) -> [m (Exp (Rep m))])
-> m (Exp (Rep m)) -> [m (Exp (Rep m))]
forall a b. (a -> b) -> a -> b
$ PrimType
-> [(m (Exp (Rep m)), m (Exp (Rep m)))]
-> [m (Body (Rep m))]
-> m (Exp (Rep m))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
PrimType
-> [(m (Exp (Rep m)), m (Exp (Rep m)))]
-> [m (Body (Rep m))]
-> m (Exp (Rep m))
elseIf PrimType
t [(m (Exp (Rep m)), m (Exp (Rep m)))]
cs [m (Body (Rep m))]
bs
elseIf PrimType
_ [(m (Exp (Rep m)), m (Exp (Rep m)))]
_ [m (Body (Rep m))]
_ = String -> m (Exp (Rep m))
forall a. HasCallStack => String -> a
error String
"In elseIf, Hist.hs: input not supported"

bindSubExpRes :: (MonadBuilder m) => String -> [SubExpRes] -> m [VName]
bindSubExpRes :: forall (m :: * -> *).
MonadBuilder m =>
String -> [SubExpRes] -> m [VName]
bindSubExpRes String
s =
  (SubExpRes -> m VName) -> [SubExpRes] -> m [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse
    ( \(SubExpRes Certs
cs SubExp
se) -> do
        bn <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
s
        certifying cs $ letBindNames [bn] $ BasicOp $ SubExp se
        pure bn
    )

nestedmap :: [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap :: [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [] [PrimType]
_ Lambda SOACS
lam = Lambda SOACS -> ADM (Lambda SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda SOACS
lam
nestedmap s :: [SubExp]
s@(SubExp
h : [SubExp]
r) [PrimType]
pt Lambda SOACS
lam = do
  params <- (PrimType -> ADM (Param Type)) -> [PrimType] -> ADM [Param Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (\PrimType
tp -> String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"x" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
tp ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
s) NoUniqueness
NoUniqueness) [PrimType]
pt
  body <- nestedmap r pt lam
  mkLambda params $
    fmap varsRes . letTupExp "res" . Op $
      Screma h (map paramName params) (mapSOAC body)

-- \ds hs -> map2 lam ds hs
mkF' :: Lambda SOACS -> [Type] -> SubExp -> ADM ([VName], [VName], Lambda SOACS)
mkF' :: Lambda SOACS
-> [Type] -> SubExp -> ADM ([VName], [VName], Lambda SOACS)
mkF' Lambda SOACS
lam [Type]
tps SubExp
n = do
  lam' <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam

  ds_params <- traverse (newParam "ds_param") tps
  hs_params <- traverse (newParam "hs_param") tps
  let ds_pars = (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
ds_params
  let hs_pars = (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
hs_params
  lam_map <-
    mkLambda (ds_params <> hs_params) $
      fmap varsRes . letTupExp "map_f'" . Op $
        Screma n (ds_pars <> hs_pars) (mapSOAC lam')

  pure (ds_pars, hs_pars, lam_map)

-- \ls as rs -> map3 (\li ai ri -> li `lam` ai `lam` ri) ls as rs
mkF :: Lambda SOACS -> [Type] -> SubExp -> ADM ([VName], Lambda SOACS)
mkF :: Lambda SOACS -> [Type] -> SubExp -> ADM ([VName], Lambda SOACS)
mkF Lambda SOACS
lam [Type]
tps SubExp
n = do
  lam_l <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam
  lam_r <- renameLambda lam
  let q = [Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Type] -> Int) -> [Type] -> Int
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam
      (lps, aps) = splitAt q $ lambdaParams lam_l
      (ips, rps) = splitAt q $ lambdaParams lam_r
  lam' <- mkLambda (lps <> aps <> rps) $ do
    lam_l_res <- bodyBind $ lambdaBody lam_l
    forM_ (zip ips lam_l_res) $ \(Param Type
ip, SubExpRes Certs
cs SubExp
se) ->
      Certs -> ADM () -> ADM ()
forall a. Certs -> ADM a -> ADM a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
ip] (Exp (Rep ADM) -> ADM ()) -> Exp (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
    bodyBind $ lambdaBody lam_r

  ls_params <- traverse (newParam "ls_param") tps
  as_params <- traverse (newParam "as_param") tps
  rs_params <- traverse (newParam "rs_param") tps
  let map_params = [Param Type]
ls_params [Param Type] -> [Param Type] -> [Param Type]
forall a. Semigroup a => a -> a -> a
<> [Param Type]
as_params [Param Type] -> [Param Type] -> [Param Type]
forall a. Semigroup a => a -> a -> a
<> [Param Type]
rs_params
  lam_map <-
    mkLambda map_params $
      fmap varsRes . letTupExp "map_f" $
        Op $
          Screma n (map paramName map_params) $
            mapSOAC lam'

  pure (map paramName as_params, lam_map)

mapout :: VName -> SubExp -> SubExp -> ADM VName
mapout :: VName -> SubExp -> SubExp -> ADM VName
mapout VName
is SubExp
n SubExp
w = do
  par_is <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"is" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  is'_lam <-
    mkLambda [par_is] $
      fmap varsRes . letTupExp "is'"
        =<< eIf
          (toExp $ withinBounds $ pure (w, paramName par_is))
          (eBody $ pure $ eParam par_is)
          (eBody $ pure $ eSubExp w)

  letExp "is'" $ Op $ Screma n (pure is) $ mapSOAC is'_lam

multiScatter :: SubExp -> [VName] -> VName -> [VName] -> ADM [VName]
multiScatter :: SubExp -> [VName] -> VName -> [VName] -> ADM [VName]
multiScatter SubExp
n [VName]
dst VName
is [VName]
vs = do
  tps <- (VName -> ADM Type) -> [VName] -> ADM [Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
vs
  par_i <- newParam "i" $ Prim int64
  scatter_params <- traverse (newParam "scatter_param" . rowType) tps
  scatter_lam <-
    mkLambda (par_i : scatter_params) $
      fmap subExpsRes . mapM (letSubExp "scatter_map_res") =<< do
        p1 <- replicateM (length scatter_params) $ eParam par_i
        p2 <- traverse eParam scatter_params
        pure $ p1 <> p2

  let spec = (Type -> VName -> (Shape, Int, VName))
-> [Type] -> [VName] -> [(Shape, Int, VName)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Type
t -> (,,) ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ SubExp -> [SubExp]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> [SubExp]) -> SubExp -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
t) Int
1) [Type]
tps [VName]
dst
  letTupExp "scatter_res" . Op $ Scatter n (is : vs) spec scatter_lam

multiIndex :: [VName] -> [DimIndex SubExp] -> ADM [VName]
multiIndex :: [VName] -> [DimIndex SubExp] -> ADM [VName]
multiIndex [VName]
vs [DimIndex SubExp]
s = do
  (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse
    ( \VName
x -> do
        t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
x
        letExp "sorted" $ BasicOp $ Index x (fullSlice t s)
    )
    [VName]
vs

--
-- special case of histogram with min/max as operator.
-- Original, assuming `is: [n]i64` and `dst: [w]btp`
--     let x = reduce_by_index dst minmax ne is vs
-- Forward sweep:
--     need to copy dst: reverse sweep might use it 7
--       (see ex. in reducebyindexminmax6.fut where the first map requires the original dst to be differentiated).
--     let dst_cpy = copy dst
--     let (x, x_inds) = zip vs (iota n)
--                       |> reduce_by_index (dst_cpy,-1s) argminmax (ne,-1) is
--
-- Reverse sweep:
--     dst_bar += map2 (\i b -> if i == -1
--                              then b
--                              else 0
--                     ) x_inds x_bar

--     vs_ctrbs = map2 (\i b -> if i == -1
--                              then 0
--                              else vs_bar[i] + b
--                     ) x_inds x_bar
--     vs_bar <- scatter vs_bar x_inds vs_ctrbs
diffMinMaxHist ::
  VjpOps -> VName -> StmAux () -> SubExp -> BinOp -> SubExp -> VName -> VName -> SubExp -> SubExp -> VName -> ADM () -> ADM ()
diffMinMaxHist :: VjpOps
-> VName
-> StmAux ()
-> SubExp
-> BinOp
-> SubExp
-> VName
-> VName
-> SubExp
-> SubExp
-> VName
-> ADM ()
-> ADM ()
diffMinMaxHist VjpOps
_ops VName
x StmAux ()
aux SubExp
n BinOp
minmax SubExp
ne VName
is VName
vs SubExp
w SubExp
rf VName
dst ADM ()
m = do
  let t :: PrimType
t = BinOp -> PrimType
binOpType BinOp
minmax
  vs_type <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
vs
  let vs_elm_type = Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
vs_type
  let vs_dims = Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
vs_type
  let inner_dims = [SubExp] -> [SubExp]
forall a. HasCallStack => [a] -> [a]
tail [SubExp]
vs_dims
  let nr_dims = [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs_dims
  dst_type <- lookupType dst
  let dst_dims = Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
dst_type

  dst_cpy <-
    letExp (baseString dst <> "_copy") . BasicOp $
      Replicate mempty (Var dst)

  acc_v_p <- newParam "acc_v" $ Prim t
  acc_i_p <- newParam "acc_i" $ Prim int64
  v_p <- newParam "v" $ Prim t
  i_p <- newParam "i" $ Prim int64
  hist_lam_inner <-
    mkLambda [acc_v_p, acc_i_p, v_p, i_p] $
      fmap varsRes . letTupExp "idx_res"
        =<< eIf
          (eCmpOp (CmpEq t) (eParam acc_v_p) (eParam v_p))
          ( eBody
              [ eParam acc_v_p,
                eBinOp (SMin Int64) (eParam acc_i_p) (eParam i_p)
              ]
          )
          ( eBody
              [ eIf
                  ( eCmpOp
                      (CmpEq t)
                      (eParam acc_v_p)
                      (eBinOp minmax (eParam acc_v_p) (eParam v_p))
                  )
                  (eBody [eParam acc_v_p, eParam acc_i_p])
                  (eBody [eParam v_p, eParam i_p])
              ]
          )
  hist_lam <- nestedmap inner_dims [vs_elm_type, int64, vs_elm_type, int64] hist_lam_inner

  dst_minus_ones <-
    letExp "minus_ones" . BasicOp $
      Replicate (Shape dst_dims) (intConst Int64 (-1))
  ne_minus_ones <-
    letSubExp "minus_ones" . BasicOp $
      Replicate (Shape inner_dims) (intConst Int64 (-1))
  iota_n <-
    letExp "red_iota" . BasicOp $
      Iota n (intConst Int64 0) (intConst Int64 1) Int64

  inp_iota <- do
    if nr_dims == 1
      then pure iota_n
      else do
        i <- newParam "i" $ Prim int64
        lam <-
          mkLambda [i] $
            fmap varsRes . letTupExp "res" =<< do
              pure $ BasicOp $ Replicate (Shape inner_dims) $ Var $ paramName i

        letExp "res" $ Op $ Screma n [iota_n] $ mapSOAC lam

  let hist_op = Shape
-> SubExp -> [VName] -> [SubExp] -> Lambda SOACS -> HistOp SOACS
forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
rf [VName
dst_cpy, VName
dst_minus_ones] [SubExp
ne, if Int
nr_dims Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 then IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1) else SubExp
ne_minus_ones] Lambda SOACS
hist_lam
  f' <- mkIdentityLambda [Prim int64, rowType vs_type, rowType $ Array int64 (Shape vs_dims) NoUniqueness]
  x_inds <- newVName (baseString x <> "_inds")
  auxing aux $
    letBindNames [x, x_inds] $
      Op $
        Hist n [is, vs, inp_iota] [hist_op] f'

  m

  x_bar <- lookupAdjVal x

  x_ind_dst <- newParam (baseString x <> "_ind_param") $ Prim int64
  x_bar_dst <- newParam (baseString x <> "_bar_param") $ Prim t
  dst_lam_inner <-
    mkLambda [x_ind_dst, x_bar_dst] $
      fmap varsRes . letTupExp "dst_bar"
        =<< eIf
          (toExp $ le64 (paramName x_ind_dst) .==. -1)
          (eBody $ pure $ eParam x_bar_dst)
          (eBody $ pure $ eSubExp $ Constant $ blankPrimValue t)
  dst_lam <- nestedmap inner_dims [int64, vs_elm_type] dst_lam_inner

  dst_bar <-
    letExp (baseString dst <> "_bar") . Op $
      Screma w [x_inds, x_bar] (mapSOAC dst_lam)

  updateAdj dst dst_bar

  vs_bar <- lookupAdjVal vs

  inds' <- traverse (letExp "inds" . BasicOp . Replicate (Shape [w]) . Var) =<< mk_indices inner_dims []
  let inds = VName
x_inds VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
inds'

  par_x_ind_vs <- replicateM nr_dims $ newParam (baseString x <> "_ind_param") $ Prim int64
  par_x_bar_vs <- newParam (baseString x <> "_bar_param") $ Prim t
  vs_lam_inner <-
    mkLambda (par_x_bar_vs : par_x_ind_vs) $
      fmap varsRes . letTupExp "res"
        =<< eIf
          (toExp $ le64 (paramName $ head par_x_ind_vs) .==. -1)
          (eBody $ pure $ eSubExp $ Constant $ blankPrimValue t)
          ( eBody $
              pure $ do
                vs_bar_i <-
                  letSubExp (baseString vs_bar <> "_el") . BasicOp $
                    Index vs_bar . Slice $
                      fmap (DimFix . Var . paramName) par_x_ind_vs
                eBinOp (getBinOpPlus t) (eParam par_x_bar_vs) (eSubExp vs_bar_i)
          )
  vs_lam <- nestedmap inner_dims (vs_elm_type : replicate nr_dims int64) vs_lam_inner

  vs_bar_p <-
    letExp (baseString vs <> "_partial") . Op $
      Screma w (x_bar : inds) (mapSOAC vs_lam)

  q <-
    letSubExp "q"
      =<< foldBinOp (Mul Int64 OverflowUndef) (intConst Int64 1) dst_dims

  scatter_inps <- do
    -- traverse (letExp "flat" . BasicOp . Reshape [DimNew q]) $ inds ++ [vs_bar_p]
    -- ToDo: Cosmin asks: is the below the correct translation of the line above?
    forM (inds ++ [vs_bar_p]) $ \VName
v -> do
      v_t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
      letExp "flat" . BasicOp . Reshape v $
        reshapeAll (arrayShape v_t) (Shape [q])

  f'' <- mkIdentityLambda $ replicate nr_dims (Prim int64) ++ [Prim t]
  vs_bar' <-
    letExp (baseString vs <> "_bar") . Op $
      Scatter q scatter_inps [(Shape vs_dims, 1, vs_bar)] f''
  insAdj vs vs_bar'
  where
    mk_indices :: [SubExp] -> [SubExp] -> ADM [VName]
    mk_indices :: [SubExp] -> [SubExp] -> ADM [VName]
mk_indices [] [SubExp]
_ = [VName] -> ADM [VName]
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
    mk_indices [SubExp
d] [SubExp]
iotas = do
      reps <- (SubExp -> ADM VName) -> [SubExp] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"rep" (Exp SOACS -> ADM VName)
-> (SubExp -> Exp SOACS) -> SubExp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS)
-> (SubExp -> BasicOp) -> SubExp -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
d])) [SubExp]
iotas
      iota_d <-
        letExp "red_iota" . BasicOp $
          Iota d (intConst Int64 0) (intConst Int64 1) Int64
      pure $ reps ++ [iota_d]
    mk_indices (SubExp
d : [SubExp]
dims) [SubExp]
iotas = do
      iota_d <-
        String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"red_iota" (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
          SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
d (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64

      i_param <- newParam "i" $ Prim int64
      lam <-
        mkLambda [i_param] $
          fmap varsRes $
            mk_indices dims $
              iotas ++ [Var $ paramName i_param]

      letTupExp "res" $ Op $ Screma d [iota_d] $ mapSOAC lam

--
-- special case of histogram with multiplication as operator.
-- Original, assuming `is: [n]i64` and `dst: [w]btp`
--     let x = reduce_by_index dst (*) ne is vs
-- Forward sweep:
--     dst does not need to be copied: dst is not overwritten
--     let (ps, zs) = map (\v -> if v == 0 then (1,1) else (v,0)) vs
--     let non_zero_prod = reduce_by_index nes (*) ne is ps
--     let zero_count = reduce_by_index 0s (+) 0 is zs
--     let h_part = map2 (\p c -> if c == 0 then p else 0
--                       ) non_zero_prod zero_count
--     let x = map2 (*) dst h_part
--
-- Reverse sweep:
--     dst_bar += map2 (*) h_part x_bar

--     let part_bar = map2 (*) dst x_bar
--     vs_bar += map2 (\i v -> let zr_cts = zero_count[i]
--                             let pr_bar = part_bar[i]
--                             let nz_prd = non_zero_prod[i]
--                             in if zr_cts == 0
--                             then pr_bar * (nz_prd / v)
--                             else if zr_cts == 1 and v == 0
--                             then nz_prd * pr_bar
--                             else 0
--                    ) is vs
diffMulHist ::
  VjpOps -> VName -> StmAux () -> SubExp -> BinOp -> SubExp -> VName -> VName -> SubExp -> SubExp -> VName -> ADM () -> ADM ()
diffMulHist :: VjpOps
-> VName
-> StmAux ()
-> SubExp
-> BinOp
-> SubExp
-> VName
-> VName
-> SubExp
-> SubExp
-> VName
-> ADM ()
-> ADM ()
diffMulHist VjpOps
_ops VName
x StmAux ()
aux SubExp
n BinOp
mul SubExp
ne VName
is VName
vs SubExp
w SubExp
rf VName
dst ADM ()
m = do
  let t :: PrimType
t = BinOp -> PrimType
binOpType BinOp
mul
  vs_type <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
vs
  let vs_dims = Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
vs_type
  let vs_elm_type = Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
vs_type
  dst_type <- lookupType dst
  let dst_dims = Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
dst_type
  let inner_dims = [SubExp] -> [SubExp]
forall a. HasCallStack => [a] -> [a]
tail [SubExp]
vs_dims

  v_param <- newParam "v" $ Prim t
  lam_ps_zs_inner <-
    mkLambda [v_param] $
      fmap varsRes . letTupExp "map_res"
        =<< eIf
          (eCmpOp (CmpEq t) (eParam v_param) (eSubExp $ Constant $ blankPrimValue t))
          (eBody $ fmap eSubExp [Constant $ onePrimValue t, intConst Int64 1])
          (eBody [eParam v_param, eSubExp $ intConst Int64 0])
  lam_ps_zs <- nestedmap vs_dims [vs_elm_type] lam_ps_zs_inner
  ps_zs_res <- eLambda lam_ps_zs [eSubExp $ Var vs]
  ps_zs <- bindSubExpRes "ps_zs" ps_zs_res
  let [ps, zs] = ps_zs

  lam_mul_inner <- binOpLambda mul t
  lam_mul <- nestedmap inner_dims [vs_elm_type, vs_elm_type] lam_mul_inner
  nz_prods0 <- letExp "nz_prd" $ BasicOp $ Replicate (Shape [w]) ne
  let hist_nzp = Shape
-> SubExp -> [VName] -> [SubExp] -> Lambda SOACS -> HistOp SOACS
forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
rf [VName
nz_prods0] [SubExp
ne] Lambda SOACS
lam_mul

  lam_add_inner <- binOpLambda (Add Int64 OverflowUndef) int64
  lam_add <- nestedmap inner_dims [int64, int64] lam_add_inner
  zr_counts0 <- letExp "zr_cts" $ BasicOp $ Replicate (Shape dst_dims) (intConst Int64 0)
  zrn_ne <- letSubExp "zr_ne" $ BasicOp $ Replicate (Shape inner_dims) (intConst Int64 0)
  let hist_zrn = Shape
-> SubExp -> [VName] -> [SubExp] -> Lambda SOACS -> HistOp SOACS
forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
rf [VName
zr_counts0] [if [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs_dims Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 then IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0 else SubExp
zrn_ne] Lambda SOACS
lam_add

  f' <- mkIdentityLambda [Prim int64, Prim int64, rowType vs_type, rowType $ Array int64 (Shape vs_dims) NoUniqueness]
  nz_prods <- newVName "non_zero_prod"
  zr_counts <- newVName "zero_count"
  auxing aux $
    letBindNames [nz_prods, zr_counts] $
      Op $
        Hist n [is, is, ps, zs] [hist_nzp, hist_zrn] f'

  p_param <- newParam "prod" $ Prim t
  c_param <- newParam "count" $ Prim int64
  lam_h_part_inner <-
    mkLambda [p_param, c_param] $
      fmap varsRes . letTupExp "h_part"
        =<< eIf
          (toExp $ 0 .==. le64 (paramName c_param))
          (eBody $ pure $ eParam p_param)
          (eBody $ pure $ eSubExp $ Constant $ blankPrimValue t)
  lam_h_part <- nestedmap dst_dims [vs_elm_type, int64] lam_h_part_inner
  h_part_res <- eLambda lam_h_part $ map (eSubExp . Var) [nz_prods, zr_counts]
  h_part' <- bindSubExpRes "h_part" h_part_res
  let [h_part] = h_part'

  lam_mul_inner' <- binOpLambda mul t
  lam_mul' <- nestedmap dst_dims [vs_elm_type, vs_elm_type] lam_mul_inner'
  x_res <- eLambda lam_mul' $ map (eSubExp . Var) [dst, h_part]
  x' <- bindSubExpRes "x" x_res
  auxing aux $ letBindNames [x] $ BasicOp $ SubExp $ Var $ head x'

  m

  x_bar <- lookupAdjVal x

  lam_mul'' <- renameLambda lam_mul'
  dst_bar_res <- eLambda lam_mul'' $ map (eSubExp . Var) [h_part, x_bar]
  dst_bar <- bindSubExpRes (baseString dst <> "_bar") dst_bar_res
  updateAdj dst $ head dst_bar

  lam_mul''' <- renameLambda lam_mul'
  part_bar_res <- eLambda lam_mul''' $ map (eSubExp . Var) [dst, x_bar]
  part_bar' <- bindSubExpRes "part_bar" part_bar_res
  let [part_bar] = part_bar'

  inner_params <- zipWithM newParam ["zr_cts", "pr_bar", "nz_prd", "a"] $ map Prim [int64, t, t, t]
  let [zr_cts, pr_bar, nz_prd, a_param] = inner_params
  lam_vsbar_inner <-
    mkLambda inner_params $
      fmap varsRes . letTupExp "vs_bar" =<< do
        eIf
          (eCmpOp (CmpEq int64) (eSubExp $ intConst Int64 0) (eParam zr_cts))
          (eBody $ pure $ eBinOp mul (eParam pr_bar) $ eBinOp (getBinOpDiv t) (eParam nz_prd) $ eParam a_param)
          ( eBody $
              pure $
                eIf
                  ( eBinOp
                      LogAnd
                      (eCmpOp (CmpEq int64) (eSubExp $ intConst Int64 1) (eParam zr_cts))
                      (eCmpOp (CmpEq t) (eSubExp $ Constant $ blankPrimValue t) $ eParam a_param)
                  )
                  (eBody $ pure $ eBinOp mul (eParam nz_prd) (eParam pr_bar))
                  (eBody $ pure $ eSubExp $ Constant $ blankPrimValue t)
          )

  lam_vsbar_middle <- nestedmap inner_dims [int64, t, t, t] lam_vsbar_inner

  i_param <- newParam "i" $ Prim int64
  a_param' <- newParam "a" $ rowType vs_type
  lam_vsbar <-
    mkLambda [i_param, a_param'] $
      fmap varsRes . letTupExp "vs_bar"
        =<< eIf
          (toExp $ withinBounds $ pure (w, paramName i_param))
          ( buildBody_ $ do
              let i = Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
vs_type [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
i_param]
              names <- traverse newVName ["zr_cts", "pr_bar", "nz_prd"]
              zipWithM_ (\VName
name -> [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
name] (Exp SOACS -> ADM ()) -> (VName -> Exp SOACS) -> VName -> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> (VName -> BasicOp) -> VName -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Slice SubExp -> BasicOp)
-> Slice SubExp -> VName -> BasicOp
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Slice SubExp -> BasicOp
Index Slice SubExp
i) names [zr_counts, part_bar, nz_prods]
              eLambda lam_vsbar_middle $ map (eSubExp . Var) names <> [eParam a_param']
          )
          (eBody $ pure $ pure $ zeroExp $ rowType dst_type)

  vs_bar <-
    letExp (baseString vs <> "_bar") $ Op $ Screma n [is, vs] $ mapSOAC lam_vsbar

  updateAdj vs vs_bar

--
-- special case of histogram with add as operator.
-- Original, assuming `is: [n]i64` and `dst: [w]btp`
--     let x = reduce_by_index dst (+) ne is vs
-- Forward sweep:
--     need to copy dst: reverse sweep might use it 7
--       (see ex. in reducebyindexminmax6.fut where the first map requires the original dst to be differentiated).
--     let dst_cpy = copy dst
--     let x = reduce_by_index dst_cpy (+) ne is vs
--
-- Reverse sweep:
--     dst_bar += x_bar
--
--     vs_bar += map (\i -> x_bar[i]) is
diffAddHist ::
  VjpOps -> VName -> StmAux () -> SubExp -> Lambda SOACS -> SubExp -> VName -> VName -> SubExp -> SubExp -> VName -> ADM () -> ADM ()
diffAddHist :: VjpOps
-> VName
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> SubExp
-> VName
-> VName
-> SubExp
-> SubExp
-> VName
-> ADM ()
-> ADM ()
diffAddHist VjpOps
_ops VName
x StmAux ()
aux SubExp
n Lambda SOACS
add SubExp
ne VName
is VName
vs SubExp
w SubExp
rf VName
dst ADM ()
m = do
  let t :: Type
t = Param Type -> Type
forall dec. Param dec -> dec
paramDec (Param Type -> Type) -> Param Type -> Type
forall a b. (a -> b) -> a -> b
$ [Param Type] -> Param Type
forall a. HasCallStack => [a] -> a
head ([Param Type] -> Param Type) -> [Param Type] -> Param Type
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
add

  dst_cpy <-
    String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
dst String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_copy") (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
      Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty (VName -> SubExp
Var VName
dst)

  f <- mkIdentityLambda [Prim int64, t]
  auxing aux . letBindNames [x] . Op $
    Hist n [is, vs] [HistOp (Shape [w]) rf [dst_cpy] [ne] add] f

  m

  x_bar <- lookupAdjVal x

  updateAdj dst x_bar

  x_type <- lookupType x
  i_param <- newParam (baseString vs <> "_i") $ Prim int64
  let i = Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
i_param
  lam_vsbar <-
    mkLambda [i_param] $
      fmap varsRes . letTupExp "vs_bar"
        =<< eIf
          (toExp $ withinBounds $ pure (w, i))
          (eBody $ pure $ pure $ BasicOp $ Index x_bar $ fullSlice x_type [DimFix $ Var i])
          (eBody $ pure $ eSubExp ne)

  vs_bar <- letExp (baseString vs <> "_bar") $ Op $ Screma n [is] $ mapSOAC lam_vsbar
  updateAdj vs vs_bar

-- Special case for vectorised combining operator. Rewrite
--   reduce_by_index dst (map2 op) nes is vss
-- to
--   map3 (\dst_col vss_col ne ->
--           reduce_by_index dst_col op ne is vss_col
--        ) (transpose dst) (transpose vss) nes |> transpose
-- before differentiating.
diffVecHist ::
  VjpOps ->
  VName ->
  StmAux () ->
  SubExp ->
  Lambda SOACS ->
  VName ->
  VName ->
  VName ->
  SubExp ->
  SubExp ->
  VName ->
  ADM () ->
  ADM ()
diffVecHist :: VjpOps
-> VName
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> VName
-> VName
-> VName
-> SubExp
-> SubExp
-> VName
-> ADM ()
-> ADM ()
diffVecHist VjpOps
ops VName
x StmAux ()
aux SubExp
n Lambda SOACS
op VName
nes VName
is VName
vss SubExp
w SubExp
rf VName
dst ADM ()
m = do
  stms <- ADM () -> ADM (Stms (Rep ADM))
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (ADM () -> ADM (Stms (Rep ADM))) -> ADM () -> ADM (Stms (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
    rank <- Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Type -> Int) -> ADM Type -> ADM Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
vss
    let dims = [Int
1, Int
0] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
drop Int
2 [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]

    dstT <- letExp "dstT" $ BasicOp $ Rearrange dst dims
    vssT <- letExp "vssT" $ BasicOp $ Rearrange vss dims
    t_dstT <- lookupType dstT
    t_vssT <- lookupType vssT
    t_nes <- lookupType nes

    dst_col <- newParam "dst_col" $ rowType t_dstT
    vss_col <- newParam "vss_col" $ rowType t_vssT
    ne <- newParam "ne" $ rowType t_nes

    f <- mkIdentityLambda (Prim int64 : lambdaReturnType op)
    map_lam <-
      mkLambda [dst_col, vss_col, ne] $ do
        -- TODO Have to copy dst_col, but isn't it already unique?
        dst_col_cpy <-
          letExp "dst_col_cpy" . BasicOp $
            Replicate mempty (Var $ paramName dst_col)
        fmap (varsRes . pure) . letExp "col_res" $
          Op $
            Hist
              n
              [is, paramName vss_col]
              [HistOp (Shape [w]) rf [dst_col_cpy] [Var $ paramName ne] op]
              f
    histT <-
      letExp "histT" . Op $
        Screma (arraySize 0 t_dstT) [dstT, vssT, nes] $
          mapSOAC map_lam
    auxing aux . letBindNames [x] . BasicOp $ Rearrange histT dims
  foldr (vjpStm ops) m stms

--
-- a step in the radix sort implementation
-- it assumes the key we are sorting
-- after is [n]i64 and it is the first VName
--
-- local def radix_sort_step [n] 't (xs: [n]t) (get_bit: i32 -> t -> i32)
--                                  (digit_n: i32): [n]t =
--   let num x = get_bit (digit_n+1) x * 2 + get_bit digit_n x
--   let pairwise op (a1,b1,c1,d1) (a2,b2,c2,d2) =
--     (a1 `op` a2, b1 `op` b2, c1 `op` c2, d1 `op` d2)
--   let bins = xs |> map num
--   let flags = bins |> map (\x -> if x == 0 then (1,0,0,0)
--                                  else if x == 1 then (0,1,0,0)
--                                  else if x == 2 then (0,0,1,0)
--                                  else (0,0,0,1))
--   let offsets = scan (pairwise (+)) (0,0,0,0) flags
--   let (na,nb,nc,_nd) = last offsets
--   let f bin (a,b,c,d) = match bin
--                         case 0 -> a-1
--                         case 1 -> na+b-1
--                         case 2 -> na+nb+c-1
--                         case _ -> na+nb+nc+d-1
--   let is = map2 f bins offsets
--   in scatter scratch is xs
radixSortStep :: [VName] -> [Type] -> SubExp -> SubExp -> SubExp -> ADM [VName]
radixSortStep :: [VName] -> [Type] -> SubExp -> SubExp -> SubExp -> ADM [VName]
radixSortStep [VName]
xs [Type]
tps SubExp
bit SubExp
n SubExp
w = do
  -- let is = head xs
  is <- VName -> SubExp -> SubExp -> ADM VName
mapout ([VName] -> VName
forall a. HasCallStack => [a] -> a
head [VName]
xs) SubExp
n SubExp
w

  num_param <- newParam "num" $ Prim int64
  num_lam <-
    mkLambda [num_param] $
      fmap varsRes . letTupExp "num_res"
        =<< eBinOp
          (Add Int64 OverflowUndef)
          ( eBinOp
              (And Int64)
              (eBinOp (AShr Int64) (eParam num_param) (eSubExp bit))
              (iConst 1)
          )
          ( eBinOp
              (Mul Int64 OverflowUndef)
              (iConst 2)
              ( eBinOp
                  (And Int64)
                  (eBinOp (AShr Int64) (eParam num_param) (eBinOp (Add Int64 OverflowUndef) (eSubExp bit) (iConst 1)))
                  (iConst 1)
              )
          )

  bins <- letExp "bins" $ Op $ Screma n [is] $ mapSOAC num_lam
  flag_param <- newParam "flag" $ Prim int64
  flag_lam <-
    mkLambda [flag_param] $
      fmap varsRes . letTupExp "flag_res"
        =<< elseIf
          int64
          (map ((,) (eParam flag_param) . iConst) [0 .. 2])
          (map (eBody . fmap iConst . (\Integer
i -> (Integer -> Integer) -> [Integer] -> [Integer]
forall a b. (a -> b) -> [a] -> [b]
map (\Integer
j -> if Integer
i Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
j then Integer
1 else Integer
0) [Integer
0 .. Integer
3])) ([0 .. 3] :: [Integer]))

  flags <- letTupExp "flags" $ Op $ Screma n [bins] $ mapSOAC flag_lam

  scan_params <- traverse (flip newParam $ Prim int64) ["a1", "b1", "c1", "d1", "a2", "b2", "c2", "d2"]
  scan_lam <-
    mkLambda scan_params $
      fmap subExpsRes . mapM (letSubExp "scan_res") =<< do
        uncurry (zipWithM (eBinOp $ Add Int64 OverflowUndef)) $ splitAt 4 $ map eParam scan_params

  scan <- scanSOAC $ pure $ Scan scan_lam $ map (intConst Int64) [0, 0, 0, 0]
  offsets <- letTupExp "offsets" $ Op $ Screma n flags scan

  ind <- letSubExp "ind_last" =<< eBinOp (Sub Int64 OverflowUndef) (eSubExp n) (iConst 1)
  let i = [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
ind]
  nabcd <- traverse newVName ["na", "nb", "nc", "nd"]
  zipWithM_ (\VName
abcd -> [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
abcd] (Exp SOACS -> ADM ()) -> (VName -> Exp SOACS) -> VName -> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> (VName -> BasicOp) -> VName -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Slice SubExp -> BasicOp)
-> Slice SubExp -> VName -> BasicOp
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Slice SubExp -> BasicOp
Index Slice SubExp
i) nabcd offsets

  let vars = (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
nabcd
  map_params <- traverse (flip newParam $ Prim int64) ["bin", "a", "b", "c", "d"]
  map_lam <-
    mkLambda map_params $
      fmap varsRes . letTupExp "map_res"
        =<< elseIf
          int64
          (map ((,) (eParam $ head map_params) . iConst) [0 .. 2])
          ( zipWith
              ( \Int
j Param Type
p ->
                  [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$
                    ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ do
                      t <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"t" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (IntType -> Overflow -> BinOp
Sub IntType
Int64 Overflow
OverflowUndef) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
p) (Integer -> ADM (Exp (Rep ADM))
forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst Integer
1)
                      foldBinOp (Add Int64 OverflowUndef) (intConst Int64 0) (t : take j vars)
              )
              [0 .. 3]
              (tail map_params)
          )

  nis <- letExp "nis" $ Op $ Screma n (bins : offsets) $ mapSOAC map_lam

  scatter_dst <- traverse (\Type
t -> String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"scatter_dst" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ PrimType -> [SubExp] -> BasicOp
Scratch (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)) tps
  multiScatter n scatter_dst nis xs
  where
    iConst :: Integer -> m (Exp (Rep m))
iConst Integer
c = SubExp -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> m (Exp (Rep m))) -> SubExp -> m (Exp (Rep m))
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
c

--
-- the radix sort implementation
-- def radix_sort [n] 't (xs: [n]i64) =
--   let iters = if n == 0 then 0 else 32
--   in loop xs for i < iters do radix_sort_step xs i64.get_bit (i*2)
radixSort :: [VName] -> SubExp -> SubExp -> ADM [VName]
radixSort :: [VName] -> SubExp -> SubExp -> ADM [VName]
radixSort [VName]
xs SubExp
n SubExp
w = do
  logw <- SubExp -> ADM SubExp
log2 (SubExp -> ADM SubExp) -> ADM SubExp -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"w1" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1)
  -- ceil logw by (logw + 1) / 2
  iters <- letSubExp "iters" =<< toExp (untyped (pe64 logw + 1) ~/~ untyped (pe64 (intConst Int64 2)))

  types <- traverse lookupType xs
  params <- zipWithM (\VName
x -> String
-> TypeBase Shape Uniqueness
-> ADM (Param (TypeBase Shape Uniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString VName
x) (TypeBase Shape Uniqueness
 -> ADM (Param (TypeBase Shape Uniqueness)))
-> (Type -> TypeBase Shape Uniqueness)
-> Type
-> ADM (Param (TypeBase Shape Uniqueness))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Uniqueness -> TypeBase Shape Uniqueness)
-> Uniqueness -> Type -> TypeBase Shape Uniqueness
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> Uniqueness -> TypeBase Shape Uniqueness
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Uniqueness
Nonunique) xs types
  i <- newVName "i"
  loopbody <- buildBody_ . localScope (scopeOfFParams params) $
    fmap varsRes $ do
      bit <- letSubExp "bit" =<< toExp (le64 i * 2)
      radixSortStep (map paramName params) types bit n w

  letTupExp "sorted" $
    Loop
      (zip params $ map Var xs)
      (ForLoop i Int64 iters)
      loopbody
  where
    log2 :: SubExp -> ADM SubExp
    log2 :: SubExp -> ADM SubExp
log2 SubExp
m = do
      params <- (String
 -> TypeBase Shape Uniqueness
 -> ADM (Param (TypeBase Shape Uniqueness)))
-> [String]
-> [TypeBase Shape Uniqueness]
-> ADM [Param (TypeBase Shape Uniqueness)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM String
-> TypeBase Shape Uniqueness
-> ADM (Param (TypeBase Shape Uniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam [String
"cond", String
"r", String
"i"] ([TypeBase Shape Uniqueness]
 -> ADM [Param (TypeBase Shape Uniqueness)])
-> [TypeBase Shape Uniqueness]
-> ADM [Param (TypeBase Shape Uniqueness)]
forall a b. (a -> b) -> a -> b
$ (PrimType -> TypeBase Shape Uniqueness)
-> [PrimType] -> [TypeBase Shape Uniqueness]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> TypeBase Shape Uniqueness
forall shape u. PrimType -> TypeBase shape u
Prim [PrimType
Bool, PrimType
int64, PrimType
int64]
      let [cond, r, i] = params

      body <- buildBody_ . localScope (scopeOfFParams params) $ do
        r' <- letSubExp "r'" =<< toExp (le64 (paramName r) .>>. 1)
        cond' <- letSubExp "cond'" =<< toExp (bNot $ pe64 r' .==. 0)
        i' <- letSubExp "i'" =<< toExp (le64 (paramName i) + 1)
        pure $ subExpsRes [cond', r', i']

      cond_init <- letSubExp "test" =<< toExp (bNot $ pe64 m .==. 0)

      l <-
        letTupExp' "log2res" $
          Loop
            (zip params [cond_init, m, Constant $ blankPrimValue int64])
            (WhileLoop $ paramName cond)
            body

      let [_, _, res] = l
      pure res

radixSort' :: [VName] -> SubExp -> SubExp -> ADM [VName]
radixSort' :: [VName] -> SubExp -> SubExp -> ADM [VName]
radixSort' [VName]
xs SubExp
n SubExp
w = do
  iota_n <-
    String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"red_iota" (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
      SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
n (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64

  radres <- radixSort [head xs, iota_n] n w
  let [is', iota'] = radres

  i_param <- newParam "i" $ Prim int64
  let slice = [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
i_param]
  map_lam <- mkLambda [i_param] $ varsRes <$> multiIndex (tail xs) slice

  sorted <- letTupExp "sorted" $ Op $ Screma n [iota'] $ mapSOAC map_lam
  pure $ iota' : is' : sorted

--
-- generic case of histogram.
-- Original, assuming `is: [n]i64` and `dst: [w]btp`
--   let xs = reduce_by_index dst odot ne is as
-- Forward sweep:
-- let h_part = reduce_by_index (replicate w ne) odot ne is as
-- let xs = map2 odot dst h_part
-- Reverse sweep:
-- h_part_bar += f'' dst h_part
-- dst_bar += f' dst h_part

-- let flag = map (\i -> i == 0 || sis[i] != sis[i-1]) (iota n)
-- let flag_rev = map (\i -> i==0 || flag[n-i]) (iota n)
-- let ls = seg_scan_exc odot ne flag sas
-- let rs = reverse sas |>
--          seg_scan_exc odot ne flag_rev |> reverse
-- let f_bar = map (\i -> if i < w && -1 < w
--                        then h_part_bar[i]
--                        else 0s
--                 ) sis
-- let sas_bar = f f_dst ls sas rs
-- as_bar += scatter (Scratch alpha n) siota sas_bar
-- Where:
--  siota: 'iota n' sorted wrt 'is'
--  sis: 'is' sorted wrt 'is'
--  sas: 'as' sorted wrt 'is'
--  f'' = vjpLambda xs_bar h_part (map2 odot)
--  f' = vjpLambda xs_bar dst (map2 odot)
--  f  = vjpLambda f_bar sas (map4 (\di li ai ri -> di odot li odot ai odot ri))
--  0s is an alpha-dimensional array with 0 (possibly 0-dim)
diffHist :: VjpOps -> [VName] -> StmAux () -> SubExp -> Lambda SOACS -> [SubExp] -> [VName] -> [SubExp] -> SubExp -> [VName] -> ADM () -> ADM ()
diffHist :: VjpOps
-> [VName]
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> [SubExp]
-> [VName]
-> [SubExp]
-> SubExp
-> [VName]
-> ADM ()
-> ADM ()
diffHist VjpOps
ops [VName]
xs StmAux ()
aux SubExp
n Lambda SOACS
lam0 [SubExp]
ne [VName]
as [SubExp]
w SubExp
rf [VName]
dst ADM ()
m = do
  as_type <- (VName -> ADM Type) -> [VName] -> ADM [Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType ([VName] -> ADM [Type]) -> [VName] -> ADM [Type]
forall a b. (a -> b) -> a -> b
$ [VName] -> [VName]
forall a. HasCallStack => [a] -> [a]
tail [VName]
as
  dst_type <- traverse lookupType dst

  nes <- traverse (letExp "new_dst" . BasicOp . Replicate (Shape $ pure $ head w)) ne

  h_map <- mkIdentityLambda $ Prim int64 : map rowType as_type
  h_part <- traverse (newVName . flip (<>) "_h_part" . baseString) xs
  auxing aux . letBindNames h_part . Op $
    Hist n as [HistOp (Shape w) rf nes ne lam0] h_map

  lam0' <- renameLambda lam0
  auxing aux . letBindNames xs . Op $
    Screma (head w) (dst <> h_part) (mapSOAC lam0')

  m

  xs_bar <- traverse lookupAdjVal xs

  (dst_params, hp_params, f') <- mkF' lam0 dst_type $ head w
  f'_adj_dst <- vjpLambda ops (map adjFromVar xs_bar) dst_params f'
  f'_adj_hp <- vjpLambda ops (map adjFromVar xs_bar) hp_params f'

  dst_bar' <- eLambda f'_adj_dst $ map (eSubExp . Var) $ dst <> h_part
  dst_bar <- bindSubExpRes "dst_bar" dst_bar'
  zipWithM_ updateAdj dst dst_bar

  h_part_bar' <- eLambda f'_adj_hp $ map (eSubExp . Var) $ dst <> h_part
  h_part_bar <- bindSubExpRes "h_part_bar" h_part_bar'

  lam <- renameLambda lam0
  lam' <- renameLambda lam0

  -- is' <- mapout (head as) n (head w)
  -- sorted <- radixSort' (is' : tail as) n $ head w
  sorted <- radixSort' as n $ head w
  let siota = [VName] -> VName
forall a. HasCallStack => [a] -> a
head [VName]
sorted
  let sis = [VName] -> VName
forall a. HasCallStack => [a] -> a
head ([VName] -> VName) -> [VName] -> VName
forall a b. (a -> b) -> a -> b
$ [VName] -> [VName]
forall a. HasCallStack => [a] -> [a]
tail [VName]
sorted
  let sas = Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
drop Int
2 [VName]
sorted

  iota_n <-
    letExp "iota" $ BasicOp $ Iota n (intConst Int64 0) (intConst Int64 1) Int64

  par_i <- newParam "i" $ Prim int64
  flag_lam <- mkFlagLam par_i sis
  flag <- letExp "flag" $ Op $ Screma n [iota_n] $ mapSOAC flag_lam

  -- map (\i -> (if flag[i] then (true,ne) else (false,vs[i-1]), if i==0 || flag[n-i] then (true,ne) else (false,vs[n-i]))) (iota n)
  par_i' <- newParam "i" $ Prim int64
  let i' = Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
par_i'
  g_lam <-
    mkLambda [par_i'] $
      fmap subExpsRes . mapM (letSubExp "scan_inps") =<< do
        im1 <- letSubExp "i_1" =<< toExp (le64 i' - 1)
        nmi <- letSubExp "n_i" =<< toExp (pe64 n - le64 i')
        let s1 = [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
im1]
        let s2 = [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
nmi]

        -- flag array for left scan
        f1 <- letSubExp "f1" $ BasicOp $ Index flag $ Slice [DimFix $ Var i']

        -- array for left scan
        r1 <-
          letTupExp' "r1"
            =<< eIf
              (eSubExp f1)
              (eBody $ fmap eSubExp ne)
              (eBody . fmap (eSubExp . Var) =<< multiIndex sas s1)

        -- array for right scan inc flag
        r2 <-
          letTupExp' "r2"
            =<< eIf
              (toExp $ le64 i' .==. 0)
              (eBody $ fmap eSubExp $ Constant (onePrimValue Bool) : ne)
              ( eBody $
                  pure $ do
                    eIf
                      (pure $ BasicOp $ Index flag $ Slice s2)
                      (eBody $ fmap eSubExp $ Constant (onePrimValue Bool) : ne)
                      ( eBody . fmap eSubExp . (Constant (blankPrimValue Bool) :) . fmap Var
                          =<< multiIndex sas s2
                      )
              )

        traverse eSubExp $ f1 : r1 ++ r2

  -- scan (\(f1,v1) (f2,v2) ->
  --   let f = f1 || f2
  --   let v = if f2 then v2 else g v1 v2
  --   in (f,v) ) (false,ne) (zip flags vals)
  scan_lams <-
    traverse
      ( \Lambda SOACS
l -> do
          f1 <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"f1" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Bool
          f2 <- newParam "f2" $ Prim Bool
          ps <- lambdaParams <$> renameLambda lam0
          let (p1, p2) = splitAt (length ne) ps

          mkLambda (f1 : p1 ++ f2 : p2) $
            fmap varsRes . letTupExp "scan_res" =<< do
              let f = BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp BinOp
LogOr (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
f1) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
f2)
              eIf
                (eParam f2)
                (eBody $ f : fmap eParam p2)
                ( eBody . (f :) . fmap (eSubExp . Var)
                    =<< bindSubExpRes "gres"
                    =<< eLambda l (fmap eParam ps)
                )
      )
      [lam, lam']

  let ne' = PrimValue -> SubExp
Constant (Bool -> PrimValue
BoolValue Bool
False) SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: [SubExp]
ne

  scansres <-
    letTupExp "adj_ctrb_scan" . Op $
      Screma n [iota_n] (scanomapSOAC (map (`Scan` ne') scan_lams) g_lam)

  let (_ : ls_arr, _ : rs_arr_rev) = splitAt (length ne + 1) scansres

  -- map (\i -> if i < w && -1 < w then (xs_bar[i], dst[i]) else (0,ne)) sis
  par_i'' <- newParam "i" $ Prim int64
  let i'' = Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
par_i''
  map_lam <-
    mkLambda [par_i''] $
      fmap varsRes . letTupExp "scan_res"
        =<< eIf
          (toExp $ withinBounds $ pure (head w, i''))
          (eBody . fmap (eSubExp . Var) =<< multiIndex h_part_bar [DimFix $ Var i''])
          ( eBody $ do
              map (\Type
t -> Exp SOACS -> ADM (Exp SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp SOACS -> ADM (Exp SOACS)) -> Exp SOACS -> ADM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExp]
forall a. HasCallStack => [a] -> [a]
tail ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t) (PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue (PrimType -> PrimValue) -> PrimType -> PrimValue
forall a b. (a -> b) -> a -> b
$ Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t)) as_type
          )

  f_bar <- letTupExp "f_bar" $ Op $ Screma n [sis] $ mapSOAC map_lam

  (as_params, f) <- mkF lam0 as_type n
  f_adj <- vjpLambda ops (map adjFromVar f_bar) as_params f

  -- map (\i -> rs_arr_rev[n-i-1]) (iota n)
  par_i''' <- newParam "i" $ Prim int64
  let i''' = Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
par_i'''
  rev_lam <- mkLambda [par_i'''] $ do
    nmim1 <- letSubExp "n_i_1" =<< toExp (pe64 n - le64 i''' - 1)
    varsRes <$> multiIndex rs_arr_rev [DimFix nmim1]

  rs_arr <- letTupExp "rs_arr" $ Op $ Screma n [iota_n] $ mapSOAC rev_lam

  sas_bar <-
    bindSubExpRes "sas_bar"
      =<< eLambda f_adj (map (eSubExp . Var) $ ls_arr <> sas <> rs_arr)

  scatter_dst <- traverse (\Type
t -> String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"scatter_dst" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ PrimType -> [SubExp] -> BasicOp
Scratch (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)) as_type
  as_bar <- multiScatter n scatter_dst siota sas_bar

  zipWithM_ updateAdj (tail as) as_bar
  where
    -- map (\i -> if i == 0 then true else is[i] != is[i-1]) (iota n)
    mkFlagLam :: LParam SOACS -> VName -> ADM (Lambda SOACS)
    mkFlagLam :: LParam SOACS -> VName -> ADM (Lambda SOACS)
mkFlagLam LParam SOACS
par_i VName
sis =
      [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [LParam (Rep ADM)
LParam SOACS
par_i] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
        ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"flag" (Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< do
          let i :: VName
i = Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
LParam SOACS
par_i
          ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
            (TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0))
            ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
onePrimValue PrimType
Bool)
            ( [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$
                ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ do
                  i_p <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"i_p" (Exp SOACS -> ADM VName) -> ADM (Exp SOACS) -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
                  vs <- traverse (letExp "vs" . BasicOp . Index sis . Slice . pure . DimFix . Var) [i, i_p]
                  let [vs_i, vs_p] = vs
                  toExp $ bNot $ le64 vs_i .==. le64 vs_p
            )