import jax.numpy as jnp
import jax
import jax_tqdm
import bilby
import gwpopulation
import pandas as pd
[docs]
def selection_function(weights, total_generated):
"""
Compute the selection function given weights and total number of injections.
Parameters
----------
weights : jnp.ndarray
Array of importance weights for injection samples.
total_generated : int or float
Total number of injections.
Returns
-------
float
Estimated selection function value (mean weight normalized by total samples).
"""
return jnp.sum(weights) / total_generated
[docs]
def selection_function_log_covariance(weights_n, weights_m, total_generated):
"""
Compute the covariance of log selection functions between two weight sets.
Parameters
----------
weights_n : jnp.ndarray
First set of importance weights.
weights_m : jnp.ndarray
Second set of importance weights (must match shape of weights_n).
total_generated : int or float
Total number of injections.
Returns
-------
float
Covariance between log selection function estimates.
"""
assert weights_n.shape == weights_m.shape
mu_n, mu_m = selection_function(weights_n, total_generated), selection_function(weights_m, total_generated)
cov = jnp.sum(weights_n * weights_m) / total_generated / mu_n / mu_m - 1
return cov / (total_generated-1)
[docs]
def likelihood_log_correction(weights, total_generated, Nobs):
"""
Compute the likelihood log-correction term for variance estimation.
Parameters
----------
weights : jnp.ndarray
Importance weights for injection samples.
total_generated : int or float
Total number of injections.
Nobs : int
Number of observed events.
Returns
-------
float
Likelihood log-correction value.
"""
var = selection_function_log_covariance(weights, weights, total_generated)
return Nobs * (Nobs+1) * var / 2
[docs]
def reweighted_event_bayes_factors(event_pe_weights):
"""
Compute reweighted Bayes factors for a set of events.
Parameters
----------
event_pe_weights : jnp.ndarray
Array of shape (Nobs, NPE) with posterior sample weights per event.
Returns
-------
jnp.ndarray
Array of mean Bayes factors per event, shape (Nobs,).
"""
return jnp.mean(event_pe_weights, axis=1)
[docs]
def event_log_covariances(event_pe_weights_n, event_pe_weights_m):
"""
Compute covariances of log Bayes factors between two sets of event weights.
Parameters
----------
event_pe_weights_n : jnp.ndarray
First array of event posterior sample weights, shape (Nobs, NPE).
event_pe_weights_m : jnp.ndarray
Second array of event posterior sample weights (same shape as above).
Returns
-------
jnp.ndarray
Covariances per event, shape (Nobs,).
"""
assert event_pe_weights_m.shape == event_pe_weights_n.shape
Nobs, NPE = event_pe_weights_n.shape
mu_n = reweighted_event_bayes_factors(event_pe_weights_n)
mu_m = reweighted_event_bayes_factors(event_pe_weights_m)
cov = jnp.mean(event_pe_weights_n*event_pe_weights_m, axis=1) / mu_n / mu_m - 1
return cov / (NPE - 1)
[docs]
def log_likelihood_covariance(vt_weights_n, vt_weights_m, event_pe_weights_n, event_pe_weights_m, total_generated):
"""
Compute covariance of log-likelihood estimates from injection and event weights.
Parameters
----------
vt_weights_n : jnp.ndarray
Injection weights for the first hyperposterior sample.
vt_weights_m : jnp.ndarray
Injection weights for the second hyperposterior sample.
event_pe_weights_n : jnp.ndarray
Event posterior weights for the first sample, shape (Nobs, NPE).
event_pe_weights_m : jnp.ndarray
Event posterior weights for the second sample, shape (Nobs, NPE).
total_generated : int or float
Total number of injections.
Returns
-------
float
Log-likelihood covariance estimate.
"""
Nobs, NPE = event_pe_weights_n.shape
event_covs = event_log_covariances(event_pe_weights_n, event_pe_weights_m)
vt_cov = selection_function_log_covariance(vt_weights_n, vt_weights_m, total_generated)
return jnp.sum(event_covs) + Nobs**2 * vt_cov
[docs]
def error_statistics_from_weights(vt_weights, event_weights, total_generated, include_likelihood_correction=True):
"""
Compute error statistics for hyperposterior, Eqs. 36-39 of arxiv:2509.07221
Parameters
----------
vt_weights : jnp.ndarray
Array of shape (n_samples, n_injections), injection weights per hyperposterior sample.
event_weights : jnp.ndarray
Array of shape (n_samples, n_obs, n_pe), event posterior weights per hyperposterior sample.
total_generated : int or float
Total number of injections.
include_likelihood_correction : bool, default=True
Whether to include the likelihood correction term in the accuracy statistic. Set to True if
inference did not include the likelihood correction term, set to False if inference did
include the likelihood correction.
Returns
-------
tuple of floats
(precision, accuracy, error), where:
- precision : float
Expected information lost to uncertainty in posterior estimator.
- accuracy : float
Expected information lost to bias in posterior estimator
- error : float
Expected information lost to both bias and uncertainty in posterior estimator.
"""
length, Nobs, NPE = event_weights.shape
axis = jnp.arange(length)
arr_n, arr_m = jnp.meshgrid(axis, axis, indexing='ij')
f = lambda n, m: log_likelihood_covariance(vt_weights[n], vt_weights[m], event_weights[n], event_weights[m], total_generated)
_f = lambda x: f(x, x)
variances = jax.lax.map(_f, axis)
@jax_tqdm.scan_tqdm(length, print_rate=1, tqdm_type='std')
def weight_func(carry, n):
_f = lambda x: f(arr_n[n,x], arr_m[n,x])
meanw = jnp.mean(jax.lax.map(_f, axis), axis=0)
if include_likelihood_correction:
meanw = likelihood_log_correction(vt_weights[n], total_generated, Nobs) - meanw
return carry, meanw
weight_func = jax_tqdm.scan_tqdm(length, print_rate=1, tqdm_type='std')(weight_func)
_, weights = jax.lax.scan(weight_func, 0., xs=axis)
precision = float((jnp.mean(variances) - jnp.mean(weights)) / 2 / jnp.log(2))
accuracy = float(jnp.var(weights) / 2 / jnp.log(2))
error = float(precision + accuracy)
return {'error_statistic': error, 'precision_statistic': precision, 'accuracy_statistic': accuracy}
[docs]
def bilby_model_to_model_function(bilby_model, conversion_function=lambda args: (args, None), rate=False, rate_key='rate'):
"""
Wrap a Bilby or gwpopulation jax-compatible model into a callable function interface.
Note: if using the rate-full likelihood, this model should return dN/d\theta. It should
*not* be in comoving rate density in units Gpc^{-3} yr^{-1}. Instead, it is expected to be in
a density in redshift.
Parameters
----------
bilby_model : bilby.hyper.model.Model or callable Model object to be converted. If it is already
a callable, it is returned unchanged.
conversion_function : callable, optional
Function applied to parameter dictionaries before evaluating the model.
Should take a dict of parameters and return (modified_parameters, added_keys).
rate : bool, default=False
Whether to be used with the rate-full hierarchical likelihood.
rate_key : string, default='rate'
The key to recognize as N, where N is the total number of mergers in the Universe during the
observing time, e.g., dN/d\theta = Np(\theta | \Lambda). This is only used if rate=True and
using bilby_model as bilby.hyper.model.Model, as this only returns probability densities.
Returns
-------
callable
A function with signature (data, parameters) -> probability values,
where `data` is a dictionary of GW parameter samples and `parameters` are
hyperparameters of the population model.
"""
if not isinstance(bilby_model, (bilby.hyper.model.Model, gwpopulation.experimental.jax.NonCachingModel)):
# TODO: add some catches here, otherwise it assumes a particular form for the model
return bilby_model # function of data, parameters
from copy import copy
copy_models = [copy(m) for m in bilby_model.models]
bilby_model = bilby.hyper.model.Model(copy_models, cache=False)
def model_to_return(data, parameters):
if rate:
R = parameters.pop(rate_key)
else:
R = 1.
parameters, added_keys = conversion_function(parameters)
bilby_model.parameters.update(parameters)
return R*bilby_model.prob(data)
return model_to_return
def _compute_mean_weights_for_correction(
hyperposterior, n, bilby_model, gw_dataset, MC_integral_size=None, conversion_function=lambda args: (args, None),
MC_type='single event', verbose=True, rate=False, rate_key='rate'
):
"""
Compute mean event or selection weights integrated over hyperposterior samples.
For a Monte Carlo integral
\hat{I}(\Lambda) = \frac{1}{M}\sum_{i=1}^M \frac{p(\theta_i | \Lambda)}{p(\theta_i | {\rm draw})},
and a set of $N_{\rm samp}$ samples from the hyperposterior $\Lambda_n$, then compute
\overline{w}_i = \frac{1}{N_{\rm samp}} \sum_{n=1}^{N_{\rm samp}} \frac{1}{\hat{I}(\Lambda_n)}\frac{p(\theta_i | \Lambda_n)}{p(\theta_i | {\rm draw})},
the weight averaged over the hyperposterior and dividing out $\hat{I}$, for use in computing the error statistics.
Parameters
----------
hyperposterior : dict of jnp.ndarray
Hyperposterior samples with keys as hyperparameters, and values are jnp.ndarray
with first dimension indexing the hyperposterior sample
n : int
Number of samples in the hyperposterior
bilby_model : bilby.hyper.model.Model, or callable
Population model used to compute probabilities.
gw_dataset : dict
Dictionary of GW data samples, must include 'prior' key for sampling prior.
MC_integral_size : int, optional
Number of Monte Carlo samples. If None, inferred from `gw_dataset` or prior shape.
conversion_function : callable, optional
Function to convert hyperposterior parameters before model evaluation. For example,
gwpopulation.conversions.convert_to_beta_parameters
MC_type : str, default='single event'
Label for progress bar (e.g. 'single event' or 'selection').
verbose : bool, default=True
Whether to show a progress bar.
rate: bool, default=False
Whether to compute the integrated weight *without* dividing by the MC integral. For use with
the rate-full likelihood. If setting rate=True, then the bilby_model must return dN/d\theta. It
should *not* be in units of comoving merger rate density.
rate_key : string, default='rate'
The key to recognize as N, where N is the total number of mergers in the Universe during the
observing time, e.g., dN/d\theta = Np(\theta | \Lambda). This is only used if rate=True and
using bilby_model as bilby.hyper.model.Model, as this only returns probability densities.
Returns
-------
jnp.ndarray
Mean normalized event weights across hyperposterior samples,
shape matching the sampling prior.
"""
model_function = bilby_model_to_model_function(bilby_model, conversion_function=conversion_function, rate=rate, rate_key=rate_key)
gw_dataset = gw_dataset.copy()
sampling_prior = gw_dataset.pop('prior')
if MC_integral_size is None:
try:
MC_integral_size = gw_dataset.pop('total_generated')
except KeyError:
MC_integral_size = sampling_prior.shape[-1]
mean_event_weights = jnp.zeros_like(sampling_prior) # (Nevents, NPE)
keys = hyperposterior.keys()
def weights_for_single_sample(ii, mean_event_weights):
parameters = {k: hyperposterior[k][ii] for k in keys}
weights = model_function(gw_dataset, parameters) / sampling_prior
if rate:
expectation = jnp.ones(weights.shape[:-1])
else:
expectation = jnp.sum(weights, axis=-1) / MC_integral_size
return mean_event_weights + weights / expectation[..., None] / n
if verbose:
f = jax_tqdm.loop_tqdm(n, print_rate=1, tqdm_type='std', desc=f'Computing {MC_type} covariance weights integrated over hyperposterior samples')
else:
f = jax.jit
weights_for_single_sample = f(weights_for_single_sample)
mean_event_weights = jax.lax.fori_loop(
0,
n,
weights_for_single_sample,
mean_event_weights,
)
return mean_event_weights
def _compute_integrated_cov(integrated_weights, sample, model_function, gw_dataset, MC_integral_size=None, rate=False):
"""
Compute integrated covariance and variance for weights of a single posterior sample.
Parameters
----------
integrated_weights : jnp.ndarray
Precomputed integrated weights across hyperposterior samples.
sample : dict
A single hyperposterior parameter sample.
model_function : callable
Function mapping (dataset, parameters) -> probability values.
gw_dataset : dict
Dataset dictionary with GW parameter samples, must include 'prior'.
MC_integral_size : int, optional
Number of Monte Carlo samples. If None, inferred from dataset.
rate : bool, default=True
Whether to assume rate-full likelihood. If True, then the weights are assumed to include the overall
rate normalization, N, where dN/d\theta = Np(\theta | \Lambda). It should only be set to True when
computing the integrated covariance for the selection efficiency integral.
Returns
-------
tuple of jnp.ndarray
- integrated_cov : Integrated covariance estimate for the sample.
- var : Variance estimate for the sample.
"""
gw_dataset = gw_dataset.copy()
sampling_prior = gw_dataset.pop('prior')
if MC_integral_size is None:
try:
MC_integral_size = gw_dataset.pop('total_generated')
except KeyError:
MC_integral_size = sampling_prior.shape[-1]
weights = model_function(gw_dataset, sample) / sampling_prior
if rate:
expectation = jnp.ones(weights.shape[:-1])
else:
expectation = jnp.sum(weights, axis=-1) / MC_integral_size
var = (-1. + jnp.sum(weights**2, axis=-1) / MC_integral_size / expectation**2) / (MC_integral_size - 1)
integrated_cov = (-1. + jnp.sum(integrated_weights * weights, axis=-1) / MC_integral_size / expectation) / (MC_integral_size - 1)
return integrated_cov, var
[docs]
def format_hyperposterior(hyperposterior):
if isinstance(hyperposterior, pd.DataFrame):
hyperposterior = hyperposterior.to_dict(orient='list')
else:
if not isinstance(hyperposterior, dict):
raise IOError(f"Hyperposterior must be a dictionary or pandas.DataFrame, not {type(hyperposterior)}")
ns = []
for k in hyperposterior.keys():
hyperposterior[k] = jnp.array(hyperposterior[k])
ns.append(hyperposterior[k].shape[0])
if not jnp.all(jnp.array(ns) == ns[0]):
raise IOError(f"Hyperposterior has unequal number of samples for hyperparameters.")
n = ns[0]
return hyperposterior, n
[docs]
def error_statistics(
model_function,
injections,
event_posteriors,
hyperposterior,
vt_model_function=None,
include_likelihood_correction=True,
conversion_function=lambda args: (args, None),
nobs=None,
verbose=True,
rate=False,
rate_key='rate',
):
"""
Compute error, precision, and accuracy statistics from model, hyperposterior, and data.
Parameters
----------
model_function : bilby.hyper.model.Model, callable
Population model with interface (dataset, parameters) -> probabilities.
injections : dict
Injection dataset, including 'prior' and 'total_generated' keys.
event_posteriors : dict
Event posterior samples, including 'prior' key.
hyperposterior : pandas.DataFrame or dict of jnp.ndarray
If pandas.DataFrame, converts to appropriate format. Otherwise, hyperposterior
samples with keys as hyperparameters, and values are jnp.ndarray with first
dimension indexing the hyperposterior sample
vt_model_function : bilby.hyper.model.Model, callable, optional
Optional separate model instance for evaluating the selection function.
Population model with interface (dataset, parameters) -> probabilities. If not
included, set to model_function
include_likelihood_correction : bool, default=True
Whether to include likelihood correction in accuracy estimate.
Set to False if the hyperlikelihood for sampling from the posterior was estimated
using the unbiased likelihood of Eq. 24 of https://arxiv.org/abs/2509.07221
conversion_function : callable, optional
Function to convert hyperposterior parameters before model evaluation.
nobs : int, optional
Number of observed events. If None, inferred from `event_posteriors`.
verbose : bool, default=True
Whether to print progress and summary messages.
rate : bool, default=False
Whether to treat the VT weights as rate-weighted. TESTTHIS!!!
rate_key : string, default='rate'
The key which to access the overall merger rate within the posterior.
Returns
-------
dict
Dictionary with keys:
- 'error_statistic' : float, total information loss in bits.
- 'precision_statistic' : float, information loss due to variance.
- 'accuracy_statistic' : float, information loss due to bias.
"""
hyperposterior, n = format_hyperposterior(hyperposterior)
if nobs is None:
nobs = event_posteriors['prior'].shape[0]
if verbose:
print(f'Nobs not provided, assuming Nobs = {nobs}')
total_generated = injections['total_generated']
mean_event_weights = _compute_mean_weights_for_correction(
hyperposterior,
n,
model_function,
event_posteriors,
conversion_function=conversion_function,
MC_type='single event',
verbose=verbose
)
if vt_model_function is None:
vt_model_function = model_function
mean_vt_weights = _compute_mean_weights_for_correction(
hyperposterior,
n,
vt_model_function,
injections,
MC_integral_size=total_generated,
conversion_function=conversion_function,
MC_type='selection',
verbose=verbose,
rate=rate,
rate_key=rate_key
)
def create_loop_fn(m, p, MC_type='single event'):
if MC_type=='single event':
_rate = False
_model_function = bilby_model_to_model_function(model_function, conversion_function=conversion_function, rate=_rate, rate_key=rate_key)
else:
_rate = rate
_model_function = bilby_model_to_model_function(vt_model_function, conversion_function=conversion_function, rate=_rate, rate_key=rate_key)
loop_fn = lambda _, sample: (_, (sample[0],)+ _compute_integrated_cov(
m,
sample[1],
_model_function,
p,
rate=_rate
))
if verbose:
return jax_tqdm.scan_tqdm(n, print_rate=1, tqdm_type='std', desc=f'For each posterior sample, average {MC_type} covariance with another posterior sample')(loop_fn)
else:
return jax.jit(loop_fn)
_, (_, event_integrated_covs, event_vars) = jax.lax.scan(
create_loop_fn(mean_event_weights, event_posteriors),
0,
(jnp.arange(n), hyperposterior),
length=n
)
_, (_, vt_integrated_covs, vt_vars) = jax.lax.scan(
create_loop_fn(mean_vt_weights, injections, MC_type='selection'),
0,
(jnp.arange(n), hyperposterior),
length=n
)
if rate:
nobs = 1
var = jnp.sum(event_vars, axis=-1) + nobs**2 * vt_vars
cov = jnp.sum(event_integrated_covs, axis=-1) + nobs**2 * vt_integrated_covs
event_precision = float(jnp.mean(jnp.sum(event_vars - event_integrated_covs, axis=-1)) / 2 / jnp.log(2))
vt_precision = float(nobs**2 * jnp.mean(vt_vars - vt_integrated_covs) / 2 / jnp.log(2))
precision = float((jnp.mean(var) - jnp.mean(cov)) / 2 / jnp.log(2))
if include_likelihood_correction:
if rate:
correction = vt_vars / 2
else:
correction = nobs*(nobs+1) * vt_vars / 2
accuracy = float(jnp.var(cov - correction) / 2 / jnp.log(2))
selection_w = nobs**2 * vt_integrated_covs - correction
else:
accuracy = float(jnp.var(cov) / 2 / jnp.log(2))
selection_w = nobs**2 * vt_integrated_covs
event_w = jnp.sum(event_integrated_covs, axis=-1)
event_accuracy = float(jnp.var(event_w) / 2 / jnp.log(2))
selection_accuracy = float(jnp.var(selection_w) / 2 / jnp.log(2))
correlation_accuracy = float(jnp.mean((event_w - jnp.mean(event_w))*(selection_w - jnp.mean(selection_w))) / jnp.log(2))
error = float(precision + accuracy)
if verbose:
print(f'\nYour inference loses approximately {round(error, 3)} bits of information to Monte Carlo approximations.')
print(f'Of the total information loss')
print(f' * {round(precision, 3)} bits is from uncertainty in the posterior. Of this')
print(f' * {round(100*event_precision/precision, 1)}% is from the single-event Monte Carlo integration')
print(f' * {round(100*vt_precision/precision, 1)}% is from the selection Monte Carlo integration')
print(f' * {round(accuracy, 5)} bits is from bias in the posterior. Of the total bias')
print(f' * {round(100*event_accuracy/accuracy, 1)}% is from the single-event Monte Carlo integration')
print(f' * {round(100*selection_accuracy/accuracy, 1)}% is from the selection Monte Carlo integration')
print(f' * {round(100*correlation_accuracy/accuracy, 1)}% is from correlations in the uncertainty of the single-event and selection MC integrals')
# how much due to VT and how much due to events? We can also compute this :O I believe bc they are additive. Well,
# I don't know if we can do it necessarily for the accuracy statistic, because Var(E + V) = Var(E) + 2Cov(E,V) + Var(V), so
# we would technically have a "covariance" between event and vt terms. Still, could be interesting at least to compute precision from VT and precision from events
return {
'error_statistic': error,
'precision_statistic': precision,
'accuracy_statistic': accuracy,
'event_precision_statistic': event_precision,
'selection_precision_statistic': vt_precision,
'event_accuracy_statistic': event_accuracy,
'selection_accuracy_statistic': selection_accuracy,
'correlation_event_selection_accuracy_statistic': correlation_accuracy,
}