"""Class and tools to read Visualization Toolkit (VTK_) data.
.. _VTK: https://vtk.org/
"""
# standard library packages
from glob import glob
from typing import Callable, Union, List, Dict
# third party packages
import torch as pt
from vtk import vtkUnstructuredGridReader, vtkXMLUnstructuredGridReader
from vtk.numpy_interface.dataset_adapter import WrapDataObject, UnstructuredGrid
# flowtorch packages
from flowtorch import DEFAULT_DTYPE
from .dataloader import Dataloader
from .utils import check_and_standardize_path, check_list_or_str
[docs]class VTKDataloader(Dataloader):
"""Load unstructured VTK files and time series.
The loader assumes that snapshots are stored in individual VTK files.
Currently, only unstructured mesh data are supported.
Examples
>>> from flowtorch import DATASETS
>>> from flowtorch.data import VTKDataloader
>>> path = DATASETS["vtk_cylinder_re200_flexi"]
>>> loader = VTKDataloader.from_flexi(path, "Cylinder_Re200_Solution_")
>>> loader.write_times
["0000000", "0000005", "0000300"]
>>> loader.field_names
{'0000000': ['Density', 'MomentumX', 'MomentumY', 'MomentumZ']}
>>> density = loader.load_snapshot("Density", loader.write_times)
>>> density.shape
torch.Size([729000, 3])
>>> from flowtorch import DATASETS
>>> from flowtorch.data import VTKDataloader
>>> path = DATASETS["vtk_su2_airfoil_2D"]
>>> loader = VTKDataloader.from_su2(path, "flow_")
>>> p, U = loader.load_snapshot(["Pressure", "Velocity"], loader.write_times[0])
>>> U.shape
torch.Size([214403, 3])
"""
def __init__(self, path: str, vtk_reader: Union[vtkUnstructuredGridReader, vtkXMLUnstructuredGridReader],
prefix: str = "", suffix: str = "", dtype: str = DEFAULT_DTYPE):
"""Create a VTKDataloader instance from a folder of VTK files.
The loader assumes that the write time is encoded in the file name.
:param path: path to folder containing VTK files
:type path: str
:param vtk_reader: unstructured VTK reader for XML or legacy VTK format
:type vtk_reader: Union[vtkUnstructuredGridReader, vtkXMLUnstructuredGridReader]
:param prefix: part of file name before time value, defaults to ""
:type prefix: str, optional
:param suffix: part of file name after time value, defaults to ""
:type suffix: str, optional
:param dtype: tensor type, defaults to DEFAULT_DTYPE
:type dtype: str, optional
"""
self._path = path
self._vtk_reader = vtk_reader
self._prefix = prefix
self._suffix = suffix
self._dtype = dtype
self._write_times = None
self._field_names = None
[docs] @classmethod
def from_flexi(cls, path: str, prefix: str = "", suffix: str = ".000000000.vtu", dtype: str = DEFAULT_DTYPE):
"""Create loader instance from VTK files generated by Flexi_.
Flexi supports the output of field and surface data as unstructured
XML-based VTK files.
.. _Flexi: https://www.flexi-project.org/
:param path: path to folder containing VTK files
:type path: str
:param prefix: part of file name before time value, defaults to ""
:type prefix: str, optional
:param suffix: part of file name after time value, defaults to ".000000000.vtu"
:type suffix: str, optional
:param dtype: tensor type, defaults to DEFAULT_DTYPE
:type dtype: str, optional
"""
return cls(path, vtkXMLUnstructuredGridReader, prefix, suffix, dtype)
[docs] @classmethod
def from_su2(cls, path: str, prefix: str = "", suffix: str = ".vtk", dtype: str = DEFAULT_DTYPE):
"""Create loader instance from VTK files generated by SU2_.
.. _SU2: https://su2code.github.io/
:param path: path to folder containing VTK files
:type path: str
:param prefix: part of file name before time value, defaults to ""
:type prefix: str, optional
:param suffix: part of file name after time value, defaults to ".vtk"
:type suffix: str, optional
:param dtype: tensor type, defaults to DEFAULT_DTYPE
:type dtype: str, optional
"""
return cls(path, vtkUnstructuredGridReader, prefix, suffix, dtype)
def _create_vtk_reader(self, file_path: str) -> UnstructuredGrid:
"""Create a VTK reader object for unstructured grids.
:param file_path: location of the VTK file
:type file_path: str
:return: VTK reader for unstructured grids
:rtype: UnstructuredGrid
"""
reader = self._vtk_reader()
reader.SetFileName(file_path)
if hasattr(reader, "ReadAllVectorsOn"):
reader.ReadAllVectorsOn()
if hasattr(reader, "ReadAllScalarsOn"):
reader.ReadAllScalarsOn()
reader.Update()
return WrapDataObject(reader.GetOutput())
def _build_file_path(self, time: str) -> str:
"""Create file path VTK file.
:param time: snapshot write time
:type time: str
:return: VTK file location
:rtype: str
"""
return f"{self._path}/{self._prefix}{time}{self._suffix}"
[docs] def load_snapshot(self,
field_name: Union[List[str], str],
time: Union[List[str], str]) -> Union[List[pt.Tensor], pt.Tensor]:
check_list_or_str(field_name, "field_name")
check_list_or_str(time, "time")
# load multiple fields
if isinstance(field_name, list):
if isinstance(time, list):
snapshots = [
self._create_vtk_reader(self._build_file_path(t)).PointData for t in time
]
return [
pt.stack(
[pt.tensor(snapshot[name], dtype=self._dtype)
for snapshot in snapshots], dim=-1
) for name in field_name
]
else:
snapshot = self._create_vtk_reader(
self._build_file_path(time)).PointData
return [
pt.tensor(snapshot[name], dtype=self._dtype) for name in field_name
]
# load single field
else:
if isinstance(time, list):
return pt.stack(
[
pt.tensor(
self._create_vtk_reader(
self._build_file_path(t)).PointData[field_name],
dtype=self._dtype
) for t in time
],
dim=-1
)
else:
return pt.tensor(
self._create_vtk_reader(self._build_file_path(time)).PointData[
field_name], dtype=self._dtype
)
@ property
def write_times(self) -> List[str]:
if self._write_times is None:
files = glob(self._build_file_path("*"))
self._write_times = sorted(
[f.split("/")[-1][len(self._prefix):-len(self._suffix)]
for f in files], key=float
)
return self._write_times
@ property
def field_names(self) -> Dict[str, List[str]]:
if self._field_names is None:
snapshot = self._create_vtk_reader(
self._build_file_path(self.write_times[0])
)
self._field_names = dict(
{self.write_times[0]: snapshot.PointData.keys()}
)
return self._field_names
@ property
def vertices(self) -> pt.Tensor:
snapshot = self._create_vtk_reader(
self._build_file_path(self.write_times[0])
)
return pt.tensor(snapshot.Points, dtype=self._dtype)
@ property
def weights(self) -> pt.Tensor:
raise NotImplementedError(
"The weights property is not yet implemented.")