"""Formula parsing utilities for building design matrices from strings.
This module provides a lightweight, NumPy/Pandas-friendly parser for model
formulas. It supports:
- Intercept handling via the special column "(intercept)".
- Basic terms (column names), interactions with ":" and products with "*".
- Inclusion/exclusion using "+" and "-" (handled via `make_pretty_minus`).
- Parentheses grouping.
- Powers using "**" (or "^" if `USE_CARET=True`).
- "As-is" expressions via `I(<python expression>)`, evaluated against `data`.
- A curated set of NumPy elementwise functions (see `NUMPY_FUNCS`).
- Special polynomial generators such as "quadratic(x1,x2)", "cubic(...)", etc.
Key entry points:
- parse(): returns x_vars, y_var, conditionals, and the processed DataFrame.
- design helpers such as generate_interactions().
Notes:
- The parser mutates a *copy* of the input DataFrame and returns it.
- (intercept) is always added as a column with value 1 unless excluded.
"""
import numpy as np
import pandas as pd
from typing import Optional, Tuple, Dict, Any, Callable
import re
from itertools import chain, combinations, combinations_with_replacement
from functools import partial
################################################################################
################################# Constants ##################################
################################################################################
USE_CARET: bool = False
"""Whether to accept '^' as a power operator (translated to '**')."""
LOGGING: bool = False
"""Enable debug prints from parse steps when True."""
NUMPY_FUNCS: Dict[str, Callable[[np.ndarray], np.ndarray]] = {
'around': lambda x: np.around(x),
'rint': lambda x: np.rint(x),
'fix': lambda x: np.fix(x),
'floor': lambda x: np.floor(x),
'ceil': lambda x: np.ceil(x),
'trunc': lambda x: np.trunc(x),
'diff': lambda x: np.diff(x),
'ediff1d': lambda x: np.ediff1d(x),
'exp': lambda x: np.exp(x),
'expm1': lambda x: np.expm1(x),
'exp2': lambda x: np.exp2(x),
'log': lambda x: np.log(x),
'log10': lambda x: np.log10(x),
'log2': lambda x: np.log2(x),
'log1p': lambda x: np.log1p(x),
'i0': lambda x: np.i0(x),
'sinc': lambda x: np.sinc(x),
'signbit': lambda x: np.signbit(x),
'spacing': lambda x: np.spacing(x),
'reciprocal': lambda x: np.reciprocal(x),
'positive': lambda x: np.positive(x),
'negative': lambda x: np.negative(x),
'angle': lambda x: np.angle(x),
'real': lambda x: np.real(x),
'imag': lambda x: np.imag(x),
'conj': lambda x: np.conj(x),
'conjugate': lambda x: np.conjugate(x),
'sqrt': lambda x: np.sqrt(x),
'cbrt': lambda x: np.cbrt(x),
'square': lambda x: np.square(x),
'absolute': lambda x: np.absolute(x),
'fabs': lambda x: np.fabs(x),
'sign': lambda x: np.sign(x),
'nan_to_num': lambda x: np.nan_to_num(x),
'real_if_close': lambda x: np.real_if_close(x)
}
################################################################################
############################## Helper functions ##############################
################################################################################
[docs]
def bind(x: list) -> list:
"""Flatten a list of lists by one level.
Args:
x (list): List whose elements are iterables to be chained.
Returns:
list: Single flattened list.
"""
return list(chain(*x))
[docs]
def unique(x: list) -> list:
"""Return list with original order but unique entries.
Args:
x (list): Input list.
Returns:
list: De-duplicated list preserving first occurrence order.
"""
outp: list = []
for item in x:
if item not in outp:
outp.append(item)
return outp
[docs]
def split_expression(expression: str, delimiter: str = '+') -> list[str]:
"""Split an expression by a delimiter respecting parentheses depth.
This avoids splitting inside nested parentheses.
Args:
expression (str): Expression string to split.
delimiter (str): Delimiter character (default: '+').
Returns:
list[str]: The top-level fragments.
"""
# Preprocess the expression to remove spaces for easier processing
expression = expression.replace(' ', '')
parts: list[str] = [] # To store the split parts of the expression
current_part: list[str] = [] # To build the current part of the expression
depth: int = 0 # Track the depth of parentheses nesting
for char in expression:
if char == delimiter and depth == 0:
# Join the characters of the current part and add it to the parts list
parts += [''.join(current_part)]
current_part = []
else:
depth += 1 if char == '(' else -1 if char == ')' else 0
current_part += [char]
# Add the last part to the parts list, if it's not empty
if current_part: parts += [''.join(current_part)]
return parts
[docs]
def match_parens(expression: str) -> Optional[str]:
"""If `expression` is fully parenthesized, return the inside; else None.
Args:
expression (str): Candidate string like "(a+b)".
Returns:
Optional[str]: Inner content without outer parentheses if valid; otherwise None.
"""
if expression[0]!='(' or expression[-1]!=')': return None
current_depth: int = 0
depth: list[int] = []
expression = expression[1:-1]
for char in expression:
current_depth += 1 if char == '(' else -1 if char == ')' else 0
depth += [current_depth]
if any([item<0 for item in depth]): return None
return expression
[docs]
def make_pretty_minus(expression: str) -> str:
"""Normalize '-' to '+-' at top level to simplify inclusion/exclusion logic.
Example:
"x - y + z" -> "x+-y+z" (and leading '+' removed if present)
Args:
expression (str): Raw expression.
Returns:
str: Normalized expression.
"""
outp: str = '+-'.join(split_expression(expression,'-'))
outp = outp.replace('++', '+')
if len(outp)!=0 and outp[0] == '+': outp = outp[1:]
return outp
[docs]
def bin(x: list) -> dict:
"""Count occurrences of each element in a list.
Args:
x (list): Input list.
Returns:
dict: Mapping item -> count.
"""
return {item:sum([1 for newbie in x if newbie==item]) for item in unique(x)}
################################################################################
############################# Expression Parsing #############################
################################################################################
[docs]
def parse(formula: str, data: pd.DataFrame) -> Tuple[list[str], str, list[str], pd.DataFrame]:
"""Parse a full formula into design metadata and a processed DataFrame.
Syntax:
y ~ rhs [| conditionals]
Where `rhs` can contain:
- column names
- interactions with ':' (e.g., a:b)
- products with '*' (expanded to main effects + interactions unless
distribution is detected)
- '+' and '-' to include/exclude terms
- powers via '**' (or '^' if `USE_CARET=True`)
- '.' to include all columns
- special generators like 'quadratic(a,b)', etc.
- 'I(expr)' to evaluate raw Python/NumPy expressions on columns
The function:
- adds '(intercept)' to `data`,
- parses the left-hand side (y),
- returns selected x-vars (after inclusion/exclusion),
- returns optional `conditionals` to be passed downstream.
Args:
formula (str): Full formula string.
data (pd.DataFrame): Source data.
Returns:
Tuple[list[str], str, list[str], pd.DataFrame]:
- x_vars: Ordered unique design column names (includes '(intercept)' unless removed).
- y_var: Dependent variable name.
- conditionals: Parsed conditional columns (after inclusion/exclusion).
- processed: A copy of data with derived columns added, restricted to
[y_var] + x_vars + conditionals.
"""
y_var, rhs = formula.split('~')
if '|' in formula:
rhs, conditionals = formula.split('|')
else:
conditionals = ''
data = data.copy()
data['(intercept)'] = 1
parse_term(y_var, data)
x_vars_included, x_vars_excluded = parse_expression(rhs, data)
conditionals_included, conditionals_excluded = parse_expression(conditionals, data)
x_vars_included = ['(intercept)']+x_vars_included
x_vars = [item for item in x_vars_included if item not in x_vars_excluded and item != y_var]
conditionals = [item for item in conditionals_included if item not in conditionals_excluded and item != y_var]
x_vars = unique(x_vars)
conditionals = unique(conditionals)
return x_vars, y_var, conditionals, data[[y_var]+x_vars+conditionals]
[docs]
def parse_expression(expression: str, data: pd.DataFrame) -> tuple[list[str], list[str]]:
"""Parse a right-hand-side-like expression into included and excluded terms.
This function orchestrates:
1) normalization (`make_pretty_minus`, removing "np." / "numpy.", caret handling),
2) attempting `parse_basic`,
3) falling back to `parse_complex`.
Args:
expression (str): RHS-like expression (may be empty).
data (pd.DataFrame): DataFrame to mutate with derived columns.
Returns:
tuple[list[str], list[str]]: (included_terms, excluded_terms)
Notes:
- Returns `None` on failure, but callers typically rely on truthiness and
do not expect `None` in normal flows.
"""
expression = expression.strip()
expression = make_pretty_minus(expression)
expression = expression.replace('np.', '')
expression = expression.replace('numpy.', '')
if USE_CARET: expression = expression.replace('^','**')
if result:=parse_basic(expression, data):
if LOGGING: print('simple',expression,result[0],result[1])
return result
if result:=parse_complex(expression, data):
if LOGGING: print('complex',expression,result[0],result[1])
return result
return None
[docs]
def parse_basic(expression: str, data: pd.DataFrame) -> Optional[tuple[list[str], list[str]]]:
"""Handle simple cases: literals, parentheses, single terms, all-cols, powers, sums.
Rules (order matters):
- Empty string -> ([], [])
- Parenthesized -> parse inner
- "1" -> intercept
- Valid term -> [term]
- "." -> all current columns
- "(... )**k" -> expand to interactions up to power k
- "a+b+..." -> sum of sub-expressions (recursively parsed)
- "-expr" -> invert included/excluded sets (for minus handling)
- Special power funcs (e.g., "quadratic(a,b)")
Args:
expression (str): Candidate expression.
data (pd.DataFrame): Data to mutate with derived columns.
Returns:
Optional[tuple[list[str], list[str]]]: (included, excluded) or None.
"""
if expression == '':
return [], []
if match := match_parens(expression):
return parse_expression(match, data)
if expression == '1':
return ['(intercept)'], []
if parse_term(expression, data):
return [expression], []
if expression == '.':
return data.columns, []
if match := re.match(r'^\((.*)\)\*\*(\d*)$',expression):
if match_parens(f'({match.group(1)})'):
power: int = int(match.group(2))
expression = match.group(1)
included: list[str]
excluded: list[str]
included, excluded = parse_expression(expression, data)
if excluded:
raise ValueError(f'Expression {expression} seems to have excluded terms in a power')
included += generate_interactions(included, power=power, data=data)
return included, excluded
if len(terms := split_expression(expression)) > 1:
included: list[str] = []
excluded: list[str] = []
for term in terms:
if result := parse_expression(term, data):
included += result[0]
excluded += result[1]
return included, excluded
if expression[0]=='-':
if not (result := parse_expression(expression[1:], data)):
raise ValueError(f'Expression {expression} cannot be inverted')
return result[1], result[0]
if included := check_special_power_funcs(expression, data):
return included, []
return None
[docs]
def parse_complex(expression: str, data: pd.DataFrame) -> Optional[Tuple[list[str], list[str]]]:
"""Handle products '*' and interactions ':' with distribution/expansion logic.
Strategy:
- Try splitting by '*' via `parse_complex_expression_by_splitting_on_string`.
If distributed=True, the product was distributable and we return terms.
Otherwise, we add generated interactions.
- Try ':' similarly; ensure resulting string is a valid interaction.
Args:
expression (str): Candidate expression with '*' or ':'.
data (pd.DataFrame): Data to mutate.
Returns:
Optional[Tuple[list[str], list[str]]]: (included, excluded) or None.
Raises:
ValueError: If negations are detected in product/interaction contexts or
invalid interaction strings are produced.
"""
included: list[str]
excluded: list[str] = []
if result := parse_complex_expression_by_splitting_on_string(expression,data):
if result['error']:
raise ValueError(f'Expression {expression} found negations in product')
distributed: bool = result['distributed']
included = result['terms']
if distributed:
return included, excluded
included += generate_interactions(included, data=data)
return included, excluded
if result := parse_complex_expression_by_splitting_on_string(expression, data, delimiter=':'):
if result['error']:
raise ValueError(f'Expression {expression} found negations in interaction')
included = result['terms']
expression = ':'.join(included)
included = [expression]
if is_interaction(expression):
return included, excluded
else:
raise ValueError(f'Expression {expression} cannot parse an interaction')
return None
[docs]
def parse_complex_expression_by_splitting_on_string(
expression: str,
data: pd.DataFrame,
delimiter: str = '*'
) -> Optional[Dict[str,Any]]:
"""Split by a delimiter ('*' or ':') and attempt recursive parsing/distribution.
For '*':
- If any sub-expression yields multiple included terms (and no excluded),
attempt distribution across the product.
- If distribution succeeds, return the distributed terms with
{'error': False, 'distributed': True, 'terms': [...]}.
- Otherwise, return the collected simple terms and mark
{'error': False, 'distributed': False, 'terms': [...]}, leaving the caller
to generate interactions.
For ':':
- Just return the list of terms; the caller will validate/construct the
final interaction string.
Args:
expression (str): Input expression.
data (pd.DataFrame): Data to mutate during parsing/evaluation.
delimiter (str): Either '*' or ':'.
Returns:
Optional[Dict[str, Any]]: A dictionary with keys:
- 'error' (bool): True if excluded terms invalidated the operation.
- 'distributed' (bool): True if product distribution occurred.
- 'terms' (list[str] | None): Collected raw terms when successful.
"""
mod_expression: str = expression.replace('**','<SPECIAL_DELIMITER>')
if len(terms := split_expression(mod_expression, delimiter=delimiter))>1:
new_terms: list = []
for i,term in enumerate(terms):
unmod_term = term.replace('<SPECIAL_DELIMITER>','**')
if result := parse_basic(unmod_term, data):
included = result[0]
excluded = result[1]
if excluded:
return {'error':True, 'distributed':False, 'terms':None}
if len(included)>1:
retp = [terms[:i]+[included[j]]+terms[(i+1):] for j in range(len(included))]
retp = [parse_expression(delimiter.join(item).replace('<SPECIAL_DELIMITER>', '**'),data) for item in retp]
if bind([item[1] for item in retp]):
return {'error':True, 'distributed':False, 'terms':None}
return {'error': False, 'distributed': True, 'terms': bind([item[0] for item in retp])}
elif len(included)==1:
new_terms += [unmod_term]
else:
...
return {'error': False, 'distributed': False, 'terms': new_terms}
return None
################################################################################
################################ Term parsing ################################
################################################################################
[docs]
def parse_term(term: str, data: pd.DataFrame) -> bool:
"""Parse a single term by attempting NumPy func, interaction, or as-is.
Order:
1) in_numpy()
2) is_interaction()
3) is_as_is()
Args:
term (str): A candidate term string.
data (pd.DataFrame): Data to be mutated if term is derived.
Returns:
bool: True if the term was successfully parsed/applied to `data`.
"""
term = term.strip()
if in_numpy(term, data): return True
if is_interaction(term, data): return True
if is_as_is(term, data): return True
return False
[docs]
def is_interaction(expression: str, data: pd.DataFrame) -> bool:
"""Create interaction column for colon-separated terms.
Example:
"a:b:c" -> data["a:b:c"] = data["a"] * data["b"] * data["c"]
Args:
expression (str): Interaction expression with ':'.
data (pd.DataFrame): Data to mutate.
Returns:
bool: True if an interaction was created; False otherwise.
"""
if len(terms:=split_expression(expression, ':'))>1:
outp = data['(intercept)'].copy()
for term in terms:
if not parse_term(term, data):
return False
outp *= data[term]
data[expression] = outp
return True
return False
[docs]
def in_numpy(expression: str, data: pd.DataFrame) -> bool:
"""Evaluate a recognized NumPy unary function or treat as existing column.
If `expression` matches a key in `NUMPY_FUNCS` in the form "<func>(col)",
the new column is added as that function applied to `data[col]`. If the
expression is already a column name, this returns True.
Args:
expression (str): Either a column name or "<func>(col)".
data (pd.DataFrame): Data to mutate.
Returns:
bool: True if the expression is a known column or created successfully.
"""
if expression in data.columns: return True
for name, function in NUMPY_FUNCS.items():
pattern = r'^' + re.escape(name) + r'\((.*)\)$'
if match := re.match(pattern, expression):
inside: str = match.group(1)
if inside in data.columns:
data[expression] = function(data[inside])
return True
return False
[docs]
def is_as_is(expression: str, data: pd.DataFrame) -> bool:
"""Evaluate a raw Python/NumPy expression with `I(...)`.
Replaces bare column names in the interior with `data['col']` and ensures
bare NumPy function names are qualified with `np.` if present in NUMPY_FUNCS.
Example:
I((x1 + x2)**2) or I(sqrt(x))
Args:
expression (str): Expression beginning with 'I'.
data (pd.DataFrame): Data to evaluate against.
Returns:
bool: True if successfully evaluated and assigned; False otherwise.
"""
if expression[0]!='I': return False
if match := match_parens(expression[1:]):
for item in data.columns:
match = match.replace(item, f'data[\'{item}\']')
for name in NUMPY_FUNCS:
if not re.search(r'\.'+name+r'\(', match) and re.search(name+r'\(', match):
match = match.replace(name,'np.'+name)
data[expression] = eval(match)
return True
return False
################################################################################
######################### Generation of Interactions #########################
################################################################################
[docs]
def generate_interactions(x: list[str], data: pd.DataFrame, power: Optional[int] = None) -> list[str]:
"""Generate all unique interaction terms up to a given order.
Args:
x (list[str]): Base term names (already validated/created in `data`).
data (pd.DataFrame): Data to mutate with interactions.
power (Optional[int]): Maximum interaction order. Defaults to len(x).
Returns:
list[str]: Sorted, unique interaction strings that were generated.
Raises:
ValueError: If any generated term fails to create an interaction in `data`.
"""
x = unique(x.copy())
if power==None: power = len(x)
x = unique(bind([unique([":".join(item) for item in combinations(x, r=i)]) for i in range(2, power + 1)]))
for i, item in enumerate(x):
terms = split_expression(item, delimiter=":")
terms = unique(terms)
terms.sort()
x[i] = ":".join(terms)
x = unique(x)
if any([not is_interaction(item, data) for item in x]):
raise ValueError(f'Terms {x} failed to generate interactions')
return x
################################################################################
########################## Special power functions ###########################
################################################################################
[docs]
def special_power(terms: list[str], data: pd.DataFrame, power: int = 1) -> Optional[list[str]]:
r"""Generate special-power terms (linear, quadratic, etc.) for a set of terms.
Args:
terms (list[str]): Base terms to expand.
data (pd.DataFrame): DataFrame where generated columns will be stored.
power (int, optional): Maximum power to generate. Defaults to 1.
Returns:
list[str]: Names of generated terms stored in ``data``.
Notes:
- For a single term ``x``, quadratic produces ``x`` and ``I(x**2)``.
- For multiple terms, interaction powers like ``I(x*y)`` and ``I(x**2*y)`` may be created.
"""
outp: list[str] = []
terms += ['1']
for pairing in combinations_with_replacement(terms, r=power):
pairing = [item for item in pairing if item!='1']
match len(pairing):
case 0:
continue
case 1:
expression = pairing[0]
case _:
pairing.sort()
mydict = bin(pairing)
if len(mydict)==1:
term = list(mydict.keys())[0]
power = mydict[term]
expression = f'I({f"({term})**{power}" if len(split_expression(term)) > 1 else f"{term}**{power}" if power != 1 else f"{term}"})'
else:
expression = f'I({"*".join([f"(({key})**{value})" if len(split_expression(key)) > 1 else f"({key}**{value})" if value != 1 else f"{key}" for key, value in bin(pairing).items()])})'
if parse_term(expression, data):
outp += [expression]
else:
raise ValueError(f'Expression: {expression} could not be parsed as a term')
return outp
special_power_funcs = {
'linear': partial(special_power, power=1),
'quadratic': partial(special_power, power=2),
'cubic': partial(special_power, power=3),
'quartic': partial(special_power, power=4),
'quintic': partial(special_power, power=5),
'sextic': partial(special_power, power=6),
'hexic': partial(special_power, power=6),
'septic': partial(special_power, power=7),
'octic': partial(special_power, power=8),
'nonic': partial(special_power, power=9),
'decic': partial(special_power, power=10),
'duodecic': partial(special_power, power=12),
'vigintic': partial(special_power, power=20)
}
"""Registry mapping special polynomial keywords to generator callables."""
[docs]
def check_special_power_funcs(expression: str, data: pd.DataFrame) -> Optional[list[str]]:
"""Detect and expand special polynomial helpers like 'quadratic(...)'.
Args:
expression (str): Expression beginning with a registered keyword.
data (pd.DataFrame): Data to mutate with generated terms.
Returns:
Optional[list[str]]: Generated term list if matched; otherwise None.
"""
for key,func in special_power_funcs.items():
if expression.startswith(key):
if match:=match_parens(expression[len(key):]):
terms = split_expression(match, delimiter=',')
result = func(terms=terms, data=data)
return result
else:
return
else:
...
return