Nested Learning & HOPE Architecture

Multi-level optimization with test-time learning (NeurIPS 2025)

Overview
Architecture
Optimizers
Training
Inference

Nested Learning Paradigm

A model is a hierarchy of nested optimization problems, each compressing its own "context flow"

Associative Memory (Definition 1)
M* = arg minM L̃(M(K); V)
Memory learns to map Keys → Values by minimizing objective L̃

Learning vs. Memorization (from paper)

Memory: A neural update caused by an input

Learning: The process for acquiring effective and useful memory

Update Frequency Hierarchy (Definition 2)

Higher level = Lower frequency. A ≻ B means fA > fB

Level 1: Memory Mt f = 1/token (fastest)
Level 2: Projections Wk, Wv, Wq f = 1/batch
Level 3: Momentum f = 1/batch
Level 4: Pre-training f = 1/epoch (slowest)

Key Insight

Training: All levels update at their respective frequencies. Each level has its own gradient flow and context.

Inference: Only Level 1 (Memory) updates! This enables "test-time learning" - the model continues learning without backpropagation.

Transformers: Are a special case of CMS with k=1 (single MLP). All components freeze at inference → "anterograde amnesia".

Architecture Comparison

Transformer

Static after pre-training (CMS with k=1)
xt

Multi-Head Attention

Wq Wk Wv

Static projections

qt=xtWq   kt=xtWk   vt=xtWv
softmax(QKT/√d)V

❄ All frozen at inference

Add & LayerNorm

Feed-Forward Network

(CMS with k=1: single MLP)

❄ Frozen at inference

yt
Limitation: All parameters frozen after pre-training. Like "anterograde amnesia" - cannot form new long-term memories.

HOPE

Self-Modifying Titans + Continuum Memory
xt

Self-Modifying Memory (Titans-based)

Wq Wk Wv

(+ potential data-dependent components)

Memory Mt 🔥 ACTIVE AT INFERENCE
Mt+1 = Mt(I - ktkt) - η∇L

Delta rule (Eq. 28-29): considers token dependencies

Mt

✨ Online learning at test time!

yt = Mt · qt (Eq. 14)

Continuum Memory System (CMS)

Eq. 30: yt = MLP(fk)(...MLP(f1)(xt))
MLP1 f1 fastest
MLP2 f2
MLP3 f3
MLPk fk slowest

f1 > f2 > ... > fk (frequency hierarchy)

Each θ(f) updates every C(ℓ) steps (Eq. 31)

yt
Key Innovation: Memory updates at inference using delta rule, enabling continual learning. CMS provides multi-timescale knowledge storage.

Transformer

xt

Attention

❄ Frozen

FFN (k=1)

❄ Frozen

yt

HOPE

xt

Memory Mt

🔥 Active

CMS (k MLPs)

Multi-frequency

yt

Optimizers as Associative Memory

How gradient descent with momentum becomes a 2-level nested optimization

Gradient Descent with Momentum (Eq. 17)
mi+1 = αi+1mi - ηt∇L(Wi; xi)
Wi+1 = Wi + mi+1
Note: Paper uses + not - in weight update

Reinterpret as Nested Optimization

Level 2: Weight Update
Wi+1 = Wi + mi+1
Level 1: Momentum as Memory (Eq. 10)
mt+1 = arg minm -⟨m, ∇L(Wt; xt+1)⟩ + ηt+1∥m - mt∥²
Momentum is a "key-less" associative memory compressing gradients

Extensions (Section 2.3)

More Expressive Objectives

Use L2 regression instead of dot-product → Delta rule (Eq. 21-22)

More Expressive Memory

Replace linear momentum with MLP → Deep Momentum GD (Eq. 23)

Non-Linear Outputs

σ(·) = Newton-Schulz → Muon optimizer (Eq. 24)

Better Backprop

L2 objective considers token dependencies → Eq. 28-29

From Section 2.3: "Adam with a small modification is the optimal associative memory for the models' gradients" (see Appendix C.4)

Training Mode

All levels update at their respective frequencies

Level 4: Pre-training (f = 1/epoch) - Outermost loop
Level 3: Momentum (f = 1/batch) - Gradient compression
mi+1 = αi+1mi - η∇L(Wi; xi)
Wi+1 = Wi + mi+1 (Eq. 17)
Level 2: Projections Wk, Wv, Wq (f = 1/batch)
Level 1: Memory Mt (f = 1/token) - Context compression
Linear Attention (Eq. 13):
Mt = Mt-1 + vtkt
Equivalent to: arg minM ⟨Mkt, vt⟩ + ∥M - Mt-1∥²
Mt
= +
Δ

Training Flow

1
Token arrives → Update Memory Mt (Eq. 13)

Level 1 (fastest): Memory updates every token. Inner optimization with dot-product objective.

2
Batch complete → Backprop through projections

Level 2 activates: Projection layers optimized with accumulated gradients.

3
Compute loss → Update momentum (Eq. 17)

Level 3: Momentum term is itself an associative memory (key-less).

4
Epoch complete → Aggregate all nested updates

Level 4 (slowest): Outermost optimization aggregates all nested updates.

Inference Mode

Only Level 1 (Memory) updates - Test-time learning!

Input Token
xt
Wk
Wv
Wq

Eq. 12: kt=xtWk, vt=xtWv, qt=xtWq

kt vt qt
Memory Update ACTIVE!
Mt+1 = Mt(I - ktkt) - η∇L(Mt; kt, vt)

Delta rule (Eq. 28-29) - handles token dependencies

Mt

🔥 Only this updates at inference!

yt = Mt · qt (Eq. 14)

Continuum Memory System (Eq. 30-31)

MLP1
MLP2
...
MLPk

Each MLP updates at frequency f

Inference Flow

1
New token xt arrives

At inference, only Level 1 (Memory) updates. All other levels are frozen.

2
Project through frozen Wk, Wv, Wq (Eq. 12)

Projections are 'slow weights' - consolidated knowledge from pre-training.

3
Update memory with delta rule (Eq. 28-29)

Unlike simple Hebbian (Mt + vt·kt), delta rule manages memory capacity better.

4
Query retrieves: yt = Mt·qt

Model combines fast (memory) and slow (frozen) knowledge for prediction.

Key Equations Reference

Eq. 1: M* = arg min L̃(M(K); V)
Eq. 12-14: Linear attention formulation
Eq. 17: GD with momentum
Eq. 28-29: Delta rule for HOPE
Eq. 30-31: CMS formulation
Def. 2: Update frequency