{- stack
  ghci
  --resolver lts-17.9
  --package containers
  --package megaparsec
  --package parser-combinators
  --package mtl
  --package lifted-base
  --package transformers-base
  --package pretty-simple
  --ghc-options "-Wall -Wno-name-shadowing"
-}

-- start snippet imports
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RecordWildCards #-}

import Control.Monad (forM_, unless, void, when)
import Control.Monad.Base (MonadBase)
import Control.Monad.Combinators.Expr (Operator (..), makeExprParser)
import Control.Monad.Cont (ContT, MonadCont, callCC, runContT)
import Control.Monad.Except (ExceptT, MonadError (..), runExceptT)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.State.Strict (MonadState, StateT, evalStateT)
import qualified Control.Monad.State.Strict as State
import Data.IORef.Lifted
import qualified Data.Map.Strict as Map
import Data.Maybe (fromMaybe)
import Data.Sequence (ViewL ((:<)), (|>))
import qualified Data.Sequence as Seq
import Data.Void (Void)
import Text.Megaparsec hiding (runParser)
import Text.Megaparsec.Char
import qualified Text.Megaparsec.Char.Lexer as L
import Text.Pretty.Simple
  ( CheckColorTty (..),
    OutputOptions (..),
    defaultOutputOptionsNoColor,
    pPrintOpt,
  )
-- end snippet imports

-- start snippet ast-expression
data Expr
  = LNull
  | LBool Bool
  | LStr String
  | LNum Integer
  | Variable Identifier
  | Binary BinOp Expr Expr
  | Call Identifier [Expr]
  | Receive Expr
  deriving (Show, Eq)

type Identifier = String

data BinOp = Plus | Minus | Equals | NotEquals | LessThan | GreaterThan
  deriving (Show, Eq)
-- end snippet ast-expression

-- start snippet ast-statement
data Stmt
  = ExprStmt Expr
  | VarStmt Identifier Expr
  | AssignStmt Identifier Expr
  | IfStmt Expr [Stmt]
  | WhileStmt Expr [Stmt]
  | FunctionStmt Identifier [Identifier] [Stmt]
  | ReturnStmt (Maybe Expr)
  | YieldStmt
  | SpawnStmt Expr
  | SendStmt Expr Identifier
  deriving (Show, Eq)

type Program = [Stmt]
-- end snippet ast-statement

-- start snippet basic-parsers
type Parser = Parsec Void String

sc :: Parser ()
sc = L.space (void spaceChar) lineCmnt blockCmnt
  where
    lineCmnt = L.skipLineComment "//"
    blockCmnt = L.skipBlockComment "/*" "*/"

lexeme :: Parser a -> Parser a
lexeme = L.lexeme sc

symbol :: String -> Parser String
symbol = L.symbol sc

parens, braces :: Parser a -> Parser a
parens = between (symbol "(") (symbol ")")
braces = between (symbol "{") (symbol "}")

semi, identifier, stringLiteral :: Parser String
semi = symbol ";"
identifier = lexeme ((:) <$> letterChar <*> many alphaNumChar)
stringLiteral = char '"' >> manyTill L.charLiteral (char '"') <* sc

integer :: Parser Integer
integer = lexeme (L.signed sc L.decimal)
-- end snippet basic-parsers

-- start snippet parser-utils
runParser :: Parser a -> String -> Either String a
runParser parser code = do
  case parse parser "" code of
    Left err -> Left $ errorBundlePretty err
    Right prog -> Right prog

pPrint :: (MonadIO m, Show a) => a -> m ()
pPrint =
  pPrintOpt CheckColorTty $
    defaultOutputOptionsNoColor
      { outputOptionsIndentAmount = 2,
        outputOptionsCompact = True,
        outputOptionsCompactParens = True
      }
-- end snippet parser-utils

-- start snippet expr-op
operators :: [[Operator Parser Expr]]
operators =
  [ [Prefix $ Receive <$ symbol "<-"],
    [ binary Plus $ symbol "+",
      binary Minus $ try (symbol "-" <* notFollowedBy (char '>'))
    ],
    [ binary LessThan $ symbol "<",
      binary GreaterThan $ symbol ">"
    ],
    [ binary Equals $ symbol "==",
      binary NotEquals $ symbol "!="
    ]
  ]
  where
    binary op symP = InfixL $ Binary op <$ symP
-- end snippet expr-op

-- start snippet expr-term
term :: Parser Expr
term =
  LNull <$ symbol "null"
    <|> LBool True <$ symbol "true"
    <|> LBool False <$ symbol "false"
    <|> LStr <$> stringLiteral
    <|> LNum <$> integer
    <|> try (Call <$> identifier <*> parens (sepBy expr (char ',' *> sc)))
    <|> Variable <$> identifier
    <|> parens expr
-- end snippet expr-term

-- start snippet expr
expr :: Parser Expr
expr = makeExprParser term operators
-- end snippet expr

-- start snippet stmt
stmt :: Parser Stmt
stmt =
  IfStmt <$> (symbol "if" *> parens expr) <*> braces (many stmt)
    <|> WhileStmt <$> (symbol "while" *> parens expr) <*> braces (many stmt)
    <|> VarStmt <$> (symbol "var" *> identifier) <*> (symbol "=" *> expr <* semi)
    <|> YieldStmt <$ symbol "yield" <* semi
    <|> SpawnStmt <$> (symbol "spawn" *> expr <* semi)
    <|> ReturnStmt <$> (symbol "return" *> optional expr <* semi)
    <|> FunctionStmt
      <$> (symbol "function" *> identifier)
      <*> parens (sepBy identifier (char ',' *> sc))
      <*> braces (many stmt)
    <|> try (AssignStmt <$> identifier <*> (symbol "=" *> expr <* semi))
    <|> try (SendStmt <$> expr <*> (symbol "->" *> identifier <* semi))
    <|> ExprStmt <$> expr <* semi
-- end snippet stmt

-- start snippet program
program :: Parser Program
program = sc *> many stmt <* eof
-- end snippet program

-- start snippet value
data Value
  = Null
  | Boolean Bool
  | Str String
  | Num Integer
  | Function Identifier [Identifier] [Stmt] Env
  | Chan Channel
-- end snippet value

-- start snippet value-insts
instance Show Value where
  show = \case
    Null -> "null"
    Boolean b -> show b
    Str s -> s
    Num n -> show n
    Function name _ _ _ -> "function " <> name
    Chan Channel {} -> "Channel"

instance Eq Value where
  Null == Null = True
  Boolean b1 == Boolean b2 = b1 == b2
  Str s1 == Str s2 = s1 == s2
  Num n1 == Num n2 = n1 == n2
  _ == _ = False
-- end snippet value-insts

-- start snippet coroutine
data Coroutine a = Coroutine InterpreterState (a -> Interpreter ())
-- end snippet coroutine

-- start snippet channel
data Channel = Channel
  { channelSendQueue :: Queue (Coroutine (), Value),
    channelReceiveQueue :: Queue (Coroutine Value)
  }

newChannel :: Interpreter Channel
newChannel = Channel <$> newIORef Seq.empty <*> newIORef Seq.empty
-- end snippet channel

-- start snippet state
data Env = Env
  { envBindings :: Map.Map Identifier (IORef Value),
    envEnclosing :: Maybe Env
  }

type Queue a = IORef (Seq.Seq a)

data InterpreterState = InterpreterState
  { isEnv :: Env,
    isCoroutines :: Queue (Coroutine ())
  }

newInterpreterState :: IO InterpreterState
newInterpreterState =
  InterpreterState (Env Map.empty Nothing) <$> newIORef Seq.empty
-- end snippet state

-- start snippet nlr
data NonLocalReturn
  = RuntimeError String
  | Return (Maybe Value)
  | CoroutineQueueEmpty

throw :: String -> Interpreter a
throw = throwError . RuntimeError
-- end snippet nlr

-- start snippet interpreter
newtype Interpreter a = Interpreter
  { runInterpreter ::
      ExceptT
        NonLocalReturn
        ( ContT
            (Either NonLocalReturn ())
            (StateT InterpreterState IO)
        )
        a
  }
  deriving
    ( Functor,
      Applicative,
      Monad,
      MonadIO,
      MonadBase IO,
      MonadState InterpreterState,
      MonadError NonLocalReturn,
      MonadCont
    )
-- end snippet interpreter

-- start snippet env
lookupEnv :: Identifier -> Interpreter Value
lookupEnv name = State.gets isEnv >>= go
  where
    go (Env bindings enclosing) =
      case Map.lookup name bindings of
        Just ref -> readIORef ref
        Nothing -> case enclosing of
          Just env' -> go env'
          Nothing -> throw $ "Unknown variable: " <> name

defineEnv :: Identifier -> Value -> Interpreter ()
defineEnv name value = do
  is@InterpreterState {isEnv = Env bindings enclosing} <- State.get
  valueRef <- newIORef value
  let env' = Env (Map.insert name valueRef bindings) enclosing
  State.put $ is {isEnv = env'}

assignEnv :: Identifier -> Value -> Interpreter ()
assignEnv name value = State.get >>= go . isEnv
  where
    go (Env bindings enclosing) =
      case Map.lookup name bindings of
        Just ref -> writeIORef ref value
        Nothing -> case enclosing of
          Just env' -> go env'
          Nothing -> throw $ "Unknown variable: " <> name
-- end snippet env

-- start snippet evaluate
evaluate :: Expr -> Interpreter Value
evaluate = \case
  LNull -> pure Null
  LBool bool -> pure $ Boolean bool
  LStr str -> pure $ Str str
  LNum num -> pure $ Num num
  Variable v -> lookupEnv v
  Receive expr ->
    evaluate expr >>= \case
      Chan channel -> channelReceive channel
      val -> throw $ "Cannot recieve from a non-channel: " <> show val
  call@Call {} -> evaluateCall call
  binary@Binary {} -> evaluateBinary binary
-- end snippet evaluate

-- start snippet evaluate-call
evaluateCall :: Expr -> Interpreter Value
evaluateCall (Call "newChannel" _) = Chan <$> newChannel
evaluateCall (Call "print" [expr]) =
  evaluate expr >>= liftIO . print >> return Null
evaluateCall ~(Call name argEs) = do
  args <- traverse evaluate argEs
  lookupEnv name >>= \case
    Function _ params body env -> do
      is@InterpreterState {isEnv = origEnv} <- State.get
      State.put $ is {isEnv = Env Map.empty (Just env)}
      forM_ (zip params args) $ uncurry defineEnv
      ret <-
        (mapM_ execute body >> return Null) `catchError` \case
          Return val -> return $ fromMaybe Null val
          err -> throwError err
      State.put is {isEnv = origEnv}
      return ret
    val -> throw $ "Cannot call a non-function: " <> show val
-- end snippet evaluate-call

-- start snippet evaluate-binary
evaluateBinary :: Expr -> Interpreter Value
evaluateBinary ~(Binary op leftE rightE) = do
  left <- evaluate leftE
  right <- evaluate rightE
  case (op, left, right) of
    (Plus, Str s1, Str s2) -> pure $ Str $ s1 <> s2
    (Plus, Num n1, Num n2) -> pure $ Num $ n1 + n2
    (Plus, Str s1, _) -> pure $ Str $ s1 <> show right
    (Plus, _, Str s2) -> pure $ Str $ show left <> s2
    (Plus, _, _) ->
      throw $ "Cannot add or append: " <> show left <> " and " <> show right
    (Minus, Num n1, Num n2) -> pure $ Num $ n1 - n2
    (Minus, _, _) ->
      throw $ "Cannot subtract: " <> show left <> " and " <> show right
    (Equals, _, _) -> pure $ Boolean $ left == right
    (NotEquals, _, _) -> pure $ Boolean $ left /= right
    (LessThan, Num n1, Num n2) -> pure $ Boolean $ n1 < n2
    (LessThan, _, _) ->
      throw $ "Cannot compare non-numbers: " <> show left <> " and " <> show right
    (GreaterThan, Num n1, Num n2) -> pure $ Boolean $ n1 > n2
    (GreaterThan, _, _) ->
      throw $ "Cannot compare non-numbers: " <> show left <> " and " <> show right
-- end snippet evaluate-binary

-- start snippet execute
execute :: Stmt -> Interpreter ()
execute = \case
  ExprStmt expr -> void $ evaluate expr
  AssignStmt name expr -> evaluate expr >>= assignEnv name
  VarStmt name expr -> evaluate expr >>= defineEnv name
  ReturnStmt mExpr -> traverse evaluate mExpr >>= throwError . Return
  IfStmt expr body -> do
    cond <- evaluate expr
    when (isTruthy cond) $ mapM_ execute body
  while@(WhileStmt expr body) -> do
    cond <- evaluate expr
    when (isTruthy cond) $
      mapM_ execute body >> execute while
  FunctionStmt name params body ->
    State.gets isEnv >>= defineEnv name . Function name params body
  YieldStmt -> yield
  SpawnStmt expr -> spawn (void $ evaluate expr)
  SendStmt expr chan -> do
    val <- evaluate expr
    evaluate (Variable chan) >>= \case
      Chan channel -> channelSend val channel
      val' -> throw $ "Cannot send to a non-channel: " <> show val'
  where
    isTruthy = \case Null -> False; Boolean b -> b; _ -> True
-- end snippet execute

-- start snippet interpret
interpret :: Program -> IO (Maybe String)
interpret program = do
  state <- newInterpreterState
  ret <-
    flip evalStateT state
      . flip runContT return
      . runExceptT
      . runInterpreter
      $ (mapM_ execute program <* awaitTermination)
  case ret of
    Left (RuntimeError err) -> return $ Just err
    _ -> return Nothing
-- end snippet interpret

-- start snippet queue
enqueue :: a -> Queue a -> Interpreter ()
enqueue val queue = modifyIORef' queue (|> val)

dequeue :: Queue a -> Interpreter (Maybe a)
dequeue queue = atomicModifyIORef' queue $ \vals ->
  case Seq.viewl vals of
    Seq.EmptyL -> (vals, Nothing)
    val :< rest -> (rest, Just val)
-- end snippet queue

-- start snippet coroutine-ops
enqueueCoroutine :: Coroutine () -> Interpreter ()
enqueueCoroutine coroutine = State.gets isCoroutines >>= enqueue coroutine

dequeueCoroutine :: Interpreter ()
dequeueCoroutine =
  State.gets isCoroutines >>= dequeue >>= \case
    Nothing -> throwError CoroutineQueueEmpty
    Just (Coroutine state action) -> State.put state >> action ()

yield :: Interpreter ()
yield = do
  state <- State.get
  callCC $ \k -> do
    enqueueCoroutine (Coroutine state k)
    dequeueCoroutine

spawn :: Interpreter () -> Interpreter ()
spawn action = do
  state <- State.get
  callCC $ \k -> do
    enqueueCoroutine (Coroutine state k)
    action
    dequeueCoroutine

awaitTermination :: Interpreter ()
awaitTermination = do
  finished <- State.gets isCoroutines >>= fmap null . readIORef
  unless finished $ yield >> awaitTermination
-- end snippet coroutine-ops

-- start snippet channel-send
channelSend :: Value -> Channel -> Interpreter ()
channelSend value Channel {..} =
  dequeue channelReceiveQueue >>= \case
    Just (Coroutine state recieve) -> do
      enqueueCoroutine $ Coroutine state (const $ recieve value)
      yield
    Nothing -> do
      state <- State.get
      callCC $ \k -> do
        enqueue (Coroutine state k, value) channelSendQueue
        dequeueCoroutine
-- end snippet channel-send

-- start snippet channel-receive
channelReceive :: Channel -> Interpreter Value
channelReceive Channel {..} =
  dequeue channelSendQueue >>= \case
    Just (coroutine, value) -> do
      enqueueCoroutine coroutine
      state <- State.get
      callCC $ \k -> do
        enqueueCoroutine (Coroutine state (const $ k value))
        dequeueCoroutine
        return Null
    Nothing -> do
      state <- State.get
      callCC $ \k -> do
        enqueue (Coroutine state k) channelReceiveQueue
        dequeueCoroutine
        return Null
-- end snippet channel-receive

-- start snippet run-file
runFile :: FilePath -> IO ()
runFile file = do
  code <- readFile file
  case runParser program code of
    Left err -> putStrLn err
    Right program ->
      interpret program >>= \case
        Just err -> putStrLn err
        Nothing -> return ()
-- end snippet run-file