import torch.nn as nn
[docs]class LinearResidual(nn.Module):
"""
Residual module that models a two-layer MLP with nonlinearity and adds the input to the output:
.. math::
f(x) = x + W_2 \\sigma(W_1 x + b_1) + b_2
Usually, another nonlineary is applied to the output.
"""
[docs] def __init__(self, dim, hidden_dim, activation=nn.ReLU(), bias=True):
"""
Initializes a new residual module.
Parameters
----------
dim: int
The dimension of the input. Equals the dimension of the output.
hidden_dim: int
The hidden dimension (i.e. the output dimension of :math:`W_1`).
activation: torch.nn.Module, default: torch.nn.ReLU()
An activation function to use (i.e. :math:`\\sigma` in the formula above).
bias: bool, default: True
Whether to add biases to the linear layers (i.e. :math:`b_{12}` in the formula above).
"""
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim, bias=bias)
self.activation = activation
self.w2 = nn.Linear(hidden_dim, dim, bias=bias)
[docs] def forward(self, x):
"""
Computes the output of the residual module.
Parameters
----------
x: torch.Tensor [N, D]
The input (batch size N, dimensionality D).
Returns
-------
torch.Tensor [N, D]
The processed output.
"""
z = self.activation(self.w1(x))
return x + self.w2(z)