Source code for utils.ndes_tf

"""
Module to provide loading functions for ndes in pydelfi.

All Mixture Density Networks (mdn) have the configuration:
    hidden_features (int): width of hidden layers (each MDN has 3 hidden layers)
    num_components (int): number of Gaussian components in the mixture model

All flow-based models (maf) have the configuration:
    hidden_features (int): width of hidden layers in the coupling layers
    num_transforms (int): number of coupling layers

"""

import pydelfi
import tensorflow as tf
import logging


[docs] def load_nde_pydelfi( n_params: int, n_data: int, model: str, index: int = 0, engine: str = 'NLE', **model_args ): """ Load an nde from pydelfi. Args: n_params (int): dimensionality of parameters n_data (int): dimensionality of data points model (str): model to use. One of: mdn, maf. index (int, optional): index of the nde in the ensemble. Defaults to 0. engine (str, optional): dummy argument to match sbi interface. Must be set to 'NLE' or will be overwritten. **model_args: additional arguments to pass to the model. """ if 'NLE' not in engine: raise ValueError( f'Engine {engine} not supported in pydelfi backend. ' 'You probably meant to specify engine="NLE" or to use the NPE or NRE' ' engines in the sbi or lampe backends.') model = model.lower() # check the model parameterizations if model == 'mdn': model_defaults = dict(hidden_features=16, num_components=3) else: model_defaults = dict(hidden_features=16, num_transforms=2) if not (set(model_args.keys()) <= set(model_defaults.keys())): raise ValueError( f"Model {model} arguments mispecified. Extra arguments found: " f"{set(model_args.keys()) - set(model_defaults.keys())}.") # set defaults model_args = {**model_defaults, **model_args} # setup models if model == 'mdn': n_hidden = [model_args['hidden_features']] * 3 activations = [tf.tanh] * 3 return pydelfi.ndes.MixtureDensityNetwork( n_parameters=n_params, n_data=n_data, n_components=model_args['num_components'], n_hidden=n_hidden, activations=activations, index=index, ) elif model == 'maf': n_hidden = [model_args['hidden_features']] * \ model_args['num_transforms'] return pydelfi.ndes.ConditionalMaskedAutoregressiveFlow( n_parameters=n_params, n_data=n_data, n_hiddens=n_hidden, n_mades=model_args['num_transforms'], act_fun=tf.tanh, index=index, ) else: raise NotImplementedError(f"Model {model} not implemented for pydelfi")