Source code for qfeval_functions.functions.mulsum

import string
import typing

import numpy as np
import torch

from .einsum import einsum


[docs] def mulsum( x: torch.Tensor, y: torch.Tensor, dim: typing.Union[int, typing.Tuple[int, ...]] = (), keepdim: bool = False, mean: bool = False, *, _ddof: int = 0, ) -> torch.Tensor: r"""Compute sum or mean of element-wise product in a memory-efficient way. This function calculates the sum (or mean) of the element-wise product of two tensors ``(x * y).sum(dim)`` or ``(x * y).mean(dim)`` without creating the intermediate product tensor in memory. This is particularly crucial when working with large tensors or when broadcasting would result in a very large intermediate tensor that could exceed available memory. Args: x (Tensor): The first input tensor. y (Tensor): The second input tensor. Must be broadcastable with :attr:`x`. dim (int or tuple of ints, optional): The dimension(s) along which to compute the sum or mean. If not specified (default is empty tuple), the operation is computed over all dimensions. keepdim (bool, optional): Whether the output tensor has :attr:`dim` retained or not. Default is False. mean (bool, optional): If True, computes the mean instead of sum. Default is False (computes sum). _ddof (int, optional): Delta degrees of freedom for mean calculations. The divisor used is ``N - _ddof``, where ``N`` is the number of elements. Default is 0. This is an internal parameter. Returns: Tensor: The sum or mean of the element-wise product. The shape depends on the input dimensions, :attr:`dim`, and :attr:`keepdim` parameters. Example: >>> # Simple element-wise product sum >>> x = torch.tensor([1.0, 2.0, 3.0]) >>> y = torch.tensor([2.0, 3.0, 4.0]) >>> QF.mulsum(x, y) tensor(20.) >>> # Equivalent to (x * y).sum() >>> torch.allclose(QF.mulsum(x, y), (x * y).sum()) True >>> # Mean instead of sum >>> QF.mulsum(x, y, mean=True) tensor(6.6667) >>> # 2D tensors with specific dimension >>> x = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) >>> y = torch.tensor([[2.0, 1.0], [1.0, 2.0]]) >>> QF.mulsum(x, y, dim=1) tensor([ 4., 11.]) >>> # Memory-efficient broadcasting >>> x = torch.randn(1000, 1) >>> y = torch.randn(1, 1000) >>> # Efficiently computes without creating 1000x1000 intermediate tensor >>> result = QF.mulsum(x, y) >>> # With keepdim >>> x = torch.tensor([[[1.0, 2.0]], [[3.0, 4.0]]]) >>> y = torch.tensor([[[2.0, 1.0]], [[1.0, 2.0]]]) >>> QF.mulsum(x, y, dim=(1, 2), keepdim=True) tensor([[[ 4.]], <BLANKLINE> [[11.]]]) .. seealso:: :func:`mulmean`: Convenience function for computing means. :func:`einsum`: The underlying Einstein summation function. :func:`covar`: Uses this function for efficient covariance calculations. """ # 1. Align the number of dimensions (c.f., NumPy broadcasting). length = max(len(x.shape), len(y.shape)) x = x.reshape((1,) * (length - len(x.shape)) + x.shape) y = y.reshape((1,) * (length - len(y.shape)) + y.shape) # 2. Figure out the final shape. # NOTE: x[None][:0] is a trick to call a torch function without wasting # cpu/memory resources. mul_shape = (x[None][:0] * y[None][:0]).shape[1:] # 3. Parse dim. # `mask` should be a tuple of ints, each of which should represent whether # the dimension is aggregated (1) or not (0). mask = torch.sum(torch.zeros((0,) * length), dim=dim, keepdim=True).shape # 4. Prepare einsum parameters and apply them to einsum. input_eq = string.ascii_lowercase[: len(x.shape)] result_eq = "".join(c for c, m in zip(input_eq, mask) if m == 0) result = einsum(f"{input_eq},{input_eq}->{result_eq}", x, y) # 5. (If keepdim is enabled,) Restore aggregated dimensions. if keepdim: result = result.reshape( tuple(1 if m == 1 else s for s, m in zip(mul_shape, mask)) ) # 6. (If mean is enabled,) Divide the result by the number of aggregated # elements. if mean: result = result / ( int(np.prod([s if m == 1 else 1 for s, m in zip(mul_shape, mask)])) - _ddof ) return result