Source code for hermespy.core.factory

# -*- coding: utf-8 -*-

from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Mapping, MutableMapping
from enum import Enum
from inspect import getmembers, isabstract, isclass
from importlib import import_module
import os
from pkgutil import iter_modules
from time import time
from types import UnionType
from typing import Any, Callable, KeysView, Literal, Sequence, Type, TypeVar, ValuesView
from typing_extensions import override, overload

import numpy as np
from h5py import ExternalLink, File, HardLink, Group, SoftLink, string_dtype
from numpy.typing import DTypeLike

import hermespy

__author__ = "Jan Adler"
__copyright__ = "Copyright 2024, Barkhausen Institut gGmbH"
__credits__ = ["Jan Adler"]
__license__ = "AGPLv3"
__version__ = "1.5.0"
__maintainer__ = "Jan Adler"
__email__ = "jan.adler@barkhauseninstitut.org"
__status__ = "Prototype"


UNDEF_TYPE = Literal["UNDEFINED"]
"""Type of an undefined value representing optional arguments that include :py:obj:`None` as a default value."""


UNDEF: UNDEF_TYPE = "UNDEFINED"
"""Undefined value for optional arguments that include :py:obj:`None` as a default value."""


SerializableType = TypeVar("SerializableType", bound="Serializable")
"""Type of Serializable Class."""


[docs] class Serializable(object): """Base class for serializable classes. Only classes inheriting from `Serializable` will be serialized by the factory. """
[docs] @abstractmethod def serialize(self, process: SerializationProcess) -> None: """Serialize this object's state. Objects cannot be serialized directly, instead a :class:`Factory<hermespy.core.factory.Factory>` must be instructed to carry out the serialization process. Args: process: The current stage of the serialization process. This object is generated by the :class:`Factory<hermespy.core.factory.Factory>` and provides an interface to serialization methods supporting multiple backends. """ ... # pragma: no cover
[docs] @classmethod @abstractmethod def Deserialize( cls: Type[SerializableType], process: DeserializationProcess ) -> SerializableType: """Deserialize an object's state. Objects cannot be deserialized directly, instead a :class:`Factory<hermespy.core.factory.Factory>` must be instructed to carry out the deserialization process. Args: process: The current stage of the deserialization process. This object is generated by the :class:`Factory<hermespy.core.factory.Factory>` and provides an interface to deserialization methods supporting multiple backends. Returns: The deserialized object. """ ... # pragma: no cover
[docs] @classmethod def serialization_tag(cls: Type[SerializableType]) -> str: """Tag used to identify the respective class during serialization. Returns: The serialization tag. """ return cls.__module__ + "." + cls.__name__
SET = TypeVar("SET", bound="SerializableEnum") """Type of serializable enumeration."""
[docs] class SerializableEnum(Serializable, Enum): """Base class for serializable enumerations."""
[docs] @classmethod def from_parameters(cls: Type[SET], enum: SET | int | str) -> SET: """Initialize enumeration from multiple parameters. Args: enum: The parameter from which the enum should be initialized. Returns: The initialized enumeration. """ if isinstance(enum, cls): return enum elif isinstance(enum, int): return cls(enum) elif isinstance(enum, str): return cls[enum] else: raise ValueError("Unknown serializable enumeration type")
[docs] @override def serialize(self, process: SerializationProcess) -> None: process.serialize_integer(self.value, "value")
[docs] @override @classmethod def Deserialize(cls: Type[SET], process: DeserializationProcess) -> SET: return cls(process.deserialize_integer("value"))
[docs] class SerializationBackend(SerializableEnum): """Backend of the serialization process.""" HDF = 0 """HDF5 serialization backend.""" MAT = 1 """MATLAB serialization backend. Not implemented yet. """ PICKLE = 2 """Pickle serialization backend. Not implemented yet. """
[docs] class Factory: """Helper class to serialize Hermespy's runtime objects.""" __tag_registry: Mapping[str, Type[Serializable]] def __init__(self) -> None: self.__tag_registry = dict() # Iterate over all modules within the hermespy namespace # Scan for serializable classes lookup_paths = list(hermespy.__path__) + [ os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) ] for _, name, is_module in iter_modules(lookup_paths, hermespy.__name__ + "."): if not is_module: continue # pragma: no cover module = import_module(name) for _, serializable_class in getmembers(module): if not isclass(serializable_class) or not issubclass( serializable_class, Serializable ): continue # Only register non-abstract classes if isabstract(serializable_class): continue # Generate a unique serialization UID tag = serializable_class.serialization_tag() # Throw an error if the tag is already registered with a different class if tag in self.__tag_registry: if self.__tag_registry[tag] != serializable_class: raise RuntimeError( f"Error attempting to register class '{serializable_class.__name__}' under the tag '{tag}': Tag is already registered with class '{self.__tag_registry[tag].__name__}'" ) continue self.__tag_registry[tag] = serializable_class @property def registered_classes(self) -> ValuesView[Type[Serializable]]: """Classes registered for serialization within the factory.""" return self.__tag_registry.values() @property def registered_tags(self) -> KeysView[str]: """Read registered serialization tags.""" return self.__tag_registry.keys() @property def tag_registry(self) -> Mapping[str, Type[Serializable]]: """Read registered serialization tags.""" return self.__tag_registry
[docs] def to_HDF(self, target: str | Group, serializable: Serializable) -> None: """Dump a runtime object to a single HDF5 file or HDF5 file-like object. Args: target: Path to the configuration file or HDF5 file-like object. serializable: Object to be serialized. Raises: RuntimeError: If objects in ``*args`` are unregistered classes or not serializable. """ # Open HDF5 file (if required) _target = File(target, mode="w") if isinstance(target, str) else target try: # Generate a new serialization process process = HDFSerializationProcess.New(self.tag_registry, _target) # Serialize the object serializable.serialize(process) finally: if isinstance(target, str): _target.close()
[docs] def from_HDF(self, target: str | Group, type: Type[Serializable]) -> Any: """Load a runtime object from a single HDF5 file or HDF5 file-like object. Args: target: Path to the configuration file or HDF5 file-like object. Returns: Serializable: The deserialized object. Raises: RuntimeError: If the object is not a registered class. """ # Open HDF5 file (if required) _target = File(target, mode="r") if isinstance(target, str) else target try: # Generate a new deserialization process process = HDFDeserializationProcess.New(self.tag_registry, _target) # Deserialize the object deserialization = type.Deserialize(process) finally: if isinstance(target, str): _target.close() return deserialization
[docs] def serialize( self, target: str, campaign: str | None = None, backend: SerializationBackend = SerializationBackend.HDF, ) -> SerializationProcess: """Serialize an object to a file. Args: target: Path to the target file. campaign: Name of the serialization campaign. Can be set to store multiple related serializations within the same file. backend: Serialization backend to be used. Defaults to :class:`SerializationBackend.HDF`. """ if backend == SerializationBackend.HDF: return HDFSerializationProcess.New(self.tag_registry, File(target, mode="a"), campaign) raise RuntimeError(f"Unsupported serialization backend '{backend}'")
[docs] def deserialize( self, target: str, campaign: str | None = None, backend: SerializationBackend = SerializationBackend.HDF, ) -> DeserializationProcess: """Deserialize an object from a file. Args: target: Path to the target file. campaign: Name of the serialization campaign. Can be set to store multiple related serializations within the same file. backend: Serialization backend to be used. Defaults to :class:`SerializationBackend.HDF`. Returns: The deserialized object. """ if backend == SerializationBackend.HDF: return HDFDeserializationProcess.New( self.tag_registry, File(target, mode="r"), campaign ) else: raise RuntimeError(f"Unsupported serialization backend '{backend}'")
[docs] class ProcessBase(ABC): """Common base class for serialization and deserialization processes.""" __tag_registry: Mapping[str, Type[Serializable]] def __init__(self, tag_registry: Mapping[str, Type[Serializable]]) -> None: self.__tag_registry = tag_registry @property def _tag_registry(self) -> Mapping[str, Type[Serializable]]: return self.__tag_registry
_OT = TypeVar("_OT", bound=Serializable) _RT = TypeVar("_RT")
[docs] class SerializationProcess(ProcessBase): """Base class for all serialization processes."""
[docs] @abstractmethod def serialize_array(self, array: np.ndarray, name: str, cache: bool = True) -> None: """Serialize a numpy array. .. warning:: Arrays to be serialized should be immutable during the serialization process. This is to ensure that the array's memory address remains constant and can be used as a unique identifier. If this can't be guaranteed, caching for the respective array should be disabled. Args: array: The numpy array to be serialized. name: Name of the dataset. cache: Cache the array for future reference. Defaults to :py:obj:`True`. """ ... # pragma: no cover
[docs] @abstractmethod def serialize_floating(self, value: float, name: str) -> None: """Serialize a floating point value. Args: value: The floating point value to be serialized. name: Name of the dataset. """ ... # pragma: no cover
[docs] @abstractmethod def serialize_complex(self, value: complex, name: str) -> None: """Serialize a complex value. Args: value: The complex value to be serialized. name: Name of the dataset. """ ... # pragma: no cover
[docs] @abstractmethod def serialize_integer(self, value: int, name: str) -> None: """Serialize an integer value. Args: value: The integer value to be serialized. name: Name of the dataset. """ ... # pragma: no cover
[docs] @abstractmethod def serialize_string(self, value: str, name: str) -> None: """Serialize a string value. Args: value: The string value to be serialized. name: Name of the dataset. """ ... # pragma: no cover
[docs] def serialize_object(self, obj: Serializable, name: str, root: bool = False) -> None: """Serialize an object. Args: obj: The object to be serialized. name: Name of the dataset. root: Serialize the object as a root object. Defaults to :py:obj:`False`. Raises: RuntimeError: If the object is not a registered class. """ if not isinstance(obj, Serializable): raise RuntimeError(f"Object '{obj}' is not a serializable class") tag = obj.serialization_tag() if tag not in self._tag_registry.keys(): raise RuntimeError( f"Object '{obj}' is not a registered class and therfore cannot be serialized" ) self._serialize_object(obj, name, tag, root)
@abstractmethod def _serialize_object(self, obj: Serializable, name: str, type: str, root: bool) -> None: """Serialize an object. Args: obj: The object to be serialized. name: Name of the dataset. type: The object's serialization tag. """ ... # pragma: no cover
[docs] def serialize_object_sequence( self, objects: Sequence[Serializable], name: str, append: bool = False, root: bool = False ) -> None: """Serialize a sequence of objects. Args: objects: The sequence of objects to be serialized. name: Name of the dataset. append: Append the objects to the dataset if it already exists. Defaults to :py:obj:`False`. """ ... # pragma: no cover
[docs] def serialize_object_mapping( self, objects: Mapping[str, Serializable], name: str, append: bool = True, root: bool = False, ) -> None: """Serialize a mapping of string keys to objects. Args: objects: The mapping of objects to be serialized. name: Name of the dataset. append: Append the objects to the dataset if it already exists. Defaults to :py:obj:`True`. """ # Serialize the number of objects self.serialize_integer(len(objects), f"{name}_count") # Serialize each object for index, (key, obj) in enumerate(objects.items()): self.serialize_object(obj, f"{name}_{index:02d}") self.serialize_string(key, f"{name}_{index:02d}_key")
[docs] def serialize_range(self, value: float | tuple[float, float] | None, name: str) -> None: # None indicates an undefined value that should not be serialized if value is None: return # Serialize the range value if isinstance(value, tuple): self.serialize_floating(value[0], f"{name}_min") self.serialize_floating(value[1], f"{name}_max") # Serialize the scalar value elif isinstance(value, (float, np.float64)): self.serialize_floating(value, name) else: raise ValueError("Invalid range value type")
[docs] @abstractmethod def serialize_custom( self, name: str, serializer: Callable[[SerializationProcess], None] ) -> None: """Serialize anything using a custom function. Args: name: Name of the dataset. serializer: Serialization function to be executed. """ ... # pragma: no cover
[docs] def finalize(self) -> None: """Finalize the serialization process. After finalization, the serialization process is considered complete and no further serialization is possible. Closes any open file handles and releases any resources reserved used during the serialization process. """ pass # pragma: no cover
[docs] class DeserializationProcess(ProcessBase): """Base class for all deserialization processes."""
[docs] @abstractmethod def deserialize_array( self, name: str, dtype: DTypeLike, default: np.ndarray | UNDEF_TYPE = UNDEF ) -> np.ndarray: """Deserialize a numpy array. Args: name: Name of the dataset. dtype: Expected data type of the array. default: Default value to be returned if the value does not exist. Returns: numpy.ndarray: The deserialized numpy array. """ ... # pragma: no cover
[docs] @abstractmethod def deserialize_floating(self, name: str, default: float | UNDEF_TYPE = UNDEF) -> float: """Deserialize a floating point value. Args: name: Name of the dataset. default: Default value to be returned if the value does not exist. Returns: float: The deserialized floating point value. """ ... # pragma: no cover
[docs] @abstractmethod def deserialize_complex(self, name: str, default: complex | UNDEF_TYPE = UNDEF) -> complex: """Deserialize a complex value. Args: name: Name of the dataset. default: Default value to be returned if the value does not exist. Returns: complex: The deserialized complex value. """ ... # pragma: no cover
[docs] @abstractmethod def deserialize_integer(self, name: str, default: int | UNDEF_TYPE = UNDEF) -> int: """Deserialize an integer value. Args: name: Name of the dataset. default: Default value to be returned if the value does not exist. Returns: int: The deserialized integer value. """ ... # pragma: no cover
[docs] @abstractmethod def deserialize_string(self, name: str, default: str | UNDEF_TYPE = UNDEF) -> str: """Deserialize a string value. Args: name: Name of the dataset. default: Default value to be returned if the value does not exist. Returns: str: The deserialized string value. """ ... # pragma: no cover
@abstractmethod def _deserialize_object(self, name: str, expected_type: Type[_OT], default: _RT) -> _OT | _RT: """Deserialize an object. Args: name: Name of the dataset. expected_type: Expected type of the deserialized object. default: Default value to be returned if the object does not exist. Returns: The deserialized object. """ ... # pragma: no cover @overload def deserialize_object(self, name: str, type: Type[_OT] | Sequence[Type[_OT]], /) -> _OT: """Safeley deserialize an object. Args: name: Name of the dataset. type: Expected type of the deserialized object. Returns: The deserialized object. Raises: RuntimeError: If the object is not a registered class. RuntimeError: If the deserilized object is not of the expected type. RuntimeError: If the dataset does not exist. """ ... # pragma: no cover @overload def deserialize_object( self, name: str, type: Type[_OT] | Sequence[Type[_OT]], default: _RT, / ) -> _OT | _RT: """Safeley deserialize an object. Args: name: Name of the dataset. type: Expected type of the deserialized object. default: Default value to be returned if the dataset does not exist. Returns: The deserialized object. Raises: RuntimeError: If the object is not a registered class. RuntimeError: If the deserilized object is not of the expected type. """ ... # pragma: no cover @overload def deserialize_object(self, name: str, /) -> Any: """Unsafeley deserialize an object. Args: name: Name of the dataset. Returns: The deserialized object. Raises: RuntimeError: If the object is not a registered class. RuntimeError: If the dataset does not exist. """ ... # pragma: no cover
[docs] def deserialize_object(self, name: str, /, *args) -> Any: """Deserialize an object. Args: name: Name of the dataset. \*args: Optional arguments for method overloading. Returns: The deserialized object. Raises: RuntimeError: If the object is not a registered class. RuntimeError: If the dataset does not exist and no default value was provided. RuntimeError: If the deserialized object is not of the expected type. """ default: Any | None expected_type: Any if len(args) < 1: default = UNDEF expected_type = Serializable elif isinstance(args[0], (type, UnionType)): default = args[1] if len(args) > 1 else UNDEF expected_type = args[0] else: default = args[0] expected_type = Serializable # Deserialize the object deserialized_object: expected_type | UNDEF_TYPE = self._deserialize_object( name, expected_type, UNDEF ) if deserialized_object is UNDEF: if default == UNDEF: raise RuntimeError( f"Object '{name}' does not exist and no default value was provided" ) else: return default # Check if the object is of the expected type if not isinstance(deserialized_object, expected_type): raise RuntimeError( f"Deserialized object '{name}' is not of the expected type ({deserialized_object.__class__} instead of {expected_type})" ) # Check that the object is a registered class if deserialized_object.serialization_tag() not in self._tag_registry.keys(): raise RuntimeError(f"Object '{deserialized_object}' is not a registered class") return deserialized_object
[docs] @abstractmethod def deserialize_object_sequence( self, name: str, type: Type[_OT], start: int = 0, stop: int | None = None ) -> Sequence[_OT]: """Deserialize a sequence of objects. Args: name: Name of the dataset. type: Expected type of the deserialized objects. start: Index of the first object in the sequence. Defaults to zero. stop: Index of the last object in the sequence. Defaults to :py:obj:`None`, i.e. the end of the sequence. Returns: Sequence[_OT]: The deserialized sequence of objects. Raises: IndexError: If the start index is out of bounds. """ ... # pragma: no cover
[docs] @abstractmethod def sequence_length(self, name: str) -> int: """Deserialize the length of a sequence. Args: name: Name of the dataset. Returns: The length of the sequence. Zero if the dataset does not exist. """ ... # pragma: no cover
[docs] def deserialize_range( self, name: str, default: _RT | UNDEF_TYPE = UNDEF ) -> float | tuple[float, float] | _RT: """Deserialize a range value. Args: name: Name of the dataset. default: Default value to be returned if the dataset does not exist. Returns: The deserialized range value. """ # Attempt deserializing the scalar value scalar_value = self.deserialize_floating(name, None) if scalar_value is not None: return scalar_value # Attempt deserializing the range value range_min = self.deserialize_floating(f"{name}_min", None) range_max = self.deserialize_floating(f"{name}_max", None) if range_min is not None and range_max is not None: return (range_min, range_max) # Return the default value if the dataset does not exist if default == "UNDEFINED": raise RuntimeError(f"Range value '{name}' does not exist") return default
[docs] @abstractmethod def deserialize_custom( self, name: str, callback: Callable[[DeserializationProcess], _RT] ) -> _RT: """Deserialize anything using a custom callback function. Args: name: Name of the dataset. callback: Callback function to be executed. Returns: _RT: The result of the callback function. """ ... # pragma: no cover
[docs] def finalize(self) -> None: """Finalize the deserialization process. After finalization, the deserialization process is considered complete and no further deserialization is possible. Closes any open file handles and releases any resources reserved used during the deserialization process. """ pass # pragma: no cover
[docs] class HDFSerializationProcess(SerializationProcess): """A process representation for serializing objects to HDF5.""" __base_group: Group __group: Group __arrays: MutableMapping[int, np.ndarray] __objects: MutableMapping[int, str] # Mapping of serialized objects to their HDF5 paths def __init__( self, tag_registry: Mapping[str, Type[Serializable]], base_group: Group, group: Group, arrays: MutableMapping[int, np.ndarray], objects: MutableMapping[int, str], ) -> None: # Init base class SerializationProcess.__init__(self, tag_registry) # Init class attributes self.__base_group = base_group self.__group = group self.__arrays = arrays self.__objects = objects
[docs] @staticmethod def New( tag_registry: Mapping[str, Type[Serializable]], base_group: Group, campaign: str | None = None, ) -> HDFSerializationProcess: # Add meta information to the base group base_group.attrs.create("version", __version__, dtype=string_dtype(encoding="utf-8")) base_group.attrs.create("timestamp", time()) # Create the campaign group if campaign is not None: base_group = HDFSerializationProcess._create_group(base_group, campaign) # Create the base group for caching arrays HDFSerializationProcess._create_group(base_group, "arrays") # Initialize the serialization process return HDFSerializationProcess(tag_registry, base_group, base_group, dict(), dict())
[docs] @override def serialize_array(self, array: np.ndarray, name: str, cache: bool = True) -> None: # Check if the dataset already exists # This might be the case for diamond inheritance if name in self.__group: return # The arrays are assumed to be immutable during the serialization process # Therefore, the unique identifier is generated from the array's memory address uid: int = id(array) # If the array has already been serialized, just create a soft link to it if cache and uid in self.__arrays: self.__group[name] = SoftLink(self.__arrays[uid]) return # Write the array to the group and record its exact location to the object cache new = self.__group.create_dataset(name, array.shape, array.dtype, array) if cache: self.__arrays[uid] = new.name
[docs] @override def serialize_floating(self, value: float, name: str) -> None: self.__group.attrs.create(name, value, dtype=np.float64)
[docs] @override def serialize_complex(self, value: complex, name: str) -> None: self.__group.attrs.create(name, value, dtype=np.complex128)
[docs] @override def serialize_integer(self, value: int, name: str) -> None: self.__group.attrs.create(name, value, dtype=np.int64)
[docs] @override def serialize_string(self, value: str, name: str) -> None: self.__group.attrs.create(name, value, dtype=string_dtype(encoding="utf-8"))
@override def _serialize_object(self, obj: Serializable, name: str, type: str, root: bool) -> None: # Check if the object's group already exists # This might be the case for diamond inheritance if name in self.__group: return # If a root object is being serialized, create a new process with cleared caches if root: new_process = HDFSerializationProcess( self._tag_registry, self.__base_group, self.__group, dict(), dict() ) return new_process._serialize_object(obj, name, type, False) obj_key: int = id(obj) # If the object has already been serialized, just create a soft link to it obj_path = self.__objects.get(obj_key, None) if obj_path is not None: self.__group[name] = SoftLink(obj_path) return # Create a new group for the object new_group = self._create_group(self.__group, name) # Save the object's path for future reference self.__objects[obj_key] = new_group.name # Create a new serialization process process = HDFSerializationProcess( self._tag_registry, self.__base_group, new_group, self.__arrays, self.__objects ) # Serialize the object obj.serialize(process) # Write the object's tag to the group # Warning: This MUST be done after calling the serialize method for improved security new_group.attrs.create("type", type, dtype=string_dtype(encoding="utf-8"))
[docs] @override def serialize_object_sequence( self, objects: Sequence[Serializable] | set[Serializable], name: str, append: bool = False, root: bool = False, ) -> None: # Check if the sequence's group already exists if name not in self.__group: self.__group.create_group(name) offset = 0 elif append: offset = len(self.__group[name]) # If the append flag has not been set and the group already exists, delete it # This is necessary to avoid conflicts with existing datasets # For diamond inheritance, the group objects might already be cached and must therefore be deleted # from the cache else: existing_group: Group = self.__group[name] for obj in existing_group.values(): self.__objects.pop(obj.name, None) offset = 0 # Serialize each object in the sequence process = HDFSerializationProcess( self._tag_registry, self.__base_group, self.__group[name], self.__arrays, self.__objects ) for index, obj in enumerate(objects, offset): process.serialize_object(obj, f"{index:04d}", root)
[docs] @override def serialize_custom( self, name: str, serializer: Callable[[SerializationProcess], None] ) -> None: new_group = self._create_group(self.__group, name) process = HDFSerializationProcess( self._tag_registry, self.__base_group, new_group, self.__arrays, self.__objects ) serializer(process)
[docs] @override def finalize(self) -> None: # Close the base file handle if isinstance(self.__base_group, File): self.__base_group.close()
@staticmethod def _create_group(group: Group, name: str) -> Group: if name not in group: return group.create_group(name) else: return group[name]
[docs] class HDFDeserializationProcess(DeserializationProcess): """A process representation for deserializing objects from HDF5.""" __base_group: Group __group: Group __arrays: MutableMapping[str, np.ndarray] __objects: MutableMapping[str, Serializable] def __init__( self, tag_registry: Mapping[str, Type[Serializable]], base_group: Group, group: Group, objects: MutableMapping[str, Serializable], arrays: MutableMapping[str, np.ndarray], ) -> None: # Init base class DeserializationProcess.__init__(self, tag_registry) # Init class attributes self.__base_group = base_group self.__group = group self.__objects = objects self.__arrays = arrays
[docs] @staticmethod def New( tag_registry: Mapping[str, Type[Serializable]], base_group: Group, campaign: str | None = None, ) -> HDFDeserializationProcess: if campaign is not None: if campaign not in base_group: raise RuntimeError(f"Campaign '{campaign}' does not exist in the HDF5 file") base_group = base_group[campaign] return HDFDeserializationProcess(tag_registry, base_group, base_group, dict(), dict())
[docs] @override def deserialize_array( self, name: str, dtype: DTypeLike, default: np.ndarray | UNDEF_TYPE = UNDEF ) -> np.ndarray: # Get the link from the group link: SoftLink | HardLink | ExternalLink | None = self.__group.get(name, None, False, True) # If the array does not exist, attempt to return the default value if link is None: if default == UNDEF: raise RuntimeError( f"Array '{name}' does not exist in the HDF5 file and no default value was provided" ) else: return default # type: ignore[return-value] # If the array has already been deserialized, just return a reference to it if isinstance(link, SoftLink): if link.path in self.__arrays: return self.__arrays[link.path] else: array_dataset = self.__base_group[link.path] else: array_dataset = self.__group[name] deserialized_array = np.asarray(array_dataset, dtype=dtype) self.__arrays[array_dataset.name] = deserialized_array return deserialized_array
def __fetch_attribute(self, name: str, default: _RT | UNDEF_TYPE) -> UNDEF_TYPE | _RT: return ( self.__group.attrs.get(name) if default is UNDEF else self.__group.attrs.get(name, UNDEF) )
[docs] @override def deserialize_floating(self, name: str, default: _RT | UNDEF_TYPE = UNDEF) -> float | _RT: deserialization = self.__fetch_attribute(name, default) if deserialization == UNDEF: return default # type: ignore[return-value] elif isinstance(deserialization, np.floating): return float(deserialization) else: raise RuntimeError(f"Attribute '{name}' is not a floating point value")
[docs] @override def deserialize_complex(self, name: str, default: _RT | UNDEF_TYPE = UNDEF) -> complex | _RT: deserialization = self.__fetch_attribute(name, default) if deserialization == UNDEF: return default # type: ignore[return-value] elif isinstance(deserialization, np.complexfloating): return complex(deserialization) else: raise RuntimeError(f"Attribute '{name}' is not a complex value")
[docs] @override def deserialize_integer(self, name: str, default: _RT | UNDEF_TYPE = UNDEF) -> int | _RT: deserialization = self.__fetch_attribute(name, default) if deserialization == UNDEF: return default # type: ignore[return-value] elif isinstance(deserialization, np.integer): return int(deserialization) else: raise RuntimeError(f"Attribute '{name}' is not an integer value")
[docs] @override def deserialize_string(self, name: str, default: _RT | UNDEF_TYPE = UNDEF) -> str | _RT: deserialization = self.__fetch_attribute(name, default) return default if deserialization is UNDEF else str(deserialization)
@override def _deserialize_object(self, name: str, expected_type: Type[_OT], default: _RT) -> _OT | _RT: # Get the object's link link: SoftLink | HardLink | ExternalLink = self.__group.get(name, None, False, True) # Return the default value if the object does not exist if link is None: return default # A softlink indicates that the object is stored in a different group if isinstance(link, SoftLink): # If the object has already been deserialized, just return a reference to it if link.path in self.__objects: return self.__objects[link.path] # type: ignore[return-value] # Otherwise, load the object from the SoftLink's group else: object_group = self.__base_group[link.path] else: object_group = self.__group[name] # Get the object's type object_type: str | None = object_group.attrs.get("type", default=None) if object_type is None: raise RuntimeError( f"Object '{name}' does not have a type attribute and therfore cannot be deserialized" ) # Get the object's class (type) object_class: Type[Serializable] | None = self._tag_registry.get(str(object_type), None) # Raise an exepception if the object is not a registered class if object_class is None: raise RuntimeError( f"Object '{str(object_type)}' is not a registered class and therfore cannot be deserialized" ) # Raise an exception if the object is not of the expected type if not issubclass(object_class, expected_type): raise RuntimeError( f"Object '{name}' is not of the expected type ({object_class} instead of {expected_type})" ) # Initialize the object process = HDFDeserializationProcess( self._tag_registry, self.__base_group, object_group, self.__objects, self.__arrays ) deserialized_object = object_class.Deserialize(process) # Store the object in the internal cache self.__objects[object_group.name] = deserialized_object # Return the deserialized object return deserialized_object
[docs] @override def deserialize_object_sequence( self, name: str, type: Type[_OT], start: int = 0, stop: int = None ) -> Sequence[_OT]: # Check if the sequence's group exists if name not in self.__group: if stop is None: return [] else: raise RuntimeError(f"Object sequence '{name}' does not exist in the HDF5 file") # Get the sequence's group sequence_group = self.__group[name] sequence_process = HDFDeserializationProcess( self._tag_registry, self.__base_group, sequence_group, self.__objects, self.__arrays ) # Get the number of objects in the sequence count = len(sequence_group) # Deserialize each object in the sequence objects: list[_OT] = [] for index in range(start, count if stop is None else min(stop, count)): deserialized_object = sequence_process.deserialize_object(f"{index:04d}", type, None) if deserialized_object is None: raise IndexError(f"Index {index} out of bounds in object sequence '{name}'") objects.append(deserialized_object) return objects
[docs] @override def sequence_length(self, name: str) -> int: if name not in self.__group: return 0 return len(self.__group[name])
[docs] @override def deserialize_custom( self, name: str, callback: Callable[[DeserializationProcess], _RT] | None = None ) -> _RT: process = HDFDeserializationProcess( self._tag_registry, self.__base_group, self.__group[name], self.__objects, self.__arrays ) return callback(process)
[docs] @override def finalize(self) -> None: # Close the base file handle if isinstance(self.__base_group, File): self.__base_group.close()