From 599f1adcbd7cf779d4cf9674b1b5ecfccea56bf7 Mon Sep 17 00:00:00 2001 From: Ed L Date: Tue, 24 Mar 2026 00:20:08 -0700 Subject: [PATCH] refactor: replace BinOp string with typed enum in Expr AST Previously Expr::BinOp(String, ...) accepted any string as an operator. Invalid operators (e.g. "diagonal") would compile fine and silently return CellValue::Empty at eval time. Now BinOp is an enum with variants Add/Sub/Mul/Div/Pow/Eq/Ne/Lt/Gt/Le/Ge. The parser produces enum variants directly; the evaluator pattern-matches exhaustively with no fallback branch. An invalid operator is now a compile error at the call site, and the compiler ensures every variant is handled in both eval_expr and eval_bool. Co-Authored-By: Claude Sonnet 4.6 --- src/formula/ast.rs | 21 ++++++++++++++++++++- src/formula/mod.rs | 2 +- src/formula/parser.rs | 34 +++++++++++++++++----------------- src/model/model.rs | 37 ++++++++++++++++++++++--------------- 4 files changed, 60 insertions(+), 34 deletions(-) diff --git a/src/formula/ast.rs b/src/formula/ast.rs index 4bb680d..30095ec 100644 --- a/src/formula/ast.rs +++ b/src/formula/ast.rs @@ -9,6 +9,25 @@ pub enum AggFunc { Count, } +/// Arithmetic and comparison operators used in binary expressions. +/// Having an enum (rather than a raw String) means the parser must +/// produce a valid operator; invalid operators are caught at parse +/// time rather than silently returning Empty at eval time. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum BinOp { + Add, + Sub, + Mul, + Div, + Pow, + Eq, + Ne, + Lt, + Gt, + Le, + Ge, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Filter { pub category: String, @@ -19,7 +38,7 @@ pub struct Filter { pub enum Expr { Number(f64), Ref(String), - BinOp(String, Box, Box), + BinOp(BinOp, Box, Box), UnaryMinus(Box), Agg(AggFunc, Box, Option), If(Box, Box, Box), diff --git a/src/formula/mod.rs b/src/formula/mod.rs index 5118069..56607af 100644 --- a/src/formula/mod.rs +++ b/src/formula/mod.rs @@ -1,5 +1,5 @@ pub mod parser; pub mod ast; -pub use ast::{AggFunc, Expr, Formula}; +pub use ast::{AggFunc, BinOp, Expr, Formula}; pub use parser::parse_formula; diff --git a/src/formula/parser.rs b/src/formula/parser.rs index 25ab00b..b5f5cb3 100644 --- a/src/formula/parser.rs +++ b/src/formula/parser.rs @@ -1,6 +1,6 @@ use anyhow::{anyhow, Result}; -use super::ast::{AggFunc, Expr, Filter, Formula}; +use super::ast::{AggFunc, BinOp, Expr, Filter, Formula}; /// Parse a formula string like "Profit = Revenue - Cost" /// or "Tax = Revenue * 0.08 WHERE Region = \"East\"" @@ -161,13 +161,13 @@ fn parse_add_sub(tokens: &[Token], pos: &mut usize) -> Result { let mut left = parse_mul_div(tokens, pos)?; while *pos < tokens.len() { let op = match &tokens[*pos] { - Token::Plus => "+", - Token::Minus => "-", + Token::Plus => BinOp::Add, + Token::Minus => BinOp::Sub, _ => break, }; *pos += 1; let right = parse_mul_div(tokens, pos)?; - left = Expr::BinOp(op.to_string(), Box::new(left), Box::new(right)); + left = Expr::BinOp(op, Box::new(left), Box::new(right)); } Ok(left) } @@ -176,13 +176,13 @@ fn parse_mul_div(tokens: &[Token], pos: &mut usize) -> Result { let mut left = parse_pow(tokens, pos)?; while *pos < tokens.len() { let op = match &tokens[*pos] { - Token::Star => "*", - Token::Slash => "/", + Token::Star => BinOp::Mul, + Token::Slash => BinOp::Div, _ => break, }; *pos += 1; let right = parse_pow(tokens, pos)?; - left = Expr::BinOp(op.to_string(), Box::new(left), Box::new(right)); + left = Expr::BinOp(op, Box::new(left), Box::new(right)); } Ok(left) } @@ -192,7 +192,7 @@ fn parse_pow(tokens: &[Token], pos: &mut usize) -> Result { if *pos < tokens.len() && tokens[*pos] == Token::Caret { *pos += 1; let exp = parse_unary(tokens, pos)?; - return Ok(Expr::BinOp("^".to_string(), Box::new(base), Box::new(exp))); + return Ok(Expr::BinOp(BinOp::Pow, Box::new(base), Box::new(exp))); } Ok(base) } @@ -298,30 +298,30 @@ fn parse_comparison(tokens: &[Token], pos: &mut usize) -> Result { let left = parse_add_sub(tokens, pos)?; if *pos >= tokens.len() { return Ok(left); } let op = match &tokens[*pos] { - Token::Eq => "=", - Token::Ne => "!=", - Token::Lt => "<", - Token::Gt => ">", - Token::Le => "<=", - Token::Ge => ">=", + Token::Eq => BinOp::Eq, + Token::Ne => BinOp::Ne, + Token::Lt => BinOp::Lt, + Token::Gt => BinOp::Gt, + Token::Le => BinOp::Le, + Token::Ge => BinOp::Ge, _ => return Ok(left), }; *pos += 1; let right = parse_add_sub(tokens, pos)?; - Ok(Expr::BinOp(op.to_string(), Box::new(left), Box::new(right))) + Ok(Expr::BinOp(op, Box::new(left), Box::new(right))) } #[cfg(test)] mod tests { use super::parse_formula; - use crate::formula::{Expr, AggFunc}; + use crate::formula::{AggFunc, BinOp, Expr}; #[test] fn parse_simple_subtraction() { let f = parse_formula("Profit = Revenue - Cost", "Measure").unwrap(); assert_eq!(f.target, "Profit"); assert_eq!(f.target_category, "Measure"); - assert!(matches!(f.expr, Expr::BinOp(ref op, _, _) if op == "-")); + assert!(matches!(f.expr, Expr::BinOp(BinOp::Sub, _, _))); } #[test] diff --git a/src/model/model.rs b/src/model/model.rs index 58ffa79..3efda80 100644 --- a/src/model/model.rs +++ b/src/model/model.rs @@ -193,15 +193,19 @@ impl Model { model.evaluate(&new_key).as_f64() } Expr::BinOp(op, l, r) => { + use crate::formula::BinOp; let lv = eval_expr(l, context, model, target_category)?; let rv = eval_expr(r, context, model, target_category)?; - Some(match op.as_str() { - "+" => lv + rv, - "-" => lv - rv, - "*" => lv * rv, - "/" => { if rv == 0.0 { return None; } lv / rv } - "^" => lv.powf(rv), - _ => return None, + Some(match op { + BinOp::Add => lv + rv, + BinOp::Sub => lv - rv, + BinOp::Mul => lv * rv, + BinOp::Div => { if rv == 0.0 { return None; } lv / rv } + BinOp::Pow => lv.powf(rv), + // Comparison operators are handled by eval_bool; reaching + // here means a comparison was used where a number is expected. + BinOp::Eq | BinOp::Ne | BinOp::Lt | + BinOp::Gt | BinOp::Le | BinOp::Ge => return None, }) } Expr::UnaryMinus(e) => Some(-eval_expr(e, context, model, target_category)?), @@ -247,18 +251,21 @@ impl Model { model: &Model, target_category: &str, ) -> Option { + use crate::formula::BinOp; match expr { Expr::BinOp(op, l, r) => { let lv = eval_expr(l, context, model, target_category)?; let rv = eval_expr(r, context, model, target_category)?; - Some(match op.as_str() { - "=" | "==" => (lv - rv).abs() < 1e-10, - "!=" => (lv - rv).abs() >= 1e-10, - "<" => lv < rv, - ">" => lv > rv, - "<=" => lv <= rv, - ">=" => lv >= rv, - _ => return None, + Some(match op { + BinOp::Eq => (lv - rv).abs() < 1e-10, + BinOp::Ne => (lv - rv).abs() >= 1e-10, + BinOp::Lt => lv < rv, + BinOp::Gt => lv > rv, + BinOp::Le => lv <= rv, + BinOp::Ge => lv >= rv, + // Arithmetic operators are not comparisons + BinOp::Add | BinOp::Sub | BinOp::Mul | + BinOp::Div | BinOp::Pow => return None, }) } _ => None,