The image captioning RNN architecture in assignment 3:
Forward and backward passes through the above network:
# (1) CNN features (N, D) -> hiddent state projection layer
h0, affine_cache = affine_forward(features, W_proj, b_proj) # (N, H)
# (2) word embedding, captions_in (N, T), output (N, T, W)
# captions_in contains indices of words in the dictionary
embedded_words, word_eb_cache = word_embedding_forward(captions_in, W_embed)
# (3) vanila RNN, h (N, T, H)
if self.cell_type == 'rnn':
h, rnn_cache = rnn_forward(embedded_words, h0, Wx, Wh, b)
else:
h, rnn_cache = lstm_forward(embedded_words, h0, Wx, Wh, b)
# (4) temporal affine, output (N, T, V) which is the score of each word in
# vocab (range(V)). Index with highest value is predicted word.
x, temp_cache = temporal_affine_forward(h, W_vocab, b_vocab)
# (5) softmax
loss, dout = temporal_softmax_loss(x, captions_out, mask)
# find grads by calling backwards functions
dh, dW_vocab, db_vocab = temporal_affine_backward(dout, temp_cache)
if self.cell_type=='rnn':
dx, dh0, dWx, dWh, db = rnn_backward(dh, rnn_cache)
else:
dx, dh0, dWx, dWh, db = lstm_backward(dh, rnn_cache)
dW_embed = word_embedding_backward(dx, word_eb_cache)
d_feature, dW_proj, db_proj = affine_backward(dh0, affine_cache)
grads = {
"W_proj": dW_proj,
"b_proj": db_proj,
"W_embed": dW_embed,
"Wx": dWx,
"Wh": dWh,
"b": db,
"W_vocab": dW_vocab,
"b_vocab": db_vocab
}
When using this for sampling, the captions_in are populated with the <SOS> only.
No comments:
Post a Comment