Source code for pyromancy.nodes.gaussian

import math
from abc import ABC, abstractmethod

import torch
import torch.nn as nn

from ..utils import mparameters
from .base import VariationalNode


[docs] class AbstractGaussianNode(VariationalNode, ABC): r"""Base class for predictive coding nodes modelling Gaussian distributions. A multivariate Gaussian distribution is described by the following probability density function: .. math:: f(\mathbf{x}; \boldsymbol{\mu}, \boldsymbol{\Sigma}) = \frac{1}{\sqrt{(2\pi)^N \lvert\boldsymbol{\Sigma}\rvert}} \exp \left(-\frac{1}{2} (\mathbf{z} - \boldsymbol{\mu}) \boldsymbol{\Sigma}^{-1} (\mathbf{z} - \boldsymbol{\mu})^\intercal \right) where :math:`\mathbf{x}` is a sample, :math:`\boldsymbol{\mu}` is the mean, and :math:`\boldsymbol{\Sigma}` is the covariance matrix, for an :math:`N`-dimensional distribution. Args: *shape (int | None): shape of the node's learned state. Attributes: value (~torch.nn.parameter.Parameter): current value of the node. """ def __init__(self, *shape: int | None) -> None: VariationalNode.__init__(self, *shape) @property @abstractmethod def covariance(self) -> torch.Tensor: r"""Covariance matrix of the Gaussian distribution. Args: value (float | ~torch.Tensor): new covariance for the distribution. Raises: NotImplementedError: must be implemented by subclasses. Returns: ~torch.Tensor: covariance of the distribution. """ raise NotImplementedError @covariance.setter @abstractmethod def covariance(self, value: float | torch.Tensor) -> None: raise NotImplementedError
[docs] class StandardGaussianNode(AbstractGaussianNode): r"""Gaussian predictive coding node with unit variance. Assumes the covariance matrix is an identity matrix. .. math:: \boldsymbol{\Sigma} = \mathbf{I} Args: *shape (int | None): shape of the node's learned state. Attributes: value (~torch.nn.parameter.Parameter): value of the node :math:`\mathbf{z}`. """ def __init__(self, *shape: int | None) -> None: AbstractGaussianNode.__init__(self, *shape) @property def covariance(self) -> torch.Tensor: r"""Covariance matrix of the Gaussian distribution. .. math:: \boldsymbol{\Sigma} = \mathbf{I} Args: value (float | ~torch.Tensor): new covariance for the distribution. Raises: RuntimeError: covariance is a fixed value. Returns: ~torch.Tensor: covariance of the distribution. """ return torch.eye(self.size, dtype=self.value.dtype, device=self.value.device) @covariance.setter def covariance(self, value: float | torch.Tensor) -> None: raise RuntimeError(f"{type(self).__name__} has fixed covariance")
[docs] def error(self, pred: torch.Tensor) -> torch.Tensor: r"""Error between the prediction and node state. .. math:: \boldsymbol{\varepsilon} = \mathbf{z} - \boldsymbol{\mu} Args: pred (~torch.Tensor): predicted distribution mean :math:`\boldsymbol{\mu}`. Returns: ~torch.Tensor: elementwise error :math:`\boldsymbol{\varepsilon}`. """ return self.value - pred
[docs] def energy(self, pred: torch.Tensor) -> torch.Tensor: r"""Variational free energy with respect to the prediction. .. math:: \begin{aligned} \mathcal{F} &= \frac{1}{2} (\mathbf{z} - \boldsymbol{\mu}) (\mathbf{z} - \boldsymbol{\mu})^\intercal \\ &= \frac{1}{2} \lVert\mathbf{z} - \boldsymbol{\mu}\rVert_2^2 \end{aligned} Args: pred (~torch.Tensor): predicted distribution mean :math:`\boldsymbol{\mu}`. Returns: ~torch.Tensor: variational free energy :math:`\mathcal{F}`. """ diff = (self.value - pred).flatten(1) return 0.5 * (diff.unsqueeze(1) @ diff.unsqueeze(2)).flatten()
[docs] def sample( self, value: torch.Tensor, generator: torch.Generator | None = None ) -> torch.Tensor: r"""Samples from the learned variational distribution. Args: value (~torch.Tensor): location parameter of the variational distribution for sampling. generator (~torch.Generator | None, optional): pseudorandom number generator for sampling. Defaults to None. Returns: ~torch.Tensor: samples from the variational distribution. """ mu, pragma = self.shapeobj.coalesce(value) x = torch.randn(mu.shape, generator=generator, out=torch.empty_like(mu)) return self.shapeobj.disperse(x, pragma)
[docs] @mparameters("logvar") class IsotropicGaussianNode(AbstractGaussianNode): r"""Gaussian predictive coding node with scalar variance. Assumes the covariance matrix is a scalar matrix. .. math:: \boldsymbol{\Sigma} = \sigma\mathbf{I} Args: *shape (int | None): shape of the node's learned state. variance (float | ~torch.Tensor, optional): initial variance. Defaults to 1.0. Attributes: value (~torch.nn.parameter.Parameter): value of the node :math:`\mathbf{z}`. logvar (~torch.nn.parameter.Parameter): log of the distribution variance :math:`\log{\sigma}`. """ logvar: nn.Parameter def __init__( self, *shape: int | None, variance: float | torch.Tensor = 1.0 ) -> None: AbstractGaussianNode.__init__(self, *shape) self.logvar = nn.Parameter(torch.empty([]), True) self.covariance = variance @property def covariance(self) -> torch.Tensor: r"""Covariance matrix of the Gaussian distribution. .. math:: \boldsymbol{\Sigma} = \sigma\mathbf{I} Args: value (float | ~torch.Tensor): new covariance for the distribution. Returns: ~torch.Tensor: covariance of the distribution. Note: Assigment of variances is performed as follows: - 0D-Tensor (or float): single variance is used. - 1D-Tensor: vector of variances are averaged. - 2D-Tensor: diagonal of the covariance matrix is averaged. """ return self.logvar.exp() * torch.eye( self.size, dtype=self.logvar.dtype, device=self.logvar.device ) @covariance.setter @torch.no_grad() def covariance(self, value: float | torch.Tensor) -> None: if not isinstance(value, torch.Tensor): if not value > 0: raise ValueError("variance must be positive") self.logvar.fill_(math.log(value)) else: match value.ndim: # scalar (isotropic multivariate) case 0: if not value > 0: raise ValueError("variance must be positive") self.logvar.fill_(value.log()) # vector (factorized multivariate) case 1: if not value.numel() == self.size: raise ValueError( "`covariance` must be specified as a scalar, a vector of " f"{self.size}, or a {self.size} x {self.size} matrix" ) if not torch.all(value > 0): raise ValueError( "all elements of the variance vector must be positive" ) self.logvar.fill_(value.mean().log()) # matrix (full multivariate) case 2: if not all(sz == self.size for sz in value.shape): raise ValueError( "`covariance` must be specified as a scalar, a vector of " f"{self.size}, or a {self.size} x {self.size} matrix" ) _, info = torch.linalg.cholesky_ex(value) if not info.item() == 0: raise ValueError( "the covariance matrix must be " "symmetric and positive-definite" ) self.logvar.fill_(value.diag().mean().log()) # invalid tensor dimensionality case _: raise ValueError( "`covariance` must be specified as a scalar, a vector of " f"{self.size}, or a {self.size} x {self.size} matrix" )
[docs] def error(self, pred: torch.Tensor) -> torch.Tensor: r"""Error between the prediction and node state. .. math:: \boldsymbol{\varepsilon} = \frac{\mathbf{z} - \boldsymbol{\mu}}{\sigma} Args: pred (~torch.Tensor): predicted distribution mean :math:`\boldsymbol{\mu}`. Returns: ~torch.Tensor: elementwise error :math:`\boldsymbol{\varepsilon}`. """ return (self.value - pred) / self.logvar.exp()
[docs] def energy(self, pred: torch.Tensor) -> torch.Tensor: r"""Variational free energy with respect to the prediction. .. math:: \begin{aligned} \mathcal{F} &= \frac{1}{2} \left((\mathbf{z} - \boldsymbol{\mu}) ((\mathbf{z} - \boldsymbol{\mu}) \sigma^{-1})^\intercal + N \log \sigma\right) \\ &= \frac{1}{2} \left(\frac{\lVert\mathbf{z} - \boldsymbol{\mu}\rVert_2^2}{\sigma} + N \log \sigma\right) \end{aligned} Args: pred (~torch.Tensor): predicted distribution mean :math:`\boldsymbol{\mu}`. Returns: ~torch.Tensor: variational free energy :math:`\mathcal{F}`. """ diff = (self.value - pred).flatten(1) y = diff / self.logvar.exp() logdet = self.size * self.logvar return 0.5 * (diff.unsqueeze(1) @ y.unsqueeze(2) + logdet).flatten()
[docs] def sample( self, value: torch.Tensor, generator: torch.Generator | None = None ) -> torch.Tensor: r"""Samples from the learned variational distribution. Args: value (~torch.Tensor): location parameter of the variational distribution for sampling. generator (~torch.Generator | None, optional): pseudorandom number generator for sampling. Defaults to None. Returns: ~torch.Tensor: samples from the variational distribution. """ mu, pragma = self.shapeobj.coalesce(value) std = self.logvar.exp().sqrt() x = std * torch.randn(mu.shape, generator=generator, out=torch.empty_like(mu)) return self.shapeobj.disperse(x, pragma)
[docs] @mparameters("logvar") class FactorizedGaussianNode(AbstractGaussianNode): r"""Gaussian predictive coding node with diagonal variances. Assumes the covariance matrix is a diagonal matrix. .. math:: \boldsymbol{\Sigma} = \begin{bmatrix} \sigma_1 & 0 & \cdots & 0 \\ 0 & \sigma_2 & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \cdots & \sigma_N \end{bmatrix} Args: *shape (int | None): shape of the node's learned state. variance (float, optional): initial variance. Defaults to 1.0. Attributes: value (~torch.nn.parameter.Parameter): value of the node :math:`\mathbf{z}`. logvar (~torch.nn.parameter.Parameter): log of the distribution variances :math:`\log{\boldsymbol{\sigma}}`. """ logvar: nn.Parameter def __init__( self, *shape: int | None, variance: float | torch.Tensor = 1.0 ) -> None: AbstractGaussianNode.__init__(self, *shape) self.logvar = nn.Parameter(torch.empty([self.size]), True) self.covariance = variance @property def covariance(self) -> torch.Tensor: r"""Covariance matrix of the Gaussian distribution. .. math:: \boldsymbol{\Sigma} = \operatorname{diag}(\sigma_1, \sigma_2, \ldots, \sigma_N) Args: value (float | ~torch.Tensor): new covariance for the distribution. Returns: ~torch.Tensor: covariance of the distribution. Note: Assigment of variances is performed as follows: - 0D-Tensor (or float): single variance is used. - 1D-Tensor: vector of variances is used. - 2D-Tensor: diagonal of the covariance matrix is used. """ return torch.diag(self.logvar.exp()) @covariance.setter @torch.no_grad() def covariance(self, value: float | torch.Tensor) -> None: if not isinstance(value, torch.Tensor): if not value > 0: raise ValueError("variance must be positive") self.logvar.fill_(math.log(value)) else: match value.ndim: # scalar (isotropic multivariate) case 0: if not value > 0: raise ValueError("variance must be positive") self.logvar.fill_(value.log()) # vector (factorized multivariate) case 1: if not value.numel() == self.size: raise ValueError( "`covariance` must be specified as a scalar, a vector of " f"{self.size}, or a {self.size} x {self.size} matrix" ) if not torch.all(value > 0): raise ValueError( "all elements of the variance vector must be positive" ) self.logvar.copy_(value.log()) # matrix (full multivariate) case 2: if not all(sz == self.size for sz in value.shape): raise ValueError( "`covariance` must be specified as a scalar, a vector of " f"{self.size}, or a {self.size} x {self.size} matrix" ) _, info = torch.linalg.cholesky_ex(value) if not info.item() == 0: raise ValueError( "the covariance matrix must be " "symmetric and positive-definite" ) self.logvar.copy_(value.diag().log()) # invalid tensor dimensionality case _: raise ValueError( "`covariance` must be specified as a scalar, a vector of " f"{self.size}, or a {self.size} x {self.size} matrix" )
[docs] def error(self, pred: torch.Tensor) -> torch.Tensor: r"""Error between the prediction and node state. .. math:: \boldsymbol{\varepsilon} = (\mathbf{z} - \boldsymbol{\mu}) \oslash \boldsymbol{\sigma} Args: pred (~torch.Tensor): predicted distribution mean :math:`\boldsymbol{\mu}`. Returns: ~torch.Tensor: elementwise error :math:`\boldsymbol{\varepsilon}`. """ diff, pragma = self.shapeobj.coalesce(self.value - pred) return self.shapeobj.disperse(diff / self.logvar.exp(), pragma)
[docs] def energy(self, pred: torch.Tensor) -> torch.Tensor: r"""Variational free energy with respect to the prediction. .. math:: \mathcal{F} = \frac{1}{2} \left( (\mathbf{z} - \boldsymbol{\mu}) ((\mathbf{z} - \boldsymbol{\mu}) \oslash \boldsymbol{\sigma})^\intercal + N \log \sigma\right) Args: pred (~torch.Tensor): predicted distribution mean :math:`\boldsymbol{\mu}`. Returns: ~torch.Tensor: variational free energy :math:`\mathcal{F}`. """ diff, pragma = self.shapeobj.coalesce(self.value - pred) y = diff / self.logvar.exp() diff = self.shapeobj.disperse(diff, pragma).flatten(1) y = self.shapeobj.disperse(y, pragma).flatten(1) logdet = self.logvar.sum() return 0.5 * (diff.unsqueeze(1) @ y.unsqueeze(2) + logdet).flatten()
[docs] def sample( self, value: torch.Tensor, generator: torch.Generator | None = None ) -> torch.Tensor: r"""Samples from the learned variational distribution. Args: value (~torch.Tensor): location parameter of the variational distribution for sampling. generator (~torch.Generator | None, optional): pseudorandom number generator for sampling. Defaults to None. Returns: ~torch.Tensor: samples from the variational distribution. """ mu, pragma = self.shapeobj.coalesce(value) std = self.logvar.exp().sqrt() x = std * torch.randn(mu.shape, generator=generator, out=torch.empty_like(mu)) return self.shapeobj.disperse(x, pragma)
[docs] @mparameters("covar_cf_logdiag", "covar_cf_offtril") class MultivariateGaussianNode(AbstractGaussianNode): r"""Gaussian predictive coding node with full covariance. The covariances of the distribution are represented as a full covariance matrix, that is, a matrix that is symmetric and positive-definite. Internally, the covariance matrix is stored as two parts that can be combined into the Cholesky factor :math:`\mathbf{L}` of the covariance matrix :math:`\boldsymbol{\Sigma}`. .. math:: \boldsymbol{\Sigma} = \mathbf{L}\mathbf{L}^\ast Args: *shape (int | None): shape of the node's learned state. variance (float, optional): initial variance. Defaults to 1.0. Attributes: value (~torch.nn.parameter.Parameter): value of the node :math:`\mathbf{z}`. covar_cf_logdiag (~torch.nn.parameter.Parameter): log of the diagonal of the Cholesky factor for the distribution covariance. covar_cf_offtril (~torch.nn.parameter.Parameter): Cholesky factor for the distribution covariances, with the diagonal zeroed. """ covar_cf_logdiag: nn.Parameter covar_cf_offtril: nn.Parameter def __init__( self, *shape: int | None, covariance: float | torch.Tensor = 1.0 ) -> None: AbstractGaussianNode.__init__(self, *shape) self.covar_cf_logdiag = nn.Parameter(torch.empty([self.size]), True) self.covar_cf_offtril = nn.Parameter(torch.empty([self.size, self.size]), True) self.covariance = covariance def _cholesky_factor_l(self) -> torch.Tensor: r"""Computes the Cholesky decomposition factor :math:`L` of the covariance matrix. Returns: ~torch.Tensor: Cholesky factor :math:`L`. """ return self.covar_cf_offtril.tril(-1) + self.covar_cf_logdiag.exp().diag() @property def covariance(self) -> torch.Tensor: r"""Covariance matrix of the Gaussian distribution. .. math:: \boldsymbol{\Sigma} = \begin{bmatrix} \sigma_{1,1} & \sigma_{1,2} & \cdots & \sigma_{1,N} \\ \sigma_{2,1} & \sigma_{2,2} & \cdots & \sigma_{2,N} \\ \vdots & \vdots & \ddots & \vdots \\ \sigma_{N,1} & \sigma_{N,2} & \cdots & \sigma_{N,N} \\ \end{bmatrix} Args: value (float | ~torch.Tensor): new covariance for the distribution. Returns: ~torch.Tensor: covariance of the distribution. Note: Assigment of covariances is performed as follows: - 0D-Tensor (or float): single variance is used, with zero covariance. - 1D-Tensor: vector of variances is used, with zero covariance. - 2D-Tensor: covariance matrix is used. """ L = self._cholesky_factor_l() return L @ L.t() @covariance.setter @torch.no_grad() def covariance(self, value: float | torch.Tensor) -> None: if not isinstance(value, torch.Tensor): if not value > 0: raise ValueError("variance must be positive") self.covar_cf_logdiag.fill_(math.log(math.sqrt(value))) self.covar_cf_offtril.fill_(0.0) else: match value.ndim: # scalar (isotropic multivariate) case 0: if not value > 0: raise ValueError("variance must be positive") self.covar_cf_logdiag.fill_(value.sqrt().log()) self.covar_cf_offtril.fill_(0.0) # vector (factorized multivariate) case 1: if not all(sz == self.size for sz in value.shape): raise ValueError( "`covariance` must be specified as a scalar, a vector of " f"{self.size}, or a {self.size} x {self.size} matrix" ) if not torch.all(value > 0): raise ValueError( "all elements of the variance vector must be positive" ) self.covar_cf_logdiag.copy_(value.sqrt().log()) self.covar_cf_offtril.fill_(0.0) # matrix (full multivariate) case 2: if not all(sz == self.size for sz in value.shape): raise ValueError( "`covariance` must be specified as a scalar, a vector of " f"{self.size}, or a {self.size} x {self.size} matrix" ) L, info = torch.linalg.cholesky_ex(value) if not info.item() == 0: raise ValueError( "the covariance matrix must be " "symmetric and positive-definite" ) self.covar_cf_logdiag.copy_(L.diag().log()) self.covar_cf_offtril.copy_(L).fill_diagonal_(0.0) # invalid tensor dimensionality case _: raise ValueError( "`covariance` must be specified as a scalar, a vector of " f"{self.size}, or a {self.size} x {self.size} matrix" )
[docs] def error(self, pred: torch.Tensor) -> torch.Tensor: r"""Error between the prediction and node state. .. math:: \boldsymbol{\varepsilon} = \boldsymbol{\Sigma}^{-1} (\mathbf{z} - \boldsymbol{\mu})^\intercal Args: pred (~torch.Tensor): predicted distribution mean :math:`\boldsymbol{\mu}`. Returns: ~torch.Tensor: elementwise error :math:`\boldsymbol{\varepsilon}`. """ diff, pragma = self.shapeobj.coalesce(self.value - pred) L = self._cholesky_factor_l() u = torch.linalg.solve_triangular(L, diff.t(), upper=False) y = torch.linalg.solve_triangular(L.t(), u, upper=True) return self.shapeobj.disperse(y.t(), pragma)
[docs] def energy(self, pred: torch.Tensor) -> torch.Tensor: r"""Variational free energy with respect to the prediction. .. math:: \mathcal{F} = \frac{1}{2} \left( (\mathbf{z} - \boldsymbol{\mu}) \boldsymbol{\Sigma}^{-1} (\mathbf{z} - \boldsymbol{\mu})^\intercal + \log \lvert\boldsymbol{\Sigma}\rvert \right) Args: pred (~torch.Tensor): predicted distribution mean :math:`\boldsymbol{\mu}`. Returns: ~torch.Tensor: variational free energy :math:`\mathcal{F}`. """ diff, pragma = self.shapeobj.coalesce(self.value - pred) L = self._cholesky_factor_l() u = torch.linalg.solve_triangular(L, diff.t(), upper=False) y = torch.linalg.solve_triangular(L.t(), u, upper=True) diff = self.shapeobj.disperse(diff, pragma).flatten(1) y = self.shapeobj.disperse(y.t(), pragma).flatten(1) logdet = 2.0 * self.covar_cf_logdiag.sum() return 0.5 * (diff.unsqueeze(1) @ y.unsqueeze(2) + logdet).flatten()
[docs] def sample( self, value: torch.Tensor, generator: torch.Generator | None = None ) -> torch.Tensor: r"""Samples from the learned variational distribution. Args: value (~torch.Tensor): location parameter of the variational distribution for sampling. generator (torch.Generator | None, optional): pseudorandom number generator for sampling. Defaults to None. Returns: ~torch.Tensor: samples from the variational distribution. """ L = self._cholesky_factor_l() mu, pragma = self.shapeobj.coalesce(value) x = L @ torch.randn(mu.shape, generator=generator, out=torch.empty_like(mu)).t() return self.shapeobj.disperse(x.t(), pragma)