Working with Expectation–Maximization¶
Background¶
Unlike the traditional backprop training regimens used for feedforward neural networks, predictive coding networks are trained using expectation–maximization (EM) based procedures (typically referred to as either inference learning or prospective configuration).
The training objective with predictive coding is to minimize the variational free energy, \(\mathcal{F}\), of the network. In a predictive coding network, the total free energy is equal to the sum of free energies for each node in the network. In general, we can describe this as a function relative to the state of a node \(\mathbf{x}\) and parameterized by \(\boldsymbol{\Theta}\).
Then, EM breaks the optimization task into two steps:
E-Step (Inference): \(\mathbf{x}^\ast = \operatorname{arg\,max}_{\mathbf{x}} \mathcal{F}(\mathbf{x}, \boldsymbol{\Theta})\)
M-Step (Learning): \(\boldsymbol{\Theta}^\ast = \operatorname{arg\,max}_{\boldsymbol{\Theta}} \mathcal{F}(\mathbf{x}, \boldsymbol{\Theta})\)
In standard inference learning (IL), for each batch, multiple E-steps are repeatedly performed, followed by a single M-step. In incremental IL, for each batch, each E-step is followed by an M-step, with multiple iterations being performed.
EM in Pyromancy¶
To work with PyTorch’s Optimizer class, the values of Node classes incorporate their trainable state using Parameter objects. In addition to this, model parameters from classes not included in Pyromancy need to be incorporated into this scheme.
To this end, Pyromancy defines some helper functions for managing these two kinds of parameters. First, it provides two decorator functions to register parameters as either E-step or M-step parameters: eparameters() and mparameters() respectively. These add attributes _e_params_ and _m_params_, respectively to the class, and fill them with any added parameter names, plus any E-step or M-step parameters in the superclasses (resolved by traversing superclasses in the method resolution order with __mro__).
For example, in the following inheritance chain, MultivariateGaussianNode has value as an E-step parameter and covar_cf_logdiag and covar_cf_offtril as M-step parameters.
@eparameters("value")
class PredictiveNode(Node, ABC):
value: nn.Parameter
def __init__(self, ...) -> None:
...
class VariationalNode(PredictiveNode, ABC):
def __init__(self, ...) -> None:
...
class AbstractGaussianNode(VariationalNode, ABC):
def __init__(self, ...) -> None:
...
@mparameters("covar_cf_logdiag", "covar_cf_offtril")
class MultivariateGaussianNode(AbstractGaussianNode):
covar_cf_logdiag: nn.Parameter
covar_cf_offtril: nn.Parameter
def __init__(self, ...) -> None:
...
Then, Pyromancy provides the functions get_named_estep_params() and get_estep_params() to retrieve E-step parameters, and the functions get_named_mstep_params() and get_mstep_params() to retrieve M-step parameters. These are modelled after the named_parameters() and parameters() methods provided by PyTorch’s Module object. These methods also traverse the entire method resolution order for classes so even if E-step and/or M-step parameters are not directly registered for a class, they will still be detected.
By default, if a class and none of its parent classes specify E-step or M-step parameters, any parameters found will be assumed to be M-step parameters. For example, for the following module, the named E-step parameters will be 1.value and the named M-step parameters will be 0.weight, 0.bias, and 1.logvar.
module = nn.ModuleList(nn.Linear(784, 256), FactorizedGaussianNode(256))