Source code for utils.ndes_pt

"""
Module to provide loading functions for ndes in various backends.

All Mixture Density Networks (mdn) have the configuration:
    hidden_features (int): width of hidden layers (each MDN has 3 hidden layers)
    num_components (int): number of Gaussian components in the mixture model

All flow-based models (maf, nsf, made) have the configuration:
    hidden_features (int): width of hidden layers in the coupling layers
    num_transforms (int): number of coupling layers

Linear classifiers (linear) have no arguments.

MLP and ResNet classifiers (mlp, resnet) have the configuration:
    hidden_features (int): width of hidden layers (each has 2 hidden layers)
"""

import logging

import numpy as np
import sbi
import torch
from torch import nn
import lampe
import zuko
from tqdm import tqdm
from typing import List, Any, Optional
from copy import deepcopy
from torch.distributions import Distribution
from torch.distributions.transforms import (
    identity_transform, AffineTransform, Transform)


[docs] def load_nde_sbi( engine: str, model: str, embedding_net: nn.Module = nn.Identity(), **model_args): """Load an nde from sbi. Args: engine (str): engine to use. One of: NPE, NLE, NRE, SNPE, SNLE, or SNRE. model (str): model to use. One of: mdn, maf, nsf, made, linear, mlp, resnet. embedding_net (nn.Module, optional): embedding network to use. Defaults to nn.Identity(). **model_args: additional arguments to pass to the model. """ # load NRE models (linear, mlp, resnet) if 'NRE' in engine: if model not in ['linear', 'mlp', 'resnet']: raise ValueError(f"Model {model} not implemented for {engine}.") return sbi.utils.classifier_nn( model=model, embedding_net_x=embedding_net, **model_args) if model not in ['mdn', 'maf', 'nsf', 'made']: raise ValueError(f"Model {model} not implemented for {engine}.") if (model == 'mdn'): # check for arguments if not (set(model_args.keys()) <= {'hidden_features', 'num_components'}): raise ValueError(f"Model {model} arguments mispecified.") else: # check for arguments if not (set(model_args.keys()) <= {'hidden_features', 'num_transforms'}): raise ValueError(f"Model {model} arguments mispecified.") # Load NPE models (mdn, maf, nsf, made) if 'NPE' in engine: return sbi.utils.posterior_nn( model=model, embedding_net=embedding_net, **model_args) # Load NLE models (mdn, maf, nsf, made) if 'NLE' in engine: if not isinstance(embedding_net, nn.Identity): logging.warning( "Using an embedding_net with NLE models compresses theta, not " "x as might be expected.") return sbi.utils.likelihood_nn( model=model, embedding_net=embedding_net, **model_args) raise ValueError(f"Engine {engine} not implemented.")
[docs] class LampeNPE(nn.Module): """Simple wrapper to add an embedding network to an NPE model.""" def __init__( self, nde: nn.Module, prior: Distribution, embedding_net: nn.Module = nn.Identity(), x_transform: Transform = identity_transform, theta_transform: Transform = identity_transform ): super().__init__() self.nde = nde self.prior = prior self.embedding_net = embedding_net self.x_transform = x_transform self.theta_transform = theta_transform self._device = 'cpu' self.max_sample_size = 1000
[docs] def forward( self, theta: torch.Tensor, x: Any ) -> torch.Tensor: # check inputs if isinstance(x, (list, np.ndarray)): x = torch.Tensor(x) if isinstance(theta, (list, np.ndarray)): theta = torch.Tensor(theta) x = x.to(self._device) theta = theta.to(self._device) logprob = self.nde( self.theta_transform.inv(theta), self.embedding_net(self.x_transform.inv(x))) log_abs_det_jacobian = self.theta_transform.log_abs_det_jacobian( theta, theta # just for shape ) # for Affine/IdentityTransform, this outputs a constant return logprob - log_abs_det_jacobian
potential = forward
[docs] def flow(self, x: torch.Tensor): # -> Distribution if hasattr(x, 'float'): x = x.float() return self.nde.flow( self.embedding_net(self.x_transform.inv(x)).float())
[docs] def sample( self, shape: tuple, x: torch.Tensor, show_progress_bars: bool = True ) -> torch.Tensor: """Accept-reject sampling""" # check inputs if isinstance(x, (list, np.ndarray)): x = torch.Tensor(x) x = x.to(self._device) # sample num_samples = np.prod(shape) if num_samples == 0: return torch.empty(shape) pbar = tqdm( disable=not show_progress_bars, total=num_samples, desc=f"Drawing {num_samples} posterior samples", ) batch_size = min(self.max_sample_size, num_samples) num_remaining = num_samples accepted = [] while num_remaining > 0: candidates = self.theta_transform( self.flow(x).sample((batch_size,))) are_accepted = self.prior.support.check(candidates) samples = candidates[are_accepted] accepted.append(samples) num_remaining -= len(samples) pbar.update(len(samples)) pbar.close() samples = torch.cat(accepted, dim=0)[:num_samples] return samples.reshape(*shape, -1)
[docs] def to(self, device): self._device = device return super().to(device)
[docs] class LampeEnsemble(nn.Module): """Simple module to wrap an ensemble of NPE models.""" def __init__( self, posteriors: List[LampeNPE], weights: torch.Tensor ): super().__init__() self.posteriors = nn.ModuleList(posteriors) self.weights = weights assert len(self.posteriors) == len(self.weights) self.prior = self.posteriors[0].prior self._device = posteriors[0]._device self.num_components = len(self.posteriors)
[docs] def forward(self, theta: torch.Tensor, x: torch.Tensor) -> torch.Tensor: return torch.stack([ weight * npe(theta, x) for weight, npe in zip(self.weights, self.posteriors) ], dim=-1)
potential = forward
[docs] def sample( self, shape: tuple, x: Any, show_progress_bars: bool = True ): # determine number of samples per model num_samples = np.prod(shape) per_model = torch.round( num_samples * self.weights/self.weights.sum()) # .numpy().astype(int) if show_progress_bars: logging.info(f"Sampling models with {per_model} samples each.") # sample samples = torch.cat([ nde.sample((int(N),), x, show_progress_bars=show_progress_bars) for nde, N in zip(self.posteriors, per_model) ], dim=0) samples = samples[:num_samples] return samples.reshape(*shape, -1)
[docs] def log_prob(self, theta: torch.Tensor, x: torch.Tensor) -> torch.Tensor: return self.forward(theta, x).sum(dim=-1).detach()
[docs] def to(self, device): self._device = device return super().to(device)
[docs] def load_nde_lampe( model: str, embedding_net: nn.Module = nn.Identity(), device: Optional[str] = 'cpu', x_normalize: bool = True, theta_normalize: bool = True, **model_args ): """Load an nde from lampe. Models include: - mdn: Mixture Density Network (https://publications.aston.ac.uk/id/eprint/373/1/NCRG_94_004.pdf) - maf: Masked Autoregressive Flow (https://arxiv.org/abs/1705.07057) - nsf: Neural Spline Flow (https://arxiv.org/abs/1906.04032) - cnf: Continuous Normalizing Flow (https://arxiv.org/abs/1810.01367) - nice: Non-linear Independent Components Estimation (https://arxiv.org/abs/1410.8516) - gf: Gaussianization Flow (https://arxiv.org/abs/2003.01941) - sospf: Sum-of-Squares Polynomial Flow (https://arxiv.org/abs/1905.02325) - naf: Neural Autoregressive Flow (https://arxiv.org/abs/1804.00779) - unaf: Unconstrained Neural Autoregressive Flow (https://arxiv.org/abs/1908.05164) For more info, see zuko at https://zuko.readthedocs.io/en/stable/index.html Args: model (str): model to use. One of: mdn, maf, nsf, ncsf, cnf, nice, sospf, gf, naf. embedding_net (nn.Module, optional): embedding network to use. Defaults to nn.Identity(). device (str, optional): device to use. Defaults to 'cpu'. x_normalize (bool, optional): whether to z-normalize x. Defaults to True. theta_normalize (bool, optional): whether to z-normalize theta. Defaults to True. **model_args: additional arguments to pass to the model. """ if model == 'mdn': # for mixture density networks if not (set(model_args.keys()) <= {'hidden_features', 'num_components'}): raise ValueError(f"Model {model} arguments mispecified.") model_args['hidden_features'] = [model_args['hidden_features']] * 3 model_args['components'] = model_args.pop('num_components', 2) flow_class = zuko.flows.mixture.GMM elif model == 'cnf': # for continuous flow models # number of time embeddings model_args['hidden_features'] = [ model_args['hidden_features']] * 2 model_args['freqs'] = model_args.pop('num_transforms', 2) flow_class = zuko.flows.continuous.CNF else: # for all discrete flow models if not (set(model_args.keys()) <= {'hidden_features', 'num_transforms'}): raise ValueError(f"Model {model} arguments mispecified.") model_args['hidden_features'] = [ model_args['hidden_features']] * 2 model_args['transforms'] = model_args.pop('num_transforms', 2) if model == 'maf': flow_class = zuko.flows.autoregressive.MAF elif model == 'nsf': flow_class = zuko.flows.spline.NSF elif model == 'nice': flow_class = zuko.flows.coupling.NICE elif model == 'gf': flow_class = zuko.flows.gaussianization.GF elif model == 'sospf': flow_class = zuko.flows.polynomial.SOSPF elif model == 'naf': flow_class = zuko.flows.neural.NAF elif model == 'unaf': flow_class = zuko.flows.neural.UNAF embedding_net = deepcopy(embedding_net) def net_constructor(x_batch, theta_batch, prior): if hasattr(embedding_net, 'initalize_model'): embedding_net.initalize_model(x_batch.shape[-1]) # pass data through embedding network z_batch = embedding_net(x_batch) z_shape = z_batch.shape[1:] theta_shape = theta_batch.shape[1:] if (len(z_shape) > 1): raise ValueError("Embedding network must return a vector.") if (len(theta_shape) > 1): raise ValueError("Parameters theta must be a vector.") # instantiate a neural density estimator nde = lampe.inference.NPE( theta_dim=theta_shape[0], x_dim=z_shape[0], build=flow_class, **model_args ).to(device) # determine transformations x_transform = identity_transform theta_transform = identity_transform if x_normalize: x_mean = x_batch.mean(dim=0).to(device) x_std = x_batch.std(dim=0).to(device) # avoid division by zero x_std = torch.clamp(x_std, min=1e-16) # z-normalize x x_transform = AffineTransform( loc=x_mean, scale=x_std, event_dim=1) if theta_normalize: theta_mean = theta_batch.mean(dim=0).to(device) theta_std = theta_batch.std(dim=0).to(device) # avoid division by zero theta_std = torch.clamp(theta_std, min=1e-16) # z-normalize theta theta_transform = AffineTransform( loc=theta_mean, scale=theta_std, event_dim=1) npe = LampeNPE( nde=nde, embedding_net=embedding_net, prior=prior, x_transform=x_transform, theta_transform=theta_transform ).to(device) return npe return net_constructor