qfeval_functions.functions.nanshift

nanshift(x, shift=1, dim=-1)[source]

Shift tensor elements along a dimension while preserving NaN positions.

This function shifts the valid (non-NaN) elements of a tensor along the specified dimension while keeping NaN values in their original positions. Unlike standard shifting operations, NaN values act as “immovable” elements that do not participate in the shifting process, allowing only valid data to be shifted around them.

This is particularly useful for time series data where missing values (represented as NaN) should remain in their temporal positions while valid observations are shifted for analysis purposes.

Parameters:
  • x (Tensor) – The input tensor to be shifted.

  • shift (int) – Number of positions to shift. Positive values shift towards higher indices, negative values shift towards lower indices. Default is 1.

  • dim (int) – The dimension along which to perform the shift. Default is -1 (last dimension).

Returns:

A tensor with the same shape as the input, where valid elements have been shifted along the specified dimension while NaN positions remain unchanged.

Return type:

Tensor

Example

>>> # Simple 1D shift with NaN values
>>> x = torch.tensor([1.0, nan, 3.0, 4.0, nan])
>>> QF.nanshift(x, shift=1)
tensor([nan, nan, 1., 3., nan])
>>> # Negative shift
>>> x = torch.tensor([1.0, nan, 3.0, 4.0, nan])
>>> QF.nanshift(x, shift=-1)
tensor([3., nan, 4., nan, nan])
>>> # 2D tensor shift along rows
>>> x = torch.tensor([[1.0, nan, 3.0],
...                   [4.0, 5.0, nan],
...                   [nan, 8.0, 9.0]])
>>> QF.nanshift(x, shift=1, dim=0)
tensor([[nan, nan, nan],
        [1., nan, nan],
        [nan, 5., 3.]])
>>> # 2D tensor shift along columns
>>> x = torch.tensor([[1.0, nan, 3.0],
...                   [4.0, 5.0, nan]])
>>> QF.nanshift(x, shift=1, dim=1)
tensor([[nan, nan, 1.],
        [nan, 4., nan]])
>>> # Large shift (wraps around valid elements)
>>> x = torch.tensor([1.0, nan, 3.0, 4.0, nan])
>>> QF.nanshift(x, shift=2)
tensor([nan, nan, nan, 1., nan])
>>> # All NaN tensor (no change)
>>> x = torch.tensor([nan, nan, nan])
>>> QF.nanshift(x, shift=1)
tensor([nan, nan, nan])

Warning

The shift operation wraps around the valid elements. For example, if there are 3 valid elements and shift=1, the last valid element becomes the first, and all others shift by one position.

See also

shift(): Standard shift function without NaN handling. group_shift(): Shift operation within groups. torch.roll: PyTorch’s standard tensor rolling function.