-- | Type inference of @loop@.  This is complicated because of the
-- uniqueness and size inference, so the implementation is separate
-- from the main type checker.
module Language.Futhark.TypeChecker.Terms.Loop
  ( UncheckedLoop,
    CheckedLoop,
    checkLoop,
  )
where

import Control.Monad
import Control.Monad.Reader
import Control.Monad.State
import Data.Bifunctor
import Data.Bitraversable
import Data.List qualified as L
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Set qualified as S
import Futhark.Util (nubOrd)
import Futhark.Util.Pretty hiding (group, space)
import Language.Futhark
import Language.Futhark.TypeChecker.Monad hiding (BoundV)
import Language.Futhark.TypeChecker.Terms.Monad
import Language.Futhark.TypeChecker.Terms.Pat
import Language.Futhark.TypeChecker.Types
import Language.Futhark.TypeChecker.Unify
import Prelude hiding (mod)

-- | Retrieve an oracle that can be used to decide whether two are in
-- the same equivalence class (i.e. have been unified).  This is an
-- exotic operation.
getAreSame :: (MonadUnify m) => m (VName -> VName -> Bool)
getAreSame :: forall (m :: * -> *). MonadUnify m => m (VName -> VName -> Bool)
getAreSame = Map VName (Level, Constraint) -> VName -> VName -> Bool
forall {a}. Map VName (a, Constraint) -> VName -> VName -> Bool
check (Map VName (Level, Constraint) -> VName -> VName -> Bool)
-> m (Map VName (Level, Constraint)) -> m (VName -> VName -> Bool)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (Map VName (Level, Constraint))
forall (m :: * -> *).
MonadUnify m =>
m (Map VName (Level, Constraint))
getConstraints
  where
    check :: Map VName (a, Constraint) -> VName -> VName -> Bool
check Map VName (a, Constraint)
constraints VName
x VName
y =
      case (VName -> Map VName (a, Constraint) -> Maybe (a, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
x Map VName (a, Constraint)
constraints, VName -> Map VName (a, Constraint) -> Maybe (a, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
y Map VName (a, Constraint)
constraints) of
        (Just (a
_, Size (Just (Var QualName VName
x' Info StructType
_ SrcLoc
_)) Usage
_), Maybe (a, Constraint)
_) ->
          Map VName (a, Constraint) -> VName -> VName -> Bool
check Map VName (a, Constraint)
constraints (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
x') VName
y
        (Maybe (a, Constraint)
_, Just (a
_, Size (Just (Var QualName VName
y' Info StructType
_ SrcLoc
_)) Usage
_)) ->
          Map VName (a, Constraint) -> VName -> VName -> Bool
check Map VName (a, Constraint)
constraints VName
x (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
y')
        (Maybe (a, Constraint), Maybe (a, Constraint))
_ ->
          VName
x VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
y

-- | Replace specified sizes with distinct fresh size variables.
someDimsFreshInType ::
  SrcLoc ->
  Name ->
  [VName] ->
  TypeBase Size als ->
  TermTypeM (TypeBase Size als)
someDimsFreshInType :: forall als.
SrcLoc
-> Name
-> [VName]
-> TypeBase Exp als
-> TermTypeM (TypeBase Exp als)
someDimsFreshInType SrcLoc
loc Name
desc [VName]
fresh TypeBase Exp als
t = do
  areSameSize <- TermTypeM (VName -> VName -> Bool)
forall (m :: * -> *). MonadUnify m => m (VName -> VName -> Bool)
getAreSame
  let freshen VName
v = (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> VName -> Bool
areSameSize VName
v) [VName]
fresh
  bitraverse (onDim freshen) pure t
  where
    onDim :: (VName -> Bool) -> Exp -> m Exp
onDim VName -> Bool
freshen (Var QualName VName
d Info StructType
_ SrcLoc
_)
      | VName -> Bool
freshen (VName -> Bool) -> VName -> Bool
forall a b. (a -> b) -> a -> b
$ QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
d = do
          v <- Usage -> Name -> m VName
forall (m :: * -> *). MonadUnify m => Usage -> Name -> m VName
newFlexibleDim (SrcLoc -> Usage
forall a. Located a => a -> Usage
mkUsage' SrcLoc
loc) Name
desc
          pure $ sizeFromName (qualName v) loc
    onDim VName -> Bool
_ Exp
d = Exp -> m Exp
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
d

-- | Replace the specified sizes with fresh size variables of the
-- specified ridigity.  Returns the new fresh size variables.
freshDimsInType ::
  Usage ->
  Rigidity ->
  Name ->
  [VName] ->
  TypeBase Size u ->
  TermTypeM (TypeBase Size u, [VName])
freshDimsInType :: forall u.
Usage
-> Rigidity
-> Name
-> [VName]
-> TypeBase Exp u
-> TermTypeM (TypeBase Exp u, [VName])
freshDimsInType Usage
usage Rigidity
r Name
desc [VName]
fresh TypeBase Exp u
t = do
  areSameSize <- TermTypeM (VName -> VName -> Bool)
forall (m :: * -> *). MonadUnify m => m (VName -> VName -> Bool)
getAreSame
  second (map snd) <$> runStateT (bitraverse (onDim areSameSize) pure t) mempty
  where
    onDim :: (VName -> VName -> Bool) -> Exp -> t m Exp
onDim VName -> VName -> Bool
areSameSize (Var (QualName [VName]
_ VName
d) Info StructType
_ SrcLoc
_)
      | (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> VName -> Bool
areSameSize VName
d) [VName]
fresh = do
          prev_subst <- ([(VName, VName)] -> Maybe (VName, VName))
-> t m (Maybe (VName, VName))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (([(VName, VName)] -> Maybe (VName, VName))
 -> t m (Maybe (VName, VName)))
-> ([(VName, VName)] -> Maybe (VName, VName))
-> t m (Maybe (VName, VName))
forall a b. (a -> b) -> a -> b
$ ((VName, VName) -> Bool)
-> [(VName, VName)] -> Maybe (VName, VName)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
L.find (VName -> VName -> Bool
areSameSize VName
d (VName -> Bool)
-> ((VName, VName) -> VName) -> (VName, VName) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, VName) -> VName
forall a b. (a, b) -> a
fst)
          case prev_subst of
            Just (VName
_, VName
d') -> Exp -> t m Exp
forall a. a -> t m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> t m Exp) -> Exp -> t m Exp
forall a b. (a -> b) -> a -> b
$ QualName VName -> SrcLoc -> Exp
sizeFromName (VName -> QualName VName
forall v. v -> QualName v
qualName VName
d') (SrcLoc -> Exp) -> SrcLoc -> Exp
forall a b. (a -> b) -> a -> b
$ Usage -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf Usage
usage
            Maybe (VName, VName)
Nothing -> do
              v <- m VName -> t m VName
forall (m :: * -> *) a. Monad m => m a -> t m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m VName -> t m VName) -> m VName -> t m VName
forall a b. (a -> b) -> a -> b
$ Usage -> Rigidity -> Name -> m VName
forall (m :: * -> *).
MonadUnify m =>
Usage -> Rigidity -> Name -> m VName
newDimVar Usage
usage Rigidity
r Name
desc
              modify ((d, v) :)
              pure $ sizeFromName (qualName v) $ srclocOf usage
    onDim VName -> VName -> Bool
_ Exp
d = Exp -> t m Exp
forall a. a -> t m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
d

data ArgSource = Initial | BodyResult

wellTypedLoopArg :: ArgSource -> [VName] -> Pat ParamType -> Exp -> TermTypeM ()
wellTypedLoopArg :: ArgSource -> [VName] -> Pat ParamType -> Exp -> TermTypeM ()
wellTypedLoopArg ArgSource
src [VName]
sparams Pat ParamType
pat Exp
arg = do
  (merge_t, _) <-
    Usage
-> Rigidity
-> Name
-> [VName]
-> StructType
-> TermTypeM (StructType, [VName])
forall u.
Usage
-> Rigidity
-> Name
-> [VName]
-> TypeBase Exp u
-> TermTypeM (TypeBase Exp u, [VName])
freshDimsInType (Exp -> Text -> Usage
forall a. Located a => a -> Text -> Usage
mkUsage Exp
arg Text
desc) Rigidity
Nonrigid Name
"loop" [VName]
sparams (StructType -> TermTypeM (StructType, [VName]))
-> StructType -> TermTypeM (StructType, [VName])
forall a b. (a -> b) -> a -> b
$
      ParamType -> StructType
forall dim u. TypeBase dim u -> TypeBase dim NoUniqueness
toStruct (Pat ParamType -> ParamType
forall d u. Pat (TypeBase d u) -> TypeBase d u
patternType Pat ParamType
pat)
  arg_t <- toStruct <$> expTypeFully arg
  onFailure (checking merge_t arg_t) $
    unify (mkUsage arg desc) merge_t arg_t
  where
    (StructType -> StructType -> Checking
checking, Text
desc) =
      case ArgSource
src of
        ArgSource
Initial -> (StructType -> StructType -> Checking
CheckingLoopInitial, Text
"matching initial loop values to pattern")
        ArgSource
BodyResult -> (StructType -> StructType -> Checking
CheckingLoopBody, Text
"matching loop body to pattern")

-- | An un-checked loop.
type UncheckedLoop =
  (PatBase NoInfo VName ParamType, LoopInitBase NoInfo VName, LoopFormBase NoInfo VName, ExpBase NoInfo VName)

-- | A loop that has been type-checked.
type CheckedLoop =
  ([VName], Pat ParamType, LoopInitBase Info VName, LoopFormBase Info VName, Exp)

checkForImpossible :: Loc -> S.Set VName -> ParamType -> TermTypeM ()
checkForImpossible :: Loc -> Set VName -> ParamType -> TermTypeM ()
checkForImpossible Loc
loc Set VName
known_before ParamType
pat_t = do
  cs <- TermTypeM (Map VName (Level, Constraint))
forall (m :: * -> *).
MonadUnify m =>
m (Map VName (Level, Constraint))
getConstraints
  let bad VName
v = do
        Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ VName
v VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.notMember` Set VName
known_before
        (_, UnknownSize v_loc _) <- VName -> Map VName (Level, Constraint) -> Maybe (Level, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName (Level, Constraint)
cs
        Just . typeError (srclocOf loc) mempty $
          "Inferred type for loop parameter is"
            </> indent 2 (pretty pat_t)
            </> "but"
            <+> dquotes (prettyName v)
            <+> "is an existential size created inside the loop body at"
            <+> pretty (locStrRel loc v_loc)
            <> "."
  case mapMaybe bad $ S.toList $ fvVars $ freeInType pat_t of
    TermTypeM ()
problem : [TermTypeM ()]
_ -> TermTypeM ()
problem
    [] -> () -> TermTypeM ()
forall a. a -> TermTypeM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- | Type-check a @loop@ expression, passing in a function for
-- type-checking subexpressions.
checkLoop ::
  (ExpBase NoInfo VName -> TermTypeM Exp) ->
  UncheckedLoop ->
  SrcLoc ->
  TermTypeM (CheckedLoop, AppRes)
checkLoop :: (ExpBase NoInfo VName -> TermTypeM Exp)
-> UncheckedLoop -> SrcLoc -> TermTypeM (CheckedLoop, AppRes)
checkLoop ExpBase NoInfo VName -> TermTypeM Exp
checkExp (PatBase NoInfo VName ParamType
mergepat, LoopInitBase NoInfo VName
loopinit, LoopFormBase NoInfo VName
form, ExpBase NoInfo VName
loopbody) SrcLoc
loc = do
  loopinit' <- ExpBase NoInfo VName -> TermTypeM Exp
checkExp (ExpBase NoInfo VName -> TermTypeM Exp)
-> ExpBase NoInfo VName -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ case LoopInitBase NoInfo VName
loopinit of
    LoopInitExplicit ExpBase NoInfo VName
e -> ExpBase NoInfo VName
e
    LoopInitImplicit NoInfo (ExpBase NoInfo VName)
_ ->
      -- Should have been filled out in Names
      String -> ExpBase NoInfo VName
forall a. HasCallStack => String -> a
error String
"Unspected LoopInitImplicit"
  known_before <- M.keysSet <$> getConstraints
  zeroOrderType
    (mkUsage loopinit' "use as loop variable")
    "type used as loop variable"
    . toStruct
    =<< expTypeFully loopinit'

  -- The handling of dimension sizes is a bit intricate, but very
  -- similar to checking a function, followed by checking a call to
  -- it.  The overall procedure is as follows:
  --
  -- (1) All empty dimensions in the merge pattern are instantiated
  -- with nonrigid size variables.  All explicitly specified
  -- dimensions are preserved.
  --
  -- (2) The body of the loop is type-checked.  The result type is
  -- combined with the merge pattern type to determine which sizes are
  -- variant, and these are turned into size parameters for the merge
  -- pattern.
  --
  -- (3) We now conceptually have a function parameter type and
  -- return type.  We check that it can be called with the body type
  -- as argument.
  --
  -- (4) Similarly to (3), we check that the "function" can be
  -- called with the initial merge values as argument.  The result
  -- of this is the type of the loop as a whole.

  (merge_t, new_dims_map) <-
    -- dim handling (1)
    allDimsFreshInType (mkUsage loc "loop parameter type inference") Nonrigid "loop_d"
      =<< expTypeFully loopinit'
  let new_dims_to_initial_dim = Map VName Exp -> [(VName, Exp)]
forall k a. Map k a -> [(k, a)]
M.toList Map VName Exp
new_dims_map
      new_dims = ((VName, Exp) -> VName) -> [(VName, Exp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, Exp) -> VName
forall a b. (a, b) -> a
fst [(VName, Exp)]
new_dims_to_initial_dim

  -- dim handling (2)
  let checkLoopReturnSize Pat ParamType
mergepat' Exp
loopbody' = do
        loopbody_t <- Exp -> TermTypeM StructType
expTypeFully Exp
loopbody'
        mergepat_t <- normTypeFully (patternType mergepat')

        let ok_names = Set VName
known_before Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList [VName]
new_dims
        checkForImpossible (locOf mergepat) ok_names mergepat_t

        pat_t <- someDimsFreshInType loc "loop" new_dims mergepat_t

        -- We are ignoring the dimensions here, because any mismatches
        -- should be turned into fresh size variables.
        onFailure (CheckingLoopBody (toStruct pat_t) (toStruct loopbody_t)) $
          unify
            (mkUsage loopbody "matching loop body to loop pattern")
            (toStruct pat_t)
            (toStruct loopbody_t)

        -- Figure out which of the 'new_dims' dimensions are variant.
        -- This works because we know that each dimension from
        -- new_dims in the pattern is unique and distinct.
        areSameSize <- getAreSame
        let onDims p
_ Exp
x Exp
y
              | Exp
x Exp -> Exp -> Bool
forall a. Eq a => a -> a -> Bool
== Exp
y = Exp -> f Exp
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
x
            onDims p
_ Exp
e Exp
d = do
              Set VName -> (VName -> f ()) -> f ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (FV -> Set VName
fvVars (FV -> Set VName) -> FV -> Set VName
forall a b. (a -> b) -> a -> b
$ Exp -> FV
freeInExp Exp
e) ((VName -> f ()) -> f ()) -> (VName -> f ()) -> f ()
forall a b. (a -> b) -> a -> b
$ \VName
v -> do
                case ((VName, Exp) -> Bool) -> [(VName, Exp)] -> Maybe (VName, Exp)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
L.find (VName -> VName -> Bool
areSameSize VName
v (VName -> Bool) -> ((VName, Exp) -> VName) -> (VName, Exp) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, Exp) -> VName
forall a b. (a, b) -> a
fst) [(VName, Exp)]
new_dims_to_initial_dim of
                  Just (VName
_, Exp
e') ->
                    if Exp
e' Exp -> Exp -> Bool
forall a. Eq a => a -> a -> Bool
== Exp
d
                      then (p (Map VName (Subst t)) [VName]
 -> p (Map VName (Subst t)) [VName])
-> f ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((p (Map VName (Subst t)) [VName]
  -> p (Map VName (Subst t)) [VName])
 -> f ())
-> (p (Map VName (Subst t)) [VName]
    -> p (Map VName (Subst t)) [VName])
-> f ()
forall a b. (a -> b) -> a -> b
$ (Map VName (Subst t) -> Map VName (Subst t))
-> p (Map VName (Subst t)) [VName]
-> p (Map VName (Subst t)) [VName]
forall a b c. (a -> b) -> p a c -> p b c
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first ((Map VName (Subst t) -> Map VName (Subst t))
 -> p (Map VName (Subst t)) [VName]
 -> p (Map VName (Subst t)) [VName])
-> (Map VName (Subst t) -> Map VName (Subst t))
-> p (Map VName (Subst t)) [VName]
-> p (Map VName (Subst t)) [VName]
forall a b. (a -> b) -> a -> b
$ VName -> Subst t -> Map VName (Subst t) -> Map VName (Subst t)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v (Subst t -> Map VName (Subst t) -> Map VName (Subst t))
-> Subst t -> Map VName (Subst t) -> Map VName (Subst t)
forall a b. (a -> b) -> a -> b
$ Exp -> Subst t
forall t. Exp -> Subst t
ExpSubst Exp
e'
                      else
                        Bool -> f () -> f ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (VName
v VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Set VName
known_before) (f () -> f ()) -> f () -> f ()
forall a b. (a -> b) -> a -> b
$
                          (p (Map VName (Subst t)) [VName]
 -> p (Map VName (Subst t)) [VName])
-> f ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (([VName] -> [VName])
-> p (Map VName (Subst t)) [VName]
-> p (Map VName (Subst t)) [VName]
forall b c a. (b -> c) -> p a b -> p a c
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (VName
v VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
:))
                  Maybe (VName, Exp)
_ ->
                    () -> f ()
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
              Exp -> f Exp
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
e
        loopbody_t' <- normTypeFully loopbody_t
        merge_t' <- normTypeFully merge_t

        let (init_substs, sparams) =
              execState (matchDims onDims merge_t' loopbody_t') mempty

        -- Make sure that any of new_dims that are invariant will be
        -- replaced with the invariant size in the loop body.  Failure
        -- to do this can cause type annotations to still refer to
        -- new_dims.
        let dimToInit (VName
v, ExpSubst Exp
e) =
              VName -> Constraint -> TermTypeM ()
constrain VName
v (Constraint -> TermTypeM ()) -> Constraint -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ Maybe Exp -> Usage -> Constraint
Size (Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
e) (SrcLoc -> Text -> Usage
forall a. Located a => a -> Text -> Usage
mkUsage SrcLoc
loc Text
"size of loop parameter")
            dimToInit (VName, Subst t)
_ =
              () -> TermTypeM ()
forall a. a -> TermTypeM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        mapM_ dimToInit $ M.toList init_substs

        mergepat'' <- applySubst (`M.lookup` init_substs) <$> updateTypes mergepat'

        -- Eliminate those new_dims that turned into sparams so it won't
        -- look like we have ambiguous sizes lying around.
        modifyConstraints $ M.filterWithKey $ \VName
k (Level, Constraint)
_ -> VName
k VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
sparams

        -- dim handling (3)
        --
        -- The only trick here is that we have to turn any instances
        -- of loop parameters in the type of loopbody' rigid,
        -- because we are no longer in a position to change them,
        -- really.
        wellTypedLoopArg BodyResult sparams mergepat'' loopbody'

        pure (nubOrd sparams, mergepat'')

  (sparams, mergepat', form', loopbody') <-
    case form of
      For IdentBase NoInfo VName StructType
i ExpBase NoInfo VName
uboundexp -> do
        uboundexp' <-
          Text -> [PrimType] -> Exp -> TermTypeM Exp
require Text
"being the bound in a 'for' loop" [PrimType]
anySignedType
            (Exp -> TermTypeM Exp) -> TermTypeM Exp -> TermTypeM Exp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ExpBase NoInfo VName -> TermTypeM Exp
checkExp ExpBase NoInfo VName
uboundexp
        bound_t <- expTypeFully uboundexp'
        bindingIdent i bound_t $ \Ident StructType
i' ->
          [SizeBinder VName]
-> PatBase NoInfo VName ParamType
-> StructType
-> (Pat ParamType
    -> TermTypeM
         ([VName], Pat ParamType, LoopFormBase Info VName, Exp))
-> TermTypeM ([VName], Pat ParamType, LoopFormBase Info VName, Exp)
forall u a.
[SizeBinder VName]
-> PatBase NoInfo VName (TypeBase Exp u)
-> StructType
-> (Pat ParamType -> TermTypeM a)
-> TermTypeM a
bindingPat [] PatBase NoInfo VName ParamType
mergepat StructType
merge_t ((Pat ParamType
  -> TermTypeM
       ([VName], Pat ParamType, LoopFormBase Info VName, Exp))
 -> TermTypeM
      ([VName], Pat ParamType, LoopFormBase Info VName, Exp))
-> (Pat ParamType
    -> TermTypeM
         ([VName], Pat ParamType, LoopFormBase Info VName, Exp))
-> TermTypeM ([VName], Pat ParamType, LoopFormBase Info VName, Exp)
forall a b. (a -> b) -> a -> b
$ \Pat ParamType
mergepat' -> TermTypeM ([VName], Pat ParamType, LoopFormBase Info VName, Exp)
-> TermTypeM ([VName], Pat ParamType, LoopFormBase Info VName, Exp)
forall a. TermTypeM a -> TermTypeM a
incLevel (TermTypeM ([VName], Pat ParamType, LoopFormBase Info VName, Exp)
 -> TermTypeM
      ([VName], Pat ParamType, LoopFormBase Info VName, Exp))
-> TermTypeM ([VName], Pat ParamType, LoopFormBase Info VName, Exp)
-> TermTypeM ([VName], Pat ParamType, LoopFormBase Info VName, Exp)
forall a b. (a -> b) -> a -> b
$ do
            loopbody' <- ExpBase NoInfo VName -> TermTypeM Exp
checkExp ExpBase NoInfo VName
loopbody
            (sparams, mergepat'') <- checkLoopReturnSize mergepat' loopbody'
            pure
              ( sparams,
                mergepat'',
                For i' uboundexp',
                loopbody'
              )
      ForIn PatBase NoInfo VName StructType
xpat ExpBase NoInfo VName
e -> do
        (arr_t, _) <- Usage -> Name -> Level -> TermTypeM (StructType, StructType)
newArrayType (SrcLoc -> Usage
forall a. Located a => a -> Usage
mkUsage' (ExpBase NoInfo VName -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf ExpBase NoInfo VName
e)) Name
"e" Level
1
        e' <- unifies "being iterated in a 'for-in' loop" arr_t =<< checkExp e
        t <- expTypeFully e'
        case t of
          StructType
_
            | Just StructType
t' <- Level -> StructType -> Maybe StructType
forall dim u. Level -> TypeBase dim u -> Maybe (TypeBase dim u)
peelArray Level
1 StructType
t ->
                [SizeBinder VName]
-> PatBase NoInfo VName StructType
-> StructType
-> (Pat ParamType
    -> TermTypeM
         ([VName], Pat ParamType, LoopFormBase Info VName, Exp))
-> TermTypeM ([VName], Pat ParamType, LoopFormBase Info VName, Exp)
forall u a.
[SizeBinder VName]
-> PatBase NoInfo VName (TypeBase Exp u)
-> StructType
-> (Pat ParamType -> TermTypeM a)
-> TermTypeM a
bindingPat [] PatBase NoInfo VName StructType
xpat StructType
t' ((Pat ParamType
  -> TermTypeM
       ([VName], Pat ParamType, LoopFormBase Info VName, Exp))
 -> TermTypeM
      ([VName], Pat ParamType, LoopFormBase Info VName, Exp))
-> (Pat ParamType
    -> TermTypeM
         ([VName], Pat ParamType, LoopFormBase Info VName, Exp))
-> TermTypeM ([VName], Pat ParamType, LoopFormBase Info VName, Exp)
forall a b. (a -> b) -> a -> b
$ \Pat ParamType
xpat' ->
                  [SizeBinder VName]
-> PatBase NoInfo VName ParamType
-> StructType
-> (Pat ParamType
    -> TermTypeM
         ([VName], Pat ParamType, LoopFormBase Info VName, Exp))
-> TermTypeM ([VName], Pat ParamType, LoopFormBase Info VName, Exp)
forall u a.
[SizeBinder VName]
-> PatBase NoInfo VName (TypeBase Exp u)
-> StructType
-> (Pat ParamType -> TermTypeM a)
-> TermTypeM a
bindingPat [] PatBase NoInfo VName ParamType
mergepat StructType
merge_t ((Pat ParamType
  -> TermTypeM
       ([VName], Pat ParamType, LoopFormBase Info VName, Exp))
 -> TermTypeM
      ([VName], Pat ParamType, LoopFormBase Info VName, Exp))
-> (Pat ParamType
    -> TermTypeM
         ([VName], Pat ParamType, LoopFormBase Info VName, Exp))
-> TermTypeM ([VName], Pat ParamType, LoopFormBase Info VName, Exp)
forall a b. (a -> b) -> a -> b
$ \Pat ParamType
mergepat' -> TermTypeM ([VName], Pat ParamType, LoopFormBase Info VName, Exp)
-> TermTypeM ([VName], Pat ParamType, LoopFormBase Info VName, Exp)
forall a. TermTypeM a -> TermTypeM a
incLevel (TermTypeM ([VName], Pat ParamType, LoopFormBase Info VName, Exp)
 -> TermTypeM
      ([VName], Pat ParamType, LoopFormBase Info VName, Exp))
-> TermTypeM ([VName], Pat ParamType, LoopFormBase Info VName, Exp)
-> TermTypeM ([VName], Pat ParamType, LoopFormBase Info VName, Exp)
forall a b. (a -> b) -> a -> b
$ do
                    loopbody' <- ExpBase NoInfo VName -> TermTypeM Exp
checkExp ExpBase NoInfo VName
loopbody
                    (sparams, mergepat'') <- checkLoopReturnSize mergepat' loopbody'
                    pure
                      ( sparams,
                        mergepat'',
                        ForIn (fmap toStruct xpat') e',
                        loopbody'
                      )
            | Bool
otherwise ->
                SrcLoc
-> Notes
-> Doc ()
-> TermTypeM ([VName], Pat ParamType, LoopFormBase Info VName, Exp)
forall loc a. Located loc => loc -> Notes -> Doc () -> TermTypeM a
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc () -> m a
typeError (ExpBase NoInfo VName -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf ExpBase NoInfo VName
e) Notes
forall a. Monoid a => a
mempty (Doc ()
 -> TermTypeM
      ([VName], Pat ParamType, LoopFormBase Info VName, Exp))
-> Doc ()
-> TermTypeM ([VName], Pat ParamType, LoopFormBase Info VName, Exp)
forall a b. (a -> b) -> a -> b
$
                  Doc ()
"Iteratee of a for-in loop must be an array, but expression has type"
                    Doc () -> Doc () -> Doc ()
forall a. Doc a -> Doc a -> Doc a
<+> StructType -> Doc ()
forall a ann. Pretty a => a -> Doc ann
forall ann. StructType -> Doc ann
pretty StructType
t
      While ExpBase NoInfo VName
cond ->
        [SizeBinder VName]
-> PatBase NoInfo VName ParamType
-> StructType
-> (Pat ParamType
    -> TermTypeM
         ([VName], Pat ParamType, LoopFormBase Info VName, Exp))
-> TermTypeM ([VName], Pat ParamType, LoopFormBase Info VName, Exp)
forall u a.
[SizeBinder VName]
-> PatBase NoInfo VName (TypeBase Exp u)
-> StructType
-> (Pat ParamType -> TermTypeM a)
-> TermTypeM a
bindingPat [] PatBase NoInfo VName ParamType
mergepat StructType
merge_t ((Pat ParamType
  -> TermTypeM
       ([VName], Pat ParamType, LoopFormBase Info VName, Exp))
 -> TermTypeM
      ([VName], Pat ParamType, LoopFormBase Info VName, Exp))
-> (Pat ParamType
    -> TermTypeM
         ([VName], Pat ParamType, LoopFormBase Info VName, Exp))
-> TermTypeM ([VName], Pat ParamType, LoopFormBase Info VName, Exp)
forall a b. (a -> b) -> a -> b
$ \Pat ParamType
mergepat' ->
          TermTypeM ([VName], Pat ParamType, LoopFormBase Info VName, Exp)
-> TermTypeM ([VName], Pat ParamType, LoopFormBase Info VName, Exp)
forall a. TermTypeM a -> TermTypeM a
incLevel (TermTypeM ([VName], Pat ParamType, LoopFormBase Info VName, Exp)
 -> TermTypeM
      ([VName], Pat ParamType, LoopFormBase Info VName, Exp))
-> TermTypeM ([VName], Pat ParamType, LoopFormBase Info VName, Exp)
-> TermTypeM ([VName], Pat ParamType, LoopFormBase Info VName, Exp)
forall a b. (a -> b) -> a -> b
$ do
            cond' <-
              ExpBase NoInfo VName -> TermTypeM Exp
checkExp ExpBase NoInfo VName
cond
                TermTypeM Exp -> (Exp -> TermTypeM Exp) -> TermTypeM Exp
forall a b. TermTypeM a -> (a -> TermTypeM b) -> TermTypeM b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Text -> StructType -> Exp -> TermTypeM Exp
unifies Text
"being the condition of a 'while' loop" (ScalarTypeBase Exp NoUniqueness -> StructType
forall dim u. ScalarTypeBase dim u -> TypeBase dim u
Scalar (ScalarTypeBase Exp NoUniqueness -> StructType)
-> ScalarTypeBase Exp NoUniqueness -> StructType
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarTypeBase Exp NoUniqueness
forall dim u. PrimType -> ScalarTypeBase dim u
Prim PrimType
Bool)
            loopbody' <- checkExp loopbody
            (sparams, mergepat'') <- checkLoopReturnSize mergepat' loopbody'
            pure
              ( sparams,
                mergepat'',
                While cond',
                loopbody'
              )

  -- dim handling (4)
  wellTypedLoopArg Initial sparams mergepat' loopinit'

  (loopt, retext) <-
    freshDimsInType
      (mkUsage loc "inference of loop result type")
      (Rigid RigidLoop)
      "loop"
      sparams
      (patternType mergepat')
  pure
    ( (sparams, mergepat', LoopInitExplicit loopinit', form', loopbody'),
      AppRes (toStruct loopt) retext
    )