Source code for nsbi_common_utils.models.sbi_parametric_model

import numpy as np
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax.tree_util import tree_map
from functools import partial
from typing import Dict, Union, Any, Optional

[docs] class sbi_parametric_model: """ Statistical model for semi-parametric Simulation-Based Inference (SBI). Defines parameterized expected yields, density ratios, and the negative log-likelihood (NLL) passed to fitting algorithms. Supports both binned and unbinned channels with systematic uncertainties handled via polynomial interpolation / exponential extrapolation (HistFactory strategy 5). Two JIT-compiled entry points are built: * :meth:`model` — evaluates the negative log-likelihood function for fit. * :meth:`model_grad` — evaluates the negative log-likelihood gradient via JAX reverse-mode autodiff. Parameters ---------- workspace : dict A workspace dictionary following the pyhf-like JSON schema. Must contain ``"measurements"`` and ``"channels"`` keys. Channels may be tagged with ``"type": "binned"`` or ``"type": "unbinned"``. measurement_to_fit : str Name of the measurement block inside ``workspace["measurements"]`` to use. Selects the parameter of interest (POI) and the list of parameters to fit. Attributes ---------- list_parameters : list of str Ordered parameter names: POI first, then unconstrained norm factors, then constrained nuisance parameters. initial_parameter_values : jnp.ndarray Starting values for every parameter, in the same order as ``list_parameters``. num_unconstrained_param : int Number of leading parameters that are unconstrained (POI + free norm factors). expected_hist : jnp.ndarray Binned expected yields evaluated at the initial parameter values. See Also -------- nsbi_common_utils.inference.inference : Fits and scans this model. """ def __init__(self, workspace: Dict[Any, Any], measurement_to_fit: str): self.workspace = workspace self.measurements_dict: list[Dict[str, Any]] = workspace["measurements"] for measurement in self.measurements_dict: measurement_name = measurement.get("name") if measurement_name == measurement_to_fit: self.measurement_name = measurement_name self.poi = measurement["config"]["poi"] self.measurement_param_dict = measurement["config"]["parameters"] break self.param_names = [p['name'] for p in self.workspace['measurements'][0]['config']['parameters']] self.parameters_in_measurement, \ self.initial_values_dict = self._get_parameters_to_fit() self.channels_binned = self._get_channel_list(type_of_fit="binned") self.channels_unbinned = self._get_channel_list(type_of_fit="unbinned") self.all_channels = self.channels_binned + self.channels_unbinned self.all_samples = self._get_samples_list() sorting_order = {"normfactor": 0, "normplusshape": 1} self.list_parameters, \ self.list_parameters_types, \ self.num_unconstrained_param = self._get_parameters(sorting_order) self.list_syst_normplusshape = self._get_list_syst_for_interp() self.list_normfactors, \ self.norm_sample_map = self._get_norm_factors() self.has_normplusshape = len(self.list_syst_normplusshape) > 0 self.initial_parameter_values = self._get_param_vec_initial() self.index_normparam_map = self._make_map_index_norm() self.yield_array_dict, _ = self._get_nominal_expected_arrays( type_of_fit = "binned" ) self.unbinned_total_dict, \ self.ratios_array_dict = self._get_nominal_expected_arrays( type_of_fit = "unbinned" ) self.combined_var_up_binned, \ self.combined_var_dn_binned = self._get_systematic_data( type_of_fit="binned" ) self.combined_var_up_unbinned, \ self.combined_var_dn_unbinned, \ self.combined_tot_up_unbinned, \ self.combined_tot_dn_unbinned = self._get_systematic_data( type_of_fit="unbinned" ) self.weight_arrays_unbinned = self._get_asimov_weights_array() self._finalize_to_device() self.expected_hist = self._get_expected_hist(param_vec = self.initial_parameter_values) self.expected_rate_unbinned = self._get_expected_rate_unbinned(param_vec = self.initial_parameter_values) # Stack per-process dicts into arrays for vectorized NLL self._build_stacked_data() self._jit_nll, self._jit_val_and_grad = self._build_jit_functions()
[docs] def get_model_parameters(self): """ Return parameter names and initial values for fitting. The returned order matches the convention expected by :class:`~nsbi_common_utils.inference.inference`: POI at index 0, followed by unconstrained norm factors, then constrained nuisance parameters. Returns ------- list_parameters : list of str Ordered parameter names. initial_parameter_values : jnp.ndarray, shape (n_params,) Starting values aligned with ``list_parameters``. """ return self.list_parameters, self.initial_parameter_values
def _get_expected_hist(self, param_vec): """ Optimized function for NLL computations """ param_vec_interpolation = param_vec[ self.num_unconstrained_param : ] norm_modifiers = {} hist_vars_binned = {} norm_modifiers = self._calculate_norm_variations(param_vec) for process in self.all_samples: if self.has_normplusshape: hist_vars_binned[process] = _calculate_combined_var( param_vec_interpolation, self.combined_var_up_binned[process], self.combined_var_dn_binned[process] ) else: hist_vars_binned[process] = jnp.ones_like( self.yield_array_dict[process] ) data_expected = self._calculate_parameterized_yields( self.yield_array_dict, hist_vars_binned, norm_modifiers ) return data_expected def _get_expected_rate_unbinned(self, param_vec): """Compute the total expected rate in unbinned channels at ``param_vec``. Used as Asimov observed rate.""" param_vec_interpolation = param_vec[ self.num_unconstrained_param : ] norm_modifiers = self._calculate_norm_variations(param_vec) hist_vars_unbinned = {} for process in self.all_samples: if self.has_normplusshape: hist_vars_unbinned[process] = _calculate_combined_var( param_vec_interpolation, self.combined_tot_up_unbinned[process], self.combined_tot_dn_unbinned[process] ) else: hist_vars_unbinned[process] = jnp.ones_like( self.unbinned_total_dict[process] ) return self._calculate_parameterized_yields(self.unbinned_total_dict, hist_vars_unbinned, norm_modifiers) def _make_map_index_norm(self): """ Maps the index of parameter in the parameter vector to norm factor """ dict_index_normfactor = {} for normfactor in self.list_normfactors: index = self.list_parameters.index( normfactor ) dict_index_normfactor[normfactor] = index return dict_index_normfactor def _get_param_vec_initial(self): initial_values_vec = np.ones((len(self.list_parameters),)) for count, parameter in enumerate(self.list_parameters): initial_values_vec[count] = self.initial_values_dict[parameter] return jnp.asarray(initial_values_vec) def _get_norm_factors(self) -> Union[list, Dict[str, list]]: """Assume same normfactor across channels for now (TO-DO: Add support for normfactor per channel)""" dict_sample_normfactors = {sample_name: [] for sample_name in self.all_samples} list_all_norm_factors = [] for channel in self.all_channels[:1]: channel_index = self._index_of_region(channel_name=channel) for sample in self.all_samples: sample_index = self._index_of_sample(channel_name=channel, sample_name=sample) modifier_list = self.workspace["channels"][channel_index]["samples"][sample_index]["modifiers"] for modifier in modifier_list: if modifier["type"] == "normfactor": modifier_name = modifier["name"] if modifier_name not in list_all_norm_factors : list_all_norm_factors.append(modifier_name) if modifier_name not in dict_sample_normfactors[sample] : dict_sample_normfactors[sample].append(modifier_name) list_all_norm_factors = [p for p in list_all_norm_factors if p in self.param_names] dict_sample_normfactors = {key: val for key, val in dict_sample_normfactors.items() if any(p in self.param_names for p in val) } return list_all_norm_factors, dict_sample_normfactors def _get_parameters_to_fit(self) -> tuple[list[str], dict[str, float]]: """ Outputs a list of parameters specified by the user for fitting in the workspace """ parameters_to_fit = [] initial_value_params = {} for parameters in self.measurement_param_dict: parameter_name = parameters["name"] parameter_init = parameters["inits"][0] parameters_to_fit.append(parameter_name) initial_value_params[parameter_name] = parameter_init return parameters_to_fit, initial_value_params def _get_list_syst_for_interp(self): """Get the list of subset of systematics that need interpolation.""" mask_normplusshape = (np.array(self.list_parameters_types) == "normplusshape") list_normplusshape = np.array(self.list_parameters)[mask_normplusshape].tolist() return list_normplusshape def _get_channel_list(self, type_of_fit: Union[str, None] = None) -> list: """Get the channel list to be used in the measurement""" list_channels = [] channels: list[Dict[str, Any]] = self.workspace["channels"] for channel_dict in channels: if type_of_fit is not None: if channel_dict.get("type") != type_of_fit: continue list_channels.append(channel_dict.get("name")) return list_channels def _get_samples_list(self): """Get the sample list from the first channel""" list_samples = [] channels: list[Dict[str, Any]] = self.workspace["channels"] for channel_dict in channels: samples: list[Dict[str, Any]] = channel_dict["samples"] for sample_dict in samples: list_samples.append(sample_dict.get("name")) break return list_samples def _get_asimov_weights_array(self): """ Get the Asimov weight vector for fitting """ weight_array = np.array([]) for channel in self.channels_unbinned: channel_index = self._index_of_region(channel) weights = np.load(self.workspace["channels"][channel_index]["weights"]) weight_array = np.append(weight_array, weights) return weight_array def _get_parameters(self, sorting_order): """Get a list of all parameters.""" list_param_names = [] list_param_types = [] channels: list[Dict[str, Any]] = self.workspace["channels"] for channel_dict in channels: samples: list[Dict[str, Any]] = channel_dict["samples"] for sample_dict in samples: modifiers_list: list[Dict[str, Any]] = sample_dict["modifiers"] for modifier in modifiers_list: modifier_name = modifier.get("name") if modifier_name not in self.parameters_in_measurement: continue modifier_type = modifier.get("type") if modifier_name not in list_param_names: list_param_names.append(modifier_name) list_param_types.append(modifier_type) indices = np.argsort([sorting_order.get(param_type, 999) for param_type in list_param_types]) list_param_names = [list_param_names[i] for i in indices] list_param_types = [list_param_types[i] for i in indices] index_poi = list_param_names.index(self.poi) if index_poi != 0: poi_name = list_param_names.pop(index_poi) poi_type = list_param_types.pop(index_poi) list_param_names.insert(0, poi_name) list_param_types.insert(0, poi_type) num_unconstrained_params = 0 for poi_type_ in list_param_types: if poi_type_ != "normfactor": break num_unconstrained_params += 1 return list_param_names, list_param_types, num_unconstrained_params def _calculate_parameterized_yields(self, hist_yields, hist_vars, norm_modifiers): nu_tot = 0.0 for process in self.all_samples: # This will not work in the general case where model is non-linear in POI, needs modifications (TO-DO) nu_tot += norm_modifiers[process] * hist_yields[process] * hist_vars[process] return nu_tot def _calculate_parameterized_ratios(self, nu_nominal, nu_vars, ratios, ratio_vars, norm_modifiers): dnu_dx = jnp.zeros_like(self.weight_arrays_unbinned) # To-do: Generalize to any dataset, not just nominal for process in self.all_samples: # jax.debug.print("norm_modifiers variations is {x1}", x1 = norm_modifiers[process]) dnu_dx += norm_modifiers[process] * nu_vars[process] * nu_nominal[process] * ratios[process] * ratio_vars[process] return jnp.log( dnu_dx )
[docs] def model(self, param_array: Union[np.array, jnp.array, list[float]]): """ High-level API that returns the full negative log-likelihood for a parameter point. Computes the combined NLL in all channels defined by the input workspace - unbinned SBI and binned Control and Signal regions. This callable is the function to be passed to :class:`~nsbi_common_utils.inference.inference` as ``model_nll``. Parameters ---------- param_array : array-like, shape (n_params,) Parameter values in the order defined by :meth:`get_model_parameters`. Returns ------- nll : jnp.ndarray, scalar The negative log-likelihood value (scalar). """ param_array = jnp.asarray(param_array) return self._jit_nll(param_array, self._model_data)
[docs] def model_grad(self, param_array: Union[np.array, jnp.array, list[float]]): """ Return the gradient of the NLL with respect to all parameters. Uses JAX reverse-mode autodiff (JIT-compiled). Suitable as the ``grad`` argument to :class:`iminuit.Minuit`. Parameters ---------- param_array : array-like, shape (n_params,) Returns ------- grad : np.ndarray, shape (n_params,) """ param_array = jnp.asarray(param_array) _, g = self._jit_val_and_grad(param_array, self._model_data) return np.asarray(g)
def _get_nominal_expected_arrays(self, type_of_fit:str): """ Get an array of expected event yields or ratios """ data_expected = {sample_name : np.array([]) for sample_name in self.all_samples} ratio_expected = {sample_name : np.array([]) for sample_name in self.all_samples} if type_of_fit == "binned": channels_list = self.channels_binned elif type_of_fit == "unbinned": channels_list = self.channels_unbinned for sample_name in self.all_samples: for channel_name in channels_list: channel_index = self._index_of_region(channel_name = channel_name) sample_index = self._index_of_sample(channel_name = channel_name, sample_name = sample_name) if type_of_fit == "binned": sample_data = np.array(self.workspace["channels"][channel_index]["samples"][sample_index]["data"]) sample_ratio = np.array([]) elif type_of_fit == "unbinned": sample_data = np.array(self.workspace["channels"][channel_index]["samples"][sample_index]["data"]) sample_ratio = np.load(self.workspace["channels"][channel_index]["samples"][sample_index]["ratios"]) data_expected[sample_name] = np.append(data_expected[sample_name], sample_data) ratio_expected[sample_name] = np.append(ratio_expected[sample_name], sample_ratio) return data_expected, ratio_expected def _calculate_norm_variations(self, param_vec): norm_var = {sample_name: 1.0 for sample_name in self.all_samples} for sample, params_sample in self.norm_sample_map.items(): # params_sample: list[str] for param in params_sample: index_param = self.index_normparam_map[param] norm_var[sample] *= param_vec[index_param] return norm_var def _get_systematic_data(self, type_of_fit: str) -> Dict[str, jnp.ndarray]: """ Builds a rectangular array with (N_syst, N_datapoints) dimensions, where N_datapoints is the number of bins in binned channels and number of events in unbinned channels. Concatenates all binned or all unbinned channels into one big array for array-based computations. type_of_fit -> choose if building array for "unbinned" channels or "binned" """ if type_of_fit == "binned": base_array_for_size = self.yield_array_dict[self.all_samples[0]] channel_list = self.channels_binned elif type_of_fit == "unbinned": base_array_for_size = self.ratios_array_dict[self.all_samples[0]] base_tot_for_size = self.unbinned_total_dict[self.all_samples[0]] channel_list = self.channels_unbinned combined_var_up = {sample_name: np.ones((len(self.list_syst_normplusshape), len(base_array_for_size))) for sample_name in self.all_samples} combined_var_dn = {sample_name: np.ones((len(self.list_syst_normplusshape), len(base_array_for_size))) for sample_name in self.all_samples} if type_of_fit == "unbinned": combined_tot_up = {sample_name: np.ones((len(self.list_syst_normplusshape), len(base_tot_for_size))) for sample_name in self.all_samples} combined_tot_dn = {sample_name: np.ones((len(self.list_syst_normplusshape), len(base_tot_for_size))) for sample_name in self.all_samples} for sample_name in self.all_samples: for count, systematic_name in enumerate(self.list_syst_normplusshape): var_up_array_syst = np.array([]) var_dn_array_syst = np.array([]) if type_of_fit == "unbinned": var_up_tot_syst = np.array([]) var_dn_tot_syst = np.array([]) for channel_name in channel_list: channel_index = self._index_of_region(channel_name = channel_name) sample_index = self._index_of_sample(channel_name = channel_name, sample_name = sample_name) modifier_index = self._index_of_modifiers(channel_name = channel_name, sample_name = sample_name, systematic_name = systematic_name) modifier_dict = self.workspace["channels"][channel_index]["samples"][sample_index]["modifiers"][modifier_index] if type_of_fit == "binned": var_array_up_channel = modifier_dict["data"]["hi_data"] var_array_dn_channel = modifier_dict["data"]["lo_data"] elif type_of_fit == "unbinned": var_array_up_channel = np.load(modifier_dict["data"]["hi_ratio"]) var_total_up_channel = modifier_dict["data"]["hi_data"] var_array_dn_channel = np.load(modifier_dict["data"]["lo_ratio"]) var_total_dn_channel = modifier_dict["data"]["lo_data"] var_up_tot_syst = np.append(var_up_tot_syst, var_total_up_channel) var_dn_tot_syst = np.append(var_dn_tot_syst, var_total_dn_channel) var_up_array_syst = np.append(var_up_array_syst, var_array_up_channel) var_dn_array_syst = np.append(var_dn_array_syst, var_array_dn_channel) combined_var_up[sample_name][count] = var_up_array_syst combined_var_dn[sample_name][count] = var_dn_array_syst if type_of_fit == "unbinned": combined_tot_up[sample_name][count] = var_up_tot_syst combined_tot_dn[sample_name][count] = var_dn_tot_syst if type_of_fit == "unbinned": return combined_var_up, combined_var_dn, combined_tot_up, combined_tot_dn return combined_var_up, combined_var_dn def _finalize_to_device(self): # convert to JAX arrays for JIT compiled function self.yield_array_dict = tree_map(jnp.asarray, self.yield_array_dict) self.unbinned_total_dict = tree_map(jnp.asarray, self.unbinned_total_dict) self.ratios_array_dict = tree_map(jnp.asarray, self.ratios_array_dict) self.combined_var_up_unbinned = tree_map(jnp.asarray, self.combined_var_up_unbinned) self.combined_var_dn_unbinned = tree_map(jnp.asarray, self.combined_var_dn_unbinned) self.combined_tot_up_unbinned = tree_map(jnp.asarray, self.combined_tot_up_unbinned) self.combined_tot_dn_unbinned = tree_map(jnp.asarray, self.combined_tot_dn_unbinned) self.combined_var_up_binned = tree_map(jnp.asarray, self.combined_var_up_binned) self.combined_var_dn_binned = tree_map(jnp.asarray, self.combined_var_dn_binned) self.weight_arrays_unbinned = jnp.asarray(self.weight_arrays_unbinned) def _build_stacked_data(self): """Stack per-process dicts into arrays and bundle into a single pytree for JIT.""" samples = self.all_samples # Stack nominal data: (n_samples, n_datapoints) yield_stacked = jnp.stack([self.yield_array_dict[s] for s in samples]) unbinned_total_stacked = jnp.stack([self.unbinned_total_dict[s] for s in samples]) ratios_stacked = jnp.stack([self.ratios_array_dict[s] for s in samples]) # Stack systematic variations: (n_samples, n_syst, n_datapoints) var_up_binned_stacked = jnp.stack([self.combined_var_up_binned[s] for s in samples]) var_dn_binned_stacked = jnp.stack([self.combined_var_dn_binned[s] for s in samples]) var_up_unbinned_stacked = jnp.stack([self.combined_var_up_unbinned[s] for s in samples]) var_dn_unbinned_stacked = jnp.stack([self.combined_var_dn_unbinned[s] for s in samples]) tot_up_unbinned_stacked = jnp.stack([self.combined_tot_up_unbinned[s] for s in samples]) tot_dn_unbinned_stacked = jnp.stack([self.combined_tot_dn_unbinned[s] for s in samples]) # Norm-factor mask: (n_samples, n_params) — True where param j is a normfactor for sample i. prod(where(mask, param_vec, 1)) gives the per-sample multiplicative modifier. n_samples = len(samples) n_params = len(self.list_parameters) norm_matrix = np.zeros((n_samples, n_params), dtype=bool) for i, sample in enumerate(samples): if sample in self.norm_sample_map: for nf_name in self.norm_sample_map[sample]: j = self.index_normparam_map[nf_name] norm_matrix[i, j] = True # Bundle everything into a dict pytree passed as a *dynamic* argument to the JIT-compiled NLL so that arrays are traced as abstract inputs (no constant-folding / memory blow-up). self._model_data = { 'yield': yield_stacked, 'unbinned_total': unbinned_total_stacked, 'ratios': ratios_stacked, 'var_up_binned': var_up_binned_stacked, 'var_dn_binned': var_dn_binned_stacked, 'var_up_unbinned': var_up_unbinned_stacked, 'var_dn_unbinned': var_dn_unbinned_stacked, 'tot_up_unbinned': tot_up_unbinned_stacked, 'tot_dn_unbinned': tot_dn_unbinned_stacked, 'norm_matrix': jnp.array(norm_matrix), 'expected_hist': self.expected_hist, 'expected_rate': self.expected_rate_unbinned, 'weights': self.weight_arrays_unbinned, } def _build_jit_functions(self): """ Create JIT-compiled NLL and value-and-grad functions. """ num_unc = self.num_unconstrained_param has_syst = self.has_normplusshape _batched_var = jax.vmap(_calculate_combined_var, in_axes=(None, 0, 0)) def _nll_pure(param_vec, data): param_syst = param_vec[num_unc:] norm_mods = jnp.prod( jnp.where(data['norm_matrix'], param_vec[None, :], 1.0), axis=1 ) if has_syst: hist_vars_binned = _batched_var(param_syst, data['var_up_binned'], data['var_dn_binned']) hist_vars_unbinned = _batched_var(param_syst, data['tot_up_unbinned'], data['tot_dn_unbinned']) ratio_vars = _batched_var(param_syst, data['var_up_unbinned'], data['var_dn_unbinned']) else: hist_vars_binned = jnp.ones_like(data['yield']) hist_vars_unbinned = jnp.ones_like(data['unbinned_total']) ratio_vars = jnp.ones_like(data['ratios']) nu_binned = jnp.sum(norm_mods[:, None] * data['yield'] * hist_vars_binned, axis=0) llr_binned = -2.0 * jnp.sum( data['expected_hist'] * jnp.log(nu_binned) - nu_binned ) nu_unbinned = jnp.sum( norm_mods[:, None] * data['unbinned_total'] * hist_vars_unbinned, axis=0 ) llr_rate = -2.0 * jnp.sum( data['expected_rate'] * jnp.log(nu_unbinned) - nu_unbinned ) dnu_dx = jnp.sum( norm_mods[:, None] * hist_vars_unbinned * data['unbinned_total'] * data['ratios'] * ratio_vars, axis=0 ) llr_pe = jnp.log(dnu_dx) - jnp.log(nu_unbinned) llr_constraints = jnp.sum(param_syst ** 2) return (llr_binned + llr_rate - 2.0 * jnp.sum(data['weights'] * llr_pe, axis=0) + llr_constraints) jit_nll = jax.jit(_nll_pure) jit_val_and_grad = jax.jit(jax.value_and_grad(_nll_pure, argnums=0)) return jit_nll, jit_val_and_grad def _index_of_modifiers(self, channel_name: str, sample_name: str, systematic_name: str) -> Optional[int]: """ Get the index associated with a systematic, in a specific sample of a particular channel """ channel_index = self._index_of_region(channel_name) sample_index = self._index_of_sample(channel_name, sample_name) modifiers: list[dict[str, Any]] = self.workspace["channels"][channel_index]["samples"][sample_index]["modifiers"] for count, modifier in enumerate(modifiers): if modifier.get("name") == systematic_name: return count return None def _index_of_sample(self, channel_name: str, sample_name: str) -> Optional[int]: """ Get the index associated with a sample, in a particular channel """ channel_index = self._index_of_region(channel_name) samples: list[dict[str, Any]] = self.workspace["channels"][channel_index]["samples"] for count, sample in enumerate(samples): if sample.get("name") == sample_name: return count return None def _index_of_region(self, channel_name: str) -> Optional[int]: """ Get the index associated with a particular channel in the workspace """ channels: list[dict[str, Any]] = self.workspace["channels"] for count, channel in enumerate(channels): if channel.get("name") == channel_name: return count return None
@jax.jit def _poly_interp(tuple_input): """ Sixth-order polynomial interpolation for systematic variations. Implements the HistFactory "strategy 5" interpolation used when :math:`|\\alpha| \\le 1`. Smoothly connects the upward and downward variation multipliers using a degree-6 polynomial in the nuisance parameter :math:`\\alpha`. Parameters ---------- tuple_input : tuple of (jnp.ndarray, jnp.ndarray, jnp.ndarray) ``(alpha, pow_up, pow_down)`` where *alpha* is the nuisance parameter value, *pow_up* the upward variation ratio, and *pow_down* the downward variation ratio. Returns ------- variation : jnp.ndarray Multiplicative correction to apply to the nominal prediction. """ alpha, pow_up, pow_down = tuple_input logHi = jnp.log(pow_up) logLo = jnp.log(pow_down) pow_up_log = jnp.multiply(pow_up, logHi) pow_down_log = -jnp.multiply(pow_down, logLo) pow_up_log2 = jnp.multiply(pow_up_log, logHi) pow_down_log2 = -jnp.multiply(pow_down_log, logLo) S0 = (pow_up + pow_down) / 2.0 A0 = (pow_up - pow_down) / 2.0 S1 = (pow_up_log + pow_down_log) / 2.0 A1 = (pow_up_log - pow_down_log) / 2.0 S2 = (pow_up_log2 + pow_down_log2) / 2.0 A2 = (pow_up_log2 - pow_down_log2) / 2.0 a1 = ( 15 * A0 - 7 * S1 + A2) / 8.0 a2 = (-24 + 24 * S0 - 9 * A1 + S2) / 8.0 a3 = ( -5 * A0 + 5 * S1 - A2) / 4.0 a4 = ( 12 - 12 * S0 + 7 * A1 - S2) / 4.0 a5 = ( 3 * A0 - 3 * S1 + A2) / 8.0 a6 = ( -8 + 8 * S0 - 5 * A1 + S2) / 8.0 return alpha * (a1 + alpha * ( a2 + alpha * ( a3 + alpha * ( a4 + alpha * ( a5 + alpha * a6 ) ) ) ) ) @jax.jit def _exp_extrap(tuple_input): """ Exponential extrapolation for systematic variations. Used when :math:`|\\alpha| > 1` (outside the interpolation region). Extrapolates the variation as a power law in :math:`\\alpha`. Parameters ---------- tuple_input : tuple of (jnp.ndarray, jnp.ndarray, jnp.ndarray) ``(alpha, varUp, varDown)`` where *alpha* is the nuisance parameter value, *varUp* the upward variation ratio, and *varDown* the downward variation ratio. Returns ------- variation : jnp.ndarray Multiplicative correction (minus 1) to apply to the nominal prediction. See Also -------- _poly_interp : Polynomial interpolation for :math:`|\\alpha| \\le 1`. """ alpha, varUp, varDown = tuple_input return jnp.where(alpha>1.0, (varUp)**alpha, (varDown)**(-alpha)) - 1.0 @jax.jit def _calculate_combined_var(param_vec, combined_var_up, combined_var_down): """ Compute the net multiplicative effect of all systematic variations. Sequentially applies each nuisance parameter's variation using :func:`_poly_interp` (for :math:`|\\alpha| \\le 1`) or :func:`_exp_extrap` (for :math:`|\\alpha| > 1`) via ``jax.lax.scan``. Parameters ---------- param_vec : jnp.ndarray, shape (n_syst,) Values of the constrained nuisance parameters. combined_var_up : jnp.ndarray, shape (n_syst, n_datapoints) Upward variation ratios for each systematic and data point. combined_var_down : jnp.ndarray, shape (n_syst, n_datapoints) Downward variation ratios for each systematic and data point. Returns ------- combined_var : jnp.ndarray, shape (n_datapoints,) Net multiplicative variation factor across all systematics. """ def calculate_variations(carry, param_val): param, combined_var_up_NP, combined_var_down_NP = param_val combined_var_array_alpha = carry # Strategy 5 of RooFit: combined_var_array_alpha += combined_var_array_alpha * jax.lax.cond(jnp.abs(param)<=1.0, _poly_interp, _exp_extrap, (param, combined_var_up_NP, combined_var_down_NP)) return combined_var_array_alpha, None # Prepare loop_tuple for jax.lax.scan loop_tuple = (param_vec, combined_var_up, combined_var_down) # Loop over systematic variations to calculate net effect combined_var_array, _ = jax.lax.scan(calculate_variations, jnp.ones_like(combined_var_up[0]), loop_tuple) return combined_var_array