Multi-level optimization with test-time learning (NeurIPS 2025)
A model is a hierarchy of nested optimization problems, each compressing its own "context flow"
Memory: A neural update caused by an input
Learning: The process for acquiring effective and useful memory
Higher level = Lower frequency. A ≻ B means fA > fB
Training: All levels update at their respective frequencies. Each level has its own gradient flow and context.
Inference: Only Level 1 (Memory) updates! This enables "test-time learning" - the model continues learning without backpropagation.
Transformers: Are a special case of CMS with k=1 (single MLP). All components freeze at inference → "anterograde amnesia".
Static projections
❄ All frozen at inference
❄ Frozen at inference
(+ potential data-dependent components)
Delta rule (Eq. 28-29): considers token dependencies
✨ Online learning at test time!
f1 > f2 > ... > fk (frequency hierarchy)
Each θ(fℓ) updates every C(ℓ) steps (Eq. 31)
❄ Frozen
❄ Frozen
Multi-frequency
How gradient descent with momentum becomes a 2-level nested optimization
Reinterpret as Nested Optimization
Use L2 regression instead of dot-product → Delta rule (Eq. 21-22)
Replace linear momentum with MLP → Deep Momentum GD (Eq. 23)
σ(·) = Newton-Schulz → Muon optimizer (Eq. 24)
L2 objective considers token dependencies → Eq. 28-29
All levels update at their respective frequencies
Level 1 (fastest): Memory updates every token. Inner optimization with dot-product objective.
Level 2 activates: Projection layers optimized with accumulated gradients.
Level 3: Momentum term is itself an associative memory (key-less).
Level 4 (slowest): Outermost optimization aggregates all nested updates.
Only Level 1 (Memory) updates - Test-time learning!
Eq. 12: kt=xtWk, vt=xtWv, qt=xtWq
Delta rule (Eq. 28-29) - handles token dependencies
🔥 Only this updates at inference!
Each MLP updates at frequency fℓ
At inference, only Level 1 (Memory) updates. All other levels are frozen.
Projections are 'slow weights' - consolidated knowledge from pre-training.
Unlike simple Hebbian (Mt + vt·kt⊤), delta rule manages memory capacity better.
Model combines fast (memory) and slow (frozen) knowledge for prediction.
Eq. 1: M* = arg min L̃(M(K); V)
Eq. 12-14: Linear attention formulation
Eq. 17: GD with momentum
Eq. 28-29: Delta rule for HOPE
Eq. 30-31: CMS formulation
Def. 2: Update frequency