Source code for qfeval_functions.functions.bincount

import torch

from .apply_for_axis import apply_for_axis


[docs] def bincount( x: torch.Tensor, minlength: int = 0, dim: int = -1 ) -> torch.Tensor: r"""Count number of occurrences of each value in a tensor. This function computes the frequency of each non-negative integer value in the input tensor along the specified dimension. The output tensor's size along the counted dimension is the maximum of :attr:`minlength` and the largest value in the input plus one. Each position in the output contains the count of that index value in the input. Args: x (Tensor): The input tensor containing non-negative integer values. minlength (int, optional): Minimum length of the output tensor along the counted dimension. If the maximum value in :attr:`x` is less than :attr:`minlength` - 1, the output is padded with zeros. Default is 0. dim (int, optional): The dimension along which to count values. Default is -1 (the last dimension). Returns: Tensor: A tensor where the value at index `i` along the specified dimension contains the count of occurrences of value `i` in the corresponding slice of the input tensor. Example: >>> x = torch.tensor([1, 2, 2, 3, 3, 1]) >>> QF.bincount(x) tensor([0, 2, 2, 2]) >>> x = torch.tensor([[1, 2, 2], [3, 3, 1]]) >>> QF.bincount(x, dim=1) tensor([[0, 1, 2, 0], [0, 1, 0, 2]]) >>> x = torch.tensor([1, 2]) >>> QF.bincount(x, minlength=5) tensor([0, 1, 1, 0, 0]) """ def _bincount(x: torch.Tensor) -> torch.Tensor: if x.numel() == 0: # handle edge case where x is empty # https://github.com/pytorch/pytorch/blob/5fb11cda4fe60c1a7b30e6c844f84ce8933ef953/torch/_numpy/_funcs_impl.py#L630 return torch.zeros((1, minlength), dtype=torch.int64) n = max(minlength, int(torch.amax(x)) + 1) zeros = torch.zeros_like(x[:1, :1]).expand(x.shape[0], n) ones = torch.ones_like(x[:1, :1]).expand(x.shape) return torch.scatter_add(zeros, dim, x, ones) return apply_for_axis(_bincount, x, dim)