import torch.nn as nn

[docs]class VAELoss(nn.Module): """ Loss for the reconstruction error of a variational autoencoder when the encoder parametrizes a Gaussian distribution. Taken from "Auto-Encoding Variational Bayes" (Kingma and Welling, 2014). """
[docs] def __init__(self, loss): """ Initializes a new loss for a variational autoencoder. Parameters ---------- loss: torch.nn.Module The loss to incur for the decoder's output given `(x_pred, x_true)`. This might e.g. be a BCE loss. **The reduction must be 'none'.** """ super().__init__() self.loss = loss
[docs] def forward(self, x_pred, mu, logvar, x_true): """ Computes the loss of the decoder's output. Parameters ---------- x_pred: torch.Tensor [N, ...] The outputs of the decoder (batch size N). mu: torch.Tensor [N, D] The output for the means from the encoder (dimensionality D). logvar: torch.Tensor [N, D] The output for the log-values of the diagonal entries of the covariance matrix. x_true: torch.Tensor [N, ...] The target outputs for the decoder. Returns ------- torch.Tensor [1] The loss incurred computed as the actual loss plus a weighted KL-divergence. """ dims = range(1, x_pred.dim()) # we want to sum over all dimensions but the batch dimension loss = self.loss(x_pred, x_true).sum(tuple(dims)) kld = -0.5 * (1 + logvar - mu * mu - logvar.exp()).sum(-1) return (loss + kld).mean()