# -*- coding: utf-8 -*-
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Generic, List, Sequence, Type, TypeVar, Union
import matplotlib.pyplot as plt
import numpy as np
from ruamel.yaml import MappingNode, Node, SafeConstructor, SafeRepresenter
from hermespy.core import Serializable, SerializableEnum, Signal, VisualizableAttribute
from hermespy.core.visualize import ImageVisualization, VAT
from ...symbols import Symbols, StatedSymbols
from ...tools import PskQamMapping
from ...waveform import (
__author__ = "Jan Adler"
__copyright__ = "Copyright 2024, Barkhausen Institut gGmbH"
__credits__ = ["Jan Adler", "Tobias Kronauer"]
__license__ = "AGPLv3"
__version__ = "1.4.0"
__maintainer__ = "Jan Adler"
__email__ = "jan.adler@barkhauseninstitut.org"
__status__ = "Prototype"
class ElementType(SerializableEnum):
"""Type of resource element."""
"""Reference element within the resource grid"""
DATA = 1
"""Data element within the resource grid"""
NULL = 2
"""Empty element within the resource grid"""
class PrefixType(SerializableEnum):
"""Type of prefix applied to the grid resource in time-domain."""
"""Cyclic prefix repeating the resource waveform in time-domain"""
"""Prefix zero-padding the prefix in time-domain"""
NONE = 2
"""No prefix applied"""
class GridElement(Serializable):
yaml_tag = "Element"
serialized_attributes = {"type", "repetitions"}
type: ElementType
repetitions: int = 1
def __init__(self, type: str | ElementType, repetitions: int = 1) -> None:
self.type = ElementType[type] if isinstance(type, str) else type
self.repetitions = repetitions
class ReferencePosition(SerializableEnum):
"""Applied channel estimation algorithm after reception."""
class GridResource(Serializable):
"""Configures one sub-section of a resource grid in both dimensions."""
yaml_tag = "Resource"
serialized_attributes = {"prefix_type", "elements"}
__repetitions: int
__prefix_ratio: float
prefix_type: PrefixType
"""Prefix type of the frame resource"""
elements: List[GridElement]
"""Individual resource elements"""
def __init__(
repetitions: int = 1,
prefix_type: Union[PrefixType, str] = PrefixType.CYCLIC,
prefix_ratio: float = 0.0,
elements: List[GridElement] | None = None,
) -> None:
self.repetitions = repetitions
self.prefix_ratio = prefix_ratio
self.prefix_type = PrefixType[prefix_type] if isinstance(prefix_type, str) else prefix_type
self.elements = elements if elements is not None else []
def repetitions(self) -> int:
"""Number of block repetitions along the frequency axis.
int: Number of repetitions.
return self.__repetitions
def repetitions(self, reps: int) -> None:
"""Modify the number of repetitions.
reps (int): Number of repetitions.
ValueError: If `reps` is smaller than one.
if reps < 1:
raise ValueError("Number of frame resource repetitions must be greater or equal to one")
self.__repetitions = reps
def prefix_ratio(self) -> float:
"""Ratio between full block length and prefix length.
float: The ratio between zero and one.
ValueError: If ratio is less than zero or larger than one.
return self.__prefix_ratio
def prefix_ratio(self, ratio: float) -> None:
if ratio < 0.0 or ratio > 1.0:
raise ValueError(f"Cyclic prefix ratio must be between zero and one, not {ratio}")
self.__prefix_ratio = ratio
def num_subcarriers(self) -> int:
"""Number of occupied subcarriers.
int: Number of occupied subcarriers.
num: int = 0
for element in self.elements:
num += element.repetitions
return self.__repetitions * num
def num_symbols(self) -> int:
"""Number of data symbols this resource can modulate.
Number of modulated symbols.
num: int = 0
for element in self.elements:
if element.type == ElementType.DATA:
num += element.repetitions
return self.__repetitions * num
def num_references(self) -> int:
"""Number of references symbols this resource can modulate.
Number of modulated symbols.
num: int = 0
for element in self.elements:
if element.type == ElementType.REFERENCE:
num += element.repetitions
return self.__repetitions * num
def mask(self) -> np.ndarray:
"""Boolean mask selecting a specific type of element from the OFDM grid.
Mask of dimension `num_element_types`x`num_subcarriers`.
# Initialize the base mask as all false
mask = np.ndarray((len(ElementType), self.num_subcarriers), dtype=bool) * False
element_count = 0
for element in self.elements:
mask[element.type.value, element_count : element_count + element.repetitions] = True
element_count += element.repetitions
# Repeat the subcarrier masks according to the configured number of repetitions.
mask = np.tile(mask[:, :element_count], (1, self.__repetitions))
return mask
OWT = TypeVar("OWT", bound="OrthogonalWaveform")
"""Type variable for orthogonal waveform types."""
class GridSection(Generic[OWT], ABC):
"""Description of a part of a grid's time domain."""
__wave: OWT | None
__num_repetitions: int
__sample_offset: int
def __init__(
self, num_repetitions: int = 1, sample_offset: int = 0, wave: OWT | None = None
) -> None:
num_repetitions (int, optional): Number of times this section is repeated in time-domain.
sample_offset (int, optional): Offset in samples to the start of the section.
wave (OWT, optional): Waveform this section is associated with. Defaults to None.
# Initialize class attributes
self.wave = wave
self.sample_offset = sample_offset
self.num_repetitions = num_repetitions
def wave(self) -> OWT | None:
"""Waveform this section is associated with."""
return self.__wave
def wave(self, value: OWT | None) -> None:
self.__wave = value
def sample_offset(self) -> int:
"""Offset in samples to the start of the section.
This can be used to explot cyclic prefixes and suffixes in order to be more robust
against timing offsets.
return self.__sample_offset
def sample_offset(self, value: int) -> None:
self.__sample_offset = value
def num_repetitions(self) -> int:
"""Number of section repetitions in the time-domain of an OFDM grid.
int: The number of repetitions.
return self.__num_repetitions
def num_repetitions(self, value: int) -> None:
"""Number of section repetitions in the time-domain of an OFDM grid.
value (int): The number of repetitions.
ValueError: If `value` is smaller than one.
if value < 1:
raise ValueError("OFDM frame number of repetitions must be greater or equal to one")
self.__num_repetitions = value
def num_symbols(self) -> int:
"""Number of data symbols this section can modulate.
int: The number of symbols
return 0
def num_references(self) -> int:
"""Number of data symbols this section can modulate.
int: The number of symbols
return 0
def num_words(self) -> int:
"""Number of OFDM symbols, i.e. words of subcarrier symbols this section can modulate.
int: The number of words.
return 0
def num_subcarriers(self) -> int:
"""Number of subcarriers this section requires.
int: The number of subcarriers.
return 0
def resource_mask(self) -> np.ndarray:
return np.empty((len(ElementType), 0, 0), dtype=bool)
def num_samples(self) -> int:
"""Number of samples within this OFDM time-section.
int: Number of samples
... # pragma: no cover
def place_symbols(self, data_symbols: np.ndarray, reference_symbols: np.ndarray) -> np.ndarray:
"""Place this section's symbols into the resource grid.
data_symbols (numpy.ndarray): Data symbols to be placed. Numpy vector of size `num_symbols`.
reference_symbols (numpy.ndarray): Reference symbols to be placed. Numpy vector of size `num_references`.
Returns: Two dimensional numpy array of size `num_words`x`num_subcarriers`.
# Collect resource masks
mask = self.resource_mask
grid = np.zeros((self.num_words, self.num_subcarriers), dtype=np.complex128)
grid[mask[ElementType.REFERENCE.value, ::]] = reference_symbols
grid[mask[ElementType.DATA.value, ::]] = data_symbols
return grid
def pick_symbols(self, grid: np.ndarray) -> np.ndarray:
"""Pick this section's data symbols from the resource grid.
grid (numpy.ndarray): Resource grid. Two dimensional numpy array of size `num_words`x`num_subcarriers`.
Returns: Data symbols. Numpy vector of size `num_symbols`.
# Collect resource masks
mask = self.resource_mask
# Select correct subgrid
subgrid_selector = tuple(slice(None) for _ in range(grid.ndim - 1)) + (
slice(0, self.num_subcarriers, 1),
subgrid = grid[subgrid_selector]
# Pick symbols
# reference_symbols = subgrid[mask[ElementType.REFERENCE.value]]
selector = tuple(slice(None) for _ in range(subgrid.ndim - 2)) + (
picked_symbols = subgrid[selector]
return picked_symbols
def place_samples(self, signal: np.ndarray) -> np.ndarray:
"""Place this section's samples into the time-domain signal.
signal (numpy.ndarray): Time-domain signal to be placed. Numpy vector of size `num_samples`.
Returns: Time-domain signal with the section's samples placed.
... # pragma: no cover
def pick_samples(self, signal: np.ndarray) -> np.ndarray:
"""Pick this section's samples from the time-domain signal.
signal (numpy.ndarray): Time-domain signal to be picked from. Numpy vector of size `num_samples`.
Returns: Time-domain signal with the section's samples picked.
... # pragma: no cover
class SymbolSection(GridSection["OrthogonalWaveform"], Serializable):
yaml_tag: str = "Symbol"
serialized_attributes = {"pattern"}
pattern: List[int]
def __init__(
num_repetitions: int = 1,
pattern: List[int] | None = None,
sample_offset: int = 0,
wave: OrthogonalWaveform | None = None,
) -> None:
num_repetitions (int, optional): Number of times this section is repeated in time-domain.
pattern (List[int], optional): Resource pattern within this symbol section.
sample_offset (int, optional): Offset in samples to the start of the section.
frame (OrthogonalWaveform | None, optional): _description_. Defaults to None.
# Initialize bae class
GridSection.__init__(self, num_repetitions, sample_offset, wave)
# Initialize class attributes
self.pattern = pattern if pattern is not None else []
def num_symbols(self) -> int:
num = 0
for resource_idx in self.pattern:
resource = self.wave.grid_resources[resource_idx]
num += resource.num_symbols
return self.num_repetitions * num
def num_references(self) -> int:
num = 0
for resource_idx in self.pattern:
resource = self.wave.grid_resources[resource_idx]
num += resource.num_references
return self.num_repetitions * num
def num_words(self) -> int:
return self.num_repetitions * len(self.pattern)
def num_subcarriers(self) -> int:
num = 0
for resource_idx in set(self.pattern):
num = max(num, self.wave.grid_resources[resource_idx].num_subcarriers)
return num
def _padded_num_subcarriers(self) -> int:
"""Number of subcarriers required to represent this section in time-domain."""
return self.wave.num_subcarriers * self.wave.oversampling_factor
def place_samples(self, samples: np.ndarray) -> np.ndarray:
placed_samples = np.empty(self.num_samples, dtype=np.complex128)
sample_idx = 0
resource_idx: int
resource_samples: np.ndarray
for resource_idx, resource_samples in enumerate(samples):
# Infer pattern index
pattern_idx = resource_idx % len(self.pattern)
# Extract prefix parameters from configuration
prefix_ratio = self.wave.grid_resources[self.pattern[pattern_idx]].prefix_ratio
prefix_type = self.wave.grid_resources[self.pattern[pattern_idx]].prefix_type
num_prefix_samples = int(self._padded_num_subcarriers * prefix_ratio)
# Only add a prefix if required
if num_prefix_samples > 0 and prefix_type != PrefixType.NONE:
# Cyclic prefix
if prefix_type == PrefixType.CYCLIC:
placed_samples[sample_idx : sample_idx + num_prefix_samples] = resource_samples[
# Zero padding
elif prefix_type == PrefixType.ZEROPAD:
placed_samples[sample_idx : sample_idx + num_prefix_samples] = np.zeros(
num_prefix_samples, dtype=np.complex128
# Raise exception for unsupproted prefix types
raise RuntimeError("Unsupported prefix type configured")
# Advance the sample index by the prefix length
sample_idx += num_prefix_samples
# Append base resource waveform after prefix
placed_samples[sample_idx : sample_idx + resource_samples.size] = resource_samples
sample_idx += resource_samples.size
return placed_samples
def pick_samples(self, samples: np.ndarray) -> np.ndarray:
sample_index = 0
num_symbols = len(self.pattern) * self.num_repetitions
resource_samples = np.empty(
(*samples.shape[:-1], num_symbols, self._padded_num_subcarriers), dtype=complex
prefix_slice = [slice(None)] * (resource_samples.ndim - 2)
for resource_idx in range(num_symbols):
# Infer pattern index
pattern_idx = resource_idx % len(self.pattern)
# Extract prefix parameters from configuration
resource = self.wave.grid_resources[self.pattern[pattern_idx]]
prefix_ratio = resource.prefix_ratio
prefix_type = resource.prefix_type
num_prefix_samples = int(self._padded_num_subcarriers * prefix_ratio)
# Only add a prefix if required
if num_prefix_samples > 0 and prefix_type != PrefixType.NONE:
# Advance the sample index by the prefix length, essentially skipping the prefix
sample_index += num_prefix_samples
# Sort resource samples into their respective matrix sections
resource_slicing = (*prefix_slice, resource_idx, slice(None))
signal_slicing = (
sample_index - self.sample_offset,
sample_index + self._padded_num_subcarriers - self.sample_offset,
resource_samples[resource_slicing] = samples[signal_slicing]
# Advance sample index by resource length
sample_index += self._padded_num_subcarriers
return resource_samples
def resource_mask(self) -> np.ndarray:
# Initialize the base mask as all false
mask = np.zeros((len(ElementType), len(self.pattern), self.num_subcarriers), dtype=bool)
for word_idx, resource_idx in enumerate(self.pattern):
resource = self.wave.grid_resources[resource_idx]
mask[:, word_idx, : resource.num_subcarriers] = resource.mask
return np.tile(mask, (1, self.num_repetitions, 1))
def num_samples(self) -> int:
num_samples_per_slot = self.wave.num_subcarriers * self.wave.oversampling_factor
num = len(self.pattern) * num_samples_per_slot
# Add up the additional samples from cyclic prefixes
for resource_idx in self.pattern:
num += int(num_samples_per_slot * self.wave.grid_resources[resource_idx].prefix_ratio)
# Add up the base samples from each timeslot
return num * self.num_repetitions
class GuardSection(GridSection["OrthogonalWaveform"], Serializable):
yaml_tag = "Guard"
__duration: float
def __init__(
self, duration: float, num_repetitions: int = 1, frame: OrthogonalWaveform | None = None
) -> None:
GridSection.__init__(self, num_repetitions=num_repetitions, wave=frame)
self.duration = duration
def duration(self) -> float:
"""Guard section duration in seconds.
float: Duration in seconds.
return self.__duration
def duration(self, value: float) -> None:
"""Guard section duration in seconds.
value (float): New duration.
ValueError: If `value` is smaller than zero.
if value < 0.0:
raise ValueError("Guard section duration must be greater or equal to zero")
self.__duration = value
def num_samples(self) -> int:
return int(self.num_repetitions * self.__duration * self.wave.sampling_rate)
def place_samples(self, signal: np.ndarray) -> np.ndarray:
return np.zeros(self.num_samples, dtype=np.complex128)
def pick_samples(self, signal: np.ndarray) -> np.ndarray:
return np.empty(
(0, self.wave.num_subcarriers * self.wave.oversampling_factor), dtype=np.complex128
class PilotSection(Generic[OWT], GridSection[OWT], Serializable):
"""Pilot symbol section within an resource grid."""
yaml_tag = "Pilot"
"""YAML serialization tag"""
__pilot_elements: Symbols | None
__cached_num_subcarriers: int
__cached_oversampling_factor: int
__cached_pilot: np.ndarray | None
def __init__(self, pilot_elements: Symbols | None = None, wave: OWT | None = None) -> None:
pilot_elements (Symbols, optional):
Symbols with which the subcarriers within the pilot will be modulated.
By default, a pseudo-random sequence from the frame mapping will be generated.
wave (OWT, optional):
The waveform configuration this pilot section is associated with.
# Initialize base class
GridSection.__init__(self, 1, 0, wave=wave)
# Initialize class attributes
self.__pilot_elements = pilot_elements
self.__cached_num_subcarriers = -1
self.__cached_oversampling_factor = -1
self.__cached_pilot = None
@GridSection.num_repetitions.setter # type: ignore
def num_repetitions(self, value: int) -> None:
if value != 1:
raise ValueError("Pilot sections may not be repeated")
GridSection.num_repetitions.fset(self, value) # type: ignore
@GridSection.sample_offset.setter # type: ignore
def sample_offset(self, value: int) -> None:
if value != 0:
raise ValueError("Pilot sections may not have a sample offset")
GridSection.sample_offset.fset(self, value) # type: ignore
def num_samples(self) -> int:
return self.wave.num_subcarriers * self.wave.oversampling_factor
def num_symbols(self) -> int:
return 0
def num_words(self) -> int:
return 1
def num_subcarriers(self) -> int:
return self.wave.num_subcarriers if self.wave else 0
def num_references(self) -> int:
if self.__pilot_elements or self.wave is None:
return 0
return self.wave.num_subcarriers
def resource_mask(self) -> np.ndarray:
mask = np.zeros(
(len(ElementType), 1, self.wave.num_subcarriers if self.wave else 0), dtype=bool
mask[ElementType.REFERENCE.value, 0, ::] = True
return mask
def pilot_elements(self) -> Symbols | None:
"""Symbols with which the orthogonal subcarriers within the pilot will be modulated.
A stream of symbols.
`None`, if no pilot symbols were specified.
ValueError: If the configured symbols contains multiple streams.
return self.__pilot_elements
def pilot_elements(self, value: Symbols | None) -> None:
if value is None:
self.__pilot_elements = None
if value.num_streams != 1:
raise ValueError("Subsymbol pilot configuration may only contain a single stream")
if value.num_symbols < 1:
raise ValueError("Subsymbol pilot configuration must contain at least one symbol")
# Reset the cached pilot, since the subsymbols have changed
self.__cached_pilot = None
self.__pilot_elements = value
def _pilot_sequence(self, num_symbols: int = None) -> np.ndarray:
"""Generate a new sequence of pilot elements.
num_symbols (int, optional):
The required number of symbols.
By default, a symbol for each subcarrier is generated.
A sequence of symbols.
num_symbols = self.wave.num_subcarriers if num_symbols is None else num_symbols
# Generate a pseudo-random symbol stream if no subsymbols are specified
if self.__pilot_elements is None:
rng = np.random.default_rng(50)
num_bits = num_symbols * self.wave.mapping.bits_per_symbol
subsymbols = self.wave.mapping.get_symbols(rng.integers(0, 2, num_bits))
num_repetitions = int(np.ceil(num_symbols / self.__pilot_elements.num_symbols))
subsymbols = np.tile(self.__pilot_elements.raw.flat, (num_repetitions))
return subsymbols[:num_symbols]
def place_symbols(self, data_symbols: np.ndarray, reference_symbols: np.ndarray) -> np.ndarray:
reference_symbols = self._pilot_sequence(self.wave.num_subcarriers)
return GridSection.place_symbols(self, data_symbols, reference_symbols)
def place_samples(self, signal: np.ndarray) -> np.ndarray:
# Just a stub, since the pilot section does not consider any prefixing
return signal
def pick_samples(self, signal: np.ndarray) -> np.ndarray:
# Just a stub, since the pilot section does not consider any prefixing
return signal
def generate(self) -> np.ndarray:
if self.wave is None:
raise RuntimeError("Pilot section must be associated with a waveform")
"""Generate the pilot section in time domain."""
# Return the cached pilot signal if available and the relevant frame parameters haven't changed
if (
self.__cached_pilot is not None
and self.__cached_num_subcarriers == self.wave.num_subcarriers
and self.__cached_oversampling_factor == self.wave.oversampling_factor
return self.__cached_pilot
pilot_symbols = self._pilot_sequence(self.wave.num_subcarriers)
pilot = self.wave._forward_transformation(pilot_symbols[np.newaxis, :])
# Cache the pilot
self.__cached_pilot = pilot
self.__cached_num_subcarriers = self.wave.num_subcarriers
self.__cached_oversampling_factor = self.wave.oversampling_factor
return pilot
def to_yaml(
cls: Type[PilotSection], representer: SafeRepresenter, node: PilotSection
) -> MappingNode:
"""Serialize a serializable object to YAML.
representer (SafeRepresenter):
A handle to a representer used to generate valid YAML code.
The representer gets passed down the serialization tree to each node.
node (PilotSection):
The channel instance to be serialized.
Returns: The serialized YAML node.
additional_fields = {}
if node.pilot_elements:
additional_fields["pilot_elements"] = node.pilot_elements.raw
return node._mapping_serialization_wrapper(
representer, blacklist={"pilot_elements"}, additional_fields=additional_fields
def from_yaml(
cls: Type[PilotSection], constructor: SafeConstructor, node: Node
) -> PilotSection:
"""Recall a new serializable class instance from YAML.
constructor (SafeConstructor):
A handle to the constructor extracting the YAML information.
node (Node):
YAML node representing the `PilotSection` serialization.
Returns: The de-serialized object.
state: dict = constructor.construct_mapping(node, deep=True)
pilot_elements = state.pop("pilot_elements", None)
if pilot_elements is not None:
pilot_elements = Symbols(pilot_elements)
state["pilot_elements"] = pilot_elements
return cls.InitializationWrapper(state)
class GridVisualization(VisualizableAttribute[ImageVisualization]):
"""Plot the grid structure of an orthogonal waveform."""
def __init__(self, wave: OrthogonalWaveform) -> None:
wave (OrthogonalWaveform): Waveform this plot is associated with.
# Initialize base class
# Initialize class attributes
self.__wave = wave
def title(self) -> str:
return "Resource Grid"
def __generate_image(self) -> np.ndarray:
mask = self.__wave.resource_mask
grid = np.zeros(mask.shape[1:], dtype=np.int_)
grid[mask[ElementType.NULL.value]] = 1
grid[mask[ElementType.REFERENCE.value]] = 2
grid[mask[ElementType.DATA.value]] = 3
return grid.T
def _prepare_visualization(
self, figure: plt.Figure | None, axes: VAT, **kwargs
) -> ImageVisualization:
ax: plt.Axes = axes.flat[0]
image = ax.imshow(self.__generate_image(), cmap="viridis", aspect="auto")
return ImageVisualization(figure, axes, image)
def _update_visualization(self, visualization: ImageVisualization, **kwargs) -> None: