Sunday 14 July 2024

Image Captioning with RNN and LSTM

 The image captioning RNN architecture in assignment 3:

Forward and backward passes through the above network:
        # (1) CNN features (N, D) -> hiddent state projection layer
        h0, affine_cache = affine_forward(features, W_proj, b_proj) # (N, H)

        # (2) word embedding, captions_in (N, T), output (N, T, W)
        # captions_in contains indices of words in the dictionary
        embedded_words, word_eb_cache = word_embedding_forward(captions_in, W_embed)

        # (3) vanila RNN, h (N, T, H)
        if self.cell_type == 'rnn':
          h, rnn_cache = rnn_forward(embedded_words, h0,  Wx, Wh, b)
        else:
          h, rnn_cache = lstm_forward(embedded_words, h0,  Wx, Wh, b)

        # (4) temporal affine, output (N, T, V) which is the score of each word in
        # vocab (range(V)). Index with highest value is predicted word.
        x, temp_cache = temporal_affine_forward(h, W_vocab, b_vocab)

        # (5) softmax
        loss, dout = temporal_softmax_loss(x, captions_out, mask)

        # find grads by calling backwards functions
        dh, dW_vocab, db_vocab = temporal_affine_backward(dout, temp_cache)
        if self.cell_type=='rnn':
          dx, dh0, dWx, dWh, db = rnn_backward(dh, rnn_cache)
        else:
          dx, dh0, dWx, dWh, db = lstm_backward(dh, rnn_cache)
        dW_embed = word_embedding_backward(dx, word_eb_cache)
        d_feature, dW_proj, db_proj = affine_backward(dh0, affine_cache)
       
        grads = {
          "W_proj": dW_proj,
          "b_proj": db_proj,
          "W_embed": dW_embed,
          "Wx": dWx,
          "Wh": dWh,
          "b": db,
          "W_vocab": dW_vocab,
          "b_vocab": db_vocab
        }

When using this for sampling, the captions_in are populated with the <SOS> only.

No comments: