Momentum
gives every parameter the same learning rate. But in a deep network different parameters live
on wildly different scales: some gradients are consistently large, others tiny. A single step
size that suits one will be hopeless for the other.
Adam (Adaptive Moment Estimation) gives every parameter its own
learning rate. It keeps two running averages of the gradient: the first moment
(the mean — momentum) and the second moment (the mean of the squared gradient
— à la RMSProp). Dividing the first by the root of the second normalizes each parameter's step
by how big its gradients typically are. Large-gradient directions get reined in; small-gradient
directions get amplified.
Deriving the Adam update
At step t let g_t = \nabla L(\theta_t).
Adam carries two buffers, m_t (first moment) and
v_t (second moment), both starting at zero, with decay rates
\beta_1 \approx 0.9 and \beta_2 \approx 0.999.
Step 1 — the first moment (momentum). An exponentially-weighted average of
the gradient:
m_t = \beta_1\, m_{t-1} + (1 - \beta_1)\, g_t.
Step 2 — the second moment (RMS). An exponentially-weighted average of the
squared gradient (element-wise g_t^2):
v_t = \beta_2\, v_{t-1} + (1 - \beta_2)\, g_t^2.
Step 3 — why the moments are biased. Both buffers start at
0, so early on they are pulled toward zero — biased
low. Unrolling the first moment with a steady gradient
g shows the deficit exactly:
m_t = (1 - \beta_1) \sum_{k=1}^{t} \beta_1^{\,t-k}\, g = g\,(1 - \beta_1^{\,t}).
So m_t is a factor (1 - \beta_1^{\,t})
short of the true mean g — a large gap when
t is small.
Step 4 — bias-correct. Divide out exactly that factor to recover an unbiased
estimate:
\hat{m}_t = \frac{m_t}{1 - \beta_1^{\,t}}, \qquad \hat{v}_t = \frac{v_t}{1 - \beta_2^{\,t}}.
As t \to \infty the correction
(1 - \beta_1^{\,t}) \to 1 and fades away — it only matters in the
first few hundred steps, exactly when the raw moments are most starved.
Step 5 — the adaptive update. Step along the bias-corrected first moment,
scaled per parameter by the root of the bias-corrected second moment (with a tiny
\varepsilon \approx 10^{-8} to avoid dividing by zero):
\theta_t = \theta_{t-1} - \eta\, \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \varepsilon}.
The ratio is the magic: a parameter whose gradients are consistently large has large
\sqrt{\hat{v}_t}, so its step shrinks; a parameter with tiny
gradients has small \sqrt{\hat{v}_t}, so its step grows. Every
parameter ends up taking a step of roughly comparable size, no matter its natural scale.
With g_t = \nabla L(\theta_t) and
m_0 = v_0 = 0:
-
Two moments.
m_t = \beta_1 m_{t-1} + (1-\beta_1) g_t (mean) and
v_t = \beta_2 v_{t-1} + (1-\beta_2) g_t^2 (mean square).
-
Bias correction.
\hat{m}_t = m_t/(1-\beta_1^{\,t}) and
\hat{v}_t = v_t/(1-\beta_2^{\,t}) undo the zero-initialization
bias; the correction vanishes as t\to\infty.
-
Adaptive update.
\theta_t = \theta_{t-1} - \eta\,\hat{m}_t/(\sqrt{\hat{v}_t} + \varepsilon)
gives each parameter its own effective learning rate.
Classic L2 regularization adds \lambda\theta to the gradient.
But Adam then divides that penalty by \sqrt{\hat{v}_t} too — so
parameters with large gradients get less decay, which is backwards. The weight
decay gets tangled up in the adaptive scaling.
AdamW fixes this by decoupling weight decay from the gradient:
apply the adaptive step, then shrink the weights separately,
\theta_t = \theta_{t-1} - \eta\, \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \varepsilon} - \eta\,\lambda\,\theta_{t-1}.
The decay now hits every parameter uniformly, independent of its gradient history. This
small change measurably improves generalization, which is why AdamW is the default
optimizer for training transformers and most modern large models.
Three optimizers, one bowl
The same elongated, ill-conditioned bowl, descended three ways from the same start: plain
SGD,
SGD with
momentum,
and Adam. SGD crawls and zig-zags; momentum cuts across but can overshoot; Adam, rescaling
each axis by its own gradient magnitude, drives almost straight to the minimum. Step through
the iterations and watch the lead change.