Source code for pyromancy.infra

import math
from typing import Any, Iterator, overload

import einops as ein
import torch


[docs] class Shape: r"""Tensor shape with support for placeholder dimensions. Args: *shape (int | None): dimensions of the tensor, either positive integers for fixed dimensions or none for unspecified dimensions. Important: Scalar tensors (i.e. tensors with no dimensions) are unsupported, as are tensors with any dimension of size 0. """ _rawshape: tuple[int | None, ...] _concrete_dims: tuple[int, ...] _virtual_dims: tuple[int, ...] _parseshp_str: str _coalesce_str: str _disperse_str: str def __init__(self, *shape: int | None) -> None: if not len(shape) > 0: raise ValueError("`shape` must contain at least one element") if not all(isinstance(s, int | None) for s in shape): raise TypeError("all elements of `shape` must be of type `int` or `None`") if not all(s > 0 for s in shape if s is not None): raise ValueError("all integer elements of `shape` must be positive") self._rawshape = tuple(int(s) if s is not None else None for s in shape) self._concrete_dims = tuple( d for d, s in enumerate(self._rawshape) if s is not None ) self._virtual_dims = tuple(d for d, s in enumerate(self._rawshape) if s is None) dims = tuple(f"d{d}" for d in range(len(self._rawshape))) cdims = tuple(f"d{d}" for d in self._concrete_dims) vdims = tuple(f"d{d}" for d in self._virtual_dims) self._parseshp_str = " ".join(dims) self._coalesce_str = ( f"{' '.join(dims)} -> ({' '.join(vdims)}) ({' '.join(cdims)})" ) self._disperse_str = ( f"({' '.join(vdims)}) ({' '.join(cdims)}) -> {' '.join(dims)}" ) def __repr__(self) -> str: return f"{type(self).__name__}({', '.join(str(d) for d in self._rawshape)})" def __eq__(self, other: Any) -> bool: if isinstance(other, type(self)): return self._rawshape == other._rawshape elif isinstance(other, tuple): return self._rawshape == other else: return False @overload def __getitem__(self, index: int) -> int | None: ... @overload def __getitem__(self, index: slice) -> tuple[int | None, ...]: ... def __getitem__(self, index: int | slice) -> int | None | tuple[int | None, ...]: return self._rawshape[index] def __len__(self) -> int: return len(self._rawshape) def __iter__(self) -> Iterator[int | None]: return iter(self._rawshape) @property def rshape(self) -> tuple[int | None, ...]: r"""Tensor shape, including placeholder dimensions. Returns: tuple[int | None, ...]: raw tensor shape. """ return self._rawshape @property def bshape(self) -> tuple[int, ...]: r"""Tensor shape, with placeholder dimensions set to unit length. Returns: tuple[int | None, ...]: broadcastable tensor shape. """ return tuple(1 if s is None else s for s in self._rawshape) @property def size(self) -> int: r"""Number of elements specified by the shape. Returns: int: minimal number of tensor elements. """ return math.prod(self.bshape) @property def ndim(self) -> int: r"""Number of dimensions specified by the shape. Returns: int: dimensionality of a compatible tensor. """ return len(self._rawshape) @property def nconcrete(self) -> int: r"""Number of fixed dimensions. Returns: int: number of concrete dimensions. """ return len(self._concrete_dims) @property def nvirtual(self) -> int: r"""Number of placeholder dimensions. Returns: int: number of virtual dimensions. """ return len(self._virtual_dims)
[docs] def compat(self, *shape: int) -> bool: r"""Tests if a shape is compatible with the specified constraints. Args: *shape (int | None): dimensions of the tensor. Returns: bool: if the shape is compatible. """ if not all(isinstance(d, int) for d in shape): raise TypeError("all elements of `shape` must be of type `int`") if not all(d > 0 for d in shape): raise ValueError("all elements of `shape` must be positive") if len(shape) != len(self._rawshape): return False for dx, di in zip(shape, self._rawshape): if di is not None and dx != di: return False return True
[docs] def filled(self, *fill: int) -> tuple[int, ...]: r"""Fills placeholder dimensions with specified values. Returns: tuple[int, ...]: shape with the placeholder dimensions filled. """ if not len(fill) == self.nvirtual: raise ValueError( "`fill` must contain exactly the required number of placeholder elements" ) if not all(isinstance(d, int) for d in fill): raise TypeError("all elements of `fill` must be of type `int`") if not all(d > 0 for d in fill): raise ValueError("all elements of `fill` must be positive") shape = [*self._rawshape] for n, d in enumerate(self._virtual_dims): shape[d] = fill[n] return tuple(shape) # type: ignore
[docs] def coalesce(self, tensor: torch.Tensor) -> tuple[torch.Tensor, dict[str, int]]: r"""Coalesces a tensor into a matrix, with placeholder dimensions first and fixed dimensions second. For a tensor with :math:`V_1, \ldots, V_m` placeholder dimensions and :math:`C_1, \ldots, C_n` fixed dimensions, the output matrix will have a shape of :math:`(V_1 \times \cdots \times V_m) \times (C_1 \times \cdots \times C_n)`, and dimensions of unit length will used if the tensor has no placeholder/fixed dimensions. Args: tensor (~torch.Tensor): tensor to coalesce. Returns: tuple[~torch.Tensor, dict[str, int]]: tuple of the coalesced tensor and the required shape information to revert it. """ pragma = ein.parse_shape(tensor, self._parseshp_str) return ein.rearrange(tensor, self._coalesce_str), pragma
[docs] def disperse(self, tensor: torch.Tensor, pragma: dict[str, int]) -> torch.Tensor: r"""Disperses dimensions of a coalesced tensor to their original positions. Args: tensor (~torch.Tensor): tensor to disperse. pragma (dict[str, int]): shape information to revert the tensor. Returns: ~torch.Tensor: dispersed tensor. """ return ein.rearrange(tensor, self._disperse_str, **pragma)