Files
improvise/src/formula/parser.rs
2026-04-02 09:35:02 -07:00

462 lines
15 KiB
Rust

use anyhow::{anyhow, Result};
use super::ast::{AggFunc, BinOp, Expr, Filter, Formula};
/// Parse a formula string like "Profit = Revenue - Cost"
/// or "Tax = Revenue * 0.08 WHERE Region = \"East\""
pub fn parse_formula(raw: &str, target_category: &str) -> Result<Formula> {
let raw = raw.trim();
// Split on first `=` to get target = expression
let eq_pos = raw
.find('=')
.ok_or_else(|| anyhow!("Formula must contain '=': {raw}"))?;
let target = raw[..eq_pos].trim().to_string();
let rest = raw[eq_pos + 1..].trim();
// Check for WHERE clause at top level
let (expr_str, filter) = split_where(rest);
let filter = filter.map(parse_where).transpose()?;
let expr = parse_expr(expr_str.trim())?;
Ok(Formula::new(raw, target, target_category, expr, filter))
}
fn split_where(s: &str) -> (&str, Option<&str>) {
// Find WHERE not inside parens or quotes
let bytes = s.as_bytes();
let mut depth = 0i32;
let mut i = 0;
while i < bytes.len() {
match bytes[i] {
b'(' => depth += 1,
b')' => depth -= 1,
b'"' => {
i += 1;
while i < bytes.len() && bytes[i] != b'"' {
i += 1;
}
}
_ if depth == 0 => {
if s[i..].to_ascii_uppercase().starts_with("WHERE") {
let before = &s[..i];
let after = &s[i + 5..];
if before.ends_with(char::is_whitespace) || i == 0 {
return (before.trim(), Some(after.trim()));
}
}
}
_ => {}
}
i += 1;
}
(s, None)
}
fn parse_where(s: &str) -> Result<Filter> {
// Format: Category = "Item" or Category = Item
let eq_pos = s
.find('=')
.ok_or_else(|| anyhow!("WHERE clause must contain '=': {s}"))?;
let category = s[..eq_pos].trim().to_string();
let item_raw = s[eq_pos + 1..].trim();
let item = item_raw.trim_matches('"').to_string();
Ok(Filter { category, item })
}
/// Parse an expression using recursive descent
pub fn parse_expr(s: &str) -> Result<Expr> {
let tokens = tokenize(s)?;
let mut pos = 0;
let expr = parse_add_sub(&tokens, &mut pos)?;
if pos < tokens.len() {
return Err(anyhow!(
"Unexpected token at position {pos}: {:?}",
tokens[pos]
));
}
Ok(expr)
}
#[derive(Debug, Clone, PartialEq)]
enum Token {
Number(f64),
Ident(String),
Str(String),
Plus,
Minus,
Star,
Slash,
Caret,
LParen,
RParen,
Comma,
Eq,
Ne,
Lt,
Gt,
Le,
Ge,
}
fn tokenize(s: &str) -> Result<Vec<Token>> {
let mut tokens = Vec::new();
let chars: Vec<char> = s.chars().collect();
let mut i = 0;
while i < chars.len() {
match chars[i] {
' ' | '\t' | '\n' => i += 1,
'+' => {
tokens.push(Token::Plus);
i += 1;
}
'-' => {
tokens.push(Token::Minus);
i += 1;
}
'*' => {
tokens.push(Token::Star);
i += 1;
}
'/' => {
tokens.push(Token::Slash);
i += 1;
}
'^' => {
tokens.push(Token::Caret);
i += 1;
}
'(' => {
tokens.push(Token::LParen);
i += 1;
}
')' => {
tokens.push(Token::RParen);
i += 1;
}
',' => {
tokens.push(Token::Comma);
i += 1;
}
'!' if chars.get(i + 1) == Some(&'=') => {
tokens.push(Token::Ne);
i += 2;
}
'<' if chars.get(i + 1) == Some(&'=') => {
tokens.push(Token::Le);
i += 2;
}
'>' if chars.get(i + 1) == Some(&'=') => {
tokens.push(Token::Ge);
i += 2;
}
'<' => {
tokens.push(Token::Lt);
i += 1;
}
'>' => {
tokens.push(Token::Gt);
i += 1;
}
'=' => {
tokens.push(Token::Eq);
i += 1;
}
'"' => {
i += 1;
let mut s = String::new();
while i < chars.len() && chars[i] != '"' {
s.push(chars[i]);
i += 1;
}
if i < chars.len() {
i += 1;
}
tokens.push(Token::Str(s));
}
c if c.is_ascii_digit() || c == '.' => {
let mut num = String::new();
while i < chars.len() && (chars[i].is_ascii_digit() || chars[i] == '.') {
num.push(chars[i]);
i += 1;
}
tokens.push(Token::Number(num.parse()?));
}
c if c.is_alphabetic() || c == '_' => {
let mut ident = String::new();
while i < chars.len()
&& (chars[i].is_alphanumeric() || chars[i] == '_' || chars[i] == ' ')
{
// Don't consume trailing spaces if next non-space is operator
if chars[i] == ' ' {
// Peek ahead
let j = i + 1;
let next_nonspace = chars[j..].iter().find(|&&c| c != ' ');
if matches!(
next_nonspace,
Some('+')
| Some('-')
| Some('*')
| Some('/')
| Some('^')
| Some(')')
| Some(',')
| None
) {
break;
}
}
ident.push(chars[i]);
i += 1;
}
let ident = ident.trim_end().to_string();
tokens.push(Token::Ident(ident));
}
c => return Err(anyhow!("Unexpected character '{c}' in expression")),
}
}
Ok(tokens)
}
fn parse_add_sub(tokens: &[Token], pos: &mut usize) -> Result<Expr> {
let mut left = parse_mul_div(tokens, pos)?;
while *pos < tokens.len() {
let op = match &tokens[*pos] {
Token::Plus => BinOp::Add,
Token::Minus => BinOp::Sub,
_ => break,
};
*pos += 1;
let right = parse_mul_div(tokens, pos)?;
left = Expr::BinOp(op, Box::new(left), Box::new(right));
}
Ok(left)
}
fn parse_mul_div(tokens: &[Token], pos: &mut usize) -> Result<Expr> {
let mut left = parse_pow(tokens, pos)?;
while *pos < tokens.len() {
let op = match &tokens[*pos] {
Token::Star => BinOp::Mul,
Token::Slash => BinOp::Div,
_ => break,
};
*pos += 1;
let right = parse_pow(tokens, pos)?;
left = Expr::BinOp(op, Box::new(left), Box::new(right));
}
Ok(left)
}
fn parse_pow(tokens: &[Token], pos: &mut usize) -> Result<Expr> {
let base = parse_unary(tokens, pos)?;
if *pos < tokens.len() && tokens[*pos] == Token::Caret {
*pos += 1;
let exp = parse_unary(tokens, pos)?;
return Ok(Expr::BinOp(BinOp::Pow, Box::new(base), Box::new(exp)));
}
Ok(base)
}
fn parse_unary(tokens: &[Token], pos: &mut usize) -> Result<Expr> {
if *pos < tokens.len() && tokens[*pos] == Token::Minus {
*pos += 1;
let e = parse_primary(tokens, pos)?;
return Ok(Expr::UnaryMinus(Box::new(e)));
}
parse_primary(tokens, pos)
}
fn parse_primary(tokens: &[Token], pos: &mut usize) -> Result<Expr> {
if *pos >= tokens.len() {
return Err(anyhow!("Unexpected end of expression"));
}
match &tokens[*pos].clone() {
Token::Number(n) => {
*pos += 1;
Ok(Expr::Number(*n))
}
Token::Ident(name) => {
let name = name.clone();
*pos += 1;
// Check for function call
let upper = name.to_ascii_uppercase();
match upper.as_str() {
"SUM" | "AVG" | "MIN" | "MAX" | "COUNT" => {
let func = match upper.as_str() {
"SUM" => AggFunc::Sum,
"AVG" => AggFunc::Avg,
"MIN" => AggFunc::Min,
"MAX" => AggFunc::Max,
"COUNT" => AggFunc::Count,
_ => unreachable!(),
};
if *pos < tokens.len() && tokens[*pos] == Token::LParen {
*pos += 1;
let inner = parse_add_sub(tokens, pos)?;
// Optional WHERE filter
let filter = if *pos < tokens.len() {
if let Token::Ident(kw) = &tokens[*pos] {
if kw.eq_ignore_ascii_case("WHERE") {
*pos += 1;
let cat = match &tokens[*pos] {
Token::Ident(s) => {
let s = s.clone();
*pos += 1;
s
}
t => {
return Err(anyhow!(
"Expected category name, got {t:?}"
))
}
};
// expect =
if *pos < tokens.len() && tokens[*pos] == Token::Eq {
*pos += 1;
}
let item = match &tokens[*pos] {
Token::Str(s) | Token::Ident(s) => {
let s = s.clone();
*pos += 1;
s
}
t => return Err(anyhow!("Expected item name, got {t:?}")),
};
Some(Filter {
category: cat,
item,
})
} else {
None
}
} else {
None
}
} else {
None
};
// expect )
if *pos < tokens.len() && tokens[*pos] == Token::RParen {
*pos += 1;
} else {
return Err(anyhow!("Expected ')' to close aggregate function"));
}
return Ok(Expr::Agg(func, Box::new(inner), filter));
}
Ok(Expr::Ref(name))
}
"IF" => {
if *pos < tokens.len() && tokens[*pos] == Token::LParen {
*pos += 1;
let cond = parse_comparison(tokens, pos)?;
if *pos < tokens.len() && tokens[*pos] == Token::Comma {
*pos += 1;
}
let then = parse_add_sub(tokens, pos)?;
if *pos < tokens.len() && tokens[*pos] == Token::Comma {
*pos += 1;
}
let else_ = parse_add_sub(tokens, pos)?;
if *pos < tokens.len() && tokens[*pos] == Token::RParen {
*pos += 1;
} else {
return Err(anyhow!("Expected ')' to close IF(...)"));
}
return Ok(Expr::If(Box::new(cond), Box::new(then), Box::new(else_)));
}
Ok(Expr::Ref(name))
}
_ => Ok(Expr::Ref(name)),
}
}
Token::LParen => {
*pos += 1;
let e = parse_add_sub(tokens, pos)?;
if *pos < tokens.len() && tokens[*pos] == Token::RParen {
*pos += 1;
}
Ok(e)
}
t => Err(anyhow!("Unexpected token in expression: {t:?}")),
}
}
fn parse_comparison(tokens: &[Token], pos: &mut usize) -> Result<Expr> {
let left = parse_add_sub(tokens, pos)?;
if *pos >= tokens.len() {
return Ok(left);
}
let op = match &tokens[*pos] {
Token::Eq => BinOp::Eq,
Token::Ne => BinOp::Ne,
Token::Lt => BinOp::Lt,
Token::Gt => BinOp::Gt,
Token::Le => BinOp::Le,
Token::Ge => BinOp::Ge,
_ => return Ok(left),
};
*pos += 1;
let right = parse_add_sub(tokens, pos)?;
Ok(Expr::BinOp(op, Box::new(left), Box::new(right)))
}
#[cfg(test)]
mod tests {
use super::parse_formula;
use crate::formula::{AggFunc, BinOp, Expr};
#[test]
fn parse_simple_subtraction() {
let f = parse_formula("Profit = Revenue - Cost", "Measure").unwrap();
assert_eq!(f.target, "Profit");
assert_eq!(f.target_category, "Measure");
assert!(matches!(f.expr, Expr::BinOp(BinOp::Sub, _, _)));
}
#[test]
fn parse_where_clause() {
let f = parse_formula("EastRev = Revenue WHERE Region = \"East\"", "Measure").unwrap();
assert_eq!(f.target, "EastRev");
let filter = f.filter.as_ref().unwrap();
assert_eq!(filter.category, "Region");
assert_eq!(filter.item, "East");
}
#[test]
fn parse_sum_aggregation() {
let f = parse_formula("Total = SUM(Revenue)", "Measure").unwrap();
assert!(matches!(f.expr, Expr::Agg(AggFunc::Sum, _, _)));
}
#[test]
fn parse_avg_aggregation() {
let f = parse_formula("Avg = AVG(Revenue)", "Measure").unwrap();
assert!(matches!(f.expr, Expr::Agg(AggFunc::Avg, _, _)));
}
#[test]
fn parse_if_expression() {
let f = parse_formula("Capped = IF(Revenue > 1000, 1000, Revenue)", "Measure").unwrap();
assert!(matches!(f.expr, Expr::If(_, _, _)));
}
#[test]
fn parse_numeric_literal() {
let f = parse_formula("Fixed = 42", "Measure").unwrap();
assert!(matches!(f.expr, Expr::Number(n) if (n - 42.0).abs() < 1e-10));
}
#[test]
fn parse_chained_arithmetic() {
parse_formula("X = (A + B) * (C - D)", "Cat").unwrap();
}
#[test]
fn parse_missing_equals_returns_error() {
assert!(parse_formula("BadFormula Revenue Cost", "Cat").is_err());
}
}