module Language.Haskell.Liquid.TransformRec (
transformRecExpr
) where
import Bag
import Coercion
import Control.Arrow (second, (***))
import Control.Monad.State
import CoreLint
import CoreSyn
import qualified Data.HashMap.Strict as M
import ErrUtils
import Id (idOccInfo, setIdInfo)
import IdInfo
import MkCore (mkCoreLams)
import Outputable
import SrcLoc
import Type (mkForAllTys)
import TypeRep
import Unique hiding (deriveUnique)
import Var
import Language.Haskell.Liquid.GhcMisc
transformRecExpr :: CoreProgram -> CoreProgram
transformRecExpr cbs
| isEmptyBag e
= pg
| otherwise
= error (showPpr pg ++ "Type-check" ++ showSDoc (pprMessageBag e))
where pg = scopeTr $ evalState (transPg cbs) initEnv
(_, e) = lintCoreBindings pg
scopeTr = outerScTr . map innerScTr
outerScTr [] = []
outerScTr (NonRec x ex : xes) = NonRec x ex : mkOuterScTr x [] xes
outerScTr (xe:xes) = xe : outerScTr xes
mkOuterScTr x bs (NonRec y (Case (Var z) b ys ec) : xes) | z == x
= mkOuterScTr x (NonRec y (Case (Var z) b ys ec) : bs) xes
mkOuterScTr _ bs xes
= bs ++ outerScTr xes
innerScTr = mapBnd scTrans
scTrans x e = mapExpr scTrans $ foldr Let e0 bs
where (bs, e0) = collectBnds x [] e
collectBnds :: Id -> [Bind t] -> Expr t -> ([Bind t], Expr t)
collectBnds x bs (Let b@(NonRec _ (Case (Var v) _ _ _ )) e)
| x == v
= collectBnds x (b:bs) e
collectBnds x bs (Tick t e)
= (bs', Tick t e')
where (bs', e') = collectBnds x bs e
collectBnds _ bs e
= (bs, e)
type TE = State TrEnv
data TrEnv = Tr { freshIndex :: !Int
, loc :: SrcSpan
}
initEnv = Tr 0 noSrcSpan
transPg = mapM transBd
applyTransToAllBds = False
transBd (NonRec x e)
| applyTransToAllBds
= liftM (NonRec x) (transExpr =<< (mapBdM transBd e))
| otherwise
= liftM (NonRec x) (transExpr e)
transBd e@(Rec xes)
| applyTransToAllBds
= liftM Rec (mapM (\(x, e) -> liftM ((,) x) (mapBdM transBd e)) xes)
| otherwise
= return e
transExpr :: CoreExpr -> TE CoreExpr
transExpr e
| chkRec e1'
= trans tvs ids bs e1'
| otherwise
= return e
where (tvs, ids, e') = collectTyAndValBinders e
(bs, e1') = collectNonRecLets e'
collectNonRecLets = go []
where go bs (Let b@(NonRec _ _) e') = go (b:bs) e'
go bs e' = (reverse bs, e')
appTysAndIds tvs ids x = mkApps (mkTyApps (Var x) (map TyVarTy tvs)) (map Var ids)
trans vs ids bs e
= liftM mkLet (trans_ vs liveIds bs e)
where liveIds = map mkAlive ids
mkLet es = foldr Lam es (vs ++ liveIds)
trans_ vs ids [] (Let (Rec xes) e)
= do fids <- mapM (mkFreshIds vs ids) xs
let (ftvs, fxids, fxs) = unzip3 fids
(se, rs) <- mkFreshBdrs vs ids xs fxs
let mkSu tvs ids0 = mkSubs ids tvs ids0 (zip xs fxs)
let mkE tvs ids0 e' = mkCoreLams (tvs ++ ids0) (sub (mkSu tvs ids0) e')
let es' = zipWith3 mkE ftvs fxids es
let xes' = zip fxs es'
return $ Let (Rec (xes' ++ rs)) (sub se e)
where (xs, es) = unzip xes
trans_ vs ids bs (Let (Rec xes) e)
= liftM mkLet $ trans_ vs ids [] (Let (Rec (zip xs es')) e)
where (xs, es) = unzip xes
es' = map (\e0 -> foldr Let e0 bs) es
mkLet e' = foldr Let e' bs
mkSubs ids tvs ids0 xxs'
= M.fromList $ s1 ++ s2
where s1 = map (mkSub ids tvs ids0) xxs'
s2 = zip ids (map Var ids0)
mkSub _ tvs ids0 (x, x') = (x, appTysAndIds tvs ids0 x')
mkFreshBdrs tvs ids xs xs'
= do xs0' <- mapM fresh xs
let xxs = zip xs (map Var xs0')
let s = M.fromList xxs
let ls = zipWith (\x0' x' -> (x0', appTysAndIds tvs ids x')) xs0' xs'
return (s, ls)
mkFreshIds tvs ids x
= do ids' <- mapM fresh ids
let tvs' = tvs
let t = mkForAllTys tvs' $ mkType (reverse ids') $ varType x
let x' = setVarType x t
return (tvs', ids', x')
mkType ids ty = foldl (\t x -> FunTy (varType x) t) ty ids
chkRec (Let (Rec _) _) = True
chkRec _ = False
class Freshable a where
fresh :: a -> TE a
instance Freshable Int where
fresh _ = freshInt
instance Freshable Unique where
fresh _ = freshUnique
instance Freshable Var where
fresh v = liftM (setVarUnique v) freshUnique
freshInt
= do s <- get
let n = freshIndex s
put s{freshIndex = n+1}
return n
freshUnique = liftM (mkUnique 'X') freshInt
mkAlive x
| isId x && isDeadOcc (idOccInfo x)
= setIdInfo x (setOccInfo (idInfo x) NoOccInfo)
| otherwise
= x
class Subable a where
sub :: M.HashMap CoreBndr CoreExpr -> a -> a
subTy :: M.HashMap TyVar Type -> a -> a
instance Subable CoreExpr where
sub s (Var v) = M.lookupDefault (Var v) v s
sub _ (Lit l) = Lit l
sub s (App e1 e2) = App (sub s e1) (sub s e2)
sub s (Lam b e) = Lam b (sub s e)
sub s (Let b e) = Let (sub s b) (sub s e)
sub s (Case e b t a) = Case (sub s e) (sub s b) t (map (sub s) a)
sub s (Cast e c) = Cast (sub s e) c
sub s (Tick t e) = Tick t (sub s e)
sub _ (Type t) = Type t
sub _ (Coercion c) = Coercion c
subTy s (Var v) = Var (subTy s v)
subTy _ (Lit l) = Lit l
subTy s (App e1 e2) = App (subTy s e1) (subTy s e2)
subTy s (Lam b e) | isTyVar b
= Lam v' (subTy s e)
where v' = case M.lookup b s of
Nothing -> b
Just (TyVarTy v) -> v
subTy s (Lam b e) = Lam (subTy s b) (subTy s e)
subTy s (Let b e) = Let (subTy s b) (subTy s e)
subTy s (Case e b t a) = Case (subTy s e) (subTy s b) (subTy s t) (map (subTy s) a)
subTy s (Cast e c) = Cast (subTy s e) (subTy s c)
subTy s (Tick t e) = Tick t (subTy s e)
subTy s (Type t) = Type (subTy s t)
subTy s (Coercion c) = Coercion (subTy s c)
instance Subable Coercion where
sub _ c = c
subTy _ _ = error "subTy Coercion"
instance Subable (Alt Var) where
sub s (a, b, e) = (a, map (sub s) b, sub s e)
subTy s (a, b, e) = (a, map (subTy s) b, subTy s e)
instance Subable Var where
sub s v = if M.member v s then subVar $ s M.! v else v
where subVar (Var x) = x
subVar _ = error "sub Var"
subTy s v = setVarType v (subTy s (varType v))
instance Subable (Bind Var) where
sub s (NonRec x e) = NonRec (sub s x) (sub s e)
sub s (Rec xes) = Rec (map (sub s *** sub s) xes)
subTy s (NonRec x e) = NonRec (subTy s x) (subTy s e)
subTy s (Rec xes) = Rec (map (subTy s *** subTy s) xes)
instance Subable Type where
sub _ e = e
subTy = substTysWith
substTysWith s tv@(TyVarTy v) = M.lookupDefault tv v s
substTysWith s (FunTy t1 t2) = FunTy (substTysWith s t1) (substTysWith s t2)
substTysWith s (ForAllTy v t) = ForAllTy v (substTysWith (M.delete v s) t)
substTysWith s (TyConApp c ts) = TyConApp c (map (substTysWith s) ts)
substTysWith s (AppTy t1 t2) = AppTy (substTysWith s t1) (substTysWith s t2)
mapBnd f (NonRec b e) = NonRec b (mapExpr f e)
mapBnd f (Rec bs) = Rec (map (second (mapExpr f)) bs)
mapExpr f (Let b@(NonRec x _) e) = Let b (f x e)
mapExpr f (App e1 e2) = App (mapExpr f e1) (mapExpr f e2)
mapExpr f (Lam b e) = Lam b (mapExpr f e)
mapExpr f (Let bs e) = Let (mapBnd f bs) (mapExpr f e)
mapExpr f (Case e b t alt) = Case e b t (map (mapAlt f) alt)
mapExpr f (Tick t e) = Tick t (mapExpr f e)
mapExpr _ e = e
mapAlt f (d, bs, e) = (d, bs, mapExpr f e)
mapBdM f (Let b e) = liftM2 Let (f b) (mapBdM f e)
mapBdM f (App e1 e2) = liftM2 App (mapBdM f e1) (mapBdM f e2)
mapBdM f (Lam b e) = liftM (Lam b) (mapBdM f e)
mapBdM f (Case e b t alt) = liftM (Case e b t) (mapM (mapBdAltM f) alt)
mapBdM f (Tick t e) = liftM (Tick t) (mapBdM f e)
mapBdM _ e = return e
mapBdAltM f (d, bs, e) = liftM ((,,) d bs) (mapBdM f e)