Source code for demos.pusch_autoencoder.src.pusch_neural_detector

# SPDX-License-Identifier: MIT
# Copyright (c) 2025–present Srikanth Pagadarai

"""
Neural MIMO detector for PUSCH with learned channel estimation refinement.

This module implements a hybrid classical/neural network architecture that
combines the reliability of LS channel estimation and LMMSE equalization
with the adaptability of neural networks. The key motivation behind this
design is that learning residual corrections to classical estimates is
more stable than learning detection from scratch.
"""

import tensorflow as tf
from tensorflow.keras.layers import Layer, Conv2D, LayerNormalization
from tensorflow.nn import leaky_relu
from sionna.phy.mimo import lmmse_equalizer
from sionna.phy.mapping import Demapper, Constellation
from sionna.phy.utils import log10

from .config import Config


[docs] class Conv2DResBlock(Layer): r""" Pre-activation residual block with two convolutional layers. Implements the pre-activation ResNet variant where normalization and activation precede each convolution. This ordering improves gradient flow and enables training of deeper networks. The block computes: ``output = input + conv2(act(norm(conv1(act(norm(input))))))`` Parameters ---------- filters : int Number of convolutional filters in both layers. kernel_size : tuple of int Spatial kernel size, default ``(3, 3)``. name : str, optional Layer name for TensorFlow graph. Notes ----- - Uses LayerNormalization (not BatchNorm) for stable training with small batch sizes typical in communication system simulation. - LeakyReLU with α=0.1 prevents dead neurons while maintaining near-linear behavior for small negative inputs. - Identity skip connection requires input and output channel counts to match; caller must ensure ``filters`` equals input channels. """ def __init__(self, filters: int, kernel_size=(3, 3), name: str = None): super().__init__(name=name) self.filters = int(filters) self.kernel_size = kernel_size self._layer_norm1 = LayerNormalization(axis=-1) self._conv1 = Conv2D( filters=self.filters, kernel_size=self.kernel_size, padding="same" ) self._layer_norm2 = LayerNormalization(axis=-1) self._conv2 = Conv2D( filters=self.filters, kernel_size=self.kernel_size, padding="same" )
[docs] def call(self, x): """ Apply residual transformation. Parameters ---------- x : tf.Tensor Input tensor of shape ``[batch, height, width, filters]``. Returns ------- tf.Tensor Output tensor with same shape as input (residual added). """ y = self._layer_norm1(x) y = leaky_relu(y, alpha=0.1) y = self._conv1(y) y = self._layer_norm2(y) y = leaky_relu(y, alpha=0.1) y = self._conv2(y) return x + y
[docs] class PUSCHNeuralDetector(Layer): r""" Neural MIMO detector with learned channel estimation refinement for PUSCH. This detector implements a residual learning architecture that refines classical LS channel estimates and LMMSE-based soft symbol estimates using convolutional neural networks. The key design principle is that learning corrections to strong classical baselines is more effective than learning detection from scratch. Architecture ------------ The detector processes data through six stages: 1. **Feature extraction**: Assembles input features including LS channel estimate, received signal, matched filter output, Gram matrix structure, estimation error variance, and noise level. 2. **Shared backbone**: Processes features through convolutional ResBlocks to learn joint representations useful for both CE refinement and detection. 3. **CE refinement head**: Predicts additive corrections Δh and multiplicative log-domain corrections Δlog(err_var) to the LS estimates. 4. **Scaled correction application**: Applies learned corrections scaled by trainable parameters that start at zero, enabling gradual departure from classical behavior during training. 5. **Classical LMMSE**: Performs LMMSE equalization using refined channel estimate and error variance, followed by max-log demapping. 6. **LLR refinement**: Injects LMMSE outputs back into the network to predict additive LLR corrections, again scaled by a trainable parameter. Parameters ---------- cfg : ~demos.pusch_autoencoder.src.config.Config System configuration containing MIMO dimensions, modulation order, and PUSCH resource grid information. num_conv2d_filters : int Number of filters in all convolutional layers. Higher values increase model capacity but also computational cost. Default 128. num_shared_res_blocks : int Number of ResBlocks in the shared backbone. Controls depth of joint feature learning. Default 4. num_det_res_blocks : int Number of ResBlocks in the detection continuation path. Controls capacity for LLR refinement. Default 6. kernel_size : tuple of int Spatial kernel size for ResBlock convolutions. Default ``(3, 3)``. Pre-conditions -------------- - ``cfg.pusch_pilot_indices`` must be set (by ``PUSCHLinkE2E``) before instantiation to enable pilot/data symbol separation. - Input tensors must follow Sionna's PUSCH dimension conventions. Post-conditions --------------- - ``trainable_variables`` returns correction scales first, then network weights (enables separate optimizer configuration). - ``last_h_hat_refined`` and ``last_err_var_refined`` contain the most recent refined estimates (useful for auxiliary losses). Invariants ---------- - Correction scales initialized to 0.0, so initial behavior matches classical LMMSE exactly (graceful degradation if training fails). - Error variance correction scale is always positive (softplus transform). - Output LLR shape matches Sionna's ``PUSCHReceiver`` expectations. Example ------- >>> cfg = Config() >>> # ... set cfg.pusch_pilot_indices via PUSCHLinkE2E ... >>> detector = PUSCHNeuralDetector(cfg, num_conv2d_filters=64) >>> llr = detector(y, h_hat, err_var, no) Notes ----- The trainable correction scales serve multiple purposes: 1. **Safe initialization**: Starting at 0.0 means the detector initially behaves exactly like classical LMMSE, providing a stable starting point. 2. **Interpretability**: Scale magnitudes indicate how much the network deviates from classical processing (useful for debugging). 3. **Gradient balancing**: Separate scales for h, err_var, and LLR allow independent learning rates for different correction types. The error variance scale uses softplus to ensure positivity, since negative variance would be physically meaningless and cause numerical issues in LMMSE computation. """ def __init__( self, cfg: Config, num_conv2d_filters: int = 128, num_shared_res_blocks: int = 4, num_det_res_blocks: int = 3, kernel_size=(3, 3), ): super().__init__() self._cfg = cfg self.num_conv2d_filters = int(num_conv2d_filters) self.num_shared_res_blocks = int(num_shared_res_blocks) self.num_det_res_blocks = int(num_det_res_blocks) self.kernel_size = kernel_size # ===================================================================== # Extract Dimensions from Config # ===================================================================== self._num_bits_per_symbol = int(self._cfg.num_bits_per_symbol) self._num_ue = int(self._cfg.num_ue) self._num_streams_per_ue = int(self._cfg.num_layers) self._num_streams_total = self._num_ue * self._num_streams_per_ue self._num_bs = int(self._cfg.num_bs) self._num_bs_ant = int(self._cfg.num_bs_ant) self._num_rx_ant = self._num_bs * self._num_bs_ant self._pusch_pilot_indices = list(self._cfg.pusch_pilot_indices) self._pusch_num_symbols_per_slot = int(self._cfg.pusch_num_symbols_per_slot) # Separate pilot and data symbol indices for selective processing. # CE refinement uses all symbols; detection uses data symbols only. all_symbols = list(range(self._pusch_num_symbols_per_slot)) pilots = set(self._pusch_pilot_indices) self._data_symbol_indices = [s for s in all_symbols if s not in pilots] # ===================================================================== # Constellation and Demapper # ===================================================================== # Initialize with standard QAM; points may be overwritten during call() # if a trainable constellation is provided by the transmitter. qam_points = Constellation("qam", self._num_bits_per_symbol).points self._constellation = Constellation( "custom", num_bits_per_symbol=self._num_bits_per_symbol, points=qam_points, normalize=False, ) self._demapper = Demapper("app", constellation=self._constellation) # ===================================================================== # Compute Input Feature Dimensions # ===================================================================== # The shared backbone receives a concatenation of multiple feature types, # each providing complementary information about the channel and signal. S = self._num_streams_total Nr = self._num_rx_ant self._c_in_shared = ( 2 * Nr * S # h_ls: real and imaginary parts of channel estimate + 2 * Nr # y: real and imaginary parts of received signal + 2 * S # z_mf: matched filter output (H^H @ y), captures signal energy + S # gram_diag: diagonal of H^H @ H, indicates per-stream SNR + S * (S - 1) # gram_offdiag: inter-stream interference structure + S # err_var: channel estimation error variance per stream + 1 # no: noise variance (log scale for numerical stability) ) # [autoencoder-definition-start] # ===================================================================== # Shared Backbone Network # ===================================================================== # Processes all input features to learn representations useful for both # channel estimation refinement and detection. Sharing weights between # these tasks acts as implicit regularization and reduces parameters. self._shared_conv_in = Conv2D( filters=self.num_conv2d_filters, kernel_size=(3, 3), padding="same", name="shared_conv_in", ) self._shared_res_blocks = [ Conv2DResBlock( filters=self.num_conv2d_filters, kernel_size=self.kernel_size, name=f"shared_resblock_{i}", ) for i in range(self.num_shared_res_blocks) ] # ===================================================================== # Channel Estimation Refinement Head # ===================================================================== # Lightweight head that projects shared features to channel corrections. # Separate outputs for Δh (complex) and Δlog(err_var) (real). self._ce_head_conv1 = tf.keras.Sequential( [ Conv2D(self.num_conv2d_filters, (3, 3), padding="same"), LayerNormalization(axis=-1), tf.keras.layers.LeakyReLU(0.1), Conv2D(self.num_conv2d_filters, (3, 3), padding="same"), LayerNormalization(axis=-1), tf.keras.layers.LeakyReLU(0.1), ], name="ce_head_conv1", ) # Output layer for channel correction: 2 * Nr * S values (real + imag) self._ce_head_out_h = Conv2D( filters=2 * Nr * S, kernel_size=(1, 1), padding="same", activation=None, name="ce_head_out_h", ) # Output layer for error variance log-correction: S values per RE self._ce_head_out_loge = Conv2D( filters=S, kernel_size=(1, 1), padding="same", activation=None, name="ce_head_out_loge", ) # ===================================================================== # Detection Continuation Network # ===================================================================== # After LMMSE equalization with refined estimates, this network learns # to correct the resulting LLRs based on both the original features # and the LMMSE outputs (symbol estimates, effective noise). self._c_lmmse_feats = ( 2 * S # x_lmmse: equalized symbols (real + imag) + S # no_eff: post-equalization noise variance (log scale) + S * self._num_bits_per_symbol # llr_lmmse: baseline soft bits ) # Injection convolution fuses shared backbone features with LMMSE outputs self._det_inject_conv = Conv2D( filters=self.num_conv2d_filters, kernel_size=(3, 3), padding="same", name="det_inject_conv", ) # Additional ResBlocks for LLR refinement self._det_res_blocks = [ Conv2DResBlock( filters=self.num_conv2d_filters, kernel_size=self.kernel_size, name=f"det_resblock_{i}", ) for i in range(self.num_det_res_blocks) ] # Output convolution produces LLR corrections for all bits self._det_conv_out = Conv2D( filters=S * self._num_bits_per_symbol, kernel_size=(3, 3), padding="same", activation=None, name="det_conv_out", ) # ===================================================================== # Trainable Correction Scales # ===================================================================== # Initialize all scales to 0.0 so that initial output matches classical # LMMSE exactly. This provides a stable starting point and enables # graceful degradation if training fails to improve on classical. # Channel estimate correction: h_refined = h_ls + scale * delta_h # Unbounded since corrections can be positive or negative. self._h_correction_scale = tf.Variable( 0.0, trainable=True, name="h_correction_scale", dtype=tf.float32 ) # LLR correction: llr_final = llr_lmmse + scale * delta_llr # Unbounded to allow both confidence increase and decrease. self._llr_correction_scale = tf.Variable( 0.0, trainable=True, name="llr_correction_scale", dtype=tf.float32 ) # Error variance correction in log domain for numerical stability: # err_var_refined = exp(log(err_var) + scale * delta_log_err) # Uses softplus(raw_value) to ensure positivity; softplus(0) = ln(2) ≈ 0.69 # but we want initial scale ≈ 1.0, and softplus(0.54) ≈ 1.0 # For simplicity, initialize to 0.0; the network will adapt. self._err_var_correction_scale_raw = tf.Variable( 0.0, trainable=True, name="err_var_correction_scale_raw", dtype=tf.float32 ) # [autoencoder-definition-end] # ===================================================================== # State for Auxiliary Losses # ===================================================================== # These attributes store refined estimates from the most recent forward # pass, enabling auxiliary loss computation (e.g., MSE on h_hat). self.last_h_hat_refined = None self.last_err_var_refined = None self.last_err_var_refined_flat = None @property def err_var_correction_scale(self): """ Effective error variance correction scale (always positive). Returns ------- tf.Tensor Scalar tensor with softplus-transformed scale value. softplus(x) = log(1 + exp(x)) ensures output > 0. """ return tf.nn.softplus(self._err_var_correction_scale_raw) @property def trainable_variables(self): """ Collect all trainable variables in a specific order. The ordering places correction scales first, enabling separate optimizer configuration (e.g., higher learning rate for scales). Returns ------- list of tf.Variable Ordered list: [correction_scales, backbone_weights, ce_head_weights, detection_weights]. """ vars_ = [] # Correction scales first (for easy separate optimizer setup) vars_ += [self._h_correction_scale] vars_ += [self._err_var_correction_scale_raw] vars_ += [self._llr_correction_scale] # Shared backbone vars_ += self._shared_conv_in.trainable_variables for block in self._shared_res_blocks: vars_ += block.trainable_variables # CE head vars_ += self._ce_head_conv1.trainable_variables vars_ += self._ce_head_out_h.trainable_variables vars_ += self._ce_head_out_loge.trainable_variables # Detection continuation vars_ += self._det_inject_conv.trainable_variables for block in self._det_res_blocks: vars_ += block.trainable_variables vars_ += self._det_conv_out.trainable_variables return vars_ def _reshape_logits_to_llr(self, logits, num_data_symbols): """ Reshape detector output to match Sionna's PUSCHReceiver LLR format. Parameters ---------- logits : tf.Tensor Network output, shape ``[B, H_data, W, S * num_bits_per_symbol]``. num_data_symbols : int or tf.Tensor Total number of data resource elements (H_data * W). Returns ------- tf.Tensor Reshaped LLRs, shape ``[B, num_ue, streams_per_ue, num_data_symbols * bits]``. Notes ----- The reshape sequence preserves bit ordering while reorganizing from spatial (H, W) layout to the flattened per-UE format expected by the transport block decoder. """ B = tf.shape(logits)[0] logits = tf.reshape( logits, [B, num_data_symbols, self._num_streams_total, self._num_bits_per_symbol], ) logits = tf.reshape( logits, [ B, num_data_symbols, self._num_ue, self._num_streams_per_ue, self._num_bits_per_symbol, ], ) logits = tf.transpose(logits, [0, 2, 3, 1, 4]) llr = tf.reshape( logits, [ B, self._num_ue, self._num_streams_per_ue, num_data_symbols * self._num_bits_per_symbol, ], ) llr.set_shape([None, self._num_ue, self._num_streams_per_ue, None]) return llr
[docs] def call(self, y, h_hat, err_var, no, constellation=None, training=None): """ Execute neural MIMO detection with channel estimation refinement. Parameters ---------- y : tf.Tensor, complex64 Received OFDM signal after FFT. Shape: ``[B, num_bs, num_bs_ant, num_ofdm_syms, num_subcarriers]`` h_hat : tf.Tensor, complex64 LS channel estimate from DMRS processing. Shape: ``[B, num_bs, num_bs_ant, num_ue, num_streams, num_ofdm_syms, num_subcarriers]`` err_var : tf.Tensor, float32 Channel estimation error variance per stream and RE. Shape: ``[B, 1, 1, num_ue, num_streams, num_ofdm_syms, num_subcarriers]`` no : tf.Tensor, float32 Noise variance, shape ``[B]`` or scalar. constellation : tf.Tensor, complex64, optional Trainable constellation points from transmitter. If provided, updates the demapper's constellation for consistent symbol mapping. Shape: ``[num_constellation_points]`` training : bool, optional Training mode flag (currently unused, for Keras API compatibility). Returns ------- tf.Tensor, float32 Log-likelihood ratios for all coded bits. Shape: ``[B, num_ue, num_streams_per_ue, num_data_symbols * num_bits_per_symbol]`` Notes ----- The processing flow: 1. **Feature assembly**: Extracts and concatenates multiple signal representations (channel, received signal, matched filter, Gram matrix structure, estimation error, noise level). 2. **Shared backbone**: Processes features through ResBlocks to learn joint representations. 3. **CE refinement**: Predicts channel and error variance corrections, applies them scaled by trainable parameters. 4. **LMMSE equalization**: Classical equalization with refined estimates, using whitened interference model for improved performance. 5. **LLR refinement**: Predicts additive corrections to LMMSE LLRs, enabling learned compensation for model mismatch. Channel estimation refinement operates on ALL OFDM symbols (including pilots) to leverage full spatial-frequency context, while detection operates only on DATA symbols to produce the final LLRs. """ # ===================================================================== # Input Preparation # ===================================================================== y = tf.cast(y, tf.complex64) h_hat = tf.cast(h_hat, tf.complex64) err_var = tf.cast(err_var, tf.float32) no = tf.cast(no, tf.float32) # Synchronize demapper constellation with transmitter if trainable if constellation is not None: self._constellation.points = tf.cast(constellation, tf.complex64) # Precompute data symbol indices for later slicing data_idx = tf.constant(self._data_symbol_indices, dtype=tf.int32) # ===================================================================== # Extract Dimensions # ===================================================================== # Process full OFDM grid for CE refinement (pilots provide context) B = tf.shape(y)[0] H = tf.shape(y)[3] # num_ofdm_syms (including pilots) W = tf.shape(y)[4] # num_subcarriers S = self._num_streams_total Nr = self._num_rx_ant # ===================================================================== # Reshape Inputs to Conv2D-Friendly Format [B, H, W, C] # ===================================================================== # Sionna uses [B, rx, tx, time, freq] convention; we reshape to # [B, time, freq, features] for 2D convolution over the OFDM grid. # y: [B, num_bs, num_bs_ant, H, W] -> [B, H, W, Nr] y_flat = tf.reshape(y, [B, -1, H, W]) y_flat = tf.transpose(y_flat, [0, 2, 3, 1]) # h_hat: [B, num_bs, num_bs_ant, num_ue, streams, H, W] -> [B, H, W, Nr, S] h_flat = tf.reshape(h_hat, [B, -1, S, H, W]) h_flat = tf.transpose(h_flat, [0, 3, 4, 1, 2]) # err_var: [B, 1, 1, num_ue, streams, H, W] -> [B, H, W, S] err_var_t = tf.squeeze(err_var, axis=[1, 2]) err_var_flat = tf.transpose(err_var_t, [0, 3, 4, 1, 2]) err_var_flat = tf.reshape(err_var_flat, [B, H, W, S]) # ===================================================================== # Compute Derived Features # ===================================================================== # These features provide the network with multiple views of the signal, # each capturing different aspects of the channel and interference. # Matched filter output: z_mf = H^H @ y # Captures signal energy and provides a sufficient statistic for detection h_conj_t = tf.transpose(tf.math.conj(h_flat), [0, 1, 2, 4, 3]) # [B,H,W,S,Nr] y_col = y_flat[..., tf.newaxis] # [B,H,W,Nr,1] z_mf = tf.squeeze(tf.matmul(h_conj_t, y_col), axis=-1) # [B,H,W,S] # Gram matrix: G = H^H @ H # Diagonal elements indicate per-stream channel gain # Off-diagonal elements capture inter-stream interference gram = tf.matmul(h_conj_t, h_flat) # [B,H,W,S,S] gram_diag = tf.math.real(tf.linalg.diag_part(gram)) # [B,H,W,S] # Extract off-diagonal elements (interference structure) mask = 1.0 - tf.eye(S, dtype=tf.float32) mask = mask[tf.newaxis, tf.newaxis, tf.newaxis, :, :] gram_masked = gram * tf.cast(mask, gram.dtype) gram_offdiag = tf.abs(gram_masked) gram_offdiag_flat = tf.reshape(gram_offdiag, [B, H, W, -1]) # Select only off-diagonal entries (S*(S-1) values per RE) indices = [i * S + j for i in range(S) for j in range(S) if i != j] gram_offdiag_feats = tf.gather(gram_offdiag_flat, indices, axis=-1) # Channel estimate features (real and imaginary parts) h_flat_features = tf.reshape(h_flat, [B, H, W, Nr * S]) h_feats = tf.concat( [tf.math.real(h_flat_features), tf.math.imag(h_flat_features)], axis=-1 ) # Received signal features (real and imaginary parts) y_feats = tf.concat([tf.math.real(y_flat), tf.math.imag(y_flat)], axis=-1) # Matched filter features (real and imaginary parts) z_mf_feats = tf.concat([tf.math.real(z_mf), tf.math.imag(z_mf)], axis=-1) # Noise variance in log scale for numerical stability and scale invariance no_log = log10(no + 1e-10) no_feat = tf.broadcast_to( no_log[:, tf.newaxis, tf.newaxis, tf.newaxis], [B, H, W, 1] ) # ===================================================================== # Assemble Shared Input Features # ===================================================================== shared_input = tf.concat( [ h_feats, # 2 * Nr * S: channel estimate y_feats, # 2 * Nr: received signal z_mf_feats, # 2 * S: matched filter output gram_diag, # S: per-stream channel power gram_offdiag_feats, # S * (S-1): interference structure err_var_flat, # S: estimation error variance no_feat, # 1: noise level ], axis=-1, ) shared_input = tf.cast(shared_input, tf.float32) # [shared-backbone-start] # ===================================================================== # Shared Backbone Forward Pass # ===================================================================== shared_features = self._shared_conv_in(shared_input) for block in self._shared_res_blocks: shared_features = block(shared_features) # shared_features: [B, H, W, num_filters] # [shared-backbone-end] # [ce-head-start] # ===================================================================== # Channel Estimation Refinement Head # ===================================================================== ce_hidden = self._ce_head_conv1(shared_features) delta_h_raw = self._ce_head_out_h(ce_hidden) # [B,H,W, 2*Nr*S] delta_loge = self._ce_head_out_loge(ce_hidden) # [B,H,W, S] # Parse channel correction into complex format delta_h_raw = tf.cast(delta_h_raw, tf.float32) delta_h_r = delta_h_raw[..., : Nr * S] delta_h_i = delta_h_raw[..., Nr * S :] delta_h_c = tf.complex(delta_h_r, delta_h_i) delta_h_c = tf.reshape(delta_h_c, [B, H, W, Nr, S]) # ===================================================================== # Apply Scaled Channel Refinement # ===================================================================== # Additive correction: h_refined = h_ls + scale * delta_h # Scale starts at 0, so initial behavior matches LS exactly. h_scale = tf.cast(self._h_correction_scale, tf.complex64) h_flat_refined = h_flat + h_scale * tf.cast(delta_h_c, h_flat.dtype) # ===================================================================== # Apply Scaled Error Variance Refinement (Log Domain) # ===================================================================== # Multiplicative correction in log domain for numerical stability: # err_var_refined = exp(log(err_var) + scale * delta_log_err) # This ensures err_var remains positive regardless of delta magnitude. err_var_scale = self.err_var_correction_scale # softplus-transformed log_err = tf.math.log(err_var_flat + 1e-10) log_err_refined = log_err + err_var_scale * tf.cast(delta_loge, log_err.dtype) err_var_flat_refined = tf.exp(log_err_refined) # [ce-head-end] # ===================================================================== # Store Refined Estimates for Auxiliary Losses # ===================================================================== # Reshape back to Sionna's dimension convention for external access # h_hat_refined: [B,H,W,Nr,S] -> [B, num_bs, num_bs_ant, num_ue, streams, H, W] h_ref_t = tf.transpose(h_flat_refined, [0, 3, 4, 1, 2]) # [B,Nr,S,H,W] h_ref_t = tf.reshape( h_ref_t, [ B, self._num_bs, self._num_bs_ant, self._num_ue, self._num_streams_per_ue, H, W, ], ) self.last_h_hat_refined = h_ref_t # err_var_refined: [B,H,W,S] -> [B,1,1,num_ue,streams,H,W] ev_ref = tf.reshape( err_var_flat_refined, [B, H, W, self._num_ue, self._num_streams_per_ue] ) ev_ref = tf.transpose(ev_ref, [0, 3, 4, 1, 2]) ev_ref = ev_ref[:, tf.newaxis, tf.newaxis, ...] self.last_err_var_refined = ev_ref self.last_err_var_refined_flat = tf.cast(err_var_flat_refined, tf.float32) # ===================================================================== # LMMSE Equalization on Data Symbols Only # ===================================================================== # Slice to data symbols (exclude pilots) for detection y_flat_data = tf.gather(y_flat, data_idx, axis=1) # [B, H_data, W, Nr] shared_features_data = tf.gather( shared_features, data_idx, axis=1 ) # [B, H_data, W, F] h_flat_refined_data = tf.gather( h_flat_refined, data_idx, axis=1 ) # [B, H_data, W, Nr, S] err_var_flat_refined_data = tf.gather( err_var_flat_refined, data_idx, axis=1 ) # [B, H_data, W, S] H_data = tf.shape(y_flat_data)[1] num_data_symbols = H_data * W # Build noise covariance matrix for LMMSE # Total noise = AWGN + channel estimation error (summed over streams) no_expanded = no[:, tf.newaxis, tf.newaxis] sum_err_var = tf.reduce_sum(err_var_flat_refined_data, axis=-1) total_noise_var = no_expanded + sum_err_var eye = tf.eye(Nr, dtype=tf.complex64)[tf.newaxis, tf.newaxis, tf.newaxis, :, :] s_cov_data = ( tf.cast(total_noise_var[..., tf.newaxis, tf.newaxis], tf.complex64) * eye ) # LMMSE equalization with whitened interference model x_lmmse, no_eff = lmmse_equalizer( y_flat_data, h_flat_refined_data, s_cov_data, whiten_interference=True ) # x_lmmse: [B, H_data, W, S] - equalized symbols # no_eff: [B, H_data, W, S] - effective noise variance per stream # ===================================================================== # Demapping to Baseline LLRs # ===================================================================== x_lmmse_flat_dm = tf.reshape(x_lmmse, [-1]) no_eff_flat_dm = tf.reshape(no_eff, [-1]) llr_lmmse_flat = self._demapper(x_lmmse_flat_dm, no_eff_flat_dm) llr_lmmse = tf.reshape( llr_lmmse_flat, [B, H_data, W, S, self._num_bits_per_symbol] ) llr_lmmse = tf.reshape(llr_lmmse, [B, H_data, W, S * self._num_bits_per_symbol]) # ===================================================================== # Build LMMSE Features for Detection Continuation # ===================================================================== # Provide the network with LMMSE outputs to enable learned refinement x_lmmse_feats = tf.concat( [tf.math.real(x_lmmse), tf.math.imag(x_lmmse)], axis=-1 ) # 2*S no_eff_feats = tf.math.log(no_eff + 1e-10) # S (log scale) lmmse_features = tf.concat( [ x_lmmse_feats, # 2 * S: equalized symbols no_eff_feats, # S: effective noise variance llr_lmmse, # S * bits: baseline LLRs ], axis=-1, ) lmmse_features = tf.cast(lmmse_features, tf.float32) # ===================================================================== # Detection Continuation Network # ===================================================================== # Fuse shared backbone features with LMMSE outputs combined_features = tf.concat([shared_features_data, lmmse_features], axis=-1) # [det-head-start] # Injection convolution reduces dimensions and fuses information det_features = self._det_inject_conv(combined_features) # Additional ResBlocks for LLR refinement for block in self._det_res_blocks: det_features = block(det_features) # Predict LLR corrections llr_correction = self._det_conv_out(det_features) # [det-head-end] # ===================================================================== # Final LLR with Scaled Correction # ===================================================================== # Additive correction: llr_final = llr_lmmse + scale * delta_llr # Scale starts at 0, so initial behavior matches LMMSE exactly. llr_final = llr_lmmse + self._llr_correction_scale * llr_correction return self._reshape_logits_to_llr(llr_final, num_data_symbols)