Source code for pyblaze.nn.modules.view

import torch.nn as nn

[docs]class View(nn.Module): """ Utility module that views the input as a new dimension. This module is usually used when making use of :code:`torch.nn.Sequential` and requiring reshaping a linear output layer to a 2D input or the like. """
[docs] def __init__(self, *dim): """ Initializes a new view module. Parameters ---------- dim: varargs of int The new dimension. May contain no more than one -1. """ super().__init__() self.dim = dim
[docs] def forward(self, x): """ Views the input as this module's view dimension. Parameters ---------- x: torch.Tensor The tensor to view differently. Returns ------- torch.Tensor The input tensor with a new view on it. """ return x.view(*self.dim)