{- stack
  ghci
  --resolver lts-18.7
  --package containers
  --package megaparsec
  --package parser-combinators
  --package mtl
  --package lifted-base
  --package transformers-base
  --package pretty-simple
  --ghc-options "-Wall -Wno-name-shadowing"
-}

-- start snippet imports
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RecordWildCards #-}

import Control.Monad (unless, void, when)
import Control.Monad.Base (MonadBase)
import Control.Monad.Combinators.Expr (Operator (..), makeExprParser)
import Control.Monad.Cont (ContT, MonadCont, callCC, runContT)
import Control.Monad.Except (ExceptT, MonadError (..), runExceptT)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.State.Strict (MonadState, StateT, evalStateT)
import qualified Control.Monad.State.Strict as State
import Data.Foldable (for_, traverse_)
import Data.IORef.Lifted
import qualified Data.Map.Strict as Map
import Data.Maybe (fromMaybe)
import Data.Sequence (ViewL ((:<)), (|>))
import qualified Data.Sequence as Seq
import Data.Void (Void)
import Text.Megaparsec hiding (runParser)
import Text.Megaparsec.Char
import qualified Text.Megaparsec.Char.Lexer as L
import Text.Pretty.Simple
  ( CheckColorTty (..),
    OutputOptions (..),
    defaultOutputOptionsNoColor,
    pPrintOpt,
  )
-- end snippet imports

-- start snippet ast-expression
data Expr
  = LNull
  | LBool Bool
  | LStr String
  | LNum Integer
  | Variable Identifier
  | Binary BinOp Expr Expr
  | Call Identifier [Expr]
  | Receive Expr
  deriving (Show, Eq)

type Identifier = String

data BinOp = Plus | Minus | Equals | NotEquals | LessThan | GreaterThan
  deriving (Show, Eq)
-- end snippet ast-expression

-- start snippet ast-statement
data Stmt
  = ExprStmt Expr
  | VarStmt Identifier Expr
  | AssignStmt Identifier Expr
  | IfStmt Expr [Stmt]
  | WhileStmt Expr [Stmt]
  | FunctionStmt Identifier [Identifier] [Stmt]
  | ReturnStmt (Maybe Expr)
  | YieldStmt
  | SpawnStmt Expr
  | SendStmt Expr Identifier
  deriving (Show, Eq)

type Program = [Stmt]
-- end snippet ast-statement

-- start snippet basic-parsers
type Parser = Parsec Void String

sc :: Parser ()
sc = L.space (void spaceChar) lineCmnt blockCmnt
  where
    lineCmnt = L.skipLineComment "//"
    blockCmnt = L.skipBlockComment "/*" "*/"

lexeme :: Parser a -> Parser a
lexeme = L.lexeme sc

symbol :: String -> Parser String
symbol = L.symbol sc

parens, braces :: Parser a -> Parser a
parens = between (symbol "(") (symbol ")")
braces = between (symbol "{") (symbol "}")

semi, identifier, stringLiteral :: Parser String
semi = symbol ";"
identifier = lexeme ((:) <$> letterChar <*> many alphaNumChar)
stringLiteral = char '"' >> manyTill L.charLiteral (char '"') <* sc

integer :: Parser Integer
integer = lexeme (L.signed sc L.decimal)
-- end snippet basic-parsers

-- start snippet parser-utils
runParser :: Parser a -> String -> Either String a
runParser parser code = do
  case parse parser "" code of
    Left err -> Left $ errorBundlePretty err
    Right prog -> Right prog

pPrint :: (MonadIO m, Show a) => a -> m ()
pPrint =
  pPrintOpt CheckColorTty $
    defaultOutputOptionsNoColor
      { outputOptionsIndentAmount = 2,
        outputOptionsCompact = True,
        outputOptionsCompactParens = True
      }
-- end snippet parser-utils

-- start snippet expr-op
operators :: [[Operator Parser Expr]]
operators =
  [ [Prefix $ Receive <$ symbol "<-"],
    [ binary Plus $ symbol "+",
      binary Minus $ try (symbol "-" <* notFollowedBy (char '>'))
    ],
    [ binary LessThan $ symbol "<",
      binary GreaterThan $ symbol ">"
    ],
    [ binary Equals $ symbol "==",
      binary NotEquals $ symbol "!="
    ]
  ]
  where
    binary op symP = InfixL $ Binary op <$ symP
-- end snippet expr-op

-- start snippet expr-term
term :: Parser Expr
term =
  LNull <$ symbol "null"
    <|> LBool True <$ symbol "true"
    <|> LBool False <$ symbol "false"
    <|> LStr <$> stringLiteral
    <|> LNum <$> integer
    <|> try (Call <$> identifier <*> parens (sepBy expr (char ',' *> sc)))
    <|> Variable <$> identifier
    <|> parens expr
-- end snippet expr-term

-- start snippet expr
expr :: Parser Expr
expr = makeExprParser term operators
-- end snippet expr

-- start snippet stmt
stmt :: Parser Stmt
stmt =
  IfStmt <$> (symbol "if" *> parens expr) <*> braces (many stmt)
    <|> WhileStmt <$> (symbol "while" *> parens expr) <*> braces (many stmt)
    <|> VarStmt <$> (symbol "var" *> identifier) <*> (symbol "=" *> expr <* semi)
    <|> YieldStmt <$ symbol "yield" <* semi
    <|> SpawnStmt <$> (symbol "spawn" *> expr <* semi)
    <|> ReturnStmt <$> (symbol "return" *> optional expr <* semi)
    <|> FunctionStmt
      <$> (symbol "function" *> identifier)
      <*> parens (sepBy identifier (char ',' *> sc))
      <*> braces (many stmt)
    <|> try (AssignStmt <$> identifier <*> (symbol "=" *> expr <* semi))
    <|> try (SendStmt <$> expr <*> (symbol "->" *> identifier <* semi))
    <|> ExprStmt <$> expr <* semi
-- end snippet stmt

-- start snippet program
program :: Parser Program
program = sc *> many stmt <* eof
-- end snippet program

-- start snippet value
data Value
  = Null
  | Boolean Bool
  | Str String
  | Num Integer
  | Function Identifier [Identifier] [Stmt] Env
  | Chan Channel
-- end snippet value

-- start snippet value-insts
instance Show Value where
  show = \case
    Null -> "null"
    Boolean b -> show b
    Str s -> s
    Num n -> show n
    Function name _ _ _ -> "function " <> name
    Chan Channel {} -> "Channel"

instance Eq Value where
  Null == Null = True
  Boolean b1 == Boolean b2 = b1 == b2
  Str s1 == Str s2 = s1 == s2
  Num n1 == Num n2 = n1 == n2
  _ == _ = False
-- end snippet value-insts

-- start snippet coroutine
data Coroutine a = Coroutine InterpreterState (a -> Interpreter ())
-- end snippet coroutine

-- start snippet channel
data Channel = Channel
  { channelSendQueue :: Queue (Coroutine (), Value),
    channelReceiveQueue :: Queue (Coroutine Value)
  }

newChannel :: Interpreter Channel
newChannel = Channel <$> newIORef Seq.empty <*> newIORef Seq.empty
-- end snippet channel

-- start snippet env
type Env = Map.Map Identifier (IORef Value)
-- end snippet env

-- start snippet state
type Queue a = IORef (Seq.Seq a)

data InterpreterState = InterpreterState
  { isEnv :: Env,
    isCoroutines :: Queue (Coroutine ())
  }

newInterpreterState :: IO InterpreterState
newInterpreterState = do
  coroutines <- newIORef Seq.empty
  return $ InterpreterState Map.empty coroutines
-- end snippet state

-- start snippet exception
data Exception
  = Return Value
  | RuntimeError String
  | CoroutineQueueEmpty
-- end snippet exception

-- start snippet interpreter
newtype Interpreter a = Interpreter
  { runInterpreter ::
      ExceptT Exception
        (ContT
            (Either Exception ())
            (StateT InterpreterState IO))
        a
  }
  deriving
    ( Functor,
      Applicative,
      Monad,
      MonadIO,
      MonadBase IO,
      MonadState InterpreterState,
      MonadError Exception,
      MonadCont
    )
-- end snippet interpreter

-- start snippet define-env
defineVar :: Identifier -> Value -> Interpreter ()
defineVar name value = do
  env <- State.gets isEnv
  env' <- defineVarEnv name value env
  setEnv env'

defineVarEnv :: Identifier -> Value -> Env -> Interpreter Env
defineVarEnv name value env = do
  valueRef <- newIORef value
  return $ Map.insert name valueRef env

setEnv :: Env -> Interpreter ()
setEnv env = State.modify' $ \is -> is {isEnv = env}
-- end snippet define-env

-- start snippet lookup-assign-env
lookupVar :: Identifier -> Interpreter Value
lookupVar name =
  State.gets isEnv >>= findValueRef name >>= readIORef

assignVar :: Identifier -> Value -> Interpreter ()
assignVar name value =
  State.gets isEnv >>= findValueRef name >>= flip writeIORef value
-- end snippet lookup-assign-env

-- start snippet find-value-ref
findValueRef :: Identifier -> Env -> Interpreter (IORef Value)
findValueRef name env =
  case Map.lookup name env of
    Just ref -> return ref
    Nothing -> throw $ "Unknown variable: " <> name

throw :: String -> Interpreter a
throw = throwError . RuntimeError
-- end snippet find-value-ref

-- start snippet evaluate
evaluate :: Expr -> Interpreter Value
evaluate = \case
  LNull -> pure Null
  LBool bool -> pure $ Boolean bool
  LStr str -> pure $ Str str
  LNum num -> pure $ Num num
  Variable v -> lookupVar v
  binary@Binary {} -> evaluateBinaryOp binary
  call@Call {} -> evaluateFuncCall call
  Receive expr ->
    evaluate expr >>= \case
      Chan channel -> channelReceive channel
      val -> throw $ "Cannot recieve from a non-channel: " <> show val
-- end snippet evaluate

-- start snippet evaluate-binary
evaluateBinaryOp :: Expr -> Interpreter Value
evaluateBinaryOp ~(Binary op leftE rightE) = do
  left <- evaluate leftE
  right <- evaluate rightE
  let errMsg msg = msg <> ": " <> show left <> " and " <> show right
  case (op, left, right) of
    (Plus, Num n1, Num n2) -> pure $ Num $ n1 + n2
    (Plus, Str s1, Str s2) -> pure $ Str $ s1 <> s2
    (Plus, Str s1, _) -> pure $ Str $ s1 <> show right
    (Plus, _, Str s2) -> pure $ Str $ show left <> s2
    (Plus, _, _) -> throw $ errMsg "Cannot add or append"

    (Minus, Num n1, Num n2) -> pure $ Num $ n1 - n2
    (Minus, _, _) -> throw $ errMsg "Cannot subtract"

    (LessThan, Num n1, Num n2) -> pure $ Boolean $ n1 < n2
    (LessThan, _, _) -> throw $ errMsg "Cannot compare non-numbers"
    (GreaterThan, Num n1, Num n2) -> pure $ Boolean $ n1 > n2
    (GreaterThan, _, _) -> throw $ errMsg "Cannot compare non-numbers"

    (Equals, _, _) -> pure $ Boolean $ left == right
    (NotEquals, _, _) -> pure $ Boolean $ left /= right
-- end snippet evaluate-binary

-- start snippet evaluate-call
evaluateFuncCall :: Expr -> Interpreter Value
evaluateFuncCall ~(Call funcName argEs) = case funcName of
  "newChannel" -> Chan <$> newChannel
  "print" -> executePrint argEs
  funcName -> lookupVar funcName >>= \case
    func@Function {} -> evaluateFuncCall' func argEs
    val -> throw $ "Cannot call a non-function: " <> show val
  where
    executePrint = \case
      [expr] -> evaluate expr >>= liftIO . print >> return Null
      _ -> throw "print must be called with exactly one argument"
-- end snippet evaluate-call

-- start snippet evaluate-func-call
evaluateFuncCall' :: Value -> [Expr] -> Interpreter Value
evaluateFuncCall'
    ~func@(Function funcName params body funcDefEnv) argEs = do
  checkArgCount
  funcCallEnv <- State.gets isEnv
  setupFuncEnv
  retVal <- executeBody funcCallEnv
  setEnv funcCallEnv
  return retVal
  where
    checkArgCount = when (length argEs /= length params) $
      throw $ funcName <> " call expected " <> show (length params)
              <> " argument(s) but received " <> show (length argEs)

    setupFuncEnv = do
      args <- traverse evaluate argEs
      funcDefEnv' <- defineVarEnv funcName func funcDefEnv
      setEnv funcDefEnv'
      for_ (zip params args) $ uncurry defineVar

    executeBody funcCallEnv =
      (traverse_ execute body >> return Null) `catchError` \case
        Return val -> return val
        err -> setEnv funcCallEnv >> throwError err
-- end snippet evaluate-func-call

-- start snippet execute
execute :: Stmt -> Interpreter ()
execute = \case
  ExprStmt expr -> void $ evaluate expr
  VarStmt name expr -> evaluate expr >>= defineVar name
  AssignStmt name expr -> evaluate expr >>= assignVar name
  IfStmt expr body -> do
    cond <- evaluate expr
    when (isTruthy cond) $
      traverse_ execute body
  while@(WhileStmt expr body) -> do
    cond <- evaluate expr
    when (isTruthy cond) $ do
      traverse_ execute body
      execute while
  ReturnStmt mExpr -> do
    mRet <- traverse evaluate mExpr
    throwError . Return . fromMaybe Null $ mRet
  FunctionStmt name params body -> do
    env <- State.gets isEnv
    defineVar name $ Function name params body env
  YieldStmt -> yield
  SpawnStmt expr -> spawn (void $ evaluate expr)
  SendStmt expr chan -> do
    val <- evaluate expr
    evaluate (Variable chan) >>= \case
      Chan channel -> channelSend val channel
      val' -> throw $ "Cannot send to a non-channel: " <> show val'
  where
    isTruthy = \case
      Null -> False
      Boolean b -> b
      _ -> True
-- end snippet execute

-- start snippet interpret
interpret :: Program -> IO (Either String ())
interpret program = do
  state <- newInterpreterState
  retVal <- flip evalStateT state
    . flip runContT return
    . runExceptT
    . runInterpreter
    $ (traverse_ execute program <* awaitTermination)
  case retVal of
    Left (RuntimeError err) -> return $ Left err
    Left (Return _) -> return $ Left "Cannot return for outside functions"
    Left CoroutineQueueEmpty -> return $ Right ()
    Right _ -> return $ Right ()
-- end snippet interpret

-- start snippet queue
enqueue :: a -> Queue a -> Interpreter ()
enqueue val queue = modifyIORef' queue (|> val)

dequeue :: Queue a -> Interpreter (Maybe a)
dequeue queue = atomicModifyIORef' queue $ \vals ->
  case Seq.viewl vals of
    Seq.EmptyL -> (vals, Nothing)
    val :< rest -> (rest, Just val)
-- end snippet queue

-- start snippet coroutine-ops
enqueueCoroutine :: Coroutine () -> Interpreter ()
enqueueCoroutine coroutine = State.gets isCoroutines >>= enqueue coroutine

dequeueCoroutine :: Interpreter ()
dequeueCoroutine =
  State.gets isCoroutines >>= dequeue >>= \case
    Nothing -> throwError CoroutineQueueEmpty
    Just (Coroutine state action) -> State.put state >> action ()

yield :: Interpreter ()
yield = do
  state <- State.get
  callCC $ \k -> do
    enqueueCoroutine (Coroutine state k)
    dequeueCoroutine

spawn :: Interpreter () -> Interpreter ()
spawn action = do
  state <- State.get
  callCC $ \k -> do
    enqueueCoroutine (Coroutine state k)
    action
    dequeueCoroutine

awaitTermination :: Interpreter ()
awaitTermination = do
  finished <- State.gets isCoroutines >>= fmap null . readIORef
  unless finished $ yield >> awaitTermination
-- end snippet coroutine-ops

-- start snippet channel-send
channelSend :: Value -> Channel -> Interpreter ()
channelSend value Channel {..} =
  dequeue channelReceiveQueue >>= \case
    Just (Coroutine state recieve) -> do
      enqueueCoroutine $ Coroutine state (const $ recieve value)
      yield
    Nothing -> do
      state <- State.get
      callCC $ \k -> do
        enqueue (Coroutine state k, value) channelSendQueue
        dequeueCoroutine
-- end snippet channel-send

-- start snippet channel-receive
channelReceive :: Channel -> Interpreter Value
channelReceive Channel {..} =
  dequeue channelSendQueue >>= \case
    Just (coroutine, value) -> do
      enqueueCoroutine coroutine
      state <- State.get
      callCC $ \k -> do
        enqueueCoroutine (Coroutine state (const $ k value))
        dequeueCoroutine
        return Null
    Nothing -> do
      state <- State.get
      callCC $ \k -> do
        enqueue (Coroutine state k) channelReceiveQueue
        dequeueCoroutine
        return Null
-- end snippet channel-receive

-- start snippet run-file
runFile :: FilePath -> IO ()
runFile file = do
  code <- readFile file
  case runParser program code of
    Left err -> putStrLn err
    Right program -> interpret program >>= \case
      Left err -> putStrLn $ "ERROR: " <> err
      _ -> return ()
-- end snippet run-file