-- An almost port of Matt Might's CPS transformation code
-- (https://matt.might.net/articles/cps-conversion/) to Haskell.
-- Let is non-recursive and Alphatization code is added for correctness.
module CPSTransformer (transform) where
import Control.Monad (foldM, (>=>))
import Control.Monad.State.Strict (MonadState (..), State)
import Control.Monad.State.Strict qualified as State
import Data.List.NonEmpty (NonEmpty (..))
import Data.List.NonEmpty qualified as NEL
import Data.Map.Strict qualified as Map
import Data.Text (Text)
import Data.Text qualified as Text
type Identifier = Text
data Atomic e
= Lambda [Identifier] e
| Variable Identifier
| Boolean Bool
| Number Integer
| String Text
| Void
deriving (Show)
data AExpr
= Atomic (Atomic Expr)
| CallCC
deriving (Show)
data Expr
= AExpr AExpr
| Begin (NonEmpty Expr)
| If Expr Expr Expr
| Set Identifier Expr
| Let [(Identifier, AExpr)] Expr
| Prim PrimOp [Expr]
| Call Expr [Expr]
deriving (Show)
data PrimOp = Plus | Minus | Slash | Star | Equals deriving (Show)
data CAExpr
= CAtomic (Atomic CExpr)
| CPrim PrimOp [CAExpr]
deriving (Show)
data CExpr
= CIf CAExpr CExpr CExpr
| CSetThen Identifier CAExpr CExpr
| CLet [(Identifier, CAExpr)] CExpr
| CCall CAExpr [CAExpr]
deriving (Show)
cLambda :: [Identifier] -> CExpr -> CAExpr
cLambda params = CAtomic . Lambda params
cVariable :: Identifier -> CAExpr
cVariable = CAtomic . Variable
cVoid :: CAExpr
cVoid = CAtomic Void
type Transform = State Int
transform :: Expr -> CExpr
transform =
flip State.evalState 0
. (alphatize Map.empty >=> flip transformC (cVariable "return"))
alphatize :: Map.Map Identifier Identifier -> Expr -> Transform Expr
alphatize mapping = \case
AExpr a -> AExpr <$> alphatizeA mapping a
Begin es -> Begin <$> traverse (alphatize mapping) es
If cond then_ else_ ->
If <$> alphatize mapping cond <*> alphatize mapping then_ <*> alphatize mapping else_
Set var e -> (\ ~(Variable v) -> Set v)
<$> alphatizeAtomic mapping (Variable var)
<*> alphatize mapping e
Prim op es -> Prim op <$> traverse (alphatize mapping) es
Call f es -> Call <$> alphatize mapping f <*> traverse (alphatize mapping) es
Let bindings body -> do
(mapping', bindings') <-
foldM ( \(m, bs) (v, e) -> do
v' <- fresh v
let m' = Map.insert v v' m
e' <- alphatizeA m' e
pure (m', (v', e') : bs))
(mapping, [])
bindings
body' <- alphatize mapping' body
pure $ Let (reverse bindings') body'
alphatizeA :: Map.Map Identifier Identifier -> AExpr -> Transform AExpr
alphatizeA mapping = \case
Atomic a -> Atomic <$> alphatizeAtomic mapping a
CallCC -> pure CallCC
alphatizeAtomic :: Map.Map Identifier Identifier -> Atomic Expr -> Transform (Atomic Expr)
alphatizeAtomic mapping = \case
Number n -> pure $ Number n
String s -> pure $ String s
Boolean b -> pure $ Boolean b
Void -> pure Void
Variable v -> case Map.lookup v mapping of
Just v' -> pure $ Variable v'
Nothing -> error $ "Variable " <> Text.unpack v <> " not found in mapping"
Lambda vs body -> do
vs' <- traverse fresh vs
body' <- alphatize (Map.union (Map.fromList $ zip vs vs') mapping) body
pure $ Lambda vs' body'
transformK :: Expr -> (CAExpr -> Transform CExpr) -> Transform CExpr
transformK expr k = case expr of
AExpr _ -> atomify expr >>= k
Begin (e :| []) -> transformK e k
Begin (e :| es) ->
transformK e $ \_ ->
transformK (Begin $ NEL.fromList es) k
If cond then_ else_ -> do
v <- fresh "v"
cont' <- cLambda [v] <$> k (cVariable v)
transformK cond $ \cond' ->
CIf cond' <$> transformC then_ cont' <*> transformC else_ cont'
Set var e -> transformK e $ \e' -> CSetThen var e' <$> k cVoid
Let bindings body -> do
let (vs, es) = unzip bindings
es' <- traverse (atomify . AExpr) es
CLet (zip vs es') <$> transformK body k
_ -> do
v <- fresh "v"
cont' <- cLambda [v] <$> k (cVariable v)
transformC expr cont'
transformC :: Expr -> CAExpr -> Transform CExpr
transformC expr cont = case expr of
AExpr _ -> atomify expr >>= \e -> pure $ CCall cont [e]
Begin (e :| []) -> transformC e cont
Begin (e :| es) ->
transformK e $ \_ ->
transformC (Begin $ NEL.fromList es) cont
If cond then_ else_ -> do
k <- fresh "k"
lam <-
cLambda [k]
<$> transformK
cond
( \cond' ->
CIf cond'
<$> transformC then_ (cVariable k)
<*> transformC else_ (cVariable k)
)
pure $ CCall lam [cont]
Set var e -> transformK e $ \e' -> pure $ CSetThen var e' (CCall cont [cVoid])
Let bindings body -> do
let (vs, es) = unzip bindings
es' <- traverse (atomify . AExpr) es
CLet (zip vs es') <$> transformC body cont
Prim op es -> transformManyK es $ \es' ->
pure $ CCall cont [CPrim op es']
Call f es -> transformK f $ \f' -> transformManyK es $ \es' ->
pure $ CCall f' (es' <> [cont])
transformManyK :: [Expr] -> ([CAExpr] -> Transform CExpr) -> Transform CExpr
transformManyK expr k = case expr of
[] -> k []
e : es -> transformK e $ \e' -> transformManyK es $ \es' -> k (e' : es')
atomify :: Expr -> Transform CAExpr
atomify = \case
AExpr (Atomic a) -> go a
AExpr CallCC -> do
f <- fresh "f"
cc <- fresh "cc"
x <- fresh "v"
k <- fresh "k"
pure $
cLambda [f, cc] $
CCall
(cVariable f)
[ cLambda [x, k] $ CCall (cVariable cc) [cVariable x]
, cVariable cc
]
expr -> error $ "not an AExpr: " <> show expr
where
go = \case
(Lambda params body) -> do
k <- fresh "k"
cLambda (params <> [k]) <$> transformC body (cVariable k)
(Variable v) -> pure $ cVariable v
(Boolean b) -> pure $ CAtomic $ Boolean b
(Number n) -> pure $ CAtomic $ Number n
(String s) -> pure $ CAtomic $ String s
Void -> pure cVoid
fresh :: Identifier -> Transform Identifier
fresh prefix = do
i <- get
put (i + 1)
pure $ prefix <> Text.pack (show i)