예제 #1
0
    def forward(self, encoder_output):
        images_embedded = self.image_embedding(encoder_output)

        weights = self.linear_image_embedding_weights(images_embedded)
        normalized_weights = self.softmax(weights)

        weighted_image_boxes = normalized_weights * images_embedded
        weighted_image_boxes_summed = weighted_image_boxes.sum(dim=1)

        v_mean_embedded = l2_norm(weighted_image_boxes_summed)
        return images_embedded, v_mean_embedded
예제 #2
0
    def embed_captions(self, captions, decode_lengths):
        # Initialize LSTM state
        batch_size = captions.size(0)
        h_lan_enc, c_lan_enc = self.language_encoding_lstm.init_state(
            batch_size)

        # Tensor to store hidden activations
        lang_enc_hidden_activations = torch.zeros(
            (batch_size, self.language_encoding_lstm_size), device=device)

        for t in range(max(decode_lengths)):
            prev_words_embedded = self.embeddings(captions[:, t])
            h_lan_enc, c_lan_enc = self.language_encoding_lstm(
                h_lan_enc, c_lan_enc, prev_words_embedded)
            lang_enc_hidden_activations[decode_lengths == t +
                                        1] = h_lan_enc[decode_lengths == t + 1]

        captions_embedded = self.caption_embedding(lang_enc_hidden_activations)
        captions_embedded = l2_norm(captions_embedded)
        return captions_embedded
예제 #3
0
    def forward(self,
                encoder_output,
                target_captions=None,
                decode_lengths=None,
                teacher_forcing=0.0,
                mask_prob=0.0,
                mask_type=None):
        """Forward pass for both ranking and caption generation."""

        batch_size = encoder_output.size(0)

        # Flatten image
        encoder_output = encoder_output.view(batch_size, -1,
                                             encoder_output.size(-1))

        # Initialize LSTM states
        states = self.init_hidden_states(encoder_output)
        lang_enc_hidden_activations = None
        if self.training:
            # Tensor to store hidden activations of the language encoding LSTM of the last timestep
            # These will be the caption embedding
            lang_enc_hidden_activations = torch.zeros(
                (batch_size, max(decode_lengths),
                 self.language_encoding_lstm_size),
                device=device)

        # Tensors to hold word prediction scores
        scores = torch.zeros(
            (batch_size, max(decode_lengths), self.vocab_size), device=device)

        # FOR MULTITASK
        if self.training and target_captions is not None:
            prev_words = torch.ones(
                (batch_size, ), dtype=torch.int64,
                device=device) * target_captions[:, 0]
        else:
            # At the start, all 'previous words' are the <start> token
            prev_words = torch.full((batch_size, ),
                                    self.word_map[TOKEN_START],
                                    dtype=torch.int64,
                                    device=device)

        target_clones = target_captions
        if self.training and mask_prob:
            # FOR MASK INTERLEAVED
            target_clones = target_clones.clone()
            tag_ix = self.word_map[TOKEN_MASK_TAG]
            word_ix = self.word_map[TOKEN_MASK_WORD]
            probs = np.random.uniform(0, 1, len(target_captions))
            tochange_ixs = [ix for ix, v in enumerate(probs < mask_prob) if v]
            mask_tag_ixs = np.array([
                np.random.choice(range(0, l - 1, 2))
                for l in decode_lengths.tolist()
            ])
            mask_tag_ixs = mask_tag_ixs[tochange_ixs]
            if mask_type in {"tags", "both"}:
                target_clones[tochange_ixs, mask_tag_ixs + 1] = tag_ix
            if mask_type in {"words", "both"}:
                target_clones[tochange_ixs, mask_tag_ixs + 2] = word_ix

        for t in range(max(decode_lengths)):

            if not self.training:
                # Find all sequences where an <end> token has been produced in the last timestep
                ind_end_token = torch.nonzero(
                    prev_words == self.word_map[TOKEN_END]).view(-1).tolist()
                # Update the decode lengths accordingly
                decode_lengths[ind_end_token] = torch.min(
                    decode_lengths[ind_end_token],
                    torch.full_like(decode_lengths[ind_end_token],
                                    t,
                                    device=device),
                )

            # Check if all sequences are finished:
            incomplete_sequences_ixs = torch.nonzero(
                decode_lengths > t).view(-1)
            if len(incomplete_sequences_ixs) == 0:
                break

            # Forward prop.
            prev_words_embedded = self.embeddings(prev_words)
            scores_for_timestep, states, alphas_for_timestep = \
                self.forward_step(encoder_output, prev_words_embedded, states)

            # Update the previously predicted words
            prev_words = self.update_previous_word(scores_for_timestep,
                                                   target_clones, t,
                                                   teacher_forcing)

            scores[incomplete_sequences_ixs,
                   t, :] = scores_for_timestep[incomplete_sequences_ixs]
            if self.training:
                h_lan_enc = states[0]
                lang_enc_hidden_activations[decode_lengths >= t + 1,
                                            t] = h_lan_enc[
                                                decode_lengths >= t + 1]

        captions_embedded = None
        v_mean_embedded = None
        if self.training:
            _, v_mean_embedded = self.image_embedding(encoder_output)
            captions_attention = self.caption_attention(
                lang_enc_hidden_activations, decode_lengths)
            captions_embedded = self.caption_embedding(captions_attention)
            captions_embedded = l2_norm(captions_embedded)

        extras = {
            'images_embedded': v_mean_embedded,
            'captions_embedded': captions_embedded
        }

        return scores, decode_lengths, extras
예제 #4
0
    def forward_multi(self,
                      encoder_output,
                      target_captions=None,
                      decode_lengths=None,
                      teacher_forcing=0.0):
        """
        Forward propagation.

        :param encoder_output: output features of the encoder
        :param target_captions: encoded target captions, shape: (batch_size, max_caption_length)
        :param decode_lengths: caption lengths, shape: (batch_size, 1)
        :return: scores for vocabulary, decode lengths, weights
        """

        batch_size = encoder_output.size(0)

        # Flatten image
        encoder_output = encoder_output.view(batch_size, -1,
                                             encoder_output.size(-1))

        # Initialize LSTM state
        states = self.init_hidden_states(encoder_output)
        lang_enc_hidden_activations = None
        if self.training:
            # Tensor to store hidden activations of the language encoding LSTM of the last timestep
            # These will be the caption embedding
            lang_enc_hidden_activations = torch.zeros(
                (batch_size, self.language_encoding_lstm_size), device=device)

        # Tensors to hold word prediction scores and alphas
        w_scores = torch.zeros(
            (batch_size, max(decode_lengths), self.vocab_size), device=device)
        t_scores = torch.zeros(
            (batch_size, max(decode_lengths), self.vocab_size), device=device)

        # FOR MULTITASK
        if self.training and target_captions is not None:
            prev_words = torch.ones(
                (batch_size, ), dtype=torch.int64,
                device=device) * target_captions[:, 0]
        else:
            # At the start, all 'previous words' are the <start> token
            prev_words = torch.full((batch_size, ),
                                    self.word_map[TOKEN_START],
                                    dtype=torch.int64,
                                    device=device)

        for t in range(max(decode_lengths)):

            if not self.training:
                # Find all sequences where an <end> token has been produced in the last timestep
                ind_end_token = torch.nonzero(
                    prev_words == self.word_map[TOKEN_END]).view(-1).tolist()
                # Update the decode lengths accordingly
                decode_lengths[ind_end_token] = torch.min(
                    decode_lengths[ind_end_token],
                    torch.full_like(decode_lengths[ind_end_token],
                                    t,
                                    device=device),
                )

            # Check if all sequences are finished:
            incomplete_sequences_ixs = torch.nonzero(
                decode_lengths > t).view(-1)
            if len(incomplete_sequences_ixs) == 0:
                break

            # Forward prop.
            prev_words_embedded = self.embeddings(prev_words)
            scores_for_timestep, states, alphas_for_timestep = \
                self.forward_multi_step(encoder_output, prev_words_embedded, states)

            # Update the previously predicted words
            w_scores_for_timestep, t_scores_for_timestep = scores_for_timestep
            prev_words = self.update_previous_word(w_scores_for_timestep,
                                                   target_captions, t,
                                                   teacher_forcing)

            w_scores[incomplete_sequences_ixs,
                     t, :] = w_scores_for_timestep[incomplete_sequences_ixs]
            t_scores[incomplete_sequences_ixs,
                     t, :] = t_scores_for_timestep[incomplete_sequences_ixs]
            if self.training:
                h_lan_enc = states[0]
                lang_enc_hidden_activations[
                    decode_lengths == t + 1] = h_lan_enc[decode_lengths == t +
                                                         1]

        captions_embedded = None
        v_mean_embedded = None
        if self.training:
            _, v_mean_embedded = self.image_embedding(encoder_output)
            captions_embedded = self.caption_embedding(
                lang_enc_hidden_activations)
            captions_embedded = l2_norm(captions_embedded)

        scores = [w_scores, t_scores]
        extras = {
            'images_embedded': v_mean_embedded,
            'captions_embedded': captions_embedded
        }

        return scores, decode_lengths, extras