462 lines
15 KiB
Rust
462 lines
15 KiB
Rust
use anyhow::{anyhow, Result};
|
|
|
|
use super::ast::{AggFunc, BinOp, Expr, Filter, Formula};
|
|
|
|
/// Parse a formula string like "Profit = Revenue - Cost"
|
|
/// or "Tax = Revenue * 0.08 WHERE Region = \"East\""
|
|
pub fn parse_formula(raw: &str, target_category: &str) -> Result<Formula> {
|
|
let raw = raw.trim();
|
|
|
|
// Split on first `=` to get target = expression
|
|
let eq_pos = raw
|
|
.find('=')
|
|
.ok_or_else(|| anyhow!("Formula must contain '=': {raw}"))?;
|
|
let target = raw[..eq_pos].trim().to_string();
|
|
let rest = raw[eq_pos + 1..].trim();
|
|
|
|
// Check for WHERE clause at top level
|
|
let (expr_str, filter) = split_where(rest);
|
|
let filter = filter.map(parse_where).transpose()?;
|
|
|
|
let expr = parse_expr(expr_str.trim())?;
|
|
|
|
Ok(Formula::new(raw, target, target_category, expr, filter))
|
|
}
|
|
|
|
fn split_where(s: &str) -> (&str, Option<&str>) {
|
|
// Find WHERE not inside parens or quotes
|
|
let bytes = s.as_bytes();
|
|
let mut depth = 0i32;
|
|
let mut i = 0;
|
|
while i < bytes.len() {
|
|
match bytes[i] {
|
|
b'(' => depth += 1,
|
|
b')' => depth -= 1,
|
|
b'"' => {
|
|
i += 1;
|
|
while i < bytes.len() && bytes[i] != b'"' {
|
|
i += 1;
|
|
}
|
|
}
|
|
_ if depth == 0 => {
|
|
if s[i..].to_ascii_uppercase().starts_with("WHERE") {
|
|
let before = &s[..i];
|
|
let after = &s[i + 5..];
|
|
if before.ends_with(char::is_whitespace) || i == 0 {
|
|
return (before.trim(), Some(after.trim()));
|
|
}
|
|
}
|
|
}
|
|
_ => {}
|
|
}
|
|
i += 1;
|
|
}
|
|
(s, None)
|
|
}
|
|
|
|
fn parse_where(s: &str) -> Result<Filter> {
|
|
// Format: Category = "Item" or Category = Item
|
|
let eq_pos = s
|
|
.find('=')
|
|
.ok_or_else(|| anyhow!("WHERE clause must contain '=': {s}"))?;
|
|
let category = s[..eq_pos].trim().to_string();
|
|
let item_raw = s[eq_pos + 1..].trim();
|
|
let item = item_raw.trim_matches('"').to_string();
|
|
Ok(Filter { category, item })
|
|
}
|
|
|
|
/// Parse an expression using recursive descent
|
|
pub fn parse_expr(s: &str) -> Result<Expr> {
|
|
let tokens = tokenize(s)?;
|
|
let mut pos = 0;
|
|
let expr = parse_add_sub(&tokens, &mut pos)?;
|
|
if pos < tokens.len() {
|
|
return Err(anyhow!(
|
|
"Unexpected token at position {pos}: {:?}",
|
|
tokens[pos]
|
|
));
|
|
}
|
|
Ok(expr)
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq)]
|
|
enum Token {
|
|
Number(f64),
|
|
Ident(String),
|
|
Str(String),
|
|
Plus,
|
|
Minus,
|
|
Star,
|
|
Slash,
|
|
Caret,
|
|
LParen,
|
|
RParen,
|
|
Comma,
|
|
Eq,
|
|
Ne,
|
|
Lt,
|
|
Gt,
|
|
Le,
|
|
Ge,
|
|
}
|
|
|
|
fn tokenize(s: &str) -> Result<Vec<Token>> {
|
|
let mut tokens = Vec::new();
|
|
let chars: Vec<char> = s.chars().collect();
|
|
let mut i = 0;
|
|
|
|
while i < chars.len() {
|
|
match chars[i] {
|
|
' ' | '\t' | '\n' => i += 1,
|
|
'+' => {
|
|
tokens.push(Token::Plus);
|
|
i += 1;
|
|
}
|
|
'-' => {
|
|
tokens.push(Token::Minus);
|
|
i += 1;
|
|
}
|
|
'*' => {
|
|
tokens.push(Token::Star);
|
|
i += 1;
|
|
}
|
|
'/' => {
|
|
tokens.push(Token::Slash);
|
|
i += 1;
|
|
}
|
|
'^' => {
|
|
tokens.push(Token::Caret);
|
|
i += 1;
|
|
}
|
|
'(' => {
|
|
tokens.push(Token::LParen);
|
|
i += 1;
|
|
}
|
|
')' => {
|
|
tokens.push(Token::RParen);
|
|
i += 1;
|
|
}
|
|
',' => {
|
|
tokens.push(Token::Comma);
|
|
i += 1;
|
|
}
|
|
'!' if chars.get(i + 1) == Some(&'=') => {
|
|
tokens.push(Token::Ne);
|
|
i += 2;
|
|
}
|
|
'<' if chars.get(i + 1) == Some(&'=') => {
|
|
tokens.push(Token::Le);
|
|
i += 2;
|
|
}
|
|
'>' if chars.get(i + 1) == Some(&'=') => {
|
|
tokens.push(Token::Ge);
|
|
i += 2;
|
|
}
|
|
'<' => {
|
|
tokens.push(Token::Lt);
|
|
i += 1;
|
|
}
|
|
'>' => {
|
|
tokens.push(Token::Gt);
|
|
i += 1;
|
|
}
|
|
'=' => {
|
|
tokens.push(Token::Eq);
|
|
i += 1;
|
|
}
|
|
'"' => {
|
|
i += 1;
|
|
let mut s = String::new();
|
|
while i < chars.len() && chars[i] != '"' {
|
|
s.push(chars[i]);
|
|
i += 1;
|
|
}
|
|
if i < chars.len() {
|
|
i += 1;
|
|
}
|
|
tokens.push(Token::Str(s));
|
|
}
|
|
c if c.is_ascii_digit() || c == '.' => {
|
|
let mut num = String::new();
|
|
while i < chars.len() && (chars[i].is_ascii_digit() || chars[i] == '.') {
|
|
num.push(chars[i]);
|
|
i += 1;
|
|
}
|
|
tokens.push(Token::Number(num.parse()?));
|
|
}
|
|
c if c.is_alphabetic() || c == '_' => {
|
|
let mut ident = String::new();
|
|
while i < chars.len()
|
|
&& (chars[i].is_alphanumeric() || chars[i] == '_' || chars[i] == ' ')
|
|
{
|
|
// Don't consume trailing spaces if next non-space is operator
|
|
if chars[i] == ' ' {
|
|
// Peek ahead
|
|
let j = i + 1;
|
|
let next_nonspace = chars[j..].iter().find(|&&c| c != ' ');
|
|
if matches!(
|
|
next_nonspace,
|
|
Some('+')
|
|
| Some('-')
|
|
| Some('*')
|
|
| Some('/')
|
|
| Some('^')
|
|
| Some(')')
|
|
| Some(',')
|
|
| None
|
|
) {
|
|
break;
|
|
}
|
|
}
|
|
ident.push(chars[i]);
|
|
i += 1;
|
|
}
|
|
let ident = ident.trim_end().to_string();
|
|
tokens.push(Token::Ident(ident));
|
|
}
|
|
c => return Err(anyhow!("Unexpected character '{c}' in expression")),
|
|
}
|
|
}
|
|
Ok(tokens)
|
|
}
|
|
|
|
fn parse_add_sub(tokens: &[Token], pos: &mut usize) -> Result<Expr> {
|
|
let mut left = parse_mul_div(tokens, pos)?;
|
|
while *pos < tokens.len() {
|
|
let op = match &tokens[*pos] {
|
|
Token::Plus => BinOp::Add,
|
|
Token::Minus => BinOp::Sub,
|
|
_ => break,
|
|
};
|
|
*pos += 1;
|
|
let right = parse_mul_div(tokens, pos)?;
|
|
left = Expr::BinOp(op, Box::new(left), Box::new(right));
|
|
}
|
|
Ok(left)
|
|
}
|
|
|
|
fn parse_mul_div(tokens: &[Token], pos: &mut usize) -> Result<Expr> {
|
|
let mut left = parse_pow(tokens, pos)?;
|
|
while *pos < tokens.len() {
|
|
let op = match &tokens[*pos] {
|
|
Token::Star => BinOp::Mul,
|
|
Token::Slash => BinOp::Div,
|
|
_ => break,
|
|
};
|
|
*pos += 1;
|
|
let right = parse_pow(tokens, pos)?;
|
|
left = Expr::BinOp(op, Box::new(left), Box::new(right));
|
|
}
|
|
Ok(left)
|
|
}
|
|
|
|
fn parse_pow(tokens: &[Token], pos: &mut usize) -> Result<Expr> {
|
|
let base = parse_unary(tokens, pos)?;
|
|
if *pos < tokens.len() && tokens[*pos] == Token::Caret {
|
|
*pos += 1;
|
|
let exp = parse_unary(tokens, pos)?;
|
|
return Ok(Expr::BinOp(BinOp::Pow, Box::new(base), Box::new(exp)));
|
|
}
|
|
Ok(base)
|
|
}
|
|
|
|
fn parse_unary(tokens: &[Token], pos: &mut usize) -> Result<Expr> {
|
|
if *pos < tokens.len() && tokens[*pos] == Token::Minus {
|
|
*pos += 1;
|
|
let e = parse_primary(tokens, pos)?;
|
|
return Ok(Expr::UnaryMinus(Box::new(e)));
|
|
}
|
|
parse_primary(tokens, pos)
|
|
}
|
|
|
|
fn parse_primary(tokens: &[Token], pos: &mut usize) -> Result<Expr> {
|
|
if *pos >= tokens.len() {
|
|
return Err(anyhow!("Unexpected end of expression"));
|
|
}
|
|
match &tokens[*pos].clone() {
|
|
Token::Number(n) => {
|
|
*pos += 1;
|
|
Ok(Expr::Number(*n))
|
|
}
|
|
Token::Ident(name) => {
|
|
let name = name.clone();
|
|
*pos += 1;
|
|
// Check for function call
|
|
let upper = name.to_ascii_uppercase();
|
|
match upper.as_str() {
|
|
"SUM" | "AVG" | "MIN" | "MAX" | "COUNT" => {
|
|
let func = match upper.as_str() {
|
|
"SUM" => AggFunc::Sum,
|
|
"AVG" => AggFunc::Avg,
|
|
"MIN" => AggFunc::Min,
|
|
"MAX" => AggFunc::Max,
|
|
"COUNT" => AggFunc::Count,
|
|
_ => unreachable!(),
|
|
};
|
|
if *pos < tokens.len() && tokens[*pos] == Token::LParen {
|
|
*pos += 1;
|
|
let inner = parse_add_sub(tokens, pos)?;
|
|
// Optional WHERE filter
|
|
let filter = if *pos < tokens.len() {
|
|
if let Token::Ident(kw) = &tokens[*pos] {
|
|
if kw.eq_ignore_ascii_case("WHERE") {
|
|
*pos += 1;
|
|
let cat = match &tokens[*pos] {
|
|
Token::Ident(s) => {
|
|
let s = s.clone();
|
|
*pos += 1;
|
|
s
|
|
}
|
|
t => {
|
|
return Err(anyhow!(
|
|
"Expected category name, got {t:?}"
|
|
))
|
|
}
|
|
};
|
|
// expect =
|
|
if *pos < tokens.len() && tokens[*pos] == Token::Eq {
|
|
*pos += 1;
|
|
}
|
|
let item = match &tokens[*pos] {
|
|
Token::Str(s) | Token::Ident(s) => {
|
|
let s = s.clone();
|
|
*pos += 1;
|
|
s
|
|
}
|
|
t => return Err(anyhow!("Expected item name, got {t:?}")),
|
|
};
|
|
Some(Filter {
|
|
category: cat,
|
|
item,
|
|
})
|
|
} else {
|
|
None
|
|
}
|
|
} else {
|
|
None
|
|
}
|
|
} else {
|
|
None
|
|
};
|
|
// expect )
|
|
if *pos < tokens.len() && tokens[*pos] == Token::RParen {
|
|
*pos += 1;
|
|
} else {
|
|
return Err(anyhow!("Expected ')' to close aggregate function"));
|
|
}
|
|
return Ok(Expr::Agg(func, Box::new(inner), filter));
|
|
}
|
|
Ok(Expr::Ref(name))
|
|
}
|
|
"IF" => {
|
|
if *pos < tokens.len() && tokens[*pos] == Token::LParen {
|
|
*pos += 1;
|
|
let cond = parse_comparison(tokens, pos)?;
|
|
if *pos < tokens.len() && tokens[*pos] == Token::Comma {
|
|
*pos += 1;
|
|
}
|
|
let then = parse_add_sub(tokens, pos)?;
|
|
if *pos < tokens.len() && tokens[*pos] == Token::Comma {
|
|
*pos += 1;
|
|
}
|
|
let else_ = parse_add_sub(tokens, pos)?;
|
|
if *pos < tokens.len() && tokens[*pos] == Token::RParen {
|
|
*pos += 1;
|
|
} else {
|
|
return Err(anyhow!("Expected ')' to close IF(...)"));
|
|
}
|
|
return Ok(Expr::If(Box::new(cond), Box::new(then), Box::new(else_)));
|
|
}
|
|
Ok(Expr::Ref(name))
|
|
}
|
|
_ => Ok(Expr::Ref(name)),
|
|
}
|
|
}
|
|
Token::LParen => {
|
|
*pos += 1;
|
|
let e = parse_add_sub(tokens, pos)?;
|
|
if *pos < tokens.len() && tokens[*pos] == Token::RParen {
|
|
*pos += 1;
|
|
}
|
|
Ok(e)
|
|
}
|
|
t => Err(anyhow!("Unexpected token in expression: {t:?}")),
|
|
}
|
|
}
|
|
|
|
fn parse_comparison(tokens: &[Token], pos: &mut usize) -> Result<Expr> {
|
|
let left = parse_add_sub(tokens, pos)?;
|
|
if *pos >= tokens.len() {
|
|
return Ok(left);
|
|
}
|
|
let op = match &tokens[*pos] {
|
|
Token::Eq => BinOp::Eq,
|
|
Token::Ne => BinOp::Ne,
|
|
Token::Lt => BinOp::Lt,
|
|
Token::Gt => BinOp::Gt,
|
|
Token::Le => BinOp::Le,
|
|
Token::Ge => BinOp::Ge,
|
|
_ => return Ok(left),
|
|
};
|
|
*pos += 1;
|
|
let right = parse_add_sub(tokens, pos)?;
|
|
Ok(Expr::BinOp(op, Box::new(left), Box::new(right)))
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::parse_formula;
|
|
use crate::formula::{AggFunc, BinOp, Expr};
|
|
|
|
#[test]
|
|
fn parse_simple_subtraction() {
|
|
let f = parse_formula("Profit = Revenue - Cost", "Measure").unwrap();
|
|
assert_eq!(f.target, "Profit");
|
|
assert_eq!(f.target_category, "Measure");
|
|
assert!(matches!(f.expr, Expr::BinOp(BinOp::Sub, _, _)));
|
|
}
|
|
|
|
#[test]
|
|
fn parse_where_clause() {
|
|
let f = parse_formula("EastRev = Revenue WHERE Region = \"East\"", "Measure").unwrap();
|
|
assert_eq!(f.target, "EastRev");
|
|
let filter = f.filter.as_ref().unwrap();
|
|
assert_eq!(filter.category, "Region");
|
|
assert_eq!(filter.item, "East");
|
|
}
|
|
|
|
#[test]
|
|
fn parse_sum_aggregation() {
|
|
let f = parse_formula("Total = SUM(Revenue)", "Measure").unwrap();
|
|
assert!(matches!(f.expr, Expr::Agg(AggFunc::Sum, _, _)));
|
|
}
|
|
|
|
#[test]
|
|
fn parse_avg_aggregation() {
|
|
let f = parse_formula("Avg = AVG(Revenue)", "Measure").unwrap();
|
|
assert!(matches!(f.expr, Expr::Agg(AggFunc::Avg, _, _)));
|
|
}
|
|
|
|
#[test]
|
|
fn parse_if_expression() {
|
|
let f = parse_formula("Capped = IF(Revenue > 1000, 1000, Revenue)", "Measure").unwrap();
|
|
assert!(matches!(f.expr, Expr::If(_, _, _)));
|
|
}
|
|
|
|
#[test]
|
|
fn parse_numeric_literal() {
|
|
let f = parse_formula("Fixed = 42", "Measure").unwrap();
|
|
assert!(matches!(f.expr, Expr::Number(n) if (n - 42.0).abs() < 1e-10));
|
|
}
|
|
|
|
#[test]
|
|
fn parse_chained_arithmetic() {
|
|
parse_formula("X = (A + B) * (C - D)", "Cat").unwrap();
|
|
}
|
|
|
|
#[test]
|
|
fn parse_missing_equals_returns_error() {
|
|
assert!(parse_formula("BadFormula Revenue Cost", "Cat").is_err());
|
|
}
|
|
}
|