diff --git a/Cargo.lock b/Cargo.lock index a3d10bf..ad2f76e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -756,6 +756,8 @@ name = "improvise-formula" version = "0.1.0-rc2" dependencies = [ "anyhow", + "pest", + "pest_derive", "serde", ] diff --git a/crates/improvise-formula/Cargo.toml b/crates/improvise-formula/Cargo.toml index 338e0e6..7782815 100644 --- a/crates/improvise-formula/Cargo.toml +++ b/crates/improvise-formula/Cargo.toml @@ -8,4 +8,6 @@ repository = "https://github.com/fiddlerwoaroof/improvise" [dependencies] anyhow = "1" +pest = "2.8.6" +pest_derive = "2.8.6" serde = { version = "1", features = ["derive"] } diff --git a/crates/improvise-formula/src/formula.pest b/crates/improvise-formula/src/formula.pest new file mode 100644 index 0000000..e698d3a --- /dev/null +++ b/crates/improvise-formula/src/formula.pest @@ -0,0 +1,91 @@ +// Formula grammar for improvise. +// +// A formula has the form: TARGET = EXPR [WHERE filter] +// See parser.rs for the tree walker that produces a Formula AST. +// +// Identifier rules (bare_ident / pipe_quoted) mirror `bare_name` and +// `pipe_quoted` in src/persistence/improv.pest: bare identifiers are +// alphanumeric plus `_` and `-`, with no internal spaces; multi-word +// names must be pipe-quoted. + +// Auto-skip horizontal whitespace between tokens in non-atomic rules. +WHITESPACE = _{ " " | "\t" } + +// ---- top-level ---------------------------------------------------------- + +formula = { SOI ~ target ~ "=" ~ expr ~ where_clause? ~ EOI } + +// The target keeps its raw text (including pipes, if any) — we capture +// the span directly rather than walking into its children. +target = { identifier } + +where_clause = { ^"WHERE" ~ identifier ~ "=" ~ filter_value } + +// ---- expressions -------------------------------------------------------- + +// Used by parse_expr() — forces a standalone expression to consume the +// whole input, so `1 + 2 3` fails instead of silently dropping " 3". +expr_eoi = { SOI ~ expr ~ EOI } + +expr = { add_expr } + +add_expr = { mul_expr ~ (add_op ~ mul_expr)* } +add_op = { "+" | "-" } + +mul_expr = { pow_expr ~ (mul_op ~ pow_expr)* } +mul_op = { "*" | "/" } + +pow_expr = { unary ~ (pow_op ~ unary)? } +pow_op = { "^" } + +unary = { unary_minus | primary } +unary_minus = { "-" ~ primary } + +primary = { + number + | agg_call + | if_expr + | paren_expr + | ref_expr +} + +paren_expr = { "(" ~ expr ~ ")" } + +// Aggregates with optional inline WHERE filter inside the parens. +agg_call = { agg_func ~ "(" ~ expr ~ inline_where? ~ ")" } +agg_func = { ^"SUM" | ^"AVG" | ^"MIN" | ^"MAX" | ^"COUNT" } +inline_where = { ^"WHERE" ~ identifier ~ "=" ~ filter_value } + +// IF(cond, then, else). Comparison is a standalone rule because comparison +// operators are not valid in general expressions — only inside an IF condition. +if_expr = { ^"IF" ~ "(" ~ comparison ~ "," ~ expr ~ "," ~ expr ~ ")" } +comparison = { expr ~ cmp_op ~ expr } +cmp_op = { "!=" | "<=" | ">=" | "<" | ">" | "=" } + +// A reference to an item. `SUM` and `IF` without parens fall through to +// this rule because agg_call / if_expr require a "(" and otherwise fail. +ref_expr = { identifier } + +// ---- identifiers -------------------------------------------------------- +// +// Mirror of improv.pest's bare_name / pipe_quoted. + +identifier = ${ pipe_quoted | bare_ident } + +// Backslash escapes inside pipes: \| literal pipe, \\ backslash, \n newline. +pipe_quoted = @{ "|" ~ ("\\" ~ ANY | !"|" ~ ANY)* ~ "|" } + +bare_ident = @{ + (ASCII_ALPHA | "_") ~ (ASCII_ALPHANUMERIC | "_" | "-")* +} + +// ---- literal values ----------------------------------------------------- + +filter_value = { string | pipe_quoted | bare_ident } + +string = @{ "\"" ~ (!"\"" ~ ANY)* ~ "\"" } + +number = @{ + ASCII_DIGIT+ ~ ("." ~ ASCII_DIGIT*)? + | "." ~ ASCII_DIGIT+ +} diff --git a/crates/improvise-formula/src/parser.rs b/crates/improvise-formula/src/parser.rs index 01df646..0e7238c 100644 --- a/crates/improvise-formula/src/parser.rs +++ b/crates/improvise-formula/src/parser.rs @@ -1,462 +1,321 @@ use anyhow::{Result, anyhow}; +use pest::Parser as _; +use pest::iterators::Pair; +use pest_derive::Parser; use super::ast::{AggFunc, BinOp, Expr, Filter, Formula}; +#[derive(Parser)] +#[grammar = "formula.pest"] +struct FormulaParser; + /// Parse a formula string like "Profit = Revenue - Cost" /// or "Tax = Revenue * 0.08 WHERE Region = \"East\"" pub fn parse_formula(raw: &str, target_category: &str) -> Result { - let raw = raw.trim(); + let input = raw.trim(); + let formula_pair = FormulaParser::parse(Rule::formula, input) + .map_err(|e| anyhow!("{}", e))? + .next() + .ok_or_else(|| anyhow!("empty parse result"))?; + build_formula(formula_pair, input, target_category) +} - // Split on first `=` to get target = expression - let eq_pos = raw - .find('=') - .ok_or_else(|| anyhow!("Formula must contain '=': {raw}"))?; - let target = raw[..eq_pos].trim().to_string(); - let rest = raw[eq_pos + 1..].trim(); +/// Parse a bare expression (no target, no top-level WHERE clause). +/// Fails if the input contains trailing tokens after a complete expression. +pub fn parse_expr(s: &str) -> Result { + let input = s.trim(); + let expr_pair = FormulaParser::parse(Rule::expr_eoi, input) + .map_err(|e| anyhow!("{}", e))? + .next() + .ok_or_else(|| anyhow!("empty parse result"))? + .into_inner() + .next() + .ok_or_else(|| anyhow!("missing expression in expr_eoi"))?; + build_expr(expr_pair) +} - // Check for WHERE clause at top level - let (expr_str, filter) = split_where(rest); - let filter = filter.map(parse_where).transpose()?; - - let expr = parse_expr(expr_str.trim())?; +// ---- tree walkers ------------------------------------------------------- +fn build_formula(pair: Pair, raw: &str, target_category: &str) -> Result { + let mut target = None; + let mut expr = None; + let mut filter = None; + for inner in pair.into_inner() { + match inner.as_rule() { + Rule::target => target = Some(inner.as_str().trim().to_string()), + Rule::expr => expr = Some(build_expr(inner)?), + Rule::where_clause => filter = Some(build_filter(inner)?), + Rule::EOI => {} + r => return Err(anyhow!("unexpected rule in formula: {:?}", r)), + } + } + let target = target.ok_or_else(|| anyhow!("missing target in formula"))?; + let expr = expr.ok_or_else(|| anyhow!("missing expression in formula"))?; Ok(Formula::new(raw, target, target_category, expr, filter)) } -fn split_where(s: &str) -> (&str, Option<&str>) { - // Find WHERE not inside parens or quotes - let bytes = s.as_bytes(); - let mut depth = 0i32; - let mut i = 0; - while i < bytes.len() { - match bytes[i] { - b'(' => depth += 1, - b')' => depth -= 1, - b'"' => { - i += 1; - while i < bytes.len() && bytes[i] != b'"' { - i += 1; - } - } - b'|' => { - i += 1; - while i < bytes.len() && bytes[i] != b'|' { - i += 1; - } - } - _ if depth == 0 => { - if s[i..].to_ascii_uppercase().starts_with("WHERE") { - let before = &s[..i]; - let after = &s[i + 5..]; - if before.ends_with(char::is_whitespace) || i == 0 { - return (before.trim(), Some(after.trim())); - } - } - } - _ => {} - } - i += 1; - } - (s, None) +fn build_expr(pair: Pair) -> Result { + // expr = { add_expr } + build_add_expr(first_inner(pair, "expr")?) } -/// Strip pipe or double-quote delimiters from a value. -fn unquote(s: &str) -> String { - let s = s.trim(); - if (s.starts_with('"') && s.ends_with('"')) || (s.starts_with('|') && s.ends_with('|')) { - s[1..s.len() - 1].to_string() +fn build_add_expr(pair: Pair) -> Result { + fold_left_binop(pair, build_mul_expr, |s| match s { + "+" => Some(BinOp::Add), + "-" => Some(BinOp::Sub), + _ => None, + }) +} + +fn build_mul_expr(pair: Pair) -> Result { + fold_left_binop(pair, build_pow_expr, |s| match s { + "*" => Some(BinOp::Mul), + "/" => Some(BinOp::Div), + _ => None, + }) +} + +fn build_pow_expr(pair: Pair) -> Result { + // pow_expr = { unary ~ (pow_op ~ unary)? } + let mut pairs = pair.into_inner(); + let base_pair = pairs + .next() + .ok_or_else(|| anyhow!("empty pow_expr"))?; + let base = build_unary(base_pair)?; + match pairs.next() { + None => Ok(base), + Some(op_pair) => { + debug_assert_eq!(op_pair.as_rule(), Rule::pow_op); + let exp_pair = pairs + .next() + .ok_or_else(|| anyhow!("missing exponent in pow_expr"))?; + let exp = build_unary(exp_pair)?; + Ok(Expr::BinOp(BinOp::Pow, Box::new(base), Box::new(exp))) + } + } +} + +fn build_unary(pair: Pair) -> Result { + // unary = { unary_minus | primary } + let inner = first_inner(pair, "unary")?; + match inner.as_rule() { + Rule::unary_minus => { + let prim = first_inner(inner, "unary_minus")?; + Ok(Expr::UnaryMinus(Box::new(build_primary(prim)?))) + } + Rule::primary => build_primary(inner), + r => Err(anyhow!("unexpected rule in unary: {:?}", r)), + } +} + +fn build_primary(pair: Pair) -> Result { + // primary = { number | agg_call | if_expr | paren_expr | ref_expr } + let inner = first_inner(pair, "primary")?; + match inner.as_rule() { + Rule::number => Ok(Expr::Number(inner.as_str().parse()?)), + Rule::agg_call => build_agg_call(inner), + Rule::if_expr => build_if_expr(inner), + Rule::paren_expr => build_expr(first_inner(inner, "paren_expr")?), + Rule::ref_expr => { + let id_pair = first_inner(inner, "ref_expr")?; + Ok(Expr::Ref(identifier_to_string(id_pair))) + } + r => Err(anyhow!("unexpected rule in primary: {:?}", r)), + } +} + +fn build_agg_call(pair: Pair) -> Result { + // agg_call = { agg_func ~ "(" ~ expr ~ inline_where? ~ ")" } + let mut pairs = pair.into_inner(); + let func_pair = pairs + .next() + .ok_or_else(|| anyhow!("missing agg_func"))?; + let func = parse_agg_func(func_pair.as_str())?; + let expr_pair = pairs + .next() + .ok_or_else(|| anyhow!("missing aggregate argument"))?; + let inner_expr = build_expr(expr_pair)?; + let filter = match pairs.next() { + Some(p) if p.as_rule() == Rule::inline_where => Some(build_filter(p)?), + _ => None, + }; + Ok(Expr::Agg(func, Box::new(inner_expr), filter)) +} + +fn parse_agg_func(s: &str) -> Result { + match s.to_ascii_uppercase().as_str() { + "SUM" => Ok(AggFunc::Sum), + "AVG" => Ok(AggFunc::Avg), + "MIN" => Ok(AggFunc::Min), + "MAX" => Ok(AggFunc::Max), + "COUNT" => Ok(AggFunc::Count), + f => Err(anyhow!("unknown aggregate function: {}", f)), + } +} + +fn build_if_expr(pair: Pair) -> Result { + // if_expr = { ^"IF" ~ "(" ~ comparison ~ "," ~ expr ~ "," ~ expr ~ ")" } + let mut pairs = pair.into_inner(); + let cond_pair = pairs + .next() + .ok_or_else(|| anyhow!("missing IF condition"))?; + let cond = build_comparison(cond_pair)?; + let then_pair = pairs + .next() + .ok_or_else(|| anyhow!("missing IF then-branch"))?; + let then_e = build_expr(then_pair)?; + let else_pair = pairs + .next() + .ok_or_else(|| anyhow!("missing IF else-branch"))?; + let else_e = build_expr(else_pair)?; + Ok(Expr::If( + Box::new(cond), + Box::new(then_e), + Box::new(else_e), + )) +} + +fn build_comparison(pair: Pair) -> Result { + // comparison = { expr ~ cmp_op ~ expr } + let mut pairs = pair.into_inner(); + let lhs_pair = pairs + .next() + .ok_or_else(|| anyhow!("missing comparison lhs"))?; + let lhs = build_expr(lhs_pair)?; + let op_pair = pairs + .next() + .ok_or_else(|| anyhow!("missing comparison operator"))?; + let op = parse_cmp_op(op_pair.as_str())?; + let rhs_pair = pairs + .next() + .ok_or_else(|| anyhow!("missing comparison rhs"))?; + let rhs = build_expr(rhs_pair)?; + Ok(Expr::BinOp(op, Box::new(lhs), Box::new(rhs))) +} + +fn parse_cmp_op(s: &str) -> Result { + match s { + "=" => Ok(BinOp::Eq), + "!=" => Ok(BinOp::Ne), + "<" => Ok(BinOp::Lt), + ">" => Ok(BinOp::Gt), + "<=" => Ok(BinOp::Le), + ">=" => Ok(BinOp::Ge), + o => Err(anyhow!("unknown comparison operator: {}", o)), + } +} + +fn build_filter(pair: Pair) -> Result { + // where_clause / inline_where both have shape: + // ^"WHERE" ~ identifier ~ "=" ~ filter_value + let mut pairs = pair.into_inner(); + let cat_pair = pairs + .next() + .ok_or_else(|| anyhow!("missing WHERE category"))?; + let category = identifier_to_string(cat_pair); + let val_pair = pairs + .next() + .ok_or_else(|| anyhow!("missing WHERE value"))?; + let item = filter_value_to_string(val_pair); + Ok(Filter { category, item }) +} + +fn filter_value_to_string(pair: Pair) -> String { + // filter_value = { string | pipe_quoted | bare_ident } + let inner = pair + .into_inner() + .next() + .expect("filter_value must have an inner pair"); + let s = inner.as_str(); + match inner.as_rule() { + Rule::string => strip_string_quotes(s), + Rule::pipe_quoted => unquote_pipe(s), + _ => s.to_string(), + } +} + +/// Convert an identifier pair (identifier, pipe_quoted, or bare_ident) to +/// its content string. Pipe-quoted identifiers have their delimiters +/// stripped and backslash escapes applied; bare identifiers are returned +/// verbatim. +fn identifier_to_string(pair: Pair) -> String { + let s = pair.as_str(); + if is_pipe_quoted(s) { + unquote_pipe(s) } else { s.to_string() } } -fn parse_where(s: &str) -> Result { - // Format: Category = "Item" or Category = |Item| or Category = Item - let eq_pos = s - .find('=') - .ok_or_else(|| anyhow!("WHERE clause must contain '=': {s}"))?; - let category = unquote(&s[..eq_pos]); - let item = unquote(&s[eq_pos + 1..]); - Ok(Filter { category, item }) +fn is_pipe_quoted(s: &str) -> bool { + s.len() >= 2 && s.starts_with('|') && s.ends_with('|') } -/// Parse an expression using recursive descent -pub fn parse_expr(s: &str) -> Result { - let tokens = tokenize(s)?; - let mut pos = 0; - let expr = parse_add_sub(&tokens, &mut pos)?; - if pos < tokens.len() { - return Err(anyhow!( - "Unexpected token at position {pos}: {:?}", - tokens[pos] - )); - } - Ok(expr) +fn strip_string_quotes(s: &str) -> String { + debug_assert!(s.len() >= 2 && s.starts_with('"') && s.ends_with('"')); + s[1..s.len() - 1].to_string() } -#[derive(Debug, Clone, PartialEq)] -enum Token { - Number(f64), - Ident(String), - Str(String), - Plus, - Minus, - Star, - Slash, - Caret, - LParen, - RParen, - Comma, - Eq, - Ne, - Lt, - Gt, - Le, - Ge, -} - -fn tokenize(s: &str) -> Result> { - let mut tokens = Vec::new(); - let chars: Vec = s.chars().collect(); - let mut i = 0; - - while i < chars.len() { - match chars[i] { - ' ' | '\t' | '\n' => i += 1, - '+' => { - tokens.push(Token::Plus); - i += 1; - } - '-' => { - tokens.push(Token::Minus); - i += 1; - } - '*' => { - tokens.push(Token::Star); - i += 1; - } - '/' => { - tokens.push(Token::Slash); - i += 1; - } - '^' => { - tokens.push(Token::Caret); - i += 1; - } - '(' => { - tokens.push(Token::LParen); - i += 1; - } - ')' => { - tokens.push(Token::RParen); - i += 1; - } - ',' => { - tokens.push(Token::Comma); - i += 1; - } - '!' if chars.get(i + 1) == Some(&'=') => { - tokens.push(Token::Ne); - i += 2; - } - '<' if chars.get(i + 1) == Some(&'=') => { - tokens.push(Token::Le); - i += 2; - } - '>' if chars.get(i + 1) == Some(&'=') => { - tokens.push(Token::Ge); - i += 2; - } - '<' => { - tokens.push(Token::Lt); - i += 1; - } - '>' => { - tokens.push(Token::Gt); - i += 1; - } - '=' => { - tokens.push(Token::Eq); - i += 1; - } - '"' => { - i += 1; - let mut s = String::new(); - while i < chars.len() && chars[i] != '"' { - s.push(chars[i]); - i += 1; +/// Strip surrounding pipes and apply backslash escapes: `\|` → `|`, +/// `\\` → `\`, `\n` → newline. Matches the escape semantics documented +/// in src/persistence/improv.pest. +fn unquote_pipe(s: &str) -> String { + debug_assert!(is_pipe_quoted(s)); + let inner = &s[1..s.len() - 1]; + let mut out = String::with_capacity(inner.len()); + let mut chars = inner.chars(); + while let Some(c) = chars.next() { + if c == '\\' { + match chars.next() { + Some('|') => out.push('|'), + Some('\\') => out.push('\\'), + Some('n') => out.push('\n'), + Some(other) => { + out.push('\\'); + out.push(other); } - if i < chars.len() { - i += 1; - } - tokens.push(Token::Str(s)); + None => out.push('\\'), } - '|' => { - i += 1; - let mut s = String::new(); - while i < chars.len() && chars[i] != '|' { - s.push(chars[i]); - i += 1; - } - if i < chars.len() { - i += 1; - } - tokens.push(Token::Ident(s)); - } - c if c.is_ascii_digit() || c == '.' => { - let mut num = String::new(); - while i < chars.len() && (chars[i].is_ascii_digit() || chars[i] == '.') { - num.push(chars[i]); - i += 1; - } - tokens.push(Token::Number(num.parse()?)); - } - c if c.is_alphabetic() || c == '_' => { - let mut ident = String::new(); - while i < chars.len() - && (chars[i].is_alphanumeric() || chars[i] == '_' || chars[i] == ' ') - { - // Don't consume trailing spaces if next non-space is operator - if chars[i] == ' ' { - // Peek ahead past spaces to find the next word/token - let j = i + 1; - let next_nonspace = chars[j..].iter().find(|&&c| c != ' '); - if matches!( - next_nonspace, - Some('+') - | Some('-') - | Some('*') - | Some('/') - | Some('^') - | Some(')') - | Some(',') - | Some('<') - | Some('>') - | Some('=') - | Some('!') - | Some('"') - | None - ) { - break; - } - // Break if the identifier collected so far is a keyword - let trimmed = ident.trim_end().to_ascii_uppercase(); - if matches!( - trimmed.as_str(), - "WHERE" | "SUM" | "AVG" | "MIN" | "MAX" | "COUNT" | "IF" - ) { - break; - } - // Also break if the next word is a keyword - let rest: String = chars[j..].iter().collect(); - let next_word: String = rest - .trim_start() - .chars() - .take_while(|c| c.is_alphanumeric() || *c == '_') - .collect(); - let upper = next_word.to_ascii_uppercase(); - if matches!( - upper.as_str(), - "WHERE" | "SUM" | "AVG" | "MIN" | "MAX" | "COUNT" | "IF" - ) { - break; - } - } - ident.push(chars[i]); - i += 1; - } - let ident = ident.trim_end().to_string(); - tokens.push(Token::Ident(ident)); - } - c => return Err(anyhow!("Unexpected character '{c}' in expression")), + } else { + out.push(c); } } - Ok(tokens) + out } -fn parse_add_sub(tokens: &[Token], pos: &mut usize) -> Result { - let mut left = parse_mul_div(tokens, pos)?; - while *pos < tokens.len() { - let op = match &tokens[*pos] { - Token::Plus => BinOp::Add, - Token::Minus => BinOp::Sub, - _ => break, - }; - *pos += 1; - let right = parse_mul_div(tokens, pos)?; +// ---- small helpers ------------------------------------------------------ + +fn first_inner<'a>(pair: Pair<'a, Rule>, ctx: &str) -> Result> { + pair.into_inner() + .next() + .ok_or_else(|| anyhow!("empty rule: {}", ctx)) +} + +/// Fold a left-associative binary-operator rule of the shape +/// `rule = { child ~ (op ~ child)* }` into a left-leaning BinOp tree. +fn fold_left_binop(pair: Pair, mut build_child: F, match_op: M) -> Result +where + F: FnMut(Pair) -> Result, + M: Fn(&str) -> Option, +{ + let mut pairs = pair.into_inner(); + let first = pairs + .next() + .ok_or_else(|| anyhow!("empty binop rule"))?; + let mut left = build_child(first)?; + while let Some(op_pair) = pairs.next() { + let op = match_op(op_pair.as_str()).ok_or_else(|| { + anyhow!("unexpected operator token: {:?}", op_pair.as_str()) + })?; + let right_pair = pairs + .next() + .ok_or_else(|| anyhow!("missing rhs for operator"))?; + let right = build_child(right_pair)?; left = Expr::BinOp(op, Box::new(left), Box::new(right)); } Ok(left) } -fn parse_mul_div(tokens: &[Token], pos: &mut usize) -> Result { - let mut left = parse_pow(tokens, pos)?; - while *pos < tokens.len() { - let op = match &tokens[*pos] { - Token::Star => BinOp::Mul, - Token::Slash => BinOp::Div, - _ => break, - }; - *pos += 1; - let right = parse_pow(tokens, pos)?; - left = Expr::BinOp(op, Box::new(left), Box::new(right)); - } - Ok(left) -} - -fn parse_pow(tokens: &[Token], pos: &mut usize) -> Result { - let base = parse_unary(tokens, pos)?; - if *pos < tokens.len() && tokens[*pos] == Token::Caret { - *pos += 1; - let exp = parse_unary(tokens, pos)?; - return Ok(Expr::BinOp(BinOp::Pow, Box::new(base), Box::new(exp))); - } - Ok(base) -} - -fn parse_unary(tokens: &[Token], pos: &mut usize) -> Result { - if *pos < tokens.len() && tokens[*pos] == Token::Minus { - *pos += 1; - let e = parse_primary(tokens, pos)?; - return Ok(Expr::UnaryMinus(Box::new(e))); - } - parse_primary(tokens, pos) -} - -fn parse_primary(tokens: &[Token], pos: &mut usize) -> Result { - if *pos >= tokens.len() { - return Err(anyhow!("Unexpected end of expression")); - } - match &tokens[*pos].clone() { - Token::Number(n) => { - *pos += 1; - Ok(Expr::Number(*n)) - } - Token::Ident(name) => { - let name = name.clone(); - *pos += 1; - // Check for function call - let upper = name.to_ascii_uppercase(); - match upper.as_str() { - "SUM" | "AVG" | "MIN" | "MAX" | "COUNT" => { - let func = match upper.as_str() { - "SUM" => AggFunc::Sum, - "AVG" => AggFunc::Avg, - "MIN" => AggFunc::Min, - "MAX" => AggFunc::Max, - "COUNT" => AggFunc::Count, - _ => unreachable!(), - }; - if *pos < tokens.len() && tokens[*pos] == Token::LParen { - *pos += 1; - let inner = parse_add_sub(tokens, pos)?; - // Optional WHERE filter - let filter = if *pos < tokens.len() { - if let Token::Ident(kw) = &tokens[*pos] { - if kw.eq_ignore_ascii_case("WHERE") { - *pos += 1; - let cat = match &tokens[*pos] { - Token::Ident(s) => { - let s = s.clone(); - *pos += 1; - s - } - t => { - return Err(anyhow!( - "Expected category name, got {t:?}" - )); - } - }; - // expect = - if *pos < tokens.len() && tokens[*pos] == Token::Eq { - *pos += 1; - } - let item = match &tokens[*pos] { - Token::Str(s) | Token::Ident(s) => { - let s = s.clone(); - *pos += 1; - s - } - t => return Err(anyhow!("Expected item name, got {t:?}")), - }; - Some(Filter { - category: cat, - item, - }) - } else { - None - } - } else { - None - } - } else { - None - }; - // expect ) - if *pos < tokens.len() && tokens[*pos] == Token::RParen { - *pos += 1; - } else { - return Err(anyhow!("Expected ')' to close aggregate function")); - } - return Ok(Expr::Agg(func, Box::new(inner), filter)); - } - Ok(Expr::Ref(name)) - } - "IF" => { - if *pos < tokens.len() && tokens[*pos] == Token::LParen { - *pos += 1; - let cond = parse_comparison(tokens, pos)?; - if *pos < tokens.len() && tokens[*pos] == Token::Comma { - *pos += 1; - } - let then = parse_add_sub(tokens, pos)?; - if *pos < tokens.len() && tokens[*pos] == Token::Comma { - *pos += 1; - } - let else_ = parse_add_sub(tokens, pos)?; - if *pos < tokens.len() && tokens[*pos] == Token::RParen { - *pos += 1; - } else { - return Err(anyhow!("Expected ')' to close IF(...)")); - } - return Ok(Expr::If(Box::new(cond), Box::new(then), Box::new(else_))); - } - Ok(Expr::Ref(name)) - } - _ => Ok(Expr::Ref(name)), - } - } - Token::LParen => { - *pos += 1; - let e = parse_add_sub(tokens, pos)?; - if *pos < tokens.len() && tokens[*pos] == Token::RParen { - *pos += 1; - } - Ok(e) - } - t => Err(anyhow!("Unexpected token in expression: {t:?}")), - } -} - -fn parse_comparison(tokens: &[Token], pos: &mut usize) -> Result { - let left = parse_add_sub(tokens, pos)?; - if *pos >= tokens.len() { - return Ok(left); - } - let op = match &tokens[*pos] { - Token::Eq => BinOp::Eq, - Token::Ne => BinOp::Ne, - Token::Lt => BinOp::Lt, - Token::Gt => BinOp::Gt, - Token::Le => BinOp::Le, - Token::Ge => BinOp::Ge, - _ => return Ok(left), - }; - *pos += 1; - let right = parse_add_sub(tokens, pos)?; - Ok(Expr::BinOp(op, Box::new(left), Box::new(right))) -} - #[cfg(test)] mod tests { use super::parse_formula; @@ -544,8 +403,8 @@ mod tests { assert_eq!(filter.item, "East"); } - /// Regression: WHERE inside aggregate parens must tokenize correctly. - /// The tokenizer must not merge "Revenue WHERE" into a single identifier. + /// Regression: WHERE inside aggregate parens must parse as the + /// aggregate's inline filter, not as a top-level WHERE clause. #[test] fn parse_sum_with_inline_where_filter() { let f = parse_formula("EastTotal = SUM(Revenue WHERE Region = \"East\")", "Foo").unwrap(); @@ -562,7 +421,6 @@ mod tests { #[test] fn parse_if_with_comparison_operators() { - // Test each comparison operator in an IF expression let f = parse_formula("X = IF(A != 0, A, 1)", "Cat").unwrap(); assert!(matches!(f.expr, Expr::If(_, _, _))); @@ -583,7 +441,6 @@ mod tests { #[test] fn parse_where_with_quoted_string_inside_expression() { - // WHERE inside a formula string with quotes let f = parse_formula("X = Revenue WHERE Region = \"West Coast\"", "Foo").unwrap(); let filter = f.filter.as_ref().unwrap(); assert_eq!(filter.item, "West Coast"); @@ -650,11 +507,10 @@ mod tests { } } - // ── Quoted string in tokenizer ────────────────────────────────────── + // ── Quoted string in WHERE ────────────────────────────────────────── #[test] fn parse_quoted_string_in_where() { - // Quoted strings work in top-level WHERE clauses let f = parse_formula("X = Revenue WHERE Region = \"East\"", "Cat").unwrap(); let filter = f.filter.as_ref().unwrap(); assert_eq!(filter.item, "East"); @@ -681,27 +537,21 @@ mod tests { assert!(parse_expr("").is_err()); } + // ── Multi-word identifiers must be pipe-quoted ────────────────────── + #[test] - fn tokenizer_breaks_at_where_keyword() { - use super::tokenize; - let tokens = tokenize("Revenue WHERE Region").unwrap(); - // Should produce 3 tokens: Ident("Revenue"), Ident("WHERE"), Ident("Region") - assert_eq!(tokens.len(), 3, "Expected 3 tokens, got: {tokens:?}"); + fn multi_word_bare_identifier_is_rejected() { + // Multi-word identifiers must be pipe-quoted; bare multi-word fails + // the `bare_name`-compatible grammar rule. + assert!(parse_formula("Total Revenue = Base Revenue + Bonus", "Foo").is_err()); } - // ── Multi-word identifiers ────────────────────────────────────────── + // ── WHERE inside quotes in the expression ─────────────────────────── #[test] - fn parse_multi_word_identifier() { - let f = parse_formula("Total Revenue = Base Revenue + Bonus", "Foo").unwrap(); - assert_eq!(f.target, "Total Revenue"); - } - - // ── WHERE inside quotes in split_where ────────────────────────────── - - #[test] - fn split_where_ignores_where_inside_quotes() { - // WHERE inside quotes should not be treated as a keyword + fn where_inside_quotes_is_not_a_keyword() { + // A filter value containing the literal text "WHERE" is parsed as + // a string, not as a nested WHERE keyword. let f = parse_formula("X = Revenue WHERE Region = \"WHERE\"", "Foo").unwrap(); let filter = f.filter.as_ref().unwrap(); assert_eq!(filter.item, "WHERE"); @@ -773,4 +623,17 @@ mod tests { panic!("Expected SUM with WHERE filter, got: {:?}", f.expr); } } + + // ── Pipe-quoted escape semantics ──────────────────────────────────── + + #[test] + fn pipe_quoted_escape_literal_pipe() { + // \| inside a pipe-quoted identifier is a literal pipe + let f = parse_formula("X = |A\\|B|", "Cat").unwrap(); + if let Expr::Ref(ref s) = f.expr { + assert_eq!(s, "A|B"); + } else { + panic!("Expected Ref, got: {:?}", f.expr); + } + } }