Source code for qfeval_functions.functions.soft_topk_bottomk

import logging
import typing

import torch
import torch.nn.functional as F

logger = logging.getLogger(__name__)


[docs] def soft_topk_bottomk( x: torch.Tensor, k: int, dim: int = -1, *, epsilon: float = 0.1, max_iter: int = 200, topk_only: bool = False, ) -> torch.Tensor: r"""Apply SoftTopKBottomK module along with given dimension. See `qfeval.extension.SoftTopKBottomK` for futher information. Examples: >>> x = torch.tensor([[1., 2., 3., 4., 5.], [6., 7., 8., 9., 10.]]) >>> soft_topk_bottomk(x, k=1, dim=1) tensor([[-0.7624, -0.2123, 0.0000, 0.2123, 0.7624], [-0.7624, -0.2123, 0.0000, 0.2123, 0.7624]]) >>> soft_topk_bottomk(x, k=1, dim=0) tensor([[-0.9999, -0.9999, -0.9999, -0.9999, -0.9999], [ 0.9999, 0.9999, 0.9999, 0.9999, 0.9999]]) >>> soft_topk_bottomk(x, k=1, dim=1, epsilon=1e-3) tensor([[-0.9965, -0.0035, 0.0000, 0.0035, 0.9965], [-0.9965, -0.0035, 0.0000, 0.0035, 0.9965]]) """ # 1. Move the target dimension to the last. x = x.transpose(-1, dim) # 2. Reshape input into two dimensional tensor. shape = x.shape x = x.reshape(-1, shape[-1]) # 3. Apply SoftTopKBottomK and restore original shape. x = _soft_topk_bottomk( x, k, epsilon=epsilon, max_iter=max_iter, topk_only=topk_only ) x = x.reshape(shape).transpose(-1, dim) return x
def soft_topk( x: torch.Tensor, k: int, dim: int = -1, *, epsilon: float = 0.1, max_iter: int = 200, ) -> torch.Tensor: r"""Apply soft top-k operator along with given dimension. See `qfeval.extension.SoftTopk` for futher information. Examples: >>> x = torch.tensor([[1., 2., 3., 4., 5.], [6., 7., 8., 9., 10.]]) >>> soft_topk(x, k=1, dim=0) tensor([[0.0001, 0.0001, 0.0001, 0.0001, 0.0001], [0.9999, 0.9999, 0.9999, 0.9999, 0.9999]]) >>> soft_topk(x, k=2, dim=1) tensor([[0.0000, 0.0006, 0.0783, 0.9217, 0.9994], [0.0000, 0.0006, 0.0783, 0.9217, 0.9994]]) >>> soft_topk(x, k=1, dim=1, epsilon=1e-3) tensor([[0.0000, 0.0000, 0.0000, 0.5000, 0.5000], [0.0000, 0.0000, 0.0000, 0.5000, 0.5000]]) """ return soft_topk_bottomk( x, k, dim, epsilon=epsilon, max_iter=max_iter, topk_only=True ) def _soft_topk_bottomk( scores: torch.Tensor, k: int, *, epsilon: float = 0.1, max_iter: int = 200, topk_only: bool = False, ) -> torch.Tensor: assert epsilon > 0, f"epsilon must be greather than 0, but: {epsilon}" if not scores.isfinite().all(): raise ValueError("Input tensor has nan or inf elements.") scores = scores - scores.mean(dim=-1, keepdim=True) scores = scores / scores.std(dim=-1, keepdim=True).clamp(min=1e-6) bs, dim = scores.size() if topk_only: anchors = torch.tensor([-1, 1]).to(scores) else: # Each element represents anchors for {bottom-k, middle, top-k}. anchors = torch.tensor([-1, 0, 1]).to(scores) C = (scores[:, :, None] - anchors[None, None, :]) ** 2 C = C / C.amax(dim=(1, 2), keepdim=True).detach() assert dim - (1 if topk_only else 2) * k >= 0 mu = torch.ones(dim).to(scores) / dim if topk_only: nu = torch.tensor([dim - k, k]).to(scores) / dim else: nu = torch.tensor([k, dim - 2 * k, k]).to(scores) / dim Gamma: torch.Tensor = _sinkhorn(C, mu, nu, epsilon, max_iter) if topk_only: return Gamma[:, :, 1] * dim else: return (Gamma[:, :, 2] - Gamma[:, :, 0]) * dim class _Sinkhorn(torch.autograd.Function): """Sinkhorn algorithm for regularized optimal transport.""" @staticmethod def forward( # type: ignore[override] ctx: typing.Any, C: torch.Tensor, mu: torch.Tensor, nu: torch.Tensor, epsilon: float, max_iter: int, ) -> torch.Tensor: """Returns optimal transport plan. Args: ctx (typing.Any): Context object. C (torch.Tensor): Cost matrix in the shape of `(B, N, M)`. mu (torch.Tensor): Source vector in the shape of `(1, N, 1)`. nu (torch.Tensor): Target vector in the shape of `(1, 1, M)`. epsilon (float): Entropic-regularization parameter. max_iter (int): Maximum number of iterations. Returns: torch.Tensor: Optimal transport plan in the shape of `(B, N, M)`. """ with torch.no_grad(): # type: ignore[no-untyped-call] if epsilon > 1e-2: Gamma = _sinkhorn_forward(C, mu, nu, epsilon, max_iter) if bool(torch.any(Gamma != Gamma)): logger.info("Nan appeared in Gamma, re-computing...") Gamma = _sinkhorn_forward_stabilized( C, mu, nu, epsilon, max_iter ) else: Gamma = _sinkhorn_forward_stabilized( C, mu, nu, epsilon, max_iter ) ctx.save_for_backward(mu, nu, Gamma) ctx.epsilon = epsilon return Gamma @staticmethod def backward( # type: ignore[override] ctx: typing.Any, grad_output_Gamma: torch.Tensor ) -> typing.Any: """Returns gradient with respect to cost matrix.""" epsilon = ctx.epsilon mu, nu, Gamma = ctx.saved_tensors with torch.no_grad(): # type: ignore[no-untyped-call] grad_C = _sinkhorn_backward( grad_output_Gamma, Gamma, mu, nu, epsilon ) return grad_C, None, None, None, None def _sinkhorn_forward( C: torch.Tensor, mu: torch.Tensor, nu: torch.Tensor, epsilon: float, max_iter: int, ) -> torch.Tensor: bs, n, k_ = C.size() v = torch.ones([bs, 1, k_], device=C.device) / (k_) G = torch.exp(-C / epsilon) for _ in range(max_iter): u = mu / (G * v).sum(-1, keepdim=True) v = nu / (G * u).sum(-2, keepdim=True) Gamma = u * G * v return Gamma def _sinkhorn_forward_stabilized( C: torch.Tensor, mu: torch.Tensor, nu: torch.Tensor, epsilon: float, max_iter: int, ) -> torch.Tensor: bs, n, k_ = C.size() f = torch.zeros([bs, n, 1]).to(C) g = torch.zeros([bs, 1, k_], device=C.device) epsilon_log_mu = epsilon * torch.log(mu) epsilon_log_nu = epsilon * torch.log(nu) def min_epsilon_row(Z: torch.Tensor, epsilon: float) -> torch.Tensor: return -epsilon * torch.logsumexp((-Z) / epsilon, -1, keepdim=True) def min_epsilon_col(Z: torch.Tensor, epsilon: float) -> torch.Tensor: return -epsilon * torch.logsumexp((-Z) / epsilon, -2, keepdim=True) for _ in range(max_iter): f = min_epsilon_row(C - g, epsilon) + epsilon_log_mu g = min_epsilon_col(C - f, epsilon) + epsilon_log_nu Gamma = torch.exp((-C + f + g) / epsilon) return Gamma def _sinkhorn_backward( grad_output_Gamma: torch.Tensor, Gamma: torch.Tensor, mu: torch.Tensor, nu: torch.Tensor, epsilon: float, ) -> torch.Tensor: nu_ = nu[:, :, :-1] Gamma_ = Gamma[:, :, :-1] bs, n, k_ = Gamma.size() inv_mu = 1.0 / (mu.view([1, -1])) Kappa = torch.diag_embed(nu_.squeeze(-2)) - torch.matmul( Gamma_.transpose(-1, -2) * inv_mu.unsqueeze(-2), Gamma_ ) inv_Kappa = torch.inverse(Kappa) Gamma_mu = inv_mu.unsqueeze(-1) * Gamma_ L = Gamma_mu.matmul(inv_Kappa) G1 = grad_output_Gamma * Gamma g1 = G1.sum(-1) G21 = (g1 * inv_mu).unsqueeze(-1) * Gamma g1_L = g1.unsqueeze(-2).matmul(L) G22 = g1_L.matmul(Gamma_mu.transpose(-1, -2)).transpose(-1, -2) * Gamma G23 = -F.pad(g1_L, pad=(0, 1), mode="constant", value=0) * Gamma G2 = G21 + G22 + G23 del g1, G21, G22, G23, Gamma_mu g2 = G1.sum(-2).unsqueeze(-1) g2 = g2[:, :-1, :] G31 = -L.matmul(g2) * Gamma G32 = ( F.pad( inv_Kappa.matmul(g2).transpose(-1, -2), pad=(0, 1), mode="constant", value=0, ) * Gamma ) G3 = G31 + G32 grad_C = (-G1 + G2 + G3) / epsilon return typing.cast(torch.Tensor, grad_C) def _sinkhorn( C: torch.Tensor, mu: torch.Tensor, nu: torch.Tensor, epsilon: float = 0.1, max_iter: int = 200, ) -> torch.Tensor: """Returns optimal transport plan. Args: C (torch.Tensor): Cost matrix in the shape of `(B, N, M)`. mu (torch.Tensor): Source vector in the shape of `(N)`. nu (torch.Tensor): Target vector in the shape of `(M)`. epsilon (float): Entropic-regularization parameter. max_iter (int): Maximum number of iterations. Returns: torch.Tensor: Optimal transport plan in the shape of `(B, N, M)`. """ result: torch.Tensor = _Sinkhorn.apply( # type:ignore[no-untyped-call] C, mu[None, :, None], nu[None, None, :], epsilon, max_iter ) return result