# Source code for pyblaze.nn.functional.densities

import math
import torch

[docs]def log_prob_standard_normal(x):
"""
Computes the log-probability of observing the given data under a (multivariate) standard Normal
distribution. Although this function is equivalent to the :code:log_prob method of the
:class:torch.distributions.MultivariateNormal class, this implementation is much more
efficient due to the restriction to standard Normal.

Parameters
----------
x: torch.Tensor [N, D]
The samples whose log-probability shall be computed (number of samples N, dimensionality D).

Returns
-------
torch.Tensor [N]
The log-probabilities for all samples.
"""
dim = x.size(1)
const = dim * math.log(2 * math.pi)
norm = torch.einsum('ij,ij->i', x, x)
return -0.5 * (const + norm)

[docs]def log_prob_standard_gmm(x, means):
"""
Computes the log-probability of observing the given data under a GMM consisting of
(multivariate) standard normal distributions. Each component is assigned the same weight.

Parameters
----------
x: torch.Tensor [N, D]
The samples whose log-probability shall be computed (number of samples N,
dimensionality D).
means: torch.Tensor [M, D]
The means of the GMM.

Returns
-------
torch.Tensor [N]
The log-probabilities for all samples.
"""
num_datapoints, dim = x.size()
num_components = means.size(0)

const = dim * math.log(2 * math.pi)
xx = torch.einsum('ij,ij->i', x, x).view(num_datapoints, 1)
mm = torch.einsum('ij,ij->i', means, means).view(1, num_components)
xm = x.matmul(means.t())
log_probs = -0.5 * (const + xx - 2 * xm + mm)