[docs]deferror(self,pred:torch.Tensor)->torch.Tensor:r"""Error between the prediction and node state. .. math:: \boldsymbol{\varepsilon} = \mathbf{b} - \boldsymbol{\mu} Args: pred (~torch.Tensor): predicted bias :math:`\boldsymbol{\mu}`. Returns: ~torch.Tensor: elementwise error :math:`\boldsymbol{\varepsilon}`. """returnself.bias-pred
[docs]defforward(self,inputs:torch.Tensor,**kwargs)->torch.Tensor:r"""Expands bias tensor for the network. Args: inputs (~torch.Tensor): tensor to use as the shape for the returned bias. Returns: ~torch.Tensor: expanded bias tensor. Tip: ``inputs`` should have the desired shape (including the batch dimension) to use the returned bias as a prediction for initialization/inference. The contents, device, and data type of ``inputs`` are unused. """returnself.bias.unsqueeze(0).expand_as(inputs)
[docs]classFixedNode(Node):r"""Input node with a fixed value. Args: *shape (int | None): base shape of the node's state. Attributes: value (~torch.nn.parameter.Buffer): current value of the node. Hint: This is primarily useful when performing *query by conditioning* from an input, where the value is not updated on E-steps. """value:nn.Bufferdef__init__(self,*shape:int|None)->None:Node.__init__(self,*shape)self.value=nn.Buffer(torch.empty(0))
[docs]@torch.no_grad()defreset(self)->None:r"""Resets the node state. This operation is typically executed after each new batch. With inference learning, this is done after M-step. With incremental inference learning, this is done after the *final* M-step. """self.zero_grad()self.value.data=self.value.new_empty(0)
[docs]@torch.no_grad()definit(self,value:torch.Tensor)->nn.Buffer:r"""Initializes the node's state to a new value. Args: value (~torch.Tensor): value to initialize to. Returns: ~torch.nn.parameter.Buffer: the reinitialized value. Raises: RuntimeError: shape of ``value`` is incompatible with the node. """ifnotself.shapeobj.compat(*value.shape):raiseValueError(f"shape of `value` {(*value.shape,)} is incompatible "f"with node shape {(*self.shapeobj,)}")self.value.data=self.value.data.new_empty(*value.shape)self.value.copy_(value)returnself.value
[docs]deferror(self,pred:torch.Tensor)->torch.Tensor:r"""Error between the prediction and node state. .. math:: \boldsymbol{\varepsilon} = \mathbf{z} - \boldsymbol{\mu} Args: pred (~torch.Tensor): predicted value :math:`\boldsymbol{\mu}`. Returns: ~torch.Tensor: elementwise error :math:`\boldsymbol{\varepsilon}`. """returnself.value-pred
[docs]defforward(self,inputs:torch.Tensor,**kwargs)->torch.Tensor:r"""Computes a forward pass on the node. When ``self.training`` is True, the prediction is assigned to the value and then value is returned. When ``self.training`` is False, the prediction is directly returned (i.e. this acts as the identity operation). Args: inputs (~torch.Tensor): prediction of the value. Returns: ~torch.Tensor: value of the node. """ifself.training:returnself.init(inputs)else:returninputs
[docs]@eparameters("value")classFloatNode(Node):r"""Input node with an trainable value. Args: *shape (int | None): base shape of the node's state. Attributes: value (~torch.nn.parameter.Parameter): current value of the node. Hint: This is primarily useful when performing *query by initialization* from an input, where the value is updated on E-steps. """value:nn.Parameterdef__init__(self,*shape:int|None)->None:Node.__init__(self,*shape)self.value=nn.Parameter(torch.empty(0),True)
[docs]@torch.no_grad()defreset(self)->None:r"""Resets the node state. This operation is typically executed after each new batch. With inference learning, this is done after M-step. With incremental inference learning, this is done after the *final* M-step. """self.zero_grad()self.value.data=self.value.new_empty(0)
[docs]@torch.no_grad()definit(self,value:torch.Tensor)->nn.Parameter:r"""Initializes the node's state to a new value. Args: value (~torch.Tensor): value to initialize to. Returns: ~torch.nn.parameter.Parameter: the reinitialized value. Raises: RuntimeError: shape of ``value`` is incompatible with the node. """ifnotself.shapeobj.compat(*value.shape):raiseValueError(f"shape of `value` {(*value.shape,)} is incompatible "f"with node shape {(*self.shapeobj,)}")self.value.data=self.value.data.new_empty(*value.shape)self.value.copy_(value)returnself.value
[docs]deferror(self,pred:torch.Tensor)->torch.Tensor:r"""Error between the prediction and node state. .. math:: \boldsymbol{\varepsilon} = \mathbf{z} - \boldsymbol{\mu} Args: pred (~torch.Tensor): predicted value :math:`\boldsymbol{\mu}`. Returns: ~torch.Tensor: elementwise error :math:`\boldsymbol{\varepsilon}`. """returnself.value-pred
[docs]defforward(self,inputs:torch.Tensor,**kwargs)->torch.Tensor:r"""Computes a forward pass on the node. When ``self.training`` is True, the prediction is assigned to the value and then value is returned. When ``self.training`` is False, the prediction is directly returned (i.e. this acts as the identity operation). Args: inputs (~torch.Tensor): prediction of the value. Returns: ~torch.Tensor: value of the node. """ifself.training:returnself.init(inputs)else:returninputs