# -*- coding: utf-8 -*-
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Mapping, MutableMapping
from enum import Enum
from inspect import getmembers, isabstract, isclass
from importlib import import_module
import os
from pkgutil import iter_modules
from time import time
from types import UnionType
from typing import Any, Callable, KeysView, Literal, Sequence, Type, TypeVar, ValuesView
from typing_extensions import override, overload
import numpy as np
from h5py import ExternalLink, File, HardLink, Group, SoftLink, string_dtype
from numpy.typing import DTypeLike
import hermespy
__author__ = "Jan Adler"
__copyright__ = "Copyright 2024, Barkhausen Institut gGmbH"
__credits__ = ["Jan Adler"]
__license__ = "AGPLv3"
__version__ = "1.5.0"
__maintainer__ = "Jan Adler"
__email__ = "jan.adler@barkhauseninstitut.org"
__status__ = "Prototype"
UNDEF_TYPE = Literal["UNDEFINED"]
"""Type of an undefined value representing optional arguments that include :py:obj:`None` as a default value."""
UNDEF: UNDEF_TYPE = "UNDEFINED"
"""Undefined value for optional arguments that include :py:obj:`None` as a default value."""
SerializableType = TypeVar("SerializableType", bound="Serializable")
"""Type of Serializable Class."""
[docs]
class Serializable(object):
"""Base class for serializable classes.
Only classes inheriting from `Serializable` will be serialized by the factory.
"""
[docs]
@abstractmethod
def serialize(self, process: SerializationProcess) -> None:
"""Serialize this object's state.
Objects cannot be serialized directly, instead a :class:`Factory<hermespy.core.factory.Factory>` must be instructed to carry out the serialization process.
Args:
process:
The current stage of the serialization process.
This object is generated by the :class:`Factory<hermespy.core.factory.Factory>` and provides an interface to serialization methods supporting multiple backends.
"""
... # pragma: no cover
[docs]
@classmethod
@abstractmethod
def Deserialize(
cls: Type[SerializableType], process: DeserializationProcess
) -> SerializableType:
"""Deserialize an object's state.
Objects cannot be deserialized directly, instead a :class:`Factory<hermespy.core.factory.Factory>` must be instructed to carry out the deserialization process.
Args:
process:
The current stage of the deserialization process.
This object is generated by the :class:`Factory<hermespy.core.factory.Factory>` and provides an interface to deserialization methods supporting multiple backends.
Returns:
The deserialized object.
"""
... # pragma: no cover
[docs]
@classmethod
def serialization_tag(cls: Type[SerializableType]) -> str:
"""Tag used to identify the respective class during serialization.
Returns: The serialization tag.
"""
return cls.__module__ + "." + cls.__name__
SET = TypeVar("SET", bound="SerializableEnum")
"""Type of serializable enumeration."""
[docs]
class SerializableEnum(Serializable, Enum):
"""Base class for serializable enumerations."""
[docs]
@classmethod
def from_parameters(cls: Type[SET], enum: SET | int | str) -> SET:
"""Initialize enumeration from multiple parameters.
Args:
enum:
The parameter from which the enum should be initialized.
Returns: The initialized enumeration.
"""
if isinstance(enum, cls):
return enum
elif isinstance(enum, int):
return cls(enum)
elif isinstance(enum, str):
return cls[enum]
else:
raise ValueError("Unknown serializable enumeration type")
[docs]
@override
def serialize(self, process: SerializationProcess) -> None:
process.serialize_integer(self.value, "value")
[docs]
@override
@classmethod
def Deserialize(cls: Type[SET], process: DeserializationProcess) -> SET:
return cls(process.deserialize_integer("value"))
[docs]
class SerializationBackend(SerializableEnum):
"""Backend of the serialization process."""
HDF = 0
"""HDF5 serialization backend."""
MAT = 1
"""MATLAB serialization backend.
Not implemented yet.
"""
PICKLE = 2
"""Pickle serialization backend.
Not implemented yet.
"""
[docs]
class Factory:
"""Helper class to serialize Hermespy's runtime objects."""
__tag_registry: Mapping[str, Type[Serializable]]
def __init__(self) -> None:
self.__tag_registry = dict()
# Iterate over all modules within the hermespy namespace
# Scan for serializable classes
lookup_paths = list(hermespy.__path__) + [
os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
]
for _, name, is_module in iter_modules(lookup_paths, hermespy.__name__ + "."):
if not is_module:
continue # pragma: no cover
module = import_module(name)
for _, serializable_class in getmembers(module):
if not isclass(serializable_class) or not issubclass(
serializable_class, Serializable
):
continue
# Only register non-abstract classes
if isabstract(serializable_class):
continue
# Generate a unique serialization UID
tag = serializable_class.serialization_tag()
# Throw an error if the tag is already registered with a different class
if tag in self.__tag_registry:
if self.__tag_registry[tag] != serializable_class:
raise RuntimeError(
f"Error attempting to register class '{serializable_class.__name__}' under the tag '{tag}': Tag is already registered with class '{self.__tag_registry[tag].__name__}'"
)
continue
self.__tag_registry[tag] = serializable_class
@property
def registered_classes(self) -> ValuesView[Type[Serializable]]:
"""Classes registered for serialization within the factory."""
return self.__tag_registry.values()
@property
def registered_tags(self) -> KeysView[str]:
"""Read registered serialization tags."""
return self.__tag_registry.keys()
@property
def tag_registry(self) -> Mapping[str, Type[Serializable]]:
"""Read registered serialization tags."""
return self.__tag_registry
[docs]
def to_HDF(self, target: str | Group, serializable: Serializable) -> None:
"""Dump a runtime object to a single HDF5 file or HDF5 file-like object.
Args:
target: Path to the configuration file or HDF5 file-like object.
serializable: Object to be serialized.
Raises:
RuntimeError: If objects in ``*args`` are unregistered classes or not serializable.
"""
# Open HDF5 file (if required)
_target = File(target, mode="w") if isinstance(target, str) else target
try:
# Generate a new serialization process
process = HDFSerializationProcess.New(self.tag_registry, _target)
# Serialize the object
serializable.serialize(process)
finally:
if isinstance(target, str):
_target.close()
[docs]
def from_HDF(self, target: str | Group, type: Type[Serializable]) -> Any:
"""Load a runtime object from a single HDF5 file or HDF5 file-like object.
Args:
target: Path to the configuration file or HDF5 file-like object.
Returns:
Serializable: The deserialized object.
Raises:
RuntimeError: If the object is not a registered class.
"""
# Open HDF5 file (if required)
_target = File(target, mode="r") if isinstance(target, str) else target
try:
# Generate a new deserialization process
process = HDFDeserializationProcess.New(self.tag_registry, _target)
# Deserialize the object
deserialization = type.Deserialize(process)
finally:
if isinstance(target, str):
_target.close()
return deserialization
[docs]
def serialize(
self,
target: str,
campaign: str | None = None,
backend: SerializationBackend = SerializationBackend.HDF,
) -> SerializationProcess:
"""Serialize an object to a file.
Args:
target:
Path to the target file.
campaign:
Name of the serialization campaign.
Can be set to store multiple related serializations within the same file.
backend:
Serialization backend to be used.
Defaults to :class:`SerializationBackend.HDF`.
"""
if backend == SerializationBackend.HDF:
return HDFSerializationProcess.New(self.tag_registry, File(target, mode="a"), campaign)
raise RuntimeError(f"Unsupported serialization backend '{backend}'")
[docs]
def deserialize(
self,
target: str,
campaign: str | None = None,
backend: SerializationBackend = SerializationBackend.HDF,
) -> DeserializationProcess:
"""Deserialize an object from a file.
Args:
target:
Path to the target file.
campaign:
Name of the serialization campaign.
Can be set to store multiple related serializations within the same file.
backend:
Serialization backend to be used.
Defaults to :class:`SerializationBackend.HDF`.
Returns: The deserialized object.
"""
if backend == SerializationBackend.HDF:
return HDFDeserializationProcess.New(
self.tag_registry, File(target, mode="r"), campaign
)
else:
raise RuntimeError(f"Unsupported serialization backend '{backend}'")
[docs]
class ProcessBase(ABC):
"""Common base class for serialization and deserialization processes."""
__tag_registry: Mapping[str, Type[Serializable]]
def __init__(self, tag_registry: Mapping[str, Type[Serializable]]) -> None:
self.__tag_registry = tag_registry
@property
def _tag_registry(self) -> Mapping[str, Type[Serializable]]:
return self.__tag_registry
_OT = TypeVar("_OT", bound=Serializable)
_RT = TypeVar("_RT")
[docs]
class SerializationProcess(ProcessBase):
"""Base class for all serialization processes."""
[docs]
@abstractmethod
def serialize_array(self, array: np.ndarray, name: str, cache: bool = True) -> None:
"""Serialize a numpy array.
.. warning::
Arrays to be serialized should be immutable during the serialization process.
This is to ensure that the array's memory address remains constant and can be used as a unique identifier.
If this can't be guaranteed, caching for the respective array should be disabled.
Args:
array: The numpy array to be serialized.
name: Name of the dataset.
cache: Cache the array for future reference. Defaults to :py:obj:`True`.
"""
... # pragma: no cover
[docs]
@abstractmethod
def serialize_floating(self, value: float, name: str) -> None:
"""Serialize a floating point value.
Args:
value: The floating point value to be serialized.
name: Name of the dataset.
"""
... # pragma: no cover
[docs]
@abstractmethod
def serialize_complex(self, value: complex, name: str) -> None:
"""Serialize a complex value.
Args:
value: The complex value to be serialized.
name: Name of the dataset.
"""
... # pragma: no cover
[docs]
@abstractmethod
def serialize_integer(self, value: int, name: str) -> None:
"""Serialize an integer value.
Args:
value: The integer value to be serialized.
name: Name of the dataset.
"""
... # pragma: no cover
[docs]
@abstractmethod
def serialize_string(self, value: str, name: str) -> None:
"""Serialize a string value.
Args:
value: The string value to be serialized.
name: Name of the dataset.
"""
... # pragma: no cover
[docs]
def serialize_object(self, obj: Serializable, name: str, root: bool = False) -> None:
"""Serialize an object.
Args:
obj: The object to be serialized.
name: Name of the dataset.
root: Serialize the object as a root object. Defaults to :py:obj:`False`.
Raises:
RuntimeError: If the object is not a registered class.
"""
if not isinstance(obj, Serializable):
raise RuntimeError(f"Object '{obj}' is not a serializable class")
tag = obj.serialization_tag()
if tag not in self._tag_registry.keys():
raise RuntimeError(
f"Object '{obj}' is not a registered class and therfore cannot be serialized"
)
self._serialize_object(obj, name, tag, root)
@abstractmethod
def _serialize_object(self, obj: Serializable, name: str, type: str, root: bool) -> None:
"""Serialize an object.
Args:
obj: The object to be serialized.
name: Name of the dataset.
type: The object's serialization tag.
"""
... # pragma: no cover
[docs]
def serialize_object_sequence(
self, objects: Sequence[Serializable], name: str, append: bool = False, root: bool = False
) -> None:
"""Serialize a sequence of objects.
Args:
objects: The sequence of objects to be serialized.
name: Name of the dataset.
append: Append the objects to the dataset if it already exists. Defaults to :py:obj:`False`.
"""
... # pragma: no cover
[docs]
def serialize_object_mapping(
self,
objects: Mapping[str, Serializable],
name: str,
append: bool = True,
root: bool = False,
) -> None:
"""Serialize a mapping of string keys to objects.
Args:
objects: The mapping of objects to be serialized.
name: Name of the dataset.
append: Append the objects to the dataset if it already exists. Defaults to :py:obj:`True`.
"""
# Serialize the number of objects
self.serialize_integer(len(objects), f"{name}_count")
# Serialize each object
for index, (key, obj) in enumerate(objects.items()):
self.serialize_object(obj, f"{name}_{index:02d}")
self.serialize_string(key, f"{name}_{index:02d}_key")
[docs]
def serialize_range(self, value: float | tuple[float, float] | None, name: str) -> None:
# None indicates an undefined value that should not be serialized
if value is None:
return
# Serialize the range value
if isinstance(value, tuple):
self.serialize_floating(value[0], f"{name}_min")
self.serialize_floating(value[1], f"{name}_max")
# Serialize the scalar value
elif isinstance(value, (float, np.float64)):
self.serialize_floating(value, name)
else:
raise ValueError("Invalid range value type")
[docs]
@abstractmethod
def serialize_custom(
self, name: str, serializer: Callable[[SerializationProcess], None]
) -> None:
"""Serialize anything using a custom function.
Args:
name: Name of the dataset.
serializer: Serialization function to be executed.
"""
... # pragma: no cover
[docs]
def finalize(self) -> None:
"""Finalize the serialization process.
After finalization, the serialization process is considered complete and no further serialization is possible.
Closes any open file handles and releases any resources reserved used during the serialization process.
"""
pass # pragma: no cover
[docs]
class DeserializationProcess(ProcessBase):
"""Base class for all deserialization processes."""
[docs]
@abstractmethod
def deserialize_array(
self, name: str, dtype: DTypeLike, default: np.ndarray | UNDEF_TYPE = UNDEF
) -> np.ndarray:
"""Deserialize a numpy array.
Args:
name: Name of the dataset.
dtype: Expected data type of the array.
default: Default value to be returned if the value does not exist.
Returns:
numpy.ndarray: The deserialized numpy array.
"""
... # pragma: no cover
[docs]
@abstractmethod
def deserialize_floating(self, name: str, default: float | UNDEF_TYPE = UNDEF) -> float:
"""Deserialize a floating point value.
Args:
name: Name of the dataset.
default: Default value to be returned if the value does not exist.
Returns:
float: The deserialized floating point value.
"""
... # pragma: no cover
[docs]
@abstractmethod
def deserialize_complex(self, name: str, default: complex | UNDEF_TYPE = UNDEF) -> complex:
"""Deserialize a complex value.
Args:
name: Name of the dataset.
default: Default value to be returned if the value does not exist.
Returns:
complex: The deserialized complex value.
"""
... # pragma: no cover
[docs]
@abstractmethod
def deserialize_integer(self, name: str, default: int | UNDEF_TYPE = UNDEF) -> int:
"""Deserialize an integer value.
Args:
name: Name of the dataset.
default: Default value to be returned if the value does not exist.
Returns:
int: The deserialized integer value.
"""
... # pragma: no cover
[docs]
@abstractmethod
def deserialize_string(self, name: str, default: str | UNDEF_TYPE = UNDEF) -> str:
"""Deserialize a string value.
Args:
name: Name of the dataset.
default: Default value to be returned if the value does not exist.
Returns:
str: The deserialized string value.
"""
... # pragma: no cover
@abstractmethod
def _deserialize_object(self, name: str, expected_type: Type[_OT], default: _RT) -> _OT | _RT:
"""Deserialize an object.
Args:
name: Name of the dataset.
expected_type: Expected type of the deserialized object.
default: Default value to be returned if the object does not exist.
Returns: The deserialized object.
"""
... # pragma: no cover
@overload
def deserialize_object(self, name: str, type: Type[_OT] | Sequence[Type[_OT]], /) -> _OT:
"""Safeley deserialize an object.
Args:
name: Name of the dataset.
type: Expected type of the deserialized object.
Returns: The deserialized object.
Raises:
RuntimeError: If the object is not a registered class.
RuntimeError: If the deserilized object is not of the expected type.
RuntimeError: If the dataset does not exist.
"""
... # pragma: no cover
@overload
def deserialize_object(
self, name: str, type: Type[_OT] | Sequence[Type[_OT]], default: _RT, /
) -> _OT | _RT:
"""Safeley deserialize an object.
Args:
name: Name of the dataset.
type: Expected type of the deserialized object.
default: Default value to be returned if the dataset does not exist.
Returns: The deserialized object.
Raises:
RuntimeError: If the object is not a registered class.
RuntimeError: If the deserilized object is not of the expected type.
"""
... # pragma: no cover
@overload
def deserialize_object(self, name: str, /) -> Any:
"""Unsafeley deserialize an object.
Args:
name: Name of the dataset.
Returns: The deserialized object.
Raises:
RuntimeError: If the object is not a registered class.
RuntimeError: If the dataset does not exist.
"""
... # pragma: no cover
[docs]
def deserialize_object(self, name: str, /, *args) -> Any:
"""Deserialize an object.
Args:
name: Name of the dataset.
\*args: Optional arguments for method overloading.
Returns: The deserialized object.
Raises:
RuntimeError: If the object is not a registered class.
RuntimeError: If the dataset does not exist and no default value was provided.
RuntimeError: If the deserialized object is not of the expected type.
"""
default: Any | None
expected_type: Any
if len(args) < 1:
default = UNDEF
expected_type = Serializable
elif isinstance(args[0], (type, UnionType)):
default = args[1] if len(args) > 1 else UNDEF
expected_type = args[0]
else:
default = args[0]
expected_type = Serializable
# Deserialize the object
deserialized_object: expected_type | UNDEF_TYPE = self._deserialize_object(
name, expected_type, UNDEF
)
if deserialized_object is UNDEF:
if default == UNDEF:
raise RuntimeError(
f"Object '{name}' does not exist and no default value was provided"
)
else:
return default
# Check if the object is of the expected type
if not isinstance(deserialized_object, expected_type):
raise RuntimeError(
f"Deserialized object '{name}' is not of the expected type ({deserialized_object.__class__} instead of {expected_type})"
)
# Check that the object is a registered class
if deserialized_object.serialization_tag() not in self._tag_registry.keys():
raise RuntimeError(f"Object '{deserialized_object}' is not a registered class")
return deserialized_object
[docs]
@abstractmethod
def deserialize_object_sequence(
self, name: str, type: Type[_OT], start: int = 0, stop: int | None = None
) -> Sequence[_OT]:
"""Deserialize a sequence of objects.
Args:
name: Name of the dataset.
type: Expected type of the deserialized objects.
start: Index of the first object in the sequence. Defaults to zero.
stop: Index of the last object in the sequence. Defaults to :py:obj:`None`, i.e. the end of the sequence.
Returns:
Sequence[_OT]: The deserialized sequence of objects.
Raises:
IndexError: If the start index is out of bounds.
"""
... # pragma: no cover
[docs]
@abstractmethod
def sequence_length(self, name: str) -> int:
"""Deserialize the length of a sequence.
Args:
name: Name of the dataset.
Returns: The length of the sequence. Zero if the dataset does not exist.
"""
... # pragma: no cover
[docs]
def deserialize_range(
self, name: str, default: _RT | UNDEF_TYPE = UNDEF
) -> float | tuple[float, float] | _RT:
"""Deserialize a range value.
Args:
name: Name of the dataset.
default: Default value to be returned if the dataset does not exist.
Returns: The deserialized range value.
"""
# Attempt deserializing the scalar value
scalar_value = self.deserialize_floating(name, None)
if scalar_value is not None:
return scalar_value
# Attempt deserializing the range value
range_min = self.deserialize_floating(f"{name}_min", None)
range_max = self.deserialize_floating(f"{name}_max", None)
if range_min is not None and range_max is not None:
return (range_min, range_max)
# Return the default value if the dataset does not exist
if default == "UNDEFINED":
raise RuntimeError(f"Range value '{name}' does not exist")
return default
[docs]
@abstractmethod
def deserialize_custom(
self, name: str, callback: Callable[[DeserializationProcess], _RT]
) -> _RT:
"""Deserialize anything using a custom callback function.
Args:
name: Name of the dataset.
callback: Callback function to be executed.
Returns:
_RT: The result of the callback function.
"""
... # pragma: no cover
[docs]
def finalize(self) -> None:
"""Finalize the deserialization process.
After finalization, the deserialization process is considered complete and no further deserialization is possible.
Closes any open file handles and releases any resources reserved used during the deserialization process.
"""
pass # pragma: no cover
[docs]
class HDFSerializationProcess(SerializationProcess):
"""A process representation for serializing objects to HDF5."""
__base_group: Group
__group: Group
__arrays: MutableMapping[int, np.ndarray]
__objects: MutableMapping[int, str] # Mapping of serialized objects to their HDF5 paths
def __init__(
self,
tag_registry: Mapping[str, Type[Serializable]],
base_group: Group,
group: Group,
arrays: MutableMapping[int, np.ndarray],
objects: MutableMapping[int, str],
) -> None:
# Init base class
SerializationProcess.__init__(self, tag_registry)
# Init class attributes
self.__base_group = base_group
self.__group = group
self.__arrays = arrays
self.__objects = objects
[docs]
@staticmethod
def New(
tag_registry: Mapping[str, Type[Serializable]],
base_group: Group,
campaign: str | None = None,
) -> HDFSerializationProcess:
# Add meta information to the base group
base_group.attrs.create("version", __version__, dtype=string_dtype(encoding="utf-8"))
base_group.attrs.create("timestamp", time())
# Create the campaign group
if campaign is not None:
base_group = HDFSerializationProcess._create_group(base_group, campaign)
# Create the base group for caching arrays
HDFSerializationProcess._create_group(base_group, "arrays")
# Initialize the serialization process
return HDFSerializationProcess(tag_registry, base_group, base_group, dict(), dict())
[docs]
@override
def serialize_array(self, array: np.ndarray, name: str, cache: bool = True) -> None:
# Check if the dataset already exists
# This might be the case for diamond inheritance
if name in self.__group:
return
# The arrays are assumed to be immutable during the serialization process
# Therefore, the unique identifier is generated from the array's memory address
uid: int = id(array)
# If the array has already been serialized, just create a soft link to it
if cache and uid in self.__arrays:
self.__group[name] = SoftLink(self.__arrays[uid])
return
# Write the array to the group and record its exact location to the object cache
new = self.__group.create_dataset(name, array.shape, array.dtype, array)
if cache:
self.__arrays[uid] = new.name
[docs]
@override
def serialize_floating(self, value: float, name: str) -> None:
self.__group.attrs.create(name, value, dtype=np.float64)
[docs]
@override
def serialize_complex(self, value: complex, name: str) -> None:
self.__group.attrs.create(name, value, dtype=np.complex128)
[docs]
@override
def serialize_integer(self, value: int, name: str) -> None:
self.__group.attrs.create(name, value, dtype=np.int64)
[docs]
@override
def serialize_string(self, value: str, name: str) -> None:
self.__group.attrs.create(name, value, dtype=string_dtype(encoding="utf-8"))
@override
def _serialize_object(self, obj: Serializable, name: str, type: str, root: bool) -> None:
# Check if the object's group already exists
# This might be the case for diamond inheritance
if name in self.__group:
return
# If a root object is being serialized, create a new process with cleared caches
if root:
new_process = HDFSerializationProcess(
self._tag_registry, self.__base_group, self.__group, dict(), dict()
)
return new_process._serialize_object(obj, name, type, False)
obj_key: int = id(obj)
# If the object has already been serialized, just create a soft link to it
obj_path = self.__objects.get(obj_key, None)
if obj_path is not None:
self.__group[name] = SoftLink(obj_path)
return
# Create a new group for the object
new_group = self._create_group(self.__group, name)
# Save the object's path for future reference
self.__objects[obj_key] = new_group.name
# Create a new serialization process
process = HDFSerializationProcess(
self._tag_registry, self.__base_group, new_group, self.__arrays, self.__objects
)
# Serialize the object
obj.serialize(process)
# Write the object's tag to the group
# Warning: This MUST be done after calling the serialize method for improved security
new_group.attrs.create("type", type, dtype=string_dtype(encoding="utf-8"))
[docs]
@override
def serialize_object_sequence(
self,
objects: Sequence[Serializable] | set[Serializable],
name: str,
append: bool = False,
root: bool = False,
) -> None:
# Check if the sequence's group already exists
if name not in self.__group:
self.__group.create_group(name)
offset = 0
elif append:
offset = len(self.__group[name])
# If the append flag has not been set and the group already exists, delete it
# This is necessary to avoid conflicts with existing datasets
# For diamond inheritance, the group objects might already be cached and must therefore be deleted
# from the cache
else:
existing_group: Group = self.__group[name]
for obj in existing_group.values():
self.__objects.pop(obj.name, None)
offset = 0
# Serialize each object in the sequence
process = HDFSerializationProcess(
self._tag_registry, self.__base_group, self.__group[name], self.__arrays, self.__objects
)
for index, obj in enumerate(objects, offset):
process.serialize_object(obj, f"{index:04d}", root)
[docs]
@override
def serialize_custom(
self, name: str, serializer: Callable[[SerializationProcess], None]
) -> None:
new_group = self._create_group(self.__group, name)
process = HDFSerializationProcess(
self._tag_registry, self.__base_group, new_group, self.__arrays, self.__objects
)
serializer(process)
[docs]
@override
def finalize(self) -> None:
# Close the base file handle
if isinstance(self.__base_group, File):
self.__base_group.close()
@staticmethod
def _create_group(group: Group, name: str) -> Group:
if name not in group:
return group.create_group(name)
else:
return group[name]
[docs]
class HDFDeserializationProcess(DeserializationProcess):
"""A process representation for deserializing objects from HDF5."""
__base_group: Group
__group: Group
__arrays: MutableMapping[str, np.ndarray]
__objects: MutableMapping[str, Serializable]
def __init__(
self,
tag_registry: Mapping[str, Type[Serializable]],
base_group: Group,
group: Group,
objects: MutableMapping[str, Serializable],
arrays: MutableMapping[str, np.ndarray],
) -> None:
# Init base class
DeserializationProcess.__init__(self, tag_registry)
# Init class attributes
self.__base_group = base_group
self.__group = group
self.__objects = objects
self.__arrays = arrays
[docs]
@staticmethod
def New(
tag_registry: Mapping[str, Type[Serializable]],
base_group: Group,
campaign: str | None = None,
) -> HDFDeserializationProcess:
if campaign is not None:
if campaign not in base_group:
raise RuntimeError(f"Campaign '{campaign}' does not exist in the HDF5 file")
base_group = base_group[campaign]
return HDFDeserializationProcess(tag_registry, base_group, base_group, dict(), dict())
[docs]
@override
def deserialize_array(
self, name: str, dtype: DTypeLike, default: np.ndarray | UNDEF_TYPE = UNDEF
) -> np.ndarray:
# Get the link from the group
link: SoftLink | HardLink | ExternalLink | None = self.__group.get(name, None, False, True)
# If the array does not exist, attempt to return the default value
if link is None:
if default == UNDEF:
raise RuntimeError(
f"Array '{name}' does not exist in the HDF5 file and no default value was provided"
)
else:
return default # type: ignore[return-value]
# If the array has already been deserialized, just return a reference to it
if isinstance(link, SoftLink):
if link.path in self.__arrays:
return self.__arrays[link.path]
else:
array_dataset = self.__base_group[link.path]
else:
array_dataset = self.__group[name]
deserialized_array = np.asarray(array_dataset, dtype=dtype)
self.__arrays[array_dataset.name] = deserialized_array
return deserialized_array
def __fetch_attribute(self, name: str, default: _RT | UNDEF_TYPE) -> UNDEF_TYPE | _RT:
return (
self.__group.attrs.get(name)
if default is UNDEF
else self.__group.attrs.get(name, UNDEF)
)
[docs]
@override
def deserialize_floating(self, name: str, default: _RT | UNDEF_TYPE = UNDEF) -> float | _RT:
deserialization = self.__fetch_attribute(name, default)
if deserialization == UNDEF:
return default # type: ignore[return-value]
elif isinstance(deserialization, np.floating):
return float(deserialization)
else:
raise RuntimeError(f"Attribute '{name}' is not a floating point value")
[docs]
@override
def deserialize_complex(self, name: str, default: _RT | UNDEF_TYPE = UNDEF) -> complex | _RT:
deserialization = self.__fetch_attribute(name, default)
if deserialization == UNDEF:
return default # type: ignore[return-value]
elif isinstance(deserialization, np.complexfloating):
return complex(deserialization)
else:
raise RuntimeError(f"Attribute '{name}' is not a complex value")
[docs]
@override
def deserialize_integer(self, name: str, default: _RT | UNDEF_TYPE = UNDEF) -> int | _RT:
deserialization = self.__fetch_attribute(name, default)
if deserialization == UNDEF:
return default # type: ignore[return-value]
elif isinstance(deserialization, np.integer):
return int(deserialization)
else:
raise RuntimeError(f"Attribute '{name}' is not an integer value")
[docs]
@override
def deserialize_string(self, name: str, default: _RT | UNDEF_TYPE = UNDEF) -> str | _RT:
deserialization = self.__fetch_attribute(name, default)
return default if deserialization is UNDEF else str(deserialization)
@override
def _deserialize_object(self, name: str, expected_type: Type[_OT], default: _RT) -> _OT | _RT:
# Get the object's link
link: SoftLink | HardLink | ExternalLink = self.__group.get(name, None, False, True)
# Return the default value if the object does not exist
if link is None:
return default
# A softlink indicates that the object is stored in a different group
if isinstance(link, SoftLink):
# If the object has already been deserialized, just return a reference to it
if link.path in self.__objects:
return self.__objects[link.path] # type: ignore[return-value]
# Otherwise, load the object from the SoftLink's group
else:
object_group = self.__base_group[link.path]
else:
object_group = self.__group[name]
# Get the object's type
object_type: str | None = object_group.attrs.get("type", default=None)
if object_type is None:
raise RuntimeError(
f"Object '{name}' does not have a type attribute and therfore cannot be deserialized"
)
# Get the object's class (type)
object_class: Type[Serializable] | None = self._tag_registry.get(str(object_type), None)
# Raise an exepception if the object is not a registered class
if object_class is None:
raise RuntimeError(
f"Object '{str(object_type)}' is not a registered class and therfore cannot be deserialized"
)
# Raise an exception if the object is not of the expected type
if not issubclass(object_class, expected_type):
raise RuntimeError(
f"Object '{name}' is not of the expected type ({object_class} instead of {expected_type})"
)
# Initialize the object
process = HDFDeserializationProcess(
self._tag_registry, self.__base_group, object_group, self.__objects, self.__arrays
)
deserialized_object = object_class.Deserialize(process)
# Store the object in the internal cache
self.__objects[object_group.name] = deserialized_object
# Return the deserialized object
return deserialized_object
[docs]
@override
def deserialize_object_sequence(
self, name: str, type: Type[_OT], start: int = 0, stop: int = None
) -> Sequence[_OT]:
# Check if the sequence's group exists
if name not in self.__group:
if stop is None:
return []
else:
raise RuntimeError(f"Object sequence '{name}' does not exist in the HDF5 file")
# Get the sequence's group
sequence_group = self.__group[name]
sequence_process = HDFDeserializationProcess(
self._tag_registry, self.__base_group, sequence_group, self.__objects, self.__arrays
)
# Get the number of objects in the sequence
count = len(sequence_group)
# Deserialize each object in the sequence
objects: list[_OT] = []
for index in range(start, count if stop is None else min(stop, count)):
deserialized_object = sequence_process.deserialize_object(f"{index:04d}", type, None)
if deserialized_object is None:
raise IndexError(f"Index {index} out of bounds in object sequence '{name}'")
objects.append(deserialized_object)
return objects
[docs]
@override
def sequence_length(self, name: str) -> int:
if name not in self.__group:
return 0
return len(self.__group[name])
[docs]
@override
def deserialize_custom(
self, name: str, callback: Callable[[DeserializationProcess], _RT] | None = None
) -> _RT:
process = HDFDeserializationProcess(
self._tag_registry, self.__base_group, self.__group[name], self.__objects, self.__arrays
)
return callback(process)
[docs]
@override
def finalize(self) -> None:
# Close the base file handle
if isinstance(self.__base_group, File):
self.__base_group.close()