pusch_autoencoder¶
This module implements an end-to-end autoencoder for the 5G NR Physical Uplink Shared Channel (PUSCH).
Configuration¶
- class demos.pusch_autoencoder.src.config.Config(num_bs_ant=32)[source]¶
Bases:
objectCentral configuration for the PUSCH-Autoencoder demo.
This dataclass consolidates all hard-coded system parameters to ensure consistency across CIR generation, model instantiation, training, and inference. Parameters are intentionally non-configurable at instantiation (
init=False) to enforce a single, validated system configuration. The number of BS antennas (num_bs_ant) is configurable to allow flexibility in array size.The configuration defines a MU-MIMO uplink scenario with:
4 UEs, each with 4 antennas (cross-polarized)
1 BS with configurable antennas (default 16, cross-polarized)
16-QAM modulation (MCS index 14, table 1)
Site- (Munich-) specific ray-traced channel
- Parameters:
num_bs_ant (int, optional) – Number of BS antennas (cross-polarized array). Must be even since the antenna array uses cross-polarization (num_cols = num_bs_ant // 2). Default is 16.
Notes
After
__post_init__,num_bits_per_symbolandtarget_coderateare populated from the 3GPP NR MCS tables.All properties return consistent, validated values.
System dimensions (antennas, UEs) remain fixed after instantiation.
PUSCH-related properties (
resource_grid,pusch_pilot_indices, etc.) are set externally byPUSCHLinkE2Eafter transmitter construction.
Example
>>> cfg = Config() >>> print(cfg.num_bits_per_symbol) # 4 for 16-QAM 4 >>> print(cfg.num_ue, cfg.num_bs_ant) 4 16 >>> cfg_32ant = Config(num_bs_ant=32) >>> print(cfg_32ant.num_bs_ant) 32
Notes
The MCS decoder call in
__post_init__usestransform_precoding=Trueandpi2bpsk=False, which is appropriate for standard PUSCH without DFT-s-OFDM transform precoding at the physical layer.- property resource_grid¶
Sionna OFDM resource grid object.
Set externally by
PUSCHLinkE2Eafter transmitter construction. Used by the neural detector to determine grid dimensions.- Type:
ResourceGrid
- property pusch_pilot_indices¶
OFDM symbol indices containing DMRS pilots.
Set externally by
PUSCHLinkE2EfromPUSCHConfig.dmrs_symbol_indices. The neural detector uses this to separate pilot and data processing.- Type:
List[int]
- property pusch_num_subcarriers¶
Number of subcarriers in the PUSCH allocation.
Set externally by
PUSCHLinkE2EfromPUSCHConfig.num_subcarriers.- Type:
Channel Impulse Response¶
- class demos.pusch_autoencoder.src.cir_generator.CIRGenerator(a, tau, num_tx)[source]¶
Bases:
objectInfinite generator for random CIR sample selection.
This class wraps a pre-loaded CIR dataset and provides an infinite generator interface compatible with Sionna’s
CIRDataset. Each call to the generator yields a new random combination of UE channels, simulating the scenario where different UEs are co-scheduled in each transmission slot.- Parameters:
a (tf.Tensor or np.ndarray, complex64) – CIR path coefficients from ray-tracing. Shape:
[num_samples, 1, num_bs_ant, 1, num_ue_ant, max_paths, num_time_steps]tau (tf.Tensor or np.ndarray, float32) – Path delays in seconds. Shape:
[num_samples, 1, 1, max_paths]num_tx (int) – Number of transmitters (UEs) to sample for each yield. Typically equals
num_uein the system configuration.
Example
>>> a, tau = manager.load_from_tfrecord() # [5000, 1, 16, 1, 4, 51, 14] >>> gen = CIRGenerator(a, tau, num_tx=4) >>> for a_sample, tau_sample in gen(): ... print(a_sample.shape) # [16, 4, 4, 51, 14] ... break
Notes
The generator uses
tf.random.uniform_candidate_samplerfor efficient random selection without replacement. This is more efficient thantf.random.shufflefor large datasets when only selecting a small subset.The dimension transpositions in
__call__reorder from the storage format (sample-first) to Sionna’s expected format (antenna-first):Input
a:[num_tx, 1, num_bs_ant, 1, num_ue_ant, paths, time]Output
a:[num_bs_ant, num_tx, num_ue_ant, paths, time]
This reordering places the BS antenna dimension first, followed by the UE (TX) dimension, matching Sionna’s channel tensor conventions.
- class demos.pusch_autoencoder.src.cir_manager.CIRManager(config=None)[source]¶
Bases:
objectUnified manager for CIR generation, storage, and loading.
This class encapsulates the complete CIR data pipeline for the PUSCH autoencoder demo, including:
Munich scene setup with configurable antenna arrays
Radio map computation for coverage-aware UE position sampling
Ray-traced CIR generation with multipath propagation
TFRecord serialization for efficient data storage
MU-MIMO sample grouping for multi-user training
- Parameters:
config (Config, optional) – Configuration object with system parameters. If
None, uses defaultConfig()with standard MU-MIMO settings.Pre-conditions
--------------
(sionna.rt). (- Sionna must be installed with ray-tracing support)
(recommended (- GPU with sufficient memory for ray-tracing)
Sionna). (- Munich scene assets must be available (bundled with)
Post-conditions
---------------
setup_scene() (- After)
compute_radio_map() (- After)
generate_cir_data() (- After)
Invariants
----------
counts (- Config parameters (antenna)
init. (bandwidths) are fixed after)
position (- Scene geometry (TX)
fixed. (array orientations) is)
Example
>>> # Generation workflow (offline, run once) >>> manager = CIRManager() >>> manager.generate_and_save([0, 1, 2], tfrecord_dir="../cir_tfrecords")
>>> # Loading workflow (online, run each training) >>> manager = CIRManager() >>> a, tau = manager.load_from_tfrecord(group_for_mumimo=True) >>> model = PUSCHLinkE2E((a, tau), ...)
Notes
The CIR generation process is computationally expensive but only needs to run once. Generated TFRecords can be reused across many training runs with different hyperparameters.
The
group_for_mumimooption inload_from_tfrecord()combinesnum_ueindividual CIRs into single MU-MIMO samples, simulating co-scheduled uplink transmissions from multiple UEs.- setup_scene()[source]¶
Initialize Munich scene with BS and UE antenna configurations.
This method loads the Munich urban scene and configures:
BS array: 16x2 cross-polarized panel (32 elements) on a rooftop
UE array: 2x2 cross-polarized array (4 elements) with isotropic pattern
Camera: Overhead view for radio map visualization
- Returns:
scene – Configured scene object ready for ray-tracing.
- Return type:
sionna.rt.Scene
Notes
The BS is positioned at
[8.5, 21, 27]meters, which places it on a building rooftop in the Munich scene. Thelook_atdirection points toward the main coverage area where UEs will be sampled.The BS uses
tr38901antenna pattern (3GPP sector antenna) while UEs useiso(isotropic) pattern, reflecting realistic deployments where BS antennas are directional but UE antennas are omnidirectional.
- compute_radio_map(save_images=True)[source]¶
Compute path gain radio map for coverage-based UE sampling.
The radio map provides spatially-resolved path gain information used to sample UE positions in areas with valid coverage (avoiding dead zones and extreme near-field regions).
- Parameters:
save_images (bool) – If
True, saves radio map visualization to PNG file.- Returns:
rm – Radio map object with path gain values per spatial cell.
- Return type:
sionna.rt.RadioMap
Notes
Radio map computation uses Monte Carlo ray-tracing with 10^7 samples per TX, which takes several minutes but provides smooth coverage maps. The result is cached in
self.rmfor subsequent UE position sampling.
- generate_cir_data(seed_offset=0, max_num_paths=0)[source]¶
Generate ray-traced CIR data for multiple UE positions.
This method performs the core ray-tracing loop:
Sample UE positions from radio map (coverage-aware)
Place receiver objects at sampled positions
Run path solver to compute multipath propagation
Extract CIR (path gains and delays) for each link
- Parameters:
- Returns:
a (np.ndarray, complex64) – CIR coefficients with shape
[num_samples, 1, num_bs_ant, 1, num_ue_ant, max_paths, num_time_steps]tau (np.ndarray, float32) – Path delays with shape
[num_samples, 1, 1, max_paths]max_num_paths (int) – Maximum number of paths across all generated samples.
Notes
The generation loop processes
batch_size_cirpositions at a time, updating receiver positions and re-running the path solver. Progress is printed to console since this can take hours for large datasets.Empty CIRs (zero total power) are filtered out, as they represent positions with no valid propagation paths (e.g., inside buildings).
The seed formula
idx + seed_offset * 1000ensures different files (with different seed_offsets) produce non-overlapping position samples.
- save_to_tfrecord(a, tau, filename)[source]¶
Serialize CIR data to TFRecord format for efficient storage.
TFRecord provides efficient sequential reading, compression, and seamless integration with
tf.datapipelines for training.- Parameters:
a (np.ndarray, complex64) – CIR coefficients, shape
[num_samples, ...].tau (np.ndarray, float32) – Path delays, shape
[num_samples, ...].filename (str) – Output TFRecord file path.
Notes
Each sample is stored as a separate TFRecord
Examplecontaining:a: Serialized complex64 tensor (path gains)tau: Serialized float32 tensor (path delays)a_shape: Shape metadata for deserializationtau_shape: Shape metadata for deserialization
The shape metadata enables reconstruction of tensors with varying path counts across different TFRecord files.
- load_from_tfrecord(tfrecord_dir=None, group_for_mumimo=False)[source]¶
Load CIR data from TFRecord files with optional MU-MIMO grouping.
- Parameters:
tfrecord_dir (str, optional) – Directory containing TFRecord files, relative to this module. If not provided, defaults to
cir_tfrecords_ant{num_bs_ant}within the demo directory.group_for_mumimo (bool) – If
True, groupsnum_ueindividual CIRs into MU-MIMO samples. This simulates co-scheduled uplink transmissions.
- Returns:
all_a (tf.Tensor, complex64) – CIR coefficients. Shape depends on
group_for_mumimo:False:[num_samples, 1, num_bs_ant, 1, num_ue_ant, max_paths, num_time_steps]True:[num_mu_samples, 1, num_bs_ant, num_ue, num_ue_ant, max_paths, num_time_steps]
all_tau (tf.Tensor, float32) – Path delays. Shape depends on
group_for_mumimo:False:[num_samples, 1, 1, max_paths]True:[num_mu_samples, 1, num_ue, max_paths]
Notes
When
group_for_mumimo=True, consecutive samples are combined to form multi-user scenarios. For example, withnum_ue=4, samples 0-3 become MU-MIMO sample 0, samples 4-7 become sample 1, etc.This grouping is valid because each original sample represents an independent UE position, and combining them simulates the realistic scenario where multiple UEs transmit simultaneously.
- build_channel_model(batch_size=None, num_bs=None, num_bs_ant=None, num_ue=None, num_ue_ant=None, num_time_steps=None, tfrecord_dir=None)[source]¶
Build CIRDataset channel model from TFRecord files.
This method creates a Sionna
CIRDatasetthat can be used withOFDMChannelfor baseline BER/BLER evaluation. The dataset provides on-demand CIR sampling during simulation.- Parameters:
batch_size (int, optional) – Batch size for dataset. Default from config.
num_bs (int, optional) – Number of base stations. Default from config.
num_bs_ant (int, optional) – BS antenna count. Default from config.
num_ue (int, optional) – Number of UEs. Default from config.
num_ue_ant (int, optional) – UE antenna count. Default from config.
num_time_steps (int, optional) – OFDM symbols per slot. Default from config.
tfrecord_dir (str, optional) – Directory containing TFRecord files. If not provided, defaults to
cir_tfrecords_ant{num_bs_ant}within the demo directory.
- Returns:
channel_model – Channel model compatible with Sionna’s
OFDMChannel.- Return type:
CIRDataset
Notes
This method is primarily used for baseline evaluation where a
CIRDatasetobject is needed. For autoencoder training, useload_from_tfrecord(group_for_mumimo=True)directly to get tensors that can be indexed during training.
- save_visualization_ue_positions(filename='munich_ue_positions.png')[source]¶
Render radio map with current UE positions overlaid.
- Parameters:
filename (str) – Output PNG file path.
- Raises:
RuntimeError – If scene, radio map, or camera are not initialized.
- generate_and_save(seed_offsets, tfrecord_dir=None, save_radio_map=True)[source]¶
Complete CIR generation pipeline: generate, visualize, and save.
This is the main entry point for offline CIR dataset creation. It handles scene setup, radio map computation, CIR generation, and TFRecord serialization in a single call.
- Parameters:
seed_offsets (int or list of int) – Random seed offset(s) for UE position sampling. Each seed produces a separate TFRecord file.
tfrecord_dir (str, optional) – Output directory for TFRecord files (relative to this module). If not provided, defaults to
cir_tfrecords_ant{num_bs_ant}within the demo directory.save_radio_map (bool) – If
True, saves radio map and UE position visualizations.
Example
>>> manager = CIRManager() >>> # Generate 3 files with 5000 CIRs each (15000 total) >>> manager.generate_and_save([0, 1, 2])
Notes
Multiple seed offsets produce independent position samples, which can be beneficial for:
Increasing total dataset size beyond single-file limits
Enabling parallel generation on multiple machines
Creating train/validation/test splits with different seeds
The output directory includes the antenna count suffix to keep CIR data for different antenna configurations separate.
Trainable Transmitter¶
- class demos.pusch_autoencoder.src.pusch_trainable_transmitter.PUSCHTrainableTransmitter(*args, **kwargs)[source]¶
Bases:
PUSCHTransmitterPUSCH Transmitter with SYMMETRIC trainable constellation geometry AND labeling.
This subclass enforces 4-fold symmetry (I-axis, Q-axis, origin) in the constellation while supporting learnable labeling for autoencoder-based communication system design.
- Parameters:
*args (tuple) – Positional arguments passed to
PUSCHTransmitter.training (bool) – If
True, constellation and labeling are trainable with soft Gumbel-Softmax assignment. DefaultFalse.gumbel_temperature (float) – Temperature for Gumbel-Softmax. Lower = sharper (more discrete). Default 0.5.
**kwargs (dict) – Keyword arguments passed to
PUSCHTransmitter.
Example
>>> tx = PUSCHTrainableTransmitter(pusch_configs, output_domain="freq", ... training=True, gumbel_temperature=0.5) >>> x_map, x, b, c = tx(batch_size=32)
Notes
The symmetric constellation is stored as
[num_base_points]complex values (typically 4 for 16-QAM), and the full constellation is computed via reflections. This guarantees 4-fold symmetry is preserved during training.The learnable labeling uses a permutation logits matrix of shape
[num_points, num_points]. Rowicontains logits for which constellation point should be assigned to bit patterni.- static generate_random_symmetric_constellation(num_points, seed=None)[source]¶
Generate random constellation with 4-fold symmetry and unit average energy.
Creates a constellation by: 1. Generating num_points/4 random base points in the first quadrant 2. Reflecting them across I-axis, Q-axis, and origin 3. Normalizing to unit average power
- Parameters:
- Returns:
Symmetric constellation points with shape
[num_points]. Points are ordered: [Q1_points, Q2_points, Q4_points, Q3_points]- Return type:
tf.Tensor, complex64
Notes
The generation process:
Sample num_points/4 base points in Q1 (I>0, Q>0)
Reflect across Q-axis: negate real part -> Q2 (I<0, Q>0)
Reflect across I-axis: negate imag part -> Q4 (I>0, Q<0)
Reflect across origin: negate both parts -> Q3 (I<0, Q<0)
Normalize complete constellation to unit average power
Example
>>> points = PUSCHTrainableTransmitter.generate_random_symmetric_constellation(16, seed=42) >>> print(f"Power: {tf.reduce_mean(tf.abs(points)**2):.6f}") # Should be ~1.0 >>> # Verify I-axis symmetry (Q1 vs Q4) >>> print(f"I-sym check: {tf.reduce_mean(tf.abs(points[0] - tf.math.conj(points[8]))):.6f}") # ~0
- property trainable_variables¶
base geometry + labeling.
- Returns:
Three-element list:
[_base_points_r, _base_points_i, _labeling_logits]where: -_base_points_r: Real parts of Q1 base points -_base_points_i: Imaginary parts of Q1 base points -_labeling_logits: Permutation logits for bit-to-point mapping- Return type:
list of tf.Variable
- Type:
Return all trainable variables
- property geometry_variables¶
Return only the base geometry (Q1 constellation points) variables.
- property labeling_variables¶
Return only the labeling (permutation) variable.
- get_base_points()[source]¶
Get the base constellation points (Q1 only, not normalized).
- Returns:
Base constellation points in first quadrant, shape
[num_base_points]. For 16-QAM, this is[4]complex values.- Return type:
tf.Tensor, complex64
- get_normalized_constellation()[source]¶
Compute 4-fold symmetric and power-normalized constellation points.
- Returns:
Normalized constellation points with 4-fold symmetry, shape
[num_points]. For 16-QAM, this is[16]complex values.- Return type:
tf.Tensor, complex64
Notes
The normalization process:
Retrieve base points from trainable variables (Q1 only)
Reflect to create full 4-fold symmetric constellation
Normalize to unit average power
The normalization is differentiable and does NOT break symmetry since it’s a uniform scaling operation applied to all points equally.
Symmetry properties maintained: - I-axis: constellation[i] = conj(constellation[j]) for mirrored i,j - Q-axis: constellation[i] = -conj(constellation[k]) for mirrored i,k - Origin: constellation[i] = -constellation[m] for opposite i,m
- get_soft_labeling_matrix(hard=False)[source]¶
Get the current labeling matrix (permutation of constellation points).
- Parameters:
hard (bool) – If True, return hard one-hot assignment. If False, return soft Gumbel-Softmax probabilities.
- Returns:
Labeling matrix, shape
[num_points, num_points]. Row i indicates which constellation point(s) are assigned to bit pattern i.- Return type:
tf.Tensor, float
Example
For hard assignment, row i will have a single 1 at column j, meaning bit pattern i maps to constellation point j.
For soft assignment during training, row i contains probabilities summing to 1, enabling gradient flow through the labeling.
- call(inputs)[source]¶
Execute transmitter processing chain with trainable constellation and labeling.
The symmetric constellation is computed from base points before each forward pass.
- Parameters:
inputs (int or tf.Tensor) – If
return_bits=True:intspecifying batch size. Ifreturn_bits=False:tf.Tensorof input bits.- Returns:
If
return_bits=True: tuple(x_map, x, b, c)Ifreturn_bits=False: justx(transmitted signal).- Return type:
tuple or tf.Tensor
Notes
The processing chain:
Compute full 16-point symmetric constellation from 4 base points
Standard PUSCH transmission (bits -> symbols -> OFDM)
Symbol mapping uses learnable labeling via Gumbel-Softmax
Neural Detector¶
- class demos.pusch_autoencoder.src.pusch_neural_detector.Conv2DResBlock(*args, **kwargs)[source]¶
Bases:
LayerPre-activation residual block with two convolutional layers.
Implements the pre-activation ResNet variant where normalization and activation precede each convolution. This ordering improves gradient flow and enables training of deeper networks.
The block computes:
output = input + conv2(act(norm(conv1(act(norm(input))))))- Parameters:
Notes
Uses LayerNormalization (not BatchNorm) for stable training with small batch sizes typical in communication system simulation.
LeakyReLU with α=0.1 prevents dead neurons while maintaining near-linear behavior for small negative inputs.
Identity skip connection requires input and output channel counts to match; caller must ensure
filtersequals input channels.
- class demos.pusch_autoencoder.src.pusch_neural_detector.PUSCHNeuralDetector(*args, **kwargs)[source]¶
Bases:
LayerNeural MIMO detector with learned channel estimation refinement for PUSCH.
This detector implements a residual learning architecture that refines classical LS channel estimates and LMMSE-based soft symbol estimates using convolutional neural networks. The key design principle is that learning corrections to strong classical baselines is more effective than learning detection from scratch.
Architecture¶
The detector processes data through six stages:
Feature extraction: Assembles input features including LS channel estimate, received signal, matched filter output, Gram matrix structure, estimation error variance, and noise level.
Shared backbone: Processes features through convolutional ResBlocks to learn joint representations useful for both CE refinement and detection.
CE refinement head: Predicts additive corrections Δh and multiplicative log-domain corrections Δlog(err_var) to the LS estimates.
Scaled correction application: Applies learned corrections scaled by trainable parameters that start at zero, enabling gradual departure from classical behavior during training.
Classical LMMSE: Performs LMMSE equalization using refined channel estimate and error variance, followed by max-log demapping.
LLR refinement: Injects LMMSE outputs back into the network to predict additive LLR corrections, again scaled by a trainable parameter.
- param cfg:
System configuration containing MIMO dimensions, modulation order, and PUSCH resource grid information.
- type cfg:
~demos.pusch_autoencoder.src.config.Config
- param num_conv2d_filters:
Number of filters in all convolutional layers. Higher values increase model capacity but also computational cost. Default 128.
- type num_conv2d_filters:
int
- param num_shared_res_blocks:
Number of ResBlocks in the shared backbone. Controls depth of joint feature learning. Default 4.
- type num_shared_res_blocks:
int
- param num_det_res_blocks:
Number of ResBlocks in the detection continuation path. Controls capacity for LLR refinement. Default 6.
- type num_det_res_blocks:
int
- param kernel_size:
Spatial kernel size for ResBlock convolutions. Default
(3, 3).- type kernel_size:
tuple of int
- param Pre-conditions:
- param ————–:
- param -
cfg.pusch_pilot_indicesmust be set (byPUSCHLinkE2E) before: instantiation to enable pilot/data symbol separation.
- param - Input tensors must follow Sionna’s PUSCH dimension conventions.:
- param Post-conditions:
- param —————:
- param -
trainable_variablesreturns correction scales first: weights (enables separate optimizer configuration).
- param then network:
weights (enables separate optimizer configuration).
- param -
last_h_hat_refinedandlast_err_var_refinedcontain the most: recent refined estimates (useful for auxiliary losses).
- param Invariants:
- param ———-:
- param - Correction scales initialized to 0.0:
classical LMMSE exactly (graceful degradation if training fails).
- param so initial behavior matches:
classical LMMSE exactly (graceful degradation if training fails).
- param - Error variance correction scale is always positive (softplus transform).:
- param - Output LLR shape matches Sionna’s
PUSCHReceiverexpectations.:
Example
>>> cfg = Config() >>> # ... set cfg.pusch_pilot_indices via PUSCHLinkE2E ... >>> detector = PUSCHNeuralDetector(cfg, num_conv2d_filters=64) >>> llr = detector(y, h_hat, err_var, no)
Notes
The trainable correction scales serve multiple purposes:
Safe initialization: Starting at 0.0 means the detector initially behaves exactly like classical LMMSE, providing a stable starting point.
Interpretability: Scale magnitudes indicate how much the network deviates from classical processing (useful for debugging).
Gradient balancing: Separate scales for h, err_var, and LLR allow independent learning rates for different correction types.
The error variance scale uses softplus to ensure positivity, since negative variance would be physically meaningless and cause numerical issues in LMMSE computation.
- property err_var_correction_scale¶
Effective error variance correction scale (always positive).
- Returns:
Scalar tensor with softplus-transformed scale value. softplus(x) = log(1 + exp(x)) ensures output > 0.
- Return type:
tf.Tensor
- property trainable_variables¶
Collect all trainable variables in a specific order.
The ordering places correction scales first, enabling separate optimizer configuration (e.g., higher learning rate for scales).
- Returns:
Ordered list: [correction_scales, backbone_weights, ce_head_weights, detection_weights].
- Return type:
list of tf.Variable
- call(y, h_hat, err_var, no, constellation=None, training=None)[source]¶
Execute neural MIMO detection with channel estimation refinement.
- Parameters:
y (tf.Tensor, complex64) – Received OFDM signal after FFT. Shape:
[B, num_bs, num_bs_ant, num_ofdm_syms, num_subcarriers]h_hat (tf.Tensor, complex64) – LS channel estimate from DMRS processing. Shape:
[B, num_bs, num_bs_ant, num_ue, num_streams, num_ofdm_syms, num_subcarriers]err_var (tf.Tensor, float32) – Channel estimation error variance per stream and RE. Shape:
[B, 1, 1, num_ue, num_streams, num_ofdm_syms, num_subcarriers]no (tf.Tensor, float32) – Noise variance, shape
[B]or scalar.constellation (tf.Tensor, complex64, optional) – Trainable constellation points from transmitter. If provided, updates the demapper’s constellation for consistent symbol mapping. Shape:
[num_constellation_points]training (bool, optional) – Training mode flag (currently unused, for Keras API compatibility).
- Returns:
Log-likelihood ratios for all coded bits. Shape:
[B, num_ue, num_streams_per_ue, num_data_symbols * num_bits_per_symbol]- Return type:
tf.Tensor, float32
Notes
The processing flow:
Feature assembly: Extracts and concatenates multiple signal representations (channel, received signal, matched filter, Gram matrix structure, estimation error, noise level).
Shared backbone: Processes features through ResBlocks to learn joint representations.
CE refinement: Predicts channel and error variance corrections, applies them scaled by trainable parameters.
LMMSE equalization: Classical equalization with refined estimates, using whitened interference model for improved performance.
LLR refinement: Predicts additive corrections to LMMSE LLRs, enabling learned compensation for model mismatch.
Channel estimation refinement operates on ALL OFDM symbols (including pilots) to leverage full spatial-frequency context, while detection operates only on DATA symbols to produce the final LLRs.
Trainable Receiver¶
- class demos.pusch_autoencoder.src.pusch_trainable_receiver.PUSCHTrainableReceiver(*args, **kwargs)[source]¶
Bases:
PUSCHReceiverPUSCH Receiver variant for autoencoder training with neural detection.
This class extends
PUSCHReceiverto support end-to-end training by:Returning soft LLRs before TB decoding when in training mode
Passing the trainable constellation to the neural detector for consistent symbol demapping
Exposing the neural detector’s trainable variables
The receiver supports both perfect CSI (ground-truth channel provided) and imperfect CSI (LS estimation with neural refinement) scenarios.
- Parameters:
*args (tuple) – Positional arguments passed to
PUSCHReceiver.training (bool) – If
True,call()returns LLRs without TB decoding. IfFalse,call()returns decoded bits. DefaultFalse.pusch_transmitter (PUSCHTrainableTransmitter, optional) – Reference to the transmitter for constellation synchronization. Required for proper demapping when constellation is trainable.
**kwargs (dict) – Keyword arguments passed to
PUSCHReceiver(e.g.,mimo_detector,input_domain,channel_estimator).
Example
>>> # Training setup >>> detector = PUSCHNeuralDetector(cfg) >>> rx = PUSCHTrainableReceiver( ... mimo_detector=detector, ... input_domain="freq", ... pusch_transmitter=tx, ... training=True ... ) >>> llr = rx(y, no) # Returns soft LLRs for BCE loss
>>> # Inference setup >>> rx_eval = PUSCHTrainableReceiver( ... mimo_detector=detector, ... input_domain="freq", ... pusch_transmitter=tx, ... training=False ... ) >>> b_hat = rx_eval(y, no) # Returns decoded bits for BER
Notes
The same constellation is used for mapping (TX) and demapping (RX) for autoencoder training. The
_get_normalized_constellation()method retrieves the current (normalized) constellation from the transmitter at each forward pass, ensuring the demapper always uses the same points that were used for mapping, even as they evolve during training.- property trainable_variables¶
Collect trainable variables from the neural MIMO detector.
- Returns:
Trainable variables from
self._mimo_detector, or empty list if the detector has no trainable variables (e.g., classical LMMSE).- Return type:
list of tf.Variable
Notes
This property enables the optimizer to access detector weights without knowing the internal structure. The receiver itself has no trainable parameters; all learning happens in the neural detector.
- call(y, no, h=None)[source]¶
Execute receiver processing chain with optional training mode.
- Parameters:
y (tf.Tensor, complex64) –
Received signal. Shape depends on
input_domain:"freq":[batch, num_rx, num_rx_ant, num_ofdm_symbols, num_subcarriers]"time":[batch, num_rx, num_rx_ant, num_samples]
no (tf.Tensor, float32) – Noise variance, shape
[batch]or scalar.h (tf.Tensor, complex64, optional) – Ground-truth channel matrix for perfect CSI mode. Shape:
[batch, num_rx, num_rx_ant, num_tx, num_tx_ant, num_ofdm_symbols, num_subcarriers]Required ifchannel_estimator="perfect"was set in constructor.
- Returns:
Training mode (
training=True): LLRs for coded bits, shape[batch, num_ue, num_coded_bits]Inference mode (
training=False): Decoded information bits, shape[batch, num_ue, tb_size]If
return_tb_crc_status=Truein inference mode, returns tuple(b_hat, tb_crc_status)
- Return type:
tf.Tensor
Notes
The processing chain follows standard PUSCH reception:
OFDM demodulation (if time domain): FFT and CP removal
Channel estimation: Perfect CSI passthrough or LS estimation
MIMO detection: Neural detector with constellation sync
Layer demapping: Separate streams back to UE data
TB decoding (inference only): LDPC decoding, CRC check
In training mode, step 5 is skipped because:
TB decoding is non-differentiable (hard decisions)
Need LLRs for BCE loss against coded bits
cThe loss provides gradients to train the neural detector
The squeeze operation on LLRs (when shape has singleton dimension 2) handles the case where
num_layers=1, ensuring consistent output shape regardless of MIMO layer configuration.
System¶
- class demos.pusch_autoencoder.src.system.PUSCHLinkE2E(*args, **kwargs)[source]¶
Bases:
ModelEnd-to-end differentiable PUSCH link model for MU-MIMO autoencoder training.
This class simulates a complete 5G NR PUSCH uplink that can operate in two modes:
Baseline mode (
use_autoencoder=False): Uses standard QAM constellation with LS channel estimation + LMMSE equalization for BER/BLER benchmarking.Autoencoder mode (
use_autoencoder=True): Uses trainable constellation geometry AND labeling, plus a neural detector, enabling end-to-end optimization.
The model supports both perfect and imperfect CSI scenarios, where imperfect CSI uses LS channel estimation with optional neural refinement.
- Parameters:
channel_model (tuple or CIRDataset) – For baseline mode:
CIRDatasetobject for on-demand CIR generation. For autoencoder mode: tuple(a, tau)of pre-loaded CIR tensors.perfect_csi (bool) – If
True, provides ground-truth channel to the receiver. IfFalse, receiver performs LS channel estimation.use_autoencoder (bool) – If
True, uses trainable transmitter and neural detector. IfFalse, uses standard PUSCH TX/RX with LMMSE detection.training (bool) – If
True,call()returns the training loss (BCE + regularization). IfFalse,call()returns(bits, bits_hat)for BER evaluation.config (Config, optional) – System configuration. Defaults to
Config()if not provided. Use this to customize system parameters likenum_bs_ant.gumbel_temperature (float, optional) – Initial Gumbel-Softmax temperature for learnable labeling. Lower values produce sharper (more discrete) assignments. Default 0.5.
Notes
For autoencoder mode,
channel_modelmust be a tuple(a, tau)whereacontains complex CIR coefficients with shape[num_samples, num_bs, num_bs_ant, num_ue, num_ue_ant, num_paths, num_time_steps]andtaucontains path delays with shape[num_samples, num_bs, num_ue, num_paths]. For baseline mode,channel_modelmust be a validCIRDataset.The trainable transmitter now includes:
Learnable constellation geometry: Point positions can be optimized starting from random initialization to escape QAM local minima.
Learnable labeling: Bit-to-symbol assignment learned via Gumbel-Softmax enables optimization beyond fixed Gray coding.
This dual optimization enables full geometric shaping potential that is limited when using fixed Gray labeling.
Additional notes:
self._cfgcontains PUSCH resource grid information after construction.self.trainable_variablesreturns all trainable weights (TX + RX).In training mode,
call()returns a scalar loss tensor.In inference mode,
call()returns(b, b_hat)bit tensors.The PUSCH configuration (PRBs, MCS, layers) remains fixed after init.
Constellation normalization maintains unit average power.
Channel model type (tuple vs CIRDataset) determines internal processing path.
Example
>>> # Autoencoder training setup >>> cir_manager = CIRManager() >>> a, tau = cir_manager.load_from_tfrecord(group_for_mumimo=True) >>> model = PUSCHLinkE2E((a, tau), perfect_csi=False, ... use_autoencoder=True, training=True) >>> loss = model(batch_size=32, ebno_db=10.0) >>> # Train with separate optimizers for geometry, labeling, and receiver
- property trainable_variables¶
Collect all trainable variables from transmitter and receiver.
- Returns:
Combined list of transmitter and receiver trainable variables. For autoencoder mode, includes: - Constellation geometry (real/imag coordinates) - Labeling logits (permutation matrix) - Neural detector weights (ResBlocks, correction scales)
- Return type:
list of tf.Variable
- property tx_geometry_variables¶
Get transmitter geometry (constellation point) variables.
- Returns:
Two-element list:
[_points_r, _points_i]for separate optimization from labeling variables.- Return type:
list of tf.Variable
- property tx_labeling_variables¶
Get transmitter labeling (permutation) variables.
- Returns:
One-element list:
[_labeling_logits]for separate optimization from geometry variables.- Return type:
list of tf.Variable
- property rx_variables¶
Get receiver trainable variables.
- Returns:
All trainable variables from the receiver (neural detector weights including correction scales and ResBlock parameters).
- Return type:
list of tf.Variable
- property constellation¶
Get the current normalized constellation points.
- Returns:
Complex tensor of shape
[num_points]with unit average power. For 16-QAM, this is 16 complex values.- Return type:
tf.Tensor
- property gumbel_temperature¶
Get current Gumbel-Softmax temperature.
- Returns:
Current temperature value, or
Noneif transmitter does not support learnable labeling.- Return type:
float or None
- get_hard_labeling()[source]¶
Get the hard (argmax) labeling permutation from the transmitter.
- Returns:
Integer permutation indices if learnable labeling is enabled, otherwise
None.- Return type:
tf.Tensor or None
- call(batch_size, ebno_db)¶
Execute forward pass through the end-to-end PUSCH link.
- Parameters:
batch_size (int) – Number of transport blocks to simulate in parallel.
ebno_db (tf.Tensor) – Energy per bit to noise power spectral density ratio in dB. Can be scalar (same SNR for all samples) or vector
[batch_size].
- Returns:
Training mode: Scalar loss tensor (BCE)
Inference mode: Tuple
(b, b_hat)where: -b: Original bits, shape[batch_size, num_ue, tb_size]-b_hat: Detected bits, same shape asb
- Return type:
tf.Tensor or tuple
Notes
JIT compilation is disabled (
jit_compile=False) because the neural detector uses dynamic shapes and control flow that are incompatible with XLA compilation.The processing flow follows standard PUSCH transmission:
Transmitter: Bit generation, encoding, constellation mapping with learnable labeling, layer mapping, resource grid mapping, precoding (if enabled), OFDM modulation (if time domain)
Channel: Random sampling from pre-loaded CIR tensors (autoencoder) or on-demand CIR generation (baseline), CIR-to-OFDM conversion, AWGN addition
Receiver: OFDM demodulation (if time domain), channel estimation (perfect or LS), neural MIMO detection with constellation synchronization, layer demapping, transport block decoding (inference)
Loss computation (training only): BCE loss on soft LLRs plus minimum distance regularization to prevent constellation collapse