An end-to-end mechanistic study revealing how tiny transformers solve in-context learning. We reverse-engineer the model's weights and activations to show how multi-head attention self-organizes into specialized, parallel circuits—cooperatively decoupling Bayesian elimination and associative recall without explicit supervision.
3→3, 4→5, 2→1, 1→4, 5→2 is revealed pair by pair. The model must read it from context — it changes every single example.
A permutation over V=5 is a rule that maps each of the numbers 1–5 to a unique other number in 1–5. Every input has exactly one output. No two inputs share an output.
The mapping is different on every single training example. The model cannot memorize "3 always maps to 3" — because in the next example, 3 might map to 1.
Instead, the model must learn a meta-skill: read the (x,y) pairs from context, then answer new queries using what it just read.
Read pairs from context. When the model sees 3, 3 it should understand: "3 maps to 3 in this episode."
Track used values. Once y=3 has been output, it cannot appear again (it's a valid permutation — no repeats).
Answer queries. When given a new x, use the pairs seen so far to find its y — or at least narrow the candidates.
All 120 permutations are split once, before any training. 80% (96 permutations) become training mappings. 20% (24) become test mappings. These sets never overlap.
This is critical: the model is evaluated on mappings it has never seen. We're testing the general skill of reading context — not memorization of specific mappings.
Every time the dataloader asks for a sample, three things happen fresh:
Choosing a vocabulary size of $V=5$ and configuring the dataloader's epoch sizes represent a deliberate balance between computational simplicity, architectural capacity constraints, and optimization stability.
There are $5! = 120$ possible permutation mappings. Shuffling the input order ($5! = 120$) yields $120 \times 120 = 14,400$ total unique sequence configurations, providing a diverse training space.
At $V=4$ ($24$ permutations), the mappings are fewer than the model's intermediate dimensions ($d_{\text{model}}=32, d_{\text{ff}}=64$), allowing neurons to hardcode each rule. At $V=5$, the 120 permutations outnumber these dimensions, forcing the model to learn a compact in-context retrieval algorithm.
With $100,000$ training samples per epoch, the model sees each unique training configuration ($11,520$) roughly $8.7$ times. This high density provides the Adam optimizer with stable, smoothed gradient updates, helping the narrow network avoid bad local minima.
By withholding 24 of the 120 mappings entirely from training, evaluating on $5,000$ test samples guarantees that the model has acquired a generalized meta-skill. The over-complete coverage of the test space ($2,880$ configurations) ensures stable statistical metrics.
def __getitem__(self, idx): mapping_perm = random.choice(self.mappings) # pick one of 96 train perms mapping = {x: y for x, y in zip(self.elements, mapping_perm)} X_order = list(self.elements) random.shuffle(X_order) # shuffle [1,2,3,4,5] → e.g. [3,4,2,1,5] seq = [] for x in X_order: seq.append(x) seq.append(mapping[x]) # interleave: [3,3, 4,5, 2,1, 1,4, 5,2] inputs = torch.tensor(seq[:-1], dtype=torch.long) # tokens 0–8 (length 9) targets = torch.tensor(seq[1:], dtype=torch.long) # tokens 1–9 (length 9, shifted by 1) mask = torch.zeros(len(inputs), dtype=torch.float32) mask[::2] = 1.0 # positions 0,2,4,6,8 → the y-prediction positions return inputs, targets, mask
Say the mapping for this example is 3→3, 4→5, 2→1, 1→4, 5→2 and the x-order is [3,4,2,1,5]. The full sequence is:
Positions 0,2,4,6,8 are x tokens (inputs/queries). Positions 1,3,5,7,9 are y tokens (answers).
inputs = seq[:-1] — drop the last token. Length = 9.
The model reads these 9 tokens all at once and produces 9 predictions simultaneously.
targets = seq[1:] — drop the first token. Length = 9.
Each target is the token that comes after the corresponding input. This is the standard language-model "next token prediction" setup.
Let's line them up and see what the model is predicting at each position:
| Position | Input token (seen) | Target token (predict) | Meaningful? | What the model must do |
|---|---|---|---|---|
| 0 | 3 | 3 | YES ✓ | Given $x_1=3$, predict $y_1$. (No prior context — hardest step.) |
| 1 | 3 | 4 | NO | Predict next $x$ (which is random) — not meaningful, loss masked out. |
| 2 | 4 | 5 | YES ✓ | Given $x_2=4$, predict $y_2$. Can see ($3\to3$) already. |
| 4 | 2 | 1 | YES ✓ | Given $x_3=2$, predict $y_3$. Can see ($3\to3, 4\to5$). |
| 6 | 1 | 4 | YES ✓ | Given $x_4=1$, predict $y_4$. Can see ($3\to3, 4\to5, 2\to1$). |
| 8 | 5 | 2 | YES ✓ | Given $x_5=5$, predict $y_5$. All 4 prior pairs visible. |
At each $y$-prediction position, the model outputs a probability distribution over the 5 possible values. Cross-entropy measures how wrong that distribution is:
# mask[::2] = 1.0 means positions 0, 2, 4, 6, 8 get mask=1 (active) # positions 1, 3, 5, 7 get mask=0 (ignored) loss_elementwise = criterion(logits.view(-1, logits.size(-1)), targets.view(-1)) # loss_elementwise shape: [batch*9] — one loss number per position per sample masked_loss = (loss_elementwise * masks.view(-1)).sum() # multiply by mask → zero out x-positions, keep only y-positions loss = masked_loss / masks.sum() # divide by number of active tokens (5 per sample) → average loss per y-prediction
The 5 $y$-positions have different difficulty levels:
| Step | Prior context | Best possible loss |
|---|---|---|
| y₁ | None — 5 candidates | ln(5)≈1.609 |
| y₂ | 1 pair seen — 4 candidates | ln(4)≈1.386 |
| y₃ | 2 pairs — 3 candidates | ln(3)≈1.099 |
| y₄ | 3 pairs — 2 candidates | ln(2)≈0.693 |
| y₅ | 4 pairs — 1 candidate | 0.000 |
The token embedding encodes what a token is (e.g. "I am the number 3"). The positional embedding encodes where it is in the sequence (e.g. "I am at position 4"). Adding them together gives each position's representation both pieces of information.
[3,3,4,5] the same as [4,5,3,3] — it would be blind to order.| Version | d | heads | Learns recall? |
|---|---|---|---|
| v1 | 16 | 1 | ✗ No |
| v2 | 32 | 1 | ~ Partial |
| v3 ✓ | 32 | 2 | ✓ 100.00% |
Every token is represented as a vector of $d_{\text{model}}$ numbers. Think of it as the "working memory" size for each position. Everything the model knows about a token — its identity, its position, the context it has gathered — must fit in this vector.
5 tokens need to be well-separated in space. ~5–6 dimensions sufficient.
Is this an x or a y? Even or odd position? ~2–3 dimensions.
Which $y$-values are used? $2^5=32$ subsets → needs ~8+ dimensions.
Test sequence from an unseen permutation: [2,3, 1,5, 4,4, 3,2, 5,1] — mapping: $2\to3, 1\to5, 4\to4, 3\to2, 5\to1$
By comparing the model's output distribution against the exact mathematical Bayesian posterior over the entire test set (5,000 samples), we can quantitatively measure the precision of its probabilistic reasoning:
| Prediction Position | Categorical Accuracy | Bayesian Target | Mean Absolute Error (MAE) | Max Absolute Error (Max-AE) |
|---|---|---|---|---|
| Step 1 (no history) | 9.56% | 20.00% | 0.012388 | 0.034079 |
| Step 2 (1 pair seen) | 16.40% | 25.00% | 0.015530 | 0.072218 |
| Step 3 (2 pairs seen) | 24.56% | 33.33% | 0.018369 | 0.101983 |
| Step 4 (3 pairs seen) | 42.26% | 50.00% | 0.015447 | 0.119183 |
| Step 5 (4 pairs seen) | 100.00% | 100.00% | 0.000292 | 0.001334 |
To determine if the models have actually learned associative recall (matching $x \to y$ in memory) or are simply performing elimination, we construct a probe sequence on the repeated-prompt dataset. Here, the fifth $x$ token repeats $x_1$ ($x_5 = x_1$):
We trained **both** models on this repeated dataset. Yet, the architectural capacity of the model dictates whether it can solve it.
Even with repeat training, this architecture lacks the representational bandwidth to retain both rules.
Doubling the model dimension and adding a head allows both algorithms to run in parallel.
When evaluated on the test set with the repeated structure, Model 2 successfully deploys both elimination and recall at Step 5. The metrics show precise execution of both processes:
| Subgroup Evaluated (Model 2) | Categorical Accuracy | Mean Absolute Error (MAE) | Max Absolute Error (Max-AE) |
|---|---|---|---|
| Step 5 (Pure Elimination Subgroup) | 100.00% | 0.000081 | 0.001548 |
| Step 5 (Memorization Subgroup) | 100.00% | 0.000838 | 0.055650 |
Each token produces Query (Q), Key (K), and Value (V) vectors. The attention weight between token $i$ and token $j$ represents the dot product of $Q_i \cdot K_j$, scaled and softmaxed. It governs the percentage of information transferred. By looking at these maps, we can mathematically unpack how the models allocate their heads.
Observe the weights for the query positions ($x_k$):
Historical Accumulation: Just like in the single-head model, this head focuses exclusively on $y$ tokens to track and eliminate already-used outputs.
Local Key-Value Binding: Here, $y$ tokens attend strongly to themselves ($y_2 \to y_2$: $1.00$) and $x$ tokens attend to the immediately preceding $y$ ($x_3 \to y_2$: $0.66$).
There is no rule in our PyTorch codebase that states "Head 1 must calculate elimination." Both heads are structurally and mathematically identical. Yet across different seeds, gradient descent consistently enforces this division of labor.
The recall head (Head 2) creates tight local bonds between consecutive ($x, y$) pairs. Specifically:
This creates a chain: each $x$ token's post-attention representation is heavily colored by the $y$ token that preceded it. The ($x, y$) pair is encoded as a unit.
After attention, each token passes through the Feed-Forward Network — two linear layers with a ReLU in between. The FFN is applied independently to each position (it does not mix information between positions).
Research on transformers (Geva et al., 2021: "Transformer Feed-Forward Layers Are Key-Value Memories") has shown that FFN layers act as associative memories: the first layer's weights are "keys" and the second layer's weights are "values." The FFN learns to map input patterns to output content.
The model sees thousands of examples where $x$ repeats and the correct answer is the earlier $y$ — not the unused one. This creates a loss signal that pure elimination cannot minimize.
Each $y$ position's representation becomes a stable anchor for its value (strong self-attention). Each $x$ absorbs its preceding $y$. The pair is encoded as a unit in the residual stream.
When $x_1$ repeats at position 8, its token embedding is identical to when it first appeared. The FFN recognizes this embedding pattern and returns $y_1$'s value, regardless of what elimination would suggest.
Both signals feed into the final prediction. When recall is strong (99.9%), it dominates even though elimination says "pick the unused value." The model has learned to trust recall when the context demands it.
We can observe the attention weights directly — that's what the heatmaps show. But the FFN's internal operation is less transparent. What we can say confidently:
A minimal causal transformer (1 layer, $d=32$, 2 heads) trained on a permutation-lookup task — the model must read ($x,y$) pairs from context and answer queries on unseen permutations.
| Config | Elimination | Recall Probe Acc. | Test Loss |
|---|---|---|---|
| $d=16$, 1h, repeated | ✓ 100.00% | ✗ 0.01% | 0.959 |
| $d=32$, 2h, repeated | ✓ 100.00% | ✓ 100.00% | 0.956 |
Loss can be misleading. The $d=16$ model achieves near-identical loss yet fails completely at recall. You must probe to understand the mechanism.
Models learn what training rewards. Without repeated sequences there is zero gradient signal toward recall. Add the signal → the model finds the skill.
The second head enabled a qualitative leap. Not just more capacity — a structurally different computation. One head for elimination, one for recall.
Head specialization is emergent and robust. No head was told its role. The split appears consistently — only the head numbering varies across seeds.
Recall is attention + FFN together. Attention forms the associations; the FFN performs the lookup. Neither is sufficient alone.
Multi-head attention discovers complementary functional roles without being told to — one head tracking what has been used, one binding what belongs together. This emergent modularity is not designed. It is found. And it only appears when the training data actually requires it.