Source code for hermespy.core.channel

# -*- coding: utf-8 -*-

from __future__ import annotations
from itertools import product
from typing import Generator, List, SupportsIndex, Tuple, Type
from enum import Enum

import numpy as np
import matplotlib.pyplot as plt
from h5py import Group
from scipy.fft import fft, ifft
from sparse import COO, SparseArray  # type: ignore

from .factory import HDFSerializable
from .signal_model import Signal

__author__ = "Jan Adler"
__copyright__ = "Copyright 2024, Barkhausen Institut gGmbH"
__credits__ = ["Jan Adler"]
__license__ = "AGPLv3"
__version__ = "1.4.0"
__maintainer__ = "Jan Adler"
__email__ = "jan.adler@barkhauseninstitut.org"
__status__ = "Prototype"


[docs] class ChannelStateFormat(Enum): """Format flag for wireless transmission link states.""" IMPULSE_RESPONSE = 0 # Channel state in impulse response format FREQUENCY_SELECTIVITY = 1 # Channel state in frequency selectivity format
[docs] class ChannelStateDimension(Enum): """Dimension selection of channel state information.""" RECEIVE_STREAMS = 0 TRANSMIT_STREAMS = 1 SAMPLES = 2 INFORMATION = 3
[docs] class ChannelStateInformation(HDFSerializable): """State of a single wireless link between a transmitting and receiving device. Attributes: __state_format(ChannelStateFormat): The current format of the channel state information. The format may change depending on the most recent format requests. __state (numpy.ndarray): The current channel state. A numpy tensor of dimension `num_receive_streams`x`num_transmit_streams`x`num_samples`x`state_information`. If the state is currently in the format impulse response, `num_samples` is the time domain of the channel state and `state_information` the delay taps. If the state is currently in frequency selectivity format, `num_samples` discrete frequency domain bins and `state_information` is of length one, containing the respective Fourier weights. __num_delay_taps (int): Number of delay taps in impulse-response mode. Recovers the 4th matrix dimension during conversions. __num_frequency_bins (int): Number of discrete frequency bins in frequency-selectivity mode. Recovers the 4th matrix dimension during conversions. """ __state_format: ChannelStateFormat __state: np.ndarray | SparseArray __num_delay_taps: int __num_frequency_bins: int def __init__( self, state_format: ChannelStateFormat, state: np.ndarray | SparseArray = None, num_delay_taps: int | None = None, num_frequency_bins: int | None = None, ) -> None: """Channel State Information object initialization. Args: state_format (ChannelStateFormat): Format of the `state` from which to initialize the channel state information. state (numpy.ndarray | sparse.SparseArray, optional): Channel state matrix. A numpy tensor of dimension `num_receive_streams`x`num_transmit_streams`x`num_samples`x`state_information`. num_delay_taps (int, optional): Number of delay taps in impulse-response mode. num_frequency_bins (int): Number of discrete frequency bins in frequency-selectivity mode.. Raises: ValueError: If `state` dimensions are invalid. """ self.set_state(state_format, state, num_delay_taps, num_frequency_bins) @property def state_format(self) -> ChannelStateFormat: """Current channel state format. Returns: ChannelStateFormat: The current channel state format. """ return self.__state_format @property def state(self) -> np.ndarray | SparseArray: """Current channel state tensor. Returns: np.ndarray: The current channel state tensor. """ return self.__state @state.setter def state(self, new_state: np.ndarray | SparseArray) -> None: """Modify the channel state tensor. Args: new_state (numpy.ndarray | SparseArray): The new channel state. Raises: ValueError: On invalid state dimensions. """ self.set_state(self.__state_format, new_state)
[docs] def dense_state(self) -> np.ndarray: """Return the channel state in dense format. Note that this method will convert the channel state to dense format if it is currently in sparse format. This operation may be computationally expensive and should be avoided if possible. Returns: The channel state tensor in dense format. """ return self.__state.todense() if isinstance(self.__state, (SparseArray)) else self.__state
[docs] def set_state( self, state_format: ChannelStateFormat, state: np.ndarray | SparseArray = None, num_delay_taps: int | None = None, num_frequency_bins: int | None = None, ) -> None: """Set a new channel state. Args: state_format (ChannelStateFormat): Format of the `state` from which to initialize the channel state information. state (numpy.ndarray | SparseArray, optional): Channel state matrix. A numpy tensor of dimension `num_receive_streams`x`num_transmit_streams`x`num_samples`x`state_information`. num_delay_taps (int, optional): Number of delay taps. num_frequency_bins (int, optional): Number of discrete frequency bins. Raises: ValueError: If `state` dimensions are invalid. """ state = np.empty((0, 0, 0, 1), dtype=complex) if state is None else state if state.ndim != 4: raise ValueError("Channel state tensor must be 4-dimensional") if num_delay_taps is None: num_delay_taps = state.shape[3] if num_frequency_bins is None: num_frequency_bins = state.shape[3] # if num_delay_taps < 1: # raise ValueError("Number of delay taps must be greater or equal to one") # if num_frequency_bins < 1: # raise ValueError("Number of frequency bins must be greater or equal to one") if state_format == ChannelStateFormat.IMPULSE_RESPONSE and num_delay_taps != state.shape[3]: raise ValueError( "Number of delay taps must be equal to the last dimension of the impulse response" ) # if state_format == ChannelStateFormat.FREQUENCY_SELECTIVITY and state.shape[3] != 1: # raise ValueError("In frequency selectivity mode," # "the fourth channel state matrix dimension must be of size one") self.__state_format = state_format self.__state = state self.__num_delay_taps = num_delay_taps self.__num_frequency_bins = num_frequency_bins
[docs] def to_impulse_response(self) -> ChannelStateInformation: """Access the channel state in time-domain. May convert the internal state format via FFT. Returns: ChannelStateInformation: The current channel tensor of dimensions `num_receive_streams`x`num_transmit_streams`x`num_samples`x`num_delay_taps`. """ if self.__state_format == ChannelStateFormat.FREQUENCY_SELECTIVITY: self.__state = ifft(self.__state, axis=3) self.__state_format = ChannelStateFormat.IMPULSE_RESPONSE return self
[docs] def to_frequency_selectivity(self, num_bins: int | None = None) -> ChannelStateInformation: """Access the channel state in frequency-domain. May convert the internal state format via FFT. Args: num_bins (int, optional): Number of discrete frequency bins. By default, this will be the number of time samples, i.e. a FFT without zero-padding will be performed. Returns: ChannelStateInformation: The current channel tensor of dimensions `num_receive_streams`x`num_transmit_streams`x`num_samples`x`num_frequency_bins`. """ if self.__state_format == ChannelStateFormat.IMPULSE_RESPONSE: if num_bins is None: num_bins = self.__num_frequency_bins else: self.__num_frequency_bins = num_bins self.__state = fft(self.dense_state()[:, :, :num_bins, :], axis=3, n=num_bins) self.__state_format = ChannelStateFormat.FREQUENCY_SELECTIVITY return self
@property def num_receive_streams(self) -> int: """Number of receive streams within this channel state. Returns: int: Number of receive streams. """ return self.__state.shape[0] @property def num_transmit_streams(self) -> int: """Number of transmit streams within this channel state. Returns: int: Number of transmit streams. """ return self.__state.shape[1] @property def num_samples(self) -> int: """Number of time-domain samples within this channel state. Returns: int: Number of samples. """ return self.__state.shape[2] @property def num_symbols(self) -> int: """Number of symbols considered within this channel state. Returns: int: Number of symbols. """ if self.__state_format == ChannelStateFormat.IMPULSE_RESPONSE: return self.__state.shape[2] else: # Channel estate is in frequency selectivity format return self.__state.shape[2] * self.__state.shape[3] @property def num_delay_taps(self) -> int: """Number of taps within the delay response of the channel state. Returns: int: Number of taps. """ return self.__num_delay_taps @property def linear(self) -> SparseArray: """Convert the channel state to a linear transformation tensor. Returns: Sparse linear transformation tensor of dimension N_Rx x N_Tx x N_out x N_in. """ if self.__state_format == ChannelStateFormat.IMPULSE_RESPONSE: return self.__impulse_response_transformation() else: # Channel estate is in frequency selectivity format return self.__frequency_response_transformation() def __impulse_response_transformation(self) -> SparseArray: """Convert a channel impulse response to a linear transformation tensor. Returns: SparseArray: Sparse linear transformation tensor of dimension N_Rx x N_Tx x T+L x T. Note that the slice over the last two dimensions will form a lower triangular matrix. """ num_rx = self.num_receive_streams num_tx = self.num_transmit_streams num_taps = self.__state.shape[3] num_s = self.num_samples num_out = num_s + num_taps - 1 num_in = num_s in_ids = np.repeat(np.arange(num_in), num_taps) out_ids = np.array([np.arange(num_taps) + t for t in range(num_in)]).flatten() rx_ids = np.arange(num_rx) tx_ids = np.arange(num_tx) coordinates = [ rx_ids.repeat(num_tx * num_taps * num_in), tx_ids.repeat(num_rx * num_taps * num_in).reshape((num_tx, -1), order="F").flatten(), np.tile(out_ids, num_rx * num_tx), np.tile(in_ids, num_rx * num_tx), ] transformation = COO( coordinates, self.__state.flatten(), shape=(num_rx, num_tx, num_out, num_in) ) return transformation def __frequency_response_transformation(self) -> SparseArray: """Convert a channel frequency response to a linear transformation tensor. Returns: COO: Sparse linear transformation tensor of dimension N_Rx x N_Tx x F*T x F*T. Note that the slice over the first and last dimension will be a diagonal matrix. """ num_rx = self.num_receive_streams num_tx = self.num_transmit_streams num_s = self.num_samples num_frequencies = self.__state.shape[3] num_symbols = num_s * num_frequencies diagonal_ids = np.arange(num_symbols) rx_ids = np.arange(num_rx) tx_ids = np.arange(num_tx) coordinates = [ rx_ids.repeat(num_tx * num_symbols), # ToDo: This is probably not completely correct np.tile(tx_ids.repeat(num_symbols), num_rx), np.tile(diagonal_ids, num_rx * num_tx), np.tile(diagonal_ids, num_rx * num_tx), ] transformation = COO( coordinates, self.__state.flatten(), shape=(num_rx, num_tx, num_symbols, num_symbols) ) return transformation
[docs] @staticmethod def Ideal( num_samples: int, num_receive_streams: int = 1, num_transmit_streams: int = 1 ) -> ChannelStateInformation: """Initialize an ideal channel state. Args: num_samples (int): Number of timestamps at which the channel state has been sampled. num_receive_streams (int, optional): Number of emerging data streams after channel propagation. num_transmit_streams (int, optional): Number of data streams feeding into the channel before propagation. Returns: ChannelStateInformation: Ideal channel state information of a non-distorting channel. """ state = np.ones((num_receive_streams, num_transmit_streams, num_samples, 1), dtype=complex) return ChannelStateInformation(ChannelStateFormat.IMPULSE_RESPONSE, state)
[docs] def received_streams(self) -> Generator[ChannelStateInformation, ChannelStateInformation, None]: """Iterate over the received streams slices within this channel state. Returns: Generator: Generator. """ for received_stream in self.__state: yield ChannelStateInformation(self.__state_format, received_stream[np.newaxis, ...])
[docs] def samples(self) -> Generator[ChannelStateInformation, ChannelStateInformation, None]: """Iterate over the sample slices within this channel state. Returns: Generator: Generator. """ for sample_idx in range(self.num_samples): yield ChannelStateInformation( self.__state_format, self.__state[:, :, [sample_idx], :], self.__num_delay_taps, self.__num_frequency_bins, )
def __getitem__( self, section: SupportsIndex | Tuple[SupportsIndex | slice, ...] | slice ) -> ChannelStateInformation: """Slice the channel state information. Args: section (slice): Slice of the channel state. Returns: ChannelStateInformation: New channel state with a section according to `value` slice. """ state_section = self.__state[section] for s, sec in enumerate(section): # type: ignore if isinstance(sec, int): state_section = np.expand_dims(state_section, axis=s) num_delay_taps = ( self.__num_delay_taps if state_section.shape[3] == self.__state.shape[3] else None ) return ChannelStateInformation(self.__state_format, state_section, num_delay_taps) def __setitem__( self, key: SupportsIndex | slice | Tuple[SupportsIndex | slice, ...], value: ChannelStateInformation, ) -> None: """Update the channel state information. Args: key (slice): Section of the channel state to be set. value (ChannelStateInformation): The information to be set. Raises: NotImplementedError: If the formats of `value` and this channel do not match. """ if value.state_format != self.__state_format: raise NotImplementedError("Setting CSIs of a different type is not yet supported") self.__state[key] = value.__state
[docs] @staticmethod def concatenate( elements: List[ChannelStateInformation], dimension: ChannelStateDimension ) -> ChannelStateInformation: states = [element.__state for element in elements] stack = np.concatenate(states, axis=dimension.value) # ToDo: Make this smarter, it's not generally correct state_format = ( elements[0].__state_format if len(elements) > 0 else ChannelStateFormat.IMPULSE_RESPONSE ) num_delay_taps = elements[0].__num_delay_taps if len(elements) > 0 else None return ChannelStateInformation(state_format, stack, num_delay_taps)
[docs] def plot(self) -> None: """Visualize the internal channel state information. Plots the absolute values of all channel state weights. """ fig, axes = plt.subplots(self.__state.shape[0], self.__state.shape[1], squeeze=False) for rx_id, receive_states in enumerate(self.__state): for tx_id, transmit_states in enumerate(receive_states): axes[rx_id, tx_id].imshow(abs(transmit_states))
[docs] def propagate(self, signal: Signal) -> Signal: """Propagate a single signal model over this channel state information. This method should generally be avoided, since it's computationally costly. Instead, whenever there is access to a :class:`ChannelRealization`, :meth:`ChannelRealization.propagate` should always be preferred. Args: signal (Signal): Signal model to be propagated. Returns: Propagated signal model. """ # Make sure the accessed state is in impulse response format state = self.to_impulse_response().dense_state() # Propagate the signal propagated_samples = np.zeros( (state.shape[0], signal.num_samples + state.shape[3] - 1), dtype=np.complex128 ) for delay_index in range(state.shape[3]): for tx_idx, rx_idx in product(range(state.shape[1]), range(state.shape[0])): delayed_signal = ( state[rx_idx, tx_idx, : signal.num_samples, delay_index] * signal.getitem(tx_idx).flatten() ) propagated_samples[ rx_idx, delay_index : delay_index + signal.num_samples ] += delayed_signal return signal.from_ndarray(propagated_samples)
[docs] def reciprocal(self) -> ChannelStateInformation: """Compute the reciprocal channel state. Returns: The reciprocal channel state information. """ reciprocal_state = self.__state.transpose((1, 0, 2, 3)) return ChannelStateInformation( self.__state_format, reciprocal_state, self.num_delay_taps, self.__num_frequency_bins )
@classmethod def from_HDF(cls: Type[ChannelStateInformation], group: Group) -> ChannelStateInformation: # Recall datasets state = np.array(group["state"], dtype=complex) # Recall attributes format = ChannelStateFormat[group.attrs.get("format", "IMPULSE_RESPONSE")] # Initialize object from recalled state return cls(state=state, state_format=format) def to_HDF(self, group: Group) -> None: # Serialize datasets group.create_dataset("state", data=self.dense_state()) # Serialize attributes group.attrs["num_transmit_streams"] = self.num_transmit_streams group.attrs["num_receive_streams"] = self.num_receive_streams group.attrs["num_symbols"] = self.num_symbols group.attrs["num_taps"] = self.num_delay_taps group.attrs["num_samples"] = self.num_samples group.attrs["format"] = self.state_format.name