diff --git a/src/ast/astnode.rs b/src/ast/astnode.rs index 0688935..e69de29 100644 --- a/src/ast/astnode.rs +++ b/src/ast/astnode.rs @@ -1,196 +0,0 @@ -use crate::lexer::token; -use from_variants::FromVariants; -use match_any::match_any; -use std::fmt::{Debug, Display}; - -#[derive(FromVariants)] -pub enum ASTNode { - BinaryExpr(BinaryExpr), - GroupingExpr(GroupingExpr), - Literal(Literal), - UnaryExpr(UnaryExpr), -} - -macro_rules! all_variants { - ($expr:expr, $val_name:ident => $expr_arm:expr) => { - { - use match_any::match_any; - use $crate::ast::astnode::*; - match_any!($expr, ASTNode::BinaryExpr($val_name) | ASTNode::GroupingExpr($val_name) | ASTNode::Literal($val_name) | ASTNode::UnaryExpr($val_name) => $expr_arm) - } - - }; -} -pub(crate) use all_variants; - -impl Debug for ASTNode { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - all_variants!(self, n => n.fmt(f)) - } -} - -pub enum UnaryOperator { - Minus, - Bang, -} - -impl TryFrom for UnaryOperator { - type Error = EnumConvertError; - - fn try_from(value: token::TokenType) -> Result { - Ok(match value { - token::TokenType::Bang => Self::Bang, - token::TokenType::Minus => Self::Minus, - _ => return Err(EnumConvertError), - }) - } -} - -impl Display for UnaryOperator { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}", - match *self { - UnaryOperator::Minus => "-", - UnaryOperator::Bang => "!", - } - ) - } -} - -pub enum Operator { - BangEqual, - Equal, - EqualEqual, - Greater, - GreaterEqual, - Less, - LessEqual, -} - -#[derive(Debug)] -pub struct EnumConvertError; -impl std::fmt::Display for EnumConvertError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Couldn't convert between enums") - } -} - -impl std::error::Error for EnumConvertError {} - -impl TryFrom for Operator { - type Error = EnumConvertError; - - fn try_from(value: token::TokenType) -> Result { - Ok(match value { - token::TokenType::BangEqual => Self::BangEqual, - token::TokenType::Equal => Self::Equal, - token::TokenType::EqualEqual => Self::EqualEqual, - token::TokenType::Greater => Self::Greater, - token::TokenType::GreaterEqual => Self::GreaterEqual, - token::TokenType::Less => Self::Less, - token::TokenType::LessEqual => Self::LessEqual, - - _ => return Err(EnumConvertError), - }) - } -} - -impl Display for Operator { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}", - match *self { - Operator::Less => "<", - Operator::Equal => "=", - Operator::Greater => ">", - Operator::BangEqual => "!=", - Operator::LessEqual => "<=", - Operator::EqualEqual => "==", - Operator::GreaterEqual => ">=", - } - ) - } -} - -#[derive(Debug, Clone)] -pub enum Literal { - String(String), - Int(i32), - Float(f32), - Bool(bool), -} - -pub struct BinaryExpr { - pub left: Box, - pub operator: Operator, - pub right: Box, -} - -impl BinaryExpr { - pub fn new(left: Box, operator: Operator, right: Box) -> Self { - Self { - left, - operator, - right, - } - } -} - -pub struct GroupingExpr(pub Box); - -impl GroupingExpr { - pub(crate) fn new(expr: Box) -> Self { - Self(expr) - } -} - -pub struct UnaryExpr { - pub operator: UnaryOperator, - pub right: Box, -} - -impl UnaryExpr { - pub fn new(operator: UnaryOperator, right: Box) -> Self { - Self { operator, right } - } -} - -impl Debug for BinaryExpr { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "({} {:?} {:?})", self.operator, self.left, self.right) - } -} - -impl Debug for GroupingExpr { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "({:?})", self.0) - } -} - -impl Debug for UnaryExpr { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "({} {:?})", self.operator, self.right) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn ast_printer() { - let ast = ASTNode::BinaryExpr(BinaryExpr { - left: Box::new(ASTNode::UnaryExpr(UnaryExpr { - operator: UnaryOperator::Bang, - right: Box::new(ASTNode::Literal(Literal::Int(1))), - })), - operator: Operator::EqualEqual, - right: Box::new(ASTNode::Literal(Literal::Int(0))), - }); - let formated = format!("{:?}", ast); - assert_eq!("(== (! Int(1)) Int(0))", formated); - } -} diff --git a/src/ast/expression/expression_node.rs b/src/ast/expression/expression_node.rs new file mode 100644 index 0000000..d33596b --- /dev/null +++ b/src/ast/expression/expression_node.rs @@ -0,0 +1,195 @@ +use crate::lexer::token; +use from_variants::FromVariants; +use match_any::match_any; +use std::fmt::{Debug, Display}; + +#[derive(FromVariants)] +pub enum ExpressionNode { + BinaryExpr(BinaryExpr), + GroupingExpr(GroupingExpr), + Literal(Literal), + UnaryExpr(UnaryExpr), +} + +macro_rules! all_variants { + ($expr:expr, $val_name:ident => $expr_arm:expr) => { + { + use match_any::match_any; + use $crate::ast::expression::expression_node::*; + match_any!($expr, ExpressionNode::BinaryExpr($val_name) | ExpressionNode::GroupingExpr($val_name) | ExpressionNode::Literal($val_name) | ExpressionNode::UnaryExpr($val_name) => $expr_arm) + } + }; +} +pub(crate) use all_variants; + +impl Debug for ExpressionNode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + all_variants!(self, n => n.fmt(f)) + } +} + +pub enum UnaryOperator { + Minus, + Bang, +} + +impl TryFrom for UnaryOperator { + type Error = EnumConvertError; + + fn try_from(value: token::TokenType) -> Result { + Ok(match value { + token::TokenType::Bang => Self::Bang, + token::TokenType::Minus => Self::Minus, + _ => return Err(EnumConvertError), + }) + } +} + +impl Display for UnaryOperator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match *self { + UnaryOperator::Minus => "-", + UnaryOperator::Bang => "!", + } + ) + } +} + +pub enum Operator { + BangEqual, + Equal, + EqualEqual, + Greater, + GreaterEqual, + Less, + LessEqual, +} + +#[derive(Debug)] +pub struct EnumConvertError; +impl std::fmt::Display for EnumConvertError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Couldn't convert between enums") + } +} + +impl std::error::Error for EnumConvertError {} + +impl TryFrom for Operator { + type Error = EnumConvertError; + + fn try_from(value: token::TokenType) -> Result { + Ok(match value { + token::TokenType::BangEqual => Self::BangEqual, + token::TokenType::Equal => Self::Equal, + token::TokenType::EqualEqual => Self::EqualEqual, + token::TokenType::Greater => Self::Greater, + token::TokenType::GreaterEqual => Self::GreaterEqual, + token::TokenType::Less => Self::Less, + token::TokenType::LessEqual => Self::LessEqual, + + _ => return Err(EnumConvertError), + }) + } +} + +impl Display for Operator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match *self { + Operator::Less => "<", + Operator::Equal => "=", + Operator::Greater => ">", + Operator::BangEqual => "!=", + Operator::LessEqual => "<=", + Operator::EqualEqual => "==", + Operator::GreaterEqual => ">=", + } + ) + } +} + +#[derive(Debug, Clone)] +pub enum Literal { + String(String), + Int(i32), + Float(f32), + Bool(bool), +} + +pub struct BinaryExpr { + pub left: Box, + pub operator: Operator, + pub right: Box, +} + +impl BinaryExpr { + pub fn new(left: Box, operator: Operator, right: Box) -> Self { + Self { + left, + operator, + right, + } + } +} + +pub struct GroupingExpr(pub Box); + +impl GroupingExpr { + pub(crate) fn new(expr: Box) -> Self { + Self(expr) + } +} + +pub struct UnaryExpr { + pub operator: UnaryOperator, + pub right: Box, +} + +impl UnaryExpr { + pub fn new(operator: UnaryOperator, right: Box) -> Self { + Self { operator, right } + } +} + +impl Debug for BinaryExpr { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "({} {:?} {:?})", self.operator, self.left, self.right) + } +} + +impl Debug for GroupingExpr { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "({:?})", self.0) + } +} + +impl Debug for UnaryExpr { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "({} {:?})", self.operator, self.right) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn expression_node_ast_printer() { + let ast = ExpressionNode::BinaryExpr(BinaryExpr { + left: Box::new(ExpressionNode::UnaryExpr(UnaryExpr { + operator: UnaryOperator::Bang, + right: Box::new(ExpressionNode::Literal(Literal::Int(1))), + })), + operator: Operator::EqualEqual, + right: Box::new(ExpressionNode::Literal(Literal::Int(0))), + }); + let formated = format!("{:?}", ast); + assert_eq!("(== (! Int(1)) Int(0))", formated); + } +} diff --git a/src/ast/expression/expression_parser.rs b/src/ast/expression/expression_parser.rs new file mode 100644 index 0000000..c102fea --- /dev/null +++ b/src/ast/expression/expression_parser.rs @@ -0,0 +1,124 @@ +use crate::lexer::token::{self, TokenType}; + +use super::super::parser::{InnerASTParsingError, Parser, Result}; +use super::expression_node::*; + +impl<'a, T: Iterator>> Parser<'a, T> { + pub fn expression(&mut self) -> Result { + self.equality() + } + + fn equality(&mut self) -> Result { + let mut node = self.comparison()?; + while let Some(o) = self + .token_iter + .next_if(|t| matches!(t.token_type, TokenType::EqualEqual | TokenType::BangEqual)) + { + node = BinaryExpr::new( + Box::new(node), + o.token_type.try_into().unwrap(), + Box::new(self.comparison()?), + ) + .into(); + } + Ok(node) + } + + fn comparison(&mut self) -> Result { + let mut node = self.term()?; + + while let Some(o) = self.token_iter.next_if(|t| { + matches!( + t.token_type, + TokenType::Greater + | TokenType::GreaterEqual + | TokenType::Less | TokenType::LessEqual + ) + }) { + node = BinaryExpr::new( + Box::new(node), + o.token_type.try_into().unwrap(), + Box::new(self.comparison()?), + ) + .into(); + } + Ok(node) + } + + fn term(&mut self) -> Result { + let mut node = self.factor()?; + + while let Some(o) = self + .token_iter + .next_if(|t| matches!(t.token_type, TokenType::Minus | TokenType::Plus)) + { + node = BinaryExpr::new( + Box::new(node), + o.token_type.try_into().unwrap(), + Box::new(self.comparison()?), + ) + .into() + } + Ok(node) + } + + fn factor(&mut self) -> Result { + let mut node = self.unary()?; + + while let Some(o) = self + .token_iter + .next_if(|t| matches!(t.token_type, TokenType::Star | TokenType::Slash)) + { + node = BinaryExpr::new( + Box::new(node), + o.token_type.try_into().unwrap(), + Box::new(self.comparison()?), + ) + .into(); + } + Ok(node) + } + + fn unary(&mut self) -> Result { + if let Some(op) = self + .token_iter + .next_if(|t| matches!(t.token_type, TokenType::Bang | TokenType::Minus)) + { + let right = Box::new(self.unary()?); + Ok(ExpressionNode::UnaryExpr(UnaryExpr::new( + op.token_type.try_into().unwrap(), + right, + ))) + } else { + self.primary() + } + } + + fn primary(&mut self) -> Result { + let node = match self.token_iter.next() { + Some(token) => match token.token_type { + TokenType::False => ExpressionNode::Literal(Literal::Bool(false)), + TokenType::True => ExpressionNode::Literal(Literal::Bool(true)), + TokenType::Int(i) => ExpressionNode::Literal(Literal::Int(i)), + TokenType::String(i) => ExpressionNode::Literal(Literal::String(i)), + TokenType::Float(f) => ExpressionNode::Literal(Literal::Float(f)), + TokenType::LeftParen => { + let expr = self.expression()?; + let group = GroupingExpr::new(Box::new(expr)); + match self + .token_iter + .next_if(|v| matches!(v.token_type, TokenType::RightParen)) + { + Some(_) => return Ok(group.into()), + None => { + return Err(token.location.wrap(InnerASTParsingError::UnmatchedBrace)) + } + } + } + a => return Err(token.location.wrap(InnerASTParsingError::IncorrectToken(a))), + }, + None => todo!(), + }; + Ok(node) + } +} diff --git a/src/ast/expression/mod.rs b/src/ast/expression/mod.rs new file mode 100644 index 0000000..8b145e4 --- /dev/null +++ b/src/ast/expression/mod.rs @@ -0,0 +1,2 @@ +pub mod expression_node; +mod expression_parser; diff --git a/src/ast/mod.rs b/src/ast/mod.rs index de6d4be..19d7ea5 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -1,2 +1,2 @@ -pub mod astnode; +pub mod expression; pub mod parser; diff --git a/src/ast/parser.rs b/src/ast/parser.rs index c5e7ffc..ba879a2 100644 --- a/src/ast/parser.rs +++ b/src/ast/parser.rs @@ -1,5 +1,4 @@ -use super::astnode; -use super::astnode::{ASTNode, BinaryExpr}; +use super::expression::expression_node; use crate::error::ErrorLocationWrapper; use crate::lexer::{token, token::TokenType}; @@ -23,10 +22,10 @@ impl std::fmt::Display for InnerASTParsingError { impl std::error::Error for InnerASTParsingError {} pub type ASTParsingError = ErrorLocationWrapper; -type Result = StdResult; +pub(super) type Result = StdResult; pub struct Parser<'a, T: Iterator>> { - token_iter: iter::Peekable, + pub(super) token_iter: iter::Peekable, } impl<'a, T: Iterator>> Parser<'a, T> { @@ -35,7 +34,9 @@ impl<'a, T: Iterator>> Parser<'a, T> { token_iter: iter.peekable(), } } - pub fn scan_expressions(&mut self) -> StdResult, Vec> { + pub fn parse_all( + &mut self, + ) -> StdResult, Vec> { let mut tokens = Vec::new(); let mut errors = Vec::new(); @@ -55,122 +56,4 @@ impl<'a, T: Iterator>> Parser<'a, T> { Err(errors) } } - - fn expression(&mut self) -> Result { - self.equality() - } - - fn equality(&mut self) -> Result { - let mut node = self.comparison()?; - while let Some(o) = self - .token_iter - .next_if(|t| matches!(t.token_type, TokenType::EqualEqual | TokenType::BangEqual)) - { - node = BinaryExpr::new( - Box::new(node), - o.token_type.try_into().unwrap(), - Box::new(self.comparison()?), - ) - .into(); - } - Ok(node) - } - - fn comparison(&mut self) -> Result { - let mut node = self.term()?; - - while let Some(o) = self.token_iter.next_if(|t| { - matches!( - t.token_type, - TokenType::Greater - | TokenType::GreaterEqual - | TokenType::Less | TokenType::LessEqual - ) - }) { - node = BinaryExpr::new( - Box::new(node), - o.token_type.try_into().unwrap(), - Box::new(self.comparison()?), - ) - .into(); - } - Ok(node) - } - - fn term(&mut self) -> Result { - let mut node = self.factor()?; - - while let Some(o) = self - .token_iter - .next_if(|t| matches!(t.token_type, TokenType::Minus | TokenType::Plus)) - { - node = BinaryExpr::new( - Box::new(node), - o.token_type.try_into().unwrap(), - Box::new(self.comparison()?), - ) - .into() - } - Ok(node) - } - - fn factor(&mut self) -> Result { - let mut node = self.unary()?; - - while let Some(o) = self - .token_iter - .next_if(|t| matches!(t.token_type, TokenType::Star | TokenType::Slash)) - { - node = BinaryExpr::new( - Box::new(node), - o.token_type.try_into().unwrap(), - Box::new(self.comparison()?), - ) - .into(); - } - Ok(node) - } - - fn unary(&mut self) -> Result { - if let Some(op) = self - .token_iter - .next_if(|t| matches!(t.token_type, TokenType::Bang | TokenType::Minus)) - { - let right = Box::new(self.unary()?); - Ok(ASTNode::UnaryExpr(astnode::UnaryExpr::new( - op.token_type.try_into().unwrap(), - right, - ))) - } else { - self.primary() - } - } - - fn primary(&mut self) -> Result { - let node = match self.token_iter.next() { - Some(token) => match token.token_type { - TokenType::False => ASTNode::Literal(astnode::Literal::Bool(false)), - TokenType::True => ASTNode::Literal(astnode::Literal::Bool(true)), - TokenType::Int(i) => ASTNode::Literal(astnode::Literal::Int(i)), - TokenType::String(i) => ASTNode::Literal(astnode::Literal::String(i)), - TokenType::Float(f) => ASTNode::Literal(astnode::Literal::Float(f)), - TokenType::LeftParen => { - let expr = self.expression()?; - let group = astnode::GroupingExpr::new(Box::new(expr)); - match self - .token_iter - .next_if(|v| matches!(v.token_type, TokenType::RightParen)) - { - Some(_) => return Ok(group.into()), - None => { - return Err(token.location.wrap(InnerASTParsingError::UnmatchedBrace)) - } - } - } - a => return Err(token.location.wrap(InnerASTParsingError::IncorrectToken(a))), - }, - None => todo!(), - }; - Ok(node) - } } diff --git a/src/interpreter/ast_walker.rs b/src/interpreter/ast_walker.rs index da478d5..3a44cb4 100644 --- a/src/interpreter/ast_walker.rs +++ b/src/interpreter/ast_walker.rs @@ -1,7 +1,7 @@ use std::fmt::Display; use super::types::Value; -use crate::ast::astnode::{self, UnaryOperator}; +use crate::ast::expression::expression_node; #[derive(Debug)] pub struct RuntimeError; @@ -17,40 +17,40 @@ pub trait Interpret { fn interpret(&self) -> Result; } -impl Interpret for astnode::ASTNode { +impl Interpret for expression_node::ExpressionNode { fn interpret(&self) -> Result { - astnode::all_variants!(self, n => n.interpret()) + expression_node::all_variants!(self, n => n.interpret()) } } -impl Interpret for astnode::Literal { +impl Interpret for expression_node::Literal { fn interpret(&self) -> Result { Ok(self.clone().into()) } } -impl Interpret for astnode::BinaryExpr { +impl Interpret for expression_node::BinaryExpr { fn interpret(&self) -> Result { let left_val = self.left.interpret().expect("expected lval"); let right_val = self.right.interpret().expect("expected rval"); match self.operator { - astnode::Operator::BangEqual => Ok((left_val != right_val).into()), - astnode::Operator::Less => Ok((left_val < right_val).into()), - astnode::Operator::LessEqual => Ok((left_val <= right_val).into()), - astnode::Operator::Greater => Ok((left_val > right_val).into()), - astnode::Operator::GreaterEqual => Ok((left_val >= right_val).into()), - astnode::Operator::EqualEqual => Ok((left_val == right_val).into()), - astnode::Operator::Equal => todo!(), + expression_node::Operator::BangEqual => Ok((left_val != right_val).into()), + expression_node::Operator::Less => Ok((left_val < right_val).into()), + expression_node::Operator::LessEqual => Ok((left_val <= right_val).into()), + expression_node::Operator::Greater => Ok((left_val > right_val).into()), + expression_node::Operator::GreaterEqual => Ok((left_val >= right_val).into()), + expression_node::Operator::EqualEqual => Ok((left_val == right_val).into()), + expression_node::Operator::Equal => todo!(), } } } -impl Interpret for astnode::UnaryExpr { +impl Interpret for expression_node::UnaryExpr { fn interpret(&self) -> Result { let val = self.right.interpret()?; match self.operator { - UnaryOperator::Bang => Ok(Value::Bool(!val.truthy())), - UnaryOperator::Minus => match val { + expression_node::UnaryOperator::Bang => Ok(Value::Bool(!val.truthy())), + expression_node::UnaryOperator::Minus => match val { Value::Int(i) => Ok(Value::Int(-i)), Value::Float(f) => Ok(Value::Float(-f)), _ => Err(RuntimeError), @@ -59,7 +59,7 @@ impl Interpret for astnode::UnaryExpr { } } -impl Interpret for astnode::GroupingExpr { +impl Interpret for expression_node::GroupingExpr { fn interpret(&self) -> Result { self.0.interpret() } diff --git a/src/interpreter/types.rs b/src/interpreter/types.rs index 879dbdb..a5cd514 100644 --- a/src/interpreter/types.rs +++ b/src/interpreter/types.rs @@ -1,4 +1,4 @@ -use crate::ast::astnode; +use crate::ast::expression::expression_node; use from_variants::FromVariants; #[derive(Debug, PartialEq, PartialOrd, FromVariants)] @@ -11,10 +11,10 @@ pub enum Value { String(String), } -impl From for Value { - fn from(l: astnode::Literal) -> Self { +impl From for Value { + fn from(l: expression_node::Literal) -> Self { match_any::match_any!(l, - astnode::Literal::Int(v) | astnode::Literal::Bool(v) | astnode::Literal::Float(v) | astnode::Literal::String(v) => v.into() + expression_node::Literal::Int(v) | expression_node::Literal::Bool(v) | expression_node::Literal::Float(v) | expression_node::Literal::String(v) => v.into() ) } } diff --git a/src/lib.rs b/src/lib.rs index c4ce1d2..e02453f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,7 +3,8 @@ pub mod error; pub mod interpreter; pub mod lexer; -use ast::{astnode::ASTNode, parser::ASTParsingError}; +use ast::expression::expression_node::ExpressionNode; +use ast::parser::ASTParsingError; use interpreter::ast_walker::{Interpret, RuntimeError}; use interpreter::types::Value; use lexer::{token::Token, Lexer, LexingError}; @@ -16,12 +17,12 @@ pub fn lex<'a, 'b>( lexer.scan_tokens() } -pub fn parse(tokens: Vec) -> Result, Vec> { +pub fn parse(tokens: Vec) -> Result, Vec> { let mut parser = crate::ast::parser::Parser::new(tokens.into_iter()); - parser.scan_expressions() + parser.parse_all() } -pub fn exec(nodes: Vec) -> Result { +pub fn exec(nodes: Vec) -> Result { nodes[0].interpret() }