Source code for qfeval_functions.functions.pca

from dataclasses import dataclass

import torch

from .covar import covar


@dataclass
class PcaResult:
    components: torch.Tensor
    explained_variance: torch.Tensor


[docs] def pca(x: torch.Tensor) -> PcaResult: """Computes principal components on the given input `x`. The returned value represents principal components. Specifically, `result[*, i, :]` represents the :math:`(i+1)`-th largest principal component of the batch specified by `*`. Shape: - x: :math:`(*, S, D)` where `*` means any number of additional dimensions, `S` means the number of sections, and `D` means the number of dimensions. In qfeval, dimensions and sections often represent symbols and timestamps respectively. Return: - components (Tensor): `components[i]` represents the eigenvalue of the :math:`i`-th component. - explained_variance (Tensor): `explained_variance[i, j]` represents the :math:`i`-th component's weight for the `j`-th feature. """ return pca_cov(covar(x[..., None], x[..., None, :], dim=-3))
def pca_cov(cov: torch.Tensor) -> PcaResult: """Computes principal components on the given covariance `cov`. The returned value represents principal components. Specifically, `result[*, i, :]` represents the :math:`(i+1)`-th largest principal component of the batch specified by `*`. Shape: - cov: :math:`(*, D, D)` where `*` means any number of additional dimensions, and `D` means the number of dimensions. Return: - components (Tensor): `components[i]` represents the eigenvalue of the :math:`i`-th component. - explained_variance (Tensor): `explained_variance[i, j]` represents the :math:`i`-th component's weight for the `j`-th feature. """ batch_shape = cov.shape[:-2] _, s, v = torch.linalg.svd(cov.unsqueeze(0).flatten(end_dim=-3)) return PcaResult( components=v.reshape(batch_shape + v.shape[1:]), explained_variance=s.reshape(batch_shape + s.shape[1:]), )