Source code for qfeval_functions.functions.group_shift

import math
import typing

import torch

from .shift import shift as _shift

AggregateFunction = typing.Literal["any", "all"]


def reduce_nan_patterns(
    x: torch.Tensor,
    dim: int = -1,
    refdim: int = 0,
    agg_f: AggregateFunction = "any",
) -> torch.Tensor:
    r"""Creates a mask for group shift.

    A mask is a one-dimensional boolean tensor that represents the pattern
    of observed values in a reference dimension (`refdim`).
    i.e., `True` values correspond to the locations of non-nan values.

    Args:
        - x: The input tensor. This should have at least 2 dimensions.
        - dim: The dimension along which `x` will be shifted.
        - refdim: The reference dimension to extract a pattern of (non-)nans.
        - agg_f: The function for aggregating all other dimensions.
    Returns:
        - mask: a 1-D boolean tensor

    Examples:
        >>> x = torch.tensor([
        ...     [1.0, 2.0, nan, 1.0],
        ...     [2.0, 4.0, nan, 2.0],
        ...     [3.0, nan, nan, 3.0],
        ...     [1.0, 1.0, 1.0, 1.0],
        ... ])
        >>> reduce_nan_patterns(x, -1, 0)
        tensor([ True, False, False,  True])
        >>> # As x.dim() == 2, agg_f does not affect the results
        >>> reduce_nan_patterns(x, -1, 0, agg_f="all")
        tensor([ True, False, False,  True])
        >>> reduce_nan_patterns(x, 0, 1)
        tensor([False, False, False,  True])
    """
    # transpose x so that the first dimension is dim and the second is refdim
    # NOTE: currently, dimensions added here will not be squeezed, since
    # they do not affect "any" and "all" aggregations
    dim = dim % len(x.shape) + 2
    refdim = refdim % len(x.shape) + 2
    x = x[None, None, ...].transpose(0, dim).transpose(1, refdim)

    # aggregate values in all dimensions but dim and refdim
    reduced = (~x.isnan()).flatten(2)
    if agg_f == "any":
        reduced = reduced.any(dim=-1)
    elif agg_f == "all":
        reduced = reduced.all(dim=-1)
    else:
        raise NotImplementedError("Unknown aggregate function.")

    # return True if all (aggregated) values in redim are True
    return reduced.all(dim=1)


[docs] def group_shift( x: torch.Tensor, shift: int = 1, dim: int = 0, mask: typing.Optional[torch.Tensor] = None, refdim: typing.Optional[int] = -1, agg_f: AggregateFunction = "any", ) -> torch.Tensor: r"""Shift tensor elements along a dimension, skipping masked positions. This function performs a selective shift operation where only elements at positions marked as ``True`` in the mask are shifted, while positions marked as ``False`` are skipped. This is particularly useful for time series data where certain time points should be excluded from the shift operation, such as weekends in financial data or missing observations. The function works by reordering elements based on the mask, applying the shift only to valid positions, and then restoring the original order. Elements at masked-out positions are replaced with NaN values. .. warning:: Unmasked values (where mask is ``False``) are replaced with NaN before shifting. This means that even a zero shift (``shift=0``) will not return the original input, as unmasked values will be NaN in the output. Args: x (Tensor): The input tensor to be shifted. shift (int, optional): The number of positions to shift. Positive values shift forward (toward higher indices), negative values shift backward. Default is 1. dim (int, optional): The dimension along which to perform the shift. Default is 0. mask (Tensor, optional): A 1D boolean tensor where ``True`` indicates positions to include in the shift, and ``False`` indicates positions to skip. Length must equal ``x.shape[dim]``. If not provided, :attr:`refdim` must be specified. refdim (int, optional): If specified, automatically generates a mask using :func:`reduce_nan_patterns`. This identifies valid positions based on non-NaN patterns in the reference dimension. Default is -1. agg_f (str, optional): Aggregation function used when generating mask from :attr:`refdim`. Can be "any" or "all". Default is "any". Returns: Tensor: A tensor of the same shape as the input, with elements shifted according to the mask. Unmasked positions contain NaN. Example: >>> # Basic masked shift >>> x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) >>> mask = torch.tensor([True, False, True, False, True]) >>> shifted = QF.group_shift(x, shift=1, dim=0, mask=mask) >>> shifted tensor([nan, nan, 1., nan, 3.]) >>> # Shift with existing NaN values >>> x = torch.tensor([1.0, nan, 2.0, nan, 3.0]) >>> mask = torch.tensor([True, False, True, False, True]) >>> shifted = QF.group_shift(x, shift=1, dim=0, mask=mask) >>> shifted tensor([nan, nan, 1., nan, 2.]) >>> # 2D tensor with automatic mask generation >>> x = torch.tensor([[1.0, 2.0, nan, 4.0], ... [5.0, 6.0, nan, 8.0], ... [9.0, 10., nan, 12.]]) >>> # Use refdim=0 to generate mask from first row's NaN pattern >>> shifted = QF.group_shift(x, shift=1, dim=1, refdim=0) >>> shifted tensor([[nan, 1., nan, 2.], [nan, 5., nan, 6.], [nan, 9., nan, 10.]]) .. seealso:: :func:`reduce_nan_patterns`: For understanding mask generation from reference dimensions. """ n = x.shape[dim] # 1. Create a priority index from the mask priority = torch.arange(n, device=x.device) if mask is not None: mask = mask.squeeze() assert mask.shape == (n,) elif refdim is not None: mask = reduce_nan_patterns(x, dim, refdim, agg_f) else: raise ValueError("Either mask or refdim must be set.") priority = priority + (~mask).int() * n index = torch.argsort(priority) reversed_index = torch.argsort(index) # 2. Move the target dimension to the top to make data manipulation easier. x = x.transpose(0, dim) shape = x.shape x = x.reshape((n, -1)) # 3. Gather valid values, apply shift, and scatter them. # fill unmasked indices with nans y = torch.where( mask[:, None].expand(n, x.shape[1]), x, torch.tensor(math.nan).to(x), ) # sort y so that masked values come first y = y[index, :] # apply shift y = torch.where( mask[index, None].expand(n, x.shape[1]), _shift(y, shift, 0), y, ) x = y[reversed_index, :] # 4. Restore the shape and the order of dimensions. return x.reshape(shape).transpose(0, dim)