from typing import List, Tuple, Optional
import numpy as np
import torch
import torch.nn as nn
import torch.jit as jit
[docs]class StackedLSTM(nn.Module):
"""
The stacked LSTM is an extension to PyTorch's native LSTM allowing stacked LSTMs with different
hidden dimensions being stacked. Furthermore, it allows using an LSTM on a GPU without cuDNN.
This is useful when higher-order gradients are required. In all other cases, it is best to use
PyTorch's builtin LSTM.
"""
batch_first: jit.Final[bool]
[docs] def __init__(self, input_size, hidden_sizes, bias=True, batch_first=False, cudnn=True):
"""
Initializes a new stacked LSTM according to the given parameters.
Parameters
----------
input_size: int
The dimension of the sequence's elements.
hidden_sizes: list of int
The dimensions of the stacked LSTM's layers.
bias: bool, default: True
Whether to use biases in the LSTM.
batch_first: bool, default: False
Whether the batch or the sequence can be found in the first
dimension.
cudnn: bool, default: True
Whether to use PyTorch's LSTM implementation which uses cuDNN on Nvidia GPUs. You
usually don't want to change the default value, however, PyTorch's default
implementation does not allow higher-order gradients of the LSTMCell as of version
1.1.0. When this value is set to False, we therefore use a (slower) implementation
of a LSTM cell which allows higher-order gradients.
"""
super().__init__()
self.batch_first = batch_first
self.stacked_cell = StackedLSTMCell(input_size, hidden_sizes, bias=bias, cudnn=cudnn)
[docs] def forward(self, inputs: torch.Tensor,
initial_states: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
return_sequence: bool = True):
"""
Computes the forward pass through the stacked LSTM.
Parameters
----------
inputs: torch.Tensor [S, B, N]
The inputs fed to the LSTM one after the other. Sequence length S, batch size B, and
input size N. If `batch_first` is set to True, the first and second dimension should
be swapped.
initial_states: list of tuple of (torch.Tensor [H_i], torch.Tensor [H_i]), default: None
The initial states for all LSTM layers. The length of the list must match the number of
layers in the LSTM, the sizes of the states must match the hidden sizes of the LSTM
layers. If None is given, the initial states are defaulted to all zeros.
return_sequence: bool, default: True
Whether to return all outputs from the last LSTM layer or only the last one.
Returns
-------
torch.Tensor [S, B, K] or torch.Tensor [B, K]
Depending on whether sequences are returned, either all outputs or only the output from
the last cell are returned. If the stacked LSTM was initialized with `batch_first`,
the first and second dimension are swapped when sequences are returned.
"""
if self.batch_first:
inputs = inputs.transpose(1, 0)
sequence_length = inputs.size(0)
# Initialize the state to empty vectors is needed for jit to properly
# compile the function
if initial_states is None:
states = [(torch.empty(0), torch.empty(0))]
else:
states = initial_states
# Iterate over sequence
outputs = []
for n in range(sequence_length):
output, states = self.stacked_cell(
inputs[n], None if states[0][0].size(0) == 0 else states
)
if return_sequence or n == sequence_length - 1:
outputs.append(output)
if return_sequence:
result = torch.stack(outputs)
if self.batch_first:
# set batch first, sequence length second
result = result.transpose(1, 0)
return result
return outputs[0]
[docs]class StackedLSTMCell(nn.Module):
"""
Actually, the StackedLSTMCell can easily be constructed from existing modules, however, a bug
in PyTorch's JIT compiler prevents implementing anything where a stacked LSTM is used within a
loop (see the following issue: https://github.com/pytorch/pytorch/issues/18143). Hence, this
class provides a single time step for a stacked LSTM.
"""
cells: jit.Final[int]
num_stacked: jit.Final[int]
[docs] def __init__(self, input_size, hidden_sizes, bias=True, cudnn=True):
"""
Initializes a new stacked LSTM cell.
Parameters
----------
input_size: int
The dimension of the input variables.
hidden_sizes: list of int
The hidden dimension of the stacked LSTMs.
bias: bool, default: True
Whether to use a bias term for the LSTM implementation.
cudnn: bool, default: True
Whether to not use cuDNN. In almost all cases, you don't want to set this value to
false, however, you will need to change it if you want to compute higher-order
derivatives of a network with a stacked LSTM cell.
"""
super().__init__()
self.num_stacked = len(hidden_sizes)
cell_class = nn.LSTMCell if cudnn else _LSTMCell
cells = []
dims = zip([input_size] + hidden_sizes, hidden_sizes)
for in_dim, out_dim in dims:
cells.append(cell_class(in_dim, out_dim, bias=bias))
self.cells = nn.ModuleList(cells)
# pylint: disable=arguments-differ
[docs] def forward(self, x: torch.Tensor,
initial_states: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None):
"""
Computes the new hidden states and cell states for each stacked cell.
Parameters
----------
x: torch.Tensor [B, N]
The input with batch size B and dimension N.
states: list of tuple of (torch.Tensor [B, D], torch.Tensor [B, D]), default: None
The states for each of the cells where each state is expected to have a size with batch
size B and (respective) hidden dimension D.
Returns
-------
torch.Tensor [B, D]
The output, i.e. the hidden state of the deepest cell. Only given for convenience as it
can be extracted from the other return value.
list of tuple of (torch.Tensor [B, D], torch.Tensor [B, D])
The new hidden states and cell states for all cells.
"""
if initial_states is None:
# JIT Compatibility
states = [
(torch.empty(0), torch.empty(0))
for _ in range(self.num_stacked)
]
else:
states = initial_states
i = 0
for cell in self.cells:
x, next_cell = cell(
x, None if states[i][0].size(0) == 0 else states[i]
)
states[i] = (x, next_cell)
i += 1
return x, states
class _LSTMCell(nn.Module):
"""
LSTMCell which does not have cuDNN support but allows for higher-order gradients.
Consult PyTorch's LSTMCell for documentation on the class's initialization parameters and how
to call it.
"""
hidden_size: jit.Final[int]
has_bias: jit.Final[bool]
def __init__(self, input_size, hidden_size, bias=True):
super().__init__()
self.hidden_size = hidden_size
self.input_weight = nn.Parameter(
torch.FloatTensor(input_size, 4 * hidden_size)
)
self.hidden_weight = nn.Parameter(
torch.FloatTensor(hidden_size, 4 * hidden_size)
)
if bias:
self.bias = nn.Parameter(torch.FloatTensor(4 * hidden_size))
self.has_bias = True
else:
self.has_bias = False
self.reset_parameters()
def reset_parameters(self):
"""
Resets the parameters of the model.
"""
sqrt_hidden = np.sqrt(1 / self.hidden_size)
init_from = (-sqrt_hidden, sqrt_hidden)
for p in self.parameters():
nn.init.uniform_(p, *init_from)
# pylint: disable=arguments-differ,missing-function-docstring
def forward(self, x_in: torch.Tensor,
state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):
if state is None:
size = (x_in.size(0), self.hidden_size)
hidden_state = torch.zeros(
*size, dtype=torch.float, device=x_in.device
)
cell_state = torch.zeros(
*size, dtype=torch.float, device=x_in.device
)
else:
hidden_state, cell_state = state
# 1) Perform matrix multiplications for input and last hidden state
if self.has_bias:
x = torch.addmm(self.bias, x_in, self.input_weight)
h = torch.addmm(self.bias, hidden_state, self.hidden_weight)
else:
x = x_in.matmul(self.input_weight)
h = hidden_state.matmul(self.hidden_weight)
forget_gate_x, input_gate_x_1, input_gate_x_2, output_gate_x = \
x.split(self.hidden_size, dim=1)
forget_gate_h, input_gate_h_1, input_gate_h_2, output_gate_h = \
h.split(self.hidden_size, dim=1)
# 2) Forget gate
forget_gate = torch.sigmoid(forget_gate_x + forget_gate_h)
# 3) Input gate
input_gate_1 = torch.sigmoid(input_gate_x_1 + input_gate_h_1)
input_gate_2 = torch.tanh(input_gate_x_2 + input_gate_h_2)
input_gate = forget_gate * cell_state + input_gate_1 * input_gate_2
# 4) Output gate
output_gate_1 = torch.sigmoid(output_gate_x + output_gate_h)
output_gate = output_gate_1 * torch.tanh(input_gate)
return output_gate, input_gate