import pyblaze.plot as P
The plotting module provides utilities for plotting. It is fully compatible with
and merely provides additional (high-level) methods that integrate more seemlessly with PyTorch.
density_plot2d(callback, x_range=(- 1, 1), y_range=(- 1, 1), resolution=500, cmap='Blues', threshold=None, ax=None, **kwargs)¶
Generates a scatter plot visualizing a distribution’s density in the given 2D region.
callback (callable) – The distribution for which the probability is evaluated. Potentially given as PyTorch module. Must take a torch.Tensor [N, 2] (number of evaluation points N) as input and return the output as torch.Tensor [N].
x_range (tuple of (float, float), default: (-1, 1)) – The range to visualize in x-dimension.
y_range (tuple of (float, float), default: (-1, 1)) – The range to visualize in y-dimension.
resolution (int, default: 500) – The number of evaluation points for each dimension. The total number of evaluation points is therefore given by the squared resolution.
cmap (str, default: 'gist_heat') – The matplotlib colorbar to use for visualization.
threshold (float, default: None) – A minimum value that needs to be surpassed in order for a scatter point to be plotted.
ax (matplotlib.axes, default: None) – The axis to use for plotting or None if the global imperative API of matplotlib should be used.
kwargs (keyword arguments) – Additional arguments passed to the