{- stack
  ghci
  --package containers
  --package megaparsec
  --package parser-combinators
  --package mtl
  --package lifted-base
  --package transformers-base
  --package pretty-simple
  --ghc-options "-Wall -Wno-name-shadowing"
-}

-- start snippet imports
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RecordWildCards #-}
module CoInterpreter where

import Control.Monad (unless, void, when)
import Control.Monad.Base (MonadBase)
import Control.Monad.Combinators.Expr (Operator (..), makeExprParser)
import Control.Monad.Cont (ContT, MonadCont, callCC, runContT)
import Control.Monad.Except (ExceptT, MonadError (..), runExceptT)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.State.Strict (MonadState, StateT, evalStateT)
import qualified Control.Monad.State.Strict as State
import Data.Foldable (for_, traverse_)
import Data.IORef.Lifted
import qualified Data.Map.Strict as Map
import Data.Maybe (fromMaybe)
import Data.Sequence (ViewL ((:<)), (|>))
import qualified Data.Sequence as Seq
import Data.Void (Void)
import System.IO (hPutStrLn, stderr)
import Text.Megaparsec hiding (runParser)
import Text.Megaparsec.Char
import qualified Text.Megaparsec.Char.Lexer as L
import Text.Pretty.Simple
  ( CheckColorTty (..),
    OutputOptions (..),
    defaultOutputOptionsNoColor,
    pPrintOpt,
  )
-- end snippet imports

-- start snippet ast-expression
data Expr
  = LNull
  | LBool Bool
  | LStr String
  | LNum Integer
  | Variable Identifier
  | Binary BinOp Expr Expr
  | Call Identifier [Expr]
  | Receive Expr
  deriving (Show, Eq)

type Identifier = String

data BinOp = Plus | Minus | Div | Equals | NotEquals | LessThan | GreaterThan
  deriving (Show, Eq)
-- end snippet ast-expression

-- start snippet ast-statement
data Stmt
  = ExprStmt Expr
  | VarStmt Identifier Expr
  | AssignStmt Identifier Expr
  | IfStmt Expr [Stmt]
  | WhileStmt Expr [Stmt]
  | FunctionStmt Identifier [Identifier] [Stmt]
  | ReturnStmt (Maybe Expr)
  | YieldStmt
  | SpawnStmt Expr
  | SendStmt Expr Identifier
  deriving (Show, Eq)

type Program = [Stmt]
-- end snippet ast-statement

-- start snippet basic-parsers
type Parser = Parsec Void String

sc :: Parser ()
sc = L.space (void spaceChar) lineCmnt blockCmnt
  where
    lineCmnt = L.skipLineComment "//"
    blockCmnt = L.skipBlockComment "/*" "*/"

lexeme :: Parser a -> Parser a
lexeme = L.lexeme sc

symbol :: String -> Parser String
symbol = L.symbol sc

parens, braces :: Parser a -> Parser a
parens = between (symbol "(") (symbol ")")
braces = between (symbol "{") (symbol "}")

semi, identifier, stringLiteral :: Parser String
semi = symbol ";"
identifier = lexeme ((:) <$> letterChar <*> many alphaNumChar)
stringLiteral = char '"' >> manyTill L.charLiteral (char '"') <* sc

integer :: Parser Integer
integer = lexeme (L.signed sc L.decimal)
-- end snippet basic-parsers

-- start snippet parser-utils
runParser :: Parser a -> String -> Either String a
runParser parser code = do
  case parse parser "" code of
    Left err -> Left $ errorBundlePretty err
    Right prog -> Right prog

pPrint :: (MonadIO m, Show a) => a -> m ()
pPrint =
  pPrintOpt CheckColorTty $
    defaultOutputOptionsNoColor
      { outputOptionsIndentAmount = 2,
        outputOptionsCompact = True,
        outputOptionsCompactParens = True
      }
-- end snippet parser-utils

-- start snippet expr-op
operators :: [[Operator Parser Expr]]
operators =
  [ [Prefix $ Receive <$ symbol "<-"],
    [ binary Div $ symbol "/"],
    [ binary Plus $ symbol "+",
      binary Minus $ try (symbol "-" <* notFollowedBy (char '>'))
    ],
    [ binary LessThan $ symbol "<",
      binary GreaterThan $ symbol ">"
    ],
    [ binary Equals $ symbol "==",
      binary NotEquals $ symbol "!="
    ]
  ]
  where
    binary op symP = InfixL $ Binary op <$ symP
-- end snippet expr-op

-- start snippet expr-term
term :: Parser Expr
term =
  LNull <$ symbol "null"
    <|> LBool True <$ symbol "true"
    <|> LBool False <$ symbol "false"
    <|> LStr <$> stringLiteral
    <|> LNum <$> integer
    <|> try (Call <$> identifier <*> parens (sepBy expr (char ',' *> sc)))
    <|> Variable <$> identifier
    <|> parens expr
-- end snippet expr-term

-- start snippet expr
expr :: Parser Expr
expr = makeExprParser term operators
-- end snippet expr

-- start snippet stmt
stmt :: Parser Stmt
stmt =
  IfStmt <$> (symbol "if" *> parens expr) <*> braces (many stmt)
    <|> WhileStmt <$> (symbol "while" *> parens expr) <*> braces (many stmt)
    <|> VarStmt <$> (symbol "var" *> identifier) <*> (symbol "=" *> expr <* semi)
    <|> YieldStmt <$ symbol "yield" <* semi
    <|> SpawnStmt <$> (symbol "spawn" *> expr <* semi)
    <|> ReturnStmt <$> (symbol "return" *> optional expr <* semi)
    <|> FunctionStmt
      <$> (symbol "function" *> identifier)
      <*> parens (sepBy identifier (char ',' *> sc))
      <*> braces (many stmt)
    <|> try (AssignStmt <$> identifier <*> (symbol "=" *> expr <* semi))
    <|> try (SendStmt <$> expr <*> (symbol "->" *> identifier <* semi))
    <|> ExprStmt <$> expr <* semi
-- end snippet stmt

-- start snippet program
program :: Parser Program
program = sc *> many stmt <* eof
-- end snippet program

-- start snippet value
data Value
  = Null
  | Boolean Bool
  | Str String
  | Num Integer
  | Function Identifier [Identifier] [Stmt] Env
  | BuiltinFunction Identifier Int ([Expr] -> Interpreter Value)
  | Chan Channel
-- end snippet value

-- start snippet value-insts
instance Show Value where
  show = \case
    Null -> "null"
    Boolean b -> show b
    Str s -> s
    Num n -> show n
    Function name _ _ _ -> "function " <> name
    BuiltinFunction name _ _ -> "function " <> name
    Chan Channel {} -> "Channel"

instance Eq Value where
  Null == Null = True
  Boolean b1 == Boolean b2 = b1 == b2
  Str s1 == Str s2 = s1 == s2
  Num n1 == Num n2 = n1 == n2
  _ == _ = False
-- end snippet value-insts

-- start snippet coroutine
data Coroutine a = Coroutine Env (a -> Interpreter ())
-- end snippet coroutine

-- start snippet channel
data Channel = Channel
  { channelSendQueue :: Queue (Coroutine (), Value),
    channelReceiveQueue :: Queue (Coroutine Value)
  }

newChannel :: Interpreter Channel
newChannel = Channel <$> newIORef Seq.empty <*> newIORef Seq.empty
-- end snippet channel

-- start snippet env
type Env = Map.Map Identifier (IORef Value)
-- end snippet env

-- start snippet state
type Queue a = IORef (Seq.Seq a)

data InterpreterState = InterpreterState
  { isEnv :: Env,
    isCoroutines :: Queue (Coroutine ())
  }

initInterpreterState :: IO InterpreterState
initInterpreterState =
  InterpreterState <$> builtinEnv <*> newIORef Seq.empty

builtinEnv :: IO Env
builtinEnv = do
  printFn <- newIORef $ BuiltinFunction "print" 1 executePrint
  newChannelFn <- newIORef $
    BuiltinFunction "newChannel" 0 $ fmap Chan . const newChannel
  return $ Map.fromList [
      ("print", printFn)
    , ("newChannel", newChannelFn)
    ]
-- end snippet state

-- start snippet exception
data Exception
  = Return Value
  | RuntimeError String
  | CoroutineQueueEmpty
-- end snippet exception

-- start snippet interpreter
newtype Interpreter a = Interpreter
  { runInterpreter ::
      ExceptT Exception
        (ContT
            (Either Exception ())
            (StateT InterpreterState IO))
        a
  }
  deriving
    ( Functor,
      Applicative,
      Monad,
      MonadIO,
      MonadBase IO,
      MonadState InterpreterState,
      MonadError Exception,
      MonadCont
    )
-- end snippet interpreter

-- start snippet define-env
defineVar :: Identifier -> Value -> Interpreter ()
defineVar name value = do
  env <- State.gets isEnv
  env' <- defineVarEnv name value env
  setEnv env'

defineVarEnv :: Identifier -> Value -> Env -> Interpreter Env
defineVarEnv name value env = do
  valueRef <- newIORef value
  return $ Map.insert name valueRef env

setEnv :: Env -> Interpreter ()
setEnv env = State.modify' $ \is -> is {isEnv = env}
-- end snippet define-env

-- start snippet lookup-assign-env
lookupVar :: Identifier -> Interpreter Value
lookupVar name =
  State.gets isEnv >>= findValueRef name >>= readIORef

assignVar :: Identifier -> Value -> Interpreter ()
assignVar name value =
  State.gets isEnv >>= findValueRef name >>= flip writeIORef value
-- end snippet lookup-assign-env

-- start snippet find-value-ref
findValueRef :: Identifier -> Env -> Interpreter (IORef Value)
findValueRef name env =
  case Map.lookup name env of
    Just ref -> return ref
    Nothing -> throw $ "Unknown variable: " <> name

throw :: String -> Interpreter a
throw = throwError . RuntimeError
-- end snippet find-value-ref

-- start snippet evaluate
evaluate :: Expr -> Interpreter Value
evaluate = \case
  LNull -> pure Null
  LBool bool -> pure $ Boolean bool
  LStr str -> pure $ Str str
  LNum num -> pure $ Num num
  Variable v -> lookupVar v
  binary@Binary {} -> evaluateBinaryOp binary
  call@Call {} -> evaluateFuncCall call
  Receive expr ->
    evaluate expr >>= \case
      Chan channel -> channelReceive channel
      val -> throw $ "Cannot recieve from a non-channel: " <> show val
-- end snippet evaluate

-- start snippet evaluate-binary
evaluateBinaryOp :: Expr -> Interpreter Value
evaluateBinaryOp ~(Binary op leftE rightE) = do
  left <- evaluate leftE
  right <- evaluate rightE
  let errMsg msg = msg <> ": " <> show left <> " and " <> show right
  case (op, left, right) of
    (Plus, Num n1, Num n2) -> pure $ Num $ n1 + n2
    (Plus, Str s1, Str s2) -> pure $ Str $ s1 <> s2
    (Plus, Str s1, _) -> pure $ Str $ s1 <> show right
    (Plus, _, Str s2) -> pure $ Str $ show left <> s2
    (Plus, _, _) -> throw $ errMsg "Cannot add or append"

    (Minus, Num n1, Num n2) -> pure $ Num $ n1 - n2
    (Minus, _, _) -> throw $ errMsg "Cannot subtract non-numbers"

    (Div, Num n1, Num n2) -> pure $ Num $ n1 `div` n2
    (Div, _, _) -> throw $ errMsg "Cannot divide non-numbers"

    (LessThan, Num n1, Num n2) -> pure $ Boolean $ n1 < n2
    (LessThan, _, _) -> throw $ errMsg "Cannot compare non-numbers"
    (GreaterThan, Num n1, Num n2) -> pure $ Boolean $ n1 > n2
    (GreaterThan, _, _) -> throw $ errMsg "Cannot compare non-numbers"

    (Equals, _, _) -> pure $ Boolean $ left == right
    (NotEquals, _, _) -> pure $ Boolean $ left /= right
-- end snippet evaluate-binary

-- start snippet evaluate-call
evaluateFuncCall :: Expr -> Interpreter Value
evaluateFuncCall ~(Call funcName argEs) =
  lookupVar funcName >>= \case
    BuiltinFunction _ arity func -> do
      checkArgCount funcName argEs arity
      func argEs
    func@Function {} -> evaluateFuncCall' func argEs
    val -> throw $ "Cannot call a non-function: " <> show val

checkArgCount :: Identifier -> [Expr] -> Int -> Interpreter ()
checkArgCount funcName argEs arity =
  when (length argEs /= arity) $
    throw $ funcName <> " call expected " <> show arity
            <> " argument(s) but received " <> show (length argEs)

executePrint :: [Expr] -> Interpreter Value
executePrint argEs =
  evaluate (head argEs) >>= liftIO . print >> return Null
-- end snippet evaluate-call

-- start snippet evaluate-func-call
evaluateFuncCall' :: Value -> [Expr] -> Interpreter Value
evaluateFuncCall'
    ~func@(Function funcName params body funcDefEnv) argEs = do
  checkArgCount funcName argEs (length params)
  funcCallEnv <- State.gets isEnv
  setupFuncEnv
  retVal <- executeBody funcCallEnv
  setEnv funcCallEnv
  return retVal
  where
    setupFuncEnv = do
      args <- traverse evaluate argEs
      funcDefEnv' <- defineVarEnv funcName func funcDefEnv
      setEnv funcDefEnv'
      for_ (zip params args) $ uncurry defineVar

    executeBody funcCallEnv =
      (traverse_ execute body >> return Null) `catchError` \case
        Return val -> return val
        err -> setEnv funcCallEnv >> throwError err
-- end snippet evaluate-func-call

-- start snippet execute
execute :: Stmt -> Interpreter ()
execute = \case
  ExprStmt expr -> void $ evaluate expr
  VarStmt name expr -> evaluate expr >>= defineVar name
  AssignStmt name expr -> evaluate expr >>= assignVar name
  IfStmt expr body -> do
    cond <- evaluate expr
    when (isTruthy cond) $
      traverse_ execute body
  while@(WhileStmt expr body) -> do
    cond <- evaluate expr
    when (isTruthy cond) $ do
      traverse_ execute body
      execute while
  ReturnStmt mExpr -> do
    mRet <- traverse evaluate mExpr
    throwError . Return . fromMaybe Null $ mRet
  FunctionStmt name params body -> do
    env <- State.gets isEnv
    defineVar name $ Function name params body env
  YieldStmt -> yield
  SpawnStmt expr -> spawn (void $ evaluate expr)
  SendStmt expr chan -> do
    val <- evaluate expr
    evaluate (Variable chan) >>= \case
      Chan channel -> channelSend val channel
      val' -> throw $ "Cannot send to a non-channel: " <> show val'
  where
    isTruthy = \case
      Null -> False
      Boolean b -> b
      _ -> True
-- end snippet execute

-- start snippet queue
enqueue :: a -> Queue a -> Interpreter ()
enqueue val queue = modifyIORef' queue (|> val)

dequeue :: Queue a -> Interpreter (Maybe a)
dequeue queue = atomicModifyIORef' queue $ \vals ->
  case Seq.viewl vals of
    Seq.EmptyL -> (vals, Nothing)
    val :< rest -> (rest, Just val)
-- end snippet queue

-- start snippet coroutine-ops
scheduleCoroutine :: Coroutine () -> Interpreter ()
scheduleCoroutine coroutine =
  State.gets isCoroutines >>= enqueue coroutine

runNextCoroutine :: Interpreter ()
runNextCoroutine =
  State.gets isCoroutines >>= dequeue >>= \case
    Nothing -> throwError CoroutineQueueEmpty
    Just (Coroutine env action) -> setEnv env >> action ()
-- end snippet coroutine-ops

-- start snippet yield
yield :: Interpreter ()
yield = do
  env <- State.gets isEnv
  callCC $ \k -> do
    scheduleCoroutine (Coroutine env k)
    runNextCoroutine
-- end snippet yield

-- start snippet spawn
spawn :: Interpreter () -> Interpreter ()
spawn action = do
  env <- State.gets isEnv
  callCC $ \k -> do
    scheduleCoroutine (Coroutine env k)
    action
    runNextCoroutine
-- end snippet spawn

-- start snippet await-term
awaitTermination :: Interpreter ()
awaitTermination = do
  coroutines <- readIORef =<< State.gets isCoroutines
  unless (null coroutines) $ do
    yield
    awaitTermination
-- end snippet await-term

-- start snippet channel-send
channelSend :: Value -> Channel -> Interpreter ()
channelSend value Channel {..} =
  dequeue channelReceiveQueue >>= \case
    Just (Coroutine env recieve) -> do
      scheduleCoroutine $ Coroutine env (const $ recieve value)
      yield
    Nothing -> do
      env <- State.gets isEnv
      callCC $ \k -> do
        enqueue (Coroutine env k, value) channelSendQueue
        runNextCoroutine
-- end snippet channel-send

-- start snippet channel-receive
channelReceive :: Channel -> Interpreter Value
channelReceive Channel {..} =
  dequeue channelSendQueue >>= \case
    Just (coroutine, value) -> do
      scheduleCoroutine coroutine
      env <- State.gets isEnv
      callCC $ \k -> do
        scheduleCoroutine (Coroutine env (const $ k value))
        runNextCoroutine
        return Null
    Nothing -> do
      env <- State.gets isEnv
      callCC $ \k -> do
        enqueue (Coroutine env k) channelReceiveQueue
        runNextCoroutine
        return Null
-- end snippet channel-receive

-- start snippet interpret
interpret :: Program -> IO (Either String ())
interpret program = do
  state <- initInterpreterState
  retVal <- flip evalStateT state
    . flip runContT return
    . runExceptT
    . runInterpreter
    $ (traverse_ execute program >> awaitTermination)
  case retVal of
    Left (RuntimeError err) -> return $ Left err
    Left (Return _) -> return $ Left "Cannot return for outside functions"
    Left CoroutineQueueEmpty -> return $ Right ()
    Right _ -> return $ Right ()
-- end snippet interpret

-- start snippet run-file
runFile :: FilePath -> IO ()
runFile file = do
  code <- readFile file
  case runParser program code of
    Left err -> hPutStrLn stderr err
    Right program -> interpret program >>= \case
      Left err -> hPutStrLn stderr $ "ERROR: " <> err
      _ -> return ()
-- end snippet run-file