Source code for qfeval_functions.functions.shift

import math
import typing

import torch


@typing.overload
def shift(x: torch.Tensor, shifts: int, dims: int) -> torch.Tensor:
    pass


@typing.overload
def shift(
    x: torch.Tensor,
    shifts: typing.Tuple[int, ...],
    dims: typing.Tuple[int, ...],
) -> torch.Tensor:
    pass


[docs] def shift( x: torch.Tensor, shifts: typing.Union[int, typing.Tuple[int, ...]], dims: typing.Union[int, typing.Tuple[int, ...]], ) -> torch.Tensor: r"""Shifts array elements along specified dimensions.""" # 1. Force dims/shifts to be tuples. if isinstance(dims, int): dims = (dims,) if isinstance(shifts, int): shifts = (shifts,) * len(dims) if len(shifts) != len(dims): raise RuntimeError( f"Inconsistent number of dimensions: shifts={shifts} dims={dims}" ) # Prevent shifts from causing an out-of-index error. shifts = tuple( max(min(s, x.shape[dims[i]]), -x.shape[dims[i]]) for i, s in enumerate(shifts) ) # 2. Build a mask to fill NaNs. mask = torch.zeros_like(x, dtype=torch.bool, device=x.device) for shift, dim in zip(shifts, dims): s = slice(shift, None) if shift < 0 else slice(0, shift) key = tuple(s if i == dim else slice(None) for i in range(len(x.shape))) mask[key] = True # 3. Apply torch.roll and fill rolled values with NaNs using the mask. x = x.roll(shifts, dims) return torch.where(mask, torch.as_tensor(math.nan).to(x), x)