Research Presentation

Can a Tiny Transformer Learn to Remember?

An end-to-end mechanistic study revealing how tiny transformers solve in-context learning. We reverse-engineer the model's weights and activations to show how multi-head attention self-organizes into specialized, parallel circuits—cooperatively decoupling Bayesian elimination and associative recall without explicit supervision.

Architecture
1-layer Transformer
Vocabulary
V = 5 (numbers 1–5)
Models Compared
d=16 (1h) vs. d=32 (2h)
Key finding
Capacity enables dual circuits
A complete training sequence
33 45 21 14 5?
Input x Output y To predict
The mapping 3→3, 4→5, 2→1, 1→4, 5→2 is revealed pair by pair. The model must read it from context — it changes every single example.
Section 01

The Task: Permutation Lookup

What is a Permutation?

A permutation over V=5 is a rule that maps each of the numbers 1–5 to a unique other number in 1–5. Every input has exactly one output. No two inputs share an output.

1
4
2
1
3
3
4
5
5
2
5! = 120 total possible permutations

Why Is This Hard?

The mapping is different on every single training example. The model cannot memorize "3 always maps to 3" — because in the next example, 3 might map to 1.

Instead, the model must learn a meta-skill: read the (x,y) pairs from context, then answer new queries using what it just read.

This is called in-context learning — the "program" (the mapping) lives in the input, not the weights.

Three sub-skills the model needs

Skill 1

Read pairs from context. When the model sees 3, 3 it should understand: "3 maps to 3 in this episode."

Skill 2

Track used values. Once y=3 has been output, it cannot appear again (it's a valid permutation — no repeats).

Skill 3

Answer queries. When given a new x, use the pairs seen so far to find its y — or at least narrow the candidates.

Section 02

The Dataset — How Data is Generated

The Train/Test Split

All 120 permutations are split once, before any training. 80% (96 permutations) become training mappings. 20% (24) become test mappings. These sets never overlap.

This is critical: the model is evaluated on mappings it has never seen. We're testing the general skill of reading context — not memorization of specific mappings.

If the model scores well on test mappings, it has genuinely learned a skill, not a lookup table.

How One Training Example is Built

Every time the dataloader asks for a sample, three things happen fresh:

1. Pick a random mapping from the 96 train perms
2. Shuffle the order of x-values: e.g. [3,4,2,1,5]
3. Interleave: [3,3, 4,5, 2,1, 1,4, 5,2]
Shuffling the x-order is essential. Without it, the model could learn "the 1st pair is always 1→?" and never learn to read context properly.

Why is $V=5$ the Optimal Choice?

Choosing a vocabulary size of $V=5$ and configuring the dataloader's epoch sizes represent a deliberate balance between computational simplicity, architectural capacity constraints, and optimization stability.

Combinatorial Space

There are $5! = 120$ possible permutation mappings. Shuffling the input order ($5! = 120$) yields $120 \times 120 = 14,400$ total unique sequence configurations, providing a diverse training space.

Preventing Weight Memorization

At $V=4$ ($24$ permutations), the mappings are fewer than the model's intermediate dimensions ($d_{\text{model}}=32, d_{\text{ff}}=64$), allowing neurons to hardcode each rule. At $V=5$, the 120 permutations outnumber these dimensions, forcing the model to learn a compact in-context retrieval algorithm.

Optimization Dynamics

With $100,000$ training samples per epoch, the model sees each unique training configuration ($11,520$) roughly $8.7$ times. This high density provides the Adam optimizer with stable, smoothed gradient updates, helping the narrow network avoid bad local minima.

Generalization & Evaluation

By withholding 24 of the 120 mappings entirely from training, evaluating on $5,000$ test samples guarantees that the model has acquired a generalized meta-skill. The over-complete coverage of the test space ($2,880$ configurations) ensures stable statistical metrics.

The Code

def __getitem__(self, idx):
    mapping_perm = random.choice(self.mappings)          # pick one of 96 train perms
    mapping = {x: y for x, y in zip(self.elements, mapping_perm)}

    X_order = list(self.elements)
    random.shuffle(X_order)                              # shuffle [1,2,3,4,5] → e.g. [3,4,2,1,5]

    seq = []
    for x in X_order:
        seq.append(x)
        seq.append(mapping[x])                           # interleave: [3,3, 4,5, 2,1, 1,4, 5,2]

    inputs  = torch.tensor(seq[:-1], dtype=torch.long) # tokens 0–8  (length 9)
    targets = torch.tensor(seq[1:],  dtype=torch.long) # tokens 1–9  (length 9, shifted by 1)
    mask = torch.zeros(len(inputs), dtype=torch.float32)
    mask[::2] = 1.0                                    # positions 0,2,4,6,8 → the y-prediction positions
    return inputs, targets, mask
Section 03

Exactly What Goes Into the Model — and What Comes Out

This is the most important section to understand before anything else. Let's be completely explicit.

The Raw Sequence (length 10)

Say the mapping for this example is 3→3, 4→5, 2→1, 1→4, 5→2 and the x-order is [3,4,2,1,5]. The full sequence is:

33 45 21 14 52
pos 0pos 1 pos 2pos 3 pos 4pos 5 pos 6pos 7 pos 8pos 9

Positions 0,2,4,6,8 are x tokens (inputs/queries). Positions 1,3,5,7,9 are y tokens (answers).

inputs tensor (what the model sees)

inputs = seq[:-1] — drop the last token. Length = 9.

33 45 21 14 5
indices: 0 1 2 3 4 5 6 7 8

The model reads these 9 tokens all at once and produces 9 predictions simultaneously.

targets tensor (what we want predicted)

targets = seq[1:] — drop the first token. Length = 9.

3 45 21 14 52
indices: 0 1 2 3 4 5 6 7 8

Each target is the token that comes after the corresponding input. This is the standard language-model "next token prediction" setup.

Aligning inputs to targets — position by position

Let's line them up and see what the model is predicting at each position:

PositionInput token (seen)Target token (predict)Meaningful?What the model must do
033YES ✓Given $x_1=3$, predict $y_1$. (No prior context — hardest step.)
134NOPredict next $x$ (which is random) — not meaningful, loss masked out.
245YES ✓Given $x_2=4$, predict $y_2$. Can see ($3\to3$) already.
421YES ✓Given $x_3=2$, predict $y_3$. Can see ($3\to3, 4\to5$).
614YES ✓Given $x_4=1$, predict $y_4$. Can see ($3\to3, 4\to5, 2\to1$).
852YES ✓Given $x_5=5$, predict $y_5$. All 4 prior pairs visible.
The model outputs 9 predictions simultaneously in one forward pass. But loss is only computed at positions 0, 2, 4, 6, 8 — the 5 $y$-prediction positions. The odd positions (predicting the next $x$) are masked out because those predictions are meaningless.
Section 04

How Loss is Computed — Step by Step

Loss is the number that training minimizes. Understanding exactly what it measures here is essential.

What is Cross-Entropy Loss?

At each $y$-prediction position, the model outputs a probability distribution over the 5 possible values. Cross-entropy measures how wrong that distribution is:

CE_loss = −log( P(correct answer) )
If model says P(y=3) = 0.90 → loss = −log(0.90) = 0.105 (low, good)
If model says P(y=3) = 0.20 → loss = −log(0.20) = 1.609 (high, bad)
Random guessing among 5 values gives P = 0.20 each → loss = −log(0.20) = ln(5) ≈ 1.609. This is the theoretical ceiling. A model that does better than this has learned something.

The Masking — Why We Don't Train on Every Position

# mask[::2] = 1.0 means positions 0, 2, 4, 6, 8 get mask=1 (active)
# positions 1, 3, 5, 7 get mask=0 (ignored)

loss_elementwise = criterion(logits.view(-1, logits.size(-1)), targets.view(-1))
# loss_elementwise shape: [batch*9] — one loss number per position per sample

masked_loss = (loss_elementwise * masks.view(-1)).sum()
# multiply by mask → zero out x-positions, keep only y-positions

loss = masked_loss / masks.sum()
# divide by number of active tokens (5 per sample) → average loss per y-prediction
Why mask? Because predicting the next $x$ is trivial and meaningless — $x$ values are chosen randomly by the dataset and don't follow any learnable pattern. Including them in the loss would add noise and dilute the signal.

Loss at Each Position — What a Perfect Model Would See

The 5 $y$-positions have different difficulty levels:

StepPrior contextBest possible loss
y₁None — 5 candidatesln(5)≈1.609
y₂1 pair seen — 4 candidatesln(4)≈1.386
y₃2 pairs — 3 candidatesln(3)≈1.099
y₄3 pairs — 2 candidatesln(2)≈0.693
y₅4 pairs — 1 candidate0.000
A model that only does elimination (tracks used values, uniform over the rest) would achieve roughly: (1.609+1.386+1.099+0.693+0.000)/5 ≈ 0.957 average loss per $y$-token.
Our model plateaus at ~0.959 — almost exactly this number. This tells us the baseline model is doing elimination only and nothing more.
The loss alone cannot tell you the mechanism. To know whether the model is using elimination or recall, you have to probe it — which we do in Section 07.
Section 05

The Model Architecture

🔢
Token Embedding
Embedding(6, d_model) — integers 1–5 become d-dimensional vectors
+
📍
Positional Embedding
Embedding(9, d_model) — positions 0–8 become d-dimensional vectors
🔍
Multi-Head Self-Attention
nhead heads, causal mask — each token gathers info from past tokens
Feed-Forward Network (FFN)
d_model → 2×d_model → d_model with ReLU activation
📊
Linear Output Layer
d_model → 6 logits → softmax → probability over values 1–5

Why sum token + position embeddings?

The token embedding encodes what a token is (e.g. "I am the number 3"). The positional embedding encodes where it is in the sequence (e.g. "I am at position 4"). Adding them together gives each position's representation both pieces of information.

Without positional embeddings, the model would treat [3,3,4,5] the same as [4,5,3,3] — it would be blind to order.

Configuration Evolution

VersiondheadsLearns recall?
v1161✗ No
v2321~ Partial
v3 ✓322✓ 100.00%

What is $d_{\text{model}}$ and why does it matter?

Every token is represented as a vector of $d_{\text{model}}$ numbers. Think of it as the "working memory" size for each position. Everything the model knows about a token — its identity, its position, the context it has gathered — must fit in this vector.

Token identity

5 tokens need to be well-separated in space. ~5–6 dimensions sufficient.

Position signal

Is this an x or a y? Even or odd position? ~2–3 dimensions.

Context state

Which $y$-values are used? $2^5=32$ subsets → needs ~8+ dimensions.

$d=16$ barely clears $6+3+8=17$. $d=32$ gives comfortable headroom, especially for the recall task where ($x\to y$) bindings must also be stored.
Section 06

Results: Elimination and the Bayesian Hypothesis

Test sequence from an unseen permutation: [2,3, 1,5, 4,4, 3,2, 5,1] — mapping: $2\to3, 1\to5, 4\to4, 3\to2, 5\to1$

Input seen so far → [2] | Used y: [] | Target: y=3
y=1
0.208
y=2
0.188
y=3
0.201
✓ target
y=4
0.169
y=5
0.233
The model has seen only one token: $x=2$. It has zero context — no ($x,y$) pairs have been revealed yet. All five $y$-values are equally possible. Output is close to uniform (≈20% each). This is exactly the right behavior: there is no information to act on.
The true target $y=3$ has only 20.1% probability. That is not a failure. It is the only honest answer given no context.
Input seen so far → [2,3, 1] | Used y: [3] | Target: y=5
y=1
0.260
y=2
0.236
y=3
0.000
✕ used
y=4
0.236
y=5
0.269
✓ target
Elimination kicks in. $y=3$ was used as $y_1$. The model has crushed its probability to ~0.000. The remaining 4 candidates share probability roughly equally (~25% each). The model has learned to track used values and rule them out.
Input seen → [2,3, 1,5, 4] | Used y: [3,5] | Target: y=4
y=1
0.347
y=2
0.368
y=3
0.000
✕ used
y=4
0.285
✓ target
y=5
0.000
✕ used
Two values eliminated. Three remain ($y=1, y=2, y=4$). The model spreads probability ~⅓ each over valid candidates. It doesn't have enough context to know which of these three $x=4$ maps to — it has not seen $x=4$ before. Rational uncertainty.
Input seen → [2,3, 1,5, 4,4, 3] | Used y: [3,5,4] | Target: y=2
y=1
0.546
y=2
0.454
✓ target
y=3
0.000
✕ used
y=4
0.000
✕ used
y=5
0.000
✕ used
Three eliminated. Only $y=1$ and $y=2$ remain. The model is ~50/50. This is completely rational — $x=3$ hasn't been seen before, so without a prior ($x,y$) pair for $x=3$, there's no way to distinguish $y=1$ from $y=2$.
Input seen → [..., 3,2, 5] | Used y: [3,5,4,2] | Target: y=1
y=1
0.9999
✓ target
y=2
0.000
✕ used
y=3
0.000
✕ used
y=4
0.000
✕ used
y=5
0.000
✕ used
99.99% confidence! Four values eliminated — only $y=1$ remains. The model is certain by deduction alone. Critically, it doesn't need to know $x=5$ maps to $y=1$; it just knows nothing else is available.
This is the danger: step 5 always looks perfect for a pure elimination model, regardless of whether it has learned the actual mapping. We must probe it to find out.

Quantitative Validation: The Perfect Bayesian Proof

By comparing the model's output distribution against the exact mathematical Bayesian posterior over the entire test set (5,000 samples), we can quantitatively measure the precision of its probabilistic reasoning:

Prediction Position Categorical Accuracy Bayesian Target Mean Absolute Error (MAE) Max Absolute Error (Max-AE)
Step 1 (no history) 9.56% 20.00% 0.012388 0.034079
Step 2 (1 pair seen) 16.40% 25.00% 0.015530 0.072218
Step 3 (2 pairs seen) 24.56% 33.33% 0.018369 0.101983
Step 4 (3 pairs seen) 42.26% 50.00% 0.015447 0.119183
Step 5 (4 pairs seen) 100.00% 100.00% 0.000292 0.001334
Analysis:
  • For Step 1, with no context, the model cannot guess the target above random chance (~10% empirical accuracy on this particular split is expected with randomized selections). Crucially, however, the model's entropy is correct: the MAE of 0.012 from a flat prior indicates it distributes its probabilities uniformly. The model mathematically knows what it does not know.
  • At Step 5, the model resolves the permutation with near-perfect confidence: a Max-AE of 0.001334 proves that across all test sequences, the model never assigned less than 99.8% probability to the correct remaining element.
Section 07

Probing Capacity: The Memorization Trap

The Memorization Probe Setup

To determine if the models have actually learned associative recall (matching $x \to y$ in memory) or are simply performing elimination, we construct a probe sequence on the repeated-prompt dataset. Here, the fifth $x$ token repeats $x_1$ ($x_5 = x_1$):

24 53 45 32 2?
The Elimination Shortcut ($y=1$)
Since $y \in \{4, 3, 5, 2\}$ have already been output, the only unused element is $y=1$. A pure elimination circuit will choose $y=1$ with absolute certainty.
The Associative Recall Truth ($y=4$)
Since $x_5 = 2$, and the context explicitly shows $2 \to 4$ at Step 1, the true mapping is $y_5 = 4$. To get this right, the model must ignore elimination and prioritize memorization.

We trained **both** models on this repeated dataset. Yet, the architectural capacity of the model dictates whether it can solve it.

Model 1 ($d=16$, 1 head) — Capacity Deficit (Fails)

Even with repeat training, this architecture lacks the representational bandwidth to retain both rules.

Probe sequence: [2,4, 5,3, 4,5, 3,2, 2,?] — correct recall = $y=4$
y=1
0.9991
← elimination trap
y=4
0.0002
← recall target
Result: Fails completely. It reverts exclusively to the elimination circuit because $d_{\text{model}}=16$ does not provide enough active dimensions to store the key-value bindings of recall concurrently.

Model 2 ($d=32$, 2 heads) — Dual-Circuits Enabled (Succeeds)

Doubling the model dimension and adding a head allows both algorithms to run in parallel.

Probe sequence: [2,4, 5,3, 4,5, 3,2, 2,?] — correct recall = $y=4$
y=4
0.9985
← correct recall!
y=1
0.000
← elimination
Result: Achieves a near-perfect recall rate. At $d_{\text{model}}=32$ and with 2 heads, the network successfully decouples the two circuits, allowing the recall mechanism to completely override elimination when a repeated $x$ is queried.

Proving Dual-Circuit Competence in the Larger Model (Model 2)

When evaluated on the test set with the repeated structure, Model 2 successfully deploys both elimination and recall at Step 5. The metrics show precise execution of both processes:

Subgroup Evaluated (Model 2) Categorical Accuracy Mean Absolute Error (MAE) Max Absolute Error (Max-AE)
Step 5 (Pure Elimination Subgroup) 100.00% 0.000081 0.001548
Step 5 (Memorization Subgroup) 100.00% 0.000838 0.055650
The Representation Verdict: Capacity is the ultimate bottleneck. Model 2 achieves a perfect 100.00% accuracy across both subgroups. The low MAE and Max-AE verify that the model has not simply "hacked" the loss; it acts as a mathematically rigorous Bayesian selector when a token is new, and transitions to a high-confidence associative recall engine when a token is repeated.
Section 08

Attention Maps — Decoupling the Mechanisms

Understanding Attention Weight Routing

Each token produces Query (Q), Key (K), and Value (V) vectors. The attention weight between token $i$ and token $j$ represents the dot product of $Q_i \cdot K_j$, scaled and softmaxed. It governs the percentage of information transferred. By looking at these maps, we can mathematically unpack how the models allocate their heads.

Task 1: Single-Head Model ($d=16$)

Attention Map

d=16, nhead=1

Mathematical Pooling in Action

Observe the weights for the query positions ($x_k$):

  • $x_2$: Attends $1.00$ to $y_1$ ($1/1$).
  • $x_3$: Attends $0.49$ to $y_1$ and $0.51$ to $y_2$ ($\approx 1/2$ each).
  • $x_4$: Attends $0.34$, $0.31$, $0.35$ to $y_1, y_2, y_3$ ($\approx 1/3$ each).
  • $x_5$: Attends $0.24$, $0.22$, $0.28$, $0.25$ to $y_1, y_2, y_3, y_4$ ($\approx 1/4$ each).
This single head acts as a near-perfect uniform accumulator. By spreading attention evenly across all prior $y$ outputs, the representation at position $x_k$ obtains an unweighted pooled average of the used values, which the linear output layer subsequently suppresses to zero.

Task 2: Multi-Head Model ($d=32$)

Head 1 — Elimination

d=32, nhead=2

Historical Accumulation: Just like in the single-head model, this head focuses exclusively on $y$ tokens to track and eliminate already-used outputs.

The $x$ query rows ignore all other $x$ embeddings and build an average of preceding $y$ outputs, ensuring candidate restriction.

Head 2 — Recall

local (x,y) binding

Local Key-Value Binding: Here, $y$ tokens attend strongly to themselves ($y_2 \to y_2$: $1.00$) and $x$ tokens attend to the immediately preceding $y$ ($x_3 \to y_2$: $0.66$).

This isolates consecutive $(x, y)$ pairs as unified spatial states, allowing the Feed-Forward Network to recognize repeated tokens.

Head Specialization as an Emergent Modularity

There is no rule in our PyTorch codebase that states "Head 1 must calculate elimination." Both heads are structurally and mathematically identical. Yet across different seeds, gradient descent consistently enforces this division of labor.

This functional decoupling is highly robust. When given adequate dimensional space ($d_{\text{model}}=32$), the system naturally resolves the objective by allocating specialized algorithmic roles to separate heads. This emergent specialization fails entirely at $d_{\text{model}}=16$.
Section 09

How Recall Actually Works — Attention + FFN

The attention map for the recall head shows a suspicious pattern: $x$ tokens attend to the immediately preceding $y$, not to an earlier $x_1$. So how does the model recall $y_1$ when $x_1$ repeats? The answer is that attention and the FFN do different jobs.

Step 1: What Attention Does

The recall head (Head 2) creates tight local bonds between consecutive ($x, y$) pairs. Specifically:

y₁ attends to itself: 1.00 → y₁'s output representation is strongly self-anchored
x₂ attends to y₁: 1.00 → x₂ absorbs y₁'s representation
y₂ attends to itself: 1.00 → y₂ is self-anchored
x₃ attends to y₂: 0.66 → x₃ absorbs y₂'s representation

This creates a chain: each $x$ token's post-attention representation is heavily colored by the $y$ token that preceded it. The ($x, y$) pair is encoded as a unit.

Step 2: What the FFN Does

After attention, each token passes through the Feed-Forward Network — two linear layers with a ReLU in between. The FFN is applied independently to each position (it does not mix information between positions).

Research on transformers (Geva et al., 2021: "Transformer Feed-Forward Layers Are Key-Value Memories") has shown that FFN layers act as associative memories: the first layer's weights are "keys" and the second layer's weights are "values." The FFN learns to map input patterns to output content.

When $x_1$ repeats at position 8, its token embedding is identical to when it first appeared. The FFN — acting as a key-value store — recognizes this embedding pattern and returns $y_1$'s value, regardless of what elimination would suggest.

The Full Recall Mechanism

Training on repeated sequences

The model sees thousands of examples where $x$ repeats and the correct answer is the earlier $y$ — not the unused one. This creates a loss signal that pure elimination cannot minimize.

Head 2 forms (x,y) local bonds via attention

Each $y$ position's representation becomes a stable anchor for its value (strong self-attention). Each $x$ absorbs its preceding $y$. The pair is encoded as a unit in the residual stream.

FFN performs associative lookup

When $x_1$ repeats at position 8, its token embedding is identical to when it first appeared. The FFN recognizes this embedding pattern and returns $y_1$'s value, regardless of what elimination would suggest.

Head 1's elimination signal is overridden

Both signals feed into the final prediction. When recall is strong (99.9%), it dominates even though elimination says "pick the unused value." The model has learned to trust recall when the context demands it.

Why we can't be 100% certain it's the FFN

We can observe the attention weights directly — that's what the heatmaps show. But the FFN's internal operation is less transparent. What we can say confidently:

Confirmed: The recall head's attention pattern (local $x\to\text{preceding } y$ bonding) is necessary but not sufficient. Attention alone cannot perform a lookup to an $x$-position that isn't being attended to.
Strongly implied: The FFN must be completing the recall — it's the only other component with the capacity to perform position-independent associative lookups. This is consistent with Geva et al. (2021) and follow-up work on transformer memory.
Section 10

Summary — The Full Story

What We Built

A minimal causal transformer (1 layer, $d=32$, 2 heads) trained on a permutation-lookup task — the model must read ($x,y$) pairs from context and answer queries on unseen permutations.

Results Comparison

ConfigEliminationRecall Probe Acc.Test Loss
$d=16$, 1h, repeated✓ 100.00%✗ 0.01%0.959
$d=32$, 2h, repeated✓ 100.00%✓ 100.00%0.956

Key Numbers

Random baseline loss: ln(5) = 1.609
Pure elimination ceiling: ≈ 0.957
Final model loss: 0.956
Step 5 Memorization Probe MAE: 0.000838

The 5 Core Findings

1

Loss can be misleading. The $d=16$ model achieves near-identical loss yet fails completely at recall. You must probe to understand the mechanism.

2

Models learn what training rewards. Without repeated sequences there is zero gradient signal toward recall. Add the signal → the model finds the skill.

3

The second head enabled a qualitative leap. Not just more capacity — a structurally different computation. One head for elimination, one for recall.

4

Head specialization is emergent and robust. No head was told its role. The split appears consistently — only the head numbering varies across seeds.

5

Recall is attention + FFN together. Attention forms the associations; the FFN performs the lookup. Neither is sufficient alone.

Core insight

Multi-head attention discovers complementary functional roles without being told to — one head tracking what has been used, one binding what belongs together. This emergent modularity is not designed. It is found. And it only appears when the training data actually requires it.

Minimal Transformer Interpretability — d=32, nhead=2, 1 layer, V=5, 50 epochs, repeat_ratio=0.30