cps-transform.hs

-- An almost port of Matt Might's CPS transformation code
-- (https://matt.might.net/articles/cps-conversion/) to Haskell.
-- Let is non-recursive and Alphatization code is added for correctness.
module CPSTransformer (transform) where

import Control.Monad (foldM, (>=>))
import Control.Monad.State.Strict (MonadState (..), State)
import Control.Monad.State.Strict qualified as State
import Data.List.NonEmpty (NonEmpty (..))
import Data.List.NonEmpty qualified as NEL
import Data.Map.Strict qualified as Map
import Data.Text (Text)
import Data.Text qualified as Text

type Identifier = Text

data Atomic e
  = Lambda [Identifier] e
  | Variable Identifier
  | Boolean Bool
  | Number Integer
  | String Text
  | Void
  deriving (Show)

data AExpr
  = Atomic (Atomic Expr)
  | CallCC
  deriving (Show)

data Expr
  = AExpr AExpr
  | Begin (NonEmpty Expr)
  | If Expr Expr Expr
  | Set Identifier Expr
  | Let [(Identifier, AExpr)] Expr
  | Prim PrimOp [Expr]
  | Call Expr [Expr]
  deriving (Show)

data PrimOp = Plus | Minus | Slash | Star | Equals deriving (Show)

data CAExpr
  = CAtomic (Atomic CExpr)
  | CPrim PrimOp [CAExpr]
  deriving (Show)

data CExpr
  = CIf CAExpr CExpr CExpr
  | CSetThen Identifier CAExpr CExpr
  | CLet [(Identifier, CAExpr)] CExpr
  | CCall CAExpr [CAExpr]
  deriving (Show)

cLambda :: [Identifier] -> CExpr -> CAExpr
cLambda params = CAtomic . Lambda params

cVariable :: Identifier -> CAExpr
cVariable = CAtomic . Variable

cVoid :: CAExpr
cVoid = CAtomic Void

type Transform = State Int

transform :: Expr -> CExpr
transform =
  flip State.evalState 0
    . (alphatize Map.empty >=> flip transformC (cVariable "return"))

alphatize :: Map.Map Identifier Identifier -> Expr -> Transform Expr
alphatize mapping = \case
  AExpr a -> AExpr <$> alphatizeA mapping a
  Begin es -> Begin <$> traverse (alphatize mapping) es
  If cond then_ else_ ->
    If <$> alphatize mapping cond <*> alphatize mapping then_ <*> alphatize mapping else_
  Set var e -> (\ ~(Variable v) -> Set v)
      <$> alphatizeAtomic mapping (Variable var)
      <*> alphatize mapping e
  Prim op es -> Prim op <$> traverse (alphatize mapping) es
  Call f es -> Call <$> alphatize mapping f <*> traverse (alphatize mapping) es
  Let bindings body -> do
    (mapping', bindings') <-
      foldM ( \(m, bs) (v, e) -> do
                v' <- fresh v
                let m' = Map.insert v v' m
                e' <- alphatizeA m' e
                pure (m', (v', e') : bs))
        (mapping, [])
        bindings
    body' <- alphatize mapping' body
    pure $ Let (reverse bindings') body'

alphatizeA :: Map.Map Identifier Identifier -> AExpr -> Transform AExpr
alphatizeA mapping = \case
  Atomic a -> Atomic <$> alphatizeAtomic mapping a
  CallCC -> pure CallCC

alphatizeAtomic :: Map.Map Identifier Identifier -> Atomic Expr -> Transform (Atomic Expr)
alphatizeAtomic mapping = \case
  Number n -> pure $ Number n
  String s -> pure $ String s
  Boolean b -> pure $ Boolean b
  Void -> pure Void
  Variable v -> case Map.lookup v mapping of
    Just v' -> pure $ Variable v'
    Nothing -> error $ "Variable " <> Text.unpack v <> " not found in mapping"
  Lambda vs body -> do
    vs' <- traverse fresh vs
    body' <- alphatize (Map.union (Map.fromList $ zip vs vs') mapping) body
    pure $ Lambda vs' body'

transformK :: Expr -> (CAExpr -> Transform CExpr) -> Transform CExpr
transformK expr k = case expr of
  AExpr _ -> atomify expr >>= k
  Begin (e :| []) -> transformK e k
  Begin (e :| es) ->
    transformK e $ \_ ->
      transformK (Begin $ NEL.fromList es) k
  If cond then_ else_ -> do
    v <- fresh "v"
    cont' <- cLambda [v] <$> k (cVariable v)
    transformK cond $ \cond' ->
      CIf cond' <$> transformC then_ cont' <*> transformC else_ cont'
  Set var e -> transformK e $ \e' -> CSetThen var e' <$> k cVoid
  Let bindings body -> do
    let (vs, es) = unzip bindings
    es' <- traverse (atomify . AExpr) es
    CLet (zip vs es') <$> transformK body k
  _ -> do
    v <- fresh "v"
    cont' <- cLambda [v] <$> k (cVariable v)
    transformC expr cont'

transformC :: Expr -> CAExpr -> Transform CExpr
transformC expr cont = case expr of
  AExpr _ -> atomify expr >>= \e -> pure $ CCall cont [e]
  Begin (e :| []) -> transformC e cont
  Begin (e :| es) ->
    transformK e $ \_ ->
      transformC (Begin $ NEL.fromList es) cont
  If cond then_ else_ -> do
    k <- fresh "k"
    lam <-
      cLambda [k]
        <$> transformK
          cond
          ( \cond' ->
              CIf cond'
                <$> transformC then_ (cVariable k)
                <*> transformC else_ (cVariable k)
          )
    pure $ CCall lam [cont]
  Set var e -> transformK e $ \e' -> pure $ CSetThen var e' (CCall cont [cVoid])
  Let bindings body -> do
    let (vs, es) = unzip bindings
    es' <- traverse (atomify . AExpr) es
    CLet (zip vs es') <$> transformC body cont
  Prim op es -> transformManyK es $ \es' ->
    pure $ CCall cont [CPrim op es']
  Call f es -> transformK f $ \f' -> transformManyK es $ \es' ->
    pure $ CCall f' (es' <> [cont])

transformManyK :: [Expr] -> ([CAExpr] -> Transform CExpr) -> Transform CExpr
transformManyK expr k = case expr of
  [] -> k []
  e : es -> transformK e $ \e' -> transformManyK es $ \es' -> k (e' : es')

atomify :: Expr -> Transform CAExpr
atomify = \case
  AExpr (Atomic a) -> go a
  AExpr CallCC -> do
    f <- fresh "f"
    cc <- fresh "cc"
    x <- fresh "v"
    k <- fresh "k"
    pure $
      cLambda [f, cc] $
        CCall
          (cVariable f)
          [ cLambda [x, k] $ CCall (cVariable cc) [cVariable x]
          , cVariable cc
          ]
  expr -> error $ "not an AExpr: " <> show expr
  where
    go = \case
      (Lambda params body) -> do
        k <- fresh "k"
        cLambda (params <> [k]) <$> transformC body (cVariable k)
      (Variable v) -> pure $ cVariable v
      (Boolean b) -> pure $ CAtomic $ Boolean b
      (Number n) -> pure $ CAtomic $ Number n
      (String s) -> pure $ CAtomic $ String s
      Void -> pure cVoid

fresh :: Identifier -> Transform Identifier
fresh prefix = do
  i <- get
  put (i + 1)
  pure $ prefix <> Text.pack (show i)