use anyhow::{anyhow, Result}; use super::ast::{AggFunc, BinOp, Expr, Filter, Formula}; /// Parse a formula string like "Profit = Revenue - Cost" /// or "Tax = Revenue * 0.08 WHERE Region = \"East\"" pub fn parse_formula(raw: &str, target_category: &str) -> Result { let raw = raw.trim(); // Split on first `=` to get target = expression let eq_pos = raw .find('=') .ok_or_else(|| anyhow!("Formula must contain '=': {raw}"))?; let target = raw[..eq_pos].trim().to_string(); let rest = raw[eq_pos + 1..].trim(); // Check for WHERE clause at top level let (expr_str, filter) = split_where(rest); let filter = filter.map(parse_where).transpose()?; let expr = parse_expr(expr_str.trim())?; Ok(Formula::new(raw, target, target_category, expr, filter)) } fn split_where(s: &str) -> (&str, Option<&str>) { // Find WHERE not inside parens or quotes let bytes = s.as_bytes(); let mut depth = 0i32; let mut i = 0; while i < bytes.len() { match bytes[i] { b'(' => depth += 1, b')' => depth -= 1, b'"' => { i += 1; while i < bytes.len() && bytes[i] != b'"' { i += 1; } } _ if depth == 0 => { if s[i..].to_ascii_uppercase().starts_with("WHERE") { let before = &s[..i]; let after = &s[i + 5..]; if before.ends_with(char::is_whitespace) || i == 0 { return (before.trim(), Some(after.trim())); } } } _ => {} } i += 1; } (s, None) } fn parse_where(s: &str) -> Result { // Format: Category = "Item" or Category = Item let eq_pos = s .find('=') .ok_or_else(|| anyhow!("WHERE clause must contain '=': {s}"))?; let category = s[..eq_pos].trim().to_string(); let item_raw = s[eq_pos + 1..].trim(); let item = item_raw.trim_matches('"').to_string(); Ok(Filter { category, item }) } /// Parse an expression using recursive descent pub fn parse_expr(s: &str) -> Result { let tokens = tokenize(s)?; let mut pos = 0; let expr = parse_add_sub(&tokens, &mut pos)?; if pos < tokens.len() { return Err(anyhow!( "Unexpected token at position {pos}: {:?}", tokens[pos] )); } Ok(expr) } #[derive(Debug, Clone, PartialEq)] enum Token { Number(f64), Ident(String), Str(String), Plus, Minus, Star, Slash, Caret, LParen, RParen, Comma, Eq, Ne, Lt, Gt, Le, Ge, } fn tokenize(s: &str) -> Result> { let mut tokens = Vec::new(); let chars: Vec = s.chars().collect(); let mut i = 0; while i < chars.len() { match chars[i] { ' ' | '\t' | '\n' => i += 1, '+' => { tokens.push(Token::Plus); i += 1; } '-' => { tokens.push(Token::Minus); i += 1; } '*' => { tokens.push(Token::Star); i += 1; } '/' => { tokens.push(Token::Slash); i += 1; } '^' => { tokens.push(Token::Caret); i += 1; } '(' => { tokens.push(Token::LParen); i += 1; } ')' => { tokens.push(Token::RParen); i += 1; } ',' => { tokens.push(Token::Comma); i += 1; } '!' if chars.get(i + 1) == Some(&'=') => { tokens.push(Token::Ne); i += 2; } '<' if chars.get(i + 1) == Some(&'=') => { tokens.push(Token::Le); i += 2; } '>' if chars.get(i + 1) == Some(&'=') => { tokens.push(Token::Ge); i += 2; } '<' => { tokens.push(Token::Lt); i += 1; } '>' => { tokens.push(Token::Gt); i += 1; } '=' => { tokens.push(Token::Eq); i += 1; } '"' => { i += 1; let mut s = String::new(); while i < chars.len() && chars[i] != '"' { s.push(chars[i]); i += 1; } if i < chars.len() { i += 1; } tokens.push(Token::Str(s)); } c if c.is_ascii_digit() || c == '.' => { let mut num = String::new(); while i < chars.len() && (chars[i].is_ascii_digit() || chars[i] == '.') { num.push(chars[i]); i += 1; } tokens.push(Token::Number(num.parse()?)); } c if c.is_alphabetic() || c == '_' => { let mut ident = String::new(); while i < chars.len() && (chars[i].is_alphanumeric() || chars[i] == '_' || chars[i] == ' ') { // Don't consume trailing spaces if next non-space is operator if chars[i] == ' ' { // Peek ahead past spaces to find the next word/token let j = i + 1; let next_nonspace = chars[j..].iter().find(|&&c| c != ' '); if matches!( next_nonspace, Some('+') | Some('-') | Some('*') | Some('/') | Some('^') | Some(')') | Some(',') | Some('<') | Some('>') | Some('=') | Some('!') | Some('"') | None ) { break; } // Break if the identifier collected so far is a keyword let trimmed = ident.trim_end().to_ascii_uppercase(); if matches!( trimmed.as_str(), "WHERE" | "SUM" | "AVG" | "MIN" | "MAX" | "COUNT" | "IF" ) { break; } // Also break if the next word is a keyword let rest: String = chars[j..].iter().collect(); let next_word: String = rest .trim_start() .chars() .take_while(|c| c.is_alphanumeric() || *c == '_') .collect(); let upper = next_word.to_ascii_uppercase(); if matches!( upper.as_str(), "WHERE" | "SUM" | "AVG" | "MIN" | "MAX" | "COUNT" | "IF" ) { break; } } ident.push(chars[i]); i += 1; } let ident = ident.trim_end().to_string(); tokens.push(Token::Ident(ident)); } c => return Err(anyhow!("Unexpected character '{c}' in expression")), } } Ok(tokens) } fn parse_add_sub(tokens: &[Token], pos: &mut usize) -> Result { let mut left = parse_mul_div(tokens, pos)?; while *pos < tokens.len() { let op = match &tokens[*pos] { Token::Plus => BinOp::Add, Token::Minus => BinOp::Sub, _ => break, }; *pos += 1; let right = parse_mul_div(tokens, pos)?; left = Expr::BinOp(op, Box::new(left), Box::new(right)); } Ok(left) } fn parse_mul_div(tokens: &[Token], pos: &mut usize) -> Result { let mut left = parse_pow(tokens, pos)?; while *pos < tokens.len() { let op = match &tokens[*pos] { Token::Star => BinOp::Mul, Token::Slash => BinOp::Div, _ => break, }; *pos += 1; let right = parse_pow(tokens, pos)?; left = Expr::BinOp(op, Box::new(left), Box::new(right)); } Ok(left) } fn parse_pow(tokens: &[Token], pos: &mut usize) -> Result { let base = parse_unary(tokens, pos)?; if *pos < tokens.len() && tokens[*pos] == Token::Caret { *pos += 1; let exp = parse_unary(tokens, pos)?; return Ok(Expr::BinOp(BinOp::Pow, Box::new(base), Box::new(exp))); } Ok(base) } fn parse_unary(tokens: &[Token], pos: &mut usize) -> Result { if *pos < tokens.len() && tokens[*pos] == Token::Minus { *pos += 1; let e = parse_primary(tokens, pos)?; return Ok(Expr::UnaryMinus(Box::new(e))); } parse_primary(tokens, pos) } fn parse_primary(tokens: &[Token], pos: &mut usize) -> Result { if *pos >= tokens.len() { return Err(anyhow!("Unexpected end of expression")); } match &tokens[*pos].clone() { Token::Number(n) => { *pos += 1; Ok(Expr::Number(*n)) } Token::Ident(name) => { let name = name.clone(); *pos += 1; // Check for function call let upper = name.to_ascii_uppercase(); match upper.as_str() { "SUM" | "AVG" | "MIN" | "MAX" | "COUNT" => { let func = match upper.as_str() { "SUM" => AggFunc::Sum, "AVG" => AggFunc::Avg, "MIN" => AggFunc::Min, "MAX" => AggFunc::Max, "COUNT" => AggFunc::Count, _ => unreachable!(), }; if *pos < tokens.len() && tokens[*pos] == Token::LParen { *pos += 1; let inner = parse_add_sub(tokens, pos)?; // Optional WHERE filter let filter = if *pos < tokens.len() { if let Token::Ident(kw) = &tokens[*pos] { if kw.eq_ignore_ascii_case("WHERE") { *pos += 1; let cat = match &tokens[*pos] { Token::Ident(s) => { let s = s.clone(); *pos += 1; s } t => { return Err(anyhow!( "Expected category name, got {t:?}" )) } }; // expect = if *pos < tokens.len() && tokens[*pos] == Token::Eq { *pos += 1; } let item = match &tokens[*pos] { Token::Str(s) | Token::Ident(s) => { let s = s.clone(); *pos += 1; s } t => return Err(anyhow!("Expected item name, got {t:?}")), }; Some(Filter { category: cat, item, }) } else { None } } else { None } } else { None }; // expect ) if *pos < tokens.len() && tokens[*pos] == Token::RParen { *pos += 1; } else { return Err(anyhow!("Expected ')' to close aggregate function")); } return Ok(Expr::Agg(func, Box::new(inner), filter)); } Ok(Expr::Ref(name)) } "IF" => { if *pos < tokens.len() && tokens[*pos] == Token::LParen { *pos += 1; let cond = parse_comparison(tokens, pos)?; if *pos < tokens.len() && tokens[*pos] == Token::Comma { *pos += 1; } let then = parse_add_sub(tokens, pos)?; if *pos < tokens.len() && tokens[*pos] == Token::Comma { *pos += 1; } let else_ = parse_add_sub(tokens, pos)?; if *pos < tokens.len() && tokens[*pos] == Token::RParen { *pos += 1; } else { return Err(anyhow!("Expected ')' to close IF(...)")); } return Ok(Expr::If(Box::new(cond), Box::new(then), Box::new(else_))); } Ok(Expr::Ref(name)) } _ => Ok(Expr::Ref(name)), } } Token::LParen => { *pos += 1; let e = parse_add_sub(tokens, pos)?; if *pos < tokens.len() && tokens[*pos] == Token::RParen { *pos += 1; } Ok(e) } t => Err(anyhow!("Unexpected token in expression: {t:?}")), } } fn parse_comparison(tokens: &[Token], pos: &mut usize) -> Result { let left = parse_add_sub(tokens, pos)?; if *pos >= tokens.len() { return Ok(left); } let op = match &tokens[*pos] { Token::Eq => BinOp::Eq, Token::Ne => BinOp::Ne, Token::Lt => BinOp::Lt, Token::Gt => BinOp::Gt, Token::Le => BinOp::Le, Token::Ge => BinOp::Ge, _ => return Ok(left), }; *pos += 1; let right = parse_add_sub(tokens, pos)?; Ok(Expr::BinOp(op, Box::new(left), Box::new(right))) } #[cfg(test)] mod tests { use super::parse_formula; use crate::formula::{AggFunc, BinOp, Expr}; #[test] fn parse_simple_subtraction() { let f = parse_formula("Profit = Revenue - Cost", "Measure").unwrap(); assert_eq!(f.target, "Profit"); assert_eq!(f.target_category, "Measure"); assert!(matches!(f.expr, Expr::BinOp(BinOp::Sub, _, _))); } #[test] fn parse_where_clause() { let f = parse_formula("EastRev = Revenue WHERE Region = \"East\"", "Measure").unwrap(); assert_eq!(f.target, "EastRev"); let filter = f.filter.as_ref().unwrap(); assert_eq!(filter.category, "Region"); assert_eq!(filter.item, "East"); } #[test] fn parse_sum_aggregation() { let f = parse_formula("Total = SUM(Revenue)", "Measure").unwrap(); assert!(matches!(f.expr, Expr::Agg(AggFunc::Sum, _, _))); } #[test] fn parse_avg_aggregation() { let f = parse_formula("Avg = AVG(Revenue)", "Measure").unwrap(); assert!(matches!(f.expr, Expr::Agg(AggFunc::Avg, _, _))); } #[test] fn parse_if_expression() { let f = parse_formula("Capped = IF(Revenue > 1000, 1000, Revenue)", "Measure").unwrap(); assert!(matches!(f.expr, Expr::If(_, _, _))); } #[test] fn parse_numeric_literal() { let f = parse_formula("Fixed = 42", "Measure").unwrap(); assert!(matches!(f.expr, Expr::Number(n) if (n - 42.0).abs() < 1e-10)); } #[test] fn parse_chained_arithmetic() { parse_formula("X = (A + B) * (C - D)", "Cat").unwrap(); } #[test] fn parse_missing_equals_returns_error() { assert!(parse_formula("BadFormula Revenue Cost", "Cat").is_err()); } // ── Aggregate functions ───────────────────────────────────────────── #[test] fn parse_min_aggregation() { let f = parse_formula("Lo = MIN(Revenue)", "Measure").unwrap(); assert!(matches!(f.expr, Expr::Agg(AggFunc::Min, _, _))); } #[test] fn parse_max_aggregation() { let f = parse_formula("Hi = MAX(Revenue)", "Measure").unwrap(); assert!(matches!(f.expr, Expr::Agg(AggFunc::Max, _, _))); } #[test] fn parse_count_aggregation() { let f = parse_formula("N = COUNT(Revenue)", "Measure").unwrap(); assert!(matches!(f.expr, Expr::Agg(AggFunc::Count, _, _))); } // ── Aggregate with WHERE filter ───────────────────────────────────── #[test] fn parse_sum_with_top_level_where_works() { let f = parse_formula( "EastTotal = SUM(Revenue) WHERE Region = \"East\"", "Measure", ) .unwrap(); assert!(matches!(f.expr, Expr::Agg(AggFunc::Sum, _, _))); let filter = f.filter.as_ref().unwrap(); assert_eq!(filter.category, "Region"); assert_eq!(filter.item, "East"); } /// Regression: WHERE inside aggregate parens must tokenize correctly. /// The tokenizer must not merge "Revenue WHERE" into a single identifier. #[test] fn parse_sum_with_inline_where_filter() { let f = parse_formula( "EastTotal = SUM(Revenue WHERE Region = \"East\")", "Measure", ) .unwrap(); if let Expr::Agg(AggFunc::Sum, inner, Some(filter)) = &f.expr { assert!(matches!(**inner, Expr::Ref(_))); assert_eq!(filter.category, "Region"); assert_eq!(filter.item, "East"); } else { panic!("Expected SUM with inline WHERE filter, got: {:?}", f.expr); } } // ── Comparison operators ──────────────────────────────────────────── #[test] fn parse_if_with_comparison_operators() { // Test each comparison operator in an IF expression let f = parse_formula("X = IF(A != 0, A, 1)", "Cat").unwrap(); assert!(matches!(f.expr, Expr::If(_, _, _))); let f = parse_formula("X = IF(A < 10, A, 10)", "Cat").unwrap(); assert!(matches!(f.expr, Expr::If(_, _, _))); let f = parse_formula("X = IF(A <= 10, A, 10)", "Cat").unwrap(); assert!(matches!(f.expr, Expr::If(_, _, _))); let f = parse_formula("X = IF(A >= 10, 10, A)", "Cat").unwrap(); assert!(matches!(f.expr, Expr::If(_, _, _))); let f = parse_formula("X = IF(A = B, 1, 0)", "Cat").unwrap(); assert!(matches!(f.expr, Expr::If(_, _, _))); } // ── Quoted strings in WHERE ───────────────────────────────────────── #[test] fn parse_where_with_quoted_string_inside_expression() { // WHERE inside a formula string with quotes let f = parse_formula( "X = Revenue WHERE Region = \"West Coast\"", "Measure", ) .unwrap(); let filter = f.filter.as_ref().unwrap(); assert_eq!(filter.item, "West Coast"); } // ── Power operator ────────────────────────────────────────────────── #[test] fn parse_power_operator() { let f = parse_formula("Sq = X ^ 2", "Cat").unwrap(); assert!(matches!(f.expr, Expr::BinOp(BinOp::Pow, _, _))); } // ── Unary minus ───────────────────────────────────────────────────── #[test] fn parse_unary_minus() { let f = parse_formula("Neg = -Revenue", "Measure").unwrap(); assert!(matches!(f.expr, Expr::UnaryMinus(_))); } // ── Division and multiplication ───────────────────────────────────── #[test] fn parse_multiplication() { let f = parse_formula("Double = Revenue * 2", "Measure").unwrap(); assert!(matches!(f.expr, Expr::BinOp(BinOp::Mul, _, _))); } #[test] fn parse_division() { let f = parse_formula("Half = Revenue / 2", "Measure").unwrap(); assert!(matches!(f.expr, Expr::BinOp(BinOp::Div, _, _))); } // ── Parenthesized expression ──────────────────────────────────────── #[test] fn parse_nested_parens() { let f = parse_formula("X = ((A + B))", "Cat").unwrap(); assert!(matches!(f.expr, Expr::BinOp(BinOp::Add, _, _))); } // ── Aggregate function name used as ref (no parens) ───────────────── #[test] fn parse_aggregate_name_without_parens_is_ref() { // "SUM" without parens should be treated as a reference, not a function let f = parse_formula("X = SUM + 1", "Cat").unwrap(); assert!(matches!(f.expr, Expr::BinOp(BinOp::Add, _, _))); if let Expr::BinOp(_, lhs, _) = &f.expr { assert!(matches!(**lhs, Expr::Ref(_))); } } #[test] fn parse_if_without_parens_is_ref() { // "IF" without parens should be treated as a reference let f = parse_formula("X = IF + 1", "Cat").unwrap(); if let Expr::BinOp(BinOp::Add, lhs, _) = &f.expr { assert!(matches!(**lhs, Expr::Ref(_))); } else { panic!("Expected BinOp(Add), got: {:?}", f.expr); } } // ── Quoted string in tokenizer ────────────────────────────────────── #[test] fn parse_quoted_string_in_where() { // Quoted strings work in top-level WHERE clauses let f = parse_formula("X = Revenue WHERE Region = \"East\"", "Cat").unwrap(); let filter = f.filter.as_ref().unwrap(); assert_eq!(filter.item, "East"); } // ── Error paths ───────────────────────────────────────────────────── #[test] fn parse_unexpected_token_error() { use super::parse_expr; // Extra tokens after a valid expression assert!(parse_expr("1 + 2 3").is_err()); } #[test] fn parse_unexpected_character_error() { use super::parse_expr; assert!(parse_expr("@invalid").is_err()); } #[test] fn parse_empty_expression_error() { use super::parse_expr; assert!(parse_expr("").is_err()); } #[test] fn tokenizer_breaks_at_where_keyword() { use super::tokenize; let tokens = tokenize("Revenue WHERE Region").unwrap(); // Should produce 3 tokens: Ident("Revenue"), Ident("WHERE"), Ident("Region") assert_eq!(tokens.len(), 3, "Expected 3 tokens, got: {tokens:?}"); } // ── Multi-word identifiers ────────────────────────────────────────── #[test] fn parse_multi_word_identifier() { let f = parse_formula("Total Revenue = Base Revenue + Bonus", "Measure").unwrap(); assert_eq!(f.target, "Total Revenue"); } // ── WHERE inside quotes in split_where ────────────────────────────── #[test] fn split_where_ignores_where_inside_quotes() { // WHERE inside quotes should not be treated as a keyword let f = parse_formula( "X = Revenue WHERE Region = \"WHERE\"", "Measure", ) .unwrap(); let filter = f.filter.as_ref().unwrap(); assert_eq!(filter.item, "WHERE"); } }