Source code for root_mcp.extended.analysis.expression

"""Expression evaluation utilities."""

from __future__ import annotations

import ast
import re
from typing import Any

import numpy as np


[docs] def strip_outer_parens(expr: str) -> str: """Strip outer parentheses from expression.""" expr = expr.strip() while expr.startswith("(") and expr.endswith(")"): depth = 0 ok = True for i, ch in enumerate(expr): if ch == "(": depth += 1 elif ch == ")": depth -= 1 if depth == 0 and i != len(expr) - 1: ok = False break if not ok: break expr = expr[1:-1].strip() return expr
[docs] def translate_leaf_expr(expr: str) -> str: """Translate C++ style expression to Python.""" expr = expr.strip() expr = expr.replace("&&", "&").replace("||", "|") expr = re.sub(r"!(?!=)", "~", expr) expr = re.sub(r"\btrue\b", "True", expr, flags=re.IGNORECASE) expr = re.sub(r"\bfalse\b", "False", expr, flags=re.IGNORECASE) return expr
[docs] class SafeExprEvaluator(ast.NodeVisitor): """ Safe expression evaluator for limited math subset. Uses abstract syntax tree to evaluate expressions without eval(). """
[docs] def __init__(self, names: dict[str, Any]): """ Initialize evaluator. Args: names: Dictionary of allowed variable names and their values """ self.names = names self.funcs: dict[str, Any] = { "abs": np.abs, "sqrt": np.sqrt, "log": np.log, "exp": np.exp, "sin": np.sin, "cos": np.cos, "tan": np.tan, "min": np.minimum, "max": np.maximum, "sinh": np.sinh, "cosh": np.cosh, "tanh": np.tanh, "arcsin": np.arcsin, "arccos": np.arccos, "arctan": np.arctan, "arctan2": np.arctan2, }
[docs] def visit_Expression(self, node: ast.Expression) -> Any: return self.visit(node.body)
[docs] def visit_Name(self, node: ast.Name) -> Any: if node.id in self.names: return self.names[node.id] if node.id in {"True", "False"}: return node.id == "True" if node.id in self.funcs: return self.funcs[node.id] raise ValueError(f"Unknown identifier in expression: {node.id}")
[docs] def visit_Constant(self, node: ast.Constant) -> Any: return node.value
[docs] def visit_UnaryOp(self, node: ast.UnaryOp) -> Any: operand = self.visit(node.operand) if isinstance(node.op, ast.USub): return -operand if isinstance(node.op, ast.UAdd): return +operand if isinstance(node.op, ast.Invert): return ~operand raise ValueError("Unsupported unary operator")
[docs] def visit_BinOp(self, node: ast.BinOp) -> Any: left = self.visit(node.left) right = self.visit(node.right) op = node.op if isinstance(op, ast.Add): return left + right if isinstance(op, ast.Sub): return left - right if isinstance(op, ast.Mult): return left * right if isinstance(op, ast.Div): return left / right if isinstance(op, ast.Pow): return left**right if isinstance(op, ast.Mod): return left % right if isinstance(op, ast.BitAnd): return left & right if isinstance(op, ast.BitOr): return left | right raise ValueError("Unsupported binary operator")
[docs] def visit_Compare(self, node: ast.Compare) -> Any: if len(node.ops) != 1 or len(node.comparators) != 1: raise ValueError("Chained comparisons are not supported") left = self.visit(node.left) right = self.visit(node.comparators[0]) op = node.ops[0] if isinstance(op, ast.Lt): return left < right if isinstance(op, ast.LtE): return left <= right if isinstance(op, ast.Gt): return left > right if isinstance(op, ast.GtE): return left >= right if isinstance(op, ast.Eq): return left == right if isinstance(op, ast.NotEq): return left != right raise ValueError("Unsupported comparison operator")
[docs] def visit_Call(self, node: ast.Call) -> Any: if not isinstance(node.func, ast.Name): raise ValueError("Only simple function calls are allowed") fn = self.visit_Name(node.func) if fn not in self.funcs.values(): raise ValueError(f"Function '{node.func.id}' is not allowed") if node.keywords: raise ValueError("Keyword arguments are not supported") args = [self.visit(a) for a in node.args] return fn(*args)
[docs] def generic_visit(self, node: ast.AST) -> Any: raise ValueError(f"Unsupported syntax in expression: {type(node).__name__}")