Source code for utils.samplers

"""
Custom samplers for sampling posteriors for Likelihood Estimation and
Ratio Estimation models. Currently supports emcee samplers for both sbi
and pydelfi backends, and pyro samplers only for the sbi backend.
"""

import os
import numpy as np
import emcee
from abc import ABC
from collections.abc import Sequence
from typing import Any
from math import ceil

try:
    import torch
    from sbi.inference.posteriors.base_posterior import NeuralPosterior
    from sbi.utils.posterior_ensemble import NeuralPosteriorEnsemble
    from sbi.inference.posteriors import (
        DirectPosterior, MCMCPosterior, VIPosterior)
    from sbi.inference.potentials.posterior_based_potential import (
        posterior_estimator_based_potential)
    ModelClass = NeuralPosterior
except ModuleNotFoundError:
    from ili.inference.pydelfi_wrappers import DelfiWrapper
    ModelClass = DelfiWrapper


class _MCMCSampler(ABC):
    """Base sampler class demonstrating the sampler functionality

    Args:
        posterior (Posterior): posterior object to sample from, must have
            a .potential method specifiying the log posterior
        num_chains (int, optional): number of chains to sample from. Defaults
            to os.cpu_count()-1
        thin (int, optional): thinning factor for the chains. Defaults to 10
        burn_in (int, optional): number of steps to discard as burn-in.
            Defaults to 100
    """

    def __init__(
            self,
            posterior: ModelClass,
            num_chains: int = -1,
            thin: int = 10,
            burn_in: int = 100,
    ) -> None:
        super().__init__()
        self.posterior = posterior
        self.num_chains = os.cpu_count()-1 if num_chains == -1 else num_chains
        self.thin = thin
        self.burn_in = burn_in


[docs] class EmceeSampler(_MCMCSampler): """Sampler class for emcee's EnsembleSampler Args: posterior (Posterior): posterior object to sample from, must have a .potential method specifiying the log posterior num_chains (int, optional): number of chains to sample from. Defaults to os.cpu_count()-1 thin (int, optional): thinning factor for the chains. Defaults to 10 burn_in (int, optional): number of steps to discard as burn-in. Defaults to 100 """
[docs] def sample(self, nsteps: int, x: np.ndarray, progress: bool = False, skip_initial_state_check: bool = False) -> np.ndarray: """ Sample nsteps samples from the posterior, evaluated at data x. Args: nsteps (int): number of samples to draw x (np.ndarray): data to evaluate the posterior at progress (bool, optional): whether to show progress bar. Defaults to False. skip_initial_state_check (bool, optional): If True, a check that the initial_state can fully explore the space will be skipped. Defaults to False. """ # calculate number of samples per chain per_chain = ceil(nsteps / self.num_chains) # build posterior to sample def log_target(t, x): res = self.posterior.potential( t.astype(np.float32), x.astype(np.float32)) if hasattr(res, 'cpu'): res = np.array(res.detach().cpu()) return res # Initialize walkers theta0 = [self.posterior.prior.sample() for _ in range(self.num_chains)] if isinstance(theta0[0], np.ndarray): theta0 = np.stack(theta0) else: theta0 = np.array(torch.stack(theta0).cpu()) # Set up the sampler self.sampler = emcee.EnsembleSampler( self.num_chains, theta0.shape[-1], log_target, vectorize=False, args=(x,), ) # Sample self.sampler.run_mcmc( theta0, self.burn_in + per_chain, thin_by=self.thin, progress=progress, skip_initial_state_check=skip_initial_state_check ) return self.sampler.get_chain(discard=self.burn_in, flat=True)[:nsteps]
[docs] class PyroSampler(_MCMCSampler): """Sampler class for pyro's samplers. Integrates with pyro through the sbi backend Args: posterior (Posterior): posterior object to sample from, must have a .potential method specifiying the log posterior num_chains (int, optional): number of chains to sample from. Defaults to os.cpu_count()-1 thin (int, optional): thinning factor for the chains. Defaults to 10 burn_in (int, optional): number of steps to discard as burn-in. Defaults to 100 method (str, optional): method to use for sampling. Defaults to 'slice_np_vectorized'. See sbi documentation for more details. """ def __init__( self, posterior: ModelClass, num_chains: int = -1, thin: int = 10, burn_in: int = 100, method='slice_np_vectorized' ) -> None: # convert DirectPosteriors to MCMCPosteriors if isinstance(posterior, DirectPosterior): posterior = self._Direct_to_MCMC(posterior) elif isinstance(posterior, NeuralPosteriorEnsemble): posteriors = posterior.posteriors posterior = NeuralPosteriorEnsemble( [(self._Direct_to_MCMC(p) if isinstance(p, DirectPosterior) else p) for p in posteriors], weights=posterior.weights, theta_transform=posterior.theta_transform ) super().__init__(posterior, num_chains, thin, burn_in) self.method = method def _Direct_to_MCMC(self, posterior: ModelClass) -> ModelClass: """Converts a DirectPosterior to an MCMCPosterior, which is required for sampling with pyro. Args: posterior (DirectPosterior): posterior object to convert Returns: MCMCPosterior: converted posterior object """ potential_fn, theta_transform = posterior_estimator_based_potential( posterior.posterior_estimator, posterior.prior, x_o=None, enable_transform=True, ) return MCMCPosterior( potential_fn=potential_fn, proposal=posterior.prior, theta_transform=theta_transform, device=posterior._device )
[docs] def sample(self, nsteps: int, x: np.ndarray, progress: bool = False) -> np.ndarray: """ Sample nsteps samples from the posterior, evaluated at data x. Args: nsteps (int): number of samples to draw x (np.ndarray): data to evaluate the posterior at progress (bool, optional): whether to show progress bar. Defaults to False. """ return self.posterior.sample( (nsteps,), x=torch.Tensor(x).to(self.posterior._device), method=self.method, num_chains=self.num_chains, thin=self.thin, warmup_steps=self.burn_in, show_progress_bars=progress ).detach().cpu().numpy()
[docs] class DirectSampler(ABC): """Sampler class for posteriors with a direct sampling method, i.e. amortized posterior inference models. Args: posterior (Posterior): posterior object to sample from, must have a .sample method allowing for direct sampling. """ def __init__(self, posterior: ModelClass) -> None: self.posterior = posterior
[docs] def sample(self, nsteps: int, x: Any, progress: bool = False) -> np.ndarray: """ Sample nsteps samples from the posterior, evaluated at data x. Args: nsteps (int): number of samples to draw x (np.ndarray): data to evaluate the posterior at progress (bool, optional): whether to show progress bar. Defaults to False. """ try: x = torch.as_tensor(x) if hasattr(self.posterior, '_device'): x = x.to(self.posterior._device) except ValueError: pass return self.posterior.sample( (nsteps,), x=x, show_progress_bars=progress ).detach().cpu().numpy()
[docs] class VISampler(ABC): """Sampler class for variational inference methods. See https://sbi-dev.github.io/sbi/reference/#sbi.inference.posteriors.vi_posterior.VIPosterior for more details. Args: posterior (Posterior): posterior object to sample from, must have a .potential method specifiying the log posterior dist (str, optional): distribution to use for the variational inference. Defaults to 'maf'. train_kwargs (dict, optional): keyword arguments to pass to the posterior's train method. Defaults to {}. """ def __init__(self, posterior: ModelClass, dist: str = 'maf', **train_kwargs) -> None: if isinstance(posterior, DirectPosterior): posterior = self._Direct_to_VI(posterior) elif isinstance(posterior, NeuralPosteriorEnsemble): posterior = VIPosterior( potential_fn=posterior.potential_fn, prior=posterior.prior, theta_transform=posterior.theta_transform, device=posterior._device ) super().__init__() self.posterior = posterior self.dist = dist self.train_kwargs = train_kwargs def _Direct_to_VI(self, posterior: ModelClass) -> ModelClass: """Converts a DirectPosterior to a VIPosterior, which is required for sampling with variational inference. Args: posterior (DirectPosterior): posterior object to convert Returns: VIPosterior: converted posterior object """ potential_fn, theta_transform = posterior_estimator_based_potential( posterior.posterior_estimator, posterior.prior, x_o=None, enable_transform=True, ) return VIPosterior( potential_fn=potential_fn, prior=posterior.prior, theta_transform=theta_transform, device=posterior._device )
[docs] def sample(self, nsteps: int, x: np.ndarray, progress: bool = False) -> np.ndarray: """ Sample nsteps samples from the posterior, evaluated at data x. Args: nsteps (int): number of samples to draw x (np.ndarray): data to evaluate the posterior at progress (bool, optional): whether to show progress bar. Defaults to False. """ x = torch.Tensor(x).to(self.posterior._device) self.posterior.set_default_x(x) self.posterior.set_q(self.dist) self.posterior.train( show_progress_bar=progress, quality_control=False, **self.train_kwargs ) return self.posterior.sample((nsteps,)).detach().cpu().numpy()