Source code for flowtorch.analysis.svd

"""Classes and functions wrapping around *torch.linalg.svd*.
"""

# third party packages
import torch as pt
# flowtorch packages
from flowtorch.data.utils import format_byte_size


[docs]class SVD(object): """Compute and analyze the SVD of a data matrix. :param U: left singular vectors :type U: pt.Tensor :param s: singular values :type s: pt.Tensor :param s_rel: singular values normalized with their sum in percent :type s_rel: pt.Tensor :param s_cum: cumulative normalized singular values in percent :type s_cum: pt.Tensor :param V: right singular values :type V: pt.Tensor :param rank: rank used for truncation :type rank: int :param opt_rank: optimal rank according to SVHT :type opt_rank: int :param required_memory: memory required to store the truncation U, s, and V :type required_memory: int Examples >>> from flowtorch import DATASETS >>> from flowtorch.data import FOAMDataloader >>> from flowtorch.analysis import SVD >>> loader = FOAMDataloader(DATASETS["of_cavity_ascii"]) >>> data = loader.load_snapshot("p", loader.write_times[1:]) >>> svd = SVD(data, rank=100) >>> print(svd) SVD of a 400x5 data matrix Selected/optimal rank: 5/2 data type: torch.float32 (4b) truncated SVD size: 7.9297Kb >>> svd.s_rel tensor([9.9969e+01, 3.0860e-02, 3.0581e-04, 7.8097e-05, 3.2241e-05]) >>> svd.s_cum tensor([ 99.9687, 99.9996, 99.9999, 100.0000, 100.0000]) >>> svd.U.shape torch.Size([400, 5]) """ def __init__(self, data_matrix: pt.Tensor, rank: int = None): shape = data_matrix.shape assert len(shape) == 2, ( f"The data matrix must be a 2D tensor.\ The provided data matrix has {len(shape)} dimensions." ) self._rows, self._cols = shape U, s, VH = pt.linalg.svd(data_matrix, full_matrices=False) self._opt_rank = self._optimal_rank(s) self.rank = self.opt_rank if rank is None else rank self._U = U[:, :self.rank] self._s = s[:self.rank] self._V = VH.conj().T[:, :self.rank] def _optimal_rank(self, s: pt.Tensor) -> int: """Compute the optimal singular value hard threshold. This function implements the svht_ rank estimation. .. _svht: https://doi.org/10.1109/TIT.2014.2323359 :param s: sorted singular values :type s: pt.Tensor :return: optimal rank for truncation :rtype: int """ beta = min(self._rows, self._cols) / max(self._rows, self._cols) omega = 0.56*beta**3 - 0.95*beta**2 + 1.82*beta + 1.43 tau_star = omega * pt.median(s) closest = pt.argmin((s - tau_star).abs()).item() if s[closest] > tau_star: return closest + 1 else: return closest
[docs] def reconstruct(self, rank: int = None) -> pt.Tensor: """Reconstruct the data matrix for a given rank. :param rank: rank used to compute a truncated reconstruction :type rank: int, optional :return: reconstruction of the input data matrix :rtype: pt.Tensor """ r_rank = self.rank if rank is None else max(min(rank, self.rank), 1) return self.U[:, :r_rank] @ pt.diag(self.s[:r_rank]) @ self.V[:, :r_rank].conj().T
@property def U(self) -> pt.Tensor: return self._U @property def s(self) -> pt.Tensor: return self._s @property def s_rel(self) -> pt.Tensor: return self._s / self._s.sum() * 100.0 @property def s_cum(self) -> pt.Tensor: s_sum = self._s.sum().item() return pt.tensor( [self._s[:i].sum().item() / s_sum * 100.0 for i in range(1, self._s.shape[0]+1)], dtype=self._s.dtype ) @property def V(self) -> pt.Tensor: return self._V @property def rank(self) -> int: return self._rank @rank.setter def rank(self, value: int): self._rank = max(min(self._cols, value), 1) self._cols = self._rank @property def opt_rank(self) -> int: return self._opt_rank @property def required_memory(self) -> int: """Compute the memory size in bytes of the truncated SVD. :return: cumulative size of truncated U, s, and V tensors in bytes :rtype: int """ return (self.U.element_size() * self.U.nelement() + self.s.element_size() * self.s.nelement() + self.V.element_size() * self.V.nelement()) def __repr__(self) -> str: return f"{self.__class__.__qualname__}(data_matrix, rank={self.rank})" def __str__(self) -> str: ms = [] ms.append(f"SVD of a {self._rows}x{self._cols} data matrix") ms.append(f"Selected/optimal rank: {self.rank}/{self.opt_rank}") ms.append(f"data type: {self.U.dtype} ({self.U.element_size()}b)") size, unit = format_byte_size(self.required_memory) ms.append("truncated SVD size: {:1.4f}{:s}".format(size, unit)) return "\n".join(ms)