Source code for hermespy.modem.waveform_single_carrier

# -*- coding: utf-8 -*-

from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any
from typing_extensions import override

import matplotlib.pyplot as plt
import numpy as np

from hermespy.core import (
    Executable,
    FloatingError,
    Serializable,
    Signal,
    SerializationProcess,
    DeserializationProcess,
)
from .waveform import (
    ConfigurablePilotWaveform,
    MappedPilotSymbolSequence,
    CommunicationWaveform,
    ChannelEstimation,
    ChannelEqualization,
    PilotSymbolSequence,
    ZeroForcingChannelEqualization,
)
from hermespy.modem.tools.psk_qam_mapping import PskQamMapping
from .symbols import StatedSymbols, Symbols
from .waveform_correlation_synchronization import CorrelationSynchronization

__author__ = "Andre Noll Barreto"
__copyright__ = "Copyright 2024, Barkhausen Institut gGmbH"
__credits__ = ["Andre Noll Barreto", "Tobias Kronauer", "Jan Adler"]
__license__ = "AGPLv3"
__version__ = "1.5.0"
__maintainer__ = "Jan Adler"
__email__ = "jan.adler@barkhauseninstitut.org"
__status__ = "Prototype"


[docs] class FilteredSingleCarrierWaveform(ConfigurablePilotWaveform): """This method provides a class for a generic PSK/QAM modem. The modem has the following characteristics: - root-raised cosine filter with arbitrary roll-off factor - arbitrary constellation, as defined in modem.tools.psk_qam_mapping:PskQamMapping This implementation has currently the following limitations: - hard output only (no LLR) - no reference signal - ideal channel estimation - equalization of ISI with FMCW in AWGN channel only - no equalization (only amplitude and phase of first propagation path is compensated) """ __DEFAULT_SYMBOL_RATE: float = 1e6 __DEFAULT_NUM_PREAMBLE_SYMBOLS: int = 16 __DEFAULT_NUM_DATA_SYMBOLS: int = 256 __DEFAULT_NUM_POSTAMBLE_SYMBOLS: int = 0 __DEFAULT_PILOT_RATE: int = 0 __DEFAULT_GUARD_INTERVAL: float = 0.0 __DEFAULT_OVERSAMPLING_FACTOR: int = 4 __symbol_rate: float __num_preamble_symbols: int __num_data_symbols: int __num_postamble_symbols: int __guard_interval: float __mapping: PskQamMapping __pilot_rate: int _data_symbol_idx: np.ndarray | None _pulse_correlation_matrix: np.ndarray | None def __init__( self, symbol_rate: float = __DEFAULT_SYMBOL_RATE, num_preamble_symbols: int = __DEFAULT_NUM_PREAMBLE_SYMBOLS, num_data_symbols: int = __DEFAULT_NUM_DATA_SYMBOLS, num_postamble_symbols: int = __DEFAULT_NUM_POSTAMBLE_SYMBOLS, pilot_rate: int = __DEFAULT_PILOT_RATE, guard_interval: float = __DEFAULT_GUARD_INTERVAL, oversampling_factor: int = __DEFAULT_OVERSAMPLING_FACTOR, pilot_symbol_sequence: PilotSymbolSequence | None = None, repeat_pilot_symbol_sequence: bool = True, **kwargs: Any, ) -> None: """ Args: symbol_rate: Rate at which symbols are being generated in Hz. num_preamble_symbols: Number of preamble symbols within a single communication frame. num_data_symbols: Number of data symbols within a single communication frame. num_postamble_symbols: Number of postamble symbols within a single communication frame. guard_interval: Guard interval between communication frames in seconds. Zero by default. oversampling_factor: The oversampling factor of the waform. pilot_rate: Pilot symbol rate. Zero by default, i.e. no pilot symbols. pilot_symbol_sequence: The configured pilot symbol sequence. Uniform by default. repeat_pilot_symbol_sequence: Allow the repetition of pilot symbol sequences. Enabled by default. \*\*kwargs: Waveform generator base class initialization parameters. """ # Init base class ConfigurablePilotWaveform.__init__( self, repeat_symbol_sequence=repeat_pilot_symbol_sequence, oversampling_factor=oversampling_factor, **kwargs, ) self.symbol_rate = symbol_rate self.num_preamble_symbols = num_preamble_symbols self.num_data_symbols = num_data_symbols self.num_postamble_symbols = num_postamble_symbols self.pilot_rate = pilot_rate self.guard_interval = guard_interval self.pilot_symbol_sequence = ( MappedPilotSymbolSequence(self.__mapping) if pilot_symbol_sequence is None else pilot_symbol_sequence ) @abstractmethod def _transmit_filter(self) -> np.ndarray: """Pulse shaping filter applied to data symbols during transmission. Returns: The shaping filter impulse response. """ ... # pragma: no cover @abstractmethod def _receive_filter(self) -> np.ndarray: """Pulse shaping filter applied to signal streams during reception. Returns: The shaping filter impulse response. """ ... # pragma: no cover @property @abstractmethod def _filter_delay(self) -> int: """Cumulative delay introduced during transmit and receive filtering. Returns: Delay in samples. """ ... # pragma: no cover @property def symbol_rate(self) -> float: """Repetition rate of symbols. Returns: Symbol rate in Hz. Raises: ValueError: For rates smaller or equal to zero. """ return self.__symbol_rate @symbol_rate.setter def symbol_rate(self, value: float) -> None: if value <= 0.0: raise ValueError("Symbol rate must be greater than zero") self.__symbol_rate = value @property def num_preamble_symbols(self) -> int: """Number of preamble symbols. Transmitted at the beginning of communication frames. Raises: ValueError: If the number of symbols is smaller than zero. """ return self.__num_preamble_symbols @num_preamble_symbols.setter def num_preamble_symbols(self, value: int) -> None: if value < 0: raise ValueError("Nummber of preamble symbols must be greater or equal to zero") self.__num_preamble_symbols = value @property def num_postamble_symbols(self) -> int: """Number of postamble symbols. Transmitted at the end of communication frames. Raises: ValueError: If the number of symbols is smaller than zero. """ return self.__num_postamble_symbols @num_postamble_symbols.setter def num_postamble_symbols(self, value: int) -> None: if value < 0: raise ValueError("Nummber of postamble symbols must be greater or equal to zero") self.__num_postamble_symbols = value @CommunicationWaveform.modulation_order.setter # type: ignore def modulation_order(self, value: int) -> None: self.__mapping = PskQamMapping(value, soft_output=False) self.pilot_symbol_sequence = MappedPilotSymbolSequence( self.__mapping ) # ToDo: Find a better way to update the pilot symbol sequence CommunicationWaveform.modulation_order.fset(self, value) # type: ignore @property def pilot_signal(self) -> Signal: if self.num_preamble_symbols < 1: return Signal.Empty(self.sampling_rate) pilot_symbols = np.zeros( 1 + (self.num_preamble_symbols - 1) * self.oversampling_factor, dtype=complex ) pilot_symbols[:: self.oversampling_factor] = self.pilot_symbols(self.num_preamble_symbols) return Signal.Create( np.convolve(pilot_symbols, self._transmit_filter()), sampling_rate=self.sampling_rate )
[docs] @override def map(self, bits: np.ndarray) -> Symbols: return Symbols(self.__mapping.get_symbols(bits))
[docs] @override def unmap(self, symbols: Symbols) -> np.ndarray: return self.__mapping.detect_bits(symbols.raw.flatten())
[docs] @override def place(self, data_symbols: Symbols) -> Symbols: # Generate pilot symbol sequences pilot_symbols = self.pilot_symbols( self.num_preamble_symbols + self._num_pilot_symbols + self.num_postamble_symbols ) placed_symbols = np.empty(self._num_frame_symbols, dtype=np.complex128) # Assign preamble symbols within the frame placed_symbols[: self.num_preamble_symbols] = pilot_symbols[: self.num_preamble_symbols] # Assign postamble symbols within the frame placed_symbols[self.num_preamble_symbols + self._num_payload_symbols :] = pilot_symbols[ self.num_preamble_symbols + self._num_pilot_symbols : ] # Assign payload symbols within the frame # The payload consists of data symbols interleaved with pilots according to the pilot rate placed_symbols[self.num_preamble_symbols + self._pilot_symbol_indices] = pilot_symbols[ self.num_preamble_symbols : self.num_preamble_symbols + self._num_pilot_symbols ] placed_symbols[self.num_preamble_symbols + self._data_symbol_indices] = ( data_symbols.raw.flatten() ) return Symbols(placed_symbols[np.newaxis, :, np.newaxis])
[docs] @override def pick(self, symbols: StatedSymbols) -> StatedSymbols: data_block_indices = self.num_preamble_symbols + self._data_symbol_indices picked_symbol_blocks = symbols.raw[:, data_block_indices, :] picked_state_blocks = symbols.states[:, :, data_block_indices, :] return StatedSymbols(picked_symbol_blocks, picked_state_blocks)
[docs] @override def modulate(self, symbols: Symbols) -> np.ndarray: frame = np.zeros( 1 + (self._num_frame_symbols - 1) * self.oversampling_factor, dtype=complex ) frame[:: self.oversampling_factor] = symbols.raw.flatten() # Generate waveforms by treating the frame as a comb and convolving with the impulse response output_signal = np.convolve(frame, self._transmit_filter()) return output_signal
[docs] @override def demodulate(self, signal: np.ndarray) -> Symbols: # Query filters filter_delay = self._filter_delay # Filter the signal and csi filtered_signal = np.convolve(signal, self._receive_filter()) symbols = filtered_signal[ filter_delay : filter_delay + self._num_frame_symbols * self.oversampling_factor : self.oversampling_factor ] return Symbols(symbols[np.newaxis, :, np.newaxis])
@property def guard_interval(self) -> float: """Frame guard interval. Raises: ValueError: If `interval` is smaller than zero. """ return self.__guard_interval @guard_interval.setter def guard_interval(self, interval: float) -> None: if interval < 0.0: raise ValueError("Guard interval must be greater or equal to zero") self.__guard_interval = interval @property def num_guard_samples(self) -> int: """Number of samples within the guarding section of a frame.""" return int(round(self.guard_interval * self.sampling_rate)) @property def pilot_rate(self) -> int: """Repetition rate of pilot symbols within the frame. A pilot rate of zero indicates no pilot symbols within the data frame. Raises: ValueError: If the pilot rate is smaller than zero. """ return self.__pilot_rate @pilot_rate.setter def pilot_rate(self, value: int) -> None: if value < 0: raise ValueError("Pilot symbol rate must be greater or equal to zero") self.__pilot_rate = int(value) @property def _num_pilot_symbols(self) -> int: if self.pilot_rate <= 0: return 0 return max(0, int(self.num_data_symbols / self.pilot_rate) - 1) @property def _num_payload_symbols(self) -> int: num_symbols = self.num_data_symbols + self._num_pilot_symbols return num_symbols @property def _num_frame_symbols(self) -> int: return self.num_preamble_symbols + self._num_payload_symbols + self.num_postamble_symbols @property def _pilot_symbol_indices(self) -> np.ndarray: """Indices of pilot symbols within the ful communication frame.""" if self.pilot_rate <= 0: return np.empty(0, dtype=int) pilot_indices = np.arange(1, 1 + self._num_pilot_symbols) * (1 + self.pilot_rate) - 1 return pilot_indices @property def _data_symbol_indices(self) -> np.ndarray: """Indices of data symbols within the full communication frame.""" data_indices = np.arange(self._num_payload_symbols) payload_indices = self._pilot_symbol_indices if len(payload_indices) > 0: data_indices = np.delete(data_indices, self._pilot_symbol_indices) return data_indices @property def num_data_symbols(self) -> int: """Number of data symbols per frame. Raises: ValueError: If `num` is smaller than zero. """ return self.__num_data_symbols @num_data_symbols.setter def num_data_symbols(self, num: int) -> None: if num < 0: raise ValueError("Number of data symbols must be greater or equal to zero") self.__num_data_symbols = num @property def samples_per_frame(self) -> int: return ( self._num_frame_symbols - 1 ) * self.oversampling_factor + self._transmit_filter().shape[0] @property @override def symbol_duration(self) -> float: return 1 / self.symbol_rate @property @override def bit_energy(self) -> float: return 1 / self.bits_per_symbol @property @override def symbol_energy(self) -> float: return 1.0 @property @override def power(self) -> float: return 1 / self.oversampling_factor @property @override def sampling_rate(self) -> float: return self.symbol_rate * self.oversampling_factor
[docs] def plot_filter_correlation(self) -> plt.Figure: """Plot the convolution between transmit and receive filter shapes. Returns: Handle to the generated matplotlib figure. """ with Executable.style_context(): tx_filter = self._transmit_filter() rx_filter = self._receive_filter() autocorrelation = np.convolve(tx_filter, rx_filter) fig, axes = plt.subplots() fig.suptitle("Pulse Autocorrelation") axes.plot(np.abs(autocorrelation)) return fig
[docs] def plot_filter(self) -> plt.Figure: """Plot the transmit filter shape. Returns: Handle to the generated matplotlib figure. """ with Executable.style_context(): tx_filter = self._transmit_filter() fig, axes = plt.subplots() fig.suptitle("Pulse Shape") axes.plot(tx_filter.real) axes.plot(tx_filter.imag) return fig
[docs] @override def serialize(self, process: SerializationProcess) -> None: ConfigurablePilotWaveform.serialize(self, process) process.serialize_floating(self.symbol_rate, "symbol_rate") process.serialize_integer(self.num_preamble_symbols, "num_preamble_symbols") process.serialize_integer(self.num_data_symbols, "num_data_symbols") process.serialize_integer(self.num_postamble_symbols, "num_postamble_symbols") process.serialize_floating(self.guard_interval, "guard_interval") process.serialize_integer(self.pilot_rate, "pilot_rate")
@override @classmethod def _DeserializeParameters(cls, process: DeserializationProcess) -> dict[str, Any]: parameters = ConfigurablePilotWaveform._DeserializeParameters(process) parameters["symbol_rate"] = process.deserialize_floating( "symbol_rate", cls.__DEFAULT_SYMBOL_RATE ) parameters["num_preamble_symbols"] = process.deserialize_integer( "num_preamble_symbols", cls.__DEFAULT_NUM_PREAMBLE_SYMBOLS ) parameters["num_data_symbols"] = process.deserialize_integer( "num_data_symbols", cls.__DEFAULT_NUM_DATA_SYMBOLS ) parameters["num_postamble_symbols"] = process.deserialize_integer( "num_postamble_symbols", cls.__DEFAULT_NUM_POSTAMBLE_SYMBOLS ) parameters["guard_interval"] = process.deserialize_floating( "guard_interval", cls.__DEFAULT_GUARD_INTERVAL ) parameters["pilot_rate"] = process.deserialize_integer( "pilot_rate", cls.__DEFAULT_PILOT_RATE ) return parameters
[docs] @override @classmethod def Deserialize(cls, process: DeserializationProcess) -> FilteredSingleCarrierWaveform: return cls(**cls._DeserializeParameters(process)) # type: ignore[arg-type]
class SingleCarrierCorrelationSynchronization( CorrelationSynchronization[FilteredSingleCarrierWaveform] ): """Correlation-based clock-synchronization for PSK-QAM waveforms.""" class SingleCarrierChannelEstimation(ChannelEstimation[FilteredSingleCarrierWaveform], ABC): """Channel estimation for Psk Qam waveforms.""" def __init__(self, waveform: FilteredSingleCarrierWaveform | None = None) -> None: """ Args: waveform: The waveform generator this synchronization routine is attached to. """ ChannelEstimation.__init__(self, waveform) class SingleCarrierLeastSquaresChannelEstimation(SingleCarrierChannelEstimation): """Least-Squares channel estimation for Psk Qam waveforms.""" def __init__(self, waveform: FilteredSingleCarrierWaveform | None = None) -> None: """ Args: waveform: The waveform generator this channel estimation routine is attached to. """ SingleCarrierChannelEstimation.__init__(self, waveform) def estimate_channel(self, symbols: Symbols, delay: float = 0.0) -> StatedSymbols: if self.waveform is None: raise FloatingError( "Error trying to fetch the pilot section of a floating channel estimator" ) # Query required waveform information num_preamble_symbols = self.waveform.num_preamble_symbols num_postamble_symbols = self.waveform.num_postamble_symbols num_payload_symbols = self.waveform._num_payload_symbols num_pilot_symbols = self.waveform._num_pilot_symbols pilot_symbol_indices = self.waveform._pilot_symbol_indices transmitted_reference_symbols = self.waveform.pilot_symbols( num_preamble_symbols + num_pilot_symbols + num_postamble_symbols ) # Extract reference symbols preamble_symbols = symbols.raw[:, :num_preamble_symbols, 0] pilot_symbols = symbols.raw[ :, num_preamble_symbols : num_preamble_symbols + num_payload_symbols, 0 ][:, pilot_symbol_indices] postamble_symbols = symbols.raw[:, num_preamble_symbols + num_payload_symbols :, 0] received_reference_symbols = np.concatenate( (preamble_symbols, pilot_symbols, postamble_symbols), axis=1 ) # Estimate the channel over all reference symbols channel_estimation_stems = received_reference_symbols / transmitted_reference_symbols channel_estimation_stem_indices = np.concatenate( ( np.arange(num_preamble_symbols), pilot_symbol_indices + num_preamble_symbols, np.arange(num_postamble_symbols) + num_preamble_symbols + num_payload_symbols, ) ) # Interpolate to the whole channel estimation channel_estimation_indices = np.arange(symbols.num_blocks) channel_estimation = np.empty((symbols.num_streams, symbols.num_blocks), dtype=complex) for s, stems in enumerate(channel_estimation_stems): channel_estimation[s, :] = np.interp( channel_estimation_indices, channel_estimation_stem_indices, stems ) return StatedSymbols(symbols.raw, channel_estimation[:, np.newaxis, :, np.newaxis]) class SingleCarrierChannelEqualization(ChannelEqualization[FilteredSingleCarrierWaveform], ABC): """Channel estimation for Psk Qam waveforms.""" def __init__(self, waveform: FilteredSingleCarrierWaveform | None = None) -> None: """ Args: waveform: The waveform generator this equalization routine is attached to. """ ChannelEqualization.__init__(self, waveform) class SingleCarrierZeroForcingChannelEqualization( ZeroForcingChannelEqualization[FilteredSingleCarrierWaveform] ): """Zero-Forcing Channel estimation for Psk Qam waveforms.""" class SingleCarrierMinimumMeanSquareChannelEqualization(SingleCarrierChannelEqualization, ABC): """Minimum-Mean-Square Channel estimation for Psk Qam waveforms.""" def __init__(self, waveform: FilteredSingleCarrierWaveform | None = None) -> None: """ Args: waveform: The waveform generator this equalization routine is attached to. """ SingleCarrierChannelEqualization.__init__(self, waveform) def equalize_channel(self, symbols: StatedSymbols) -> Symbols: # Query SNR and cached CSI from the device snr = float("inf") # self.waveform.modem.receiving_device.snr # If no information about transmitted streams is available, assume orthogonal channels if symbols.num_transmit_streams < 2 and symbols.num_streams < 2: return Symbols(symbols.raw / (symbols.states[:, 0, :, :] + 1 / snr)) if symbols.num_transmit_streams > symbols.num_streams: raise RuntimeError( "MMSE equalization is not supported for more transmit streams than receive streams" ) # Default behaviour for mimo systems is to use the pseudo-inverse for equalization raw_equalized_symbols = np.empty( (symbols.num_transmit_streams, symbols.num_blocks, symbols.num_symbols), dtype=complex ) for b, s in np.ndindex(symbols.num_blocks, symbols.num_symbols): symbol_slice = symbols.raw[:, b, s] mimo_state = symbols.states[:, :, b, s] # ToDo: Introduce noise term here equalization = np.linalg.pinv(mimo_state) raw_equalized_symbols[:, b, s] = equalization @ symbol_slice return Symbols(raw_equalized_symbols)
[docs] class RolledOffSingleCarrierWaveform(FilteredSingleCarrierWaveform): """Base class for single carrier waveforms applying linear filters longer than a single symbol duration.""" __DEFAULT_RELATIVE_BANDWIDTH: float = 1.0 # Default pulse bandwidth relative to the symbol rate __DEFAULT_ROLL_OFF: float = 0.0 # Default filter pulse roll off factor __DEFAULT_FILTER_LENGTH: int = 16 # Default filter length in modulation symbols # Pulse bandwidth relative to the configured symbol rate __relative_bandwidth: float __roll_off: float # Filter pulse roll off factor __filter_length: int # Filter length in modulation symbols def __init__( self, relative_bandwidth: float = __DEFAULT_RELATIVE_BANDWIDTH, roll_off: float = __DEFAULT_ROLL_OFF, filter_length: int = __DEFAULT_FILTER_LENGTH, *args, **kwargs, ) -> None: """ Args: relative_bandwidth: Bandwidth relative to the configured symbol rate. One by default, meaning the pulse bandwidth is equal to the symbol rate in Hz, assuming zero `roll_off`. roll_off: Filter pulse shape roll off factor between zero and one. Zero by default, meaning no inter-symbol interference at the sampling instances. filter_length: Filter length in modulation symbols. 16 by default. """ self.relative_bandwidth = relative_bandwidth self.roll_off = roll_off self.filter_length = filter_length FilteredSingleCarrierWaveform.__init__(self, *args, **kwargs) @property def filter_length(self) -> int: """Filter length in modulation symbols. Configures how far the shaping filter stretches in terms of the number of modulation symbols it overlaps with. Raises: ValueError: For filter lengths smaller than one. """ return self.__filter_length @filter_length.setter def filter_length(self, value: int) -> None: if value < 1: raise ValueError("Filter length must be greater than zero") self.__filter_length = value @property def relative_bandwidth(self) -> float: """Bandwidth relative to the configured symbol rate. Raises: ValueError: On values smaller or equal to zero. """ return self.__relative_bandwidth @relative_bandwidth.setter def relative_bandwidth(self, value: float) -> None: if value <= 0.0: raise ValueError("Relative pulse bandwidth must be greater than zero") self.__relative_bandwidth = value @property def roll_off(self) -> float: """Filter pulse shape roll off factor. Raises: ValueError: On values smaller than zero or larger than one. """ return self.__roll_off @roll_off.setter def roll_off(self, value: float) -> None: if value < 0.0 or value > 1.0: raise ValueError( "Filter pulse shape roll off factor value must be between zero and one" ) self.__roll_off = value @property def bandwidth(self) -> float: return self.symbol_rate * self.relative_bandwidth * (1 + self.roll_off) @abstractmethod def _base_filter(self) -> np.ndarray: """Generate the base filter impulse response. Returns: The base filter impulse response as a numpy array. """ ... # pragma: no cover def _transmit_filter(self) -> np.ndarray: return self._base_filter() def _receive_filter(self) -> np.ndarray: return self._base_filter() @property def _filter_delay(self) -> int: return 2 * int(0.5 * self.filter_length) * self.oversampling_factor
[docs] @override def serialize(self, process: SerializationProcess) -> None: FilteredSingleCarrierWaveform.serialize(self, process) process.serialize_floating(self.relative_bandwidth, "relative_bandwidth") process.serialize_floating(self.roll_off, "roll_off") process.serialize_integer(self.filter_length, "filter_length")
@override @classmethod def _DeserializeParameters(cls, process: DeserializationProcess) -> dict[str, Any]: parameters = FilteredSingleCarrierWaveform._DeserializeParameters(process) parameters["relative_bandwidth"] = process.deserialize_floating( "relative_bandwidth", cls.__DEFAULT_RELATIVE_BANDWIDTH ) parameters["roll_off"] = process.deserialize_floating("roll_off", cls.__DEFAULT_ROLL_OFF) parameters["filter_length"] = process.deserialize_integer( "filter_length", cls.__DEFAULT_FILTER_LENGTH ) return parameters
[docs] class RootRaisedCosineWaveform(RolledOffSingleCarrierWaveform, Serializable): """Root-Raised-Cosine filtered single carrier modulation.""" def __init__(self, *args, **kwargs) -> None: RolledOffSingleCarrierWaveform.__init__(self, *args, **kwargs) def _base_filter(self) -> np.ndarray: impulse_response = np.zeros(self.oversampling_factor * self.filter_length) # Generate timestamps time = ( np.linspace( -int(0.5 * self.filter_length), int(0.5 * self.filter_length), self.filter_length * self.oversampling_factor, endpoint=(self.filter_length % 2 == 1), ) * self.relative_bandwidth ) # Build filter response idx_0_by_0 = time == 0 # indices with division of zero by zero if self.roll_off != 0: # indices with division by zero idx_x_by_0 = abs(time) == 1 / (4 * self.roll_off) else: idx_x_by_0 = np.zeros_like(time, dtype=bool) idx = (~idx_0_by_0) & (~idx_x_by_0) impulse_response[idx] = ( np.sin(np.pi * time[idx] * (1 - self.roll_off)) + 4 * self.roll_off * time[idx] * np.cos(np.pi * time[idx] * (1 + self.roll_off)) ) / (np.pi * time[idx] * (1 - (4 * self.roll_off * time[idx]) ** 2)) if np.any(idx_x_by_0): impulse_response[idx_x_by_0] = ( self.roll_off / np.sqrt(2) * ( (1 + 2 / np.pi) * np.sin(np.pi / (4 * self.roll_off)) + (1 - 2 / np.pi) * np.cos(np.pi / (4 * self.roll_off)) ) ) impulse_response[idx_0_by_0] = 1 + self.roll_off * (4 / np.pi - 1) return impulse_response / np.linalg.norm(impulse_response)
[docs] class RaisedCosineWaveform(RolledOffSingleCarrierWaveform, Serializable): """Root-Raised-Cosine filtered single carrier modulation.""" def __init__(self, *args, **kwargs) -> None: RolledOffSingleCarrierWaveform.__init__(self, *args, **kwargs) def _base_filter(self) -> np.ndarray: impulse_response = np.zeros(self.oversampling_factor * self.filter_length) # Generate timestamps time = ( np.linspace( -int(0.5 * self.filter_length), int(0.5 * self.filter_length), self.filter_length * self.oversampling_factor, endpoint=(self.filter_length % 2 == 1), ) * self.relative_bandwidth ) # Build filter response if self.roll_off != 0: # indices with division of zero by zero idx_0_by_0 = abs(time) == 1 / (2 * self.roll_off) else: idx_0_by_0 = np.zeros_like(time, dtype=bool) idx = ~idx_0_by_0 impulse_response[idx] = ( np.sinc(time[idx]) * np.cos(np.pi * self.roll_off * time[idx]) / (1 - (2 * self.roll_off * time[idx]) ** 2) ) if np.any(idx_0_by_0): impulse_response[idx_0_by_0] = np.pi / 4 * np.sinc(1 / (2 * self.roll_off)) return impulse_response / np.linalg.norm(impulse_response)
[docs] class RectangularWaveform(FilteredSingleCarrierWaveform, Serializable): """Rectangular filtered single carrier modulation.""" __DEFAULT_RELATIVE_BANDWIDTH: float = 1.0 # Default pulse bandwidth relative to the symbol rate __relative_bandwidth: float def __init__( self, relative_bandwidth: float = __DEFAULT_RELATIVE_BANDWIDTH, *args, **kwargs ) -> None: # Init base class FilteredSingleCarrierWaveform.__init__(self, *args, **kwargs) # Init attributes self.relative_bandwidth = relative_bandwidth @property def relative_bandwidth(self) -> float: """Bandwidth relative to the configured symbol rate. Raises: ValueError: On values smaller or equal to zero. """ return self.__relative_bandwidth @relative_bandwidth.setter def relative_bandwidth(self, value: float) -> None: if value <= 0.0: raise ValueError("Relative pulse bandwidth must be greater than zero") self.__relative_bandwidth = value @property def bandwidth(self) -> float: return self.symbol_rate * self.relative_bandwidth def _transmit_filter(self) -> np.ndarray: pulse_width = int(self.oversampling_factor / self.relative_bandwidth) return np.ones(pulse_width, dtype=complex) / np.sqrt(pulse_width) def _receive_filter(self) -> np.ndarray: return self._transmit_filter() @property def _filter_delay(self) -> int: return int(self.oversampling_factor / self.relative_bandwidth) - 1
[docs] @override def serialize(self, process: SerializationProcess) -> None: FilteredSingleCarrierWaveform.serialize(self, process) process.serialize_floating(self.relative_bandwidth, "relative_bandwidth")
@override @classmethod def _DeserializeParameters(cls, process: DeserializationProcess) -> dict[str, Any]: params = FilteredSingleCarrierWaveform._DeserializeParameters(process) params["relative_bandwidth"] = process.deserialize_floating( "relative_bandwidth", cls.__DEFAULT_RELATIVE_BANDWIDTH ) return params
[docs] class FMCWWaveform(FilteredSingleCarrierWaveform, Serializable): """Frequency Modulated Continuous Waveform Filter Modulation Scheme.""" __bandwidth: float # Chirp bandwidth in Hz __chirp_duration: float # Chirp duration in seconds def __init__(self, bandwidth: float, chirp_duration: float = 0.0, *args, **kwargs) -> None: """ Args: bandwidth: The chirp bandwidth in Hz. chirp_duration: Duration of each FMCW chirp in seconds. By default, the inverse symbol rate is assumed. """ self.bandwidth = bandwidth self.chirp_duration = chirp_duration FilteredSingleCarrierWaveform.__init__(self, *args, **kwargs) @property def chirp_duration(self) -> float: """FMCW Chirp duration. A duration of zero will result in the inverse symbol rate as chirp duration. Raises: ValueError: If the duration is smaller than zero. """ return self.__chirp_duration @chirp_duration.setter def chirp_duration(self, value: float) -> None: if value < 0.0: raise ValueError("Chirp duration must be greater or equal to zero") self.__chirp_duration = value @property def __true_chirp_duration(self) -> float: """Chirp duration for internal calculations.""" if self.chirp_duration <= 0.0: return 1 / self.symbol_rate return self.chirp_duration @property def bandwidth(self) -> float: return self.__bandwidth @bandwidth.setter def bandwidth(self, value: float) -> None: if value <= 0.0: raise ValueError("Chirp bandwidth must be greater than zero") self.__bandwidth = value @property def chirp_slope(self) -> float: """Chirp slope. The slope is equal to the chirp bandwidth divided by its duration.""" return self.bandwidth / self.__true_chirp_duration def _transmit_filter(self) -> np.ndarray: time = np.linspace(0, 1 / self.symbol_rate, self.oversampling_factor) impulse_response = np.exp(1j * np.pi * (self.bandwidth * time + self.chirp_slope * time**2)) # Cut off the chirp appropriately impulse_response[time > self.__true_chirp_duration] = 0.0 return impulse_response / np.linalg.norm(impulse_response) def _receive_filter(self) -> np.ndarray: return np.flip(self._transmit_filter().conj()) @property def _filter_delay(self) -> int: return self.oversampling_factor - 1
[docs] @override def serialize(self, process: SerializationProcess) -> None: FilteredSingleCarrierWaveform.serialize(self, process) process.serialize_floating(self.bandwidth, "bandwidth") process.serialize_floating(self.chirp_duration, "chirp_duration")
[docs] @override @classmethod def Deserialize(cls, process: DeserializationProcess) -> FMCWWaveform: return cls( process.deserialize_floating("bandwidth"), process.deserialize_floating("chirp_duration"), **cls._DeserializeParameters(process), )