Artificial IntelligenceJune 18, 2026

Training A Transformer On Variable-Length Byte Sequences With Learned Positional Bias

N

Written by

Nova Neural

The problem that pulled me in

I ran into a weird failure mode while training a small Transformer for a “toy” classification task: my inputs were raw byte sequences with lots of length variance (some sequences were a few dozen bytes, others were tens of thousands). I was using learned token embeddings plus standard positional embeddings, and training would look fine—until the longer sequences started dominating the loss in a way that made the model miscalibrate on short sequences.

What fixed it for me was a very specific tweak: instead of adding only an absolute positional embedding to tokens, I added a learned positional bias inside the attention mechanism that depends on relative byte positions (i.e., how far apart two tokens are). This is a concept popular in modern attention variants, but I implemented a compact version in PyTorch that works well for variable-length “byte” streams.

Below is exactly what I built and how it works.


What I built (in one sentence)

A byte-level Transformer classifier that uses learned relative positional bias (a small lookup table) added to attention logits, enabling it to handle variable-length sequences more stably.

Key terms (brief and practical)

  • Transformer: a neural network architecture that uses attention to mix information across a sequence.
  • Attention logits: the raw (unnormalized) scores used to decide which positions each token should look at.
  • Relative positional bias: an extra learned score that depends on the distance between two positions (e.g., token i attending to token j).

A minimal working dataset (synthetic bytes)

To keep this blog self-contained, I generate random byte sequences and assign a label based on a simple rule:

  • Label 1 if the sequence contains the byte value 0xAA
  • Label 0 otherwise

This is deliberately simple, but it still exhibits the variable-length behavior that caused the training instability.

import random import torch from torch.utils.data import Dataset, DataLoader class ByteContainsDataset(Dataset): def __init__(self, n_samples=2000, min_len=8, max_len=400, p_aa=0.5, seed=0): super().__init__() rng = random.Random(seed) self.samples = [] for _ in range(n_samples): L = rng.randint(min_len, max_len) seq = [rng.randrange(256) for _ in range(L)] # Force label signal with controllable probability if rng.random() < p_aa: # Ensure at least one 0xAA seq[rng.randrange(L)] = 0xAA y = 1 else: # Ensure no 0xAA seq = [b if b != 0xAA else ((b + 1) % 256) for b in seq] y = 0 self.samples.append((torch.tensor(seq, dtype=torch.long), y)) def __len__(self): return len(self.samples) def __getitem__(self, idx): return self.samples[idx] def collate_batch(batch, pad_value=0): # batch: list[(seq_tensor, y)] ys = torch.tensor([y for _, y in batch], dtype=torch.long) lengths = torch.tensor([len(seq) for seq, _ in batch], dtype=torch.long) max_len = int(lengths.max().item()) B = len(batch) x = torch.full((B, max_len), pad_value, dtype=torch.long) for i, (seq, _) in enumerate(batch): x[i, :len(seq)] = seq # mask: True for real tokens, False for padding attn_mask = torch.arange(max_len).unsqueeze(0) < lengths.unsqueeze(1) return x, attn_mask, ys

The model: Transformer + learned relative positional bias

Architecture choices I made

  • Byte vocabulary: 256 possible values → embedding size d_model
  • Classifier head: mean-pool over tokens (masked) then linear layer
  • Attention: multi-head self-attention where attention logits get an added bias based on relative distance.

Why this helps (what was happening)

With only absolute positional embeddings, the model needs to relearn how “far” things are across lengths. Relative bias gives attention a stable notion of distance (small vs large separation) that transfers better across variable-length inputs.


Step-by-step code: the relative bias module

This module creates a bias tensor shaped like [num_heads, T, T], where T is the sequence length in the batch.

I clamp relative distances into a fixed range so the lookup table stays small.

import torch import torch.nn as nn import torch.nn.functional as F class LearnedRelativePositionBias(nn.Module): def __init__(self, num_heads, max_distance=64): super().__init__() self.num_heads = num_heads self.max_distance = max_distance # distances in [-max_distance, +max_distance] vocab_size = 2 * max_distance + 1 self.bias = nn.Embedding(vocab_size, num_heads) def forward(self, T, device=None): device = device or self.bias.weight.device # positions: 0..T-1 pos = torch.arange(T, device=device) # relative distance matrix: (i - j) rel = pos[:, None] - pos[None, :] # clamp to [-max_distance, +max_distance] rel = rel.clamp(-self.max_distance, self.max_distance) # shift to [0, vocab_size-1] rel_index = rel + self.max_distance # lookup: [T, T, num_heads] -> permute to [num_heads, T, T] b = self.bias(rel_index) # (T, T, H) b = b.permute(2, 0, 1).contiguous() return b

Multi-head attention with the bias added to logits

Attention works like this:

  1. Compute queries Q, keys K, values V
  2. Compute logits: Q @ K^T / sqrt(d_head)
  3. Add relative bias to logits
  4. Apply mask so padding tokens get no attention
  5. Softmax → weights → weighted sum of values
class MultiHeadSelfAttentionWithRelativeBias(nn.Module): def __init__(self, d_model, num_heads, dropout=0.1, max_distance=64): super().__init__() assert d_model % num_heads == 0 self.d_model = d_model self.num_heads = num_heads self.d_head = d_model // num_heads self.qkv = nn.Linear(d_model, 3 * d_model, bias=False) self.out = nn.Linear(d_model, d_model, bias=False) self.dropout = nn.Dropout(dropout) self.rel_bias = LearnedRelativePositionBias(num_heads, max_distance=max_distance) def forward(self, x, attn_mask): """ x: (B, T, d_model) attn_mask: (B, T) bool, True for tokens, False for padding """ B, T, _ = x.shape # Project to Q, K, V qkv = self.qkv(x) # (B, T, 3*d_model) q, k, v = qkv.chunk(3, dim=-1) # Reshape for heads # (B, T, H, d_head) -> (B, H, T, d_head) q = q.view(B, T, self.num_heads, self.d_head).transpose(1, 2) k = k.view(B, T, self.num_heads, self.d_head).transpose(1, 2) v = v.view(B, T, self.num_heads, self.d_head).transpose(1, 2) # Attention logits: (B, H, T, T) logits = (q @ k.transpose(-2, -1)) / (self.d_head ** 0.5) # Add learned relative positional bias: (H, T, T) -> broadcast to (B, H, T, T) bias = self.rel_bias(T, device=x.device) # (H, T, T) logits = logits + bias.unsqueeze(0) # Mask padding: # We want keys that are padding to get -inf logits. # attn_mask: True for valid tokens, False for padding # keys_mask: (B, 1, 1, T) keys_mask = attn_mask.unsqueeze(1).unsqueeze(2) # (B,1,1,T) logits = logits.masked_fill(~keys_mask, float("-inf")) # Softmax over keys dimension attn = F.softmax(logits, dim=-1) attn = self.dropout(attn) # Weighted sum: (B, H, T, d_head) y = attn @ v # Merge heads: (B, T, H*d_head) y = y.transpose(1, 2).contiguous().view(B, T, self.d_model) return self.out(y)

Full Transformer encoder block + classifier

I kept it small: attention + feed-forward + layer norms.

class TransformerBlock(nn.Module): def __init__(self, d_model, num_heads, d_ff=4_096, dropout=0.1, max_distance=64): super().__init__() self.attn = MultiHeadSelfAttentionWithRelativeBias( d_model=d_model, num_heads=num_heads, dropout=dropout, max_distance=max_distance ) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.ff = nn.Sequential( nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_ff, d_model), nn.Dropout(dropout), ) def forward(self, x, attn_mask): # Attention sublayer with residual x = x + self.attn(self.norm1(x), attn_mask) # Feed-forward sublayer with residual x = x + self.ff(self.norm2(x)) return x class ByteTransformerClassifier(nn.Module): def __init__(self, vocab_size=256, d_model=128, num_heads=8, num_layers=4, max_distance=64, dropout=0.1, pad_id=0): super().__init__() self.pad_id = pad_id self.token_emb = nn.Embedding(vocab_size, d_model) self.dropout = nn.Dropout(dropout) self.layers = nn.ModuleList([ TransformerBlock( d_model=d_model, num_heads=num_heads, d_ff=4 * d_model, dropout=dropout, max_distance=max_distance ) for _ in range(num_layers) ]) self.classifier = nn.Linear(d_model, 2) def forward(self, x, attn_mask): """ x: (B, T) token ids attn_mask: (B, T) bool True for valid tokens """ # Embed h = self.token_emb(x) # (B, T, d_model) h = self.dropout(h) # Transformer stack for layer in self.layers: h = layer(h, attn_mask) # Masked mean pooling # attn_mask: (B,T) -> (B,T,1) mask = attn_mask.unsqueeze(-1).float() h_sum = (h * mask).sum(dim=1) # (B, d_model) denom = mask.sum(dim=1).clamp(min=1.0) # (B, 1) h_mean = h_sum / denom return self.classifier(h_mean)

Training loop (with a sanity-check print)

This trains the model and prints loss/accuracy. The dataset is tiny, so training is fast.

from sklearn.metrics import accuracy_score def train(): torch.manual_seed(0) device = "cuda" if torch.cuda.is_available() else "cpu" train_ds = ByteContainsDataset(n_samples=3000, min_len=8, max_len=400, p_aa=0.5, seed=1) val_ds = ByteContainsDataset(n_samples=800, min_len=8, max_len=400, p_aa=0.5, seed=2) train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, collate_fn=collate_batch) val_loader = DataLoader(val_ds, batch_size=64, shuffle=False, collate_fn=collate_batch) model = ByteTransformerClassifier( d_model=128, num_heads=8, num_layers=4, max_distance=64, dropout=0.1, pad_id=0 ).to(device) opt = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01) criterion = nn.CrossEntropyLoss() def eval_loop(loader): model.eval() all_preds = [] all_y = [] total_loss = 0.0 n = 0 with torch.no_grad(): for x, attn_mask, y in loader: x = x.to(device) attn_mask = attn_mask.to(device) y = y.to(device) logits = model(x, attn_mask) loss = criterion(logits, y) total_loss += float(loss.item()) * len(y) n += len(y) preds = logits.argmax(dim=-1).cpu().tolist() all_preds.extend(preds) all_y.extend(y.cpu().tolist()) acc = accuracy_score(all_y, all_preds) return total_loss / n, acc model.train() for epoch in range(1, 6): total_loss = 0.0 total = 0 for step, (x, attn_mask, y) in enumerate(train_loader): x = x.to(device) attn_mask = attn_mask.to(device) y = y.to(device) opt.zero_grad() logits = model(x, attn_mask) loss = criterion(logits, y) loss.backward() opt.step() total_loss += float(loss.item()) * len(y) total += len(y) if step == 0 and epoch == 1: # Sanity check: show shapes and a tiny bit of masking behavior # For the first batch, attn_mask sums to the number of real tokens per item. with torch.no_grad(): print("debug:", "x", tuple(x.shape), "attn_mask_tokens_per_sample", attn_mask.sum(dim=1)[:5].cpu().tolist(), "logits", tuple(logits.shape)) train_loss = total_loss / total val_loss, val_acc = eval_loop(val_loader) print(f"epoch {epoch} | train_loss={train_loss:.4f} | val_loss={val_loss:.4f} | val_acc={val_acc:.4f}") if __name__ == "__main__": train()

What I observed when I ran it

With this relative bias in place:

  • Training converged quickly (the model learns to detect 0xAA anywhere in the sequence).
  • Accuracy stayed stable even when I increased max_len substantially.
  • The masked mean pooling kept sequence length from overwhelming the classifier.

The key thing I noticed is that relative bias made attention scores more “distance-aware” without relying on an absolute positional scheme that gets brittle when padding and variable lengths get extreme.


Practical note: choosing max_distance

max_distance controls the size of the relative bias table:

  • Too small → long-range relationships collapse to the same bucket.
  • Too big → more parameters (and more memory) for little gain if most relevant interactions are local.

For byte streams in the “hundreds to a few thousands” range, values like 32–128 are a decent starting point.


Conclusion

I built a byte-level Transformer classifier that adds learned relative positional bias directly to attention logits, then trained it on variable-length padded sequences. The implementation stays compact, runs fast in PyTorch, and—most importantly—fixes the “long sequences dominate short ones” instability I saw with only token/absolute positional embeddings by giving attention a stable, distance-based signal.