Source code for hermespy.modem.symbols

# -*- coding: utf-8 -*-

from __future__ import annotations
from enum import Enum
from typing import Optional, Union, Iterable, Type

import matplotlib.pyplot as plt
import numpy as np
from h5py import Group
from matplotlib import rcParams
from sparse import SparseArray  # type: ignore

from hermespy.core import HDFSerializable, VisualizableAttribute, ScatterVisualization, VAT

__author__ = "Jan Adler"
__copyright__ = "Copyright 2024, Barkhausen Institut gGmbH"
__credits__ = ["Jan Adler", "Tobias Kronauer"]
__license__ = "AGPLv3"
__version__ = "1.3.0"
__maintainer__ = "Jan Adler"
__email__ = "jan.adler@barkhauseninstitut.org"
__status__ = "Prototype"


class SymbolType(Enum):
    """Communication symbol type flag."""

    DATA = 0
    """Data symbol transmitting information."""

    REFERENCE = 1
    """Reference symbol for channel estimation."""

    PILOT = 2
    """Pilot symbol for frame detection."""


[docs] class _ConstellationPlot(VisualizableAttribute[ScatterVisualization]): """Plot the symbol constellation. Essentially projects the time-series of symbols onto a single complex plane. Args: axes (Optional[plt.Axes], optional): The axes to plot the graph to. By default, a new matplotlib figure is created. title (str, optional): Plot title. Only relevant if no axes were provided. Returns: Optional[plt.Figure]: Handle to the created matplotlib.pyplot figure object. None if the axes were provided. """ __symbols: Symbols def __init__(self, symbols: Symbols) -> None: """ Args: symbols (Symbols): The symbols to be plotted. """ # Initialize the base class super().__init__() # Initialize attributes self.__symbols = symbols @property def title(self) -> str: return "Symbol Constellation" def _prepare_visualization( self, figure: plt.Figure | None, axes: VAT, **kwargs ) -> ScatterVisualization: ax: plt.Axes = axes.flat[0] ax.set(ylabel="Imag") ax.set(xlabel="Real") ax.grid(True, which="both") ax.axhline(y=0, color=rcParams["grid.color"]) ax.axvline(x=0, color=rcParams["grid.color"]) ax.set_xlim(-1.25, 1.25) ax.set_ylim(-1.25, 1.25) num_symbols = ( self.__symbols.num_symbols * self.__symbols.num_blocks * self.__symbols.num_streams ) zeros = np.zeros(num_symbols, dtype=np.float_) path_collection = np.empty((1, 1), dtype=np.object_) path_collection[0, 0] = ax.scatter(zeros, zeros) return ScatterVisualization(figure, axes, path_collection) def _update_visualization(self, visualization: ScatterVisualization, **kwargs) -> None: symbols = self.__symbols.raw.flatten() path: plt.PathCollection = visualization.paths[0, 0] path.set_offsets(np.array([symbols.real, symbols.imag]).T)
class Symbol(object): """A single communication symbol located somewhere on the complex plane.""" value: complex """Value of the symbol.""" flag: SymbolType """Type of the symbol.""" def __init__(self, value: complex, flag: SymbolType = SymbolType.DATA) -> None: """ Args: value (complex): Symbol value. flag (SymbolType, optional): Assumed symbol type. Data is assumed by default. """ self.value = value self.flag = flag
[docs] class Symbols(HDFSerializable): """A time-series of communication symbols located somewhere on the complex plane.""" __symbols: np.ndarray # Internal symbol storage __constellation_plot: _ConstellationPlot # Symbol constellation plot def __init__(self, symbols: Optional[Union[Iterable, np.ndarray]] = None) -> None: """ Args: symbols (Union[Iterable, numpy.ndarray], optional): A three-dimensional array of complex-valued communication symbols. The first dimension denotes the number of streams, the second dimension the number of symbol blocks per stream, the third dimension the number of symbols per block. """ symbols = np.empty((0, 0, 0), dtype=complex) if symbols is None else symbols symbols = np.array(symbols) if not isinstance(symbols, np.ndarray) else symbols # Make sure the initialization is a valid symbol sequence if symbols.ndim > 3: raise ValueError("Symbols initialization array may have a maximum of three dimensions") # Exand the dimensions if required if symbols.ndim == 1: symbols = symbols[np.newaxis, :, np.newaxis] elif symbols.ndim == 2: symbols = symbols[:, :, np.newaxis] self.__symbols = symbols self.__constellation_plot = _ConstellationPlot(self) @property def num_streams(self) -> int: """Number of streams within this symbol series. Returns: int: Number of streams. """ return self.__symbols.shape[0] @property def num_blocks(self) -> int: """Number of symbol blocks within this symbol series. Returns: int: Number of symbols """ return self.__symbols.shape[1] @property def num_symbols(self) -> int: """Number of symbols per stream within this symbol series. Returns: int: Number of symbols """ return self.__symbols.shape[2]
[docs] def append_stream(self, symbols: Union[Symbols, np.ndarray]) -> None: """Append a new symbol stream to this symbol seris. Represents a matrix concatenation in the first dimensions. Args: symbols (Union[Symbols, np.ndarray]): Symbol stream to be appended to this symbol series. Raises: ValueError: If the number of symbols in time-domain do not match. """ if isinstance(symbols, Symbols): symbols = symbols.raw if symbols.ndim == 1: symbols = symbols[np.newaxis, :, np.newaxis] elif symbols.ndim == 2: symbols = symbols[:, :, np.newaxis] if symbols.ndim != 3: raise ValueError("Symbols must be matrix (an array of dimension two)") if self.num_symbols < 1 and self.num_streams <= 1: self.__symbols = symbols else: if self.num_symbols != symbols.shape[2]: raise ValueError("Symbol models to be concatenated do not match in time-domain") if self.num_blocks != symbols.shape[1]: raise ValueError("Symbol models to be concatenated do not match in block-domain") self.__symbols = np.append(self.__symbols, symbols, axis=0)
[docs] def append_symbols(self, symbols: Union[Symbols, np.ndarray]) -> None: """Append a new symbol sequence to this symbol seris. Represents a matrix concatenation in the second dimensions. Args: symbols (Union[Symbols, np.ndarray]): Symbol sequence to be appended to this symbol series. Raises: ValueError: If the number of symbol streams do not match. """ if isinstance(symbols, Symbols): symbols = symbols.raw if symbols.ndim == 1: symbols = symbols[np.newaxis, :, np.newaxis] elif symbols.ndim == 2: symbols = symbols[:, :, np.newaxis] if symbols.ndim != 3: raise ValueError("Symbols must contain three dimensions") if self.num_symbols < 1 and self.num_streams <= 1: self.__symbols = symbols else: if self.num_streams != symbols.shape[0]: raise ValueError("Symbol models to be concatenated do not match in stream-domain") self.__symbols = np.append(self.__symbols, symbols, axis=1)
@property def raw(self) -> np.ndarray: """Access the raw symbol array. Return: np.ndarray: The raw symbol array """ return self.__symbols @raw.setter def raw(self, value: np.ndarray) -> None: if value.ndim != 3: raise ValueError("Raw symbols must be a three-dimensionall array") self.__symbols = value
[docs] def copy(self) -> Symbols: """Create a deep copy of this symbol sequence. Returns: Symbols: Copied sequence. """ return Symbols(self.__symbols.copy())
def __getitem__(self, section: slice) -> Symbols: """Slice this symbol series. Args: section (slice): Slice symbol selection. Returns: Symbols: New Symbols object representing the selected `section`. """ return Symbols(self.__symbols[section]) def __setitem__(self, section: slice, value: Union[Symbols, np.ndarray]) -> None: """Set symbols within this series. Args: section (slice): Slice pointing to the symbol positions to be updated. value (Union[Symbols, np.ndarray]): The symbols to be set. """ if isinstance(value, Symbols): self.__symbols[section] = value.__symbols else: self.__symbols[section] = value @property def plot_constellation(self) -> _ConstellationPlot: """Plot the symbol constellation.""" return self.__constellation_plot @classmethod def from_HDF(cls: Type[Symbols], group: Group) -> Symbols: # Recall datasets symbols = np.array(group["symbols"]) # dtype=complex # Initialize object from recalled state return cls(symbols=symbols) def to_HDF(self, group: Group) -> None: # Serialize datasets group.create_dataset("symbols", data=self.__symbols) # Serialize attributes group.attrs["num_streams"] = self.num_streams group.attrs["num_blocks"] = self.num_blocks group.attrs["num_symbols"] = self.num_symbols
[docs] class StatedSymbols(Symbols): """A time-series of communication symbols and channel states located somewhere on the complex plane.""" __states: np.ndarray # Symbol states, four-dimensional array def __init__(self, symbols: Iterable | np.ndarray, states: np.ndarray | SparseArray) -> None: """ Args: symbols (Union[Iterable, numpy.ndarray]): A three-dimensional array of complex-valued communication symbols. The first dimension denotes the number of streams, the second dimension the number of symbol blocks per stream, the the dimension the number of symbols per block. states (np.ndarray | SparseArray): Four-dimensional numpy array with the first two dimensions indicating the MIMO receive and transmit streams, respectively and the last two dimensions indicating the number of symbol blocks and symbols per block. """ Symbols.__init__(self, symbols) self.states = states @property def states(self) -> np.ndarray | SparseArray: """Symbol state information. Four-dimensional numpy array with the first two dimensions indicating the MIMO receive and transmit streams, respectively and the last two dimensions indicating the number of symbol blocks and symbols per block. Raises: ValueError: If the state array is not four-dimensional. ValueError: If the state dimensions don't match the symbol dimensions. """ return self.__states @states.setter def states(self, value: np.ndarray | SparseArray) -> None: if value.ndim != 4: raise ValueError("State must be a four-dimensional numpy array") if value.shape[0] != self.num_streams: raise ValueError( f"Number of received streams don't match, expected {self.num_streams} instead of {value.shape[0]}" ) if value.shape[2] != self.num_blocks: raise ValueError( f"Number of received blocks don't match, expected {self.num_blocks} instead of {value.shape[2]}" ) if value.shape[3] != self.num_symbols: raise ValueError( f"Symbol block sizes don't match, expected {self.num_symbols} instead of {value.shape[3]}" ) self.__states = value.copy()
[docs] def dense_states(self) -> np.ndarray: """Return the channel state in dense format. Note that this method will convert the channel state to dense format if it is currently in sparse format. This operation may be computationally expensive and should be avoided if possible. Returns: The channel state tensor in dense format. """ return self.__states.todense() if isinstance(self.__states, SparseArray) else self.__states
@property def num_transmit_streams(self) -> int: """Number of impinging transmit streams. Returns: Number of streams. """ return self.__states.shape[1]
[docs] def copy(self) -> StatedSymbols: return StatedSymbols(self.raw.copy(), self.states.copy())
@classmethod def from_HDF(cls: Type[StatedSymbols], group: Group) -> StatedSymbols: # Recall datasets symbols = np.array(group["symbols"], dtype=complex) states = np.array(group["states"], dtype=complex) # Initialize object from recalled state return cls(symbols=symbols, states=states) def to_HDF(self, group: Group) -> None: # Serialize base class Symbols.to_HDF(self, group) # Serialize datasets group.create_dataset("states", data=self.dense_states())