Statistical Models

The statistical models available for NSBI. Defines the JIT-compiled negative log-likelihood (ratio) function written using JAX, which can be used by fitting algorithms.

The two main entry points for downstream code are:

  • model() — the NLL callable (pass to inference as model_nll).

  • model_grad() — the NLL gradient callable (pass to inference as model_grad).

class sbi_parametric_model(workspace, measurement_to_fit)[source]

Bases: object

Statistical model for semi-parametric Simulation-Based Inference (SBI).

Defines parameterized expected yields, density ratios, and the negative log-likelihood (NLL) passed to fitting algorithms. Supports both binned and unbinned channels with systematic uncertainties handled via polynomial interpolation / exponential extrapolation (HistFactory strategy 5).

Two JIT-compiled entry points are built:

  • model() — evaluates the negative log-likelihood function for fit.

  • model_grad() — evaluates the negative log-likelihood gradient via JAX reverse-mode autodiff.

Parameters:
  • workspace (dict) – A workspace dictionary following the pyhf-like JSON schema. Must contain "measurements" and "channels" keys. Channels may be tagged with "type": "binned" or "type": "unbinned".

  • measurement_to_fit (str) – Name of the measurement block inside workspace["measurements"] to use. Selects the parameter of interest (POI) and the list of parameters to fit.

Parameters:
list_parameters

Ordered parameter names: POI first, then unconstrained norm factors, then constrained nuisance parameters.

Type:

list of str

initial_parameter_values

Starting values for every parameter, in the same order as list_parameters.

Type:

jnp.ndarray

num_unconstrained_param

Number of leading parameters that are unconstrained (POI + free norm factors).

Type:

int

expected_hist
Type:

jnp.ndarray Binned expected yields evaluated at the initial parameter values.

See also

nsbi_common_utils.inference.inference

Fits and scans this model.

get_model_parameters()[source]

Return parameter names and initial values for fitting.

The returned order matches the convention expected by inference: POI at index 0, followed by unconstrained norm factors, then constrained nuisance parameters.

Returns:

  • list_parameters (list of str) – Ordered parameter names.

  • initial_parameter_values (jnp.ndarray, shape (n_params,)) – Starting values aligned with list_parameters.

model(param_array)[source]

High-level API that returns the full negative log-likelihood for a parameter point.

Computes the combined NLL in all channels defined by the input workspace - unbinned SBI and binned Control and Signal regions. This callable is the function to be passed to inference as model_nll.

Parameters:

param_array (array-like, shape (n_params,)) – Parameter values in the order defined by get_model_parameters().

Returns:

nll (jnp.ndarray, scalar) – The negative log-likelihood value (scalar).

Parameters:

param_array (numpy.array | jax.numpy.array | list[float])

model_grad(param_array)[source]

Return the gradient of the NLL with respect to all parameters.

Uses JAX reverse-mode autodiff (JIT-compiled). Suitable as the grad argument to iminuit.Minuit.

Parameters:

param_array (array-like, shape (n_params,))

Returns:

grad (np.ndarray, shape (n_params,))

Parameters:

param_array (numpy.array | jax.numpy.array | list[float])