{- stack
  ghci
  --resolver lts-18.7
  --package containers
  --package megaparsec
  --package parser-combinators
  --package mtl
  --package lifted-base
  --package transformers-base
  --package pretty-simple
  --ghc-options "-Wall -Wno-name-shadowing"
-}

-- start snippet imports
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RecordWildCards #-}

import Control.Monad (unless, void, when)
import Control.Monad.Base (MonadBase)
import Control.Monad.Combinators.Expr (Operator (..), makeExprParser)
import Control.Monad.Cont (ContT, MonadCont, callCC, runContT)
import Control.Monad.Except (ExceptT, MonadError (..), runExceptT)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.State.Strict (MonadState, StateT, evalStateT)
import qualified Control.Monad.State.Strict as State
import Data.Foldable (for_, traverse_)
import Data.IORef.Lifted
import qualified Data.Map.Strict as Map
import Data.Maybe (fromMaybe)
import Data.Sequence (ViewL ((:<)), (|>))
import qualified Data.Sequence as Seq
import Data.Void (Void)
import Text.Megaparsec hiding (runParser)
import Text.Megaparsec.Char
import qualified Text.Megaparsec.Char.Lexer as L
import Text.Pretty.Simple
  ( CheckColorTty (..),
    OutputOptions (..),
    defaultOutputOptionsNoColor,
    pPrintOpt,
  )
-- end snippet imports

-- start snippet ast-expression
data Expr
  = LNull
  | LBool Bool
  | LStr String
  | LNum Integer
  | Variable Identifier
  | Binary BinOp Expr Expr
  | Call Identifier [Expr]
  | Receive Expr
  deriving (Show, Eq)

type Identifier = String

data BinOp = Plus | Minus | Equals | NotEquals | LessThan | GreaterThan
  deriving (Show, Eq)
-- end snippet ast-expression

-- start snippet ast-statement
data Stmt
  = ExprStmt Expr
  | VarStmt Identifier Expr
  | AssignStmt Identifier Expr
  | IfStmt Expr [Stmt]
  | WhileStmt Expr [Stmt]
  | FunctionStmt Identifier [Identifier] [Stmt]
  | ReturnStmt (Maybe Expr)
  | YieldStmt
  | SpawnStmt Expr
  | SendStmt Expr Identifier
  deriving (Show, Eq)

type Program = [Stmt]
-- end snippet ast-statement

-- start snippet basic-parsers
type Parser = Parsec Void String

sc :: Parser ()
sc = L.space (void spaceChar) lineCmnt blockCmnt
  where
    lineCmnt = L.skipLineComment "//"
    blockCmnt = L.skipBlockComment "/*" "*/"

lexeme :: Parser a -> Parser a
lexeme = L.lexeme sc

symbol :: String -> Parser String
symbol = L.symbol sc

parens, braces :: Parser a -> Parser a
parens = between (symbol "(") (symbol ")")
braces = between (symbol "{") (symbol "}")

semi, identifier, stringLiteral :: Parser String
semi = symbol ";"
identifier = lexeme ((:) <$> letterChar <*> many alphaNumChar)
stringLiteral = char '"' >> manyTill L.charLiteral (char '"') <* sc

integer :: Parser Integer
integer = lexeme (L.signed sc L.decimal)
-- end snippet basic-parsers

-- start snippet parser-utils
runParser :: Parser a -> String -> Either String a
runParser parser code = do
  case parse parser "" code of
    Left err -> Left $ errorBundlePretty err
    Right prog -> Right prog

pPrint :: (MonadIO m, Show a) => a -> m ()
pPrint =
  pPrintOpt CheckColorTty $
    defaultOutputOptionsNoColor
      { outputOptionsIndentAmount = 2,
        outputOptionsCompact = True,
        outputOptionsCompactParens = True
      }
-- end snippet parser-utils

-- start snippet expr-op
operators :: [[Operator Parser Expr]]
operators =
  [ [Prefix $ Receive <$ symbol "<-"],
    [ binary Plus $ symbol "+",
      binary Minus $ try (symbol "-" <* notFollowedBy (char '>'))
    ],
    [ binary LessThan $ symbol "<",
      binary GreaterThan $ symbol ">"
    ],
    [ binary Equals $ symbol "==",
      binary NotEquals $ symbol "!="
    ]
  ]
  where
    binary op symP = InfixL $ Binary op <$ symP
-- end snippet expr-op

-- start snippet expr-term
term :: Parser Expr
term =
  LNull <$ symbol "null"
    <|> LBool True <$ symbol "true"
    <|> LBool False <$ symbol "false"
    <|> LStr <$> stringLiteral
    <|> LNum <$> integer
    <|> try (Call <$> identifier <*> parens (sepBy expr (char ',' *> sc)))
    <|> Variable <$> identifier
    <|> parens expr
-- end snippet expr-term

-- start snippet expr
expr :: Parser Expr
expr = makeExprParser term operators
-- end snippet expr

-- start snippet stmt
stmt :: Parser Stmt
stmt =
  IfStmt <$> (symbol "if" *> parens expr) <*> braces (many stmt)
    <|> WhileStmt <$> (symbol "while" *> parens expr) <*> braces (many stmt)
    <|> VarStmt <$> (symbol "var" *> identifier) <*> (symbol "=" *> expr <* semi)
    <|> YieldStmt <$ symbol "yield" <* semi
    <|> SpawnStmt <$> (symbol "spawn" *> expr <* semi)
    <|> ReturnStmt <$> (symbol "return" *> optional expr <* semi)
    <|> FunctionStmt
      <$> (symbol "function" *> identifier)
      <*> parens (sepBy identifier (char ',' *> sc))
      <*> braces (many stmt)
    <|> try (AssignStmt <$> identifier <*> (symbol "=" *> expr <* semi))
    <|> try (SendStmt <$> expr <*> (symbol "->" *> identifier <* semi))
    <|> ExprStmt <$> expr <* semi
-- end snippet stmt

-- start snippet program