Source code for demos.mimo_ofdm_neural_receiver.src.system

# SPDX-License-Identifier: MIT
# Copyright (c) 2025–present Srikanth Pagadarai

"""
End-to-end MIMO-OFDM system.

Provides the ``System`` class which orchestrates the complete
simulation pipeline by composing Tx, Channel, and Rx components:

    Tx -> Channel -> Rx (baseline or neural)

The system supports two operational modes:

1. **Training mode** (``training=True``):
   - Channel coding disabled in Tx (random coded bits)
   - Neural receiver outputs raw LLRs (no LDPC decoding)
   - Returns BCE loss for gradient-based optimization

2. **Inference mode** (``training=False``):
   - Full LDPC encoding in Tx
   - Complete decoding in Rx (baseline or neural)
   - Returns (transmitted_bits, decoded_bits) for BER/BLER computation

The system is implemented as a Keras Model, enabling:
- Automatic variable tracking for optimization
- Integration with TensorFlow training loops
- Checkpointing and weight serialization
"""

import tensorflow as tf
from sionna.phy.utils import ebnodb2no
from tensorflow.keras import Model

from .config import Config, BitsPerSym, CDLModel
from .csi import CSI
from .tx import Tx
from sionna.phy.channel import ApplyOFDMChannel
from .rx import Rx
from .neural_rx import NeuralRx


[docs] class System(Model): """ End-to-end MIMO-OFDM system with baseline and neural receiver options. This Keras Model composes all simulation components and provides a unified interface for both training and inference. The system generates transmitted signals, applies channel effects, and processes received signals through either a conventional or neural receiver. The processing pipeline is: 1. **CSI Generation**: Create frequency-domain channel response 2. **Transmission**: Generate bits, encode, modulate, map to OFDM grid 3. **Channel**: Apply frequency-domain channel and add AWGN 4. **Reception**: Process received signal (baseline LMMSE or neural CNN) 5. **Output**: Return loss (training) or bit tensors (inference) Parameters ---------- training : bool, (default False) If True, configure for training mode: - Disable channel coding in Tx/Rx - Return BCE loss instead of bit tensors perfect_csi : bool, (default False) If True, baseline receiver uses ground-truth CSI. Only affects baseline Rx; neural Rx never uses explicit CSI. cdl_model : {"A", "B", "C", "D", "E"}, (default "D") 3GPP CDL channel model variant. delay_spread : float, (default 300e-9) RMS delay spread in seconds. carrier_frequency : float, (default 2.6e9) Carrier frequency in Hz. speed : float, (default 0.0) UE speed in m/s for Doppler modeling. num_bits_per_symbol : BitsPerSym, (default BitsPerSym.QPSK) Modulation order. use_neural_rx : bool, (default False) If True, use neural receiver; otherwise use baseline LMMSE receiver. num_conv2d_filters : int, (default 128) Neural receiver CNN width. num_resnet_layers : int, (default 2) Layers per residual block in neural receiver. num_res_blocks : int, (default 4) Number of residual blocks in neural receiver. name : str, (default "system") Keras model name for variable scoping. Attributes ---------- bce : tf.keras.losses.BinaryCrossentropy Loss function for training (expects logits, not probabilities). Note ---- The system accepts Eb/N0 in dB and internally converts to noise power using ``ebnodb2no``. This allows consistent SNR specification across different modulation orders and code rates. Both ``__call__`` and ``call_scalar`` are provided: - ``__call__``: Takes vector Eb/N0 (one per batch sample) - ``call_scalar``: Takes scalar Eb/N0 (broadcast to all samples) The scalar variant is required for compatibility with Sionna's ``PlotBER.simulate()`` which passes scalar SNR values. Example ------- >>> # Training >>> system = System(training=True, use_neural_rx=True) >>> loss = system(batch_size, ebno_db_vector) >>> # Inference >>> system = System(training=False, use_neural_rx=True) >>> b, b_hat = system(batch_size, ebno_db_vector) """ def __init__( self, *, training: bool = False, perfect_csi: bool = False, cdl_model: CDLModel = "D", delay_spread: float = 300e-9, carrier_frequency: float = 2.6e9, speed: float = 0.0, num_bits_per_symbol: BitsPerSym = BitsPerSym.QPSK, use_neural_rx: bool = False, num_conv2d_filters: int = 256, num_resnet_layers: int = 2, num_res_blocks: int = 4, name: str = "system", ): """ Initialize end-to-end system with all components. Parameters ---------- training : bool Enable training mode (affects Tx encoding and Rx decoding). perfect_csi : bool Enable perfect CSI for baseline receiver. cdl_model : str CDL channel model variant. delay_spread : float Channel delay spread in seconds. carrier_frequency : float Carrier frequency in Hz. speed : float UE speed for Doppler. num_bits_per_symbol : BitsPerSym Modulation order. use_neural_rx : bool Select neural vs baseline receiver. num_conv2d_filters : int Neural receiver width. num_resnet_layers : int Neural receiver block depth. num_res_blocks : int Neural receiver depth. name : str Keras model name. """ super().__init__(name=name) self._training = training # Neural receiver architecture parameters self._use_neural_rx = bool(use_neural_rx) self._num_conv2d_filters = num_conv2d_filters self._num_resnet_layers = num_resnet_layers self._num_res_blocks = num_res_blocks # Build configuration with PHY parameters self._cfg = Config( perfect_csi=perfect_csi, cdl_model=cdl_model, delay_spread=delay_spread, carrier_frequency=carrier_frequency, speed=speed, num_bits_per_symbol=num_bits_per_symbol, ) # ===================================================================== # Component Instantiation # All components share the same Config and CSI for consistency # ===================================================================== self._csi = CSI(self._cfg) # Tx: channel coding disabled during training to allow gradient flow self._tx = Tx(self._cfg, self._training) self._ch = ApplyOFDMChannel(add_awgn=True) # Baseline Rx: conventional LMMSE receiver (not used during training) self._rx = Rx(self._cfg, self._csi) # Neural Rx: CNN-based receiver with configurable architecture # Channel decoding disabled during training (BCE loss on raw LLRs) self._neural_rx = NeuralRx( self._cfg, self._training, self._num_conv2d_filters, self._num_resnet_layers, self._num_res_blocks, ) # BCE loss for training: from_logits=True because neural Rx outputs LLRs # LLRs are log-odds ratios, equivalent to logits for binary classification self.bce = tf.keras.losses.BinaryCrossentropy(from_logits=True) @tf.function( reduce_retracing=True, input_signature=[ tf.TensorSpec(shape=[], dtype=tf.int32), # scalar batch size tf.TensorSpec(shape=[], dtype=tf.float32), # scalar ebno_db ], ) def call_scalar(self, batch_size, ebno_db_scalar): """ Forward pass with scalar Eb/N0 (for PlotBER compatibility). This method broadcasts a single Eb/N0 value to all batch samples, providing compatibility with Sionna's ``PlotBER.simulate()`` which calls the model with scalar SNR values. Parameters ---------- batch_size : tf.Tensor, int32, scalar Number of samples in the batch. ebno_db_scalar : tf.Tensor, float32, scalar Eb/N0 in dB, applied uniformly to all batch samples. Returns ------- See ``__call__`` for return value documentation. Note ---- This is a thin wrapper that expands the scalar to a vector and delegates to ``__call__``. """ # Broadcast scalar Eb/N0 to vector of length batch_size ebno_vec = tf.fill([batch_size], ebno_db_scalar) return self.__call__(batch_size, ebno_vec) @tf.function( reduce_retracing=True, input_signature=[ tf.TensorSpec(shape=[], dtype=tf.int32), # batch_size (scalar) tf.TensorSpec( shape=[None], dtype=tf.float32 ), # ebno_db vector (len == batch_size) ], ) def __call__(self, batch_size: tf.Tensor, ebno_db: tf.Tensor): """ Forward pass through complete MIMO-OFDM system. Parameters ---------- batch_size : tf.Tensor, int32, scalar Number of independent channel realizations to simulate. ebno_db : tf.Tensor, float32, [batch_size] Per-sample Eb/N0 in dB. Allows training across SNR range by varying Eb/N0 within each batch. Returns ------- Training mode (``training=True`` and ``use_neural_rx=True``): loss : tf.Tensor, float32, scalar Binary cross-entropy loss between transmitted coded bits and predicted LLRs. Suitable for gradient descent. Inference mode (``training=False``): b : tf.Tensor, float32, [batch, 1, num_streams, k] Transmitted information bits (ground truth). b_hat : tf.Tensor, float32, [batch, 1, num_streams, k] Decoded information bits (receiver output). Pre-conditions -------------- - ``len(ebno_db) == batch_size`` (vector length must match batch). - ``ebno_db`` values should be in reasonable range (e.g., -5 to 20 dB). Post-conditions --------------- - Training: Loss is scalar, suitable for optimizer.apply_gradients(). - Inference: b and b_hat have matching shapes for BER computation. Note ---- The Eb/N0 to noise power conversion accounts for: - Modulation order (bits per symbol) - Code rate - Resource grid structure (data vs pilot symbols) This ensures consistent SNR interpretation across configurations. Invariants ---------- - Same channel realization (h_freq) used for transmission and reception. - Random bits are independent across batch samples. """ # ===================================================================== # Channel Generation # Generate new channel realization for this batch # ===================================================================== h_freq = self._csi.build(batch_size) # Convert Eb/N0 (dB) to noise power spectral density # Accounts for modulation order, code rate, and resource grid efficiency no = ebnodb2no( ebno_db, self._cfg.num_bits_per_symbol, self._cfg.coderate, self._cfg.rg ) # ===================================================================== # Transmission and Channel # ===================================================================== tx_out = self._tx(batch_size) y = self._ch(tx_out["x_rg"], h_freq, no) y_out = {"y": y} # ===================================================================== # Reception # Select receiver and prepare arguments based on configuration # ===================================================================== rx_to_use = self._neural_rx if self._use_neural_rx else self._rx # Neural Rx takes (y, no, batch_size); baseline Rx takes (y, h_freq, no) # This difference reflects that neural Rx doesn't use explicit CSI rx_args_to_pass = ( (y_out["y"], no, batch_size) if self._use_neural_rx else (y_out["y"], h_freq, no) ) rx_out = rx_to_use(*rx_args_to_pass) # ===================================================================== # Output Selection # Training returns loss; inference returns bit tensors # ===================================================================== if self._use_neural_rx and self._training: # BCE loss: compare predicted LLRs against transmitted coded bits # tx_out["c"] contains the ground-truth coded bits # rx_out["llr"] contains the predicted log-likelihood ratios loss = self.bce(tx_out["c"], rx_out["llr"]) return loss # Inference: return transmitted and decoded information bits return tx_out["b"], rx_out["b_hat"]