An end-to-end mechanistic research revealing how tiny transformers solve in-context learning. We reverse-engineer the model's weights and activations to show how multi-head attention self-organizes into specialized, parallel circuits—cooperatively decoupling Bayesian elimination and associative recall without explicit supervision.
3→3, 4→5, 2→1, 1→4, 5→2 is revealed pair by pair. The model must read it from context — it changes every single example.
A permutation over V=5 is a rule that maps each of the numbers 1–5 to a unique other number in 1–5. Every input has exactly one output. No two inputs share an output.
The mapping is different on every single training example. The model cannot memorize "3 always maps to 3" — because in the next example, 3 might map to 1.
Instead, the model must learn a meta-skill: read the (x,y) pairs from context, then answer new queries using what it just read.
Read pairs from context. When the model sees 3, 3 it should understand: "3 maps to 3 in this episode."
Track used values. Once y=3 has been output, it cannot appear again (it's a valid permutation — no repeats).
Answer queries. When given a new x, use the pairs seen so far to find its y — or at least narrow the candidates.
All 120 permutations are split once, before any training. 80% (96 permutations) become training mappings. 20% (24) become test mappings. These sets never overlap.
This is critical: the model is evaluated on mappings it has never seen. We're testing the general skill of reading context — not memorization of specific mappings.
Every time the dataloader asks for a sample, three things happen fresh:
Choosing a vocabulary size of $V=5$ and configuring the dataloader's epoch sizes represent a deliberate balance between computational simplicity, architectural capacity constraints, and optimization stability.
There are $5! = 120$ possible permutation mappings. Shuffling the input order ($5! = 120$) yields $120 \times 120 = 14,400$ total unique sequence configurations, providing a diverse training space.
At $V=4$ ($24$ permutations), the mappings are fewer than the model's intermediate dimensions ($d_{\text{model}}=32, d_{\text{ff}}=64$), allowing neurons to hardcode each rule. At $V=5$, the 120 permutations outnumber these dimensions, forcing the model to learn a compact in-context retrieval algorithm.
With $100,000$ training samples per epoch, the model sees each unique training configuration ($11,520$) roughly $8.7$ times. This high density provides the Adam optimizer with stable, smoothed gradient updates, helping the narrow network avoid bad local minima.
By withholding 24 of the 120 mappings entirely from training, evaluating on $5,000$ test samples guarantees that the model has acquired a generalized meta-skill. The over-complete coverage of the test space ($2,880$ configurations) ensures stable statistical metrics.
def __getitem__(self, idx): mapping_perm = random.choice(self.mappings) # pick one of 96 train perms mapping = {x: y for x, y in zip(self.elements, mapping_perm)} X_order = list(self.elements) random.shuffle(X_order) # shuffle [1,2,3,4,5] → e.g. [3,4,2,1,5] seq = [] for x in X_order: seq.append(x) seq.append(mapping[x]) # interleave: [3,3, 4,5, 2,1, 1,4, 5,2] inputs = torch.tensor(seq[:-1], dtype=torch.long) # tokens 0–8 (length 9) targets = torch.tensor(seq[1:], dtype=torch.long) # tokens 1–9 (length 9, shifted by 1) mask = torch.zeros(len(inputs), dtype=torch.float32) mask[::2] = 1.0 # positions 0,2,4,6,8 → the y-prediction positions return inputs, targets, mask
Say the mapping for this example is 3→3, 4→5, 2→1, 1→4, 5→2 and the x-order is [3,4,2,1,5]. The full sequence is:
Positions 0,2,4,6,8 are x tokens (inputs/queries). Positions 1,3,5,7,9 are y tokens (answers).
inputs = seq[:-1] — drop the last token. Length = 9.
The model reads these 9 tokens all at once and produces 9 predictions simultaneously.
targets = seq[1:] — drop the first token. Length = 9.
Each target is the token that comes after the corresponding input. This is the standard language-model "next token prediction" setup.
Let's line them up and see what the model is predicting at each position:
| Position | Input token (seen) | Target token (predict) | Meaningful? | What the model must do |
|---|---|---|---|---|
| 0 | 3 | 3 | YES ✓ | Given $x_1=3$, predict $y_1$. (No prior context — hardest step.) |
| 1 | 3 | 4 | NO | Predict next $x$ (which is random) — not meaningful, loss masked out. |
| 2 | 4 | 5 | YES ✓ | Given $x_2=4$, predict $y_2$. Can see ($3\to3$) already. |
| 4 | 2 | 1 | YES ✓ | Given $x_3=2$, predict $y_3$. Can see ($3\to3, 4\to5$). |
| 6 | 1 | 4 | YES ✓ | Given $x_4=1$, predict $y_4$. Can see ($3\to3, 4\to5, 2\to1$). |
| 8 | 5 | 2 | YES ✓ | Given $x_5=5$, predict $y_5$. All 4 prior pairs visible. |
At each $y$-prediction position, the model outputs a probability distribution over the 5 possible values. Cross-entropy measures how wrong that distribution is:
# mask[::2] = 1.0 means positions 0, 2, 4, 6, 8 get mask=1 (active) # positions 1, 3, 5, 7 get mask=0 (ignored) loss_elementwise = criterion(logits.view(-1, logits.size(-1)), targets.view(-1)) # loss_elementwise shape: [batch*9] — one loss number per position per sample masked_loss = (loss_elementwise * masks.view(-1)).sum() # multiply by mask → zero out x-positions, keep only y-positions loss = masked_loss / masks.sum() # divide by number of active tokens (5 per sample) → average loss per y-prediction
The 5 $y$-positions have different difficulty levels:
| Step | Prior context | Best possible loss |
|---|---|---|
| y₁ | None — 5 candidates | ln(5)≈1.609 |
| y₂ | 1 pair seen — 4 candidates | ln(4)≈1.386 |
| y₃ | 2 pairs — 3 candidates | ln(3)≈1.099 |
| y₄ | 3 pairs — 2 candidates | ln(2)≈0.693 |
| y₅ | 4 pairs — 1 candidate | 0.000 |
The token embedding encodes what a token is (e.g. "I am the number 3"). The positional embedding encodes where it is in the sequence (e.g. "I am at position 4"). Adding them together gives each position's representation both pieces of information.
[3,3,4,5] the same as [4,5,3,3] — it would be blind to order.| Version | d | heads | Learns recall? |
|---|---|---|---|
| v1 | 16 | 1 | ✗ No |
| v2 | 32 | 1 | ~ Partial |
| v3 ✓ | 32 | 2 | ✓ 100.00% |
Every token is represented as a vector of $d_{\text{model}}$ numbers. Think of it as the "working memory" size for each position. Everything the model knows about a token — its identity, its position, the context it has gathered — must fit in this vector.
5 tokens need to be well-separated in space. ~5–6 dimensions sufficient.
Is this an x or a y? Even or odd position? ~2–3 dimensions.
Which $y$-values are used? $2^5=32$ subsets → needs ~8+ dimensions.
Test sequence from an unseen permutation: [2,3, 1,5, 4,4, 3,2, 5,1] — mapping: $2\to3, 1\to5, 4\to4, 3\to2, 5\to1$
By comparing the model's output distribution against the exact mathematical Bayesian posterior over the entire test set (5,000 samples), we can quantitatively measure the precision of its probabilistic reasoning:
| Prediction Position | Categorical Accuracy | Bayesian Target | Mean Absolute Error (MAE) | Max Absolute Error (Max-AE) |
|---|---|---|---|---|
| Step 1 (no history) | 9.56% | 20.00% | 0.012388 | 0.034079 |
| Step 2 (1 pair seen) | 16.40% | 25.00% | 0.015530 | 0.072218 |
| Step 3 (2 pairs seen) | 24.56% | 33.33% | 0.018369 | 0.101983 |
| Step 4 (3 pairs seen) | 42.26% | 50.00% | 0.015447 | 0.119183 |
| Step 5 (4 pairs seen) | 100.00% | 100.00% | 0.000292 | 0.001334 |
We take a real test sequence, keep the first 4 pairs intact, then replace x₅ with a repeat of x₁. The correct answer is x₁'s original y — which is already in the "used" set.
# Built programmatically from a real test sequence — no hardcoding inputs, targets, _ = test_dataset[0] full_seq = [inputs[0].item()] + targets.tolist() mapping = {full_seq[i]: full_seq[i+1] for i in range(0, 10, 2)} x_order = [full_seq[i] for i in range(0, 10, 2)] probe_seq = [] for x in x_order[:4]: # first 4 pairs — normal probe_seq.append(x) probe_seq.append(mapping[x]) probe_seq.append(x_order[0]) # x5 = repeat of x1 probe_seq.append(mapping[x_order[0]]) # correct answer if model recalls
Probe sequence: [2, 4, 5, 3, 4, 5, 3, 2, 2, 4] — x=2 repeats, correct recall answer = y=4
The original dataset contains only valid permutations — every x appears exactly once per sequence. The model was never penalized for ignoring (x→y) associations, because elimination always gave the right answer. There was literally zero gradient signal pushing it toward recall.
We introduce a modified dataset where 30% of sequences have a repeated x at position 5. Everything else — the mapping split, the interleaving logic, the masking — stays identical.
class RepeatedPermutationSequenceDataset(Dataset): def __getitem__(self, idx): mapping_perm = random.choice(self.mappings) mapping = {x: y for x, y in zip(self.elements, mapping_perm)} X_order = list(self.elements) random.shuffle(X_order) if random.random() < self.repeat_ratio: # 30% of the time: X_order[4] = random.choice(X_order[:4]) # x5 becomes a copy of x1–x4 seq = [] for x in X_order: seq.append(x) seq.append(mapping[x]) # same mapping — x still maps to the same y # inputs, targets, mask — identical to before
On a repeated sequence, all 4 other y-values will be used by step 5. Elimination says "pick y=1" (the unused one). But the correct answer is the y that x₁ mapped to earlier — which is already used. The model cannot minimize loss without learning to recall.
The split is on permutation mappings, not sequence types. The 24 test permutations are never used during training. The model is tested on mappings it has never seen, regardless of whether normal or repeated sequences appear during training.
We also upgrade the model: d=32, nhead=2. Everything else in the training loop is identical.
# New dataset with repeats train_dataset = RepeatedPermutationSequenceDataset(train_mappings, V=V, num_samples=100000) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) # Upgraded model model = MinimalCausalTransformer(vocab_size=V, max_seq_len=9, d_model=32, nhead=2, num_layers=1).to(device) optimizer = optim.Adam(model.parameters(), lr=0.001) epochs = 50 # Training loop — identical to before for epoch in range(1, epochs + 1): for inputs, targets, masks in train_loader: logits = model(inputs) loss_elementwise = criterion(logits.view(-1, logits.size(-1)), targets.view(-1)) masked_loss = (loss_elementwise * masks.view(-1)).sum() loss = masked_loss / masks.sum() loss.backward() optimizer.step()
Test sequence from an unseen permutation: [2,3, 1,5, 5,1, 4,4, 3,2] — mapping: 2→3, 1→5, 5→1, 4→4, 3→2
Same probe design: x₁=2 repeats at step 5. Correct recall answer = y=3. Only unused value = y=2.
Steps 1–4 show normal elimination — identical behaviour to Task 1. Used values are crushed to zero. Probability spreads over remaining candidates.
We evaluate against the theoretically perfect Bayesian distribution at each step — for both subgroups at step 5 (pure elimination vs. memorization).
| Step | Categorical Accuracy | Theoretical Perfect | MAE | Max-AE |
|---|---|---|---|---|
| 1 | 14.56% | 20.00% | 0.014704 | 0.055206 |
| 2 | 16.40% | 25.00% | 0.020145 | 0.090151 |
| 3 | 23.36% | 33.33% | 0.021205 | 0.143933 |
| 4 | 41.50% | 50.00% | 0.020931 | 0.168254 |
| 5 — Elimination subgroup | 100.00% | 100.00% | 0.000081 | 0.001548 |
| 5 — Memorization subgroup | 100.00% | 100.00% | 0.000838 | 0.055650 |
Self-attention lets each token in the sequence ask: "which past tokens should I pay attention to?" The heatmap shows the answer. Each row is a token asking. Each column is a token being looked at. The number in each cell is the attention weight — how much the row token attends to the column token. Each row sums to 1.0.
The upper-right triangle is always zero — the causal mask prevents any token from seeing future positions. Only the lower-left triangle matters.
There is nothing in the architecture that forces Head 1 to do elimination and Head 2 to do recall. The two heads are mathematically symmetric. Across different random seeds, gradient descent sometimes assigns the roles the other way around. The functional split is always the same — only the head number changes. This is actually strong evidence that the specialization is robust: the model always converges to this two-role solution, it just doesn't care which head gets which job.
The recall head creates tight local bonds. Each y attends to itself with near-1.0 weight — making each y's representation a stable, self-anchored vector. Each x then absorbs the representation of its preceding y.
Each (x, y) pair becomes encoded as a unit. The x's post-attention representation carries the y that came right before it.
After attention, each token passes through the Feed-Forward Network — two linear layers with ReLU. The FFN is applied independently to each position. It does not mix information between positions.
Research (Geva et al., 2021) shows FFN layers act as key-value memories: the first layer's weights are "keys", the second layer's weights are "values." The FFN learns to map specific input patterns to specific output content.
The attention weights are directly readable — that's what the heatmaps show. The FFN's internal operation is less transparent. What we can say:
Model trains on clean permutation sequences. Learns elimination perfectly — crushed used values, 100% accuracy at step 5 by deduction. Fails the memorization probe completely (99.91% on the wrong answer). The training signal never rewarded recall, so the model never learned it.
Model trains on sequences where 30% have a repeated x. Now recall is required to minimize loss. The model discovers both skills. Memorization probe: 99.85% on the correct recalled value. Formal evaluation: 100% accuracy on both elimination subgroup and memorization subgroup at step 5.
| Config | Elim | Recall | Loss |
|---|---|---|---|
| d=16, 1h, normal | ✓ 100% | ✗ 0% | ~0.959 |
| d=32, 2h, repeated | ✓ 100% | ✓ 100% | ~0.956 |
Loss alone is misleading. Both models plateau near 0.959. You must probe to understand what the model actually learned.
Models learn what training rewards. No repeated sequences → no gradient toward recall → model never learns it. Add the signal → the skill appears.
Two heads enable a qualitative leap. Not just more capacity — a structurally different computation. One head per skill, discovered autonomously.
Specialization is emergent and robust. No head was told its role. Across random seeds, the same split always appears — just with head labels swapped.
Recall requires attention + FFN together. Attention forms the (x,y) bonds; the FFN performs the associative lookup. Neither is sufficient alone.
Multi-head attention discovers complementary functional roles without being told to. One head tracks what has been used. One head binds what belongs together. This emergent modularity is not designed — it is found. And it only appears when the training data actually requires it.