Source code for qfeval_functions.functions.nanmulmean

import typing

import torch

from .fillna import fillna
from .mulsum import mulsum


[docs] def nanmulmean( x: torch.Tensor, y: torch.Tensor, dim: typing.Union[int, typing.Tuple[int, ...]] = (), keepdim: bool = False, *, _ddof: int = 0, ) -> torch.Tensor: r"""Compute the mean of element-wise product, ignoring NaN values, in a memory-efficient way. This function calculates the mean of the element-wise product of two tensors ``nanmean(x * y, dim)`` without creating the intermediate product tensor in memory, while also excluding NaN values from the computation. If either :attr:`x` or :attr:`y` has a NaN at a given position, that pair is excluded from the mean calculation. The function is mathematically equivalent to ``nanmean(x * y, dim)`` but uses a more memory-efficient implementation that avoids materializing the full product tensor, making it suitable for large tensor operations and complex broadcasting patterns. The NaN-aware mean is computed as: .. math:: \text{nanmulmean}(X, Y) = \frac{\sum_{i \text{ valid}} X_i \cdot Y_i}{ N_{\text{valid}} - \text{ddof}} where the sum is over valid (non-NaN) pairs and :math:`N_{\text{valid}}` is the number of valid pairs. Args: x (Tensor): The first input tensor. y (Tensor): The second input tensor. Must be broadcastable with :attr:`x`. dim (int or tuple of ints, optional): The dimension(s) along which to compute the mean. If not specified (default is empty tuple), the mean is computed over all dimensions. keepdim (bool, optional): Whether the output tensor has :attr:`dim` retained or not. Default is False. _ddof (int, optional): Delta degrees of freedom for internal calculations. The divisor used is ``N_valid - _ddof``, where ``N_valid`` is the number of valid (non-NaN) pairs. Default is 0. This is an internal parameter. Returns: Tensor: The mean of the element-wise product computed only over valid (non-NaN) pairs. The shape depends on the input dimensions, :attr:`dim`, and :attr:`keepdim` parameters. Example: >>> # Simple element-wise product mean with NaN >>> x = torch.tensor([1.0, 2.0, nan, 4.0]) >>> y = torch.tensor([2.0, nan, 4.0, 5.0]) >>> QF.nanmulmean(x, y) tensor(11.) >>> # 2D tensors with NaN values >>> x = torch.tensor([[1.0, nan, 3.0], ... [4.0, 5.0, nan]]) >>> y = torch.tensor([[2.0, 4.0, nan], ... [nan, 10.0, 12.0]]) >>> QF.nanmulmean(x, y, dim=1) tensor([ 2., 50.]) >>> # Broadcasting with NaN handling >>> x = torch.tensor([[1.0], [nan], [3.0]]) >>> y = torch.tensor([2.0, 4.0, nan]) >>> QF.nanmulmean(x, y) tensor(6.) >>> # With keepdim >>> x = torch.tensor([[1.0, nan], [3.0, 4.0]]) >>> y = torch.tensor([[2.0, 4.0], [nan, 5.0]]) >>> QF.nanmulmean(x, y, dim=1, keepdim=True) tensor([[ 2.], [20.]]) .. seealso:: :func:`mulmean`: Memory-efficient product mean without NaN handling. :func:`nanmean`: NaN-aware mean function. :func:`mulsum`: Memory-efficient product sum function. """ result = mulsum(fillna(x), fillna(y), dim=dim, keepdim=keepdim) x_mask = (~x.isnan()).to(result) y_mask = (~y.isnan()).to(result) return result / (mulsum(x_mask, y_mask, dim=dim, keepdim=keepdim) - _ddof)