Source code for pyblaze.nn.modules.gp

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd

[docs]class GradientPenalty(nn.Module): """ Implementation of the gradient penalty as presented in "Improved Training of Wasserstein GANs" (Gulrajani et al., 2017). It ensures that the norm of the critic's gradient is close to 1, ensuring Lipschitz continuity. Optionally, the gradient penalty can be replaced by a Lipschitz penalty which does not penalize gradients smaller than one. It is taken from "On the Regularization of Wasserstein GANs" (Petzka et al., 2018). """
[docs] def __init__(self, module, coefficient=10, lipschitz=False): """ Initializes a new gradient penalty for the given module. Parameters ---------- module: torch.nn.Module The module whose gradient norm should be penalized. coefficient: float, default: 10 The coefficient for the gradient penalty. The default value is taken from the original WGAN-GP paper. lipschitz: boolean, default: False Whether to use Lipschitz penalty instead of simple gradient penalty (not penalizing gradient norms smaller than 1). """ super().__init__() self.module = module self.coefficient = coefficient self.lipschitz = lipschitz
[docs] def forward(self, fake, real): """ Computes the loss incurred on the penalized module based on a batch of fake and real instances. Parameters ---------- fake: torch.Tensor [N, ...] The fake instances (batch size N). real: torch.Tensor [N, ...] The real instances. Returns ------- torch.Tensor [1] The gradient penalty times the penalty coefficient. """ interpolation, out = self.interpolate(fake, real) grad_out = torch.ones_like(out).requires_grad_(False) gradients = autograd.grad( out, interpolation, grad_outputs=grad_out, create_graph=True, retain_graph=True )[0] gradients = gradients.contiguous().view(gradients.size(0), -1) target = gradients.norm(2, dim=1) - 1 if self.lipschitz: target = F.relu(target) return self.coefficient * (target ** 2).mean()
[docs] def interpolate(self, fake, real): """ Interpolates the given fake and real instances with an arbitrary alpha value weighing each batch sample. By default, it assumes that fake and real instances can be interpolated over the first dimension. This method may be overridden by subclasses for more complicated models. Parameters ---------- fake: torch.Tensor [N, ...] The fake instances passed to the module (batch size N). real: torch.Tensor [N, ...] The real instances passed to the module. Returns ------- torch.Tensor [N, ...] The interpolation which (which must have `requires_grad` set to `True`). torch.Tensor [N] The module's output for the interpolated fake and real instances. """ batch_size = fake.size(0) dim = fake.dim() alpha = fake.new(batch_size, *([1] * (dim - 1))).uniform_(0.0, 1.0) interpolation = alpha * fake + (1 - alpha) * real interpolation.requires_grad_() return interpolation, self.module(interpolation)