pusch_autoencoder

This module implements an end-to-end autoencoder for the 5G NR Physical Uplink Shared Channel (PUSCH).

Configuration

class demos.pusch_autoencoder.src.config.Config(num_bs_ant=32)[source]

Bases: object

Central configuration for the PUSCH-Autoencoder demo.

This dataclass consolidates all hard-coded system parameters to ensure consistency across CIR generation, model instantiation, training, and inference. Parameters are intentionally non-configurable at instantiation (init=False) to enforce a single, validated system configuration. The number of BS antennas (num_bs_ant) is configurable to allow flexibility in array size.

The configuration defines a MU-MIMO uplink scenario with:

  • 4 UEs, each with 4 antennas (cross-polarized)

  • 1 BS with configurable antennas (default 16, cross-polarized)

  • 16-QAM modulation (MCS index 14, table 1)

  • Site- (Munich-) specific ray-traced channel

Parameters:

num_bs_ant (int, optional) – Number of BS antennas (cross-polarized array). Must be even since the antenna array uses cross-polarization (num_cols = num_bs_ant // 2). Default is 16.

Notes

  • After __post_init__, num_bits_per_symbol and target_coderate are populated from the 3GPP NR MCS tables.

  • All properties return consistent, validated values.

  • System dimensions (antennas, UEs) remain fixed after instantiation.

  • PUSCH-related properties (resource_grid, pusch_pilot_indices, etc.) are set externally by PUSCHLinkE2E after transmitter construction.

Example

>>> cfg = Config()
>>> print(cfg.num_bits_per_symbol)  # 4 for 16-QAM
4
>>> print(cfg.num_ue, cfg.num_bs_ant)
4 16
>>> cfg_32ant = Config(num_bs_ant=32)
>>> print(cfg_32ant.num_bs_ant)
32

Notes

The MCS decoder call in __post_init__ uses transform_precoding=True and pi2bpsk=False, which is appropriate for standard PUSCH without DFT-s-OFDM transform precoding at the physical layer.

num_bs_ant: int = 32
property subcarrier_spacing: float

Subcarrier spacing in Hz (30 kHz for FR1 NR).

Type:

float

property num_time_steps: int

Number of OFDM symbols per slot (14 with normal CP).

Type:

int

property num_ue: int

Number of User Equipment (UE) devices in the MU-MIMO scenario.

Type:

int

property num_bs: int

Number of base stations (currently single-cell only).

Type:

int

property num_ue_ant: int

Number of antennas per UE (4 = 2x2 cross-polarized array).

Type:

int

property batch_size_cir: int

Batch size for ray-tracing CIR generation.

Type:

int

property target_num_cirs: int

Total number of CIR realizations to generate.

Type:

int

property max_depth: int

Maximum number of ray reflections in path tracing.

Type:

int

property min_gain_db: float

Minimum path gain threshold (dB) for UE position sampling.

Type:

float

property max_gain_db: float

Maximum path gain threshold (dB) for UE position sampling.

Type:

float

property min_dist_m: float

Minimum UE-BS distance (m) for position sampling.

Type:

float

property max_dist_m: float

Maximum UE-BS distance (m) for position sampling.

Type:

float

property rm_cell_size: Tuple[float, float]

Radio map cell size (x, y) in meters.

Type:

Tuple[float, float]

property rm_samples_per_tx: int

Monte Carlo samples per transmitter for radio map.

Type:

int

property rm_vmin_db: float

Minimum value (dB) for radio map colormap.

Type:

float

property rm_clip_at: float

Clipping threshold for radio map visualization.

Type:

float

property rm_resolution: Tuple[int, int]

Radio map image resolution (width, height) pixels.

Type:

Tuple[int, int]

property rm_num_samples: int

Anti-aliasing samples per pixel for rendering.

Type:

int

property batch_size: int

Batch size for training and BER/BLER simulation.

Type:

int

property seed: int

Global random seed for reproducibility.

Type:

int

property num_prb: int

Number of Physical Resource Blocks (PRBs) allocated.

Type:

int

property mcs_index: int

Modulation and Coding Scheme index (0-28 for table 1).

Type:

int

property num_layers: int

Number of MIMO layers per UE (spatial streams).

Type:

int

property mcs_table: int

MCS table index (1=64-QAM max, 2=256-QAM, 3=64-QAM low-SE).

Type:

int

property domain: str

Processing domain (‘freq’ or ‘time’) for transmitter output.

Type:

str

property num_bits_per_symbol: float

Bits per constellation symbol (derived from MCS).

Type:

int

property target_coderate: float

Target code rate (derived from MCS, typically 0.3-0.9).

Type:

float

property resource_grid

Sionna OFDM resource grid object.

Set externally by PUSCHLinkE2E after transmitter construction. Used by the neural detector to determine grid dimensions.

Type:

ResourceGrid

property pusch_pilot_indices

OFDM symbol indices containing DMRS pilots.

Set externally by PUSCHLinkE2E from PUSCHConfig.dmrs_symbol_indices. The neural detector uses this to separate pilot and data processing.

Type:

List[int]

property pusch_num_subcarriers

Number of subcarriers in the PUSCH allocation.

Set externally by PUSCHLinkE2E from PUSCHConfig.num_subcarriers.

Type:

int

property pusch_num_symbols_per_slot

Number of OFDM symbols per slot in the carrier configuration.

Set externally by PUSCHLinkE2E from carrier settings.

Type:

int

Channel Impulse Response

class demos.pusch_autoencoder.src.cir_generator.CIRGenerator(a, tau, num_tx)[source]

Bases: object

Infinite generator for random CIR sample selection.

This class wraps a pre-loaded CIR dataset and provides an infinite generator interface compatible with Sionna’s CIRDataset. Each call to the generator yields a new random combination of UE channels, simulating the scenario where different UEs are co-scheduled in each transmission slot.

Parameters:
  • a (tf.Tensor or np.ndarray, complex64) – CIR path coefficients from ray-tracing. Shape: [num_samples, 1, num_bs_ant, 1, num_ue_ant, max_paths, num_time_steps]

  • tau (tf.Tensor or np.ndarray, float32) – Path delays in seconds. Shape: [num_samples, 1, 1, max_paths]

  • num_tx (int) – Number of transmitters (UEs) to sample for each yield. Typically equals num_ue in the system configuration.

Example

>>> a, tau = manager.load_from_tfrecord()  # [5000, 1, 16, 1, 4, 51, 14]
>>> gen = CIRGenerator(a, tau, num_tx=4)
>>> for a_sample, tau_sample in gen():
...     print(a_sample.shape)  # [16, 4, 4, 51, 14]
...     break

Notes

The generator uses tf.random.uniform_candidate_sampler for efficient random selection without replacement. This is more efficient than tf.random.shuffle for large datasets when only selecting a small subset.

The dimension transpositions in __call__ reorder from the storage format (sample-first) to Sionna’s expected format (antenna-first):

  • Input a: [num_tx, 1, num_bs_ant, 1, num_ue_ant, paths, time]

  • Output a: [num_bs_ant, num_tx, num_ue_ant, paths, time]

This reordering places the BS antenna dimension first, followed by the UE (TX) dimension, matching Sionna’s channel tensor conventions.

class demos.pusch_autoencoder.src.cir_manager.CIRManager(config=None)[source]

Bases: object

Unified manager for CIR generation, storage, and loading.

This class encapsulates the complete CIR data pipeline for the PUSCH autoencoder demo, including:

  • Munich scene setup with configurable antenna arrays

  • Radio map computation for coverage-aware UE position sampling

  • Ray-traced CIR generation with multipath propagation

  • TFRecord serialization for efficient data storage

  • MU-MIMO sample grouping for multi-user training

Parameters:
  • config (Config, optional) – Configuration object with system parameters. If None, uses default Config() with standard MU-MIMO settings.

  • Pre-conditions

  • --------------

  • (sionna.rt). (- Sionna must be installed with ray-tracing support)

  • (recommended (- GPU with sufficient memory for ray-tracing)

  • Sionna). (- Munich scene assets must be available (bundled with)

  • Post-conditions

  • ---------------

  • setup_scene() (- After)

  • compute_radio_map() (- After)

  • generate_cir_data() (- After)

  • Invariants

  • ----------

  • counts (- Config parameters (antenna)

  • init. (bandwidths) are fixed after)

  • position (- Scene geometry (TX)

  • fixed. (array orientations) is)

Example

>>> # Generation workflow (offline, run once)
>>> manager = CIRManager()
>>> manager.generate_and_save([0, 1, 2], tfrecord_dir="../cir_tfrecords")
>>> # Loading workflow (online, run each training)
>>> manager = CIRManager()
>>> a, tau = manager.load_from_tfrecord(group_for_mumimo=True)
>>> model = PUSCHLinkE2E((a, tau), ...)

Notes

The CIR generation process is computationally expensive but only needs to run once. Generated TFRecords can be reused across many training runs with different hyperparameters.

The group_for_mumimo option in load_from_tfrecord() combines num_ue individual CIRs into single MU-MIMO samples, simulating co-scheduled uplink transmissions from multiple UEs.

setup_scene()[source]

Initialize Munich scene with BS and UE antenna configurations.

This method loads the Munich urban scene and configures:

  • BS array: 16x2 cross-polarized panel (32 elements) on a rooftop

  • UE array: 2x2 cross-polarized array (4 elements) with isotropic pattern

  • Camera: Overhead view for radio map visualization

Returns:

scene – Configured scene object ready for ray-tracing.

Return type:

sionna.rt.Scene

Notes

The BS is positioned at [8.5, 21, 27] meters, which places it on a building rooftop in the Munich scene. The look_at direction points toward the main coverage area where UEs will be sampled.

The BS uses tr38901 antenna pattern (3GPP sector antenna) while UEs use iso (isotropic) pattern, reflecting realistic deployments where BS antennas are directional but UE antennas are omnidirectional.

compute_radio_map(save_images=True)[source]

Compute path gain radio map for coverage-based UE sampling.

The radio map provides spatially-resolved path gain information used to sample UE positions in areas with valid coverage (avoiding dead zones and extreme near-field regions).

Parameters:

save_images (bool) – If True, saves radio map visualization to PNG file.

Returns:

rm – Radio map object with path gain values per spatial cell.

Return type:

sionna.rt.RadioMap

Notes

Radio map computation uses Monte Carlo ray-tracing with 10^7 samples per TX, which takes several minutes but provides smooth coverage maps. The result is cached in self.rm for subsequent UE position sampling.

generate_cir_data(seed_offset=0, max_num_paths=0)[source]

Generate ray-traced CIR data for multiple UE positions.

This method performs the core ray-tracing loop:

  1. Sample UE positions from radio map (coverage-aware)

  2. Place receiver objects at sampled positions

  3. Run path solver to compute multipath propagation

  4. Extract CIR (path gains and delays) for each link

Parameters:
  • seed_offset (int) – Random seed offset for reproducible position sampling. Different offsets produce different UE position sets.

  • max_num_paths (int) – Initial value for path count tracking. Updated during generation.

Returns:

  • a (np.ndarray, complex64) – CIR coefficients with shape [num_samples, 1, num_bs_ant, 1, num_ue_ant, max_paths, num_time_steps]

  • tau (np.ndarray, float32) – Path delays with shape [num_samples, 1, 1, max_paths]

  • max_num_paths (int) – Maximum number of paths across all generated samples.

Notes

The generation loop processes batch_size_cir positions at a time, updating receiver positions and re-running the path solver. Progress is printed to console since this can take hours for large datasets.

Empty CIRs (zero total power) are filtered out, as they represent positions with no valid propagation paths (e.g., inside buildings).

The seed formula idx + seed_offset * 1000 ensures different files (with different seed_offsets) produce non-overlapping position samples.

save_to_tfrecord(a, tau, filename)[source]

Serialize CIR data to TFRecord format for efficient storage.

TFRecord provides efficient sequential reading, compression, and seamless integration with tf.data pipelines for training.

Parameters:
  • a (np.ndarray, complex64) – CIR coefficients, shape [num_samples, ...].

  • tau (np.ndarray, float32) – Path delays, shape [num_samples, ...].

  • filename (str) – Output TFRecord file path.

Notes

Each sample is stored as a separate TFRecord Example containing:

  • a: Serialized complex64 tensor (path gains)

  • tau: Serialized float32 tensor (path delays)

  • a_shape: Shape metadata for deserialization

  • tau_shape: Shape metadata for deserialization

The shape metadata enables reconstruction of tensors with varying path counts across different TFRecord files.

load_from_tfrecord(tfrecord_dir=None, group_for_mumimo=False)[source]

Load CIR data from TFRecord files with optional MU-MIMO grouping.

Parameters:
  • tfrecord_dir (str, optional) – Directory containing TFRecord files, relative to this module. If not provided, defaults to cir_tfrecords_ant{num_bs_ant} within the demo directory.

  • group_for_mumimo (bool) – If True, groups num_ue individual CIRs into MU-MIMO samples. This simulates co-scheduled uplink transmissions.

Returns:

  • all_a (tf.Tensor, complex64) – CIR coefficients. Shape depends on group_for_mumimo:

    • False: [num_samples, 1, num_bs_ant, 1, num_ue_ant, max_paths, num_time_steps]

    • True: [num_mu_samples, 1, num_bs_ant, num_ue, num_ue_ant, max_paths, num_time_steps]

  • all_tau (tf.Tensor, float32) – Path delays. Shape depends on group_for_mumimo:

    • False: [num_samples, 1, 1, max_paths]

    • True: [num_mu_samples, 1, num_ue, max_paths]

Notes

When group_for_mumimo=True, consecutive samples are combined to form multi-user scenarios. For example, with num_ue=4, samples 0-3 become MU-MIMO sample 0, samples 4-7 become sample 1, etc.

This grouping is valid because each original sample represents an independent UE position, and combining them simulates the realistic scenario where multiple UEs transmit simultaneously.

build_channel_model(batch_size=None, num_bs=None, num_bs_ant=None, num_ue=None, num_ue_ant=None, num_time_steps=None, tfrecord_dir=None)[source]

Build CIRDataset channel model from TFRecord files.

This method creates a Sionna CIRDataset that can be used with OFDMChannel for baseline BER/BLER evaluation. The dataset provides on-demand CIR sampling during simulation.

Parameters:
  • batch_size (int, optional) – Batch size for dataset. Default from config.

  • num_bs (int, optional) – Number of base stations. Default from config.

  • num_bs_ant (int, optional) – BS antenna count. Default from config.

  • num_ue (int, optional) – Number of UEs. Default from config.

  • num_ue_ant (int, optional) – UE antenna count. Default from config.

  • num_time_steps (int, optional) – OFDM symbols per slot. Default from config.

  • tfrecord_dir (str, optional) – Directory containing TFRecord files. If not provided, defaults to cir_tfrecords_ant{num_bs_ant} within the demo directory.

Returns:

channel_model – Channel model compatible with Sionna’s OFDMChannel.

Return type:

CIRDataset

Notes

This method is primarily used for baseline evaluation where a CIRDataset object is needed. For autoencoder training, use load_from_tfrecord(group_for_mumimo=True) directly to get tensors that can be indexed during training.

save_visualization_ue_positions(filename='munich_ue_positions.png')[source]

Render radio map with current UE positions overlaid.

Parameters:

filename (str) – Output PNG file path.

Raises:

RuntimeError – If scene, radio map, or camera are not initialized.

generate_and_save(seed_offsets, tfrecord_dir=None, save_radio_map=True)[source]

Complete CIR generation pipeline: generate, visualize, and save.

This is the main entry point for offline CIR dataset creation. It handles scene setup, radio map computation, CIR generation, and TFRecord serialization in a single call.

Parameters:
  • seed_offsets (int or list of int) – Random seed offset(s) for UE position sampling. Each seed produces a separate TFRecord file.

  • tfrecord_dir (str, optional) – Output directory for TFRecord files (relative to this module). If not provided, defaults to cir_tfrecords_ant{num_bs_ant} within the demo directory.

  • save_radio_map (bool) – If True, saves radio map and UE position visualizations.

Example

>>> manager = CIRManager()
>>> # Generate 3 files with 5000 CIRs each (15000 total)
>>> manager.generate_and_save([0, 1, 2])

Notes

Multiple seed offsets produce independent position samples, which can be beneficial for:

  1. Increasing total dataset size beyond single-file limits

  2. Enabling parallel generation on multiple machines

  3. Creating train/validation/test splits with different seeds

The output directory includes the antenna count suffix to keep CIR data for different antenna configurations separate.

Trainable Transmitter

class demos.pusch_autoencoder.src.pusch_trainable_transmitter.PUSCHTrainableTransmitter(*args, **kwargs)[source]

Bases: PUSCHTransmitter

PUSCH Transmitter with SYMMETRIC trainable constellation geometry AND labeling.

This subclass enforces 4-fold symmetry (I-axis, Q-axis, origin) in the constellation while supporting learnable labeling for autoencoder-based communication system design.

Parameters:
  • *args (tuple) – Positional arguments passed to PUSCHTransmitter.

  • training (bool) – If True, constellation and labeling are trainable with soft Gumbel-Softmax assignment. Default False.

  • gumbel_temperature (float) – Temperature for Gumbel-Softmax. Lower = sharper (more discrete). Default 0.5.

  • **kwargs (dict) – Keyword arguments passed to PUSCHTransmitter.

Example

>>> tx = PUSCHTrainableTransmitter(pusch_configs, output_domain="freq",
...                                 training=True, gumbel_temperature=0.5)
>>> x_map, x, b, c = tx(batch_size=32)

Notes

The symmetric constellation is stored as [num_base_points] complex values (typically 4 for 16-QAM), and the full constellation is computed via reflections. This guarantees 4-fold symmetry is preserved during training.

The learnable labeling uses a permutation logits matrix of shape [num_points, num_points]. Row i contains logits for which constellation point should be assigned to bit pattern i.

static generate_random_symmetric_constellation(num_points, seed=None)[source]

Generate random constellation with 4-fold symmetry and unit average energy.

Creates a constellation by: 1. Generating num_points/4 random base points in the first quadrant 2. Reflecting them across I-axis, Q-axis, and origin 3. Normalizing to unit average power

Parameters:
  • num_points (int) – Total constellation size (must be divisible by 4). For 16-QAM: 16.

  • seed (int, optional) – Random seed for reproducibility.

Returns:

Symmetric constellation points with shape [num_points]. Points are ordered: [Q1_points, Q2_points, Q4_points, Q3_points]

Return type:

tf.Tensor, complex64

Notes

The generation process:

  1. Sample num_points/4 base points in Q1 (I>0, Q>0)

  2. Reflect across Q-axis: negate real part -> Q2 (I<0, Q>0)

  3. Reflect across I-axis: negate imag part -> Q4 (I>0, Q<0)

  4. Reflect across origin: negate both parts -> Q3 (I<0, Q<0)

  5. Normalize complete constellation to unit average power

Example

>>> points = PUSCHTrainableTransmitter.generate_random_symmetric_constellation(16, seed=42)
>>> print(f"Power: {tf.reduce_mean(tf.abs(points)**2):.6f}")  # Should be ~1.0
>>> # Verify I-axis symmetry (Q1 vs Q4)
>>> print(f"I-sym check: {tf.reduce_mean(tf.abs(points[0] - tf.math.conj(points[8]))):.6f}")  # ~0
property trainable_variables

base geometry + labeling.

Returns:

Three-element list: [_base_points_r, _base_points_i, _labeling_logits] where: - _base_points_r: Real parts of Q1 base points - _base_points_i: Imaginary parts of Q1 base points - _labeling_logits: Permutation logits for bit-to-point mapping

Return type:

list of tf.Variable

Type:

Return all trainable variables

property geometry_variables

Return only the base geometry (Q1 constellation points) variables.

property labeling_variables

Return only the labeling (permutation) variable.

property gumbel_temperature

Current Gumbel-Softmax temperature.

Type:

float

get_base_points()[source]

Get the base constellation points (Q1 only, not normalized).

Returns:

Base constellation points in first quadrant, shape [num_base_points]. For 16-QAM, this is [4] complex values.

Return type:

tf.Tensor, complex64

get_normalized_constellation()[source]

Compute 4-fold symmetric and power-normalized constellation points.

Returns:

Normalized constellation points with 4-fold symmetry, shape [num_points]. For 16-QAM, this is [16] complex values.

Return type:

tf.Tensor, complex64

Notes

The normalization process:

  1. Retrieve base points from trainable variables (Q1 only)

  2. Reflect to create full 4-fold symmetric constellation

  3. Normalize to unit average power

The normalization is differentiable and does NOT break symmetry since it’s a uniform scaling operation applied to all points equally.

Symmetry properties maintained: - I-axis: constellation[i] = conj(constellation[j]) for mirrored i,j - Q-axis: constellation[i] = -conj(constellation[k]) for mirrored i,k - Origin: constellation[i] = -constellation[m] for opposite i,m

get_soft_labeling_matrix(hard=False)[source]

Get the current labeling matrix (permutation of constellation points).

Parameters:

hard (bool) – If True, return hard one-hot assignment. If False, return soft Gumbel-Softmax probabilities.

Returns:

Labeling matrix, shape [num_points, num_points]. Row i indicates which constellation point(s) are assigned to bit pattern i.

Return type:

tf.Tensor, float

Example

For hard assignment, row i will have a single 1 at column j, meaning bit pattern i maps to constellation point j.

For soft assignment during training, row i contains probabilities summing to 1, enabling gradient flow through the labeling.

call(inputs)[source]

Execute transmitter processing chain with trainable constellation and labeling.

The symmetric constellation is computed from base points before each forward pass.

Parameters:

inputs (int or tf.Tensor) – If return_bits=True: int specifying batch size. If return_bits=False: tf.Tensor of input bits.

Returns:

If return_bits=True: tuple (x_map, x, b, c) If return_bits=False: just x (transmitted signal).

Return type:

tuple or tf.Tensor

Notes

The processing chain:

  1. Compute full 16-point symmetric constellation from 4 base points

  2. Standard PUSCH transmission (bits -> symbols -> OFDM)

  3. Symbol mapping uses learnable labeling via Gumbel-Softmax

Neural Detector

class demos.pusch_autoencoder.src.pusch_neural_detector.Conv2DResBlock(*args, **kwargs)[source]

Bases: Layer

Pre-activation residual block with two convolutional layers.

Implements the pre-activation ResNet variant where normalization and activation precede each convolution. This ordering improves gradient flow and enables training of deeper networks.

The block computes: output = input + conv2(act(norm(conv1(act(norm(input))))))

Parameters:
  • filters (int) – Number of convolutional filters in both layers.

  • kernel_size (tuple of int) – Spatial kernel size, default (3, 3).

  • name (str, optional) – Layer name for TensorFlow graph.

Notes

  • Uses LayerNormalization (not BatchNorm) for stable training with small batch sizes typical in communication system simulation.

  • LeakyReLU with α=0.1 prevents dead neurons while maintaining near-linear behavior for small negative inputs.

  • Identity skip connection requires input and output channel counts to match; caller must ensure filters equals input channels.

call(x)[source]

Apply residual transformation.

Parameters:

x (tf.Tensor) – Input tensor of shape [batch, height, width, filters].

Returns:

Output tensor with same shape as input (residual added).

Return type:

tf.Tensor

class demos.pusch_autoencoder.src.pusch_neural_detector.PUSCHNeuralDetector(*args, **kwargs)[source]

Bases: Layer

Neural MIMO detector with learned channel estimation refinement for PUSCH.

This detector implements a residual learning architecture that refines classical LS channel estimates and LMMSE-based soft symbol estimates using convolutional neural networks. The key design principle is that learning corrections to strong classical baselines is more effective than learning detection from scratch.

Architecture

The detector processes data through six stages:

  1. Feature extraction: Assembles input features including LS channel estimate, received signal, matched filter output, Gram matrix structure, estimation error variance, and noise level.

  2. Shared backbone: Processes features through convolutional ResBlocks to learn joint representations useful for both CE refinement and detection.

  3. CE refinement head: Predicts additive corrections Δh and multiplicative log-domain corrections Δlog(err_var) to the LS estimates.

  4. Scaled correction application: Applies learned corrections scaled by trainable parameters that start at zero, enabling gradual departure from classical behavior during training.

  5. Classical LMMSE: Performs LMMSE equalization using refined channel estimate and error variance, followed by max-log demapping.

  6. LLR refinement: Injects LMMSE outputs back into the network to predict additive LLR corrections, again scaled by a trainable parameter.

param cfg:

System configuration containing MIMO dimensions, modulation order, and PUSCH resource grid information.

type cfg:

~demos.pusch_autoencoder.src.config.Config

param num_conv2d_filters:

Number of filters in all convolutional layers. Higher values increase model capacity but also computational cost. Default 128.

type num_conv2d_filters:

int

param num_shared_res_blocks:

Number of ResBlocks in the shared backbone. Controls depth of joint feature learning. Default 4.

type num_shared_res_blocks:

int

param num_det_res_blocks:

Number of ResBlocks in the detection continuation path. Controls capacity for LLR refinement. Default 6.

type num_det_res_blocks:

int

param kernel_size:

Spatial kernel size for ResBlock convolutions. Default (3, 3).

type kernel_size:

tuple of int

param Pre-conditions:

param ————–:

param - cfg.pusch_pilot_indices must be set (by PUSCHLinkE2E) before:

instantiation to enable pilot/data symbol separation.

param - Input tensors must follow Sionna’s PUSCH dimension conventions.:

param Post-conditions:

param —————:

param - trainable_variables returns correction scales first:

weights (enables separate optimizer configuration).

param then network:

weights (enables separate optimizer configuration).

param - last_h_hat_refined and last_err_var_refined contain the most:

recent refined estimates (useful for auxiliary losses).

param Invariants:

param ———-:

param - Correction scales initialized to 0.0:

classical LMMSE exactly (graceful degradation if training fails).

param so initial behavior matches:

classical LMMSE exactly (graceful degradation if training fails).

param - Error variance correction scale is always positive (softplus transform).:

param - Output LLR shape matches Sionna’s PUSCHReceiver expectations.:

Example

>>> cfg = Config()
>>> # ... set cfg.pusch_pilot_indices via PUSCHLinkE2E ...
>>> detector = PUSCHNeuralDetector(cfg, num_conv2d_filters=64)
>>> llr = detector(y, h_hat, err_var, no)

Notes

The trainable correction scales serve multiple purposes:

  1. Safe initialization: Starting at 0.0 means the detector initially behaves exactly like classical LMMSE, providing a stable starting point.

  2. Interpretability: Scale magnitudes indicate how much the network deviates from classical processing (useful for debugging).

  3. Gradient balancing: Separate scales for h, err_var, and LLR allow independent learning rates for different correction types.

The error variance scale uses softplus to ensure positivity, since negative variance would be physically meaningless and cause numerical issues in LMMSE computation.

property err_var_correction_scale

Effective error variance correction scale (always positive).

Returns:

Scalar tensor with softplus-transformed scale value. softplus(x) = log(1 + exp(x)) ensures output > 0.

Return type:

tf.Tensor

property trainable_variables

Collect all trainable variables in a specific order.

The ordering places correction scales first, enabling separate optimizer configuration (e.g., higher learning rate for scales).

Returns:

Ordered list: [correction_scales, backbone_weights, ce_head_weights, detection_weights].

Return type:

list of tf.Variable

call(y, h_hat, err_var, no, constellation=None, training=None)[source]

Execute neural MIMO detection with channel estimation refinement.

Parameters:
  • y (tf.Tensor, complex64) – Received OFDM signal after FFT. Shape: [B, num_bs, num_bs_ant, num_ofdm_syms, num_subcarriers]

  • h_hat (tf.Tensor, complex64) – LS channel estimate from DMRS processing. Shape: [B, num_bs, num_bs_ant, num_ue, num_streams, num_ofdm_syms, num_subcarriers]

  • err_var (tf.Tensor, float32) – Channel estimation error variance per stream and RE. Shape: [B, 1, 1, num_ue, num_streams, num_ofdm_syms, num_subcarriers]

  • no (tf.Tensor, float32) – Noise variance, shape [B] or scalar.

  • constellation (tf.Tensor, complex64, optional) – Trainable constellation points from transmitter. If provided, updates the demapper’s constellation for consistent symbol mapping. Shape: [num_constellation_points]

  • training (bool, optional) – Training mode flag (currently unused, for Keras API compatibility).

Returns:

Log-likelihood ratios for all coded bits. Shape: [B, num_ue, num_streams_per_ue, num_data_symbols * num_bits_per_symbol]

Return type:

tf.Tensor, float32

Notes

The processing flow:

  1. Feature assembly: Extracts and concatenates multiple signal representations (channel, received signal, matched filter, Gram matrix structure, estimation error, noise level).

  2. Shared backbone: Processes features through ResBlocks to learn joint representations.

  3. CE refinement: Predicts channel and error variance corrections, applies them scaled by trainable parameters.

  4. LMMSE equalization: Classical equalization with refined estimates, using whitened interference model for improved performance.

  5. LLR refinement: Predicts additive corrections to LMMSE LLRs, enabling learned compensation for model mismatch.

Channel estimation refinement operates on ALL OFDM symbols (including pilots) to leverage full spatial-frequency context, while detection operates only on DATA symbols to produce the final LLRs.

Parameters:
  • cfg (Config)

  • num_conv2d_filters (int)

  • num_shared_res_blocks (int)

  • num_det_res_blocks (int)

Trainable Receiver

class demos.pusch_autoencoder.src.pusch_trainable_receiver.PUSCHTrainableReceiver(*args, **kwargs)[source]

Bases: PUSCHReceiver

PUSCH Receiver variant for autoencoder training with neural detection.

This class extends PUSCHReceiver to support end-to-end training by:

  1. Returning soft LLRs before TB decoding when in training mode

  2. Passing the trainable constellation to the neural detector for consistent symbol demapping

  3. Exposing the neural detector’s trainable variables

The receiver supports both perfect CSI (ground-truth channel provided) and imperfect CSI (LS estimation with neural refinement) scenarios.

Parameters:
  • *args (tuple) – Positional arguments passed to PUSCHReceiver.

  • training (bool) – If True, call() returns LLRs without TB decoding. If False, call() returns decoded bits. Default False.

  • pusch_transmitter (PUSCHTrainableTransmitter, optional) – Reference to the transmitter for constellation synchronization. Required for proper demapping when constellation is trainable.

  • **kwargs (dict) – Keyword arguments passed to PUSCHReceiver (e.g., mimo_detector, input_domain, channel_estimator).

Example

>>> # Training setup
>>> detector = PUSCHNeuralDetector(cfg)
>>> rx = PUSCHTrainableReceiver(
...     mimo_detector=detector,
...     input_domain="freq",
...     pusch_transmitter=tx,
...     training=True
... )
>>> llr = rx(y, no)  # Returns soft LLRs for BCE loss
>>> # Inference setup
>>> rx_eval = PUSCHTrainableReceiver(
...     mimo_detector=detector,
...     input_domain="freq",
...     pusch_transmitter=tx,
...     training=False
... )
>>> b_hat = rx_eval(y, no)  # Returns decoded bits for BER

Notes

The same constellation is used for mapping (TX) and demapping (RX) for autoencoder training. The _get_normalized_constellation() method retrieves the current (normalized) constellation from the transmitter at each forward pass, ensuring the demapper always uses the same points that were used for mapping, even as they evolve during training.

property trainable_variables

Collect trainable variables from the neural MIMO detector.

Returns:

Trainable variables from self._mimo_detector, or empty list if the detector has no trainable variables (e.g., classical LMMSE).

Return type:

list of tf.Variable

Notes

This property enables the optimizer to access detector weights without knowing the internal structure. The receiver itself has no trainable parameters; all learning happens in the neural detector.

call(y, no, h=None)[source]

Execute receiver processing chain with optional training mode.

Parameters:
  • y (tf.Tensor, complex64) –

    Received signal. Shape depends on input_domain:

    • "freq": [batch, num_rx, num_rx_ant, num_ofdm_symbols, num_subcarriers]

    • "time": [batch, num_rx, num_rx_ant, num_samples]

  • no (tf.Tensor, float32) – Noise variance, shape [batch] or scalar.

  • h (tf.Tensor, complex64, optional) – Ground-truth channel matrix for perfect CSI mode. Shape: [batch, num_rx, num_rx_ant, num_tx, num_tx_ant, num_ofdm_symbols, num_subcarriers] Required if channel_estimator="perfect" was set in constructor.

Returns:

  • Training mode (training=True): LLRs for coded bits, shape [batch, num_ue, num_coded_bits]

  • Inference mode (training=False): Decoded information bits, shape [batch, num_ue, tb_size]

  • If return_tb_crc_status=True in inference mode, returns tuple (b_hat, tb_crc_status)

Return type:

tf.Tensor

Notes

The processing chain follows standard PUSCH reception:

  1. OFDM demodulation (if time domain): FFT and CP removal

  2. Channel estimation: Perfect CSI passthrough or LS estimation

  3. MIMO detection: Neural detector with constellation sync

  4. Layer demapping: Separate streams back to UE data

  5. TB decoding (inference only): LDPC decoding, CRC check

In training mode, step 5 is skipped because:

  • TB decoding is non-differentiable (hard decisions)

  • Need LLRs for BCE loss against coded bits c

  • The loss provides gradients to train the neural detector

The squeeze operation on LLRs (when shape has singleton dimension 2) handles the case where num_layers=1, ensuring consistent output shape regardless of MIMO layer configuration.

System

class demos.pusch_autoencoder.src.system.PUSCHLinkE2E(*args, **kwargs)[source]

Bases: Model

End-to-end differentiable PUSCH link model for MU-MIMO autoencoder training.

This class simulates a complete 5G NR PUSCH uplink that can operate in two modes:

  1. Baseline mode (use_autoencoder=False): Uses standard QAM constellation with LS channel estimation + LMMSE equalization for BER/BLER benchmarking.

  2. Autoencoder mode (use_autoencoder=True): Uses trainable constellation geometry AND labeling, plus a neural detector, enabling end-to-end optimization.

The model supports both perfect and imperfect CSI scenarios, where imperfect CSI uses LS channel estimation with optional neural refinement.

Parameters:
  • channel_model (tuple or CIRDataset) – For baseline mode: CIRDataset object for on-demand CIR generation. For autoencoder mode: tuple (a, tau) of pre-loaded CIR tensors.

  • perfect_csi (bool) – If True, provides ground-truth channel to the receiver. If False, receiver performs LS channel estimation.

  • use_autoencoder (bool) – If True, uses trainable transmitter and neural detector. If False, uses standard PUSCH TX/RX with LMMSE detection.

  • training (bool) – If True, call() returns the training loss (BCE + regularization). If False, call() returns (bits, bits_hat) for BER evaluation.

  • config (Config, optional) – System configuration. Defaults to Config() if not provided. Use this to customize system parameters like num_bs_ant.

  • gumbel_temperature (float, optional) – Initial Gumbel-Softmax temperature for learnable labeling. Lower values produce sharper (more discrete) assignments. Default 0.5.

Notes

For autoencoder mode, channel_model must be a tuple (a, tau) where a contains complex CIR coefficients with shape [num_samples, num_bs, num_bs_ant, num_ue, num_ue_ant, num_paths, num_time_steps] and tau contains path delays with shape [num_samples, num_bs, num_ue, num_paths]. For baseline mode, channel_model must be a valid CIRDataset.

The trainable transmitter now includes:

  • Learnable constellation geometry: Point positions can be optimized starting from random initialization to escape QAM local minima.

  • Learnable labeling: Bit-to-symbol assignment learned via Gumbel-Softmax enables optimization beyond fixed Gray coding.

This dual optimization enables full geometric shaping potential that is limited when using fixed Gray labeling.

Additional notes:

  • self._cfg contains PUSCH resource grid information after construction.

  • self.trainable_variables returns all trainable weights (TX + RX).

  • In training mode, call() returns a scalar loss tensor.

  • In inference mode, call() returns (b, b_hat) bit tensors.

  • The PUSCH configuration (PRBs, MCS, layers) remains fixed after init.

  • Constellation normalization maintains unit average power.

  • Channel model type (tuple vs CIRDataset) determines internal processing path.

Example

>>> # Autoencoder training setup
>>> cir_manager = CIRManager()
>>> a, tau = cir_manager.load_from_tfrecord(group_for_mumimo=True)
>>> model = PUSCHLinkE2E((a, tau), perfect_csi=False,
...                       use_autoencoder=True, training=True)
>>> loss = model(batch_size=32, ebno_db=10.0)
>>> # Train with separate optimizers for geometry, labeling, and receiver
property trainable_variables

Collect all trainable variables from transmitter and receiver.

Returns:

Combined list of transmitter and receiver trainable variables. For autoencoder mode, includes: - Constellation geometry (real/imag coordinates) - Labeling logits (permutation matrix) - Neural detector weights (ResBlocks, correction scales)

Return type:

list of tf.Variable

property tx_geometry_variables

Get transmitter geometry (constellation point) variables.

Returns:

Two-element list: [_points_r, _points_i] for separate optimization from labeling variables.

Return type:

list of tf.Variable

property tx_labeling_variables

Get transmitter labeling (permutation) variables.

Returns:

One-element list: [_labeling_logits] for separate optimization from geometry variables.

Return type:

list of tf.Variable

property rx_variables

Get receiver trainable variables.

Returns:

All trainable variables from the receiver (neural detector weights including correction scales and ResBlock parameters).

Return type:

list of tf.Variable

property constellation

Get the current normalized constellation points.

Returns:

Complex tensor of shape [num_points] with unit average power. For 16-QAM, this is 16 complex values.

Return type:

tf.Tensor

property gumbel_temperature

Get current Gumbel-Softmax temperature.

Returns:

Current temperature value, or None if transmitter does not support learnable labeling.

Return type:

float or None

get_hard_labeling()[source]

Get the hard (argmax) labeling permutation from the transmitter.

Returns:

Integer permutation indices if learnable labeling is enabled, otherwise None.

Return type:

tf.Tensor or None

call(batch_size, ebno_db)

Execute forward pass through the end-to-end PUSCH link.

Parameters:
  • batch_size (int) – Number of transport blocks to simulate in parallel.

  • ebno_db (tf.Tensor) – Energy per bit to noise power spectral density ratio in dB. Can be scalar (same SNR for all samples) or vector [batch_size].

Returns:

  • Training mode: Scalar loss tensor (BCE)

  • Inference mode: Tuple (b, b_hat) where: - b: Original bits, shape [batch_size, num_ue, tb_size] - b_hat: Detected bits, same shape as b

Return type:

tf.Tensor or tuple

Notes

JIT compilation is disabled (jit_compile=False) because the neural detector uses dynamic shapes and control flow that are incompatible with XLA compilation.

The processing flow follows standard PUSCH transmission:

  1. Transmitter: Bit generation, encoding, constellation mapping with learnable labeling, layer mapping, resource grid mapping, precoding (if enabled), OFDM modulation (if time domain)

  2. Channel: Random sampling from pre-loaded CIR tensors (autoencoder) or on-demand CIR generation (baseline), CIR-to-OFDM conversion, AWGN addition

  3. Receiver: OFDM demodulation (if time domain), channel estimation (perfect or LS), neural MIMO detection with constellation synchronization, layer demapping, transport block decoding (inference)

  4. Loss computation (training only): BCE loss on soft LLRs plus minimum distance regularization to prevent constellation collapse