Source code for qfeval_functions.functions.mmax

import torch

from .rcummax import rcummax


[docs] def mmax(x: torch.Tensor, span: int, dim: int = -1) -> torch.Tensor: r"""Compute the moving (sliding window) maximum of a tensor. This function calculates the maximum value within a sliding window of size :attr:`span` along the specified dimension. The output tensor has the same shape as the input tensor. For positions where the sliding window cannot fully cover preceding elements (i.e., the first ``span - 1`` elements along the selected dimension), the result is computed using available values by padding with the first element. The moving maximum is computed as: .. math:: \text{MMAX}[i] = \max_{j=\max(0, i-\text{span}+1)}^{i} x[j] Args: x (Tensor): The input tensor containing values. span (int): The size of the sliding window. Must be positive. dim (int, optional): The dimension along which to compute the moving maximum. Default is -1 (the last dimension). Returns: Tensor: A tensor of the same shape as the input, containing the moving maximum values. Example: >>> # Simple moving maximum with window size 3 >>> x = torch.tensor([1.0, 5.0, 3.0, 8.0, 2.0]) >>> QF.mmax(x, span=3) tensor([1., 5., 5., 8., 8.]) >>> # 2D tensor with moving maximum along columns >>> x = torch.tensor([[1.0, 4.0, 2.0, 6.0], ... [3.0, 1.0, 5.0, 2.0], ... [2.0, 7.0, 3.0, 4.0]]) >>> QF.mmax(x, span=2, dim=1) tensor([[1., 4., 4., 6.], [3., 3., 5., 5.], [2., 7., 7., 4.]]) >>> # Moving maximum along the first dimension >>> x = torch.tensor([[1.0, 2.0], ... [4.0, 1.0], ... [2.0, 5.0], ... [3.0, 2.0]]) >>> QF.mmax(x, span=3, dim=0) tensor([[1., 2.], [4., 2.], [4., 5.], [4., 5.]]) >>> # Handling negative values >>> x = torch.tensor([-2.0, -1.0, -5.0, -3.0, -1.0]) >>> QF.mmax(x, span=2) tensor([-2., -1., -1., -3., -1.]) .. seealso:: :func:`mmin`: Moving minimum function. :func:`msum`: Moving sum function. """ # 1. Move the target dimension to the top to make data manipulation easier. x = x.transpose(0, dim) shape = x.shape x = x.reshape((shape[0], -1)) # 2. Reshape the target dimension into `(*, span)` with expanding the edge. pad_len = span * 2 - x.shape[0] % span pad = x[:1, :].expand(pad_len, -1) x = torch.cat((pad, x), dim=0) x = x.reshape((-1, span, x.shape[-1])) # 3. Calculate `max(x[i:i+span])` by splitting it into # `max(x[i:s], x[s:i+span])` where `s` is a multiple of `span`. They can # be calculated by cummax and rcummax. a = x.cummax(dim=1).values b = rcummax(x, dim=1).values x = torch.cat((torch.max(a[1:, :-1], b[:-1, 1:]), a[1:, -1:]), dim=1) x = x.reshape((-1, x.shape[-1]))[-shape[0] :] # 4. Restore the shape and the order of dimensions. return x.reshape(shape).transpose(0, dim)