# -*- coding: utf-8 -*-
from __future__ import annotations
from typing import Any, Generic, List, TypeVar
from typing_extensions import override
import numpy as np
from scipy.signal import correlate, find_peaks
from hermespy.core import SerializationProcess, DeserializationProcess
from .waveform import PilotCommunicationWaveform, Synchronization
__author__ = "Jan Adler"
__copyright__ = "Copyright 2021, Barkhausen Institut gGmbH"
__credits__ = ["Jan Adler"]
__license__ = "AGPLv3"
__version__ = "1.5.0"
__maintainer__ = "Jan Adler"
__email__ = "jan.adler@barkhauseninstitut.org"
__status__ = "Prototype"
PGT = TypeVar("PGT", bound=PilotCommunicationWaveform)
"""Type of pilot-generating waveforms."""
[docs]
class CorrelationSynchronization(Generic[PGT], Synchronization[PGT]):
"""Correlation-based clock synchronization for arbitrary communication waveforms.
The implemented algorithm is equivalent to :cite:p:`1976:knapp` without pre-filtering.
"""
__DEFAULT_THRESHOLD: float = 0.9 # Default correlation threshold
__DEFAULT_GUARD_RATIO: float = 0.8 # Default guard ratio
__DEFAULT_PEAK_PROMINENCE: float = 0.2 # Default peak prominence
__threshold: float # Correlation threshold at which a pilot signal is detected
__guard_ratio: float # Guard ratio of frame duration
__peak_prominence: float # Minimum peak prominence for peak detection
def __init__(
self,
threshold: float = __DEFAULT_THRESHOLD,
guard_ratio: float = __DEFAULT_GUARD_RATIO,
peak_prominence: float = __DEFAULT_PEAK_PROMINENCE,
*args: Any,
**kwargs: Any,
) -> None:
"""
Args:
threshold:
Correlation threshold at which a pilot signal is detected.
guard_ratio:
Guard ratio of frame duration.
peak_prominence:
Minimum peak prominence for peak detection in the interval (0, 1].
:math:`0.2` is a good default value for most applications.
*args:
Synchronization base class initialization parameters.
"""
self.threshold = threshold
self.guard_ratio = guard_ratio
self.__peak_prominence = peak_prominence
Synchronization.__init__(self, *args, **kwargs)
@property
def threshold(self) -> float:
"""Correlation threshold at which a pilot signal is detected.
Returns:
float: Threshold between zero and one.
Raises:
ValueError: If threshold is smaller than zero or greater than one.
"""
return self.__threshold
@threshold.setter
def threshold(self, value: float):
"""Set correlation threshold at which a pilot signal is detected."""
if value < 0.0 or value > 1.0:
raise ValueError("Synchronization threshold must be between zero and one.")
self.__threshold = value
@property
def guard_ratio(self) -> float:
"""Correlation guard ratio at which a pilot signal is detected.
After the detection of a pilot section, `guard_ratio` prevents the detection of another pilot in
the following samples for a span relative to the configured frame duration.
Returns:
float: Guard Ratio between zero and one.
Raises:
ValueError: If guard ratio is smaller than zero or greater than one.
"""
return self.__guard_ratio
@guard_ratio.setter
def guard_ratio(self, value: float):
"""Set correlation guard ratio at which a pilot signal is detected."""
if value < 0.0 or value > 1.0:
raise ValueError("Synchronization guard ratio must be between zero and one.")
self.__guard_ratio = value
[docs]
def synchronize(self, signal: np.ndarray) -> List[int]:
# Expand the dimensionality for flat signal streams
if signal.ndim == 1:
signal = signal[np.newaxis, :]
# Query the pilot signal from the waveform generator
pilot_sequence = self.waveform.pilot_signal.getitem().flatten()
# Raise a runtime error if pilot sequence is empty
if len(pilot_sequence) < 1:
raise RuntimeError(
"No pilot sequence configured, time-domain correlation synchronization impossible"
)
# Compute the correlation between each signal stream and the pilot sequence, sum up as a result
correlation = np.zeros(len(pilot_sequence) + signal.shape[1] - 1, dtype=float)
for stream in signal:
correlation += abs(correlate(stream, pilot_sequence, mode="full", method="fft"))
correlation /= correlation.max() # Normalize correlation
# Determine the pilot sequence locations by performing a peak search over the correlation profile
frame_length = self.waveform.samples_per_frame
pilot_indices, _ = find_peaks(
abs(correlation), height=0.9, distance=int(0.8 * frame_length)
)
# Abort if no pilot section has been detected
if len(pilot_indices) < 1:
return []
# Correct pilot indices by the convolution length
pilot_length = len(pilot_sequence)
pilot_indices -= pilot_length - 1
# Correct infeasible pilot index choices
pilot_indices = np.where(pilot_indices < 0, 0, pilot_indices)
pilot_indices = np.where(
pilot_indices > (signal.shape[1] - frame_length),
abs(signal.shape[1] - frame_length),
pilot_indices,
)
return pilot_indices.tolist()
[docs]
@override
def serialize(self, process: SerializationProcess) -> None:
process.serialize_floating(self.threshold, "threshold")
process.serialize_floating(self.guard_ratio, "guard_ratio")
process.serialize_floating(self.__peak_prominence, "peak_prominence")
[docs]
@override
@classmethod
def Deserialize(cls, process: DeserializationProcess) -> CorrelationSynchronization:
return cls(
threshold=process.deserialize_floating("threshold", cls.__DEFAULT_THRESHOLD),
guard_ratio=process.deserialize_floating("guard_ratio", cls.__DEFAULT_GUARD_RATIO),
peak_prominence=process.deserialize_floating(
"peak_prominence", cls.__DEFAULT_PEAK_PROMINENCE
),
)