Source code for nsbi_common_utils.inference

from __future__ import annotations

import numpy as np
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)

from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from iminuit import Minuit

[docs] def plot_NLL_scans(parameter_label: str, list_scan_points: list[list[float]], list_nll_values: list[list[float]], list_labels: list[str], list_linestyles: list[str], list_colors: list[str], ax: plt.Axes | None = None): """ Plot one or more NLL profile scan curves on a single axes. Draws each scan as a line on a shared axes, adds horizontal reference lines at :math:`\\Delta NLL = 1, 4, 9` corresponding to the :math:`1\\sigma`, :math:`2\\sigma`, and :math:`3\\sigma` confidence intervals, and annotates them accordingly. Parameters ---------- parameter_label : str LaTeX-formatted label for the x-axis, e.g. ``r"$\\mu$"``. If an empty string is passed the raw ``parameter_name`` variable is used as a fallback (note: ``parameter_name`` must be defined in the calling scope in that case). list_scan_points : list of list of float Scan point coordinates for each curve. Each inner list must have the same length as the corresponding entry in ``list_nll_values``. Typically the first return value of :meth:`inference.perform_profile_scan`. list_nll_values : list of list of float :math:`\\Delta NLL` values for each curve, evaluated at the corresponding scan points. Values should already be minimum-subtracted (i.e. the minimum of each curve sits at 0). list_labels : list of str Legend labels for each curve, e.g. ``["Stat + Syst", "Stat Only"]``. list_linestyles : list of str Matplotlib linestyle strings for each curve, e.g. ``["solid", "dashed"]``. list_colors : list of str Matplotlib colour strings for each curve, e.g. ``["black", "red"]``. ax : matplotlib.axes.Axes or None, optional Axes object to draw on. If ``None`` (default), a new figure and axes are created internally. Pass an existing axes to embed the plot in a larger figure layout. Notes ----- * All lists (``list_scan_points``, ``list_nll_values``, ``list_labels``, ``list_linestyles``, ``list_colors``) must have the same length; no length validation is performed. * The y-axis lower limit is fixed at ``0.0``; there is no upper limit set, so matplotlib will auto-scale to the data. * Reference lines at :math:`\\Delta NLL = 1, 4, 9` assume that the profile likelihood ratio test statistic :math:`t_\\mu = -2\\Delta\\ln L` is used, so confidence intervals are valid under Wilks' theorem. * If ``ax`` is ``None`` the created figure is not returned; call ``plt.savefig`` or ``plt.show`` after this function if needed. Examples -------- .. code-block:: python scan_pts, nll_vals = fitter.perform_profile_scan("mu", (0.0, 3.0)) plot_NLL_scans( parameter_label=r"$\\mu$", list_scan_points=[scan_pts], list_nll_values=[nll_vals], list_labels=["Stat + Syst"], list_linestyles=["solid"], list_colors=["black"] ) plt.show() See Also -------- inference.perform_profile_scan : Produces the scan arrays consumed by this function. """ if ax is None: fig, ax = plt.subplots() for count in range(len(list_labels)): ax.plot( list_scan_points[count], list_nll_values[count], linestyle=list_linestyles[count], label=list_labels[count], color=list_colors[count]) ax.legend() ax.set_ylim(bottom=0.0) ax.set_xlabel(parameter_label or parameter_name) ax.set_ylabel(r"$t_\mu$") ax.axhline(y=1.0, color='gray', linestyle='dotted', alpha=0.5) ax.text(1.0, 1.02, r"$1\sigma$ ", transform=ax.get_yaxis_transform(), ha='right', va='bottom', color='gray', fontsize=9) ax.axhline(y=4.0, color='gray', linestyle='dotted', alpha=0.5) ax.text(1.0, 4.02, r"$2\sigma$ ", transform=ax.get_yaxis_transform(), ha='right', va='bottom', color='gray', fontsize=9) ax.axhline(y=9.0, color='gray', linestyle='dotted', alpha=0.5) ax.text(1.0, 9.02, r"$3\sigma$ ", transform=ax.get_yaxis_transform(), ha='right', va='bottom', color='gray', fontsize=9)
[docs] class inference: def __init__(self, model_nll, initial_values: list[float], list_parameters: list[str], num_unconstrained_params: int, model_grad=None): """ Initialise the inference engine around a callable NLL function. Parameters ---------- model_nll : callable A scalar-valued function representing the negative log-likelihood to minimise. The signature must be compatible with iminuit's expectation: either a single array argument (used here via ``Minuit(f, values, name=names)``) or explicit keyword arguments. The function must return a scalar NLL value. JAX-compiled functions are supported and recommended for performance. initial_values : list of float or jnp.ndarray, shape (n_params,) Starting values for all parameters passed to MIGRAD. The order must match ``list_parameters`` exactly. Typically obtained from :meth:`sbi_parametric_model.get_model_parameters`. list_parameters : list of str Names of all parameters in the model, in the same order as ``initial_values``. The parameter of interest (POI) is expected at index ``0``, followed by unconstrained norm factors, then constrained nuisance parameters. num_unconstrained_params : int Number of leading parameters (starting from index ``0``) that are treated as unconstrained (i.e. parameters of interest and free norm factors with no Gaussian penalty). Parameters from index ``num_unconstrained_params`` onwards are treated as constrained nuisance parameters and will be fixed when constructing a stat-only NLL curve in :meth:`perform_profile_scan`. model_grad : callable or None, optional A function that returns the gradient of the NLL with respect to all parameters. Signature: ``model_grad(param_array) -> ndarray``. If provided, iminuit uses analytical gradients instead of finite-difference approximations, reducing the number of NLL evaluations by a factor of ~(n_params + 1). Notes ----- * ``pulls_global_fit`` is initialised to ``None`` and populated only after :meth:`perform_fit` is called. Methods that depend on global-fit values (e.g. ``doStatOnly=True`` in :meth:`perform_profile_scan`) will raise a ``RuntimeError`` if called before :meth:`perform_fit`. See Also -------- perform_fit : Run the global MIGRAD minimisation. perform_profile_scan : Compute a profiled NLL scan over one parameter. """ self.model_nll = model_nll self.initial_values = initial_values self.list_parameters = list_parameters self.num_unconstrained_params = num_unconstrained_params self.model_grad = model_grad self.pulls_global_fit = None
[docs] def perform_fit(self, fit_strategy=2, freeze_params=[]): """ Run MIGRAD and store best-fit parameter values. Parameters ---------- fit_strategy : int Minuit strategy (0 = fast, 1 = default, 2 = robust). freeze_params : list[str] | None List of parameter names to fix during the global fit. Notes ----- After a successful fit, ``self.pulls_global_fit`` is set to a :class:`numpy.ndarray` of best-fit values. This is required before calling :meth:`perform_profile_scan` with ``doStatOnly=True``. """ # Instantiate the iminuit object m = Minuit(self.model_nll, self.initial_values, grad=self.model_grad, name=tuple(self.list_parameters)) m.errordef = Minuit.LEAST_SQUARES strategy = fit_strategy # Freeze parameters in freeze_params list to initial values if len(freeze_params)>=1: for param in freeze_params: m.fixed[param] = True m.strategy = strategy # Run the fit with MIGRAD mg = m.migrad() # Store best-fit values in parameter order self.pulls_global_fit = np.array(m.values) # Displays results of the global fit print(f'fit: \n {mg}')
[docs] def perform_profile_scan(self, parameter_name: str = '', bound_range: tuple[float] = (0.0, 3.0), fit_strategy: int = 2, freeze_params: list[str] =[], doStatOnly: bool = False, isConstrainedNP: bool = False, size: int = 100) -> tuple[list[float]]: """ Profile the NLL along `parameter_name` and plot the scan. Parameters ---------- parameter_name : str Name of the parameter to scan (must be in `list_parameters`). bound_range : (float, float) Scan bounds for profile fit. fit_strategy : int Minuit strategy for the profile scans. freeze_params : list[str] | None Parameters to fix for both scans (Stat+Syst and StatOnly). doStatOnly : bool If True, also produce a "Stat Only" curve by fixing nuisance params (those after `num_unconstrained_params`) at their global-fit values. isConstrainedNP : bool If True, change the y-axis label to t_alpha; else use t_mu. size : int Number of scan points. Returns ------- scan_points : array-like Parameter values at which the NLL was evaluated. NLL_value : array-like :math:`\\Delta NLL` values (minimum-subtracted). scan_points_StatOnly : array-like, optional Returned only when ``doStatOnly=True``. Scan points for the stat-only curve. NLL_value_StatOnly : array-like, optional Returned only when ``doStatOnly=True``. :math:`\\Delta NLL` for the stat-only curve. Raises ------ RuntimeError If ``doStatOnly=True`` but :meth:`perform_fit` has not been called yet. """ m = Minuit(self.model_nll, self.initial_values, grad=self.model_grad, name=tuple(self.list_parameters)) m.errordef = Minuit.LEAST_SQUARES m.strategy = fit_strategy for param in freeze_params: m.fixed[param] = True # Profile fit: subtract_min=True returns \Delta NLL scan_points, NLL_value, _ = m.mnprofile(parameter_name, bound=bound_range, subtract_min=True, size = size) # Optionally plot a stat-only NLL curve if doStatOnly: if self.pulls_global_fit is None: raise RuntimeError( "perform_fit() must be called before doStatOnly=True, " "so nuisance parameters can be fixed at their global-fit values." ) # Re-initialize with global fit pulls so that fixed params are at the best-fit point m_StatOnly = Minuit(self.model_nll, self.pulls_global_fit, grad=self.model_grad, name=tuple(self.list_parameters)) m_StatOnly.errordef = Minuit.LEAST_SQUARES m_StatOnly.strategy = fit_strategy # Fix the same globally-frozen params for param in freeze_params: m_StatOnly.fixed[param] = True # Additionally fix nuisance parameters (constrained NPs) for param in self.list_parameters[self.num_unconstrained_params:]: m_StatOnly.fixed[param] = True scan_points_StatOnly, NLL_value_StatOnly, _ = m_StatOnly.mnprofile(parameter_name, bound=bound_range, subtract_min=True, size = size) return scan_points, NLL_value, scan_points_StatOnly, NLL_value_StatOnly else: return scan_points, NLL_value