MIMO-OFDM Neural Receiver

Overview

This demo implements a neural network-based receiver for MIMO-OFDM wireless communication systems using Sionna. It is a simple extension of the Neural Receiver for OFDM SIMO Systems tutorial which demonstrates a Single-Input Multiple-Output (SIMO) system, by considering a Multiple-Input Multiple-Output (MIMO)-OFDM system with 4 transmit antennas and 8 receive antennas.

The neural receiver replaces the traditional receiver chain, namely channel estimation, equalization, and demapping, with a learned convolutional neural network that directly maps received signals and noise power estimates to Log-Likelihood Ratios (LLRs) for channel decoding. This end-to-end approach allows the network to jointly optimize these operations, potentially achieving better performance than the baseline receiver, particularly under imperfect channel state information (CSI).

System Architecture

The end-to-end system (see diagram below) is architected by the System class, and the architecture diagram depicting this class is shown below.

MIMO-OFDM Neural Receiver System MIMO-OFDM Neural Receiver System

It utilizes the configuration defined in Config to enforce a single, validated system configuration. See the code-snippet extracted from Config below for more details.

    # =========================================================================
    # Hard-coded PHY/system parameters
    # These define the physical layer configuration and should not be modified
    # without re-validating the entire simulation chain.
    # =========================================================================

    # Uplink direction: UTs transmit to BS (affects antenna role assignment)
    _direction: str = field(init=False, default="uplink", repr=False)

    # OFDM numerology: 15 kHz SCS is standard for sub-6 GHz 5G NR
    _subcarrier_spacing: float = field(init=False, default=15e3, repr=False)

    # FFT size chosen to balance frequency resolution vs. complexity
    _fft_size: int = field(init=False, default=76, repr=False)

    # 14 OFDM symbols per slot (normal cyclic prefix, 5G NR standard)
    _num_ofdm_symbols: int = field(init=False, default=14, repr=False)

    # CP length must exceed max channel delay spread to prevent ISI
    _cyclic_prefix_length: int = field(init=False, default=6, repr=False)

    # Guard carriers prevent aliasing at band edges (asymmetric for DC null)
    _num_guard_carriers: Tuple[int, int] = field(init=False, default=(5, 6), repr=False)

    # DC subcarrier nulled to avoid LO leakage issues in practical systems
    _dc_null: bool = field(init=False, default=True, repr=False)

    # Kronecker pilot pattern: orthogonal in time-frequency, good for MIMO
    _pilot_pattern: str = field(init=False, default="kronecker", repr=False)

    # Pilot positions distributed across slot for channel tracking
    _pilot_ofdm_symbol_indices: Tuple[int, ...] = field(
        init=False, default=(2, 5, 8, 11), repr=False
    )

    # 4x8 MIMO: 4 UT antennas (Tx), 8 BS antennas (Rx) for uplink
    # This asymmetry favors the BS with more receive diversity
    _num_ut_ant: int = field(init=False, default=4, repr=False)
    _num_bs_ant: int = field(init=False, default=8, repr=False)

    # QAM modulation (constellation type, not order)
    _modulation: str = field(init=False, default="qam", repr=False)

    # Default bits per symbol (overridden by user-settable num_bits_per_symbol)
    _num_bits_per_symbol: int = field(init=False, default=2, repr=False)  # QPSK

    # Rate-1/2 LDPC provides good balance of coding gain vs. complexity
    _coderate: float = field(init=False, default=0.5, repr=False)

    # Fixed seed for reproducible channel/noise realizations
    _seed: int = field(init=False, default=42, repr=False)

Tx class generates random information bits, applies LDPC encoding, maps coded bits to QPSK symbols, and places them onto an OFDM resource grid with 14 OFDM symbols per slot, a 15 kHz subcarrier spacing, and an FFT size of 76. CSI class generates the frequency-domain CDL channel for the 4x8 MIMO link with the desired delay spread, carrier frequency, mobility, and propagation scenario by first configuring dual cross-polarized transmit and receive antenna arrays according to 3GPP TR 38.901 [1], then generating delay-domain channel impulse responses (CIRs), and finally converting them into normalized frequency-domain channel coefficients consistent with the resource grid. Rx, the baseline receiver class, processes received signals by first obtaining channel estimates, then utilizing them to perform equalization, followed by demapping to produce LLRs, which are finally decoded by an LDPC decoder.

Neural Receiver

The neural receiver architecture (see diagram below), implemented in NeuralRx, is a convolutional network (as in [2]) that processes the received resource grid as a 2D image where the spatial dimensions correspond to OFDM symbols (time) and subcarriers (frequency), with channels representing the real and imaginary parts of each receive antenna plus the noise power estimate.

Neural Receiver Architecture Neural Receiver Architecture

The architecture consists of an input convolutional layer that projects the concatenated input features to 256 feature channels, cascaded by 4 residual blocks (ResidualBlock), each built from stacking two cascaded units of layer normalization, activation function, and convolutional layer, with a skip connection to avoid gradient vanishing. The output convolutional layer produces per-subcarrier LLRs for each stream and bit position, which are then reshaped and passed through a resource-grid demapper before LDPC decoding.

The input/output layers and the residual blocks are initialized as shown below:

        # Input conv: expand from (2*num_rx_ant + 1) to num_conv2d_filters channels
        self._input_conv = Conv2D(
            filters=self.num_conv2d_filters,
            kernel_size=(3, 3),
            padding="same",
            activation=None,
        )

        # Residual stack for feature extraction
        self._res_blocks = [
            ResidualBlock(
                num_conv2d_filters=self.num_conv2d_filters,
                num_resnet_layers=self.num_resnet_layers,
            )
            for _ in range(self.num_res_blocks)
        ]

        # Output conv: contract to (num_streams x bits_per_symbol) for LLR output
        # Each output channel corresponds to one bit position in the constellation
        self._output_conv = Conv2D(
            filters=int(
                self._cfg.rg.num_streams_per_tx * self._cfg.num_bits_per_symbol
            ),
            kernel_size=(3, 3),
            padding="same",
            activation=None,
        )

and the features are processed as shown below:

        # =====================================================================
        # Neural Network Forward Pass
        # =====================================================================
        # Input convolution: expand channel dimension
        z = self._input_conv(z)

        # Residual stack: extract hierarchical features
        for block in self._res_blocks:
            z = block(z)

        # Output convolution: produce per-bit LLR predictions
        z = self._output_conv(z)

Training

Training minimizes binary cross-entropy (BCE) loss between the neural receiver’s soft LLR outputs and the transmitted coded bits, bypassing LDPC encoding/decoding for gradient flow. The training loop samples Eb/N0 uniformly from -3 dB to 7 dB for each batch, enabling the network to learn across the full SNR range simultaneously.

Gradient accumulation over 4 mini-batches of size 32 yields an effective batch size of 128, balancing memory constraints with gradient stability. The Adam optimizer with default learning rate is used. Checkpointing supports training resumption via the --iterations and --fresh command-line arguments. The core of the training logic is shown below:

# =============================================================================
# Training Step Function
# Returns loss and gradients; accumulation handled in main loop
# =============================================================================
@tf.function(
    reduce_retracing=True,
    input_signature=[
        tf.TensorSpec([], tf.int32),
        tf.TensorSpec([None], tf.float32),
    ],
)
def train_step(batch_size, ebno_vec):
    """
    Execute single forward/backward pass.

    Parameters
    ----------
    batch_size : tf.Tensor, int32, scalar
        Number of samples in batch.

    ebno_vec : tf.Tensor, float32, [batch_size]
        Per-sample Eb/N0 in dB.

    Returns
    -------
    loss : tf.Tensor, float32, scalar
        BCE loss for this batch.

    grads : List[tf.Tensor]
        Gradients for each trainable variable. None gradients are
        replaced with zeros to avoid issues in accumulation.

    Note
    ----
    The @tf.function decorator with input_signature ensures the function
    is traced once and reused, avoiding retracing overhead each iteration.
    """
    with tf.GradientTape() as tape:
        loss = system(batch_size, ebno_vec)
    grads = tape.gradient(loss, system.trainable_variables)
    # Replace None gradients with zeros (occurs for unused variables)
    grads = [
        g if g is not None else tf.zeros_like(w)
        for g, w in zip(grads, system.trainable_variables)
    ]
    return loss, grads


# =============================================================================
# Validation: Accumulation Alignment
# Start/target must align with accumulation steps for correct averaging
# =============================================================================
if start_iteration % ACCUMULATION_STEPS != 0:
    raise ValueError("start_iteration must be a multiple of ACCUMULATION_STEPS")

if target_iteration % ACCUMULATION_STEPS != 0:
    raise ValueError("target_iteration must be a multiple of ACCUMULATION_STEPS")

# =============================================================================
# Training Loop
# Gradient accumulation: sum gradients over ACCUMULATION_STEPS, then apply
# =============================================================================
accumulated_grads = None
for i in range(start_iteration, target_iteration):
    # Sample random Eb/N0 for each batch element
    # Training across SNR range improves generalization
    ebno_db = rng.uniform([BATCH_SIZE], EBN0_DB_MIN, EBN0_DB_MAX, tf.float32)
    loss, grads = train_step(tf.constant(BATCH_SIZE, tf.int32), ebno_db)

    # Accumulate gradients
    if accumulated_grads is None:
        accumulated_grads = [tf.Variable(g, trainable=False) for g in grads]
    else:
        for acc_g, g in zip(accumulated_grads, grads):
            acc_g.assign_add(g)

    # Apply accumulated gradients every ACCUMULATION_STEPS iterations
    if (i + 1) % ACCUMULATION_STEPS == 0:
        avg_grads = [g / ACCUMULATION_STEPS for g in accumulated_grads]
        optimizer.apply_gradients(zip(avg_grads, system.trainable_variables))
        accumulated_grads = None

    loss_value = float(loss.numpy())
    loss_history.append(loss_value)

    print(
        f"\rStep {i}/{target_iteration}  Loss: {loss_value:.4f}",
        end="",
        flush=True,
    )
print("\n\nTraining complete.")

Results

Performance is evaluated using Bit Error Rate (BER) and Block Error Rate (BLER) across Eb/N0 from -3 dB to 6 dB under the CDL-C channel model. The baseline receiver uses LS channel estimation with LMMSE equalization, tested under both perfect CSI and imperfect CSI (estimated channel) conditions.

BER Comparison

BER comparison under CDL-C channel: neural receiver vs. baseline with perfect/imperfect CSI.

BLER Comparison

BLER comparison under CDL-C channel: neural receiver vs. baseline with perfect/imperfect CSI.

The neural receiver significantly outperforms the baseline with imperfect CSI, particularly at mid-to-high SNR where channel estimation errors dominate baseline performance. At 6 dB Eb/N0, the neural receiver achieves approximately 4x lower BER than the imperfect-CSI baseline. However, the neural receiver does not yet match the perfect-CSI baseline, indicating room for architectural improvements or increased model capacity.

References

[1] 3GPP TR 38.901, “Study on channel model for frequencies from 0.5 to 100 GHz”, Release 18.0.

[2] M. Honkala, D. Korpi and J. M. J. Huttunen, “DeepRx: Fully Convolutional Deep Learning Receiver,” in IEEE Transactions on Wireless Communications, vol. 20, no. 6, pp. 3925-3940, June 2021, doi: 10.1109/TWC.2021.3054520.