Source code for qfeval_functions.functions.stable_sort

import math
import typing

import torch


class StableSortResult(typing.NamedTuple):
    values: torch.Tensor
    indices: torch.Tensor


def _unsafe_stable_sort(x: torch.Tensor, dim: int) -> StableSortResult:
    r"""Sorts the given tensor without NaN values preserving the order of
    equivalent elements.
    """

    # 1. First sort values using an unstable algorithm.
    # NOTE: v should be the same as a stable algorithm computes.
    v, idx = x.sort(dim=dim)

    # 2. Computes the rank of v.
    v_rank = torch.ne(v, v.roll(1, dim)).cumsum(dim=dim)

    # 3. Sort idx within a group having the same v's rank.
    idx = v_rank * idx.shape[dim] + idx
    idx = idx.sort(dim=dim).values % idx.shape[dim]

    # 4. Return the result of a stable sorting.
    return StableSortResult(v, idx)


[docs] def stable_sort(x: torch.Tensor, dim: int = -1) -> StableSortResult: r"""Sorts the given tensor preserving the order of equivalent elements.""" # If x has no NaN values, it is okay to apply _unsafe_stable_sort. isnan = x.isnan() if not isnan.any(): return _unsafe_stable_sort(x, dim) # If x has NaN values, computes the stable result using two results. safe_x = x.nan_to_num(math.inf, math.inf, -math.inf) result = _unsafe_stable_sort(safe_x, dim) # Mapping NaN to 0.0, +inf to -1.0 and other values to -2.0. # NOTE: Multiplying -inf to the result should recover NaN and +inf. nan_result = _unsafe_stable_sort( torch.where( safe_x.isposinf(), x.nan_to_num(0.0, -1.0), torch.full_like(x, -2.0), ), dim, ) return StableSortResult( torch.where( result.values.isposinf(), nan_result.values * -math.inf, result.values, ), torch.where( result.values.isposinf(), nan_result.indices, result.indices, ), )