Source code for demos.mimo_ofdm_neural_receiver.src.neural_rx

# SPDX-License-Identifier: MIT
# Copyright (c) 2025–present Srikanth Pagadarai

"""
Neural receiver for MIMO-OFDM simulation.

Implements a convolutional neural network-based receiver that
learns to map received OFDM signals directly to log-likelihood ratios (LLRs),
bypassing traditional channel estimation and equalization stages:

    Received Signal (y) + Noise Power (no) -> CNN -> LLRs -> [LDPC Decoder]

The architecture uses a ResNet-style design with:
- Input convolution to expand channel dimension
- Stack of residual blocks with layer normalization
- Output convolution producing per-bit LLR predictions

Key design decisions:

1. **End-to-end learning**: The network jointly learns channel estimation,
   equalization, and demapping in a single differentiable pipeline.

2. **Noise power as input**: Feeding log10(no) helps the network adapt its
   behavior across different SNR operating points.

3. **Training mode**: When ``channel_coding_off=True``, LDPC decoding is
   skipped and raw LLRs are returned for BCE loss computation.
"""

import tensorflow as tf
import sionna as sn
from sionna.phy.fec.ldpc import LDPC5GEncoder, LDPC5GDecoder
from tensorflow.keras.layers import Layer, Conv2D, LayerNormalization
from tensorflow.nn import relu

from .config import Config


[docs] class ResidualBlock(Layer): """ Residual block with convolutions and layer normalization. Implements a pre-activation residual block where normalization and activation precede each convolution. The skip connection enables gradient flow through deep networks and allows the block to learn residual refinements rather than full transformations. Architecture per layer: LayerNorm -> ReLU -> Conv2D(3x3) The block applies ``num_resnet_layers`` such layers sequentially, then adds the input via skip connection. Parameters ---------- num_conv2d_filters : int, (default 128) Number of output channels for each convolution. All convolutions in the block use the same filter count. num_resnet_layers : int, (default 2) Number of normalization-activation-convolution sequences in the block. Must be at least 1. Raises ------ ValueError If ``num_resnet_layers < 1``. Note ---- Layer normalization is applied over spatial and channel dimensions (axes -1, -2, -3) rather than batch normalization. This provides more stable training with small batch sizes and varying SNR conditions. The 3x3 kernel with 'same' padding preserves spatial dimensions, allowing the skip connection to work without dimension adjustment. """ def __init__(self, num_conv2d_filters: int = 128, num_resnet_layers: int = 2): """ Initialize residual block layers. Parameters ---------- num_conv2d_filters : int, (default 128) Output channels per convolution. num_resnet_layers : int, (default 2) Depth of the residual block (number of conv layers). Post-conditions --------------- - ``_layer_norms`` contains ``num_resnet_layers`` LayerNorm instances. - ``_convs`` contains ``num_resnet_layers`` Conv2D instances. - All convolutions use 3x3 kernels with 'same' padding. """ super().__init__() if num_resnet_layers < 1: raise ValueError("num_resnet_layers must be >= 1") self.num_conv2d_filters = int(num_conv2d_filters) self.num_resnet_layers = int(num_resnet_layers) # Pre-activation design: LayerNorm -> ReLU -> Conv for each layer self._layer_norms = [ LayerNormalization(axis=(-1, -2, -3)) for _ in range(self.num_resnet_layers) ] self._convs = [ Conv2D( filters=self.num_conv2d_filters, kernel_size=(3, 3), padding="same", activation=None, ) for _ in range(self.num_resnet_layers) ] # [resblock-call-start]
[docs] def call(self, inputs): """ Apply residual transformation to input tensor. Parameters ---------- inputs : tf.Tensor, float32, [batch, height, width, channels] Input feature maps. Channel dimension must match ``num_conv2d_filters`` for the skip connection to work. Returns ------- tf.Tensor, float32, [batch, height, width, channels] Output feature maps with same shape as input. Pre-conditions -------------- - Input must be float32 (assertion checks this for debugging). - Input channels should equal ``num_conv2d_filters``. Post-conditions --------------- - Output shape equals input shape. - Output = transform(input) + input (residual connection). Invariants ---------- - Spatial dimensions are preserved (3x3 conv with 'same' padding). """ z = inputs for ln, conv in zip(self._layer_norms, self._convs): # Debug assertion: catch dtype issues early in development tf.debugging.assert_type(z, tf.float32) z = ln(z) z = relu(z) z = conv(z) # Skip connection: enables gradient flow and residual learning return z + inputs
# [resblock-call-end]
[docs] class NeuralRx(Layer): """ Convolutional neural receiver mapping received signals to LLRs. This network replaces the traditional channel estimation, equalization, and demapping stages with a learned CNN that directly produces log-likelihood ratios for each coded bit. The architecture processes the received signal across a time-frequency dimensional grid. Architecture: 1. **Input preparation**: Concatenate [Re(y), Im(y), log10(no)] 2. **Input convolution**: Expand to ``num_conv2d_filters`` channels 3. **Residual stack**: ``num_res_blocks`` residual blocks 4. **Output convolution**: Reduce to ``num_streams x bits_per_symbol`` 5. **Reshape**: Reorganize to per-stream, per-bit LLR format 6. **Resource grid demapper**: Extract data symbol positions 7. **LDPC decoder** (optional): Decode to information bits Parameters ---------- cfg : ~demos.mimo_ofdm_neural_receiver.src.config.Config Configuration containing resource grid, modulation, and code params. channel_coding_off : bool, (default False) If True, skip LDPC decoding and return raw LLRs. Used during training to compute BCE loss against transmitted coded bits. num_conv2d_filters : int, (default 128) Channel dimension throughout the residual stack. num_resnet_layers : int, (default 2) Number of conv layers per residual block. num_res_blocks : int, (default 4) Number of residual blocks in the network. Attributes ---------- _cfg : ~demos.mimo_ofdm_neural_receiver.src.config.Config Reference to configuration object. _channel_coding_off : bool Whether to skip LDPC decoding. Note ---- The noise power is fed in log10 scale because: 1. SNR varies over orders of magnitude during training 2. Log scale provides more uniform gradient behavior 3. Empirically improves convergence and final performance Example ------- >>> cfg = Config(num_bits_per_symbol=BitsPerSym.QPSK) >>> neural_rx = NeuralRx(cfg, channel_coding_off=True) >>> out = neural_rx(y, no, batch_size) >>> llrs = out["llr"] # Shape: [batch, 1, num_streams, n] """ def __init__( self, cfg: Config, channel_coding_off: bool = False, num_conv2d_filters: int = 256, num_resnet_layers: int = 2, num_res_blocks: int = 4, ): """ Initialize neural receiver architecture. Parameters ---------- cfg : ~demos.mimo_ofdm_neural_receiver.src.config.Config Configuration specifying resource grid and modulation. channel_coding_off : bool, (default False) Skip LDPC decoding if True (training mode). num_conv2d_filters : int, (default 256) Width of the residual network. num_resnet_layers : int, (default 2) Depth of each residual block. num_res_blocks : int, (default 4) Number of residual blocks. """ super().__init__() self._cfg = cfg self._channel_coding_off = bool(channel_coding_off) self.num_conv2d_filters = int(num_conv2d_filters) self.num_resnet_layers = int(num_resnet_layers) self.num_res_blocks = int(num_res_blocks) # [neural_rx-definition-start] # Input conv: expand from (2*num_rx_ant + 1) to num_conv2d_filters channels self._input_conv = Conv2D( filters=self.num_conv2d_filters, kernel_size=(3, 3), padding="same", activation=None, ) # Residual stack for feature extraction self._res_blocks = [ ResidualBlock( num_conv2d_filters=self.num_conv2d_filters, num_resnet_layers=self.num_resnet_layers, ) for _ in range(self.num_res_blocks) ] # Output conv: contract to (num_streams x bits_per_symbol) for LLR output # Each output channel corresponds to one bit position in the constellation self._output_conv = Conv2D( filters=int( self._cfg.rg.num_streams_per_tx * self._cfg.num_bits_per_symbol ), kernel_size=(3, 3), padding="same", activation=None, ) # [neural_rx-definition-end] # Resource grid demapper extracts LLRs at data symbol positions self._rg_demapper = sn.phy.ofdm.ResourceGridDemapper(self._cfg.rg, self._cfg.sm) # LDPC decoder (used only during inference when channel_coding_off=False) self._decoder = LDPC5GDecoder( LDPC5GEncoder(self._cfg.k, self._cfg.n), hard_out=True )
[docs] def call(self, y: tf.Tensor, no: tf.Tensor, batch_size: tf.Tensor) -> tf.Tensor: """ Process received signal to produce LLRs and optionally decoded bits. Parameters ---------- y : tf.Tensor, complex64, [batch, num_rx, num_rx_ant, num_ofdm_symbols, fft_size] Received OFDM signal after channel and noise. no : tf.Tensor, float32, [batch] or scalar Noise power spectral density. batch_size : tf.Tensor, int32, scalar Batch dimension size (needed for reshape operations in graph mode). Returns ------- Dict[str, tf.Tensor] Dictionary containing: - ``"llr"``: Predicted log-likelihood ratios, shape [batch, 1, num_ut_ant, n]. - ``"b_hat"``: Decoded information bits, shape [batch, 1, num_ut_ant, k]. None if ``channel_coding_off=True``. Note ---- The tensor transformations in this method follow a specific sequence: 1. Remove num_rx dimension (assuming single receiver) 2. Transpose to [batch, ofdm_symbols, subcarriers, antennas] 3. Split complex to real channels: 2xnum_rx_ant + 1 (noise) channels 4. Process through CNN 5. Reshape output to match ResourceGridDemapper expectations 6. Extract data positions and reshape for decoder input """ # ===================================================================== # Input Preparation # ===================================================================== # Remove num_rx dimension (single receiver assumed in this demo) y = tf.squeeze(y, axis=1) # Convert noise power to log scale for better neural network conditioning # SNR ranges span orders of magnitude; log scale normalizes the input no = sn.phy.utils.log10(no) # Transpose to image format: [batch, OFDM_symbols, subcarriers, antennas] # CNN expects spatial dims in the middle, channels last y = tf.transpose(y, [0, 2, 3, 1]) # Broadcast noise power to spatial dimensions for concatenation # Shape: [batch] -> [batch, 1, 1, 1] -> [batch, H, W, 1] no = tf.reshape(no, [-1]) no = no + tf.zeros( [tf.shape(y)[0]], dtype=no.dtype ) # ensure length: tf.shape(y)[0] no = tf.reshape(no, [tf.shape(y)[0], 1, 1, 1]) # [tf.shape(y)[0],1,1,1] # Broadcast to match spatial dimensions of y no = tf.broadcast_to(no, [tf.shape(y)[0], tf.shape(y)[1], tf.shape(y)[2], 1]) # Concatenate real, imaginary, and noise channels # Input channels: 2 * num_rx_ant (real + imag) + 1 (noise) = 17 for 8 antennas z = tf.concat([tf.math.real(y), tf.math.imag(y), no], axis=-1) # [neural_rx-call-start] # ===================================================================== # Neural Network Forward Pass # ===================================================================== # Input convolution: expand channel dimension z = self._input_conv(z) # Residual stack: extract hierarchical features for block in self._res_blocks: z = block(z) # Output convolution: produce per-bit LLR predictions z = self._output_conv(z) # [neural_rx-call-end] # ===================================================================== # Output Reshaping for Decoder Compatibility # ===================================================================== # Reshape from [batch, H, W, streams*bits] to [batch, H, W, streams, bits] z = tf.reshape( z, [ tf.shape(z)[0], tf.shape(z)[1], tf.shape(z)[2], self._cfg.rg.num_streams_per_tx, self._cfg.num_bits_per_symbol, ], ) # Transpose to ResourceGridDemapper expected format # From [batch, ofdm, subcarrier, stream, bits] to [batch, stream, ofdm, subcarrier, bits] z = tf.transpose(z, [0, 3, 1, 2, 4]) # Add num_tx dimension (required by ResourceGridDemapper) z = tf.expand_dims(z, axis=1) # Extract LLRs at data symbol positions (removes pilots) llr = self._rg_demapper(z) # Reshape to decoder input format: [batch, 1, num_ut_ant, n] llr = tf.reshape(llr, [batch_size, 1, self._cfg.num_ut_ant, self._cfg.n]) # ===================================================================== # Optional LDPC Decoding # ===================================================================== b_hat = None if not self._channel_coding_off: # Decode LLRs to hard bit decisions b_hat = self._decoder(llr) return {"llr": llr, "b_hat": b_hat}