"""Optimized DMD via gradient descent and backpropagation.
"""
# standard library packages
from typing import Union, Set, Callable, Tuple
from collections import defaultdict
from math import sqrt
# third party packages
import torch as pt
from numpy import pi
# flowtorch packages
from .dmd import DMD
DEFAULT_SCHEDULER_OPT = {"mode": "min", "factor": 0.5, "patience": 20, "min_lr": 1.0e-6}
def _create_conj_complex_pairs(ev: pt.Tensor) -> Tuple[pt.Tensor, pt.Tensor]:
"""Create indices to preserve conjugate complex eigenvalues.
For real input data, eivenvalues and eigenvectors must come as
conjugate complex pairs. If there is on odd number of eigenvalues,
one eigenvalue is reserved to be on the real axis. If there are
multiple eigenvalues without imaginary part and the number of eigenvalues
is odd, only the one with real part closest to unity is preserved.
:param ev: tensor of eigenvalues
:type ev: pt.Tensor
:return: indices of eigenvalues to keep and indices of eigenvalues
with conjugate complex pair
:rtype: Tuple[pt.Tensor, pt.Tensor]
"""
sort_imag = ev.imag.sort(descending=True).indices
zero_imag = ev.imag == 0
n = ev.size(0)
indices = pt.tensor(range(n), dtype=pt.int64)
pairs = indices[: n // 2]
if zero_imag.sum() <= 1: # unique real ev
keep = indices[sort_imag][: n // 2 + n % 2]
else: # multiple real ev
if n % 2 == 1: # preserve ev with real part closest to unity
keep_real = (ev[zero_imag].real - 1.0).abs().sort().indices[0]
keep_real = indices[zero_imag][keep_real]
keep = pt.cat((indices[sort_imag][: n // 2], keep_real.unsqueeze(-1)))
else: # do not preserve any real ev
keep = indices[sort_imag][: n // 2]
return keep, pairs
[docs]
def l2_loss(
label: pt.Tensor, prediction: pt.Tensor, eigvecs: pt.Tensor, eigvals: pt.Tensor
) -> pt.Tensor:
"""Compute the L2 norm of the prediction error.
Note: in contrast to the default norm function in PyTorch,
the norm is normalized by the number of elements in the
prediction tensor.
Note: the argument list is more generic than necessary for this
loss function; the eigenvectors and eigenvalues may be used for
regularization or to enforce physical constraints in more advanced
loss functions.
:param label: ground truth
:type label: pt.Tensor
:param prediction: predictions
:type prediction: pt.Tensor
:return: L2 norm normalized by number of elements
:rtype: pt.Tensor
"""
return (label - prediction).norm() / sqrt(prediction.numel())
[docs]
class EarlyStopping:
"""Provide stopping control for iterative optimization tasks."""
def __init__(
self,
patience: int = 40,
min_delta: float = 0.0,
checkpoint: str = None,
model: pt.nn.Module = None,
):
"""Initialize a new controller instance.
:param patience: number of iterations to wait for an improved
loss value before stopping; defaults to 40
:type patience: int, optional
:param min_delta: minimum reduction in the loss values to be
considered an improvement; avoids overly long optimization
with marginal improvements per iteration; defaults to 0.0
:type min_delta: float, optional
:param checkpoint: path at which to store the best known state
of the model; the state is not saved if None; defaults to None
:type checkpoint: str, optional
:param model: instance of PyTorch model; the model's state dict is
saved upon improvement of the loss function is a valid checkpoint
is provided; defaults to None
:type model: pt.nn.Module, optional
"""
self._patience = patience
self._min_delta = min_delta
self._chp = checkpoint
self._model = model
self._best_loss = float("inf")
self._counter = 0
self._stop = False
def __call__(self, loss: float) -> bool:
"""_summary_
:param loss: new loss value
:type loss: float
:return: boolean flag indicating if the optimization can be stopped
:rtype: bool
"""
if loss < self._best_loss - self._min_delta:
self._best_loss = loss
self._counter = 0
if self._chp is not None and self._model is not None:
pt.save(self._model.state_dict(), self._chp)
else:
self._counter += 1
if self._counter >= self._patience:
self._stop = True
return self._stop
[docs]
class OptDMD(pt.nn.Module):
"""Optimized DMD based on backpropagation and gradient descent.
For a detailed description this DMD variant refer to
`Weiner and Semaan (2023) <https://arxiv.org/abs/2312.12928>`_.
Examples
>>> import torch as pt
>>> from flowtorch.analysis import OptDMD
>>> dm = pt.rand((200, 100))
>>> dmd = DMD(dm, dt=1.0)
>>> dmd.train(stopping_options={'patience' : 80, 'checkpoint' : '/tmp/best_model.pt'})
>>> dmd.load_state_dict(pt.load('/tmp/best_model.pt'))
>>> train_loss = dmd.log['train_loss']
>>> val_loss = dmd.log['val_loss']
"""
def __init__(self, *dmd_args, **dmd_kwargs):
"""Construct an optimized DMD instance.
The arguments and keyword arguments passed to the constructor are
used to create a regular `DMD` instance for initialization. The
arguments specific to the optimization of eigenvectors and eigenvalues
are passed to the `train` method.
Warning: the implementation was only tested rigorously for real input
data; eigenvectors and eigenvalues are enforced to have complex conjugate
pairs in case of real input data. If there is an odd number of modes,
the mode-eigenvalue-pair whose imaginary part
"""
super(OptDMD, self).__init__()
self._dmd = DMD(*dmd_args, **dmd_kwargs)
if self._dmd._complex:
n_modes = self._dmd.eigvals.size(0)
keep = pt.tensor(range(n_modes), dtype=pt.int64)
self._conj_indices = pt.zeros(n_modes, dtype=pt.bool)
else:
keep, pairs = _create_conj_complex_pairs(self._dmd.eigvals)
self._conj_indices = pairs
scaled_modes = self._dmd._modes[:, keep] * self._dmd.amplitude[keep]
self._eigvecs = pt.nn.Parameter(scaled_modes)
self._eigvals = pt.nn.Parameter(self._dmd._eigvals[keep].clone())
self._log = defaultdict(list)
def _create_train_val_split(
self, train_size: Union[int, float], val_size: Union[int, float]
) -> tuple:
"""Split the data matrix into training and validation data.
In contrast to general machine learning tasks, the data is not
split randomly. The validation data are taken from the end of
the snapshot sequence. The provided values for training and validation
size are weighted by their sum such that always a valid split results.
:param train_size: size of the training data set; can be the number
of the snapshots or a fraction of the overall number of snapshots
:type train_size: Union[int, float]
:param val_size: size of the validation data set; can be the number
of the snapshots or a fraction of the overall number of snapshots
:type val_size: Union[int, float]
:return: training and validation data wrapped as `TensorDataset`
:rtype: tuple
"""
data = pt.utils.data.TensorDataset(
pt.tensor(range(self._dmd._cols), dtype=pt.int64))
n_train = int(len(data) * train_size / (train_size + val_size))
n_val = len(data) - n_train
if n_val > 0:
return pt.utils.data.TensorDataset(
data[:n_train][0]
), pt.utils.data.TensorDataset(data[n_train:][0])
else:
return data, None
[docs]
def forward(self, time_indices: pt.Tensor) -> pt.Tensor:
evals = pt.cat((self._eigvals, self._eigvals[self._conj_indices].conj()))
evecs = pt.cat(
(self._eigvecs, self._eigvecs[:, self._conj_indices].conj()), dim=1
)
vander = pt.vstack([evals ** n.item() for n in time_indices[0]]).T
return evecs @ vander
[docs]
def train(
self,
epochs: int = 1000,
lr: float = 1.0e-3,
batch_size: int = None,
train_size: Union[int, float] = 0.75,
val_size: Union[int, float] = 0.25,
loss_function: Callable = l2_loss,
scheduler_options: dict = {},
stopping_options: dict = {},
loss_key: str = "val_loss"
):
"""Optimize modes and dynamics based on gradient descent.
:param epochs: number of training iterations, defaults to 1000
:type epochs: int, optional
:param lr: initial learning rate, defaults to 1.0e-3
:type lr: float, optional
:param batch_size: batch size for batch training, defaults to None
:type batch_size: int, optional
:param train_size: fraction or number of snapshots to use for
training; defaults to 0.75
:type train_size: Union[int, float], optional
:param val_size: fraction or number of snapshots to use for
validation; defaults to 0.25
:type val_size: Union[int, float], optional
:param loss_function: user-defined loss function, e.g., to add
sparsity promotion, defaults to l2_loss
:type loss_function: Callable, optional
:param scheduler_options: options passed to learning rate scheduler;
refer to PyTorch's `ReduceLROnPlateau` documentation; defaults to {}
:type scheduler_options: dict, optional
:param stopping_options: options to modify early stopping behavior;
refer to the `EarlyStopping` class; defaults to {}
:type stopping_options: dict, optional
:param loss_key: key of loss value based on which the learning rate schedule
and early stopping criteria are evaluated; can be 'train_loss',
'val_loss', or 'full_loss'; defaults to "val_loss"
:type loss_key: str, optional
"""
optimizer = pt.optim.AdamW(self.parameters(), lr=lr)
options = {
key: scheduler_options[key] if key in scheduler_options else val
for key, val in DEFAULT_SCHEDULER_OPT.items()
}
scheduler = pt.optim.lr_scheduler.ReduceLROnPlateau(
optimizer=optimizer, **options
)
train_set, val_set = self._create_train_val_split(train_size, val_size)
batch_size, shuffle = (
(batch_size, True) if batch_size else (len(train_set), False)
)
train_loader = pt.utils.data.DataLoader(
train_set, batch_size=batch_size, shuffle=shuffle
)
stopper = EarlyStopping(model=self, **stopping_options)
e, stop = 0, False
while e < epochs and not stop:
train_loss = 0.0
for batch in train_loader:
pred = self.forward(batch)
loss = loss_function(
self._dmd._dm[:, batch[0]], pred, self._eigvecs, self._eigvals
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()
self._log["train_loss"].append(train_loss / len(train_loader))
if val_set is not None:
with pt.no_grad():
batch = val_set[:]
self._log["val_loss"].append(
loss_function(
self._dmd._dm[:, batch[0]],
self.forward(batch),
self._eigvecs,
self._eigvals,
).item()
)
self._log["lr"].append(optimizer.param_groups[0]["lr"])
val_loss = self._log["val_loss"][-1] if val_set is not None else 0.0
self._log["full_loss"].append(val_loss + self._log["train_loss"][-1])
scheduler.step(self._log[loss_key][-1])
print(
"\rEpoch {:4d} - train loss: {:1.6e}, val loss: {:1.6e}, lr: {:1.6e}".format(
e, self._log["train_loss"][-1], val_loss, self._log["lr"][-1]
),
end="",
)
e += 1
stop = stopper(self._log[loss_key][-1])
[docs]
def partial_reconstruction(self, mode_indices: Set[int]) -> pt.Tensor:
modes = self.modes
mode_mask = pt.zeros(modes.shape[1], dtype=modes.dtype)
mode_indices = pt.tensor(list(mode_indices), dtype=pt.int64)
mode_mask[mode_indices] = 1.0
rec = (modes * mode_mask) @ self.dynamics
if not self._dmd._complex:
rec = rec.real
return rec.type(self._dmd._dm.dtype)
[docs]
def top_modes(
self,
n: int = 10,
integral: bool = False,
f_min: float = -float("inf"),
f_max: float = float("inf"),
) -> pt.Tensor:
importance = self.integral_contribution if integral else self.amplitude
modes_in_range = pt.logical_and(self.frequency >= f_min, self.frequency < f_max)
mode_indices = pt.tensor(range(modes_in_range.shape[0]), dtype=pt.int64)[
modes_in_range
]
n = min(n, mode_indices.shape[0])
top_n = importance[mode_indices].abs().topk(n).indices
return mode_indices[top_n]
[docs]
def predict(self, initial_condition: pt.Tensor, n_steps: int) -> pt.Tensor:
"""Predict evolution over N steps starting from used-defined initial conditions.
:param initial_condition: initial state vector
:type initial_condition: pt.Tensor
:param n_steps: number of steps to predict
:type n_steps: int
:return: predicted evolution including the initial state (N+1 states are returned)
:rtype: pt.Tensor
"""
modes = self.eigvecs / self.amplitude
b = pt.linalg.pinv(modes) @ initial_condition.type(modes.dtype)
prediction = modes @ pt.diag(b) @ pt.linalg.vander(self.eigvals, N=n_steps+1)
if not self._dmd._complex:
prediction = prediction.real
return prediction
@property
def dmd_init(self) -> DMD:
return self._dmd
@property
def log(self) -> dict:
return self._log
@property
def modes(self) -> pt.Tensor:
return self.eigvecs / self.amplitude
@property
def eigvals(self) -> pt.Tensor:
ev = self._eigvals.detach()
return pt.cat((ev, ev[self._conj_indices].conj()))
@property
def eigvals_cont(self) -> pt.Tensor:
return pt.log(self.eigvals) / self._dmd._dt
@property
def eigvecs(self) -> pt.Tensor:
ev = self._eigvecs.detach()
return pt.cat((ev, ev[:, self._conj_indices].conj()), dim=1)
@property
def frequency(self) -> pt.Tensor:
return pt.log(self.eigvals).imag / (2.0 * pi * self._dmd._dt)
@property
def growth_rate(self) -> pt.Tensor:
return (pt.log(self.eigvals) / self._dmd._dt).real
@property
def amplitude(self) -> pt.Tensor:
return self.eigvecs.norm(dim=0)
@property
def dynamics(self) -> pt.Tensor:
vander = pt.linalg.vander(self.eigvals, N=self._dmd._cols)
return pt.diag(self.amplitude.type(vander.dtype)) @ vander
@property
def integral_contribution(self) -> pt.Tensor:
return self.amplitude**2 * self.dynamics.abs().sum(dim=1)
@property
def reconstruction(self) -> pt.Tensor:
return (self.modes @ self.dynamics).real
@property
def reconstruction_error(self) -> pt.Tensor:
return self._dmd._dm - self.reconstruction