Source code for pyblaze.nn.modules.wasserstein

import torch.nn as nn

[docs]class WassersteinLossGenerator(nn.Module): """ Computes the loss of the generator in the Wasserstein GAN setting. """
[docs] def forward(self, out): """ Computes the loss for the generator given the outputs of the critic. Parameters ---------- out: torch.Tensor [N] The output values of the critic (batch size N). Returns ------- torch.Tensor [1] The loss incurred for the generator. """ return -out.mean()
[docs]class WassersteinLossCritic(nn.Module): """ Computes the loss of the critic in the Wasserstein GAN setting. This loss optionally includes a gradient penalty that should be used if no other regularization methods (weight clipping, spectral normalization, ...) are used. """
[docs] def __init__(self, gradient_penalty=None): """ Initializes a new Wasserstein loss for a critic. Parameters ---------- gradient_penalty: nn.Module, default: False A gradient penalty object that accepts fake and real inputs to the critic and computes the gradient penalty for it. """ super().__init__() self.gradient_penalty = gradient_penalty
[docs] def forward(self, out_fake, out_real, *inputs): """ Computes the loss for the critic given the outputs of itself and potentially a tuple of inputs. Parameters ---------- out_fake: torch.Tensor [N] The critic's output for the fake inputs (batch size N). out_real: torch.Tensor [N] The critic's output for the real inputs. inputs: tuple of (torch.Tensor [N, ...], torch.Tensor [N, ...]) A tuple of `(in_fake, in_real)` that must be given if a gradient penalty is used. Returns ------- torch.Tensor [1] The loss incurred for the critic. torch.Tensor [1] The estimated Earth mover's (Wasserstein-1) distance (equal to the detached negative loss if there is no gradient penalty). """ loss = out_fake.mean() - out_real.mean() wass_dist = -loss.detach() penalty = 0 if self.gradient_penalty is not None: penalty = self.gradient_penalty(*inputs) return loss + penalty, wass_dist