Plotting

import pyblaze.plot as P

The plotting module provides utilities for plotting. It is fully compatible with matplotlib and merely provides additional (high-level) methods that integrate more seemlessly with PyTorch.

Visualizing Densities

pyblaze.plot.density.density_plot2d(callback, x_range=(- 1, 1), y_range=(- 1, 1), resolution=500, cmap='Blues', threshold=None, ax=None, **kwargs)[source]

Generates a scatter plot visualizing a distribution’s density in the given 2D region.

Parameters
  • callback (callable) – The distribution for which the probability is evaluated. Potentially given as PyTorch module. Must take a torch.Tensor [N, 2] (number of evaluation points N) as input and return the output as torch.Tensor [N].

  • x_range (tuple of (float, float), default: (-1, 1)) – The range to visualize in x-dimension.

  • y_range (tuple of (float, float), default: (-1, 1)) – The range to visualize in y-dimension.

  • resolution (int, default: 500) – The number of evaluation points for each dimension. The total number of evaluation points is therefore given by the squared resolution.

  • cmap (str, default: 'gist_heat') – The matplotlib colorbar to use for visualization.

  • threshold (float, default: None) – A minimum value that needs to be surpassed in order for a scatter point to be plotted.

  • ax (matplotlib.axes, default: None) – The axis to use for plotting or None if the global imperative API of matplotlib should be used.

  • kwargs (keyword arguments) – Additional arguments passed to the scatter method.