def beam_sample(self, image_features, beam_size=5): batch_size = image_features.size(0) beam_searcher = BeamSearch(beam_size, batch_size, 17) # init the result with zeros and lstm states states = self.init_hidden_noise(image_features) states = (states[0].repeat(1, beam_size, 1).cuda(), states[1].repeat(1, beam_size, 1).cuda()) # embed the start symbol words_feed = self.embed.word_embeddings([self.embed.START_SYMBOL] * batch_size) \ .repeat(beam_size, 1).unsqueeze(1).cuda() for i in range(self.max_sentence_length): hidden, states = self.lstm(words_feed, states) outputs = self.output_linear(hidden.squeeze(1)) beam_indices, words_indices = beam_searcher.expand_beam( outputs=outputs) if len(beam_indices) == 0 or i == 15: generated_captions = beam_searcher.get_results()[:, 0] outcaps = self.embed.words_from_indices( generated_captions.cpu().numpy()) else: words_feed = torch.stack([ self.embed.word_embeddings_from_indices(words_indices) ]).view(beam_size, 1, -1).cuda() return " ".join(outcaps) # .split(self.embed.END_SYMBOL)[0]
def sample(self, image_features, image_regions, start_word, beam_size=5): images_count = image_features.shape[0] sentence_length = 17 batch_size = images_count * beam_size h0, c0 = self.get_start_states(batch_size) # image_regions = image_regions.repeat(beam_size, 1, 1) # image_features = image_features.repeat(beam_size, 1) image_features = torch.stack([image_features] * beam_size) \ .permute(1, 0, 2) \ .contiguous() \ .view(-1, image_features.shape[-1]) image_regions = torch.stack([image_regions.view(images_count, -1)] * beam_size) \ .permute(1, 0, 2) \ .contiguous() \ .view(-1, image_regions.shape[1], image_regions.shape[2]) word = start_word.repeat(batch_size) alphas = [] all_words_indices = [] beam_searcher = BeamSearch(beam_size, images_count, sentence_length) for step in range(17): if self.use_cuda: word = word.cuda() embeddings = self.embeds_1(word) embeddings_2 = self.embeds_2(embeddings) hiddens, (h0, c0) = self.rnn_cell( embeddings_2.view(1, batch_size, 2048), (h0, c0)) attention_layer = self._attention_layer atten_features, alpha = attention_layer( image_regions, hiddens.view(batch_size, 512)) # images count * beam size * regions alphas.append(alpha.reshape(images_count, beam_size, -1)) mm_features = self.multi_modal(embeddings_2, hiddens.view(batch_size, -1), atten_features, image_features) # intermediate_features = self.intermediate(mm_features) intermediate_features = F.linear(mm_features, weight=self.embeds_1.weight) beam_indices, words_indices = beam_searcher.expand_beam( outputs=intermediate_features) words_indices = torch.tensor(words_indices) # images count * beam size * word index all_words_indices.append( words_indices.reshape(images_count, beam_size)) word = words_indices results = beam_searcher.get_results() if images_count == 1: for j in range(images_count): for i in range(len(results)): nonzero = ( all_words_indices[i][j] == results[i][j]).nonzero()[0] alphas[i] = alphas[i][j][nonzero].squeeze() else: alphas = [] return results, alphas