"""
Wrapper module to import distributions from torch.distributions
and make their configuration easier in the ltu-ili interface.
Specifically, if we're using a vector of parameters, we want to
be able to pass the vector to the log_prob method of the distribution
and return a scalar. This is not the default behavior of several
distributions in torch.distributions, so we wrap them here.
"""
from torch.distributions.utils import broadcast_all
from torch.distributions import constraints, Distribution
from numbers import Number
import math
import torch
from torch.distributions import Independent
from .import_utils import load_class
# Not used directly, but raises error if tried loading with wrong backend
import sbi
# These distributions will be loaded and wrapped
dist_names = [
'Uniform', 'Normal', 'Beta', 'Cauchy', 'Chi2', 'Exponential',
'FisherSnedecor', 'Gamma', 'Gumbel', 'HalfCauchy', 'HalfNormal', 'Laplace',
'LogNormal', 'Pareto', 'StudentT', 'VonMises', 'Weibull'
]
[docs]
class CustomIndependent(Independent):
def __init__(self, *args, device='cpu', **kwargs):
# Convert args and kwargs to torch tensors
args = [torch.as_tensor(v, dtype=torch.float32, device=device)
for v in args]
kwargs = {k: torch.as_tensor(v, dtype=torch.float32, device=device)
for k, v in kwargs.items()}
self.device = device
self.dist = self.Distribution(*args, **kwargs)
return super().__init__(self.dist, 1)
# Load and wrap distributions
dist_dict = {}
for name in dist_names:
dist = load_class('torch.distributions', name)
dist_dict['Independent'+name] = \
type('Independent'+name, (CustomIndependent,), {'Distribution': dist})
locals().update(dist_dict)
# Now, for all distributions in dist_names, we have a custom Independent
# version that can handle vectorized inputs. For example, if 'Normal' is in
# dist_names, then we have a 'IndependentNormal' class parameterized by a
# loc and scale vector
Uniform = IndependentUniform # Uniform is always independent
# load multivariate, continuous distributions
# this is done for API convenience, but we don't wrap them
from torch.distributions import ( # noqa
MultivariateNormal, LowRankMultivariateNormal
)
# redefining these to not require torch tensors as inputs
[docs]
class MultivariateNormal(MultivariateNormal):
def __init__(self, device='cpu', *args, **kwargs):
# Convert args and kwargs to torch tensors
args = [torch.as_tensor(v, dtype=torch.float32, device=device)
for v in args]
kwargs = {k: torch.as_tensor(v, dtype=torch.float32, device=device)
for k, v in kwargs.items()}
self.device = device
return super().__init__(*args, **kwargs)
[docs]
class LowRankMultivariateNormal(LowRankMultivariateNormal):
def __init__(self, device='cpu', *args, **kwargs):
# Convert args and kwargs to torch tensors
args = [torch.as_tensor(v, dtype=torch.float32, device=device)
for v in args]
kwargs = {k: torch.as_tensor(v, dtype=torch.float32, device=device)
for k, v in kwargs.items()}
self.device = device
return super().__init__(*args, **kwargs)
# Define TruncatedIndependentNormal to mirror pydelfi distribution
# Adapted from: https://github.com/toshas/torch_truncnorm
CONST_SQRT_2 = math.sqrt(2)
CONST_INV_SQRT_2PI = 1 / math.sqrt(2 * math.pi)
CONST_INV_SQRT_2 = 1 / math.sqrt(2)
CONST_LOG_INV_SQRT_2PI = math.log(CONST_INV_SQRT_2PI)
CONST_LOG_SQRT_2PI_E = 0.5 * math.log(2 * math.pi * math.e)
class _TruncatedStandardNormal(Distribution):
"""Truncated Standard Normal distribution.
Source: https://github.com/toshas/torch_truncnorm
Theory: https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
"""
arg_constraints = {
"a": constraints.real,
"b": constraints.real,
}
has_rsample = True
eps = 1e-6
def __init__(self, a, b, validate_args=None):
self.a, self.b = broadcast_all(a, b)
if isinstance(a, Number) and isinstance(b, Number):
batch_shape = torch.Size()
else:
batch_shape = self.a.size()
super().__init__(
batch_shape, validate_args=validate_args
)
if self.a.dtype != self.b.dtype:
raise ValueError("Truncation bounds types are different")
if any((self.a >= self.b).view(-1,).tolist()):
raise ValueError("Incorrect truncation range")
eps = self.eps
self._dtype_min_gt_0 = eps
self._dtype_max_lt_1 = 1 - eps
self._little_phi_a = self._little_phi(self.a)
self._little_phi_b = self._little_phi(self.b)
self._big_phi_a = self._big_phi(self.a)
self._big_phi_b = self._big_phi(self.b)
self._Z = (self._big_phi_b - self._big_phi_a).clamp(eps, 1 - eps)
self._log_Z = self._Z.log()
little_phi_coeff_a = torch.nan_to_num(self.a, nan=math.nan)
little_phi_coeff_b = torch.nan_to_num(self.b, nan=math.nan)
self._lpbb_m_lpaa_d_Z = (
self._little_phi_b * little_phi_coeff_b
- self._little_phi_a * little_phi_coeff_a
) / self._Z
self._mean = -(self._little_phi_b - self._little_phi_a) / self._Z
self._variance = (
1
- self._lpbb_m_lpaa_d_Z
- ((self._little_phi_b - self._little_phi_a) / self._Z) ** 2
)
self._entropy = CONST_LOG_SQRT_2PI_E + self._log_Z - 0.5 * self._lpbb_m_lpaa_d_Z
@constraints.dependent_property
def support(self):
return constraints.interval(self.a, self.b)
@property
def mean(self):
return self._mean
@property
def variance(self):
return self._variance
def entropy(self):
return self._entropy
@property
def auc(self):
return self._Z
@staticmethod
def _little_phi(x):
return (-(x**2) * 0.5).exp() * CONST_INV_SQRT_2PI
def _big_phi(self, x):
phi = 0.5 * (1 + (x * CONST_INV_SQRT_2).erf())
return phi.clamp(self.eps, 1 - self.eps)
@staticmethod
def _inv_big_phi(x):
return CONST_SQRT_2 * (2 * x - 1).erfinv()
def cdf(self, value):
if self._validate_args:
self._validate_sample(value)
return ((self._big_phi(value) - self._big_phi_a) / self._Z).clamp(0, 1)
def icdf(self, value):
y = self._big_phi_a + value * self._Z
y = y.clamp(self.eps, 1 - self.eps)
return self._inv_big_phi(y)
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
out = CONST_LOG_INV_SQRT_2PI - self._log_Z - (value**2) * 0.5
out.masked_fill_((value < self.a) | (value > self.b), -float("inf"))
return out.squeeze()
def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
p = torch.empty(shape, device=self.a.device).uniform_(
self._dtype_min_gt_0, self._dtype_max_lt_1
)
out = self.icdf(p)
if len(out.shape) == 1:
return out.unsqueeze(-1)
return out # this is a hack
class _UnivariateTruncatedNormal(_TruncatedStandardNormal):
"""Truncated Normal distribution.
Source: https://github.com/toshas/torch_truncnorm
Theory: https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
"""
has_rsample = True
def __init__(self, loc, scale, low, high, validate_args=None):
# scale = scale.clamp_min(self.eps)
self.loc, self.scale, a, b = broadcast_all(loc, scale, low, high)
self._non_std_a = a
self._non_std_b = b
a = (a - self.loc) / self.scale
b = (b - self.loc) / self.scale
super().__init__(a, b, validate_args=validate_args)
self.base = _TruncatedStandardNormal(a, b, validate_args=validate_args)
self._log_scale = self.scale.log()
self._mean = self._mean * self.scale + self.loc
self._variance = self._variance * self.scale**2
self._entropy += self._log_scale
@constraints.dependent_property
def support(self):
return constraints.interval(self._non_std_a, self._non_std_b)
def _to_std_rv(self, value):
return (value - self.loc) / self.scale
def _from_std_rv(self, value):
return value * self.scale + self.loc
def cdf(self, value):
return self.base.cdf(self._to_std_rv(value))
def icdf(self, value):
sample = self._from_std_rv(super().icdf(value))
# clamp data but keep gradients
sample_clip = torch.stack(
[sample.detach(), self._non_std_a.detach().expand_as(sample)], 0
).max(0)[0]
sample_clip = torch.stack(
[sample_clip, self._non_std_b.detach().expand_as(sample)], 0
).min(0)[0]
sample.data.copy_(sample_clip)
return sample
def log_prob(self, value):
value = self._to_std_rv(value)
return self.base.log_prob(value) - self._log_scale
# Define IndependentTruncatedNormal as a class for multivariate priors
IndependentTruncatedNormal = \
type('IndependentTruncatedNormal', (CustomIndependent,),
{'Distribution': _UnivariateTruncatedNormal})