"""Readers for TTree and histogram data."""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any
import awkward as ak
import numpy as np
if TYPE_CHECKING:
from root_mcp.config import Config
from root_mcp.core.io.file_manager import FileManager
logger = logging.getLogger(__name__)
[docs]
class TreeReader:
"""
High-level interface for reading TTree data.
Provides safe, efficient access to TTree branches with chunking,
filtering, and pagination.
"""
[docs]
def __init__(self, config: Config, file_manager: FileManager):
"""
Initialize TreeReader.
Args:
config: Server configuration
file_manager: File manager instance
"""
self.config = config
self.file_manager = file_manager
[docs]
def read_branches(
self,
path: str,
tree_name: str,
branches: list[str],
selection: str | None = None,
limit: int | None = None,
offset: int = 0,
flatten: bool = False,
defines: dict[str, str] | None = None,
) -> dict[str, Any]:
"""
Read branch data from a TTree.
Args:
path: File path
tree_name: Tree name
branches: List of branch names to read (can include derived branches from defines)
selection: Optional ROOT-style cut expression
limit: Maximum number of entries to return
offset: Number of entries to skip
flatten: Flatten jagged arrays
defines: Optional derived variable definitions {name: expression}
Returns:
Dictionary with data and metadata
"""
tree = self.file_manager.get_tree(path, tree_name)
# Get available branches from tree
available_branches = set(tree.keys())
# Determine which branches are physical vs derived
defined_branches = set(defines.keys()) if defines else set()
physical_branches_requested = []
derived_branches_requested = []
# Validate all requested branches exist (either in tree or in defines)
for branch in branches:
if branch in available_branches:
physical_branches_requested.append(branch)
elif branch in defined_branches:
derived_branches_requested.append(branch)
else:
similar = self._find_similar_branches(branch, list(available_branches))
suggestion = f"Did you mean: {similar[:3]}?" if similar else ""
raise KeyError(
f"Branch '{branch}' not found in tree '{tree_name}' or defines. "
f"Available: {list(available_branches)[:10]}... {suggestion}"
)
# Collect all branches needed for reading (physical branches + dependencies of derived branches)
branches_to_read = set(physical_branches_requested)
if defines:
# Extract dependencies from all define expressions
from root_mcp.extended.analysis.operations import _extract_branches_from_expression
for def_name, def_expr in defines.items():
# Get branches used in this definition
needed = _extract_branches_from_expression(def_expr, list(available_branches))
branches_to_read.update(needed)
# Also extract branches from selection if it exists
if selection:
selection_branches = _extract_branches_from_expression(
selection, list(available_branches)
)
branches_to_read.update(selection_branches)
elif selection:
# Just selection, no defines
from root_mcp.extended.analysis.operations import _extract_branches_from_expression
selection_branches = _extract_branches_from_expression(
selection, list(available_branches)
)
branches_to_read.update(selection_branches)
# Apply limit bounds
if limit is None:
limit = self.config.analysis.default_read_limit
limit = min(limit, self.config.limits.max_rows_per_call)
# Calculate entry range
total_entries = tree.num_entries
entry_start = offset
entry_stop = min(offset + limit, total_entries)
logger.info(
f"Reading {len(branches_to_read)} physical branches from {tree_name} "
f"(entries {entry_start}:{entry_stop}/{total_entries}), "
f"with {len(derived_branches_requested)} derived branches"
)
# Read data from tree
try:
arrays = tree.arrays(
filter_name=list(branches_to_read),
cut=None, # We'll apply selection after computing derived branches
entry_start=entry_start,
entry_stop=entry_stop,
library="ak", # Use awkward arrays
)
except Exception as e:
logger.error(f"Failed to read branches: {e}")
raise
# Process derived branches if defines are provided
if defines:
from root_mcp.extended.analysis.operations import AnalysisOperations
analysis_ops = AnalysisOperations(self.config, self.file_manager)
arrays = analysis_ops._process_defines(arrays, defines)
# Apply selection after computing derived branches (if applicable)
if selection:
try:
from root_mcp.extended.analysis.operations import _evaluate_selection_any
mask = _evaluate_selection_any(arrays, selection)
arrays = arrays[mask]
except Exception as e:
logger.error(f"Failed to apply selection: {e}")
raise ValueError(
f"Invalid selection expression: {selection}. "
"Use ROOT-style syntax (e.g., 'pt > 20 && abs(eta) < 2.4')"
) from e
# Get actual number of entries (after selection)
entries_returned = len(arrays)
# Select only the requested branches (filter out intermediate branches)
try:
arrays = arrays[branches]
except Exception as e:
logger.error(f"Failed to select requested branches: {e}")
raise
# Flatten if requested
if flatten:
arrays = ak.flatten(arrays, axis=None)
# Convert to records (list of dicts)
records = self._arrays_to_records(arrays)
# Check if jagged
is_jagged = self._check_if_jagged(arrays)
return {
"data": {
"branches": branches,
"entries": entries_returned,
"is_jagged": is_jagged,
"records": records,
},
"metadata": {
"operation": "read_branches",
"entries_scanned": entry_stop - entry_start,
"entries_selected": entries_returned,
"entries_returned": entries_returned,
"truncated": entry_stop < total_entries
or entries_returned < (entry_stop - entry_start),
"selection": selection,
"defines": list(defines.keys()) if defines else None,
},
}
[docs]
def sample_tree(
self,
path: str,
tree_name: str,
size: int = 100,
method: str = "first",
branches: list[str] | None = None,
seed: int | None = None,
) -> dict[str, Any]:
"""
Get a sample from a tree.
Args:
path: File path
tree_name: Tree name
size: Sample size
method: "first" or "random"
branches: Branches to include (None = all)
seed: Random seed for reproducibility
Returns:
Sample data and metadata
"""
tree = self.file_manager.get_tree(path, tree_name)
# Get branches
if branches is None:
branches = list(tree.keys())
# Apply size limit
size = min(size, 10_000) # Max sample size
if method == "first":
# Just read first N entries
return self.read_branches(
path=path,
tree_name=tree_name,
branches=branches,
limit=size,
offset=0,
)
elif method == "random":
# Random sampling
total_entries = tree.num_entries
if seed is not None:
np.random.seed(seed)
# Generate random indices
indices = np.random.choice(total_entries, size=min(size, total_entries), replace=False)
indices = np.sort(indices) # Sort for better I/O performance
# Read using entry ranges
# For simplicity, read in chunks that contain the random indices
# (Optimal implementation would use uproot's array indexing)
arrays = tree.arrays(
filter_name=branches,
entry_start=0,
entry_stop=total_entries,
library="ak",
)[indices]
records = self._arrays_to_records(arrays)
is_jagged = self._check_if_jagged(arrays)
return {
"data": {
"branches": branches,
"entries": len(arrays),
"is_jagged": is_jagged,
"records": records,
},
"metadata": {
"operation": "sample_tree",
"method": method,
"seed": seed,
},
}
else:
raise ValueError(f"Unknown sampling method: {method}. Use 'first' or 'random'")
[docs]
def get_branch_info(
self, path: str, tree_name: str, pattern: str | None = None
) -> list[dict[str, Any]]:
"""
Get information about branches in a tree.
Args:
path: File path
tree_name: Tree name
pattern: Optional glob pattern to filter branches
Returns:
List of branch info dictionaries
"""
tree = self.file_manager.get_tree(path, tree_name)
branches = []
for name in tree.keys():
# Filter by pattern if provided
if pattern and not self._matches_glob(name, pattern):
continue
branch = tree[name]
typename = str(branch.typename) if hasattr(branch, "typename") else "unknown"
# Determine if jagged (variable-length)
is_jagged = "[]" in typename or "vector" in typename.lower()
info = {
"name": name,
"type": typename,
"title": str(branch.title) if hasattr(branch, "title") else "",
"is_jagged": is_jagged,
}
branches.append(info)
return branches
[docs]
def compute_branch_stats(
self,
path: str,
tree_name: str,
branches: list[str],
selection: str | None = None,
) -> dict[str, dict[str, float]]:
"""
Compute statistics for branches.
Args:
path: File path
tree_name: Tree name
branches: Branches to analyze
selection: Optional cut expression
Returns:
Dictionary mapping branch names to statistics
"""
tree = self.file_manager.get_tree(path, tree_name)
# Read data (all entries, but only requested branches)
arrays = tree.arrays(
filter_name=branches,
cut=None if "RNTuple" in str(type(tree)) else selection,
library="ak",
)
# Apply selection manually for RNTuples if uproot cut wasn't used
if selection and "RNTuple" in str(type(tree)):
from root_mcp.extended.analysis.operations import _evaluate_selection_any
mask = _evaluate_selection_any(arrays, selection)
arrays = arrays[mask]
stats = {}
for branch in branches:
data = arrays[branch]
# Flatten jagged arrays
if _is_list_like(data):
data = ak.flatten(data)
# Convert to numpy for stats
data_np = ak.to_numpy(data)
# Compute statistics
stats[branch] = {
"count": len(data_np),
"mean": float(np.mean(data_np)),
"std": float(np.std(data_np)),
"min": float(np.min(data_np)),
"max": float(np.max(data_np)),
"median": float(np.median(data_np)),
}
return stats
@staticmethod
def _arrays_to_records(arrays: ak.Array) -> list[dict[str, Any]]:
"""Convert awkward arrays to list of records."""
# Convert to list of dicts
records = ak.to_list(arrays)
return records if isinstance(records, list) else []
@staticmethod
def _check_if_jagged(arrays: ak.Array) -> bool:
"""Check if arrays contain jagged (variable-length) data."""
for field in arrays.fields:
if _is_variable_length_list(arrays[field]):
return True
return False
@staticmethod
def _matches_glob(text: str, pattern: str) -> bool:
"""Check if text matches glob pattern."""
import fnmatch
return fnmatch.fnmatch(text, pattern)
[docs]
def stream_branches(
self,
path: str,
tree_name: str,
branches: list[str],
chunk_size: int = 10_000,
selection: str | None = None,
):
"""
Stream branch data in chunks for large files.
Args:
path: File path
tree_name: Tree name
branches: Branches to read
chunk_size: Number of entries per chunk
selection: Optional cut expression
Yields:
Chunks of data as awkward arrays
"""
tree = self.file_manager.get_tree(path, tree_name)
total_entries = tree.num_entries
logger.info(
f"Streaming {len(branches)} branches from {tree_name} "
f"({total_entries} entries, chunk_size={chunk_size})"
)
# Stream in chunks
for entry_start in range(0, total_entries, chunk_size):
entry_stop = min(entry_start + chunk_size, total_entries)
try:
arrays = tree.arrays(
filter_name=branches,
cut=None if "RNTuple" in str(type(tree)) else selection,
entry_start=entry_start,
entry_stop=entry_stop,
library="ak",
)
# Apply selection manually for RNTuples
if selection and "RNTuple" in str(type(tree)):
from root_mcp.extended.analysis.operations import _evaluate_selection_any
mask = _evaluate_selection_any(arrays, selection)
arrays = arrays[mask]
yield arrays
except Exception as e:
logger.error(f"Failed to read chunk {entry_start}:{entry_stop}: {e}")
raise
@staticmethod
def _find_similar_branches(target: str, available: list[str]) -> list[str]:
"""Find branches with similar names using simple heuristics."""
from difflib import get_close_matches
return get_close_matches(target, available, n=3, cutoff=0.6)
def _unwrap_awkward_layout(layout: Any) -> Any:
while True:
name = type(layout).__name__
if (
name
in {
"IndexedArray",
"IndexedOptionArray",
"ByteMaskedArray",
"BitMaskedArray",
"UnmaskedArray",
}
or name.endswith("OptionArray")
or name.endswith("MaskedArray")
) and hasattr(layout, "content"):
layout = layout.content
continue
return layout
def _is_list_like(array: ak.Array) -> bool:
try:
layout = _unwrap_awkward_layout(ak.to_layout(array))
except Exception:
return False
return type(layout).__name__ in {"RegularArray", "ListArray", "ListOffsetArray"} or (
"ListOffsetArray" in type(layout).__name__
)
def _is_variable_length_list(array: ak.Array) -> bool:
try:
layout = _unwrap_awkward_layout(ak.to_layout(array))
except Exception:
return False
name = type(layout).__name__
if name == "RegularArray":
return False
return name == "ListArray" or name == "ListOffsetArray" or "ListOffsetArray" in name
[docs]
class HistogramReader:
"""
High-level interface for reading histograms.
Provides access to TH1, TH2, TH3, and TProfile objects.
"""
[docs]
def __init__(self, config: Config, file_manager: FileManager):
"""
Initialize HistogramReader.
Args:
config: Server configuration
file_manager: File manager instance
"""
self.config = config
self.file_manager = file_manager
[docs]
def read_histogram(self, path: str, hist_name: str) -> dict[str, Any]:
"""
Read a histogram from a ROOT file.
Args:
path: File path
hist_name: Histogram name or path
Returns:
Histogram data and metadata
"""
file_obj = self.file_manager.open(path)
try:
hist = file_obj[hist_name]
except KeyError as e:
available = [h["name"] for h in self.file_manager.list_histograms(path)]
raise KeyError(f"Histogram '{hist_name}' not found. Available: {available}") from e
# Get histogram type
classname = hist.classname
# Read based on dimensionality
if "TH1" in classname or "TProfile" in classname:
return self._read_1d_histogram(hist, classname)
elif "TH2" in classname:
return self._read_2d_histogram(hist, classname)
elif "TH3" in classname:
return self._read_3d_histogram(hist, classname)
else:
raise ValueError(f"Unsupported histogram type: {classname}")
def _read_1d_histogram(self, hist: Any, classname: str) -> dict[str, Any]:
"""Read 1D histogram."""
values = hist.values()
edges = hist.axis().edges()
errors = hist.errors() if hasattr(hist, "errors") else np.sqrt(values)
return {
"type": classname,
"bin_edges": edges.tolist(),
"bin_counts": values.tolist(),
"bin_errors": errors.tolist(),
"entries": int(values.sum()),
"underflow": float(hist.values(flow=True)[0]),
"overflow": float(hist.values(flow=True)[-1]),
}
def _read_2d_histogram(self, hist: Any, classname: str) -> dict[str, Any]:
"""Read 2D histogram."""
values = hist.values()
x_edges = hist.axis(0).edges()
y_edges = hist.axis(1).edges()
return {
"type": classname,
"x_edges": x_edges.tolist(),
"y_edges": y_edges.tolist(),
"counts": values.tolist(),
"entries": int(values.sum()),
}
def _read_3d_histogram(self, hist: Any, classname: str) -> dict[str, Any]:
"""Read 3D histogram."""
values = hist.values()
x_edges = hist.axis(0).edges()
y_edges = hist.axis(1).edges()
z_edges = hist.axis(2).edges()
return {
"type": classname,
"x_edges": x_edges.tolist(),
"y_edges": y_edges.tolist(),
"z_edges": z_edges.tolist(),
"counts": values.tolist(),
"entries": int(values.sum()),
}