Saturday, 22 June 2024

Calculating Gradient of Batch Normalisation

Part of the cs231n assignment 2 is to calculate the gradients of Batch Normalisation layer. Here are the equations calculating the BN:

X=[ x1 x2
...
xN
]   with dimension (N, D)

μ= 1N k=1 N xk   with dimension (D,)
v= 1N k=1 N ( xkμ) 2   with dimension (D,)
σ= v+ε    with dimension (D,)
yi= xi-μ σ    where yi is dimension (D,) and Y (or x_hat) is (N, D)
The basic partial derivatives of the above equations are as following. They are the building blocks to find the final ∂L/∂x.
L is loss function and LY=γ×dout, of dimension (N, D)
μxi = 1N k=1 N 1    of dimension (D,).
vμ = 1N k=1 N (2xk2μ) = 2N k=1 N (xkμ)    of dimension (D,).
This turns out to be 0 because sum of xi and sum of mu are the same
However, vxi = 2N (xiμ)
σv = 0.5×1v+ε = 12σ
yiμ = −1σ
yiσ = −xiμσ2

Thanks to this post I understand the processing using the computational graph. The following table shows the computational graph: top-down is the forward pass in black; bottom up is backward pass in red.

(1):= x (N,D)
d(3)+d(2)
*1/N*np.ones((N,D))
=*
∂μ/∂x
∂L/∂μ * ∂μ/∂x
∂L/∂v * ∂v/μ * ∂μ/∂x
(9):= γ (D,) (11):= β (D,)
↓            ↘→ (2):= mean= 1 N i=1 N xi  ↓   ↓
(d(4)+d(8))
=...*∂v/∂x - ∂L/∂μ
∂L/∂v *∂v/∂x - ∂L/μ
(-1)*(d(4)+d(8)).sum(axis=0)
= - ∑(-
∂L/∂μ - ∂L/μ2)
=
∑(∂L/∂μ + ∂L/μ2)
 ↓   ↓
(3):= (1)-(2)  ←↙   ↓   ↓
↓            ↘→ (4):= (3) **2
*2*(3)
=*(-∂v/
μ) = - ∂L/μ2
=*(∂v/∂x) = ∂L/∂v * 
∂v/∂x
 ↓   ↓
       ↓ (5):= var =         1 N i=1 N (4)i
*1/N*np.ones((N,D))
 ↓   ↓
       ↓ (6): = std = sqrt((5)+ε)
*0.5*1/std
=*
∂σ/∂v = ∂L/∂v
 ↓   ↓
*(7)
=*(-
∂Y/∂μ) = -∂L/∂μ
(7):= 1/(6)
*[-1/((6)**2)]
=*
∂Y/∂σ =∂L/σ
 ↓   ↓
(8):= (3) * (7) ←↙
[*(3)].sum(axis=0)
 ↓   ↓
*γ=
∂L/∂Y
 ↓
  ↓
(10):= (8) * (9) ←←↙ *(8)   ↓
dout   ↓
(12):= (10) + (11) ←←← ←←←↙ dβ= dout.sum(axis=0)
out (N,D)
Loss
In python:
x, sample_mean, sample_var, sample_std, gamma, x_hat, eps = cache
    N, D=dout.shape
   
    dbeta = dout.sum(axis=0)
    dgamma = (dout * x_hat).sum(axis=0)

    # using computational graph in https://romenlaw.blogspot.com/2024/06/calculating-gradient-using-computation.html
    step3 = x-sample_mean

    d10 = dout * gamma
    d8_3 = d10 * (1/sample_std)          
    d8_7 = (d10 * step3).sum(axis=0)    
    d7 = - d8_7 / (sample_var + eps)    
    d6 = d7 * 0.5 / sample_std          
    d5 = d6 / N * np.ones(shape=(N,D))  
    d4 = d5 * 2 * step3                  
    d3_1 = d4 + d8_3                     # (N,D)
    d3_2 = -1 * (d4 + d8_3).sum(axis=0)  # (D,)
    d2 = d3_2 / N * np.ones(shape=(N,D)) # (N,D)
    dx = d2 + d3_1

Intuitively, it's like following the 3 paths from Y to X directed by the red arrows. Now doing it the analytical way using chain rule.
From page 4 of the original paper https://arxiv.org/pdf/1502.03167 we have the formulae for the derivatives. However, since the equations used in the assignment 2 is different from the paper (especially how the variance and mean is used in calculating x_hat or y), we can rewrite the derivatives using the notations in assignment 2:
 
Lxi =( Lμ μxi + Lv vμ μxi ) + Lv vxi Lyi yiμ
     = Lyi yiμ μxi + Lyi yiσ σv vμ μxi + Lyi yiσ σv vxi Lyi yiμ
     =[ doutiγ (1σ) 1N i=1 N 1
       doutiγ (xiμσ2) 12σ 2N i=1 N (xiμ) 1N k=1 N 1 ]
       + doutiγ (xiμσ2) 12σ 2N i=1 N (xiμ)
       doutiγ (1σ)
     = 1N i=1 N [ doutiγ σ + 1N ( i=1 N doutiγyi σ ) yi ]        ← in below python code, this is the dL_mu_x
       1N ( i=1 N doutiγyi σ ) yi        ← in below python code, this is the dL_v_x
       + doutiγ σ        ← in below python code, this is the dL_mu
In python code (modified slightly from here for readability):
x, mean, var, std, gamma, x_hat, eps = cache
    S = lambda x: x.sum(axis=0)                     # helper function
   
    dbeta = dout.sum(axis=0)
    dgamma = (dout * x_hat).sum(axis=0)

    N = dout.shape[0]  # dout dimension (N,D)
    dx = dout * gamma / (N * std)          # temporarily initialize scale value
   
    dL_v_x = -S(dx*x_hat)*x_hat
    dL_mu = - N*dx

    dL_mu2 = -dL_v_x
    d_mu_x = S(-dx + dL_mu2)  #*np.ones(x.shape)
    #d_mu_x = -S(dx)

    dx = dL_v_x - dL_mu + d_mu_x
The dx difference between this and the above method is about 1e-10. Curiously, the standard answer ignores the dL_mu2 term but yields better result 5e-13. I wonder why. 2 months later, after watching Andrej Karpathy's lecture on back prop, I realised that dv/dμ is actually 0.

No comments: