qfeval_functions.functions.nanmulsum

nanmulsum(x, y, dim=(), keepdim=False)[source]

Compute the sum of element-wise product, ignoring NaN values, in a memory-efficient way.

This function calculates the sum of the element-wise product of two tensors nansum(x * y, dim) without creating the intermediate product tensor in memory, while also excluding NaN values from the computation. If either x or y has a NaN at a given position, that pair is excluded from the sum calculation.

The function is mathematically equivalent to nansum(x * y, dim) but uses a more memory-efficient implementation that avoids materializing the full product tensor, making it suitable for large tensor operations and complex broadcasting patterns.

The NaN-aware sum is computed as:

\[\text{nanmulsum}(X, Y) = \sum_{i \text{ valid}} X_i \cdot Y_i\]

where the sum is over valid (non-NaN) pairs only.

Parameters:
  • x (Tensor) – The first input tensor.

  • y (Tensor) – The second input tensor. Must be broadcastable with x.

  • dim (Union[int, Tuple[int, ...]]) – The dimension(s) along which to compute the sum. If not specified (default is empty tuple), the sum is computed over all dimensions.

  • keepdim (bool) – Whether the output tensor has dim retained or not. Default is False.

Returns:

The sum of the element-wise product computed only over valid (non-NaN) pairs. If no valid pairs exist along a dimension, the result is NaN. The shape depends on the input dimensions, dim, and keepdim parameters.

Return type:

Tensor

Example

>>> # Simple element-wise product sum with NaN
>>> x = torch.tensor([1.0, 2.0, nan, 4.0])
>>> y = torch.tensor([2.0, nan, 4.0, 5.0])
>>> QF.nanmulsum(x, y)
tensor(22.)
>>> # 2D tensors with NaN values
>>> x = torch.tensor([[1.0, nan, 3.0],
...                   [4.0, 5.0, nan]])
>>> y = torch.tensor([[2.0, 4.0, nan],
...                   [nan, 10.0, 12.0]])
>>> QF.nanmulsum(x, y, dim=1)
tensor([ 2., 50.])
>>> # Broadcasting with NaN handling
>>> x = torch.tensor([[1.0], [nan], [3.0]])
>>> y = torch.tensor([2.0, 4.0, nan])
>>> QF.nanmulsum(x, y)
tensor(24.)
>>> # With keepdim
>>> x = torch.tensor([[1.0, nan], [3.0, 4.0]])
>>> y = torch.tensor([[2.0, 4.0], [nan, 5.0]])
>>> QF.nanmulsum(x, y, dim=1, keepdim=True)
tensor([[ 2.],
        [20.]])
>>> # All NaN pairs
>>> x = torch.tensor([nan, nan])
>>> y = torch.tensor([1.0, 2.0])
>>> QF.nanmulsum(x, y)
tensor(nan)
>>> # Mixed valid and invalid pairs
>>> x = torch.tensor([1.0, nan, 3.0, 4.0])
>>> y = torch.tensor([2.0, 5.0, nan, 6.0])
>>> QF.nanmulsum(x, y)
tensor(26.)

Warning

If all pairs along a dimension contain at least one NaN value, the result for that dimension is NaN. This differs from standard summation where NaN values would propagate through the entire calculation.

See also

mulsum(): Memory-efficient product sum without NaN handling. nansum(): NaN-aware sum function. nanmulmean(): NaN-aware memory-efficient product mean.