Source code for qfeval_functions.functions.cumcount

import torch

from .stable_sort import stable_sort


[docs] def cumcount(x: torch.Tensor, dim: int = -1) -> torch.Tensor: r"""Number each occurrence of unique values along a dimension. This function assigns a cumulative count to each occurrence of unique values along the specified dimension. For each unique value, the first occurrence is numbered 0, the second occurrence is numbered 1, and so on. This is similar to the behavior of ``pandas.GroupBy.cumcount()``. Args: x (Tensor): The input tensor containing values to be counted. dim (int, optional): The dimension along which to perform cumulative counting. Default is -1 (the last dimension). Returns: Tensor: A tensor of the same shape as the input, where each element contains the cumulative count (0-indexed) of that value's occurrence along the specified dimension. Example: >>> x = torch.tensor([1, 2, 1, 3, 2, 1, 3]) >>> QF.cumcount(x) tensor([0, 0, 1, 0, 1, 2, 1]) >>> x = torch.tensor([[1, 2, 1, 2], ... [3, 3, 4, 3]]) >>> QF.cumcount(x, dim=1) tensor([[0, 0, 1, 1], [0, 1, 0, 2]]) >>> x = torch.tensor([[1, 2, 3], ... [1, 2, 3], ... [1, 2, 3]]) >>> QF.cumcount(x, dim=0) tensor([[0, 0, 0], [1, 1, 1], [2, 2, 2]]) """ # 1. Flatten the input tensor. dim = len(x.shape) + dim if dim < 0 else dim x = x.transpose(0, dim) shape = x.shape x = x.reshape(x.shape[0], -1) # 2. Computes the index of a group for each sorted element. v, idx = stable_sort(x, dim=0) a = torch.arange(x.shape[0], device=x.device)[:, None] b = torch.where( torch.eq(v, v.roll(1, 0)), torch.zeros_like(a), a, ) g_idx = a - b.cummax(0).values # 3. Distribute the indexes to the original locations. g_idx = idx.scatter(0, idx, g_idx) # 4. Restore the shape. return g_idx.reshape(shape).transpose(0, dim) # type: ignore[no-any-return]