Exemplo n.º 1
0
    def beam_sample(self, image_features, beam_size=5):
        batch_size = image_features.size(0)
        beam_searcher = BeamSearch(beam_size, batch_size, 17)

        # init the result with zeros and lstm states
        states = self.init_hidden_noise(image_features)
        states = (states[0].repeat(1, beam_size, 1).cuda(),
                  states[1].repeat(1, beam_size, 1).cuda())

        # embed the start symbol
        words_feed = self.embed.word_embeddings([self.embed.START_SYMBOL] * batch_size) \
            .repeat(beam_size, 1).unsqueeze(1).cuda()

        for i in range(self.max_sentence_length):
            hidden, states = self.lstm(words_feed, states)
            outputs = self.output_linear(hidden.squeeze(1))
            beam_indices, words_indices = beam_searcher.expand_beam(
                outputs=outputs)

            if len(beam_indices) == 0 or i == 15:
                generated_captions = beam_searcher.get_results()[:, 0]
                outcaps = self.embed.words_from_indices(
                    generated_captions.cpu().numpy())
            else:
                words_feed = torch.stack([
                    self.embed.word_embeddings_from_indices(words_indices)
                ]).view(beam_size, 1, -1).cuda()
        return " ".join(outcaps)  # .split(self.embed.END_SYMBOL)[0]
Exemplo n.º 2
0
    def sample(self, image_features, image_regions, start_word, beam_size=5):
        images_count = image_features.shape[0]
        sentence_length = 17
        batch_size = images_count * beam_size
        h0, c0 = self.get_start_states(batch_size)
        # image_regions = image_regions.repeat(beam_size, 1, 1)
        # image_features = image_features.repeat(beam_size, 1)
        image_features = torch.stack([image_features] * beam_size) \
            .permute(1, 0, 2) \
            .contiguous() \
            .view(-1, image_features.shape[-1])

        image_regions = torch.stack([image_regions.view(images_count, -1)] * beam_size) \
            .permute(1, 0, 2) \
            .contiguous() \
            .view(-1, image_regions.shape[1], image_regions.shape[2])

        word = start_word.repeat(batch_size)
        alphas = []
        all_words_indices = []
        beam_searcher = BeamSearch(beam_size, images_count, sentence_length)
        for step in range(17):
            if self.use_cuda:
                word = word.cuda()
            embeddings = self.embeds_1(word)
            embeddings_2 = self.embeds_2(embeddings)

            hiddens, (h0, c0) = self.rnn_cell(
                embeddings_2.view(1, batch_size, 2048), (h0, c0))
            attention_layer = self._attention_layer

            atten_features, alpha = attention_layer(
                image_regions, hiddens.view(batch_size, 512))
            # images count * beam size * regions
            alphas.append(alpha.reshape(images_count, beam_size, -1))

            mm_features = self.multi_modal(embeddings_2,
                                           hiddens.view(batch_size, -1),
                                           atten_features, image_features)
            # intermediate_features = self.intermediate(mm_features)
            intermediate_features = F.linear(mm_features,
                                             weight=self.embeds_1.weight)
            beam_indices, words_indices = beam_searcher.expand_beam(
                outputs=intermediate_features)
            words_indices = torch.tensor(words_indices)
            # images count * beam size * word index
            all_words_indices.append(
                words_indices.reshape(images_count, beam_size))
            word = words_indices

        results = beam_searcher.get_results()
        if images_count == 1:
            for j in range(images_count):
                for i in range(len(results)):
                    nonzero = (
                        all_words_indices[i][j] == results[i][j]).nonzero()[0]
                    alphas[i] = alphas[i][j][nonzero].squeeze()
        else:
            alphas = []

        return results, alphas