Source code for pyblaze.nn.modules.made

import torch
import torch.nn as nn
import torch.nn.functional as F

[docs]class MADE(nn.Module): """ Masked autoencoder for distribution estimation (MADE) as introduced in `MADE: Masked Autoencoder for Distribution Estimation <>`_ (Germain et al., 2015). In consists of a series of masked linear layers and a given non-linearity between them. """
[docs] def __init__(self, *dims, activation=nn.LeakyReLU()): """ Initializes a new MADE model as a sequence of masked linear layers. Parameters ---------- dims: varargs of int Dimensions of input (first), output (last) and hidden layers. At least one hidden layer must be defined, i.e. at least 3 dimensions must be given. The output dimension must be equal to the input dimension or a multiple of it. Hidden dimensions should be a multiple of the input dimension unless a seed for random initialization is given. activation: torch.nn.Module, default: torch.nn.LeakyReLU() An activation function to be used after linear layers (except for the output layer). This module is shared for all hidden layers. """ super().__init__() if len(dims) < 3: raise ValueError("MADE model must have at least one hidden layer") if dims[-1] % dims[0] != 0: raise ValueError("Output dimension must be multiple of the input dimension") self.dims = dims m_layers = _generate_sequential(dims) layers = [] for i, (in_dim, out_dim) in enumerate(zip(dims, dims[1:])): if i > 0: layers.append(activation) mask = (m_layers[i+1].unsqueeze(-1) >= m_layers[i].unsqueeze(0)).float() layers.append(_MaskedLinear(in_dim, out_dim, mask=mask)) self.mlp = nn.Sequential(*layers)
[docs] def forward(self, x): """ Computes the outputs of the MADE model. Parameters ---------- x: torch.Tensor [..., D] The input (input dimension D). Returns ------- torch.Tensor [..., E] The output (output dimension E). """ return self.mlp(x)
class _MaskedLinear(nn.Linear): def __init__(self, in_features, out_features, mask, bias=True): super().__init__(in_features, out_features, bias) self.register_buffer('mask', mask) def forward(self, x): return F.linear(x, self.weight * self.mask, self.bias) def __repr__(self): return f'MaskedLinear(in_features={self.in_features}, ' + \ f'out_features={self.out_features}, bias={self.bias is not None})' def _generate_sequential(dims): in_dim = dims[0] out_dim = dims[-1] degrees = [torch.arange(in_dim) + 1] for dim in dims[1:-1]: degrees += [torch.arange(dim) % (in_dim - 1) + 1] degrees += [torch.arange(out_dim) % in_dim] return degrees