Source code for hermespy.core.factory

# -*- coding: utf-8 -*-

from __future__ import annotations

import re
from abc import ABCMeta, abstractmethod
from collections.abc import Iterable
from enum import Enum
from inspect import getmembers, isclass, signature
from importlib import import_module
from io import TextIOBase, StringIO
import os
from pkgutil import iter_modules
from re import compile, Pattern, Match
from typing import (
    Any,
    Dict,
    Set,
    Sequence,
    Mapping,
    Union,
    KeysView,
    List,
    Optional,
    Tuple,
    Type,
    TypeVar,
    ValuesView,
)

import numpy as np
from h5py import Group
from ruamel.yaml import (
    YAML,
    SafeConstructor,
    SafeRepresenter,
    ScalarNode,
    Node,
    MappingNode,
    SequenceNode,
)
from ruamel.yaml.constructor import ConstructorError

import hermespy
from .logarithmic import Logarithmic, LogarithmicSequence

__author__ = "Jan Adler"
__copyright__ = "Copyright 2024, Barkhausen Institut gGmbH"
__credits__ = ["Jan Adler"]
__license__ = "AGPLv3"
__version__ = "1.4.0"
__maintainer__ = "Jan Adler"
__email__ = "jan.adler@barkhauseninstitut.org"
__status__ = "Prototype"


SerializableType = TypeVar("SerializableType", bound="Serializable")
"""Type of Serializable Class."""


[docs] class Serializable(object): """Base class for serializable classes. Only classes inheriting from `Serializable` will be serialized by the factory. """ yaml_tag: Optional[str] = None """YAML serialization tag. :meta private: """ property_blacklist: Set[str] = set() """Set of properties to be ignored during serialization. :meta private: """ serialized_attributes: Set[str] = set() """Set of object attributes to be serialized. :meta private: """ @staticmethod def _arg_signature() -> Set[str]: """Argument signature. Returns: Additional arguments not inferable from the init signature. :meta private: """ return set() @classmethod def _serializable_attributes( cls: Type[Serializable], blacklist: Optional[Set[str]] = None ) -> Set[str]: """Extract the set of serializable class attributes. Args: cls (Type[Serializable]): Class of the object to be serialized. blacklist (Set[str], optional): List of attribute names to be ignored during extraction. Returns: Set of serializable attribute names. :meta private: """ if blacklist: blacklist = blacklist.copy() blacklist.update(cls.property_blacklist) else: blacklist = cls.property_blacklist # Extract initialization signature init_signature = set(signature(cls.__init__).parameters.keys()) # Query serializable properties attributes = set() for attribute_key, attribute_type in getmembers(cls): # Prevent the access to protected or private attributes if attribute_key.startswith("_"): continue # Only add attribute if it isn't blacklisted if attribute_key in blacklist: continue # Make sure the attribute is a property if not isinstance(attribute_type, property): continue # Don't serialize if the property isn't settable if attribute_type.fset is None and attribute_key not in init_signature: continue attributes.add(attribute_key) # Add forced attributes attributes.update(cls.serialized_attributes) return attributes @classmethod def to_yaml( cls: Type[SerializableType], representer: SafeRepresenter, node: SerializableType ) -> Node: """Serialize a serializable object to YAML. Args: representer (SafeRepresenter): A handle to a representer used to generate valid YAML code. The representer gets passed down the serialization tree to each node. node (Serializable): The channel instance to be serialized. Returns: The serialized YAML node. :meta private: """ return node._mapping_serialization_wrapper(representer) def _mapping_serialization_wrapper( self, representer: SafeRepresenter, blacklist: Optional[Set[str]] = None, additional_fields: Optional[Dict[str, Any]] = None, ) -> MappingNode: """Conveniently serializes the class to a YAML mapping node. Args: blacklist (Set[str], optional): Properties to be ignored during serialization. additional_fields (Dict[str, Any], optional): Additional fields to be serialized. Returns: A YAML mapping node representing this object. :meta private: """ # Init additional fields additional_fields = additional_fields if additional_fields else {} # Query serializable properties serializable_atrributes = self._serializable_attributes(blacklist) # Construct state dictionary by querying serializable attributes state: Dict[str, Any] = {} for attribute_key in serializable_atrributes: attribute_value = getattr(self, attribute_key) # Don't serialize attribute if it is None if attribute_value is None: continue state[attribute_key] = attribute_value # Add additional fields to state if additional_fields: state.update(additional_fields) # Create YAML mapping return representer.represent_mapping(self.yaml_tag, state) @classmethod def from_yaml( cls: Type[SerializableType], constructor: SafeConstructor, node: Node ) -> SerializableType: """Recall a new serializable class instance from YAML. Args: constructor (SafeConstructor): A handle to the constructor extracting the YAML information. node (Node): YAML node representing the `Serializable` serialization. Returns: The de-serialized object. :meta private: """ # Handle empty yaml nodes if isinstance(node, ScalarNode): return cls() return cls.InitializationWrapper(constructor.construct_mapping(node, deep=True)) @classmethod def InitializationWrapper( cls: Type[SerializableType], configuration: Dict[str, Any] ) -> SerializableType: """Conveniently initializes serializable classes. Args: configuration (Dict[str, Any]): Configuration parameter dictionary. Returns: SerializableArray: Initialized class instance. :meta private: """ # Extract initialization signature init_signature = list(signature(cls.__init__).parameters.keys()) arg_signature = cls._arg_signature() init_signature.remove("self") # Extract settable class properties properties = cls._serializable_attributes() init_parameters: Dict[str, Any] = {} init_properties: Dict[str, Any] = {} for configuration_key in list(configuration.keys()): if configuration_key in init_signature or configuration_key in arg_signature: init_parameters[configuration_key] = configuration.pop(configuration_key) continue lower_key = configuration_key.lower() if lower_key in init_signature or lower_key in arg_signature: # pragma: no cover init_parameters[lower_key] = configuration.pop(configuration_key) continue if configuration_key in properties: init_properties[configuration_key] = configuration.pop(configuration_key) continue if lower_key in properties: # pragma: no cover init_properties[lower_key] = configuration.pop(configuration_key) continue # Initialize class # Remaining configuration fields get treated as kwargs init_parameters.update(configuration) try: instance = cls(**init_parameters) except TypeError as e: raise TypeError(f"Error while attempting to initialize '{cls.__name__}', {str(e)}") # Configure properties for property_name, property_value in init_properties.items(): try: setattr(instance, property_name, property_value) except AttributeError as e: raise AttributeError( f"Error while attempting to configure '{property_name}', {str(e)}" ) # Return configured class instance return instance
SET = TypeVar("SET", bound="SerializableEnum") """Type of serializable enumeration."""
[docs] class SerializableEnum(Serializable, Enum): """Base class for serializable enumerations."""
[docs] @classmethod def from_parameters(cls: Type[SET], enum: SET | int | str) -> SET: """Initialize enumeration from multiple parameters. Args: enum (SET | int | str): The parameter from which the enum should be initialized. Returns: The initialized enumeration. """ if isinstance(enum, cls): return enum elif isinstance(enum, int): return cls(enum) elif isinstance(enum, str): return cls[enum] else: raise ValueError("Unknown serializable enumeration type")
@classmethod def from_yaml(cls: Type[SerializableEnum], _: SafeConstructor, node: Node) -> SerializableEnum: # Convert scalar string representation back to enum return cls[node.value] @classmethod def to_yaml( cls: Type[SerializableEnum], representer: SafeRepresenter, node: SerializableEnum ) -> ScalarNode: # Convert enum to scalar string representation return representer.represent_scalar(cls.yaml_tag, "{.name}".format(node)) @classmethod # type: ignore @property def yaml_tag(cls) -> str: # type: ignore return cls.__name__
[docs] class Factory: """Helper class to load HermesPy simulation scenarios from YAML configuration files.""" extensions: Set[str] = {".yml", ".yaml", ".cfg"} """List of recognized filename extensions for serialization files.""" __yaml: YAML __clean: bool __db_regex: Pattern __tag_registry: Mapping[str, Type[Serializable]] def __init__(self) -> None: # YAML dumper configuration self.__yaml = YAML(typ="safe", pure=True) self.__yaml.default_flow_style = False self.__yaml.compact(seq_seq=False, seq_map=False) self.__yaml.encoding = None self.__yaml.indent(mapping=4, sequence=4, offset=2) self.__clean = True self.__tag_registry = {} # Add custom representers self.__yaml.representer.add_representer(complex, Factory.__complex128representer) self.__yaml.representer.add_representer(np.ndarray, Factory.__array_representer) self.__yaml.representer.add_representer(np.float64, Factory.__numpy_float64representer) # Add custom constructors self.__yaml.constructor.add_constructor("complex", Factory.__complex128constructor) self.__yaml.constructor.add_constructor("array", Factory.__array_constructor) self.__yaml.constructor.add_constructor("dB", Factory.__logarithmic_constructor) # Iterate over all modules within the hermespy namespace # Scan for serializable classes lookup_paths = list(hermespy.__path__) + [ os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) ] for _, name, is_module in iter_modules(lookup_paths, hermespy.__name__ + "."): if not is_module: continue # pragma: no cover module = import_module(name) for _, serializable_class in getmembers(module): if not isclass(serializable_class) or not issubclass( serializable_class, Serializable ): continue # Register serializable class at the YAML factory self.__yaml.register_class(serializable_class) # Remember tag for tagged classes if serializable_class.yaml_tag is not None: self.__tag_registry[serializable_class.yaml_tag] = serializable_class # Construct regular expressions for purging self.__range_regex = compile( r"([0-9.e-]*)[ ]*,[ ]*([0-9.e-]*)[ ]*,[ ]*\.\.\.[ ]*,[ ]*([0-9.e-]*)" ) self.__db_regex = compile(r"\[([ 0-9.,-]*)\][ ]*dB") @property def clean(self) -> bool: """Use clean YAML standard. Disabling the clean flag will deactivate additional text processing to the YAML configuration files done by Hermes, such as dB conversion or linear number spaces. Returns: Clean flag. """ return self.__clean @clean.setter def clean(self, flag: bool) -> None: self.__clean = flag @property def registered_classes(self) -> ValuesView[Type[Serializable]]: """Classes registered for serialization within the factory.""" return self.__tag_registry.values() @property def registered_tags(self) -> KeysView[str]: """Read registered YAML tags.""" return self.__tag_registry.keys() @property def tag_registry(self) -> Mapping[str, Type[Serializable]]: """Read registered YAML tags.""" return self.__tag_registry @staticmethod def __complex128representer(representer: SafeRepresenter, value: complex) -> ScalarNode: """Represent complex numbers as strings. Args: representer (SafeRepresenter): YAML representer. value (complex): The complex number to be transformed to a string. Returns: Scalar yaml node. """ complex128string = str(value)[1:-1] return representer.represent_scalar("complex", complex128string) @staticmethod def __complex128constructor(constructor: SafeConstructor, node: ScalarNode) -> complex: """Construct a complex number from YAML. Args: constructor (SafeConstructor): YAML constructor. node (ScalarNode): The YAML node representing the complex number. Returns: A complex number. """ complex128number = complex(constructor.construct_scalar(node)) return complex128number @staticmethod def __array_representer(representer: SafeRepresenter, array: np.ndarray) -> SequenceNode: """Represent numpy arrays as lists. Args: representer (SafeRepresenter): YAML representer. array (numpy.ndarray): The array to be transformed to a sequence. Returns: Sequence yaml node. """ # Transform complex numpy arrays to their string representation if array.dtype in [np.complex64, np.complex128]: object_array = np.empty(array.shape, dtype=object) for index, number in np.ndenumerate(array): object_array[index] = str(number).replace("(", "").replace(")", "") list = object_array.tolist() else: list = array.tolist() sequence = representer.represent_sequence("array", list, flow_style=True) return sequence @staticmethod def __numpy_float64representer(representer: SafeRepresenter, value: np.float64) -> ScalarNode: """Represent numy floating point scalar numbers as strings. Args: representer (SafeRepresenter): YAML representer. value (np.float64): The number to be transformed to a string. Returns: Scalar yaml node. """ return representer.represent_float(float(value)) @staticmethod def __array_constructor(constructor: SafeConstructor, node: SequenceNode) -> np.ndarray: """Construct a numpy array from YAML. Args: constructor (SafeConstructor): YAML constructor. node (ScalarNode): The YAML node representing the array. Returns: A numpy array. """ if isinstance(node, SequenceNode): return np.array([Factory.__array_constructor(constructor, n) for n in node.value]) if "j" in node.value: return Factory.__complex128constructor(constructor, node) else: return constructor.construct_object(node) @staticmethod def __logarithmic_constructor( constructor: SafeConstructor, node: Union[ScalarNode, SequenceNode] ) -> Union[Logarithmic, LogarithmicSequence]: """Construct a logarithmic value or sequence from YAML. Args: constructor (SafeConstructor): YAML constructor. node (Union[ScalarNode, SequenceNode]): The YAML node representing the array. Returns: A logarithmic representation. """ if isinstance(node, ScalarNode): return Logarithmic(float(constructor.construct_scalar(node))) if isinstance(node, SequenceNode): return LogarithmicSequence(constructor.construct_sequence(node)) @staticmethod def __decibel_conversion(match: re.Match) -> str: """Convert YAML sequences with dB annotations to tagged sequences. Args: match (re.Match): The serialization sequence to be converted. Returns: str: The purged sequence. """ linear_values = [float(str_rep) for str_rep in match[1].replace(" ", "").split(",")] string_replacement = "!<dB> [" for linear_value in linear_values: string_replacement += str(linear_value) + ", " string_replacement += "]" return string_replacement
[docs] def from_path(self, paths: Union[str, Set[str]]) -> Sequence[Any]: """Load a configuration from an arbitrary file system path. Args: paths (Union[str, Set[str]]): Paths to a file or a folder featuring .yml config files. Returns: Serializable objects recalled from `paths`. Raises: ValueError: If the provided `path` does not exist on the filesystem. """ # Convert single path to a set if required if isinstance(paths, str): paths = {paths} hermes_objects = [] for path in paths: if not os.path.exists(path): raise ValueError(f"Lookup path '{path}' not found") if os.path.isdir(path): deserialization = self.from_folder(path) else: deserialization = self.from_file(path) if isinstance(deserialization, list): hermes_objects += deserialization else: hermes_objects.append(deserialization) # pragma: no cover return hermes_objects
[docs] def from_folder( self, path: str, recurse: bool = True, follow_links: bool = False ) -> Sequence[Any] | Any: """Load a configuration from a folder. Args: path (str): Path to the folder configuration. recurse (bool, optional): Recurse into sub-folders within `path`. follow_links (bool, optional): Follow links within `path`. Returns: Serializable objects recalled from `path`. Raises: ValueError: If `path` is not a directory. """ if not os.path.exists(path): raise ValueError("Lookup path '{}' not found".format(path)) if not os.path.isdir(path): raise ValueError("Lookup path '{}' is not a directory".format(path)) hermes_objects: List[Any] = [] for directory, _, files in os.walk(path, followlinks=follow_links): for file in files: _, extension = os.path.splitext(file) if extension in self.extensions: deserialization = self.from_file(os.path.join(directory, file)) hermes_objects += ( deserialization if isinstance(deserialization, list) else [deserialization] ) if not recurse: break return hermes_objects
[docs] def to_folder(self, path: str, *args: Any) -> None: """Dump a configuration to a folder. Args: path (str): Path to the folder configuration. *args (Any): Configuration objects to be dumped. """ pass # pragma: no cover
[docs] def from_str(self, config: str) -> Sequence[Any] | Any: """Load a configuration from a string object. Args: config (str): The configuration to be loaded. Returns: List of objects or object from `config`. """ stream = StringIO(config) return self.from_stream(stream)
[docs] def to_str(self, *args: Any) -> str: """Dump a configuration to a folder. Args: *args (Any): Configuration objects to be dumped. Returns: str: String containing full YAML configuration. Raises: RepresenterError: If objects in ``*args`` are unregistered classes. """ stream = StringIO() self.to_stream(stream, args) return stream.getvalue()
[docs] def from_file(self, file: str) -> Sequence[Any] | Any: """Load a configuration from a single YAML file. Args: file (str): Path to the folder configuration. Returns: Serialized objects within `path`. """ with open(file, mode="r") as file_stream: try: return self.from_stream(file_stream) # Re-raise constructor errors with the correct file name except ConstructorError as constructor_error: constructor_error.problem_mark.name = file raise constructor_error
[docs] def to_file(self, path: str, *args: Any) -> None: """Dump a configuration to a single YML file. Args: path (str): Path to the configuration file. *args (Any): Configuration objects to be dumped. Raises: RepresenterError: If objects in ``*args`` are unregistered classes. """ pass # pragma: no cover
@staticmethod def __range_restore_callback(m: Match) -> str: """Internal regular expression callback. Args: m (Match): Regular expression match. Returns: str: The processed match line. """ # Extract range parameters start = float(m.group(1)) step = float(m.group(2)) - start stop = float(m.group(3)) + step range = np.arange(start=start, stop=stop, step=step) replacement = "" for step in range[:-1]: replacement += str(step) + ", " replacement += str(range[-1]) return replacement
[docs] def from_stream(self, stream: TextIOBase) -> Sequence[Any] | Any: """Load a configuration from an arbitrary text stream. Args: stream (TextIOBase): Text stream containing the configuration. Returns: List of deserialized objects or object within `stream`. Raises: ConstructorError: If YAML parsing fails. """ if not self.__clean: return self.__yaml.load(stream) clean_stream = "" for line in stream.readlines(): clean_line = self.__range_regex.sub(self.__range_restore_callback, line) clean_line = self.__db_regex.sub(self.__decibel_conversion, clean_line) clean_stream += clean_line hermes_objects = self.__yaml.load(StringIO(clean_stream)) # If the deserialization is empty, return an empty list if hermes_objects is None: return [] # If the deserialization is a single item, return just the item if isinstance(hermes_objects, Sequence) and len(hermes_objects) == 1: return hermes_objects[0] return hermes_objects
[docs] def to_stream(self, stream: TextIOBase, *args: Iterable[Any]) -> None: """Dump a configuration to an arbitrary text stream. Args: stream (TextIOBase): Text stream to the configuration. *args (Any): Configuration objects to be dumped. Raises: RepresenterError: If objects in ``*args`` are unregistered classes. """ for serializable_object in args: self.__yaml.dump(serializable_object, stream)
HDFSerializableType = TypeVar("HDFSerializableType", bound="HDFSerializable") """Type of HDF Serializable Class"""
[docs] class HDFSerializable(metaclass=ABCMeta): """Base class for object serializable to the HDF5 format. Structures are serialized to HDF5 files by the :meth:`to_HDF<HDFSerializable.to_HDF>` routine and de-serialized by the :meth:`from_HDF<HDFSerializable.from_HDF>` method, respectively. """ @abstractmethod def to_HDF(self, group: Group) -> None: """Serialize the object state to HDF5. Dumps the object's state and additional information to a HDF5 group. Args: group (h5py.Group): The HDF5 group to which the object is serialized. :meta private: """ ... # pragma no cover @classmethod @abstractmethod def from_HDF(cls: Type[HDFSerializableType], group: Group) -> HDFSerializableType: """De-Serialized the object state from HDF5. Recalls the object's state from a HDF5 group. Args: group (h5py.Group): The HDF5 group from which the object state is recalled. Returns: The object initialized from the HDF5 group state. :meta private: """ ... # pragma no cover @staticmethod def _create_group(group: Group, name: str) -> Group: """Create an HDF5 group if it does not exist yet. Args: group (h5py.Group): The HDF5 group from which the object state is recalled. name (str): Name of the group to be created. Returns: A handle to group `name`. :meta private: """ if name not in group: return group.create_group(name) else: return group[name] @staticmethod def _write_dataset(group: Group, dataset: str, data: Any | None) -> None: """Write to a dataset. Args: group (h5py.Group): The HDF5 group from which the object state is recalled. dataset (str): The dataset name. data (Any | None): The data to be written to `dataset`. :meta private: """ if dataset in group: del group[dataset] group.create_dataset(dataset, data=data) @staticmethod def _range_to_HDF(group: Group, id: str, value: float | Tuple[float, float]) -> None: """Serialize a range variable to HDF5. Args: group (h5py.Group): The HDF5 group to which the range value is serialized. id (str): Identifier string of the range value. value (float | Tuple[float, float]): The range value to be serialized. Can either be a scalar or a tuple of two values indicating maximum and minimum. """ if isinstance(value, tuple): group.attrs[id + "_min"] = value[0] group.attrs[id + "_max"] = value[1] else: group.attrs[id] = value @staticmethod def _range_from_HDF(group: Group, id: str) -> float | Tuple[float, float]: """Deserialize a range variable from HDF5. Args: group (h5py.Group): The HDF5 group from which the range value is deserialized. id (str): Identifier string of the range value. Returns: The deserialized range value. Can either be a scalar or a tuple of two values indicating maximum and minimum. """ if id in group.attrs: return float(group.attrs[id]) else: return (float(group.attrs[id + "_min"]), float(group.attrs[id + "_max"]))