qfeval_functions.functions.nanshift
- nanshift(x, shift=1, dim=-1)[source]
Shift tensor elements along a dimension while preserving NaN positions.
This function shifts the valid (non-NaN) elements of a tensor along the specified dimension while keeping NaN values in their original positions. Unlike standard shifting operations, NaN values act as “immovable” elements that do not participate in the shifting process, allowing only valid data to be shifted around them.
This is particularly useful for time series data where missing values (represented as NaN) should remain in their temporal positions while valid observations are shifted for analysis purposes.
- Parameters:
- Returns:
A tensor with the same shape as the input, where valid elements have been shifted along the specified dimension while NaN positions remain unchanged.
- Return type:
Example
>>> # Simple 1D shift with NaN values >>> x = torch.tensor([1.0, nan, 3.0, 4.0, nan]) >>> QF.nanshift(x, shift=1) tensor([nan, nan, 1., 3., nan])
>>> # Negative shift >>> x = torch.tensor([1.0, nan, 3.0, 4.0, nan]) >>> QF.nanshift(x, shift=-1) tensor([3., nan, 4., nan, nan])
>>> # 2D tensor shift along rows >>> x = torch.tensor([[1.0, nan, 3.0], ... [4.0, 5.0, nan], ... [nan, 8.0, 9.0]]) >>> QF.nanshift(x, shift=1, dim=0) tensor([[nan, nan, nan], [1., nan, nan], [nan, 5., 3.]])
>>> # 2D tensor shift along columns >>> x = torch.tensor([[1.0, nan, 3.0], ... [4.0, 5.0, nan]]) >>> QF.nanshift(x, shift=1, dim=1) tensor([[nan, nan, 1.], [nan, 4., nan]])
>>> # Large shift (wraps around valid elements) >>> x = torch.tensor([1.0, nan, 3.0, 4.0, nan]) >>> QF.nanshift(x, shift=2) tensor([nan, nan, nan, 1., nan])
>>> # All NaN tensor (no change) >>> x = torch.tensor([nan, nan, nan]) >>> QF.nanshift(x, shift=1) tensor([nan, nan, nan])
Warning
The shift operation wraps around the valid elements. For example, if there are 3 valid elements and
shift=1, the last valid element becomes the first, and all others shift by one position.See also
shift(): Standard shift function without NaN handling.group_shift(): Shift operation within groups.torch.roll: PyTorch’s standard tensor rolling function.