Source code for demos.dpd.src.rx

# SPDX-License-Identifier: MIT
# Copyright (c) 2025–present Srikanth Pagadarai

"""
OFDM receiver for DPD evaluation.

Implements a minimal OFDM receiver chain for measuring DPD
performance via Error Vector Magnitude (EVM). The receiver handles the
full path from PA output back to constellation symbols, enabling direct
comparison of signal quality with and without predistortion.

The receiver operates only during inference (not training):
1. DPD training uses indirect learning on PA input/output, not decoded symbols
2. EVM is an evaluation metric, not a training objective

The receive chain uses NumPy/SciPy for resampling (scipy.signal.resample_poly)
because this runs only at inference time where graph-mode compatibility
is not required.
"""

import numpy as np
import tensorflow as tf
from fractions import Fraction
from scipy.signal import resample_poly
from sionna.phy.ofdm import OFDMDemodulator


[docs] class Rx(tf.keras.layers.Layer): """ Minimal OFDM receiver chain for DPD performance evaluation. Implements the full receive path from PA output to equalized QAM symbols, with EVM computation for quantifying signal distortion. This receiver is designed for DPD evaluation, not for conducting end-to-end communication link performance evaluation. Parameters ---------- signal_fs : float Baseband signal sample rate in Hz (e.g., 15.36 MHz for 15 kHz subcarrier spacing with 1024-point FFT). pa_sample_rate : float PA operating sample rate in Hz (typically 8x signal rate for adequate reconstruction of PA nonlinear products). fft_size : int OFDM FFT size (number of subcarriers including guards). cp_length : int Cyclic prefix length in samples. num_ofdm_symbols : int Number of OFDM symbols per slot. num_guard_lower : int Number of guard subcarriers at lower band edge. num_guard_upper : int Number of guard subcarriers at upper band edge. dc_null : bool Whether the DC subcarrier is nulled. **kwargs Additional keyword arguments passed to Keras Layer. Attributes ---------- _ofdm_demod : OFDMDemodulator Sionna OFDM demodulator (FFT + CP removal). _lower_start, _lower_end : int Subcarrier indices for lower data band. _upper_start, _upper_end : int Subcarrier indices for upper data band. Notes ----- **Receiver Processing Steps:** 1. **Downsample**: Convert from PA rate to baseband rate 2. **Time sync**: Cross-correlation to find frame boundary 3. **OFDM demod**: Remove CP and apply FFT 4. **Equalize**: Zero-forcing per-subcarrier equalization 5. **EVM**: Compute error vector magnitude vs. reference **Why Zero-Forcing Equalization?** In this loopback test scenario, the channel is essentially flat (no multipath). ZF equalization corrects only for the PA's linear gain and phase offset. More sophisticated equalizers (MMSE, etc.) are unnecessary and would complicate DPD performance attribution. **miscellaneous:** - PA output signals must be at ``pa_sample_rate`` - Reference baseband signal must be at ``signal_fs`` - All signals must represent the same transmitted frame - Equalized symbols are normalized to reference constellation - EVM is returned as percentage (0-100+ scale) Example ------- >>> rx = Rx( ... signal_fs=15.36e6, ... pa_sample_rate=122.88e6, ... fft_size=1024, ... cp_length=72, ... num_ofdm_symbols=14, ... num_guard_lower=200, ... num_guard_upper=199, ... dc_null=True, ... ) >>> results = rx.process_and_compute_evm( ... pa_input, pa_output_no_dpd, pa_output_with_dpd, ... tx_baseband, fd_symbols ... ) >>> print(f"EVM with DPD: {results['evm_with_dpd']:.2f}%") """ def __init__( self, signal_fs: float, pa_sample_rate: float, fft_size: int, cp_length: int, num_ofdm_symbols: int, num_guard_lower: int, num_guard_upper: int, dc_null: bool, **kwargs, ): super().__init__(**kwargs) self._signal_fs = signal_fs self._pa_sample_rate = pa_sample_rate self._fft_size = fft_size self._cp_length = cp_length self._num_ofdm_symbols = num_ofdm_symbols self._num_guard_lower = num_guard_lower self._num_guard_upper = num_guard_upper self._dc_null = dc_null # Sionna's OFDM demodulator handles CP removal and FFT. # l_min=0 means no negative delay taps (single-path channel). self._ofdm_demod = OFDMDemodulator( fft_size=fft_size, l_min=0, cyclic_prefix_length=cp_length, ) # Precompute data subcarrier index ranges. # OFDM spectrum layout: [guard_lower | data_lower | DC | data_upper | guard_upper] # After FFT, indices 0..fft_size/2-1 are lower half, fft_size/2..fft_size-1 are upper. self._lower_start = num_guard_lower self._lower_end = fft_size // 2 self._upper_start = fft_size // 2 + (1 if dc_null else 0) self._upper_end = fft_size - num_guard_upper
[docs] def process_and_compute_evm( self, pa_input, pa_output_no_dpd, pa_output_with_dpd, tx_baseband, fd_symbols, ): """ Process PA outputs and compute EVM for DPD performance comparison. Runs three signal paths (PA input, PA output without DPD, PA output with DPD) through the complete receiver chain and computes EVM for each. This enables direct comparison of DPD effectiveness. Parameters ---------- pa_input : tf.Tensor or np.ndarray PA input signal at PA sample rate (reference for best-case EVM). pa_output_no_dpd : tf.Tensor or np.ndarray PA output without predistortion at PA sample rate. pa_output_with_dpd : tf.Tensor or np.ndarray PA output with predistortion at PA sample rate. tx_baseband : tf.Tensor or np.ndarray Original baseband transmit signal at signal sample rate (used as timing reference for synchronization). fd_symbols : tf.Tensor or np.ndarray Transmitted frequency-domain symbols, shape ``[num_data_subcarriers, num_ofdm_symbols]``. Used as reference for equalization and EVM calculation. Returns ------- dict Results dictionary containing: - ``symbols_input`` : np.ndarray Equalized constellation symbols from PA input path. - ``symbols_no_dpd`` : np.ndarray Equalized symbols from PA output without DPD. - ``symbols_with_dpd`` : np.ndarray Equalized symbols from PA output with DPD. - ``evm_input`` : float EVM (%) for PA input (baseline, should be near-zero). - ``evm_no_dpd`` : float EVM (%) for PA output without DPD (shows PA distortion). - ``evm_with_dpd`` : float EVM (%) for PA output with DPD (shows DPD improvement). Notes ----- The PA input path serves as a sanity check: its EVM should be very low since it hasn't passed through PA nonlinearity. """ # Flatten all signals to 1D for processing. # Input tensors may have batch/stream dimensions that aren't needed here. def flatten(x): if len(x.shape) > 1: return tf.reshape(x, [-1]).numpy() return x.numpy() if hasattr(x, "numpy") else x pa_input_flat = flatten(pa_input) pa_no_dpd_flat = flatten(pa_output_no_dpd) pa_with_dpd_flat = flatten(pa_output_with_dpd) tx_baseband_np = flatten(tx_baseband) # Step 1: Downsample from PA rate to baseband rate. # Using rational resampling to handle non-integer rate ratios. frac = Fraction(self._signal_fs / self._pa_sample_rate).limit_denominator(1000) data_input = resample_poly(pa_input_flat, frac.numerator, frac.denominator) data_no_dpd = resample_poly(pa_no_dpd_flat, frac.numerator, frac.denominator) data_with_dpd = resample_poly( pa_with_dpd_flat, frac.numerator, frac.denominator ) # Step 2: Time synchronization via cross-correlation. # Find the delay that maximizes correlation with known transmit signal. original_len = (self._fft_size + self._cp_length) * self._num_ofdm_symbols # Use only first portion of reference to reduce computation. sync_len = min(1000, len(tx_baseband_np) // 2) def find_delay(signal, ref): """Find sample delay by peak cross-correlation.""" return np.argmax(np.abs(np.correlate(signal, ref[:sync_len], mode="valid"))) delay_input = find_delay(data_input, tx_baseband_np) delay_no_dpd = find_delay(data_no_dpd, tx_baseband_np) delay_with_dpd = find_delay(data_with_dpd, tx_baseband_np) # Extract synchronized frame (exact length needed for OFDM demod). data_input_sync = data_input[delay_input : delay_input + original_len] data_no_dpd_sync = data_no_dpd[delay_no_dpd : delay_no_dpd + original_len] data_with_dpd_sync = data_with_dpd[ delay_with_dpd : delay_with_dpd + original_len ] # Step 3: OFDM demodulation (CP removal + FFT). symbols_input = self._demod(data_input_sync) symbols_no_dpd = self._demod(data_no_dpd_sync) symbols_with_dpd = self._demod(data_with_dpd_sync) # Step 4: Per-subcarrier zero-forcing equalization. symbols_input = self._equalize(symbols_input, fd_symbols) symbols_no_dpd = self._equalize(symbols_no_dpd, fd_symbols) symbols_with_dpd = self._equalize(symbols_with_dpd, fd_symbols) # Convert to numpy for EVM calculation. fd_np = fd_symbols.numpy() if isinstance(fd_symbols, tf.Tensor) else fd_symbols sym_input_np = symbols_input.numpy() sym_no_dpd_np = symbols_no_dpd.numpy() sym_with_dpd_np = symbols_with_dpd.numpy() # Step 5: Compute EVM as percentage. evm_input = self._compute_evm(sym_input_np, fd_np) evm_no_dpd = self._compute_evm(sym_no_dpd_np, fd_np) evm_with_dpd = self._compute_evm(sym_with_dpd_np, fd_np) return { "symbols_input": sym_input_np, "symbols_no_dpd": sym_no_dpd_np, "symbols_with_dpd": sym_with_dpd_np, "evm_input": evm_input, "evm_no_dpd": evm_no_dpd, "evm_with_dpd": evm_with_dpd, }
def _demod(self, signal): """ Demodulate time-domain OFDM signal to frequency-domain symbols. Parameters ---------- signal : np.ndarray Time-domain signal, shape ``[num_samples]``, at baseband rate. Length must equal ``(fft_size + cp_length) * num_ofdm_symbols``. Returns ------- tf.Tensor Data subcarrier symbols, shape ``[num_data_subcarriers, num_symbols]``. Guard bands and DC null are excluded. Notes ----- The output concatenates lower and upper data bands. This matches the order used by the transmitter's resource grid mapper. """ if not isinstance(signal, tf.Tensor): signal = tf.constant(signal, dtype=tf.complex64) # Reshape for Sionna demodulator: [batch, rx_ant, tx_ant, samples]. signal_4d = tf.reshape(signal, [1, 1, 1, -1]) # Demodulate: output is [batch, rx_ant, tx_ant, num_symbols, fft_size]. rg = self._ofdm_demod(signal_4d)[0, 0, 0, :, :] # [num_symbols, fft_size] # Extract data subcarriers (exclude guards and DC). # Transpose to get [subcarriers, symbols] ordering. fd_lower = tf.transpose(rg[:, self._lower_start : self._lower_end]) fd_upper = tf.transpose(rg[:, self._upper_start : self._upper_end]) return tf.concat([fd_lower, fd_upper], axis=0) def _equalize(self, rx, tx): """ Apply zero-forcing per-subcarrier equalization. Estimates a single complex gain per subcarrier by least-squares fit across all OFDM symbols, then divides received symbols by this estimate. Parameters ---------- rx : tf.Tensor Received symbols, shape ``[num_subcarriers, num_symbols]``. tx : tf.Tensor Transmitted reference symbols, same shape as ``rx``. Returns ------- tf.Tensor Equalized symbols, same shape as input. Notes ----- The channel estimate is computed as: .. math:: \\hat{H}_k = \\frac{\\sum_n r_{k,n} \\cdot t_{k,n}^*}{\\sum_n |t_{k,n}|^2} where :math:`k` is subcarrier index and :math:`n` is symbol index. This is the least-squares estimate assuming the channel is constant across all symbols (valid for this static loopback scenario). """ rx = tf.cast(rx, tf.complex64) tx = tf.cast(tx, tf.complex64) # Least-squares channel estimate per subcarrier. # H = sum(rx * conj(tx)) / sum(|tx|^2) H = tf.reduce_sum(rx * tf.math.conj(tx), axis=1, keepdims=True) / tf.cast( tf.reduce_sum(tf.abs(tx) ** 2, axis=1, keepdims=True), tf.complex64 ) # Zero-forcing equalization: divide by channel estimate. return rx / H @staticmethod def _compute_evm(rx, tx): """ Compute Error Vector Magnitude (EVM) as percentage. EVM quantifies the difference between received and ideal constellation points, normalized by the reference signal power. Parameters ---------- rx : np.ndarray Received (equalized) symbols. tx : np.ndarray Transmitted reference symbols (same shape as ``rx``). Returns ------- float EVM as percentage (e.g., 5.0 means 5% EVM). Notes ----- EVM is computed as: .. math:: \\text{EVM} = 100 \\times \\sqrt{\\frac{\\text{mean}(|r - t|^2)}{\\text{mean}(|t|^2)}} Lower EVM indicates better signal quality. Typical targets: - 64-QAM: < 8% EVM required - 256-QAM: < 3.5% EVM required """ error = rx - tx evm = np.sqrt(np.mean(np.abs(error) ** 2) / np.mean(np.abs(tx) ** 2)) * 100 return float(evm)