FloatNode¶
- class FloatNode(*shape: int | None)[source]¶
Bases:
NodeInput node with an trainable value.
- Parameters:
*shape (int | None) – base shape of the node’s state.
Hint
This is primarily useful when performing query by initialization from an input, where the value is updated on E-steps.
- error(pred: Tensor) Tensor[source]¶
Error between the prediction and node state.
\[\boldsymbol{\varepsilon} = \mathbf{z} - \boldsymbol{\mu}\]
- forward(inputs: Tensor, **kwargs) Tensor[source]¶
Computes a forward pass on the node.
When
self.trainingis True, the prediction is assigned to the value and then value is returned. Whenself.trainingis False, the prediction is directly returned (i.e. this acts as the identity operation).