"""
Module to train posterior inference models using the sbi package
"""
import json
import yaml
import time
import logging
import pickle
import torch
import torch.nn as nn
from pathlib import Path
from typing import Dict, List, Callable, Optional, Union
from torch.distributions import Distribution
from sbi.inference import NeuralInference
from sbi.utils.posterior_ensemble import NeuralPosteriorEnsemble
from .base import _BaseRunner
from ili.dataloaders import _BaseLoader
from ili.utils import load_class, load_from_config, load_nde_sbi, update
logging.basicConfig(level=logging.INFO)
[docs]
class SBIRunner(_BaseRunner):
"""Class to train posterior inference models using the sbi package.
Follows methodology of:
* engine='NPE': https://arxiv.org/abs/1905.07488
* engine='NLE': https://arxiv.org/abs/1805.07226
* engine='NRE': https://arxiv.org/pdf/2002.03712
Args:
prior (Distribution): prior on the parameters
engine (str): type of inference engine to use (NPE, NLE, NRE, or
any sbi inference engine; see _setup_engine)
nets (List[Callable]): list of neural nets for amortized posteriors,
likelihood models, or ratio classifiers
embedding_net (nn.Module): neural network to compress high
dimensional data into lower dimensionality
train_args (Dict): dictionary of hyperparameters for training
out_dir (str, Path): directory where to store outputs
proposal (Distribution): proposal distribution from which existing
simulations were run, for single round inference only. By default,
sbi will set proposal = prior unless a proposal is specified.
name (str): name of the model (for saving purposes)
signatures (List[str]): list of signatures for each neural net
"""
def __init__(
self,
prior: Distribution,
engine: str,
nets: List[Callable],
train_args: Dict = {},
out_dir: Union[str, Path] = None,
device: str = 'cpu',
embedding_net: nn.Module = None,
proposal: Distribution = None,
name: Optional[str] = "",
signatures: Optional[List[str]] = None,
):
super().__init__(
prior=prior,
train_args=train_args,
out_dir=out_dir,
device=device,
name=name,
)
if proposal is None:
self.proposal = prior
else:
self.proposal = proposal
self.engine = engine
self.nets = nets
self.embedding_net = embedding_net
self.num_rounds = self.train_args.pop("num_round", 1)
train_default = dict(
training_batch_size=50,
learning_rate=5e-4,
validation_fraction=0.1,
stop_after_epochs=20,
clip_max_norm=5,
)
train_default.update(self.train_args)
self.train_args = train_default
self.signatures = signatures
if self.signatures is None:
self.signatures = [""]*len(self.nets)
[docs]
@classmethod
def from_config(cls, config_path: Path, **kwargs) -> "SBIRunner":
"""Create an sbi runner from a yaml config file
Args:
config_path (Path, optional): path to config file
**kwargs: optional keyword arguments to overload config file
Returns:
SBIRunner: the sbi runner specified by the config file
"""
with open(config_path, "r") as fd:
config = yaml.safe_load(fd)
# optionally overload config with kwargs
update(config, **kwargs)
# load prior distribution
config['prior']['args']['device'] = config['device']
prior = load_from_config(config["prior"])
# load proposal distributions
proposal = None
if "proposal" in config:
config['proposal']['args']['device'] = config['device']
proposal = load_from_config(config["proposal"])
# load embedding net
if "embedding_net" in config:
embedding_net = load_from_config(
config=config["embedding_net"],
)
else:
embedding_net = nn.Identity()
# load logistics
train_args = config["train_args"]
out_dir = Path(config["out_dir"])
if "name" in config["model"]:
name = config["model"]["name"]+"_"
else:
name = ""
signatures = []
for type_nn in config["model"]["nets"]:
signatures.append(type_nn.pop("signature", ""))
# load inference class and neural nets
engine = config["model"]["engine"]
nets = [load_nde_sbi(config['model']['engine'],
embedding_net=embedding_net,
**model_args)
for model_args in config['model']['nets']]
# initialize
return cls(
prior=prior,
proposal=proposal,
engine=engine,
nets=nets,
device=config["device"],
embedding_net=embedding_net,
train_args=train_args,
out_dir=out_dir,
signatures=signatures,
name=name,
)
def _setup_engine(self, net: nn.Module):
"""Instantiate an sbi inference engine (SNPE/SNLE/SNRE)."""
if self.engine[0] == 'S':
engine_name = self.engine
else:
engine_name = 'S'+self.engine
try:
inference_class = load_class('sbi.inference', engine_name)
except ImportError:
raise ValueError(
f"Model class {self.engine} not supported. "
"Please choose one of NPE/NLE/NRE or SNPE/SNLE/SNRE or "
"an inference class in sbi.inference."
)
if ("NPE" in self.engine) or ("NLE" in self.engine):
return inference_class(
prior=self.prior,
density_estimator=net,
device=self.device,
)
elif ("NRE" in self.engine):
return inference_class(
prior=self.prior,
classifier=net,
device=self.device,
)
else:
raise ValueError(
f"Model class {self.engine} not supported with SBIRunner.")
def _train_round(self, models: List[NeuralInference],
x: torch.Tensor, theta: torch.Tensor,
proposal: Optional[Distribution]):
"""Train a single round of inference for an ensemble of models."""
# append data to models
for model in models:
if ("NPE" in self.engine):
model = model.append_simulations(theta, x, proposal=proposal)
else:
model = model.append_simulations(theta, x)
# get all previous simulations
starting_round = 0 # NOTE: won't work for SNPE_A, but we don't use it
x, _, _ = model.get_simulations(starting_round)
# split into training and validation randomly
num_examples = x.shape[0]
permuted_indices = torch.randperm(num_examples)
num_training_examples = int(
(1 - self.train_args['validation_fraction']) * num_examples)
train_indices, val_indices = (
permuted_indices[:num_training_examples],
permuted_indices[num_training_examples:],
)
posteriors, summaries = [], []
for i, model in enumerate(models):
logging.info(f"Training model {i+1} / {len(models)}.")
# hack to initialize sbi model without training (ref. issue #127)
first_round = False
if model._neural_net is None:
model.train(learning_rate=self.train_args['learning_rate'],
resume_training=False,
max_num_epochs=-1)
model._epochs_since_last_improvement = 0
first_round = True
# set train/validation splits
model.train_indices = train_indices
model.val_indices = val_indices
# train
if ("NPE" in self.engine) & first_round:
model.train(**self.train_args, resume_training=True,
force_first_round_loss=True)
else:
model.epoch, model._val_log_prob = 0, float("-Inf")
model.train(**self.train_args, resume_training=True)
# save model
posteriors.append(model.build_posterior())
summaries.append(model.summary)
# ensemble all trained models, weighted by validation loss
val_logprob = torch.tensor(
[float(x["best_validation_log_prob"][-1]) for x in summaries]
).to(self.device)
# Exponentiate with numerical stability
weights = torch.exp(val_logprob - val_logprob.max())
weights /= weights.sum()
posterior_ensemble = NeuralPosteriorEnsemble(
posteriors=posteriors,
weights=weights,
theta_transform=posteriors[0].theta_transform
) # raises warning due to bug in sbi
# record the name of the ensemble
posterior_ensemble.name = self.name
posterior_ensemble.signatures = self.signatures
return posterior_ensemble, summaries
def _save_models(self, posterior_ensemble: NeuralPosteriorEnsemble,
summaries: List[Dict]):
"""Save models to file."""
logging.info(f"Saving model to {self.out_dir}")
str_p = self.name + "posterior.pkl"
str_s = self.name + "summary.json"
with open(self.out_dir / str_p, "wb") as handle:
pickle.dump(posterior_ensemble, handle)
with open(self.out_dir / str_s, "w") as handle:
json.dump(summaries, handle)
[docs]
def __call__(self, loader: _BaseLoader, seed: int = None):
"""Train your posterior and save it to file
Args:
loader (_BaseLoader): dataloader with stored data-parameter pairs
seed (int): torch seed for reproducibility
"""
# set seed for reproducibility
if seed is not None:
torch.manual_seed(seed)
# setup training engines for each model in the ensemble
logging.info(f"MODEL INFERENCE CLASS: {self.engine}")
models = [self._setup_engine(net) for net in self.nets]
# load single-round data
x = torch.Tensor(loader.get_all_data()).to(self.device)
theta = torch.Tensor(loader.get_all_parameters()).to(self.device)
# instantiate embedding_net architecture, if necessary
if self.embedding_net and hasattr(self.embedding_net, 'initalize_model'):
self.embedding_net.initalize_model(n_input=x.shape[-1])
# train a single round of inference
t0 = time.time()
posterior_ensemble, summaries = self._train_round(
models=models,
x=x,
theta=theta,
proposal=self.proposal,
)
logging.info(f"It took {time.time() - t0} seconds to train models.")
# save if output path is specified
if self.out_dir is not None:
self._save_models(posterior_ensemble, summaries)
return posterior_ensemble, summaries
[docs]
class SBIRunnerSequential(SBIRunner):
"""
Class to train posterior inference models using the sbi package with
multiple rounds.
Follows methodology of:
* engine='SNPE': https://arxiv.org/abs/1905.07488
* engine='SNLE': https://arxiv.org/abs/1805.07226
* engine='SNRE': https://arxiv.org/pdf/2002.03712
"""
[docs]
def __call__(self, loader: _BaseLoader, seed: int = None):
"""Train your posterior and save it to file
Args:
loader (_BaseLoader): data loader with ability to simulate
data-parameter pairs
"""
# Check arguments
if not hasattr(loader, "get_obs_data"):
raise ValueError(
"For sequential inference, the loader must have a method "
"get_obs_data() that returns the observed data."
)
if not hasattr(loader, "simulate"):
raise ValueError(
"For sequential inference, the loader must have a method "
"simulate() that returns simulated data-parameter pairs."
)
# set seed for reproducibility
if seed is not None:
torch.manual_seed(seed)
# setup training engines for each model in the ensemble
logging.info(f"MODEL INFERENCE CLASS: {self.engine}")
models = [self._setup_engine(net) for net in self.nets]
# load observed and pre-run data
x_obs = loader.get_obs_data()
# pre-run data
if len(loader) > 0:
logging.info(
"The first round of inference will use existing sims from the "
"loader. Make sure that the simulations were run from the "
"given proposal distribution for consistency.")
x = torch.Tensor(loader.get_all_data()).to(self.device)
theta = torch.Tensor(loader.get_all_parameters()).to(self.device)
# no pre-run data
else:
logging.info(
"The first round of inference will simulate from the given "
"proposal or prior.")
theta, x = loader.simulate(self.proposal)
x = torch.Tensor(x).to(self.device)
theta = torch.Tensor(theta).to(self.device)
# instantiate embedding_net architecture, if necessary
if self.embedding_net and hasattr(self.embedding_net, 'initalize_model'):
self.embedding_net.initalize_model(n_input=x.shape[-1])
# train multiple rounds of inference
t0 = time.time()
for rnd in range(self.num_rounds):
logging.info(f"Running round {rnd+1} / {self.num_rounds}")
# train a round of inference
posterior_ensemble, summaries = self._train_round(
models=models,
x=x,
theta=theta,
proposal=self.proposal,
)
# update proposal for next round
self.proposal = posterior_ensemble.set_default_x(x_obs)
if rnd < self.num_rounds - 1:
# simulate new data for next round
theta, x = loader.simulate(self.proposal)
x = torch.Tensor(x).to(self.device)
theta = torch.Tensor(theta).to(self.device)
logging.info(f"It took {time.time() - t0} seconds to train models.")
if self.out_dir is not None:
self._save_models(posterior_ensemble, summaries)
return posterior_ensemble, summaries
[docs]
class ABCRunner(_BaseRunner):
"""Class to run ABC inference models using the sbi package"""
def __init__(
self,
prior: Distribution,
engine: str,
train_args: Dict = {},
out_dir: Union[str, Path] = None,
device: str = 'cpu',
name: Optional[str] = "",
):
super().__init__(
prior=prior,
train_args=train_args,
out_dir=out_dir,
device=device,
name=name,
)
self.engine = engine
[docs]
@classmethod
def from_config(cls, config_path: Path, **kwargs) -> "ABCRunner":
"""Create an sbi runner from a yaml config file
Args:
config_path (Path, optional): path to config file
**kwargs: optional keyword arguments to overload config file
Returns:
SBIRunner: the sbi runner specified by the config file
"""
with open(config_path, "r") as fd:
config = yaml.safe_load(fd)
# optionally overload config with kwargs
update(config, **kwargs)
# load prior distribution
prior = load_from_config(config["prior"])
# parse inference engine
engine = config["model"]["engine"]
# load logistics
train_args = config["train_args"]
out_dir = Path(config["out_dir"])
name = ""
if "name" in config["model"]:
name = config["model"]["name"]+"_"
return cls(
prior=prior,
engine=engine,
device=config["device"],
train_args=train_args,
out_dir=out_dir,
name=name,
)
[docs]
def __call__(self, loader: _BaseLoader, seed: int = None):
"""Train your posterior and save it to file
Args:
loader (_BaseLoader): dataloader with stored data-parameter pairs
seed (int): torch seed for reproducibility
"""
t0 = time.time()
logging.info(f"MODEL INFERENCE CLASS: {self.engine}")
x_obs = loader.get_obs_data()
# setup and train each architecture
inference_class = load_class('sbi.inference', self.engine)
model = inference_class(
prior=self.prior,
simulator=loader.simulator
)
samples = model(x_obs, return_summary=False, **self.train_args)
# save if output path is specified
if self.out_dir is not None:
str_p = self.name + "samples.pkl"
with open(self.out_dir / str_p, "wb") as handle:
pickle.dump(samples, handle)
logging.info(
f"It took {time.time() - t0} seconds to run the model.")
return samples