Source code for qfeval_functions.functions.nanvar

import math
import typing

import torch

from .nanmean import nanmean
from .nansum import nansum


[docs] def nanvar( x: torch.Tensor, dim: typing.Union[int, typing.Tuple[int, ...]] = (), unbiased: bool = True, keepdim: bool = False, ) -> torch.Tensor: r"""Compute the variance of a tensor, ignoring NaN values. This function calculates the variance of tensor elements along specified dimensions while excluding NaN values. The variance can be computed using either the unbiased estimator (dividing by N-1) or the biased estimator (dividing by N), where N is the number of non-NaN elements. Args: x (Tensor): The input tensor. dim (int or tuple of ints, optional): The dimension(s) along which to compute the variance. If not specified (default is an empty tuple), the variance is computed over all elements. Can be a single dimension or multiple dimensions. unbiased (bool, optional): If ``True`` (default), uses Bessel's correction and divides by ``N-1`` where ``N`` is the number of non-NaN elements. If ``False``, divides by ``N``. keepdim (bool, optional): If ``True``, the output tensor has the same number of dimensions as the input, with the reduced dimensions having size 1. If ``False`` (default), the reduced dimensions are removed. Returns: Tensor: The variance of non-NaN elements. The shape depends on the :attr:`dim` and :attr:`keepdim` parameters. Example: >>> x = torch.tensor([1.0, 2.0, nan, 4.0, 5.0]) >>> QF.nanvar(x) tensor(3.3333) >>> # With biased estimator >>> QF.nanvar(x, unbiased=False) tensor(2.5000) >>> # 2D tensor with dimension specification >>> x = torch.tensor([[1.0, 2.0, nan], ... [4.0, nan, 6.0]]) >>> QF.nanvar(x, dim=1) tensor([0.5000, 2.0000]) >>> # Keep dimensions >>> QF.nanvar(x, dim=1, keepdim=True) tensor([[0.5000], [2.0000]]) """ n = (~x.isnan()).to(x).sum(dim=dim, keepdim=True) if unbiased: n = n - 1 # Prevent gradients from being divided by zeros. x = torch.where(n <= 0, torch.as_tensor(math.nan).to(x), x) x2 = (x - nanmean(x, dim=dim, keepdim=True)) ** 2 r = nansum(x2, dim=dim, keepdim=True) / n return r.sum(dim=dim, keepdim=keepdim)