Source code for pyblaze.nn.modules.distribution

import torch.nn as nn
import pyblaze.nn.functional as X

[docs]class TransformedNormalLoss(nn.Module): """ This loss returns the negative log-likelihood (NLL) of some data that has been transformed via invertible transformations. The NLL is computed via the negative sum of the log-determinant of the transformations and the log-probability of observing the output under a standard Normal distribution. This loss is typically used to fit a normalizing flow. """
[docs] def __init__(self, reduction='mean'): """ Initializes a new NLL loss. Parameters ---------- reduction: str, default: 'mean' The kind of reduction to perform. Must be one of ['mean', 'sum', 'none']. """ super().__init__() if reduction not in ('mean', 'sum', 'none'): raise ValueError(f"Invalid reduction {reduction}") self.reduction = reduction
[docs] def forward(self, z, log_det): """ Computes the NLL for the given transformed values. Parameters ---------- z: torch.Tensor [N, D] The output values of the transformations (batch size N, dimensionality D). log_det: torch.Tensor [N] The log-determinants of the transformations for all values. Returns ------- torch.Tensor [1] The mean NLL for all given values. """ nll = -X.log_prob_standard_normal(z) - log_det if self.reduction == 'mean': return nll.mean() if self.reduction == 'sum': return nll.sum() return nll
[docs]class TransformedGmmLoss(nn.Module): """ This loss returns the negative log-likelihood (NLL) of some data that has been transformed via invertible transformations. The NLL is computed via the negative sum of the log-determinant of the transformations and the log-probability of observing the output under a GMM with predefined means and unit variances. The simple alternative to this loss is the :class:`TransformedNormalLoss`. """
[docs] def __init__(self, means, trainable=False, reduction='mean'): """ Initializes a new GMM loss. Parameters ---------- means: torch.Tensor [N, D] The means of the GMM. For random initialization of the means, consider using :meth:`pyblaze.nn.functional.random_gmm`. trainable: bool, default: False Whether the means are trainable. reduction: str, default: 'mean' The kind of reduction to perform. Must be one of ['mean', 'sum', 'none']. """ super().__init__() if trainable: self.means = nn.Parameter(means) else: self.register_buffer('means', means) self.reduction = reduction
[docs] def forward(self, z, log_det): """ Computes the NLL for the given transformed values. Parameters ---------- z: torch.Tensor [N, D] The output values of the transformations (batch size N, dimensionality D). log_det: torch.Tensor [N] The log-determinants of the transformations for all values. Returns ------- torch.Tensor [1] The mean NLL for all given values. """ nll = -X.log_prob_standard_gmm(z, self.means) - log_det if self.reduction == 'mean': return nll.mean() if self.reduction == 'sum': return nll.sum() return nll