Calculate error statistics from array of weights

Import packages. This only requires jax to work.

[ ]:
import jax.numpy as jnp
import numpy as np

# load in population-error package
from population_error import error_statistics_from_weights

For each found VT injection \(\theta_j\) and each sample from the hyperposterior \(\Lambda_n\), vt_weights array should be \(p(\theta_j | \Lambda_n) / p(\theta_j | {\rm draw})\) weights. The shape is vt_weights.shape = (Nsamp, Nfound)

Similarly, for each posterior sample, for the \(i^{\rm th}\) event, the \(j^{\rm th}\) posterior sample \(\theta_{ij}\), event_weights should be \(p(\theta_{ij} | \Lambda_n) / \pi(\theta_{ij}|{\rm PE})\) are the weights. The shape is event_weights.shape = (Nsamp, Nobs, NPE).

Finally, total_samples is the total number of injections, \(N_{\rm inj}\).

[2]:
# just random arrays for simplicity. Often this approach runs into memory errors
vt_weights = jnp.array(np.random.uniform(size=(500, 100_000)))
event_weights = jnp.array(np.random.uniform(size=(500, 100, 1000)))
total_samples = 5e6

Compute error statistics

[3]:
statistics = error_statistics_from_weights(vt_weights, event_weights, total_samples, include_likelihood_correction=True)
print(statistics)
Running for 500 iterations: 100%|██████████| 500/500 [00:10<00:00, 45.96it/s]
Running for 500 iterations: 100%|██████████| 500/500 [00:10<00:00, 45.94it/s]
{'error_statistic': 0.14175374994676887, 'precision_statistic': 0.14175374564949028, 'accuracy_statistic': 4.297278581024669e-09}