Source code for root_mcp.extended.analysis.plotting

"""Plotting module for ROOT-MCP."""

from __future__ import annotations

import base64
import io
import logging
from typing import Any

import matplotlib.pyplot as plt
import numpy as np

logger = logging.getLogger(__name__)

# Use non-interactive backend
plt.switch_backend("Agg")


[docs] def generate_plot( data: dict[str, Any], plot_type: str = "histogram", fit_data: dict[str, Any] | None = None, options: dict[str, Any] | None = None, config: Any | None = None, ) -> dict[str, str]: """ Generate a plot from analysis data. Args: data: Analysis result (histogram data) plot_type: Type of plot (histogram, etc.) fit_data: Optional fit results to overlay options: Plotting options (title, labels, etc.) config: Configuration object with plotting settings Returns: Dictionary with base64 encoded image """ if options is None: options = {} # Get plotting config or use defaults if config and hasattr(config, "analysis") and hasattr(config.analysis, "plotting"): plot_cfg = config.analysis.plotting figsize = (plot_cfg.figure_width, plot_cfg.figure_height) else: figsize = (10, 6) fig, ax = plt.subplots(figsize=figsize) try: if plot_type == "histogram": plot_metadata = _plot_histogram(ax, data, fit_data, options, config) # Customize 1D histogram plot branch_name = data.get("metadata", {}).get("branch", "Value") unit = options.get("unit", "") # X Label xlabel = options.get("xlabel", branch_name) if unit: xlabel += f" [{unit}]" ax.set_xlabel(xlabel) # Y Label ylabel = options.get("ylabel") if not ylabel: # Auto-generate Y label bin_width = plot_metadata.get("bin_width") if bin_width: # Format properly (e.g. 0.5 or 10) width_str = f"{bin_width:.3g}" if unit: ylabel = f"Entries / {width_str} {unit}" else: ylabel = f"Entries / {width_str}" else: ylabel = "Entries" ax.set_ylabel(ylabel) ax.set_title(options.get("title", f"{branch_name} Distribution")) # Styling if options.get("log_y"): ax.set_yscale("log") if options.get("log_x"): ax.set_xscale("log") # Get grid alpha from config if config and hasattr(config, "analysis") and hasattr(config.analysis, "plotting"): plot_cfg = config.analysis.plotting grid_alpha = plot_cfg.grid_alpha grid_enabled = plot_cfg.grid_enabled else: grid_alpha = 0.3 grid_enabled = True grid_style = options.get("grid", grid_enabled) if grid_style: ax.grid(True, alpha=grid_alpha, which="both" if options.get("log_y") else "major") ax.legend() elif plot_type == "histogram_2d": _plot_histogram_2d(fig, ax, data, options, config) else: raise ValueError(f"Unsupported plot type: {plot_type}") # Get DPI from config if config and hasattr(config, "analysis") and hasattr(config.analysis, "plotting"): dpi = config.analysis.plotting.dpi else: dpi = 100 # Save to buffer buf = io.BytesIO() fig.tight_layout() fig.savefig(buf, format="png", dpi=dpi) plt.close(fig) buf.seek(0) img_str = base64.b64encode(buf.read()).decode("utf-8") return {"image_type": "png", "image_data": img_str} except Exception as e: plt.close(fig) logger.error(f"Plotting failed: {e}") raise RuntimeError(f"Plotting failed: {e}")
def _plot_histogram( ax: plt.Axes, data: dict[str, Any], fit_data: dict[str, Any] | None, options: dict[str, Any], config: Any | None = None, ) -> dict[str, Any]: """Helper to plot 1D histogram.""" # Handle both formats: # 1. Full histogram result: {"data": {...}, "metadata": {...}} # 2. Just the data dict: {"bin_edges": [...], "bin_counts": [...]} if "data" in data and "bin_edges" not in data: hist_data = data["data"] else: hist_data = data edges = np.array(hist_data["bin_edges"]) counts = np.array(hist_data["bin_counts"]) # Handle errors if "bin_errors" in hist_data: errors = np.array(hist_data["bin_errors"]) else: errors = np.sqrt(counts) centers = (edges[:-1] + edges[1:]) / 2 width = edges[1] - edges[0] # Get plotting config or use defaults if config and hasattr(config, "analysis") and hasattr(config.analysis, "plotting"): plot_cfg = config.analysis.plotting data_color = plot_cfg.data_color marker_size = plot_cfg.marker_size marker_style = plot_cfg.marker_style cap_size = plot_cfg.error_bar_cap_size hist_alpha = plot_cfg.hist_fill_alpha hist_color = plot_cfg.hist_fill_color line_width = plot_cfg.line_width fit_color = plot_cfg.fit_line_color fit_style = plot_cfg.fit_line_style else: data_color = "black" marker_size = 4.0 marker_style = "o" cap_size = 2.0 hist_alpha = 0.2 hist_color = "blue" line_width = 2.0 fit_color = "red" fit_style = "-" # Plot data points with errors color = options.get("color", data_color) ax.errorbar( centers, counts, yerr=errors, fmt=marker_style, color=color, label="Data", markersize=marker_size, capsize=cap_size, ) # Plot histogram step ax.stairs(counts, edges, fill=True, alpha=hist_alpha, color=hist_color, label="Hist") # Overlay fit if present if fit_data: fitted_values = fit_data.get("fitted_values") if fitted_values: # If fit returned values on the same x-coord ax.plot( centers, fitted_values, fit_style, linewidth=line_width, color=fit_color, label=f"Fit ({fit_data['model']})", ) return {"bin_width": width} def _plot_histogram_2d( fig: plt.Figure, ax: plt.Axes, data: dict[str, Any], options: dict[str, Any], config: Any | None = None, ) -> None: """Helper to plot 2D histogram.""" # Handle both formats if "data" in data: hist_data = data["data"] else: hist_data = data # Handle different field naming conventions # HistogramOperations uses: bin_edges_x, bin_edges_y, bin_counts # AnalysisOperations uses: x_edges, y_edges, counts edges_x = np.array(hist_data.get("bin_edges_x") or hist_data.get("x_edges")) edges_y = np.array(hist_data.get("bin_edges_y") or hist_data.get("y_edges")) counts = np.array(hist_data.get("bin_counts") or hist_data.get("counts")) # Get options colormap = options.get("colormap", "viridis") log_z = options.get("log_z", False) title = options.get("title", "2D Histogram") xlabel = options.get("xlabel", "X") ylabel = options.get("ylabel", "Y") # Apply log scale to counts if requested plot_counts = counts.T # Transpose for correct orientation if log_z: plot_counts = np.where(plot_counts > 0, np.log10(plot_counts), 0) # Create 2D histogram plot im = ax.pcolormesh(edges_x, edges_y, plot_counts, cmap=colormap, shading="auto") # Add colorbar cbar = fig.colorbar(im, ax=ax) if log_z: cbar.set_label("log10(Entries)") else: cbar.set_label("Entries") # Set labels and title ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_title(title)