use anyhow::{anyhow, Result}; use super::ast::{AggFunc, BinOp, Expr, Filter, Formula}; /// Parse a formula string like "Profit = Revenue - Cost" /// or "Tax = Revenue * 0.08 WHERE Region = \"East\"" pub fn parse_formula(raw: &str, target_category: &str) -> Result { let raw = raw.trim(); // Split on first `=` to get target = expression let eq_pos = raw .find('=') .ok_or_else(|| anyhow!("Formula must contain '=': {raw}"))?; let target = raw[..eq_pos].trim().to_string(); let rest = raw[eq_pos + 1..].trim(); // Check for WHERE clause at top level let (expr_str, filter) = split_where(rest); let filter = filter.map(parse_where).transpose()?; let expr = parse_expr(expr_str.trim())?; Ok(Formula::new(raw, target, target_category, expr, filter)) } fn split_where(s: &str) -> (&str, Option<&str>) { // Find WHERE not inside parens or quotes let bytes = s.as_bytes(); let mut depth = 0i32; let mut i = 0; while i < bytes.len() { match bytes[i] { b'(' => depth += 1, b')' => depth -= 1, b'"' => { i += 1; while i < bytes.len() && bytes[i] != b'"' { i += 1; } } _ if depth == 0 => { if s[i..].to_ascii_uppercase().starts_with("WHERE") { let before = &s[..i]; let after = &s[i + 5..]; if before.ends_with(char::is_whitespace) || i == 0 { return (before.trim(), Some(after.trim())); } } } _ => {} } i += 1; } (s, None) } fn parse_where(s: &str) -> Result { // Format: Category = "Item" or Category = Item let eq_pos = s .find('=') .ok_or_else(|| anyhow!("WHERE clause must contain '=': {s}"))?; let category = s[..eq_pos].trim().to_string(); let item_raw = s[eq_pos + 1..].trim(); let item = item_raw.trim_matches('"').to_string(); Ok(Filter { category, item }) } /// Parse an expression using recursive descent pub fn parse_expr(s: &str) -> Result { let tokens = tokenize(s)?; let mut pos = 0; let expr = parse_add_sub(&tokens, &mut pos)?; if pos < tokens.len() { return Err(anyhow!( "Unexpected token at position {pos}: {:?}", tokens[pos] )); } Ok(expr) } #[derive(Debug, Clone, PartialEq)] enum Token { Number(f64), Ident(String), Str(String), Plus, Minus, Star, Slash, Caret, LParen, RParen, Comma, Eq, Ne, Lt, Gt, Le, Ge, } fn tokenize(s: &str) -> Result> { let mut tokens = Vec::new(); let chars: Vec = s.chars().collect(); let mut i = 0; while i < chars.len() { match chars[i] { ' ' | '\t' | '\n' => i += 1, '+' => { tokens.push(Token::Plus); i += 1; } '-' => { tokens.push(Token::Minus); i += 1; } '*' => { tokens.push(Token::Star); i += 1; } '/' => { tokens.push(Token::Slash); i += 1; } '^' => { tokens.push(Token::Caret); i += 1; } '(' => { tokens.push(Token::LParen); i += 1; } ')' => { tokens.push(Token::RParen); i += 1; } ',' => { tokens.push(Token::Comma); i += 1; } '!' if chars.get(i + 1) == Some(&'=') => { tokens.push(Token::Ne); i += 2; } '<' if chars.get(i + 1) == Some(&'=') => { tokens.push(Token::Le); i += 2; } '>' if chars.get(i + 1) == Some(&'=') => { tokens.push(Token::Ge); i += 2; } '<' => { tokens.push(Token::Lt); i += 1; } '>' => { tokens.push(Token::Gt); i += 1; } '=' => { tokens.push(Token::Eq); i += 1; } '"' => { i += 1; let mut s = String::new(); while i < chars.len() && chars[i] != '"' { s.push(chars[i]); i += 1; } if i < chars.len() { i += 1; } tokens.push(Token::Str(s)); } c if c.is_ascii_digit() || c == '.' => { let mut num = String::new(); while i < chars.len() && (chars[i].is_ascii_digit() || chars[i] == '.') { num.push(chars[i]); i += 1; } tokens.push(Token::Number(num.parse()?)); } c if c.is_alphabetic() || c == '_' => { let mut ident = String::new(); while i < chars.len() && (chars[i].is_alphanumeric() || chars[i] == '_' || chars[i] == ' ') { // Don't consume trailing spaces if next non-space is operator if chars[i] == ' ' { // Peek ahead let j = i + 1; let next_nonspace = chars[j..].iter().find(|&&c| c != ' '); if matches!( next_nonspace, Some('+') | Some('-') | Some('*') | Some('/') | Some('^') | Some(')') | Some(',') | None ) { break; } } ident.push(chars[i]); i += 1; } let ident = ident.trim_end().to_string(); tokens.push(Token::Ident(ident)); } c => return Err(anyhow!("Unexpected character '{c}' in expression")), } } Ok(tokens) } fn parse_add_sub(tokens: &[Token], pos: &mut usize) -> Result { let mut left = parse_mul_div(tokens, pos)?; while *pos < tokens.len() { let op = match &tokens[*pos] { Token::Plus => BinOp::Add, Token::Minus => BinOp::Sub, _ => break, }; *pos += 1; let right = parse_mul_div(tokens, pos)?; left = Expr::BinOp(op, Box::new(left), Box::new(right)); } Ok(left) } fn parse_mul_div(tokens: &[Token], pos: &mut usize) -> Result { let mut left = parse_pow(tokens, pos)?; while *pos < tokens.len() { let op = match &tokens[*pos] { Token::Star => BinOp::Mul, Token::Slash => BinOp::Div, _ => break, }; *pos += 1; let right = parse_pow(tokens, pos)?; left = Expr::BinOp(op, Box::new(left), Box::new(right)); } Ok(left) } fn parse_pow(tokens: &[Token], pos: &mut usize) -> Result { let base = parse_unary(tokens, pos)?; if *pos < tokens.len() && tokens[*pos] == Token::Caret { *pos += 1; let exp = parse_unary(tokens, pos)?; return Ok(Expr::BinOp(BinOp::Pow, Box::new(base), Box::new(exp))); } Ok(base) } fn parse_unary(tokens: &[Token], pos: &mut usize) -> Result { if *pos < tokens.len() && tokens[*pos] == Token::Minus { *pos += 1; let e = parse_primary(tokens, pos)?; return Ok(Expr::UnaryMinus(Box::new(e))); } parse_primary(tokens, pos) } fn parse_primary(tokens: &[Token], pos: &mut usize) -> Result { if *pos >= tokens.len() { return Err(anyhow!("Unexpected end of expression")); } match &tokens[*pos].clone() { Token::Number(n) => { *pos += 1; Ok(Expr::Number(*n)) } Token::Ident(name) => { let name = name.clone(); *pos += 1; // Check for function call let upper = name.to_ascii_uppercase(); match upper.as_str() { "SUM" | "AVG" | "MIN" | "MAX" | "COUNT" => { let func = match upper.as_str() { "SUM" => AggFunc::Sum, "AVG" => AggFunc::Avg, "MIN" => AggFunc::Min, "MAX" => AggFunc::Max, "COUNT" => AggFunc::Count, _ => unreachable!(), }; if *pos < tokens.len() && tokens[*pos] == Token::LParen { *pos += 1; let inner = parse_add_sub(tokens, pos)?; // Optional WHERE filter let filter = if *pos < tokens.len() { if let Token::Ident(kw) = &tokens[*pos] { if kw.eq_ignore_ascii_case("WHERE") { *pos += 1; let cat = match &tokens[*pos] { Token::Ident(s) => { let s = s.clone(); *pos += 1; s } t => { return Err(anyhow!( "Expected category name, got {t:?}" )) } }; // expect = if *pos < tokens.len() && tokens[*pos] == Token::Eq { *pos += 1; } let item = match &tokens[*pos] { Token::Str(s) | Token::Ident(s) => { let s = s.clone(); *pos += 1; s } t => return Err(anyhow!("Expected item name, got {t:?}")), }; Some(Filter { category: cat, item, }) } else { None } } else { None } } else { None }; // expect ) if *pos < tokens.len() && tokens[*pos] == Token::RParen { *pos += 1; } else { return Err(anyhow!("Expected ')' to close aggregate function")); } return Ok(Expr::Agg(func, Box::new(inner), filter)); } Ok(Expr::Ref(name)) } "IF" => { if *pos < tokens.len() && tokens[*pos] == Token::LParen { *pos += 1; let cond = parse_comparison(tokens, pos)?; if *pos < tokens.len() && tokens[*pos] == Token::Comma { *pos += 1; } let then = parse_add_sub(tokens, pos)?; if *pos < tokens.len() && tokens[*pos] == Token::Comma { *pos += 1; } let else_ = parse_add_sub(tokens, pos)?; if *pos < tokens.len() && tokens[*pos] == Token::RParen { *pos += 1; } else { return Err(anyhow!("Expected ')' to close IF(...)")); } return Ok(Expr::If(Box::new(cond), Box::new(then), Box::new(else_))); } Ok(Expr::Ref(name)) } _ => Ok(Expr::Ref(name)), } } Token::LParen => { *pos += 1; let e = parse_add_sub(tokens, pos)?; if *pos < tokens.len() && tokens[*pos] == Token::RParen { *pos += 1; } Ok(e) } t => Err(anyhow!("Unexpected token in expression: {t:?}")), } } fn parse_comparison(tokens: &[Token], pos: &mut usize) -> Result { let left = parse_add_sub(tokens, pos)?; if *pos >= tokens.len() { return Ok(left); } let op = match &tokens[*pos] { Token::Eq => BinOp::Eq, Token::Ne => BinOp::Ne, Token::Lt => BinOp::Lt, Token::Gt => BinOp::Gt, Token::Le => BinOp::Le, Token::Ge => BinOp::Ge, _ => return Ok(left), }; *pos += 1; let right = parse_add_sub(tokens, pos)?; Ok(Expr::BinOp(op, Box::new(left), Box::new(right))) } #[cfg(test)] mod tests { use super::parse_formula; use crate::formula::{AggFunc, BinOp, Expr}; #[test] fn parse_simple_subtraction() { let f = parse_formula("Profit = Revenue - Cost", "Measure").unwrap(); assert_eq!(f.target, "Profit"); assert_eq!(f.target_category, "Measure"); assert!(matches!(f.expr, Expr::BinOp(BinOp::Sub, _, _))); } #[test] fn parse_where_clause() { let f = parse_formula("EastRev = Revenue WHERE Region = \"East\"", "Measure").unwrap(); assert_eq!(f.target, "EastRev"); let filter = f.filter.as_ref().unwrap(); assert_eq!(filter.category, "Region"); assert_eq!(filter.item, "East"); } #[test] fn parse_sum_aggregation() { let f = parse_formula("Total = SUM(Revenue)", "Measure").unwrap(); assert!(matches!(f.expr, Expr::Agg(AggFunc::Sum, _, _))); } #[test] fn parse_avg_aggregation() { let f = parse_formula("Avg = AVG(Revenue)", "Measure").unwrap(); assert!(matches!(f.expr, Expr::Agg(AggFunc::Avg, _, _))); } #[test] fn parse_if_expression() { let f = parse_formula("Capped = IF(Revenue > 1000, 1000, Revenue)", "Measure").unwrap(); assert!(matches!(f.expr, Expr::If(_, _, _))); } #[test] fn parse_numeric_literal() { let f = parse_formula("Fixed = 42", "Measure").unwrap(); assert!(matches!(f.expr, Expr::Number(n) if (n - 42.0).abs() < 1e-10)); } #[test] fn parse_chained_arithmetic() { parse_formula("X = (A + B) * (C - D)", "Cat").unwrap(); } #[test] fn parse_missing_equals_returns_error() { assert!(parse_formula("BadFormula Revenue Cost", "Cat").is_err()); } }