Source code for pyblaze.nn.engine.autoencoder

from .mle import MLEEngine
from ._utils import forward

[docs]class AutoencoderEngine(MLEEngine): """ Utility class for easily initializing the :class:`pyblaze.nn.MLEEngine` for a usage with an autoencoder. This includes variational autoencoders. When making use of the :meth:`predict` method, the model passed to this engine should have a submodule named :code:`decoder`. The :meth:`predict` method takes the following additional parameters: Parameters ---------- reconstruct: bool, default: False If this flag is set, data passed to the :meth:`predict` method is assumed to be the input to the encoder. Otherwise, the input is assumed to be from the latent space, i.e. to be the direct input to the decoder. In both cases, the output of the decoder is returned to the caller. """
[docs] def __init__(self, model, expects_data_target=False): """ Initializes a new engine for autoencoders. Parameters ---------- model: torch.nn.Module The model to train or evaluate. expects_data_target: bool, default: False Whether the data supplied to this engine is a tuple consisting of data and target or only contains the data. """ super().__init__( model, expects_data_target=expects_data_target, uses_data_target=False, uses_data_input=True )
[docs] def predict_batch(self, data, reconstruct=False): if reconstruct: x, _ = self._get_x_target(data) return forward(self.model, x) return forward(self.model.decoder, data)