{- stack
ghci
--package containers
--package megaparsec
--package parser-combinators
--package mtl
--package lifted-base
--package transformers-base
--package pretty-simple
--package pqueue
--package clock
--ghc-options "-Wall -Wno-name-shadowing"
-}
-- start snippet imports
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RecordWildCards #-}
module CoInterpreter where
import Control.Concurrent (forkIO, threadDelay)
import Control.Concurrent.MVar.Lifted
import Control.Monad (unless, void, when)
import Control.Monad.Base (MonadBase, liftBase)
import Control.Monad.Combinators.Expr (Operator (..), makeExprParser)
import Control.Monad.Cont (ContT, MonadCont, callCC, runContT)
import Control.Monad.Except (ExceptT, MonadError (..), runExceptT)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.State.Strict (MonadState, StateT, evalStateT)
import qualified Control.Monad.State.Strict as State
import Data.Foldable (for_, traverse_)
import Data.IORef.Lifted
import qualified Data.Map.Strict as Map
import Data.Maybe (fromMaybe)
import qualified Data.PQueue.Prio.Min as PQ
import Data.Time.Clock.POSIX (getPOSIXTime)
import Data.Void (Void)
import System.Clock (Clock (Monotonic), fromNanoSecs, getTime, TimeSpec)
import System.Environment (getArgs, getProgName)
import System.IO (hPutStrLn, stderr)
import Text.Megaparsec hiding (runParser)
import Text.Megaparsec.Char
import qualified Text.Megaparsec.Char.Lexer as L
import Text.Pretty.Simple
( CheckColorTty (..),
OutputOptions (..),
defaultOutputOptionsNoColor,
pPrintOpt,
)
-- end snippet imports
-- start snippet ast-expression
data Expr
= LNull
| LBool Bool
| LStr String
| LNum Integer
| Variable Identifier
| Binary BinOp Expr Expr
| Call Expr [Expr]
| Lambda [Identifier] [Stmt]
| Receive Expr
deriving (Show, Eq)
type Identifier = String
data BinOp =
Plus
| Minus
| Slash
| Star
| Equals
| NotEquals
| LessThan
| GreaterThan
deriving (Show, Eq)
-- end snippet ast-expression
-- start snippet ast-statement
data Stmt
= ExprStmt Expr
| VarStmt Identifier Expr
| AssignStmt Identifier Expr
| IfStmt Expr [Stmt]
| WhileStmt Expr [Stmt]
| FunctionStmt Identifier [Identifier] [Stmt]
| ReturnStmt (Maybe Expr)
| YieldStmt
| SpawnStmt Expr
| SendStmt Expr Identifier
deriving (Show, Eq)
type Program = [Stmt]
-- end snippet ast-statement
-- start snippet basic-parsers
type Parser = Parsec Void String
sc :: Parser ()
sc = L.space space1 lineCmnt blockCmnt
where
lineCmnt = L.skipLineComment "//"
blockCmnt = L.skipBlockComment "/*" "*/"
lexeme :: Parser a -> Parser a
lexeme = L.lexeme sc
symbol :: String -> Parser String
symbol = L.symbol sc
reserved :: String -> Parser ()
reserved w = (lexeme . try) $ string w *> notFollowedBy alphaNumChar
parens, braces :: Parser a -> Parser a
parens = between (symbol "(") (symbol ")")
braces = between (symbol "{") (symbol "}")
semi, identifier, stringLiteral :: Parser String
semi = symbol ";"
identifier = lexeme ((:) <$> letterChar <*> many alphaNumChar)
stringLiteral = char '"' >> manyTill L.charLiteral (char '"') <* sc
integer :: Parser Integer
integer = lexeme (L.signed sc L.decimal)
-- end snippet basic-parsers
-- start snippet parser-utils
runParser :: Parser a -> String -> Either String a
runParser parser code = do
case parse parser "" code of
Left err -> Left $ errorBundlePretty err
Right prog -> Right prog
pPrint :: (MonadIO m, Show a) => a -> m ()
pPrint =
pPrintOpt CheckColorTty $
defaultOutputOptionsNoColor
{ outputOptionsIndentAmount = 2,
outputOptionsCompact = True,
outputOptionsCompactParens = True
}
-- end snippet parser-utils
-- start snippet expr-op
operators :: [[Operator Parser Expr]]
operators =
[ [Prefix $ Receive <$ symbol "<-"],
[ binary Slash $ symbol "/",
binary Star $ symbol "*"],
[ binary Plus $ symbol "+",
binary Minus $ try (symbol "-" <* notFollowedBy (char '>'))
],
[ binary LessThan $ symbol "<",
binary GreaterThan $ symbol ">"
],
[ binary Equals $ symbol "==",
binary NotEquals $ symbol "!="
]
]
where
binary op symP = InfixL $ Binary op <$ symP
-- end snippet expr-op
-- start snippet expr-term
term :: Parser Expr
term = primary >>= call
where
call e =
( lookAhead (symbol "(")
>> symbol "("
>> Call e <$> sepBy expr (symbol ",") <* symbol ")"
>>= call )
<|> pure e
primary = LNull <$ reserved "null"
<|> LBool True <$ reserved "true"
<|> LBool False <$ reserved "false"
<|> LStr <$> stringLiteral
<|> LNum <$> integer
<|> Lambda
<$> (reserved "function" *> parens (sepBy identifier $ symbol ","))
<*> braces (many stmt)
<|> Variable <$> identifier
<|> parens expr
-- end snippet expr-term
-- start snippet expr
expr :: Parser Expr
expr = makeExprParser term operators
-- end snippet expr
-- start snippet stmt
stmt :: Parser Stmt
stmt =
IfStmt <$> (reserved "if" *> parens expr) <*> braces (many stmt)
<|> WhileStmt <$> (reserved "while" *> parens expr) <*> braces (many stmt)
<|> VarStmt <$> (reserved "var" *> identifier) <*> (symbol "=" *> expr <* semi)
<|> YieldStmt <$ (reserved "yield" <* semi)
<|> SpawnStmt <$> (reserved "spawn" *> expr <* semi)
<|> ReturnStmt <$> (reserved "return" *> optional expr <* semi)
<|> FunctionStmt
<$> (try $ reserved "function" *> identifier)
<*> parens (sepBy identifier $ symbol ",")
<*> braces (many stmt)
<|> try (AssignStmt <$> identifier <*> (symbol "=" *> expr <* semi))
<|> try (SendStmt <$> expr <*> (symbol "->" *> identifier <* semi))
<|> ExprStmt <$> expr <* semi
-- end snippet stmt
-- start snippet program
program :: Parser Program
program = sc *> many stmt <* eof
-- end snippet program
-- start snippet value
data Value
= Null
| Boolean Bool
| Str String
| Num Integer
| Function Identifier [Identifier] [Stmt] Env
| BuiltinFunction Identifier Int ([Expr] -> Interpreter Value)
| Chan Channel
-- end snippet value
-- start snippet value-insts
instance Show Value where
show = \case
Null -> "null"
Boolean b -> show b
Str s -> s
Num n -> show n
Function name _ _ _ -> "function " <> name
BuiltinFunction name _ _ -> "function " <> name
Chan Channel {} -> "Channel"
instance Eq Value where
Null == Null = True
Boolean b1 == Boolean b2 = b1 == b2
Str s1 == Str s2 = s1 == s2
Num n1 == Num n2 = n1 == n2
_ == _ = False
-- end snippet value-insts
-- start snippet coroutine
data Coroutine a = Coroutine
{ corEnv :: Env
, corCont :: a -> Interpreter ()
, corReady :: MVar TimeSpec
}
newCoroutine :: Env -> (a -> Interpreter ()) -> Interpreter (Coroutine a)
newCoroutine env cont = do
ready <- newMVar =<< currentSystemTime
return $ Coroutine env cont ready
-- end snippet coroutine
-- start snippet delayed-coroutine
newDelayedCoroutine ::
Integer -> Env -> (a -> Interpreter ()) -> Interpreter (Coroutine a)
newDelayedCoroutine millis env cont = do
ready <- newEmptyMVar
liftIO $ forkIO $ do
threadDelay $ fromIntegral millis * 1000
now <- currentSystemTime
putMVar ready now
return $ Coroutine env cont ready
-- end snippet delayed-coroutine
-- start snippet env
type Env = Map.Map Identifier (IORef Value)
-- end snippet env
-- start snippet queue
type Queue a = IORef (PQ.MinPQueue TimeSpec a, TimeSpec)
newQueue :: MonadBase IO m => m (Queue a)
newQueue = do
now <- liftBase currentSystemTime
newIORef (PQ.empty, now)
-- end snippet queue
-- start snippet state
data InterpreterState = InterpreterState
{ isEnv :: Env,
isCoroutines :: Queue (Coroutine ())
}
initInterpreterState :: IO InterpreterState
initInterpreterState = InterpreterState <$> builtinEnv <*> newQueue
-- end snippet state
-- start snippet builtin-env
builtinEnv :: IO Env
builtinEnv = Map.fromList <$> traverse (traverse newIORef) [
("print", BuiltinFunction "print" 1 executePrint)
, ("newChannel",
BuiltinFunction "newChannel" 0 $ fmap Chan . const (newChannel 0))
, ("newBufferedChannel",
BuiltinFunction "newBufferedChannel" 1 executeNewBufferedChannel)
, ("sleep", BuiltinFunction "sleep" 1 executeSleep)
, ("getCurrentMillis",
BuiltinFunction "getCurrentMillis" 0 executeGetCurrentMillis)
]
-- end snippet builtin-env
-- start snippet channel
data Channel = Channel
{ channelSize :: Int,
channelBuffer :: Queue Value,
channelSendQueue :: Queue (Coroutine (), Value),
channelReceiveQueue :: Queue (Coroutine Value)
}
newChannel :: Int -> Interpreter Channel
newChannel size = Channel size <$> newQueue <*> newQueue <*> newQueue
-- end snippet channel
-- start snippet exception
data Exception
= Return Value
| RuntimeError String
| CoroutineQueueEmpty
-- end snippet exception
-- start snippet interpreter
newtype Interpreter a = Interpreter
{ runInterpreter ::
ExceptT Exception
(ContT
(Either Exception ())
(StateT InterpreterState IO))
a
}
deriving
( Functor,
Applicative,
Monad,
MonadIO,
MonadBase IO,
MonadState InterpreterState,
MonadError Exception,
MonadCont
)
-- end snippet interpreter
-- start snippet define-env
defineVar :: Identifier -> Value -> Interpreter ()
defineVar name value = do
env <- State.gets isEnv
env' <- defineVarEnv name value env
setEnv env'
defineVarEnv :: Identifier -> Value -> Env -> Interpreter Env
defineVarEnv name value env = do
valueRef <- newIORef value
return $ Map.insert name valueRef env
setEnv :: Env -> Interpreter ()
setEnv env = State.modify' $ \is -> is {isEnv = env}
-- end snippet define-env
-- start snippet lookup-assign-env
lookupVar :: Identifier -> Interpreter Value
lookupVar name =
State.gets isEnv >>= findValueRef name >>= readIORef
assignVar :: Identifier -> Value -> Interpreter ()
assignVar name value =
State.gets isEnv >>= findValueRef name >>= flip writeIORef value
-- end snippet lookup-assign-env
-- start snippet find-value-ref
findValueRef :: Identifier -> Env -> Interpreter (IORef Value)
findValueRef name env =
case Map.lookup name env of
Just ref -> return ref
Nothing -> throw $ "Unknown variable: " <> name
throw :: String -> Interpreter a
throw = throwError . RuntimeError
-- end snippet find-value-ref
-- start snippet evaluate
evaluate :: Expr -> Interpreter Value
evaluate = \case
LNull -> pure Null
LBool bool -> pure $ Boolean bool
LStr str -> pure $ Str str
LNum num -> pure $ Num num
Variable v -> lookupVar v
Lambda params body -> Function "<lambda>" params body <$> State.gets isEnv
binary@Binary {} -> evaluateBinaryOp binary
call@Call {} -> evaluateFuncCall call
Receive expr ->
evaluate expr >>= \case
Chan channel -> channelReceive channel
val -> throw $ "Cannot receive from a non-channel: " <> show val
-- end snippet evaluate
-- start snippet evaluate-binary
evaluateBinaryOp :: Expr -> Interpreter Value
evaluateBinaryOp ~(Binary op leftE rightE) = do
left <- evaluate leftE
right <- evaluate rightE
let errMsg msg = msg <> ": " <> show left <> " and " <> show right
case (op, left, right) of
(Plus, Num n1, Num n2) -> pure $ Num $ n1 + n2
(Plus, Str s1, Str s2) -> pure $ Str $ s1 <> s2
(Plus, Str s1, _) -> pure $ Str $ s1 <> show right
(Plus, _, Str s2) -> pure $ Str $ show left <> s2
(Plus, _, _) -> throw $ errMsg "Cannot add or append"
(Minus, Num n1, Num n2) -> pure $ Num $ n1 - n2
(Minus, _, _) -> throw $ errMsg "Cannot subtract non-numbers"
(Slash, Num n1, Num n2) -> pure $ Num $ n1 `div` n2
(Slash, _, _) -> throw $ errMsg "Cannot divide non-numbers"
(Star, Num n1, Num n2) -> pure $ Num $ n1 * n2
(Star, _, _) -> throw $ errMsg "Cannot multiply non-numbers"
(LessThan, Num n1, Num n2) -> pure $ Boolean $ n1 < n2
(LessThan, _, _) -> throw $ errMsg "Cannot compare non-numbers"
(GreaterThan, Num n1, Num n2) -> pure $ Boolean $ n1 > n2
(GreaterThan, _, _) -> throw $ errMsg "Cannot compare non-numbers"
(Equals, _, _) -> pure $ Boolean $ left == right
(NotEquals, _, _) -> pure $ Boolean $ left /= right
-- end snippet evaluate-binary
-- start snippet evaluate-call
evaluateFuncCall :: Expr -> Interpreter Value
evaluateFuncCall ~(Call callee argEs) =
evaluate callee >>= \case
BuiltinFunction name arity func -> do
checkArgCount name argEs arity
func argEs
func@Function {} -> evaluateFuncCall' func argEs
val -> throw $ "Cannot call a non-function: " <> show callee <> " is " <> show val
checkArgCount :: Identifier -> [Expr] -> Int -> Interpreter ()
checkArgCount funcName argEs arity =
when (length argEs /= arity) $
throw $ funcName <> " call expected " <> show arity
<> " argument(s) but received " <> show (length argEs)
executePrint :: [Expr] -> Interpreter Value
executePrint argEs =
evaluate (head argEs) >>= liftIO . print >> return Null
-- end snippet evaluate-call
-- start snippet execute-new-buffered-channel
executeNewBufferedChannel :: [Expr] -> Interpreter Value
executeNewBufferedChannel argEs = evaluate (head argEs) >>= \case
Num size | size >= 0 -> Chan <$> newChannel (fromIntegral size)
Num size -> throw $ "Channel size must be non-negative: " <> show size
_ -> throw "newBufferedChannel call expected a number argument"
-- end snippet execute-new-buffered-channel
-- start snippet execute-sleep
executeSleep :: [Expr] -> Interpreter Value
executeSleep argEs = evaluate (head argEs) >>= \case
Num n | n >= 0 -> sleep n >> return Null
Num n -> throw $ "Sleep time must be non-negative: " <> show n
_ -> throw "sleep call expected a number argument"
executeGetCurrentMillis :: [Expr] -> Interpreter Value
executeGetCurrentMillis _ =
Num . fromIntegral . floor . (* 1000) <$> liftIO getPOSIXTime
-- end snippet execute-sleep
-- start snippet evaluate-func-call
evaluateFuncCall' :: Value -> [Expr] -> Interpreter Value
evaluateFuncCall'
~func@(Function funcName params body funcDefEnv) argEs = do
checkArgCount funcName argEs (length params)
funcCallEnv <- State.gets isEnv
setupFuncEnv
retVal <- executeBody funcCallEnv
setEnv funcCallEnv
return retVal
where
setupFuncEnv = do
args <- traverse evaluate argEs
funcDefEnv' <- defineVarEnv funcName func funcDefEnv
setEnv funcDefEnv'
for_ (zip params args) $ uncurry defineVar
executeBody funcCallEnv =
(traverse_ execute body >> return Null) `catchError` \case
Return val -> return val
err -> setEnv funcCallEnv >> throwError err
-- end snippet evaluate-func-call
-- start snippet execute
execute :: Stmt -> Interpreter ()
execute = \case
ExprStmt expr -> void $ evaluate expr
VarStmt name expr -> evaluate expr >>= defineVar name
AssignStmt name expr -> evaluate expr >>= assignVar name
IfStmt expr body -> do
cond <- evaluate expr
when (isTruthy cond) $
traverse_ execute body
while@(WhileStmt expr body) -> do
cond <- evaluate expr
when (isTruthy cond) $ do
traverse_ execute body
execute while
ReturnStmt mExpr -> do
mRet <- traverse evaluate mExpr
throwError . Return . fromMaybe Null $ mRet
FunctionStmt name params body -> do
env <- State.gets isEnv
defineVar name $ Function name params body env
YieldStmt -> yield
SpawnStmt expr -> spawn expr
SendStmt expr chan -> do
val <- evaluate expr
evaluate (Variable chan) >>= \case
Chan channel -> channelSend val channel
val' -> throw $ "Cannot send to a non-channel: " <> show val'
where
isTruthy = \case
Null -> False
Boolean b -> b
_ -> True
-- end snippet execute
-- start snippet queue-ops
enqueueAt :: TimeSpec -> a -> Queue a -> Interpreter ()
enqueueAt time val queue = modifyIORef' queue $ \(q, maxWakeupTime) ->
( PQ.insert time val q,
if time > maxWakeupTime then time else maxWakeupTime
)
enqueue :: a -> Queue a -> Interpreter ()
enqueue val queue = do
now <- currentSystemTime
enqueueAt now val queue
currentSystemTime :: MonadIO m => m TimeSpec
currentSystemTime = liftIO $ getTime Monotonic
dequeue :: Queue a -> Interpreter (Maybe a)
dequeue queue = atomicModifyIORef' queue $ \(q, maxWakeupTime) ->
if PQ.null q
then ((q, maxWakeupTime), Nothing)
else let ((_, val), q') = PQ.deleteFindMin q
in ((q', maxWakeupTime), Just val)
-- end snippet queue-ops
-- start snippet schedule-delayed-coroutine
scheduleDelayedCoroutine :: TimeSpec -> Coroutine () -> Interpreter ()
scheduleDelayedCoroutine wakeupTime coroutine = do
State.gets isCoroutines >>= enqueueAt wakeupTime coroutine
-- end snippet schedule-delayed-coroutine
-- start snippet coroutine-ops
scheduleCoroutine :: Coroutine () -> Interpreter ()
scheduleCoroutine coroutine =
State.gets isCoroutines >>= enqueue coroutine
runNextCoroutine :: Interpreter ()
runNextCoroutine =
State.gets isCoroutines >>= dequeue >>= \case
Nothing -> throwError CoroutineQueueEmpty
Just Coroutine {..} -> do
takeMVar corReady
setEnv corEnv
corCont ()
-- end snippet coroutine-ops
-- start snippet yield
yield :: Interpreter ()
yield = do
env <- State.gets isEnv
callCC $ \cont -> do
newCoroutine env cont >>= scheduleCoroutine
runNextCoroutine
-- end snippet yield
-- start snippet spawn
spawn :: Expr -> Interpreter ()
spawn expr = do
env <- State.gets isEnv
coroutine <- newCoroutine env (const $ evaluate expr >> runNextCoroutine)
scheduleCoroutine coroutine
-- end snippet spawn
-- start snippet sleep
sleep :: Integer -> Interpreter ()
sleep millis = do
now <- currentSystemTime
let wakeupTime = now + fromNanoSecs (fromIntegral millis * 1000000)
env <- State.gets isEnv
callCC $ \cont -> do
scheduleDelayedCoroutine wakeupTime =<< newDelayedCoroutine millis env cont
runNextCoroutine
-- end snippet sleep
-- start snippet await-term
awaitTermination :: Interpreter ()
awaitTermination = do
(coroutines, maxWakeupTime) <- readIORef =<< State.gets isCoroutines
dur <- calcSleepDuration maxWakeupTime
unless (PQ.null coroutines) $ if dur > 0
then sleep dur >> awaitTermination
else yield >> awaitTermination
-- end snippet await-term
-- start snippet calc-sleep-duration
calcSleepDuration :: TimeSpec -> Interpreter Integer
calcSleepDuration maxWakeupTime = do
now <- currentSystemTime
return $ 1 + fromIntegral (maxWakeupTime - now) `div` 1000000
-- end snippet calc-sleep-duration
-- start snippet channel-send
channelSend :: Value -> Channel -> Interpreter ()
channelSend value Channel {..} = do
bufferSize <- length <$> readIORef channelBuffer
dequeue channelReceiveQueue >>= \case
-- there are pending receives
Just coroutine@Coroutine {..} ->
scheduleCoroutine $ coroutine { corCont = const $ corCont value }
-- there are no pending receives and the buffer is not full
Nothing | channelSize > 0 && bufferSize < channelSize ->
enqueue value channelBuffer
-- there are no pending receives and the buffer is full
-- or the channel is unbuffered
Nothing -> do
env <- State.gets isEnv
callCC $ \cont -> do
coroutine <- newCoroutine env cont
enqueue (coroutine, value) channelSendQueue
runNextCoroutine
-- end snippet channel-send
-- start snippet channel-receive
channelReceive :: Channel -> Interpreter Value
channelReceive Channel {..} = do
mSend <- dequeue channelSendQueue
mBufferedValue <- dequeue channelBuffer
case (mSend, mBufferedValue) of
-- the channel is unbuffered and there are pending sends
(Just (sendCoroutine, sendValue), Nothing) -> do
scheduleCoroutine sendCoroutine
return sendValue
-- the buffer is full and there are pending sends
(Just (sendCoroutine, sendValue), Just bufferedValue) -> do
scheduleCoroutine sendCoroutine
enqueue sendValue channelBuffer
return bufferedValue
-- the buffer is empty and there are no pending sends
(Nothing, Nothing) -> do
env <- State.gets isEnv
callCC $ \receive -> do
coroutine <- newCoroutine env receive
enqueue coroutine channelReceiveQueue
runNextCoroutine
return Null
-- the buffer is not empty and there are no pending sends
(Nothing, Just bufferedValue) -> return bufferedValue
-- end snippet channel-receive
-- start snippet interpret
interpret :: Program -> IO (Either String ())
interpret program = do
state <- initInterpreterState
retVal <- flip evalStateT state
. flip runContT return
. runExceptT
. runInterpreter
$ (traverse_ execute program >> awaitTermination)
case retVal of
Left (RuntimeError err) -> return $ Left err
Left (Return _) -> return $ Left "Cannot return from outside functions"
Left CoroutineQueueEmpty -> return $ Right ()
Right _ -> return $ Right ()
-- end snippet interpret
-- start snippet run-file
runFile :: FilePath -> IO ()
runFile file = do
code <- readFile file
case runParser program code of
Left err -> hPutStrLn stderr err
Right program -> interpret program >>= \case
Left err -> hPutStrLn stderr $ "ERROR: " <> err
_ -> return ()
-- end snippet run-file
-- start snippet main
main :: IO ()
main = getArgs >>= \case
[file] -> runFile file
_ -> do
prog <- getProgName
hPutStrLn stderr $ "Usage: " <> prog <> " <file>"
-- end snippet main