r"""Direct access to TAU simulation data.
The DRL (Deutsches Luft- und Raumfahrtzentrum) TAU_ code saves
snapshots in the NetCFD format. The :class:`TAUDataloader` is a
wrapper around the NetCFD Python bindings to simplify the access
to snapshot data.
.. _TAU: https://www.dlr.de/as/desktopdefault.aspx/tabid-395/526_read-694/
# standard library packages
from os.path import join, split
from glob import glob
from typing import List, Dict, Tuple, Union, Set
# third party packages
from netCDF4 import Dataset
import torch as pt
# flowtorch packages
from flowtorch import DEFAULT_DTYPE
from .dataloader import Dataloader
from .utils import check_list_or_str, check_and_standardize_path
VOL_SOLUTION_NAME = ".pval.unsteady_"
PMESH_NAME = "domain_{:s}_grid_1"
PVERTEX_KEY = "pcoord"
PWEIGHT_KEY = "pvolume"
PADD_POINTS_KEY = "addpoint_idx"
PGLOBAL_ID_KEY = "globalidx"
VERTEX_KEYS = ("points_xc", "points_yc", "points_zc")
WEIGHT_KEY = "volume"
SOLUTION_PREFIX_KEY = "solution_prefix"
GRID_FILE_KEY = "primary_grid"
GRID_PREFIX_KEY = "grid_prefix"
N_DOMAINS_KEY = "n_domains"
[docs]class TAUConfig(object):
"""Load and parse TAU parameter files.
The class does not parse the full content of the parameter file
but only content that is absolutely needed to load snapshot data.
def __init__(self, file_path: str):
"""Create a `TauConfig` instance from the file path.
:param file_path: path to the parameter file
:type path: str
self._path, self._file_name = split(file_path)
with open(join(self._path, self._file_name), "r") as config:
self._file_content = config.readlines()
self._config = None
def _parse_config(self, parameter: str) -> str:
"""Extract a value based on a given pattern.
Every line of the parameter file follows the structure:
parameter : value
This function extracts the value as string and remove potential
white spaces or comments (#). The separator is expected to be a
:param parameter: the parameter of which to extract the value
:type pattern: str
:return: extracted value or empty string
:rtype: str
for line in self._file_content:
if parameter in line:
return line.split(CONFIG_SEP)[-1].split(COMMENT_CHAR)[0].strip()
return ""
def _gather_config(self):
"""Gather all required configuration values.
config = {}
config[SOLUTION_PREFIX_KEY] = self._parse_config("Output files prefix")
config[GRID_FILE_KEY] = self._parse_config("Primary grid filename")
config[GRID_PREFIX_KEY] = self._parse_config("Grid prefix")
config[N_DOMAINS_KEY] = int(self._parse_config("Number of domains"))
self._config = config
def path(self) -> str:
return self._path
def config(self) -> dict:
if self._config is None:
return self._config
[docs]class TAUDataloader(Dataloader):
"""Load TAU simulation data.
The loader is currently limited to read:
- internal field solution, serial/reconstructed and distributed
- mesh vertices, serial and distributed
- cell volumes, serial (if present) and distributed
>>> from os.path import join
>>> from flowtorch import DATASETS
>>> from flowtorch.data import TAUDataloader
>>> path = DATASETS["tau_backward_facing_step"]
>>> loader = TAUDataloader(join(path, "simulation.para"))
>>> times = loader.write_times
>>> fields = loader.field_names[times[0]]
>>> fields
['density', 'x_velocity', 'y_velocity', ...]
>>> density = loader.load_snapshot("density", times)
To load distributed simulation data, set `distributed=True`
>>> path = DATASETS["tau_cylinder_2D"]
>>> loader = TAUDataloader(join(path, "simulation.para"), distributed=True)
>>> vertices = loader.vertices
def __init__(self, parameter_file: str, distributed: bool = False,
dtype: str = DEFAULT_DTYPE):
"""Create loader instance from TAU parameter file.
:param parameter_file: path to TAU simulation parameter file
:type parameter_file: str
:param distributed: True if mesh and solution are distributed in domain
files; defaults to False
:type distributed: bool, optional
:param dtype: tensor type, defaults to DEFAULT_DTYPE
:type dtype: str, optional
self._para = TAUConfig(parameter_file)
self._distributed = distributed
self._dtype = dtype
self._time_iter = self._decompose_file_name()
self._mesh_data = None
def _decompose_file_name(self) -> Dict[str, str]:
"""Extract write time and iteration from file name.
:raises FileNotFoundError: if no solution files are found
:return: dictionary with write times as keys and the corresponding
iterations as values
:rtype: Dict[str, str]
base = join(self._para.path, self._para.config[SOLUTION_PREFIX_KEY])
suffix = f"{PSOLUTION_POSTFIX}0" if self._distributed else "e???"
files = glob(f"{base}i=*t=*{suffix}")
if len(files) < 1:
raise FileNotFoundError(
f"Could not find solution files in {self._sol_path}/")
time_iter = {}
split_at = PSOLUTION_POSTFIX if self._distributed else " "
for f in files:
t = f.split("t=")[-1].split(split_at)[0]
i = f.split("i=")[-1].split("_t=")[0]
time_iter[t] = i
return time_iter
def _file_name(self, time: str, suffix: str = "") -> str:
"""Create solution file name from write time.
:param time: snapshot write time
:type time: str
:param suffix: suffix to append to the file name; used for decomposed
:type suffix: str, optional
:return: name of solution file
:rtype: str
itr = self._time_iter[time]
path = join(self._para.path, self._para.config[SOLUTION_PREFIX_KEY])
return f"{path}{VOL_SOLUTION_NAME}i={itr}_t={time}{suffix}"
def _load_domain_mesh_data(self, pid: str) -> pt.Tensor:
"""Load vertices and volumes for a single processor domain.
:param pid: domain id
:type pid: str
:return: tensor of size n_points x 4, where n_points is the number
of unique cells in the domain, and the 4 columns contain the
coordinates of the vertices (x, y, z) and the cell volumes
:rtype: pt.Tensor
prefix = self._para.config[GRID_PREFIX_KEY]
name = PMESH_NAME.format(pid)
if not (prefix == "(none)"):
name = f"{prefix}_{name}"
path = join(self._para.path, name)
with Dataset(path) as data:
vertices = pt.tensor(data[PVERTEX_KEY][:], dtype=self._dtype)
volumes = pt.tensor(data[PWEIGHT_KEY][:], dtype=self._dtype)
global_ids = pt.tensor(data[PGLOBAL_ID_KEY][:], dtype=pt.int64)
n_add_points = data[PADD_POINTS_KEY].shape[0]
n_points = volumes.shape[0] - n_add_points
data = pt.zeros((n_points, 4), dtype=self._dtype)
sorting = pt.argsort(global_ids[:n_points])
data[:, 0] = vertices[:n_points, 0][sorting]
data[:, 1] = vertices[:n_points, 1][sorting]
data[:, 2] = vertices[:n_points, 2][sorting]
data[:, 3] = volumes[:n_points][sorting]
return data
def _load_mesh_data(self):
"""Load mesh vertices and cell volumes.
The mesh data is saved as class member `_mesh_data`. The tensor has the
dimension n_points x 4; the first three columns correspond to the x/y/z
coordinates, and the 4th column contains the volumes.
if self._distributed:
n = self._para.config[N_DOMAINS_KEY]
self._mesh_data = pt.cat(
[self._load_domain_mesh_data(str(pid)) for pid in range(n)],
path = join(self._para.path, self._para.config[GRID_FILE_KEY])
with Dataset(path) as data:
vertices = pt.stack(
[pt.tensor(data[key][:], dtype=self._dtype)
for key in VERTEX_KEYS],
if WEIGHT_KEY in data.variables.keys():
weights = pt.tensor(
data.variables[WEIGHT_KEY][:], dtype=self._dtype)
f"Warning: could not find cell volumes in file {path}")
weights = pt.ones(vertices.shape[0], dtype=self._dtype)
self._mesh_data = pt.cat((vertices, weights.unsqueeze(-1)), dim=-1)
def _load_single_snapshot(self, field_name: str, time: str) -> pt.Tensor:
"""Load a single snapshot of a single field from the netCDF4 file(s).
:param field_name: name of the field
:type field_name: str
:param time: snapshot write time
:type time: str
:return: tensor holding the field values
:rtype: pt.Tensor
if self._distributed:
field = []
for pid in range(self._para.config[N_DOMAINS_KEY]):
path = self._file_name(time, f".domain_{pid}")
with Dataset(path) as data:
data.variables[field_name][:], dtype=self._dtype)
return pt.cat(field, dim=0)
path = self._file_name(time)
with Dataset(path) as data:
field = pt.tensor(
data.variables[field_name][:], dtype=self._dtype)
return field
[docs] def load_snapshot(self, field_name: Union[List[str], str],
time: Union[List[str], str]) -> Union[List[pt.Tensor], pt.Tensor]:
check_list_or_str(field_name, "field_name")
check_list_or_str(time, "time")
# load multiple fields
if isinstance(field_name, list):
if isinstance(time, list):
return [
pt.stack([self._load_single_snapshot(field, t)
for t in time], dim=-1)
for field in field_name
return [
self._load_single_snapshot(field, time) for field in field_name
# load single field
if isinstance(time, list):
return pt.stack(
[self._load_single_snapshot(field_name, t) for t in time],
return self._load_single_snapshot(field_name, time)
def write_times(self) -> List[str]:
return sorted(list(self._time_iter.keys()), key=float)
def field_names(self) -> Dict[str, List[str]]:
"""Find available fields in solution files.
Available fields are determined by matching the number of
weights with the length of datasets in the available
solution files; for distributed cases, the fields are only
determined based on *domain_0*.
:return: dictionary with time as key and list of
available solution fields as value
:rtype: Dict[str, List[str]]
self._field_names = {}
if self._distributed:
n_points = self._load_domain_mesh_data("0").shape[0]
suffix = ".domain_0"
n_points = self.vertices.shape[0]
suffix = ""
for time in self.write_times:
self._field_names[time] = []
with Dataset(self._file_name(time, suffix)) as data:
for key in data.variables.keys():
if data[key].shape[0] == n_points:
return self._field_names
def vertices(self) -> pt.Tensor:
if self._mesh_data is None:
return self._mesh_data[:, :3]
def weights(self) -> pt.Tensor:
if self._mesh_data is None:
return self._mesh_data[:, 3]