"""
Module providing wrappers for the pydelfi package to conform with the sbi
interface.
"""
import pickle
import emcee
import numpy as np
from math import ceil
from typing import Dict, List, Callable, Optional, Union
from pydelfi.delfi import Delfi
from ili.utils import load_nde_pydelfi
[docs]
class DelfiWrapper(Delfi):
"""Trainer for a neural posterior ensemble using the pydelfi package.
Wrapper for pydelfi.delfi.Delfi which adds some necessary
functionality and interface.
Args:
config_ndes (List[Dict]): list with configurations for each neural
posterior model in the ensemble
Other parameters are passed as input to the pydelfi.delfi.Delfi class
"""
def __init__(
self,
config_ndes: List[Dict],
name: Optional[str] = '',
**kwargs
):
super().__init__(**kwargs)
kwargs.pop('nde')
self.kwargs = kwargs
self.config_ndes = config_ndes
self.num_components = len(config_ndes)
self.name = name
self.prior.sample = self.prior.draw # aliasing for consistency
[docs]
def log_posterior_stacked(self, theta: np.array, x: np.array):
"""Redefinition of Delfi.log_posterior_stacked to do consistent shape
handling of theta and x.
Args:
theta (np.array): parameter vector
x (np.array): data vector to condition the inference on
"""
theta = np.atleast_2d(theta)
x = np.atleast_2d(x)
lik = self.log_likelihood_stacked(theta, x)
lik = lik.reshape(theta.shape[0], x.shape[0])
pri = self.prior.logpdf(theta)
pri = pri.reshape(theta.shape[0], 1)
return lik + pri
[docs]
def potential(self, theta: np.array, x: np.array):
"""Returns the log posterior probability of a data vector given
parameters. Modification of Delfi.log_prob designed to conform
with the form of sbi.utils.posterior_ensemble
Args:
theta (np.array): parameter vector
x (np.array): data vector to condition the inference on
Returns:
float: log posterior probability
"""
return self.log_posterior_stacked(theta, x)
[docs]
def sample(
self,
sample_shape: Union[int, tuple],
x: np.array,
show_progress_bars=False,
num_chains: int = 10,
burn_in=200,
thin=3,
skip_initial_state_check: bool = False
) -> np.array:
"""Samples from the posterior distribution using MCMC rejection.
Modification of Delfi.emcee_sample designed to conform with the
form of sbi.utils.posterior_ensemble
Args:
sample_shape (int, tuple[int]): size of samples to generate with
each MCMC walker, after burn-in
x (np.array): data vector to condition the inference on
show_progress_bars (bool): whether to print sampling progress
num_chains (int): number of MCMC chains to run in parallel
burn_in (int): length of burn-in for MCMC sampling
thin (int): thinning factor for MCMC sampling
skip_initial_state_check (bool): whether to skip the initial state
check for the MCMC sampler
Returns:
np.array: array of unique samples of shape (# of samples, # of
parameters), after MCMC rejection
"""
if isinstance(sample_shape, int):
sample_shape = (sample_shape,)
# calculate number of samples per chain
num_samples = np.prod(sample_shape)
per_chain = ceil(num_samples / num_chains)
# build posterior to sample
def log_target(t, x):
return self.potential(t, x)
# Initialize walkers
theta0 = np.stack([self.prior.sample()
for _ in range(num_chains)])
# Set up the sampler
sampler = emcee.EnsembleSampler(
num_chains,
self.npar,
log_target,
vectorize=False,
args=(x,),
)
# Sample
sampler.run_mcmc(
theta0,
burn_in + per_chain,
thin_by=thin,
progress=show_progress_bars,
skip_initial_state_check=skip_initial_state_check
)
# Pull out the unique samples and weights
chain = sampler.get_chain(discard=burn_in, flat=True)[:num_samples]
return chain.reshape((*sample_shape, self.npar))
[docs]
@staticmethod
def load_ndes(
config_ndes: List[Dict],
n_params: int,
n_data: int,
) -> List[Callable]:
"""Initialize the neural density estimators from configuration yamls.
Args:
config_ndes(List[Dict]): list with configurations for each neural
posterior model in the ensemble
n_params (int): dimensionality of each parameter vector
n_data (int): dimensionality of each datapoint
Returns:
List[Callable]: list of neural posterior models with forward
methods
"""
nets = []
for i, model_args in enumerate(config_ndes):
nets.append(
load_nde_pydelfi(
n_params=n_params, n_data=n_data,
index=i, **model_args))
return nets
[docs]
def save_engine(
self,
meta_filename: str,
):
"""Save necessary metadata for reloading to file
Args:
meta_filename (str): filename of saved metadata
"""
metadata = {
'n_data': self.D,
'n_params': self.npar,
'name': self.name,
'config_ndes': self.config_ndes,
'kwargs': self.kwargs
}
with open(self.results_dir + meta_filename, 'wb') as f:
pickle.dump(metadata, f)
[docs]
@classmethod
def load_engine(
cls,
meta_path: str,
):
"""Load a DelfiWrapper from metadata file
Args:
meta_path (str): path to saved metadata
Returns:
DelfiWrapper: a full Delfi inference model with pre-trained weights
"""
with open(meta_path, 'rb') as f:
metadata = pickle.load(f)
ndes = cls.load_ndes(
n_params=metadata['n_params'],
n_data=metadata['n_data'],
config_ndes=metadata['config_ndes']
)
if 'restore' in metadata['kwargs']:
metadata['kwargs'].pop('restore')
return cls(
**metadata['kwargs'],
nde=ndes,
config_ndes=metadata['config_ndes'],
name=metadata['name'],
restore=True
)