PredictiveNode¶
- class PredictiveNode(*shape: int | None)[source]¶
-
Base class for predictive coding nodes that generate predictions.
- Parameters:
*shape (int | None) – base shape of the node’s state.
- abstractmethod energy(pred: Tensor) Tensor[source]¶
Computes variational free energy for a prediction of the node state.
- Parameters:
pred (Tensor) – prediction of the node state.
- Raises:
NotImplementedError – must be implemented by subclasses.
- Returns:
variational free energy between the state and a prediction.
- Return type:
- abstractmethod error(pred: Tensor) Tensor[source]¶
Computes elementwise error for a prediction of the node state.
- Parameters:
pred (Tensor) – prediction of the node state.
- Raises:
NotImplementedError – must be implemented by subclasses.
- Returns:
elementwise error between the state and a prediction.
- Return type:
- forward(inputs: Tensor, **kwargs) Tensor[source]¶
Computes a forward pass on the node.
When
self.trainingis True, the prediction is assigned to the value and then value is returned. Whenself.trainingis False, the prediction is directly returned (i.e. this acts as the identity operation).