Source code for qfeval_functions.functions.nanmulsum

import math
import typing

import torch

from .fillna import fillna
from .mulsum import mulsum


[docs] def nanmulsum( x: torch.Tensor, y: torch.Tensor, dim: typing.Union[int, typing.Tuple[int, ...]] = (), keepdim: bool = False, ) -> torch.Tensor: r"""Compute the sum of element-wise product, ignoring NaN values, in a memory-efficient way. This function calculates the sum of the element-wise product of two tensors ``nansum(x * y, dim)`` without creating the intermediate product tensor in memory, while also excluding NaN values from the computation. If either :attr:`x` or :attr:`y` has a NaN at a given position, that pair is excluded from the sum calculation. The function is mathematically equivalent to ``nansum(x * y, dim)`` but uses a more memory-efficient implementation that avoids materializing the full product tensor, making it suitable for large tensor operations and complex broadcasting patterns. The NaN-aware sum is computed as: .. math:: \text{nanmulsum}(X, Y) = \sum_{i \text{ valid}} X_i \cdot Y_i where the sum is over valid (non-NaN) pairs only. Args: x (Tensor): The first input tensor. y (Tensor): The second input tensor. Must be broadcastable with :attr:`x`. dim (int or tuple of ints, optional): The dimension(s) along which to compute the sum. If not specified (default is empty tuple), the sum is computed over all dimensions. keepdim (bool, optional): Whether the output tensor has :attr:`dim` retained or not. Default is False. Returns: Tensor: The sum of the element-wise product computed only over valid (non-NaN) pairs. If no valid pairs exist along a dimension, the result is NaN. The shape depends on the input dimensions, :attr:`dim`, and :attr:`keepdim` parameters. Example: >>> # Simple element-wise product sum with NaN >>> x = torch.tensor([1.0, 2.0, nan, 4.0]) >>> y = torch.tensor([2.0, nan, 4.0, 5.0]) >>> QF.nanmulsum(x, y) tensor(22.) >>> # 2D tensors with NaN values >>> x = torch.tensor([[1.0, nan, 3.0], ... [4.0, 5.0, nan]]) >>> y = torch.tensor([[2.0, 4.0, nan], ... [nan, 10.0, 12.0]]) >>> QF.nanmulsum(x, y, dim=1) tensor([ 2., 50.]) >>> # Broadcasting with NaN handling >>> x = torch.tensor([[1.0], [nan], [3.0]]) >>> y = torch.tensor([2.0, 4.0, nan]) >>> QF.nanmulsum(x, y) tensor(24.) >>> # With keepdim >>> x = torch.tensor([[1.0, nan], [3.0, 4.0]]) >>> y = torch.tensor([[2.0, 4.0], [nan, 5.0]]) >>> QF.nanmulsum(x, y, dim=1, keepdim=True) tensor([[ 2.], [20.]]) >>> # All NaN pairs >>> x = torch.tensor([nan, nan]) >>> y = torch.tensor([1.0, 2.0]) >>> QF.nanmulsum(x, y) tensor(nan) >>> # Mixed valid and invalid pairs >>> x = torch.tensor([1.0, nan, 3.0, 4.0]) >>> y = torch.tensor([2.0, 5.0, nan, 6.0]) >>> QF.nanmulsum(x, y) tensor(26.) .. warning:: If all pairs along a dimension contain at least one NaN value, the result for that dimension is NaN. This differs from standard summation where NaN values would propagate through the entire calculation. .. seealso:: :func:`mulsum`: Memory-efficient product sum without NaN handling. :func:`nansum`: NaN-aware sum function. :func:`nanmulmean`: NaN-aware memory-efficient product mean. """ result = mulsum(fillna(x), fillna(y), dim=dim, keepdim=keepdim) x_mask = (~x.isnan()).to(result) y_mask = (~y.isnan()).to(result) count = mulsum(x_mask, y_mask, dim=dim, keepdim=keepdim) return torch.where(count > 0, result, torch.as_tensor(math.nan).to(result))