From c14cd20d8c65efcc056ae62ec5a647ee7bf66d74 Mon Sep 17 00:00:00 2001 From: bad Date: Thu, 24 Mar 2022 16:33:10 +0100 Subject: [PATCH] AST: Init --- src/ast/astnode.rs | 214 +++++++++++++++++++++++++++++++++++++++++++++ src/ast/mod.rs | 2 + src/ast/parser.rs | 164 ++++++++++++++++++++++++++++++++++ src/lexer/mod.rs | 10 +-- src/lexer/token.rs | 1 - src/main.rs | 8 +- 6 files changed, 392 insertions(+), 7 deletions(-) create mode 100644 src/ast/astnode.rs create mode 100644 src/ast/mod.rs create mode 100644 src/ast/parser.rs diff --git a/src/ast/astnode.rs b/src/ast/astnode.rs new file mode 100644 index 0000000..73d62e7 --- /dev/null +++ b/src/ast/astnode.rs @@ -0,0 +1,214 @@ +use std::fmt::{Debug, Display}; + +use crate::lexer::token; + +pub enum ASTNode { + BinaryExpr(BinaryExpr), + GroupingExpr(GroupingExpr), + Literal(Literal), + UnaryExpr(UnaryExpr), +} + +impl From for ASTNode { + fn from(b: BinaryExpr) -> Self { + ASTNode::BinaryExpr(b) + } +} + +impl From for ASTNode { + fn from(g: GroupingExpr) -> Self { + Self::GroupingExpr(g) + } +} + +impl From for ASTNode { + fn from(l: Literal) -> Self { + Self::Literal(l) + } +} + +impl From for ASTNode { + fn from(u: UnaryExpr) -> Self { + Self::UnaryExpr(u) + } +} + +impl Debug for ASTNode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::BinaryExpr(b) => b.fmt(f), + Self::GroupingExpr(g) => g.fmt(f), + Self::Literal(l) => l.fmt(f), + Self::UnaryExpr(u) => u.fmt(f), + } + } +} + +pub enum UnaryOperator { + Minus, + Bang, +} + +impl TryFrom for UnaryOperator { + type Error = EnumConvertError; + + fn try_from(value: token::TokenType) -> Result { + Ok(match value { + token::TokenType::Bang => Self::Bang, + token::TokenType::Minus => Self::Minus, + _ => return Err(EnumConvertError), + }) + } +} + +impl Display for UnaryOperator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match *self { + UnaryOperator::Minus => "-", + UnaryOperator::Bang => "!", + } + ) + } +} + +pub enum Operator { + Bang, + BangEqual, + Equal, + EqualEqual, + Greater, + GreaterEqual, + Less, + LessEqual, +} + +#[derive(Debug)] +pub struct EnumConvertError; +impl std::fmt::Display for EnumConvertError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Couldn't convert between enums") + } +} + +impl std::error::Error for EnumConvertError {} + +impl TryFrom for Operator { + type Error = EnumConvertError; + + fn try_from(value: token::TokenType) -> Result { + Ok(match value { + token::TokenType::Bang => Self::Bang, + token::TokenType::BangEqual => Self::BangEqual, + token::TokenType::Equal => Self::Equal, + token::TokenType::EqualEqual => Self::EqualEqual, + token::TokenType::Greater => Self::Greater, + token::TokenType::GreaterEqual => Self::GreaterEqual, + token::TokenType::Less => Self::Less, + token::TokenType::LessEqual => Self::LessEqual, + + _ => return Err(EnumConvertError), + }) + } +} + +impl Display for Operator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match *self { + Operator::Bang => "!", + Operator::Less => "<", + Operator::Equal => "=", + Operator::Greater => ">", + Operator::BangEqual => "!=", + Operator::LessEqual => "<=", + Operator::EqualEqual => "==", + Operator::GreaterEqual => ">=", + } + ) + } +} + +#[derive(Debug)] +pub enum Literal { + String(String), + Int(i64), + Float(f64), + Bool(bool), +} + +pub struct BinaryExpr { + left: Box, + operator: Operator, + right: Box, +} + +impl BinaryExpr { + pub fn new(left: Box, operator: Operator, right: Box) -> Self { + Self { + left, + operator, + right, + } + } +} + +pub struct GroupingExpr(Box); + +impl GroupingExpr { + pub(crate) fn new(expr: Box) -> Self { + Self(expr) + } +} + +pub struct UnaryExpr { + operator: UnaryOperator, + right: Box, +} + +impl UnaryExpr { + pub fn new(operator: UnaryOperator, right: Box) -> Self { + Self { operator, right } + } +} + +impl Debug for BinaryExpr { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "({} {:?} {:?})", self.operator, self.left, self.right) + } +} + +impl Debug for GroupingExpr { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "({:?})", self.0) + } +} + +impl Debug for UnaryExpr { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "({} {:?})", self.operator, self.right) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn ast_printer() { + let ast = ASTNode::BinaryExpr(BinaryExpr { + left: Box::new(ASTNode::UnaryExpr(UnaryExpr { + operator: UnaryOperator::Bang, + right: Box::new(ASTNode::Literal(Literal::Int(1))), + })), + operator: Operator::EqualEqual, + right: Box::new(ASTNode::Literal(Literal::Int(0))), + }); + let formated = format!("{:?}", ast); + assert_eq!("(== (! Int(1)) Int(0))", formated); + } +} diff --git a/src/ast/mod.rs b/src/ast/mod.rs new file mode 100644 index 0000000..de6d4be --- /dev/null +++ b/src/ast/mod.rs @@ -0,0 +1,2 @@ +pub mod astnode; +pub mod parser; diff --git a/src/ast/parser.rs b/src/ast/parser.rs new file mode 100644 index 0000000..28e6fed --- /dev/null +++ b/src/ast/parser.rs @@ -0,0 +1,164 @@ +use super::astnode; +use super::astnode::{ASTNode, BinaryExpr}; +use crate::lexer::{token, token::TokenType}; +use std::iter; +use std::result::Result as StdResult; + +type Result = StdResult; + +#[derive(Debug)] +pub enum ASTParsingError { + UnmatchedBrace, +} +impl std::fmt::Display for ASTParsingError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + todo!() + } +} +impl std::error::Error for ASTParsingError {} + +pub struct Parser> { + token_iter: iter::Peekable, +} + +impl> Parser { + pub fn new(iter: T) -> Parser { + Parser { + token_iter: iter.peekable(), + } + } + pub fn scan_expressions(&mut self) -> StdResult, Vec> { + let mut tokens = Vec::new(); + let mut errors = Vec::new(); + + while self.token_iter.peek().is_some() { + match self.expression() { + Ok(token) => { + if errors.is_empty() { + tokens.push(token) + } + } + Err(e) => errors.push(e), + } + } + if errors.is_empty() { + Ok(tokens) + } else { + Err(errors) + } + } + + fn expression(&mut self) -> Result { + self.equality() + } + + fn equality(&mut self) -> Result { + let mut node = self.comparison()?; + while let Some(o) = self + .token_iter + .next_if(|t| matches!(t.token_type, TokenType::EqualEqual | TokenType::BangEqual)) + { + node = BinaryExpr::new( + Box::new(node), + o.token_type.try_into().unwrap(), + Box::new(self.comparison()?), + ) + .into(); + } + Ok(node) + } + + fn comparison(&mut self) -> Result { + let mut node = self.term()?; + + while let Some(o) = self.token_iter.next_if(|t| { + matches!( + t.token_type, + TokenType::Greater + | TokenType::GreaterEqual + | TokenType::Less | TokenType::LessEqual + ) + }) { + node = BinaryExpr::new( + Box::new(node), + o.token_type.try_into().unwrap(), + Box::new(self.comparison()?), + ) + .into(); + } + Ok(node) + } + + fn term(&mut self) -> Result { + let mut node = self.factor()?; + + while let Some(o) = self + .token_iter + .next_if(|t| matches!(t.token_type, TokenType::Minus | TokenType::Plus)) + { + node = BinaryExpr::new( + Box::new(node), + o.token_type.try_into().unwrap(), + Box::new(self.comparison()?), + ) + .into() + } + Ok(node) + } + + fn factor(&mut self) -> Result { + let mut node = self.unary()?; + + while let Some(o) = self + .token_iter + .next_if(|t| matches!(t.token_type, TokenType::Star | TokenType::Slash)) + { + node = BinaryExpr::new( + Box::new(node), + o.token_type.try_into().unwrap(), + Box::new(self.comparison()?), + ) + .into(); + } + Ok(node) + } + + fn unary(&mut self) -> Result { + if let Some(op) = self + .token_iter + .next_if(|t| matches!(t.token_type, TokenType::Bang | TokenType::Minus)) + { + let right = Box::new(self.unary()?); + Ok(ASTNode::UnaryExpr(astnode::UnaryExpr::new( + op.token_type.try_into().unwrap(), + right, + ))) + } else { + self.primary() + } + } + + fn primary(&mut self) -> Result { + let node = match self.token_iter.next().map(|it| it.token_type) { + Some(TokenType::False) => ASTNode::Literal(astnode::Literal::Bool(false)), + Some(TokenType::True) => ASTNode::Literal(astnode::Literal::Bool(false)), + Some(TokenType::Int(i)) => ASTNode::Literal(astnode::Literal::Int(i)), + Some(TokenType::String(i)) => ASTNode::Literal(astnode::Literal::String(i)), + Some(TokenType::Float(f)) => ASTNode::Literal(astnode::Literal::Float(f)), + Some(TokenType::LeftParen) => { + let expr = self.expression()?; + let group = astnode::GroupingExpr::new(Box::new(expr)); + match self + .token_iter + .next_if(|v| matches!(v.token_type, TokenType::RightParen)) + { + Some(_) => return Ok(group.into()), + None => return Err(todo!()), + } + } + Some(a) => panic!("{:#?}", a), + None => todo!(), + }; + Ok(node) + } +} diff --git a/src/lexer/mod.rs b/src/lexer/mod.rs index 8a9ab4c..bbedc08 100644 --- a/src/lexer/mod.rs +++ b/src/lexer/mod.rs @@ -10,7 +10,7 @@ use self::token::Token; #[derive(Debug)] pub struct Lexer<'a> { - source: &'a str, + _source: &'a str, source_iter: iter::Peekable>, line: usize, } @@ -18,7 +18,7 @@ pub struct Lexer<'a> { impl<'a> Lexer<'a> { pub fn new(code: &'a str) -> Lexer<'a> { return Lexer { - source: code, + _source: code, source_iter: code.chars().peekable(), line: 0, }; @@ -39,7 +39,6 @@ impl<'a> Lexer<'a> { None => (), } } - tokens.push(self.get_token(token::TokenType::Eof)); if errors.is_empty() { Ok(tokens) @@ -81,6 +80,7 @@ impl<'a> Lexer<'a> { '+' => self.get_token(token::TokenType::Plus), ';' => self.get_token(token::TokenType::Semicolon), '*' => self.get_token(token::TokenType::Star), + '/' => self.get_token(token::TokenType::Slash), '!' => self.get_token_if_next_eq_or( '=', token::TokenType::BangEqual, @@ -94,7 +94,7 @@ impl<'a> Lexer<'a> { '<' => self.get_token_if_next_eq_or( '=', token::TokenType::LessEqual, - token::TokenType::Equal, + token::TokenType::Less, ), '>' => self.get_token_if_next_eq_or( '=', @@ -111,7 +111,7 @@ impl<'a> Lexer<'a> { loop { let next_char = self.source_iter.next(); if let Some('\n') = next_char { - self.line+=1; + self.line += 1; } match next_char { Some('"') => break, diff --git a/src/lexer/token.rs b/src/lexer/token.rs index e86519d..18c11c6 100644 --- a/src/lexer/token.rs +++ b/src/lexer/token.rs @@ -45,5 +45,4 @@ pub enum TokenType { True, Let, While, - Eof, } diff --git a/src/main.rs b/src/main.rs index 1326fa9..cdbe11f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,4 @@ +mod ast; mod lexer; use lexer::Lexer; @@ -27,7 +28,12 @@ fn run_repl() { fn run(code: &str) -> Result<(), Vec> { let mut lexer = Lexer::new(code); - println!("{:?}", lexer.scan_tokens()?); + let tokens = lexer.scan_tokens()?; + println!("{:?}", tokens); + let mut parser = crate::ast::parser::Parser::new(tokens.into_iter()); + let expressions = parser.scan_expressions().unwrap(); + println!("{:?}", expressions); + Ok(()) }