use indexmap::IndexMap; use serde::{Deserialize, Serialize}; use anyhow::{anyhow, Result}; use super::category::{Category, CategoryId}; use super::cell::{CellKey, CellValue, DataStore}; use crate::formula::Formula; use crate::view::View; const MAX_CATEGORIES: usize = 12; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Model { pub name: String, pub categories: IndexMap, pub data: DataStore, pub formulas: Vec, pub views: IndexMap, pub active_view: String, next_category_id: CategoryId, } impl Model { pub fn new(name: impl Into) -> Self { let name = name.into(); let default_view = View::new("Default"); let mut views = IndexMap::new(); views.insert("Default".to_string(), default_view); Self { name, categories: IndexMap::new(), data: DataStore::new(), formulas: Vec::new(), views, active_view: "Default".to_string(), next_category_id: 0, } } pub fn add_category(&mut self, name: impl Into) -> Result { let name = name.into(); if self.categories.len() >= MAX_CATEGORIES { return Err(anyhow!("Maximum of {MAX_CATEGORIES} categories reached")); } if self.categories.contains_key(&name) { return Ok(self.categories[&name].id); } let id = self.next_category_id; self.next_category_id += 1; self.categories.insert(name.clone(), Category::new(id, name.clone())); // Add to all views for view in self.views.values_mut() { view.on_category_added(&name); } Ok(id) } pub fn category_mut(&mut self, name: &str) -> Option<&mut Category> { self.categories.get_mut(name) } pub fn category(&self, name: &str) -> Option<&Category> { self.categories.get(name) } pub fn set_cell(&mut self, key: CellKey, value: CellValue) { self.data.set(key, value); } pub fn get_cell(&self, key: &CellKey) -> &CellValue { self.data.get(key) } pub fn add_formula(&mut self, formula: Formula) { // Replace if same target if let Some(pos) = self.formulas.iter().position(|f| f.target == formula.target) { self.formulas[pos] = formula; } else { self.formulas.push(formula); } } pub fn remove_formula(&mut self, target: &str) { self.formulas.retain(|f| f.target != target); } pub fn active_view(&self) -> Option<&View> { self.views.get(&self.active_view) } pub fn active_view_mut(&mut self) -> Option<&mut View> { self.views.get_mut(&self.active_view) } pub fn create_view(&mut self, name: impl Into) -> &mut View { let name = name.into(); let mut view = View::new(name.clone()); // Copy category assignments from default if any for cat_name in self.categories.keys() { view.on_category_added(cat_name); } self.views.insert(name.clone(), view); self.views.get_mut(&name).unwrap() } pub fn switch_view(&mut self, name: &str) -> Result<()> { if self.views.contains_key(name) { self.active_view = name.to_string(); Ok(()) } else { Err(anyhow!("View '{name}' not found")) } } pub fn delete_view(&mut self, name: &str) -> Result<()> { if self.views.len() <= 1 { return Err(anyhow!("Cannot delete the last view")); } self.views.shift_remove(name); if self.active_view == name { self.active_view = self.views.keys().next().unwrap().clone(); } Ok(()) } /// Return all category names pub fn category_names(&self) -> Vec<&str> { self.categories.keys().map(|s| s.as_str()).collect() } /// Evaluate a computed value at a given key, considering formulas pub fn evaluate(&self, key: &CellKey) -> CellValue { // Check if the last category dimension in the key corresponds to a formula target for formula in &self.formulas { if let Some(item_val) = key.get(&formula.target_category) { if item_val == formula.target { return self.eval_formula(formula, key); } } } self.data.get(key).clone() } fn eval_formula(&self, formula: &Formula, context: &CellKey) -> CellValue { use crate::formula::{Expr, AggFunc}; // Check WHERE filter first if let Some(filter) = &formula.filter { if let Some(item_val) = context.get(&filter.category) { if item_val != filter.item.as_str() { return self.data.get(context).clone(); } } } fn find_item_category<'a>(model: &'a Model, item_name: &str) -> Option<&'a str> { for (cat_name, cat) in &model.categories { if cat.items.contains_key(item_name) { return Some(cat_name.as_str()); } } None } fn eval_expr( expr: &Expr, context: &CellKey, model: &Model, target_category: &str, ) -> Option { match expr { Expr::Number(n) => Some(*n), Expr::Ref(name) => { let cat = find_item_category(model, name)?; let new_key = context.clone().with(cat, name); model.evaluate(&new_key).as_f64() } Expr::BinOp(op, l, r) => { let lv = eval_expr(l, context, model, target_category)?; let rv = eval_expr(r, context, model, target_category)?; Some(match op.as_str() { "+" => lv + rv, "-" => lv - rv, "*" => lv * rv, "/" => if rv == 0.0 { 0.0 } else { lv / rv }, "^" => lv.powf(rv), _ => return None, }) } Expr::UnaryMinus(e) => Some(-eval_expr(e, context, model, target_category)?), Expr::Agg(func, _inner, _filter) => { let partial = context.without(target_category); let values: Vec = model.data.matching_cells(&partial.0) .into_iter() .filter_map(|(_, v)| v.as_f64()) .collect(); match func { AggFunc::Sum => Some(values.iter().sum()), AggFunc::Avg => { if values.is_empty() { None } else { Some(values.iter().sum::() / values.len() as f64) } } AggFunc::Min => values.iter().cloned().reduce(f64::min), AggFunc::Max => values.iter().cloned().reduce(f64::max), AggFunc::Count => Some(values.len() as f64), } } Expr::If(cond, then, else_) => { let cv = eval_bool(cond, context, model, target_category)?; if cv { eval_expr(then, context, model, target_category) } else { eval_expr(else_, context, model, target_category) } } } } fn eval_bool( expr: &Expr, context: &CellKey, model: &Model, target_category: &str, ) -> Option { match expr { Expr::BinOp(op, l, r) => { let lv = eval_expr(l, context, model, target_category)?; let rv = eval_expr(r, context, model, target_category)?; Some(match op.as_str() { "=" | "==" => (lv - rv).abs() < 1e-10, "!=" => (lv - rv).abs() >= 1e-10, "<" => lv < rv, ">" => lv > rv, "<=" => lv <= rv, ">=" => lv >= rv, _ => return None, }) } _ => None, } } match eval_expr(&formula.expr, context, self, &formula.target_category) { Some(n) => CellValue::Number(n), None => CellValue::Empty, } } } #[cfg(test)] mod model_tests { use super::Model; use crate::model::cell::{CellKey, CellValue}; use crate::view::Axis; fn coord(pairs: &[(&str, &str)]) -> CellKey { CellKey::new(pairs.iter().map(|(c, i)| (c.to_string(), i.to_string())).collect()) } #[test] fn new_model_has_default_view() { let m = Model::new("Test"); assert!(m.active_view().is_some()); } #[test] fn add_category_creates_entry() { let mut m = Model::new("Test"); m.add_category("Region").unwrap(); assert!(m.category("Region").is_some()); } #[test] fn add_category_duplicate_is_idempotent() { let mut m = Model::new("Test"); let id1 = m.add_category("Region").unwrap(); let id2 = m.add_category("Region").unwrap(); assert_eq!(id1, id2); assert_eq!(m.categories.len(), 1); } #[test] fn add_category_max_limit() { let mut m = Model::new("Test"); for i in 0..12 { m.add_category(format!("Cat{i}")).unwrap(); } assert!(m.add_category("TooMany").is_err()); } #[test] fn add_category_notifies_existing_views() { let mut m = Model::new("Test"); m.add_category("Region").unwrap(); assert_ne!(m.active_view().unwrap().axis_of("Region"), Axis::Unassigned); } #[test] fn set_and_get_cell_roundtrip() { let mut m = Model::new("Test"); m.add_category("Region").unwrap(); m.add_category("Measure").unwrap(); let k = coord(&[("Region", "East"), ("Measure", "Revenue")]); m.set_cell(k.clone(), CellValue::Number(500.0)); assert_eq!(m.get_cell(&k), &CellValue::Number(500.0)); } #[test] fn get_unset_cell_returns_empty() { let m = Model::new("Test"); let k = coord(&[("Region", "East")]); assert_eq!(m.get_cell(&k), &CellValue::Empty); } #[test] fn overwrite_cell() { let mut m = Model::new("Test"); let k = coord(&[("Region", "East")]); m.set_cell(k.clone(), CellValue::Number(1.0)); m.set_cell(k.clone(), CellValue::Number(2.0)); assert_eq!(m.get_cell(&k), &CellValue::Number(2.0)); } #[test] fn three_category_model_independent_cells() { let mut m = Model::new("Test"); m.add_category("Region").unwrap(); m.add_category("Product").unwrap(); m.add_category("Measure").unwrap(); let k1 = coord(&[("Region", "East"), ("Product", "Shirts"), ("Measure", "Revenue")]); let k2 = coord(&[("Region", "West"), ("Product", "Shirts"), ("Measure", "Revenue")]); let k3 = coord(&[("Region", "East"), ("Product", "Pants"), ("Measure", "Revenue")]); let k4 = coord(&[("Region", "East"), ("Product", "Shirts"), ("Measure", "Cost")]); m.set_cell(k1.clone(), CellValue::Number(100.0)); m.set_cell(k2.clone(), CellValue::Number(200.0)); m.set_cell(k3.clone(), CellValue::Number(300.0)); m.set_cell(k4.clone(), CellValue::Number(40.0)); assert_eq!(m.get_cell(&k1), &CellValue::Number(100.0)); assert_eq!(m.get_cell(&k2), &CellValue::Number(200.0)); assert_eq!(m.get_cell(&k3), &CellValue::Number(300.0)); assert_eq!(m.get_cell(&k4), &CellValue::Number(40.0)); } #[test] fn create_view_copies_category_structure() { let mut m = Model::new("Test"); m.add_category("Region").unwrap(); m.add_category("Product").unwrap(); m.create_view("Secondary"); let v = m.views.get("Secondary").unwrap(); assert_ne!(v.axis_of("Region"), Axis::Unassigned); assert_ne!(v.axis_of("Product"), Axis::Unassigned); } #[test] fn switch_view_changes_active_view() { let mut m = Model::new("Test"); m.create_view("Other"); m.switch_view("Other").unwrap(); assert_eq!(m.active_view, "Other"); } #[test] fn switch_view_unknown_returns_error() { let mut m = Model::new("Test"); assert!(m.switch_view("NoSuchView").is_err()); } #[test] fn delete_view_removes_it() { let mut m = Model::new("Test"); m.create_view("Extra"); m.delete_view("Extra").unwrap(); assert!(!m.views.contains_key("Extra")); } #[test] fn delete_last_view_returns_error() { let mut m = Model::new("Test"); assert!(m.delete_view("Default").is_err()); } #[test] fn delete_active_view_switches_to_another() { let mut m = Model::new("Test"); m.create_view("Other"); m.switch_view("Other").unwrap(); m.delete_view("Other").unwrap(); assert_ne!(m.active_view, "Other"); } #[test] fn first_category_goes_to_row_second_to_column_rest_to_page() { let mut m = Model::new("Test"); m.add_category("Region").unwrap(); m.add_category("Product").unwrap(); m.add_category("Time").unwrap(); let v = m.active_view().unwrap(); assert_eq!(v.axis_of("Region"), Axis::Row); assert_eq!(v.axis_of("Product"), Axis::Column); assert_eq!(v.axis_of("Time"), Axis::Page); } #[test] fn data_is_shared_across_views() { let mut m = Model::new("Test"); m.create_view("Second"); let k = coord(&[("Region", "East")]); m.set_cell(k.clone(), CellValue::Number(77.0)); assert_eq!(m.get_cell(&k), &CellValue::Number(77.0)); } } #[cfg(test)] mod formula_tests { use super::Model; use crate::model::cell::{CellKey, CellValue}; use crate::formula::parse_formula; fn coord(pairs: &[(&str, &str)]) -> CellKey { CellKey::new(pairs.iter().map(|(c, i)| (c.to_string(), i.to_string())).collect()) } fn approx_eq(a: f64, b: f64) -> bool { (a - b).abs() < 1e-9 } fn revenue_cost_model() -> Model { let mut m = Model::new("Test"); m.add_category("Measure").unwrap(); m.add_category("Region").unwrap(); if let Some(cat) = m.category_mut("Measure") { cat.add_item("Revenue"); cat.add_item("Cost"); cat.add_item("Profit"); } if let Some(cat) = m.category_mut("Region") { cat.add_item("East"); cat.add_item("West"); } m.set_cell(coord(&[("Measure", "Revenue"), ("Region", "East")]), CellValue::Number(1000.0)); m.set_cell(coord(&[("Measure", "Cost"), ("Region", "East")]), CellValue::Number(600.0)); m.set_cell(coord(&[("Measure", "Revenue"), ("Region", "West")]), CellValue::Number(800.0)); m.set_cell(coord(&[("Measure", "Cost"), ("Region", "West")]), CellValue::Number(500.0)); m } #[test] fn profit_equals_revenue_minus_cost() { let mut m = revenue_cost_model(); m.add_formula(parse_formula("Profit = Revenue - Cost", "Measure").unwrap()); let k = coord(&[("Measure", "Profit"), ("Region", "East")]); assert_eq!(m.evaluate(&k), CellValue::Number(400.0)); } #[test] fn formula_evaluates_per_region() { let mut m = revenue_cost_model(); m.add_formula(parse_formula("Profit = Revenue - Cost", "Measure").unwrap()); let east = m.evaluate(&coord(&[("Measure", "Profit"), ("Region", "East")])); let west = m.evaluate(&coord(&[("Measure", "Profit"), ("Region", "West")])); assert_eq!(east, CellValue::Number(400.0)); assert_eq!(west, CellValue::Number(300.0)); } #[test] fn formula_multiplication() { let mut m = revenue_cost_model(); m.add_formula(parse_formula("Tax = Revenue * 0.1", "Measure").unwrap()); if let Some(cat) = m.category_mut("Measure") { cat.add_item("Tax"); } let val = m.evaluate(&coord(&[("Measure", "Tax"), ("Region", "East")])).as_f64().unwrap(); assert!(approx_eq(val, 100.0)); } #[test] fn formula_division() { let mut m = revenue_cost_model(); m.add_formula(parse_formula("Profit = Revenue - Cost", "Measure").unwrap()); m.add_formula(parse_formula("Margin = Profit / Revenue", "Measure").unwrap()); if let Some(cat) = m.category_mut("Measure") { cat.add_item("Profit"); cat.add_item("Margin"); } let val = m.evaluate(&coord(&[("Measure", "Margin"), ("Region", "East")])).as_f64().unwrap(); assert!(approx_eq(val, 0.4)); } #[test] fn division_by_zero_yields_zero() { let mut m = Model::new("Test"); m.add_category("Measure").unwrap(); m.add_category("Region").unwrap(); if let Some(cat) = m.category_mut("Measure") { cat.add_item("Revenue"); cat.add_item("Zero"); cat.add_item("Result"); } m.set_cell(coord(&[("Measure", "Revenue"), ("Region", "East")]), CellValue::Number(100.0)); m.set_cell(coord(&[("Measure", "Zero"), ("Region", "East")]), CellValue::Number(0.0)); m.add_formula(parse_formula("Result = Revenue / Zero", "Measure").unwrap()); assert_eq!(m.evaluate(&coord(&[("Measure", "Result"), ("Region", "East")])), CellValue::Number(0.0)); } #[test] fn unary_minus() { let mut m = revenue_cost_model(); m.add_formula(parse_formula("NegRevenue = -Revenue", "Measure").unwrap()); if let Some(cat) = m.category_mut("Measure") { cat.add_item("NegRevenue"); } let k = coord(&[("Measure", "NegRevenue"), ("Region", "East")]); assert_eq!(m.evaluate(&k), CellValue::Number(-1000.0)); } #[test] fn power_operator() { let mut m = Model::new("Test"); m.add_category("Measure").unwrap(); if let Some(cat) = m.category_mut("Measure") { cat.add_item("Base"); cat.add_item("Squared"); } m.set_cell(coord(&[("Measure", "Base")]), CellValue::Number(4.0)); m.add_formula(parse_formula("Squared = Base ^ 2", "Measure").unwrap()); assert_eq!(m.evaluate(&coord(&[("Measure", "Squared")])), CellValue::Number(16.0)); } #[test] fn formula_with_missing_ref_returns_empty() { let mut m = revenue_cost_model(); m.add_formula(parse_formula("Ghost = NoSuchField - Cost", "Measure").unwrap()); if let Some(cat) = m.category_mut("Measure") { cat.add_item("Ghost"); } let k = coord(&[("Measure", "Ghost"), ("Region", "East")]); assert_eq!(m.evaluate(&k), CellValue::Empty); } #[test] fn formula_where_applied_to_matching_region() { let mut m = revenue_cost_model(); m.add_formula(parse_formula("EastOnly = Revenue WHERE Region = \"East\"", "Measure").unwrap()); if let Some(cat) = m.category_mut("Measure") { cat.add_item("EastOnly"); } let val = m.evaluate(&coord(&[("Measure", "EastOnly"), ("Region", "East")])).as_f64().unwrap(); assert!(approx_eq(val, 1000.0)); } #[test] fn formula_where_not_applied_to_non_matching_region() { let mut m = revenue_cost_model(); m.add_formula(parse_formula("EastOnly = Revenue WHERE Region = \"East\"", "Measure").unwrap()); if let Some(cat) = m.category_mut("Measure") { cat.add_item("EastOnly"); } assert_eq!(m.evaluate(&coord(&[("Measure", "EastOnly"), ("Region", "West")])), CellValue::Empty); } #[test] fn add_formula_replaces_same_target() { let mut m = revenue_cost_model(); m.add_formula(parse_formula("Profit = Revenue - Cost", "Measure").unwrap()); m.add_formula(parse_formula("Profit = Revenue - Cost - 100", "Measure").unwrap()); assert_eq!(m.formulas.len(), 1); let k = coord(&[("Measure", "Profit"), ("Region", "East")]); assert_eq!(m.evaluate(&k), CellValue::Number(300.0)); } #[test] fn remove_formula() { let mut m = revenue_cost_model(); m.add_formula(parse_formula("Profit = Revenue - Cost", "Measure").unwrap()); m.remove_formula("Profit"); assert!(m.formulas.is_empty()); let k = coord(&[("Measure", "Profit"), ("Region", "East")]); assert_eq!(m.evaluate(&k), CellValue::Empty); } #[test] fn sum_aggregation_across_region() { let mut m = revenue_cost_model(); m.add_formula(parse_formula("Total = SUM(Revenue)", "Measure").unwrap()); if let Some(cat) = m.category_mut("Measure") { cat.add_item("Total"); } let val = m.evaluate(&coord(&[("Measure", "Total"), ("Region", "East")])).as_f64().unwrap(); assert!(val > 0.0, "SUM should be positive, got {val}"); } #[test] fn count_aggregation() { let mut m = Model::new("Test"); m.add_category("Measure").unwrap(); m.add_category("Region").unwrap(); if let Some(cat) = m.category_mut("Measure") { cat.add_item("Sales"); cat.add_item("Count"); } for region in ["East", "West", "North"] { m.set_cell(coord(&[("Measure", "Sales"), ("Region", region)]), CellValue::Number(100.0)); } m.add_formula(parse_formula("Count = COUNT(Sales)", "Measure").unwrap()); let val = m.evaluate(&coord(&[("Measure", "Count"), ("Region", "East")])).as_f64().unwrap(); assert!(val >= 1.0); } #[test] fn if_true_branch() { let mut m = Model::new("Test"); m.add_category("Measure").unwrap(); if let Some(cat) = m.category_mut("Measure") { cat.add_item("X"); cat.add_item("Result"); } m.set_cell(coord(&[("Measure", "X")]), CellValue::Number(10.0)); m.add_formula(parse_formula("Result = IF(X > 5, 1, 0)", "Measure").unwrap()); assert_eq!(m.evaluate(&coord(&[("Measure", "Result")])), CellValue::Number(1.0)); } #[test] fn if_false_branch() { let mut m = Model::new("Test"); m.add_category("Measure").unwrap(); if let Some(cat) = m.category_mut("Measure") { cat.add_item("X"); cat.add_item("Result"); } m.set_cell(coord(&[("Measure", "X")]), CellValue::Number(3.0)); m.add_formula(parse_formula("Result = IF(X > 5, 1, 0)", "Measure").unwrap()); assert_eq!(m.evaluate(&coord(&[("Measure", "Result")])), CellValue::Number(0.0)); } } #[cfg(test)] mod five_category { use super::Model; use crate::model::cell::{CellKey, CellValue}; use crate::formula::parse_formula; use crate::view::Axis; const DATA: &[(&str, &str, &str, &str, f64, f64)] = &[ ("East", "Shirts", "Online", "Q1", 1_000.0, 600.0), ("East", "Shirts", "Online", "Q2", 1_200.0, 700.0), ("East", "Shirts", "Retail", "Q1", 800.0, 500.0), ("East", "Shirts", "Retail", "Q2", 900.0, 540.0), ("East", "Pants", "Online", "Q1", 500.0, 300.0), ("East", "Pants", "Online", "Q2", 600.0, 360.0), ("East", "Pants", "Retail", "Q1", 400.0, 240.0), ("East", "Pants", "Retail", "Q2", 450.0, 270.0), ("West", "Shirts", "Online", "Q1", 700.0, 420.0), ("West", "Shirts", "Online", "Q2", 750.0, 450.0), ("West", "Shirts", "Retail", "Q1", 600.0, 360.0), ("West", "Shirts", "Retail", "Q2", 650.0, 390.0), ("West", "Pants", "Online", "Q1", 300.0, 180.0), ("West", "Pants", "Online", "Q2", 350.0, 210.0), ("West", "Pants", "Retail", "Q1", 250.0, 150.0), ("West", "Pants", "Retail", "Q2", 280.0, 168.0), ]; fn coord(region: &str, product: &str, channel: &str, time: &str, measure: &str) -> CellKey { CellKey::new(vec![ ("Channel".to_string(), channel.to_string()), ("Measure".to_string(), measure.to_string()), ("Product".to_string(), product.to_string()), ("Region".to_string(), region.to_string()), ("Time".to_string(), time.to_string()), ]) } fn build_model() -> Model { let mut m = Model::new("Sales"); for cat in ["Region", "Product", "Channel", "Time", "Measure"] { m.add_category(cat).unwrap(); } for cat in ["Region", "Product", "Channel", "Time"] { let items: &[&str] = match cat { "Region" => &["East", "West"], "Product" => &["Shirts", "Pants"], "Channel" => &["Online", "Retail"], "Time" => &["Q1", "Q2"], _ => &[], }; if let Some(c) = m.category_mut(cat) { for &item in items { c.add_item(item); } } } if let Some(c) = m.category_mut("Measure") { for &item in &["Revenue", "Cost", "Profit", "Margin", "Total"] { c.add_item(item); } } for &(region, product, channel, time, rev, cost) in DATA { m.set_cell(coord(region, product, channel, time, "Revenue"), CellValue::Number(rev)); m.set_cell(coord(region, product, channel, time, "Cost"), CellValue::Number(cost)); } m.add_formula(parse_formula("Profit = Revenue - Cost", "Measure").unwrap()); m.add_formula(parse_formula("Margin = Profit / Revenue", "Measure").unwrap()); m.add_formula(parse_formula("Total = SUM(Revenue)", "Measure").unwrap()); m } fn approx(a: f64, b: f64) -> bool { (a - b).abs() < 1e-9 } #[test] fn all_sixteen_revenue_cells_stored() { let m = build_model(); let count = DATA.iter() .filter(|&&(r, p, c, t, _, _)| !m.get_cell(&coord(r, p, c, t, "Revenue")).is_empty()) .count(); assert_eq!(count, 16); } #[test] fn all_sixteen_cost_cells_stored() { let m = build_model(); let count = DATA.iter() .filter(|&&(r, p, c, t, _, _)| !m.get_cell(&coord(r, p, c, t, "Cost")).is_empty()) .count(); assert_eq!(count, 16); } #[test] fn spot_check_raw_revenue() { let m = build_model(); assert_eq!(m.get_cell(&coord("East", "Shirts", "Online", "Q1", "Revenue")), &CellValue::Number(1_000.0)); assert_eq!(m.get_cell(&coord("West", "Pants", "Retail", "Q2", "Revenue")), &CellValue::Number(280.0)); } #[test] fn distinct_cells_do_not_alias() { let m = build_model(); let a = m.get_cell(&coord("East", "Shirts", "Online", "Q1", "Revenue")).clone(); let b = m.get_cell(&coord("West", "Pants", "Retail", "Q2", "Revenue")).clone(); assert_ne!(a, b); } #[test] fn profit_formula_correct_at_every_intersection() { let m = build_model(); for &(region, product, channel, time, rev, cost) in DATA { let expected = rev - cost; let actual = m.evaluate(&coord(region, product, channel, time, "Profit")) .as_f64() .unwrap_or_else(|| panic!("Profit empty at {region}/{product}/{channel}/{time}")); assert!(approx(actual, expected), "Profit at {region}/{product}/{channel}/{time}: expected {expected}, got {actual}"); } } #[test] fn margin_formula_correct_at_every_intersection() { let m = build_model(); for &(region, product, channel, time, rev, cost) in DATA { let expected = (rev - cost) / rev; let actual = m.evaluate(&coord(region, product, channel, time, "Margin")) .as_f64() .unwrap_or_else(|| panic!("Margin empty at {region}/{product}/{channel}/{time}")); assert!(approx(actual, expected), "Margin at {region}/{product}/{channel}/{time}: expected {expected:.4}, got {actual:.4}"); } } #[test] fn chained_formula_profit_feeds_margin() { let m = build_model(); let margin = m.evaluate(&coord("East", "Shirts", "Online", "Q1", "Margin")).as_f64().unwrap(); assert!(approx(margin, 0.4), "expected 0.4, got {margin}"); } #[test] fn update_revenue_updates_profit_and_margin() { let mut m = build_model(); m.set_cell(coord("East", "Shirts", "Online", "Q1", "Revenue"), CellValue::Number(1_500.0)); let profit = m.evaluate(&coord("East", "Shirts", "Online", "Q1", "Profit")).as_f64().unwrap(); assert!(approx(profit, 900.0), "expected 900, got {profit}"); let margin = m.evaluate(&coord("East", "Shirts", "Online", "Q1", "Margin")).as_f64().unwrap(); assert!(approx(margin, 0.6), "expected 0.6, got {margin}"); } #[test] fn sum_revenue_for_east_region() { let m = build_model(); let partial = vec![ ("Measure".to_string(), "Revenue".to_string()), ("Region".to_string(), "East".to_string()), ]; let total = m.data.sum_matching(&partial); let expected: f64 = DATA.iter().filter(|&&(r, _, _, _, _, _)| r == "East").map(|&(_, _, _, _, rev, _)| rev).sum(); assert!(approx(total, expected), "expected {expected}, got {total}"); } #[test] fn sum_revenue_for_online_channel() { let m = build_model(); let partial = vec![ ("Channel".to_string(), "Online".to_string()), ("Measure".to_string(), "Revenue".to_string()), ]; let total = m.data.sum_matching(&partial); let expected: f64 = DATA.iter().filter(|&&(_, _, ch, _, _, _)| ch == "Online").map(|&(_, _, _, _, rev, _)| rev).sum(); assert!(approx(total, expected), "expected {expected}, got {total}"); } #[test] fn sum_revenue_for_shirts_q1() { let m = build_model(); let partial = vec![ ("Measure".to_string(), "Revenue".to_string()), ("Product".to_string(), "Shirts".to_string()), ("Time".to_string(), "Q1".to_string()), ]; let total = m.data.sum_matching(&partial); let expected: f64 = DATA.iter().filter(|&&(_, p, _, t, _, _)| p == "Shirts" && t == "Q1").map(|&(_, _, _, _, rev, _)| rev).sum(); assert!(approx(total, expected), "expected {expected}, got {total}"); } #[test] fn sum_all_revenue_equals_grand_total() { let m = build_model(); let partial = vec![("Measure".to_string(), "Revenue".to_string())]; let total = m.data.sum_matching(&partial); let expected: f64 = DATA.iter().map(|&(_, _, _, _, rev, _)| rev).sum(); assert!(approx(total, expected), "expected {expected}, got {total}"); } #[test] fn default_view_first_two_on_axes_rest_on_page() { let m = build_model(); let v = m.active_view().unwrap(); assert_eq!(v.axis_of("Region"), Axis::Row); assert_eq!(v.axis_of("Product"), Axis::Column); assert_eq!(v.axis_of("Channel"), Axis::Page); assert_eq!(v.axis_of("Time"), Axis::Page); assert_eq!(v.axis_of("Measure"), Axis::Page); } #[test] fn rearranging_axes_does_not_affect_data() { let mut m = build_model(); if let Some(v) = m.active_view_mut() { v.set_axis("Region", Axis::Page); v.set_axis("Product", Axis::Page); v.set_axis("Channel", Axis::Row); v.set_axis("Time", Axis::Column); v.set_axis("Measure", Axis::Page); } assert_eq!(m.get_cell(&coord("East", "Shirts", "Online", "Q1", "Revenue")), &CellValue::Number(1_000.0)); } #[test] fn two_views_have_independent_axis_assignments() { let mut m = build_model(); m.create_view("Pivot"); if let Some(v) = m.views.get_mut("Pivot") { v.set_axis("Time", Axis::Row); v.set_axis("Channel", Axis::Column); v.set_axis("Region", Axis::Page); v.set_axis("Product", Axis::Page); v.set_axis("Measure", Axis::Page); } assert_eq!(m.views.get("Default").unwrap().axis_of("Region"), Axis::Row); assert_eq!(m.views.get("Pivot").unwrap().axis_of("Time"), Axis::Row); assert_eq!(m.views.get("Pivot").unwrap().axis_of("Channel"), Axis::Column); } #[test] fn page_selections_are_per_view() { let mut m = build_model(); m.create_view("West only"); if let Some(v) = m.views.get_mut("West only") { v.set_page_selection("Region", "West"); } assert_eq!(m.views.get("Default").unwrap().page_selection("Region"), None); assert_eq!(m.views.get("West only").unwrap().page_selection("Region"), Some("West")); } #[test] fn five_categories_well_within_limit() { let m = build_model(); assert_eq!(m.categories.len(), 5); let mut m2 = build_model(); for i in 0..7 { m2.add_category(format!("Extra{i}")).unwrap(); } assert_eq!(m2.categories.len(), 12); assert!(m2.add_category("OneMore").is_err()); } }