crftng-intrprtrs/src/ast/expression/expression_node.rs

197 lines
4.3 KiB
Rust

use crate::lexer::token;
use from_variants::FromVariants;
use match_any::match_any;
use std::fmt::{Debug, Display};
#[derive(FromVariants)]
pub enum ExpressionNode {
BinaryExpr(BinaryExpr),
GroupingExpr(GroupingExpr),
Literal(Literal),
UnaryExpr(UnaryExpr),
}
macro_rules! all_variants {
($expr:expr, $val_name:ident => $expr_arm:expr) => {
{
use match_any::match_any;
use $crate::ast::expression::expression_node::*;
match_any!($expr, ExpressionNode::BinaryExpr($val_name) | ExpressionNode::GroupingExpr($val_name) | ExpressionNode::Literal($val_name) | ExpressionNode::UnaryExpr($val_name) => $expr_arm)
}
};
}
pub(crate) use all_variants;
impl Debug for ExpressionNode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
all_variants!(self, n => n.fmt(f))
}
}
pub enum UnaryOperator {
Minus,
Bang,
}
impl TryFrom<token::TokenType> for UnaryOperator {
type Error = EnumConvertError;
fn try_from(value: token::TokenType) -> Result<Self, Self::Error> {
Ok(match value {
token::TokenType::Bang => Self::Bang,
token::TokenType::Minus => Self::Minus,
_ => return Err(EnumConvertError),
})
}
}
impl Display for UnaryOperator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match *self {
UnaryOperator::Minus => "-",
UnaryOperator::Bang => "!",
}
)
}
}
pub enum Operator {
BangEqual,
Equal,
EqualEqual,
Greater,
GreaterEqual,
Less,
LessEqual,
}
#[derive(Debug)]
pub struct EnumConvertError;
impl std::fmt::Display for EnumConvertError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Couldn't convert between enums")
}
}
impl std::error::Error for EnumConvertError {}
impl TryFrom<token::TokenType> for Operator {
type Error = EnumConvertError;
fn try_from(value: token::TokenType) -> Result<Self, Self::Error> {
Ok(match value {
token::TokenType::BangEqual => Self::BangEqual,
token::TokenType::Equal => Self::Equal,
token::TokenType::EqualEqual => Self::EqualEqual,
token::TokenType::Greater => Self::Greater,
token::TokenType::GreaterEqual => Self::GreaterEqual,
token::TokenType::Less => Self::Less,
token::TokenType::LessEqual => Self::LessEqual,
_ => return Err(EnumConvertError),
})
}
}
impl Display for Operator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match *self {
Operator::Less => "<",
Operator::Equal => "=",
Operator::Greater => ">",
Operator::BangEqual => "!=",
Operator::LessEqual => "<=",
Operator::EqualEqual => "==",
Operator::GreaterEqual => ">=",
}
)
}
}
#[derive(Debug, Clone)]
pub enum Literal {
String(String),
Int(i32),
Float(f32),
Bool(bool),
Nil,
}
pub struct BinaryExpr {
pub left: Box<ExpressionNode>,
pub operator: Operator,
pub right: Box<ExpressionNode>,
}
impl BinaryExpr {
pub fn new(left: Box<ExpressionNode>, operator: Operator, right: Box<ExpressionNode>) -> Self {
Self {
left,
operator,
right,
}
}
}
pub struct GroupingExpr(pub Box<ExpressionNode>);
impl GroupingExpr {
pub(crate) fn new(expr: Box<ExpressionNode>) -> Self {
Self(expr)
}
}
pub struct UnaryExpr {
pub operator: UnaryOperator,
pub right: Box<ExpressionNode>,
}
impl UnaryExpr {
pub fn new(operator: UnaryOperator, right: Box<ExpressionNode>) -> Self {
Self { operator, right }
}
}
impl Debug for BinaryExpr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "({} {:?} {:?})", self.operator, self.left, self.right)
}
}
impl Debug for GroupingExpr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "({:?})", self.0)
}
}
impl Debug for UnaryExpr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "({} {:?})", self.operator, self.right)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn expression_node_ast_printer() {
let ast = ExpressionNode::BinaryExpr(BinaryExpr {
left: Box::new(ExpressionNode::UnaryExpr(UnaryExpr {
operator: UnaryOperator::Bang,
right: Box::new(ExpressionNode::Literal(Literal::Int(1))),
})),
operator: Operator::EqualEqual,
right: Box::new(ExpressionNode::Literal(Literal::Int(0))),
});
let formated = format!("{:?}", ast);
assert_eq!("(== (! Int(1)) Int(0))", formated);
}
}