Source code for demos.pusch_autoencoder.src.pusch_trainable_transmitter

# SPDX-License-Identifier: MIT
# Copyright (c) 2025–present Srikanth Pagadarai

"""
Trainable PUSCH transmitter with SYMMETRIC learnable constellation geometry AND labeling.

Extends Sionna's standard PUSCHTransmitter to support end-to-end
optimization of constellation points. The key is maintaining valid
constellation properties (unit average power, centered) while
allowing gradient flow through the constellation geometry.

Design Approach
---------------
**Symmetric Constellation Geometry**:

Instead of 16 independent complex points, we store only 4 base points in the first
quadrant (I>0, Q>0). The complete 16-point constellation is constructed by:

1. Original 4 base points -> 1st quadrant (I>0, Q>0)
2. Reflection across Q-axis -> 2nd quadrant (I<0, Q>0)
3. Reflection across I-axis -> 4th quadrant (I>0, Q<0)
4. Reflection across origin -> 3rd quadrant (I<0, Q<0)

This enforces mathematical symmetry that is preserved during gradient descent.

**Learnable Labeling** (unchanged):

Gumbel-Softmax labeling remains essential because:
- Random constellation initialization breaks Gray labeling assumptions
- The system needs to jointly optimize geometry AND bit-to-symbol assignment
- Even with symmetric geometry, labeling can and should adapt

Initialization Modes
--------------------
- **Training mode** (``training=True``): Random base points in Q1, then reflected
  to create full symmetric constellation. Labeling initialized to identity.

- **Inference mode** (``training=False``): Standard QAM initialization as placeholder
  until weights are loaded from trained model.
"""

from sionna.phy.nr import PUSCHTransmitter
from sionna.phy.mapping import Constellation
import tensorflow as tf
import numpy as np


[docs] class PUSCHTrainableTransmitter(PUSCHTransmitter): r""" PUSCH Transmitter with SYMMETRIC trainable constellation geometry AND labeling. This subclass enforces 4-fold symmetry (I-axis, Q-axis, origin) in the constellation while supporting learnable labeling for autoencoder-based communication system design. Parameters ---------- *args : tuple Positional arguments passed to ``PUSCHTransmitter``. training : bool If ``True``, constellation and labeling are trainable with soft Gumbel-Softmax assignment. Default ``False``. gumbel_temperature : float Temperature for Gumbel-Softmax. Lower = sharper (more discrete). Default 0.5. **kwargs : dict Keyword arguments passed to ``PUSCHTransmitter``. Example ------- >>> tx = PUSCHTrainableTransmitter(pusch_configs, output_domain="freq", ... training=True, gumbel_temperature=0.5) >>> x_map, x, b, c = tx(batch_size=32) Notes ----- The symmetric constellation is stored as ``[num_base_points]`` complex values (typically 4 for 16-QAM), and the full constellation is computed via reflections. This guarantees 4-fold symmetry is preserved during training. The learnable labeling uses a permutation logits matrix of shape ``[num_points, num_points]``. Row ``i`` contains logits for which constellation point should be assigned to bit pattern ``i``. """ def __init__(self, *args, training=False, gumbel_temperature=0.5, **kwargs): self._training = training self._gumbel_temperature = gumbel_temperature # Parent constructor sets up standard PUSCH processing chain super().__init__(*args, **kwargs) # Replace standard constellation with symmetric trainable version self._setup_custom_constellation()
[docs] @staticmethod def generate_random_symmetric_constellation(num_points, seed=None): """ Generate random constellation with 4-fold symmetry and unit average energy. Creates a constellation by: 1. Generating num_points/4 random base points in the first quadrant 2. Reflecting them across I-axis, Q-axis, and origin 3. Normalizing to unit average power Parameters ---------- num_points : int Total constellation size (must be divisible by 4). For 16-QAM: 16. seed : int, optional Random seed for reproducibility. Returns ------- tf.Tensor, complex64 Symmetric constellation points with shape ``[num_points]``. Points are ordered: [Q1_points, Q2_points, Q4_points, Q3_points] Notes ----- The generation process: 1. Sample num_points/4 base points in Q1 (I>0, Q>0) 2. Reflect across Q-axis: negate real part -> Q2 (I<0, Q>0) 3. Reflect across I-axis: negate imag part -> Q4 (I>0, Q<0) 4. Reflect across origin: negate both parts -> Q3 (I<0, Q<0) 5. Normalize complete constellation to unit average power Example ------- >>> points = PUSCHTrainableTransmitter.generate_random_symmetric_constellation(16, seed=42) >>> print(f"Power: {tf.reduce_mean(tf.abs(points)**2):.6f}") # Should be ~1.0 >>> # Verify I-axis symmetry (Q1 vs Q4) >>> print(f"I-sym check: {tf.reduce_mean(tf.abs(points[0] - tf.math.conj(points[8]))):.6f}") # ~0 """ if num_points % 4 != 0: raise ValueError(f"num_points must be divisible by 4, got {num_points}") num_base_points = num_points // 4 if seed is not None: np.random.seed(seed) tf.random.set_seed(seed) # Sample random base points in first quadrant (I>0, Q>0) real_base = tf.random.uniform([num_base_points], minval=0.1, maxval=1.0) imag_base = tf.random.uniform([num_base_points], minval=0.1, maxval=1.0) base_points = tf.complex(real_base, imag_base) # Generate all 4 quadrants via reflections points = PUSCHTrainableTransmitter._reflect_to_full_constellation(base_points) # Normalize to unit average power for consistent SNR interpretation energy = tf.reduce_mean(tf.square(tf.abs(points))) points = points / tf.cast(tf.sqrt(energy), points.dtype) return points
@property def trainable_variables(self): """ Return all trainable variables: base geometry + labeling. Returns ------- list of tf.Variable Three-element list: ``[_base_points_r, _base_points_i, _labeling_logits]`` where: - ``_base_points_r``: Real parts of Q1 base points - ``_base_points_i``: Imaginary parts of Q1 base points - ``_labeling_logits``: Permutation logits for bit-to-point mapping """ return [self._base_points_r, self._base_points_i, self._labeling_logits] @property def geometry_variables(self): """Return only the base geometry (Q1 constellation points) variables.""" return [self._base_points_r, self._base_points_i] @property def labeling_variables(self): """Return only the labeling (permutation) variable.""" return [self._labeling_logits] @property def gumbel_temperature(self): """float: Current Gumbel-Softmax temperature.""" return self._gumbel_temperature @gumbel_temperature.setter def gumbel_temperature(self, value): """Set Gumbel-Softmax temperature (e.g., for annealing).""" self._gumbel_temperature = value
[docs] def get_base_points(self): """ Get the base constellation points (Q1 only, not normalized). Returns ------- tf.Tensor, complex64 Base constellation points in first quadrant, shape ``[num_base_points]``. For 16-QAM, this is ``[4]`` complex values. """ return tf.complex( tf.cast(self._base_points_r, self.rdtype), tf.cast(self._base_points_i, self.rdtype), )
@staticmethod def _reflect_to_full_constellation(base_points): """ Reflect base points to create full 4-fold symmetric constellation. Parameters ---------- base_points : tf.Tensor, complex64 Base points in Q1, shape ``[num_base_points]``. Returns ------- tf.Tensor, complex64 Full constellation with 4-fold symmetry, shape ``[4*num_base_points]``. Notes ----- Reflection mapping: - Q1 (I>0, Q>0): base_points (original) - Q2 (I<0, Q>0): reflect across Q-axis (negate real part) - Q4 (I>0, Q<0): reflect across I-axis (negate imag part) - Q3 (I<0, Q<0): reflect across origin (negate both parts) """ real_base = tf.math.real(base_points) imag_base = tf.math.imag(base_points) q1_points = base_points # Original q2_points = tf.complex(-real_base, imag_base) # Q-axis reflection q4_points = tf.complex(real_base, -imag_base) # I-axis reflection q3_points = tf.complex(-real_base, -imag_base) # Origin reflection return tf.concat([q1_points, q2_points, q4_points, q3_points], axis=0)
[docs] def get_normalized_constellation(self): """ Compute 4-fold symmetric and power-normalized constellation points. Returns ------- tf.Tensor, complex64 Normalized constellation points with 4-fold symmetry, shape ``[num_points]``. For 16-QAM, this is ``[16]`` complex values. Notes ----- The normalization process: 1. Retrieve base points from trainable variables (Q1 only) 2. Reflect to create full 4-fold symmetric constellation 3. Normalize to unit average power The normalization is differentiable and does NOT break symmetry since it's a uniform scaling operation applied to all points equally. Symmetry properties maintained: - I-axis: constellation[i] = conj(constellation[j]) for mirrored i,j - Q-axis: constellation[i] = -conj(constellation[k]) for mirrored i,k - Origin: constellation[i] = -constellation[m] for opposite i,m """ # Get base points from trainable variables base_points = self.get_base_points() # Reflect to full constellation points = self._reflect_to_full_constellation(base_points) # Normalize to unit power (differentiable) energy = tf.reduce_mean(tf.square(tf.abs(points))) normalized_points = points / tf.cast(tf.sqrt(energy), points.dtype) return normalized_points
def _gumbel_softmax(self, logits, temperature, hard=False): """ Gumbel-Softmax sampling for differentiable discrete sampling. Parameters ---------- logits : tf.Tensor, float Logits for categorical distribution, shape ``[..., num_classes]``. temperature : float Temperature for Gumbel-Softmax. Lower values make distribution more peaked (closer to hard one-hot). hard : bool If True, return one-hot samples (non-differentiable). If False, return soft samples (differentiable). Returns ------- tf.Tensor, float Sampled distribution, same shape as ``logits``. Notes ----- During training, use ``hard=False`` for gradient flow. During inference, use ``hard=True`` for discrete assignment. """ if hard: # Hard one-hot assignment (inference mode) indices = tf.argmax(logits, axis=-1) return tf.one_hot(indices, depth=tf.shape(logits)[-1], dtype=logits.dtype) else: # Soft Gumbel-Softmax (training mode) gumbel_noise = -tf.math.log( -tf.math.log(tf.random.uniform(tf.shape(logits))) ) y = logits + gumbel_noise return tf.nn.softmax(y / temperature)
[docs] def get_soft_labeling_matrix(self, hard=False): """ Get the current labeling matrix (permutation of constellation points). Parameters ---------- hard : bool If True, return hard one-hot assignment. If False, return soft Gumbel-Softmax probabilities. Returns ------- tf.Tensor, float Labeling matrix, shape ``[num_points, num_points]``. Row i indicates which constellation point(s) are assigned to bit pattern i. Example ------- For hard assignment, row i will have a single 1 at column j, meaning bit pattern i maps to constellation point j. For soft assignment during training, row i contains probabilities summing to 1, enabling gradient flow through the labeling. """ if hard: return self._gumbel_softmax( self._labeling_logits, self._gumbel_temperature, hard=True ) else: return self._gumbel_softmax( self._labeling_logits, self._gumbel_temperature, hard=False )
# [custom-constellation-start] def _setup_custom_constellation(self): """ Initialize symmetric trainable constellation and learnable labeling. For 16-QAM (num_bits_per_symbol=4): - Stores 4 base points in Q1 as trainable variables - Other 12 points computed via reflections (enforces symmetry) - Labeling initialized to identity matrix (preserves structure initially) Notes ----- Even with symmetric geometry, we preserve flexible labeling to allow the optimizer to discover optimal bit-to-symbol assignments that may differ from standard Gray coding. """ num_points = 2**self._num_bits_per_symbol self._num_constellation_points = num_points self._num_base_points = num_points // 4 if self._training: # Training mode: Use symmetric random initialization constellation_points = self.generate_random_symmetric_constellation( num_points, seed=None # No seed = different init each run ) # Extract just the base points (Q1 only) base_points = constellation_points[: self._num_base_points] else: # Inference mode: Use standard QAM as placeholder constellation_points = Constellation( "qam", num_bits_per_symbol=self._num_bits_per_symbol ).points # For QAM, assume first quarter are Q1 points (this is approximate) base_points = constellation_points[: self._num_base_points] # Store ONLY base points as trainable variables init_r = tf.math.real(base_points) init_i = tf.math.imag(base_points) self._base_points_r = tf.Variable( tf.cast(init_r, self.rdtype), trainable=self._training, name="constellation_base_real", ) self._base_points_i = tf.Variable( tf.cast(init_i, self.rdtype), trainable=self._training, name="constellation_base_imag", ) # Initialize labeling logits to identity-like matrix # This means bit pattern i initially maps to constellation point i identity_scale = 5.0 # Start sharply peaked at identity init_logits = tf.eye(num_points) * identity_scale self._labeling_logits = tf.Variable( tf.cast(init_logits, self.rdtype), trainable=self._training, name="labeling_logits", ) # Create constellation object for compatibility with Sionna internals # (used by resource grid mapper, etc.) # Initialize with full reflected constellation full_constellation = self.get_normalized_constellation() self._constellation = Constellation( "custom", num_bits_per_symbol=self._num_bits_per_symbol, points=full_constellation, normalize=False, center=False, ) # [custom-constellation-end] def _soft_symbol_mapping(self, c): """ Map coded bits to symbols using soft learnable labeling. This replaces the standard hard lookup with a differentiable soft assignment that enables gradient flow through the labeling. Parameters ---------- c : tf.Tensor, float32 Coded bits, shape ``[batch, num_tx, num_coded_bits]``. Returns ------- tf.Tensor, complex64 Mapped symbols, shape ``[batch, num_tx, num_symbols]``. """ # Get normalized constellation points (with 4-fold symmetry) constellation = self.get_normalized_constellation() # [num_points] # Reshape bits for symbol grouping batch_size = tf.shape(c)[0] num_tx = tf.shape(c)[1] num_bits = tf.shape(c)[2] bits_per_sym = self._num_bits_per_symbol num_symbols = num_bits // bits_per_sym # Reshape to [batch, num_tx, num_symbols, bits_per_symbol] c_reshaped = tf.reshape(c, [batch_size, num_tx, num_symbols, bits_per_sym]) # Convert bits to integer indices (binary to decimal) powers = tf.cast(2 ** tf.range(bits_per_sym - 1, -1, -1), c.dtype) indices = tf.reduce_sum( c_reshaped * powers, axis=-1 ) # [batch, num_tx, num_symbols] indices = tf.cast(indices, tf.int32) # Get labeling assignment matrix if self._training: # Soft assignment during training for gradient flow assignment = self._gumbel_softmax( self._labeling_logits, self._gumbel_temperature, hard=False ) # [num_points, num_points] else: # Hard assignment during inference assignment = self.get_soft_labeling_matrix(hard=True) # Convert indices to one-hot: [batch, num_tx, num_symbols, num_points] one_hot = tf.one_hot( indices, depth=self._num_constellation_points, dtype=self.rdtype ) # Apply labeling permutation: one_hot @ assignment soft_assignment = tf.einsum("btsp,pq->btsq", one_hot, assignment) # Weighted sum over constellation points constellation_r = tf.math.real(constellation) constellation_i = tf.math.imag(constellation) x_real = tf.einsum("btsq,q->bts", soft_assignment, constellation_r) x_imag = tf.einsum("btsq,q->bts", soft_assignment, constellation_i) x_map = tf.complex(x_real, x_imag) return x_map
[docs] def call(self, inputs): """ Execute transmitter processing chain with trainable constellation and labeling. The symmetric constellation is computed from base points before each forward pass. Parameters ---------- inputs : int or tf.Tensor If ``return_bits=True``: ``int`` specifying batch size. If ``return_bits=False``: ``tf.Tensor`` of input bits. Returns ------- tuple or tf.Tensor If ``return_bits=True``: tuple ``(x_map, x, b, c)`` If ``return_bits=False``: just ``x`` (transmitted signal). Notes ----- The processing chain: 1. Compute full 16-point symmetric constellation from 4 base points 2. Standard PUSCH transmission (bits -> symbols -> OFDM) 3. Symbol mapping uses learnable labeling via Gumbel-Softmax """ # Update constellation object with current normalized symmetric points self._constellation.points = self.get_normalized_constellation() if self._return_bits: batch_size = inputs b = self._binary_source([batch_size, self._num_tx, self._tb_size]) else: b = inputs # TB encoding: CRC, segmentation, LDPC encoding, rate matching c = self._tb_encoder(b) # Custom soft symbol mapping with learnable labeling x_map = self._soft_symbol_mapping(c) # Distribute symbols across MIMO layers x_layer = self._layer_mapper(x_map) # Place symbols on OFDM resource grid with DMRS x_grid = self._resource_grid_mapper(x_layer) # Apply precoding if configured if self._precoding == "codebook": x_pre = self._precoder(x_grid) else: x_pre = x_grid # Convert to time domain if requested if self._output_domain == "time": x = self._ofdm_modulator(x_pre) else: x = x_pre if self._return_bits: return x_map, x, b, c else: return x