"""
Module to train posterior inference models using the lampe package
"""
import json
import yaml
import time
import logging
import pickle
from copy import deepcopy
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import lampe
from pathlib import Path
from typing import Dict, List, Callable, Optional
from torch.distributions import Distribution
from ili.dataloaders import _BaseLoader
from ili.utils import load_from_config, LampeEnsemble, load_nde_lampe
logging.basicConfig(level=logging.INFO)
[docs]
class LampeRunner():
"""Class to train NPE posterior inference models using the lampe package.
Follows methodology of: https://arxiv.org/abs/1711.01861
Args:
prior (Distribution): prior on the parameters
nets (List[Callable]): list of neural nets for amortized posteriors,
likelihood models, or ratio classifiers
engine (str): name of the engine class (NPE only)
train_args (Dict): dictionary of hyperparameters for training
out_dir (Path): directory where to store outputs
device (str): device to run on
proposal (Distribution): proposal distribution from which existing
simulations were run, for single round inference only. By default,
we will set proposal = prior unless a proposal is specified.
name (str): name of the model (for saving purposes)
signatures (List[str]): list of signatures for each neural net
"""
def __init__(
self,
prior: Distribution,
nets: List[Callable],
engine: str = 'NPE',
train_args: Dict = {},
out_dir: Path = None,
device: str = 'cpu',
proposal: Distribution = None,
name: Optional[str] = "",
signatures: Optional[List[str]] = None,
):
self.prior = prior
self.nets = nets
if engine != 'NPE':
logging.warning(
'lampe only supports NPE engine. Engine set to NPE.')
self.engine = 'NPE'
self.train_args = dict(
training_batch_size=50, learning_rate=5e-4,
stop_after_epochs=30, clip_max_norm=5,
max_epochs=int(1e10),
validation_fraction=0.1)
self.train_args.update(train_args)
self.out_dir = out_dir
if self.out_dir is not None:
self.out_dir = Path(self.out_dir)
self.out_dir.mkdir(parents=True, exist_ok=True)
self.device = device
if proposal is None:
self.proposal = prior
else:
self.proposal = proposal
self.name = name
self.signatures = signatures
if self.signatures is None:
self.signatures = [""]*len(self.nets)
[docs]
@classmethod
def from_config(cls, config_path: Path, **kwargs) -> "LampeRunner":
"""Create a lampe runner from a yaml config file
Args:
config_path (Path, optional): path to config file
**kwargs: optional keyword arguments to overload config file
Returns:
LampeRunner: the lampe runner specified by the config file
"""
with open(config_path, "r") as fd:
config = yaml.safe_load(fd)
# optionally overload config with kwargs
config.update(kwargs)
# load prior distribution
config['prior']['args']['device'] = config['device']
prior = load_from_config(config["prior"])
# load proposal distributions
proposal = None
if "proposal" in config:
config['proposal']['args']['device'] = config['device']
proposal = load_from_config(config["proposal"])
# load embedding net
if "embedding_net" in config:
embedding_net = load_from_config(
config=config["embedding_net"],
)
else:
embedding_net = nn.Identity()
# load logistics
train_args = config["train_args"]
out_dir = Path(config["out_dir"])
if "name" in config["model"]:
name = config["model"]["name"]+"_"
else:
name = ""
signatures = []
for type_nn in config["model"]["nets"]:
signatures.append(type_nn.pop("signature", ""))
# load inference class and neural nets
nets = [load_nde_lampe(embedding_net=embedding_net,
device=config["device"],
**model_args)
for model_args in config['model']['nets']]
# initialize
return cls(
prior=prior,
nets=nets,
train_args=train_args,
out_dir=out_dir,
device=config["device"],
proposal=proposal,
name=name,
signatures=signatures,
)
def _prepare_loader(self, loader: _BaseLoader):
"""Prepare a loader for training."""
if (hasattr(loader, "train_loader") and
hasattr(loader, "val_loader")):
train_loader, val_loader = loader.train_loader, loader.val_loader
elif (hasattr(loader, "get_all_data") and
hasattr(loader, "get_all_parameters")):
x, theta = loader.get_all_data(), loader.get_all_parameters()
# move to device
x = torch.Tensor(x).to(self.device)
theta = torch.Tensor(theta).to(self.device)
# split data into train and validation
mask = torch.randperm(len(x)) < int(
self.train_args['validation_fraction']*len(x))
x_train, x_val = x[~mask], x[mask]
theta_train, theta_val = theta[~mask], theta[mask]
data_train = TensorDataset(x_train, theta_train)
data_val = TensorDataset(x_val, theta_val)
train_loader = DataLoader(
data_train, shuffle=True,
batch_size=self.train_args["training_batch_size"])
val_loader = DataLoader(
data_val, shuffle=False,
batch_size=self.train_args["training_batch_size"])
else:
raise ValueError("Loader must be a subclass of _BaseLoader.")
return train_loader, val_loader
def _loss(self, model, theta, x):
"""Return neg importance-weighted probability as loss."""
log_posterior = model(theta, x)
if self.prior == self.proposal:
return -log_posterior.mean()
log_prior = self.prior.log_prob(theta)
log_proposal = self.proposal.log_prob(theta)
negloss = torch.exp(log_prior - log_proposal) * log_posterior
return -negloss.mean()
def _train_epoch(self, model, train_loader, val_loader, stepper):
"""Train a single epoch of a neural network model."""
model.train()
loss_train, count = [], 0
for x, theta in train_loader:
x, theta = x.to(self.device), theta.to(self.device)
loss_train.append(
stepper(self._loss(model, theta, x)) * len(theta))
count += len(theta)
loss_train = torch.stack(loss_train).sum().item()/count
model.eval()
with torch.no_grad():
loss_val, count = [], 0
for x, theta in val_loader:
x, theta = x.to(self.device), theta.to(self.device)
loss_val.append(self._loss(model, theta, x) * len(theta))
count += len(theta)
loss_val = torch.stack(loss_val).sum().item()/count
return loss_train, loss_val
def _train_round(self, models: List[Callable],
train_loader: DataLoader, val_loader: DataLoader):
"""Train a single round of inference for an ensemble of models."""
# initialize models
x_, y_ = next(iter(train_loader))
models_rnd = [
model(x_, y_, self.prior).to(self.device)
for model in models
]
posteriors, summaries = [], []
for i, model in enumerate(models_rnd):
logging.info(f"Training model {i+1} / {len(models_rnd)}.")
# define optimizer
optimizer = torch.optim.Adam(
model.parameters(),
lr=self.train_args["learning_rate"]
)
stepper = lampe.utils.GDStep(
optimizer, clip=self.train_args["clip_max_norm"])
# train model
best_val = float('inf')
wait = 0
summary = {'training_log_probs': [], 'validation_log_probs': []}
with tqdm(iter(range(self.train_args["max_epochs"])),
unit=' epochs') as tq:
for epoch in tq:
loss_train, loss_val = self._train_epoch(
model=model,
train_loader=train_loader,
val_loader=val_loader,
stepper=stepper,
)
tq.set_postfix(
loss=loss_train,
loss_val=loss_val,
)
summary['training_log_probs'].append(-loss_train)
summary['validation_log_probs'].append(-loss_val)
# check for convergence
if loss_val < best_val:
best_val = loss_val
best_model = deepcopy(model.state_dict())
wait = 0
elif wait > self.train_args["stop_after_epochs"]:
break
else:
wait += 1
else:
logging.warning(
"Training did not converge in "
f"{self.train_args['max_epochs']} epochs.")
summary['best_validation_log_prob'] = -best_val
summary['epochs_trained'] = epoch
# save model
model.load_state_dict(best_model)
posteriors.append(model)
summaries.append(summary)
# ensemble all trained models, weighted by validation loss
val_logprob = torch.tensor(
[float(x["best_validation_log_prob"]) for x in summaries]
).to(self.device)
# Exponentiate with numerical stability
weights = torch.exp(val_logprob - val_logprob.max())
weights /= weights.sum()
posterior_ensemble = LampeEnsemble(posteriors, weights)
# record the name of the ensemble
posterior_ensemble.name = self.name
posterior_ensemble.signatures = self.signatures
return posterior_ensemble, summaries
def _save_models(self, posterior_ensemble: LampeEnsemble,
summaries: List[Dict]):
"""Save models to file."""
logging.info(f"Saving model to {self.out_dir}")
str_p = self.name + "posterior.pkl"
str_s = self.name + "summary.json"
with open(self.out_dir / str_p, "wb") as handle:
pickle.dump(posterior_ensemble, handle)
with open(self.out_dir / str_s, "w") as handle:
json.dump(summaries, handle)
[docs]
def __call__(self, loader: _BaseLoader, seed: int = None):
"""Train your posterior and save it to file
Args:
loader (_BaseLoader): dataloader with stored data-parameter pairs
seed (int): torch seed for reproducibility
"""
# set seed for reproducibility
if seed is not None:
torch.manual_seed(seed)
# setup training engines for each model in the ensemble
logging.info("MODEL INFERENCE CLASS: NPE")
# load single-round data
train_loader, val_loader = self._prepare_loader(loader)
# train a single round of inference
t0 = time.time()
posterior_ensemble, summaries = self._train_round(
models=self.nets,
train_loader=train_loader,
val_loader=val_loader,
)
logging.info(f"It took {time.time() - t0} seconds to train models.")
# save if output path is specified
if self.out_dir is not None:
self._save_models(posterior_ensemble, summaries)
return posterior_ensemble, summaries