Part of the cs231n assignment 2 is to calculate the gradients of Batch Normalisation layer. Here are the equations calculating the BN:
The basic partial derivatives of the above equations are as following. They are the building blocks to find the final ∂L/∂x.
L is loss function and
L is loss function and
However,
Thanks to this post I understand the processing using the computational graph. The following table shows the computational graph: top-down is the forward pass in black; bottom up is backward pass in red.
(1):= x (N,D) d(3)+d(2) |
*1/N*np.ones((N,D)) =*∂μ/∂x = ∂L/∂μ * ∂μ/∂x + ∂L/∂v * ∂v/∂μ * ∂μ/∂x |
(9):= γ (D,) | (11):= β (D,) |
↓ ↘→ | (2):= | ↓ | ↓ |
(d(4)+d(8)) =...*∂v/∂x - ∂L/∂μ = ∂L/∂v *∂v/∂x - ∂L/∂μ |
(-1)*(d(4)+d(8)).sum(axis=0) = - ∑(-∂L/∂μ - ∂L/∂μ2) = ∑(∂L/∂μ + ∂L/∂μ2) |
↓ | ↓ |
(3):= (1)-(2) | ←↙ | ↓ | ↓ |
↓ ↘→ | (4):= (3) **2 *2*(3) =*(-∂v/∂μ) = - ∂L/∂μ2 =*(∂v/∂x) = ∂L/∂v * ∂v/∂x |
↓ | ↓ |
↓ | (5):= var =
*1/N*np.ones((N,D)) |
↓ | ↓ |
↓ | (6): = std = sqrt((5)+ε) *0.5*1/std =*∂σ/∂v = ∂L/∂v |
↓ | ↓ |
*(7) =*(-∂Y/∂μ) = -∂L/∂μ |
(7):= 1/(6) *[-1/((6)**2)] =*∂Y/∂σ =∂L/∂σ |
↓ | ↓ |
(8):= (3) * (7) | ←↙ [*(3)].sum(axis=0) |
↓ | ↓ |
*γ= ∂L/∂Y |
↓ |
↓ | |
(10):= (8) * (9) | ←←↙ | *(8) | ↓ |
dout | ↓ | ||
(12):= (10) + (11) | ←←← | ←←←↙ | dβ= dout.sum(axis=0) |
out (N,D) Loss |
In python:
x, sample_mean, sample_var, sample_std, gamma, x_hat, eps = cache
N, D=dout.shape
dbeta = dout.sum(axis=0)
dgamma = (dout * x_hat).sum(axis=0)
# using computational graph in https://romenlaw.blogspot.com/2024/06/calculating-gradient-using-computation.html
step3 = x-sample_mean
d10 = dout * gamma
d8_3 = d10 * (1/sample_std)
d8_7 = (d10 * step3).sum(axis=0)
d7 = - d8_7 / (sample_var + eps)
d6 = d7 * 0.5 / sample_std
d5 = d6 / N * np.ones(shape=(N,D))
d4 = d5 * 2 * step3
d3_1 = d4 + d8_3 # (N,D)
d3_2 = -1 * (d4 + d8_3).sum(axis=0) # (D,)
d2 = d3_2 / N * np.ones(shape=(N,D)) # (N,D)
dx = d2 + d3_1
Intuitively, it's like following the 3 paths from Y to X directed by the red arrows. Now doing it the analytical way using chain rule.
From page 4 of the original paper https://arxiv.org/pdf/1502.03167 we have the formulae for the derivatives. However, since the equations used in the assignment 2 is different from the paper (especially how the variance and mean is used in calculating x_hat or y), we can rewrite the derivatives using the notations in assignment 2:
In python code (modified slightly from here for readability):
x, mean, var, std, gamma, x_hat, eps = cache
S = lambda x: x.sum(axis=0) # helper function
dbeta = dout.sum(axis=0)
dgamma = (dout * x_hat).sum(axis=0)
N = dout.shape[0] # dout dimension (N,D)
dx = dout * gamma / (N * std) # temporarily initialize scale value
dL_v_x = -S(dx*x_hat)*x_hat
dL_mu = - N*dx
dL_mu2 = -dL_v_x
d_mu_x = S(-dx + dL_mu2) #*np.ones(x.shape)
#d_mu_x = -S(dx)
dx = dL_v_x - dL_mu + d_mu_x
The dx difference between this and the above method is about 1e-10. Curiously, the standard answer ignores the dL_mu2 term but yields better result 5e-13. I wonder why. 2 months later, after watching Andrej Karpathy's lecture on back prop, I realised that dv/dμ is actually 0.
No comments:
Post a Comment