def test_repeating_excluded_index_does_not_die(self): # batch 0 will repeat excluded idx, batch 1 will repeat n_words = 100 repeat_idx = 47 # will be repeated and should be blocked repeat_idx_ignored = 7 # will be repeated and should not be blocked ngram_repeat = 3 for batch_sz in [1, 3, 17]: samp = RandomSampling( 0, 1, 2, batch_sz, torch.device("cpu"), 0, ngram_repeat, {repeat_idx_ignored}, False, 30, 1., 5, torch.randint(0, 30, (batch_sz,))) for i in range(ngram_repeat + 4): word_probs = torch.full( (batch_sz, n_words), -float('inf')) word_probs[0, repeat_idx_ignored] = 0 if batch_sz > 1: word_probs[1, repeat_idx] = 0 word_probs[2:, repeat_idx + i] = 0 attns = torch.randn(1, batch_sz, 53) samp.advance(word_probs, attns) if i <= ngram_repeat: self.assertFalse(samp.topk_scores.eq( self.BLOCKED_SCORE).any()) else: # now batch 1 dies self.assertFalse(samp.topk_scores[0].eq( self.BLOCKED_SCORE).any()) if batch_sz > 1: self.assertTrue(samp.topk_scores[1].eq( self.BLOCKED_SCORE).all()) self.assertFalse(samp.topk_scores[2:].eq( self.BLOCKED_SCORE).any())
def test_repeating_excluded_index_does_not_die(self): # batch 0 will repeat excluded idx, batch 1 will repeat n_words = 100 repeat_idx = 47 # will be repeated and should be blocked repeat_idx_ignored = 7 # will be repeated and should not be blocked ngram_repeat = 3 for batch_sz in [1, 3, 17]: samp = RandomSampling(0, 1, 2, batch_sz, torch.device("cpu"), 0, ngram_repeat, {repeat_idx_ignored}, False, 30, 1., 5, torch.randint(0, 30, (batch_sz, ))) for i in range(ngram_repeat + 4): word_probs = torch.full((batch_sz, n_words), -float('inf')) word_probs[0, repeat_idx_ignored] = 0 if batch_sz > 1: word_probs[1, repeat_idx] = 0 word_probs[2:, repeat_idx + i] = 0 attns = torch.randn(1, batch_sz, 53) samp.advance(word_probs, attns) if i <= ngram_repeat: self.assertFalse( samp.topk_scores.eq(self.BLOCKED_SCORE).any()) else: # now batch 1 dies self.assertFalse(samp.topk_scores[0].eq( self.BLOCKED_SCORE).any()) if batch_sz > 1: self.assertTrue(samp.topk_scores[1].eq( self.BLOCKED_SCORE).all()) self.assertFalse(samp.topk_scores[2:].eq( self.BLOCKED_SCORE).any())
def _translate_random_sampling(self, src, src_lengths, batch_size, min_length=0, sampling_temp=1.0, keep_topk=1, return_attention=False, pretraining=False): max_length = self.config.max_sequence_length + 1 # to account for EOS # Encoder forward. enc_states, memory_bank, src_lengths = self.encoder(src, src_lengths) self.decoder.init_state(src, memory_bank, enc_states) results = { "predictions": None, "scores": None, "attention": None } memory_lengths = src_lengths mb_device = memory_bank[0].device if isinstance(memory_bank, tuple) else memory_bank.device block_ngram_repeat = 0 _exclusion_idxs = {} random_sampler = RandomSampling( self.config.tgt_padding, self.config.tgt_bos, self.config.tgt_eos, batch_size, mb_device, min_length, block_ngram_repeat, _exclusion_idxs, return_attention, max_length, sampling_temp, keep_topk, memory_lengths) for step in range(max_length): # Shape: (1, B, 1) decoder_input = random_sampler.alive_seq[:, -1].view(1, -1, 1) log_probs, attn = self._decode_and_generate(decoder_input, memory_bank, memory_lengths, step, pretraining) if self.config.DISTRIBUTIONAL and not pretraining: log_probs = (log_probs * self.quantile_weight).sum(dim=3).squeeze(0) random_sampler.advance(log_probs, attn) any_batch_is_finished = random_sampler.is_finished.any() if any_batch_is_finished: random_sampler.update_finished() if random_sampler.done: break if any_batch_is_finished: select_indices = random_sampler.select_indices # Reorder states. if isinstance(memory_bank, tuple): memory_bank = tuple(x.index_select(1, select_indices) for x in memory_bank) else: memory_bank = memory_bank.index_select(1, select_indices) memory_lengths = memory_lengths.index_select(0, select_indices) self.decoder.map_state( lambda state, dim: state.index_select(dim, select_indices)) results["scores"] = random_sampler.scores results["predictions"] = random_sampler.predictions results["attention"] = random_sampler.attention return results
def test_advance_with_some_repeats_gets_blocked(self): # batch 0 and 7 will repeat, the rest will advance n_words = 100 repeat_idx = 47 other_repeat_idx = 12 ngram_repeat = 3 for batch_sz in [1, 3, 13]: samp = RandomSampling( 0, 1, 2, batch_sz, torch.device("cpu"), 0, ngram_repeat, set(), False, 30, 1., 5, torch.randint(0, 30, (batch_sz,))) for i in range(ngram_repeat + 4): word_probs = torch.full( (batch_sz, n_words), -float('inf')) # predict the same thing in batch 0 and 7 every i word_probs[0, repeat_idx] = 0 if batch_sz > 7: word_probs[7, other_repeat_idx] = 0 # push around what the other batches predict word_probs[1:7, repeat_idx + i] = 0 if batch_sz > 7: word_probs[8:, repeat_idx + i] = 0 attns = torch.randn(1, batch_sz, 53) samp.advance(word_probs, attns) if i <= ngram_repeat: self.assertFalse( samp.topk_scores.eq( self.BLOCKED_SCORE).any()) else: # now batch 0 and 7 die self.assertTrue(samp.topk_scores[0].eq(self.BLOCKED_SCORE)) if batch_sz > 7: self.assertTrue(samp.topk_scores[7].eq( self.BLOCKED_SCORE)) self.assertFalse( samp.topk_scores[1:7].eq( self.BLOCKED_SCORE).any()) if batch_sz > 7: self.assertFalse( samp.topk_scores[8:].eq( self.BLOCKED_SCORE).any())
def test_advance_with_some_repeats_gets_blocked(self): # batch 0 and 7 will repeat, the rest will advance n_words = 100 repeat_idx = 47 other_repeat_idx = 12 ngram_repeat = 3 for batch_sz in [1, 3, 13]: samp = RandomSampling( 0, 1, 2, batch_sz, torch.device("cpu"), 0, ngram_repeat, set(), False, 30, 1., 5, torch.randint(0, 30, (batch_sz,))) for i in range(ngram_repeat + 4): word_probs = torch.full( (batch_sz, n_words), -float('inf')) # predict the same thing in batch 0 and 7 every i word_probs[0, repeat_idx] = 0 if batch_sz > 7: word_probs[7, other_repeat_idx] = 0 # push around what the other batches predict word_probs[1:7, repeat_idx + i] = 0 if batch_sz > 7: word_probs[8:, repeat_idx + i] = 0 attns = torch.randn(1, batch_sz, 53) samp.advance(word_probs, attns) if i <= ngram_repeat: self.assertFalse( samp.topk_scores.eq( self.BLOCKED_SCORE).any()) else: # now batch 0 and 7 die self.assertTrue(samp.topk_scores[0].eq(self.BLOCKED_SCORE)) if batch_sz > 7: self.assertTrue(samp.topk_scores[7].eq( self.BLOCKED_SCORE)) self.assertFalse( samp.topk_scores[1:7].eq( self.BLOCKED_SCORE).any()) if batch_sz > 7: self.assertFalse( samp.topk_scores[8:].eq( self.BLOCKED_SCORE).any())
def test_advance_with_repeats_gets_blocked(self): n_words = 100 repeat_idx = 47 ngram_repeat = 3 for batch_sz in [1, 3]: samp = RandomSampling(0, 1, 2, batch_sz, torch.device("cpu"), 0, ngram_repeat, set(), False, 30, 1., 5, torch.randint(0, 30, (batch_sz, ))) for i in range(ngram_repeat + 4): # predict repeat_idx over and over again word_probs = torch.full((batch_sz, n_words), -float('inf')) word_probs[:, repeat_idx] = 0 attns = torch.randn(1, batch_sz, 53) samp.advance(word_probs, attns) if i <= ngram_repeat: expected_scores = torch.zeros((batch_sz, 1)) self.assertTrue(samp.topk_scores.equal(expected_scores)) else: self.assertTrue( samp.topk_scores.equal( torch.tensor(self.BLOCKED_SCORE).repeat( batch_sz, 1)))
def test_doesnt_predict_eos_if_shorter_than_min_len(self): # batch 0 will always predict EOS. The other batches will predict # non-eos scores. for batch_sz in [1, 3]: n_words = 100 _non_eos_idxs = [47] valid_score_dist = torch.log_softmax(torch.tensor( [6., 5.]), dim=0) min_length = 5 eos_idx = 2 lengths = torch.randint(0, 30, (batch_sz,)) samp = RandomSampling( 0, 1, 2, batch_sz, torch.device("cpu"), min_length, False, set(), False, 30, 1., 1, lengths) all_attns = [] for i in range(min_length + 4): word_probs = torch.full( (batch_sz, n_words), -float('inf')) # "best" prediction is eos - that should be blocked word_probs[0, eos_idx] = valid_score_dist[0] # include at least one prediction OTHER than EOS # that is greater than -1e20 word_probs[0, _non_eos_idxs[0]] = valid_score_dist[1] word_probs[1:, _non_eos_idxs[0] + i] = 0 attns = torch.randn(1, batch_sz, 53) all_attns.append(attns) samp.advance(word_probs, attns) if i < min_length: self.assertTrue( samp.topk_scores[0].allclose(valid_score_dist[1])) self.assertTrue( samp.topk_scores[1:].eq(0).all()) elif i == min_length: # now batch 0 has ended and no others have self.assertTrue(samp.is_finished[0, :].eq(1).all()) self.assertTrue(samp.is_finished[1:, 1:].eq(0).all()) else: # i > min_length break
def test_doesnt_predict_eos_if_shorter_than_min_len(self): # batch 0 will always predict EOS. The other batches will predict # non-eos scores. for batch_sz in [1, 3]: n_words = 100 _non_eos_idxs = [47] valid_score_dist = torch.log_softmax(torch.tensor( [6., 5.]), dim=0) min_length = 5 eos_idx = 2 lengths = torch.randint(0, 30, (batch_sz,)) samp = RandomSampling( 0, 1, 2, batch_sz, torch.device("cpu"), min_length, False, set(), False, 30, 1., 1, lengths) all_attns = [] for i in range(min_length + 4): word_probs = torch.full( (batch_sz, n_words), -float('inf')) # "best" prediction is eos - that should be blocked word_probs[0, eos_idx] = valid_score_dist[0] # include at least one prediction OTHER than EOS # that is greater than -1e20 word_probs[0, _non_eos_idxs[0]] = valid_score_dist[1] word_probs[1:, _non_eos_idxs[0] + i] = 0 attns = torch.randn(1, batch_sz, 53) all_attns.append(attns) samp.advance(word_probs, attns) if i < min_length: self.assertTrue( samp.topk_scores[0].allclose(valid_score_dist[1])) self.assertTrue( samp.topk_scores[1:].eq(0).all()) elif i == min_length: # now batch 0 has ended and no others have self.assertTrue(samp.is_finished[0, :].eq(1).all()) self.assertTrue(samp.is_finished[1:, 1:].eq(0).all()) else: # i > min_length break
def test_advance_with_repeats_gets_blocked(self): n_words = 100 repeat_idx = 47 ngram_repeat = 3 for batch_sz in [1, 3]: samp = RandomSampling( 0, 1, 2, batch_sz, torch.device("cpu"), 0, ngram_repeat, set(), False, 30, 1., 5, torch.randint(0, 30, (batch_sz,))) for i in range(ngram_repeat + 4): # predict repeat_idx over and over again word_probs = torch.full( (batch_sz, n_words), -float('inf')) word_probs[:, repeat_idx] = 0 attns = torch.randn(1, batch_sz, 53) samp.advance(word_probs, attns) if i <= ngram_repeat: expected_scores = torch.zeros((batch_sz, 1)) self.assertTrue(samp.topk_scores.equal(expected_scores)) else: self.assertTrue( samp.topk_scores.equal( torch.tensor(self.BLOCKED_SCORE) .repeat(batch_sz, 1)))
def test_returns_correct_scores_non_deterministic(self): for batch_sz in [1, 13]: for temp in [1., 3.]: n_words = 100 _non_eos_idxs = [47, 51, 13, 88, 99] valid_score_dist_1 = torch.log_softmax(torch.tensor( [6., 5., 4., 3., 2., 1.]), dim=0) valid_score_dist_2 = torch.log_softmax(torch.tensor([6., 1.]), dim=0) eos_idx = 2 lengths = torch.randint(0, 30, (batch_sz, )) samp = RandomSampling(0, 1, 2, batch_sz, torch.device("cpu"), 0, False, set(), False, 30, temp, 2, lengths) # initial step i = 0 for _ in range(100): word_probs = torch.full((batch_sz, n_words), -float('inf')) # batch 0 dies on step 0 word_probs[0, eos_idx] = valid_score_dist_1[0] # include at least one prediction OTHER than EOS # that is greater than -1e20 word_probs[0, _non_eos_idxs] = valid_score_dist_1[1:] word_probs[1:, _non_eos_idxs[0] + i] = 0 attns = torch.randn(1, batch_sz, 53) samp.advance(word_probs, attns) if samp.is_finished[0].eq(1).all(): break else: self.fail("Batch 0 never ended (very unlikely but maybe " "due to stochasticisty. If so, please increase " "the range of the for-loop.") samp.update_finished() self.assertEqual(samp.scores[0], [valid_score_dist_1[0] / temp]) if batch_sz == 1: self.assertTrue(samp.done) continue else: self.assertFalse(samp.done) # step 2 i = 1 for _ in range(100): word_probs = torch.full((batch_sz - 1, n_words), -float('inf')) # (old) batch 8 dies on step 1 word_probs[7, eos_idx] = valid_score_dist_2[0] word_probs[0:7, _non_eos_idxs[:2]] = valid_score_dist_2 word_probs[8:, _non_eos_idxs[:2]] = valid_score_dist_2 attns = torch.randn(1, batch_sz, 53) samp.advance(word_probs, attns) if samp.is_finished[7].eq(1).all(): break else: self.fail("Batch 8 never ended (very unlikely but maybe " "due to stochasticisty. If so, please increase " "the range of the for-loop.") samp.update_finished() self.assertEqual(samp.scores[8], [valid_score_dist_2[0] / temp]) # step 3 i = 2 for _ in range(250): word_probs = torch.full((samp.alive_seq.shape[0], n_words), -float('inf')) # everything dies word_probs[:, eos_idx] = 0 attns = torch.randn(1, batch_sz, 53) samp.advance(word_probs, attns) if samp.is_finished.any(): samp.update_finished() if samp.is_finished.eq(1).all(): break else: self.fail("All batches never ended (very unlikely but " "maybe due to stochasticisty. If so, please " "increase the range of the for-loop.") for b in range(batch_sz): if b != 0 and b != 8: self.assertEqual(samp.scores[b], [0]) self.assertTrue(samp.done)
def test_returns_correct_scores_deterministic(self): for batch_sz in [1, 13]: for temp in [1., 3.]: n_words = 100 _non_eos_idxs = [47, 51, 13, 88, 99] valid_score_dist_1 = torch.log_softmax(torch.tensor( [6., 5., 4., 3., 2., 1.]), dim=0) valid_score_dist_2 = torch.log_softmax(torch.tensor([6., 1.]), dim=0) eos_idx = 2 lengths = torch.randint(0, 30, (batch_sz, )) samp = RandomSampling(0, 1, 2, batch_sz, torch.device("cpu"), 0, False, set(), False, 30, temp, 1, lengths) # initial step i = 0 word_probs = torch.full((batch_sz, n_words), -float('inf')) # batch 0 dies on step 0 word_probs[0, eos_idx] = valid_score_dist_1[0] # include at least one prediction OTHER than EOS # that is greater than -1e20 word_probs[0, _non_eos_idxs] = valid_score_dist_1[1:] word_probs[1:, _non_eos_idxs[0] + i] = 0 attns = torch.randn(1, batch_sz, 53) samp.advance(word_probs, attns) self.assertTrue(samp.is_finished[0].eq(1).all()) samp.update_finished() self.assertEqual(samp.scores[0], [valid_score_dist_1[0] / temp]) if batch_sz == 1: self.assertTrue(samp.done) continue else: self.assertFalse(samp.done) # step 2 i = 1 word_probs = torch.full((batch_sz - 1, n_words), -float('inf')) # (old) batch 8 dies on step 1 word_probs[7, eos_idx] = valid_score_dist_2[0] word_probs[0:7, _non_eos_idxs[:2]] = valid_score_dist_2 word_probs[8:, _non_eos_idxs[:2]] = valid_score_dist_2 attns = torch.randn(1, batch_sz, 53) samp.advance(word_probs, attns) self.assertTrue(samp.is_finished[7].eq(1).all()) samp.update_finished() self.assertEqual(samp.scores[8], [valid_score_dist_2[0] / temp]) # step 3 i = 2 word_probs = torch.full((batch_sz - 2, n_words), -float('inf')) # everything dies word_probs[:, eos_idx] = 0 attns = torch.randn(1, batch_sz, 53) samp.advance(word_probs, attns) self.assertTrue(samp.is_finished.eq(1).all()) samp.update_finished() for b in range(batch_sz): if b != 0 and b != 8: self.assertEqual(samp.scores[b], [0]) self.assertTrue(samp.done)
def _translate_random_sampling( self, batch, src_vocabs, max_length, min_length=0, sampling_temp=1.0, keep_topk=-1, return_attention=False): """Alternative to beam search. Do random sampling at each step.""" assert self.beam_size == 1 # TODO: support these blacklisted features. assert self.block_ngram_repeat == 0 batch_size = batch.batch_size # Encoder forward. src, enc_states, memory_bank, src_lengths = self._run_encoder(batch) self.model.decoder.init_state(src, memory_bank, enc_states) use_src_map = self.copy_attn results = { "predictions": None, "scores": None, "attention": None, "batch": batch, "gold_score": self._gold_score( batch, memory_bank, src_lengths, src_vocabs, use_src_map, enc_states, batch_size, src)} memory_lengths = src_lengths src_map = batch.src_map if use_src_map else None if isinstance(memory_bank, tuple): mb_device = memory_bank[0].device else: mb_device = memory_bank.device score_temp, VAD_score_temp, src_relevance_temp, VAD_lambda, decoder_word_embedding = None, None, None, None, None if isinstance(self.model.decoder, KGTransformerDecoder) and self.emotion_topk_decoding_temp != 0: score_temp = self.model.decoder.score_temp VAD_score_temp = self.model.decoder.VAD_score_temp src_relevance_temp = self.model.decoder.src_relevance_temp VAD_lambda = self.model.decoder.VAD_lambda decoder_word_embedding = self.model.decoder.embeddings random_sampler = RandomSampling( self._tgt_pad_idx, self._tgt_bos_idx, self._tgt_eos_idx, batch_size, mb_device, min_length, self.block_ngram_repeat, self._exclusion_idxs, return_attention, self.max_length, sampling_temp, keep_topk, memory_lengths, self.emotion_topk_decoding_temp, score_temp=score_temp, VAD_score_temp=VAD_score_temp, src_relevance_temp=src_relevance_temp, VAD_lambda=VAD_lambda, decoder_word_embedding=decoder_word_embedding) # batch_emotion = None # if hasattr(batch, "emotion"): # print("batch has emotion") # batch_emotion = batch.emotion # else: # print("batch has no emotion") batch_emotion = None if hasattr(batch, "emotion"): batch_emotion = batch.emotion batch_tgt_concept_emb = None if hasattr(batch, "tgt_concept_emb"): batch_tgt_concept_emb = batch.tgt_concept_emb batch_tgt_concept_words = None if hasattr(batch, "tgt_concept_index"): batch_tgt_concept_words = batch.tgt_concept_index, batch.tgt_concept_score, batch.tgt_concept_VAD_score for step in range(max_length): # Shape: (1, B, 1) decoder_input = random_sampler.alive_seq[:, -1].view(1, -1, 1) # log_probs: ``(batch_size, vocab_size)`` log_probs, attn = self._decode_and_generate( decoder_input, memory_bank, batch, src_vocabs, memory_lengths=memory_lengths, src_map=src_map, step=step, batch_offset=random_sampler.select_indices, batch_emotion=batch_emotion, batch_tgt_concept_emb=batch_tgt_concept_emb, batch_tgt_concept_words=batch_tgt_concept_words ) if hasattr(batch, "tgt_concept_index"): random_sampler.advance(log_probs, attn, batch_tgt_concept_words, memory_bank) else: random_sampler.advance(log_probs, attn) any_batch_is_finished = random_sampler.is_finished.any() if any_batch_is_finished: random_sampler.update_finished() if random_sampler.done: break if any_batch_is_finished: select_indices = random_sampler.select_indices # Reorder states. if isinstance(memory_bank, tuple): # print("Reordering memory bank tuple") # print(select_indices) memory_bank = tuple(x.index_select(1, select_indices) for x in memory_bank) else: # print("Reordering memory bank") # print(select_indices) memory_bank = memory_bank.index_select(1, select_indices) memory_lengths = memory_lengths.index_select(0, select_indices) if src_map is not None: src_map = src_map.index_select(1, select_indices) self.model.decoder.map_state( lambda state, dim: state.index_select(dim, select_indices)) if hasattr(batch, "emotion"): batch_emotion = batch_emotion.index_select(0, select_indices) if hasattr(batch, "tgt_concept_emb"): batch_tgt_concept_emb = batch_tgt_concept_emb.index_select(0, select_indices) if hasattr(batch, "tgt_concept_index"): batch_tgt_concept_words = batch_tgt_concept_words[0].index_select(0, select_indices), \ batch_tgt_concept_words[1].index_select(0, select_indices), \ batch_tgt_concept_words[2].index_select(0, select_indices) # update batch_emotion using a mask # TO DO results["scores"] = random_sampler.scores results["predictions"] = random_sampler.predictions results["attention"] = random_sampler.attention return results
def _translate_random_sampling(self, batch, src_vocabs, max_length, min_length=0, sampling_temp=1.0, keep_topk=-1, return_attention=False): """Alternative to beam search. Do random sampling at each step.""" assert self.beam_size == 1 # TODO: support these blacklisted features. assert self.block_ngram_repeat == 0 batch_size = batch.batch_size # Encoder forward. src, enc_states, memory_bank, src_lengths = self._run_encoder(batch) self.model.decoder.init_state(src, memory_bank, enc_states) use_src_map = self.copy_attn results = { "predictions": None, "scores": None, "attention": None, "batch": batch, "gold_score": self._gold_score(batch, memory_bank, src_lengths, src_vocabs, use_src_map, enc_states, batch_size, src) } memory_lengths = src_lengths src_map = batch.src_map if use_src_map else None if isinstance(memory_bank, tuple): mb_device = memory_bank[0].device else: mb_device = memory_bank.device random_sampler = RandomSampling( self._tgt_pad_idx, self._tgt_bos_idx, self._tgt_eos_idx, batch_size, mb_device, min_length, self.block_ngram_repeat, self._exclusion_idxs, return_attention, self.max_length, sampling_temp, keep_topk, memory_lengths) for step in range(max_length): # Shape: (1, B, 1) decoder_input = random_sampler.alive_seq[:, -1].view(1, -1, 1) log_probs, attn, hack_dict = self._decode_and_generate( decoder_input, memory_bank, batch, src_vocabs, memory_lengths=memory_lengths, src_map=src_map, step=step, batch_offset=random_sampler.select_indices, ) top_prob = torch.topk(log_probs, k=1, dim=1) self.total_tokens += top_prob.indices.size()[0] tgt_field = dict(self.fields)["tgt"].base_field vocab = tgt_field.vocab src_field = dict(self.fields)["src"].base_field vocab_src = src_field.vocab for i in range(top_prob.indices.size()[0]): word = vocab.itos[top_prob.indices[i][0]] self.vocab_dict[word] += 1 if (is_function_word(word)): self.total_function_words += 1 else: self.total_content_words += 1 if self.counterfactual_attention_method == 'uniform_or_zero_out_max': log_probs_uniform = hack_dict['log_probs_uniform'] top_prob_uniform = torch.topk(log_probs_uniform, k=1, dim=1) unaffected_boolean_uniform = ( top_prob.indices == top_prob_uniform.indices) log_probs_zero_out_max = hack_dict['log_probs_zero_out_max'] top_prob_zero_out_max = torch.topk(log_probs_zero_out_max, k=1, dim=1) unaffected_boolean_zero_out_max = ( top_prob.indices == top_prob_zero_out_max.indices) log_probs_permute = hack_dict['log_probs_permute'] top_prob_permute = torch.topk(log_probs_permute, k=1, dim=1) unaffected_boolean_permute = ( top_prob.indices == top_prob_permute.indices) for i in range(unaffected_boolean_zero_out_max.size()[0]): if (unaffected_boolean_uniform[i][0].item() == 1 or unaffected_boolean_zero_out_max[i][0].item() == 1 or unaffected_boolean_permute[i][0].item() == 1): word = vocab.itos[top_prob.indices[i][0]] if is_function_word(word): self.unaffected_function_words_count += 1 self.unaffected_function_words[word] += 1 else: self.unaffected_content_words_count += 1 self.unaffected_content_words[word] += 1 else: log_probs_counterfactual = hack_dict[ 'log_probs_counterfactual'] top_prob_counterfactual = torch.topk(log_probs_counterfactual, k=1, dim=1) unaffected_boolean = ( top_prob.indices == top_prob_counterfactual.indices) self.unaffected_words_count += unaffected_boolean.sum( dim=0).cpu().numpy()[0] for i in range(unaffected_boolean.size()[0]): if (unaffected_boolean[i][0].item() == 1): word = vocab.itos[top_prob.indices[i][0]] if is_function_word(word): self.unaffected_function_words_count += 1 self.unaffected_function_words[word] += 1 else: self.unaffected_content_words_count += 1 self.unaffected_content_words[word] += 1 #if tvd_permute is True: # max_attention = hack_dict['tvd_max_attention'].cpu() # dist_change_median = hack_dict['tvd_dist_change_median'].cpu() # for i in range(top_prob.indices.size()[0]): # if(max_attention[i] > 0.5 and dist_change_median[i] < 0.2): # self.tvd_tokens[vocab.itos[top_prob.indices[i][0]]] += 1 # assert (len(max_attention.size()) == 1) # assert (len(dist_change_median.size()) == 1) # assert (max_attention.size()[0] == dist_change_median.size()[0]) # for i in range(max_attention.size()[0]): # max_att_dist_change_pairs.append((float(max_attention[i].cpu()), float(dist_change_median[i].cpu()))) random_sampler.advance(log_probs, attn) any_batch_is_finished = random_sampler.is_finished.any() if any_batch_is_finished: random_sampler.update_finished() if random_sampler.done: break if any_batch_is_finished: select_indices = random_sampler.select_indices # Reorder states. if isinstance(memory_bank, tuple): memory_bank = tuple( x.index_select(1, select_indices) for x in memory_bank) else: memory_bank = memory_bank.index_select(1, select_indices) memory_lengths = memory_lengths.index_select(0, select_indices) if src_map is not None: src_map = src_map.index_select(1, select_indices) self.model.decoder.map_state( lambda state, dim: state.index_select(dim, select_indices)) results["scores"] = random_sampler.scores results["predictions"] = random_sampler.predictions results["attention"] = random_sampler.attention return results
def test_returns_correct_scores_non_deterministic(self): for batch_sz in [1, 13]: for temp in [1., 3.]: n_words = 100 _non_eos_idxs = [47, 51, 13, 88, 99] valid_score_dist_1 = torch.log_softmax(torch.tensor( [6., 5., 4., 3., 2., 1.]), dim=0) valid_score_dist_2 = torch.log_softmax(torch.tensor( [6., 1.]), dim=0) eos_idx = 2 lengths = torch.randint(0, 30, (batch_sz,)) samp = RandomSampling( 0, 1, 2, batch_sz, torch.device("cpu"), 0, False, set(), False, 30, temp, 2, lengths) # initial step i = 0 for _ in range(100): word_probs = torch.full( (batch_sz, n_words), -float('inf')) # batch 0 dies on step 0 word_probs[0, eos_idx] = valid_score_dist_1[0] # include at least one prediction OTHER than EOS # that is greater than -1e20 word_probs[0, _non_eos_idxs] = valid_score_dist_1[1:] word_probs[1:, _non_eos_idxs[0] + i] = 0 attns = torch.randn(1, batch_sz, 53) samp.advance(word_probs, attns) if samp.is_finished[0].eq(1).all(): break else: self.fail("Batch 0 never ended (very unlikely but maybe " "due to stochasticisty. If so, please increase " "the range of the for-loop.") samp.update_finished() self.assertEqual( samp.scores[0], [valid_score_dist_1[0] / temp]) if batch_sz == 1: self.assertTrue(samp.done) continue else: self.assertFalse(samp.done) # step 2 i = 1 for _ in range(100): word_probs = torch.full( (batch_sz - 1, n_words), -float('inf')) # (old) batch 8 dies on step 1 word_probs[7, eos_idx] = valid_score_dist_2[0] word_probs[0:7, _non_eos_idxs[:2]] = valid_score_dist_2 word_probs[8:, _non_eos_idxs[:2]] = valid_score_dist_2 attns = torch.randn(1, batch_sz, 53) samp.advance(word_probs, attns) if samp.is_finished[7].eq(1).all(): break else: self.fail("Batch 8 never ended (very unlikely but maybe " "due to stochasticisty. If so, please increase " "the range of the for-loop.") samp.update_finished() self.assertEqual( samp.scores[8], [valid_score_dist_2[0] / temp]) # step 3 i = 2 for _ in range(250): word_probs = torch.full( (samp.alive_seq.shape[0], n_words), -float('inf')) # everything dies word_probs[:, eos_idx] = 0 attns = torch.randn(1, batch_sz, 53) samp.advance(word_probs, attns) if samp.is_finished.any(): samp.update_finished() if samp.is_finished.eq(1).all(): break else: self.fail("All batches never ended (very unlikely but " "maybe due to stochasticisty. If so, please " "increase the range of the for-loop.") for b in range(batch_sz): if b != 0 and b != 8: self.assertEqual(samp.scores[b], [0]) self.assertTrue(samp.done)
def test_returns_correct_scores_deterministic(self): for batch_sz in [1, 13]: for temp in [1., 3.]: n_words = 100 _non_eos_idxs = [47, 51, 13, 88, 99] valid_score_dist_1 = torch.log_softmax(torch.tensor( [6., 5., 4., 3., 2., 1.]), dim=0) valid_score_dist_2 = torch.log_softmax(torch.tensor( [6., 1.]), dim=0) eos_idx = 2 lengths = torch.randint(0, 30, (batch_sz,)) samp = RandomSampling( 0, 1, 2, batch_sz, torch.device("cpu"), 0, False, set(), False, 30, temp, 1, lengths) # initial step i = 0 word_probs = torch.full( (batch_sz, n_words), -float('inf')) # batch 0 dies on step 0 word_probs[0, eos_idx] = valid_score_dist_1[0] # include at least one prediction OTHER than EOS # that is greater than -1e20 word_probs[0, _non_eos_idxs] = valid_score_dist_1[1:] word_probs[1:, _non_eos_idxs[0] + i] = 0 attns = torch.randn(1, batch_sz, 53) samp.advance(word_probs, attns) self.assertTrue(samp.is_finished[0].eq(1).all()) samp.update_finished() self.assertEqual( samp.scores[0], [valid_score_dist_1[0] / temp]) if batch_sz == 1: self.assertTrue(samp.done) continue else: self.assertFalse(samp.done) # step 2 i = 1 word_probs = torch.full( (batch_sz - 1, n_words), -float('inf')) # (old) batch 8 dies on step 1 word_probs[7, eos_idx] = valid_score_dist_2[0] word_probs[0:7, _non_eos_idxs[:2]] = valid_score_dist_2 word_probs[8:, _non_eos_idxs[:2]] = valid_score_dist_2 attns = torch.randn(1, batch_sz, 53) samp.advance(word_probs, attns) self.assertTrue(samp.is_finished[7].eq(1).all()) samp.update_finished() self.assertEqual( samp.scores[8], [valid_score_dist_2[0] / temp]) # step 3 i = 2 word_probs = torch.full( (batch_sz - 2, n_words), -float('inf')) # everything dies word_probs[:, eos_idx] = 0 attns = torch.randn(1, batch_sz, 53) samp.advance(word_probs, attns) self.assertTrue(samp.is_finished.eq(1).all()) samp.update_finished() for b in range(batch_sz): if b != 0 and b != 8: self.assertEqual(samp.scores[b], [0]) self.assertTrue(samp.done)
class Translator(object): """Translate a batch of sentences with a saved model. Args: model (onmt.modules.NMTModel): NMT model to use for translation fields (dict[str, torchtext.data.Field]): A dict mapping each side to its list of name-Field pairs. src_reader (onmt.inputters.DataReaderBase): Source reader. tgt_reader (onmt.inputters.TextDataReader): Target reader. gpu (int): GPU device. Set to negative for no GPU. n_best (int): How many beams to wait for. min_length (int): See :class:`onmt.translate.decode_strategy.DecodeStrategy`. max_length (int): See :class:`onmt.translate.decode_strategy.DecodeStrategy`. beam_size (int): Number of beams. random_sampling_topk (int): See :class:`onmt.translate.random_sampling.RandomSampling`. random_sampling_temp (int): See :class:`onmt.translate.random_sampling.RandomSampling`. stepwise_penalty (bool): Whether coverage penalty is applied every step or not. dump_beam (bool): Debugging option. block_ngram_repeat (int): See :class:`onmt.translate.decode_strategy.DecodeStrategy`. ignore_when_blocking (set or frozenset): See :class:`onmt.translate.decode_strategy.DecodeStrategy`. replace_unk (bool): Replace unknown token. data_type (str): Source data type. verbose (bool): Print/log every translation. report_bleu (bool): Print/log Bleu metric. report_rouge (bool): Print/log Rouge metric. report_time (bool): Print/log total time/frequency. copy_attn (bool): Use copy attention. global_scorer (onmt.translate.GNMTGlobalScorer): Translation scoring/reranking object. out_file (TextIO or codecs.StreamReaderWriter): Output file. report_score (bool) : Whether to report scores logger (logging.Logger or NoneType): Logger. """ def __init__(self, model, fields, src_reader, tgt_reader, gpu=-1, n_best=1, min_length=0, max_length=100, beam_size=30, random_sampling_topk=1, random_sampling_temp=1, stepwise_penalty=None, dump_beam=False, block_ngram_repeat=0, ignore_when_blocking=frozenset(), replace_unk=False, data_type="text", verbose=False, report_bleu=False, report_rouge=False, report_time=False, copy_attn=False, simple_fusion=False, gpt_tgt=False, global_scorer=None, out_file=None, report_score=True, logger=None, seed=-1): self.model = model self.fields = fields tgt_field = dict(self.fields)["tgt"].base_field self._tgt_vocab = tgt_field.vocab self._tgt_eos_idx = self._tgt_vocab.stoi[tgt_field.eos_token] self._tgt_pad_idx = self._tgt_vocab.stoi[tgt_field.pad_token] self._tgt_bos_idx = self._tgt_vocab.stoi[tgt_field.init_token] # self._tgt_bos_idx = self._tgt_vocab.stoi['Ġsee'] print(self._tgt_bos_idx) self._tgt_unk_idx = self._tgt_vocab.stoi[tgt_field.unk_token] self._tgt_vocab_len = len(self._tgt_vocab) self._gpu = gpu self._use_cuda = gpu > -1 self._dev = torch.device("cuda", self._gpu) \ if self._use_cuda else torch.device("cpu") self.n_best = n_best self.max_length = max_length self.beam_size = beam_size self.random_sampling_temp = random_sampling_temp self.sample_from_topk = random_sampling_topk self.min_length = min_length self.stepwise_penalty = stepwise_penalty self.dump_beam = dump_beam self.block_ngram_repeat = block_ngram_repeat self.ignore_when_blocking = ignore_when_blocking self._exclusion_idxs = { self._tgt_vocab.stoi[t] for t in self.ignore_when_blocking } self.src_reader = src_reader self.tgt_reader = tgt_reader self.replace_unk = replace_unk if self.replace_unk and not self.model.decoder.attentional: raise ValueError("replace_unk requires an attentional decoder.") self.data_type = data_type self.verbose = verbose self.report_bleu = report_bleu self.report_rouge = report_rouge self.report_time = report_time self.copy_attn = copy_attn self.simple_fusion = simple_fusion self.gpt_tgt = gpt_tgt self.global_scorer = global_scorer if self.global_scorer.has_cov_pen and \ not self.model.decoder.attentional: raise ValueError( "Coverage penalty requires an attentional decoder.") self.out_file = out_file self.report_score = report_score self.logger = logger self.use_filter_pred = False self._filter_pred = None # for debugging self.beam_trace = self.dump_beam != "" self.beam_accum = None if self.beam_trace: self.beam_accum = { "predicted_ids": [], "beam_parent_ids": [], "scores": [], "log_probs": [] } set_random_seed(seed, self._use_cuda) @classmethod def from_opt(cls, model, fields, opt, model_opt, global_scorer=None, out_file=None, report_score=True, logger=None): """Alternate constructor. Args: model (onmt.modules.NMTModel): See :func:`__init__()`. fields (dict[str, torchtext.data.Field]): See :func:`__init__()`. opt (argparse.Namespace): Command line options model_opt (argparse.Namespace): Command line options saved with the model checkpoint. global_scorer (onmt.translate.GNMTGlobalScorer): See :func:`__init__()`.. out_file (TextIO or codecs.StreamReaderWriter): See :func:`__init__()`. report_score (bool) : See :func:`__init__()`. logger (logging.Logger or NoneType): See :func:`__init__()`. """ if opt.data_type == 'none': src_reader = None else: src_reader = inputters.str2reader[opt.data_type].from_opt(opt) tgt_reader = inputters.str2reader["text"].from_opt(opt) return cls(model, fields, src_reader, tgt_reader, gpu=opt.gpu, n_best=opt.n_best, min_length=opt.min_length, max_length=opt.max_length, beam_size=opt.beam_size, random_sampling_topk=opt.random_sampling_topk, random_sampling_temp=opt.random_sampling_temp, stepwise_penalty=opt.stepwise_penalty, dump_beam=opt.dump_beam, block_ngram_repeat=opt.block_ngram_repeat, ignore_when_blocking=set(opt.ignore_when_blocking), replace_unk=opt.replace_unk, data_type=opt.data_type, verbose=opt.verbose, report_bleu=opt.report_bleu, report_rouge=opt.report_rouge, report_time=opt.report_time, copy_attn=model_opt.copy_attn, simple_fusion=model_opt.simple_fusion, gpt_tgt=model_opt.GPT_representation_mode != 'none' and model_opt.GPT_representation_loc in ['tgt', 'both'], global_scorer=global_scorer, out_file=out_file, report_score=report_score, logger=logger, seed=opt.seed) def _gold_score(self, batch, memory_bank, src_lengths, src_vocabs, use_src_map, enc_states, batch_size, src): if "tgt" in batch.__dict__: gs = self._score_target(batch, memory_bank, src_lengths, src_vocabs, batch.src_map if use_src_map else None) self.model.decoder.init_state(src, memory_bank, enc_states) if self.simple_fusion: self.model.lm_decoder.init_state(src, None, None) else: gs = [0] * batch_size return gs def build_data_iter(self, opt, batch_size=1): if batch_size is None: raise ValueError("batch_size must be set") def read_file(path): priv_str = "r" priv_str += "b" with open(path, priv_str) as f: return f.readlines() src = read_file(opt.src) tgt = read_file(opt.tgt) if opt.tgt is not None else None readers, data, dirs = [], [], [] if self.src_reader: readers += [self.src_reader] data += [("src", src)] dirs += [None] if tgt: readers += [self.tgt_reader] data += [("tgt", tgt)] dirs += [None] data = inputters.Dataset( self.fields, readers=readers, data=data, dirs=dirs, sort_key=inputters.str2sortkey[self.data_type], filter_pred=self._filter_pred) data_iter = inputters.OrderedIterator(dataset=data, device=self._dev, batch_size=batch_size, train=False, sort=False, sort_within_batch=True, shuffle=False) self.src_vocabs = data.src_vocabs self._build_xlation(data, tgt) self.all_scores = [] self.all_predictions = [] self.pred_score_total, self.pred_words_total = 0, 0 self.gold_score_total, self.gold_words_total = 0, 0 self.counter = count(1) return data_iter def set_encoder_state(self, batch, tags=None, temperature=1.0): if self.beam_size != 1: self.beam_size = 1 if self.block_ngram_repeat != 0: self.block_ngram_repeat = 0 # Encoder forward. src, enc_states, memory_bank, src_lengths = self._run_encoder(batch) self.model.decoder.init_state(src, memory_bank, enc_states) if self.simple_fusion: self.model.lm_decoder.init_state(src, None, None) use_src_map = self.copy_attn memory_lengths = src_lengths src_map = batch.src_map if use_src_map else None # set for the decoder usage self.enc_states = enc_states self.src = src self.src_lengths = src_lengths self.memory_bank = memory_bank self.memory_lengths = memory_lengths self.src_map = src_map self.batch = batch self.batch_size = batch.batch_size self.set_random_sampler(temperature=temperature) self._build_result() def forward_pass(self, decoder_in, step, past=None, input_embeds=None, tags=None, use_copy=True): memory_bank = self.memory_bank src_vocabs = self.src_vocabs memory_lengths = self.memory_lengths src_map = self.src_map batch = self.batch if self.copy_attn: # Turn any copied words into UNKs. decoder_in = decoder_in.masked_fill( decoder_in.gt(self._tgt_vocab_len - 1), self._tgt_unk_idx) decoder = self.model.decoder dec_out, all_hidden_states, past, dec_attn = decoder( decoder_in, memory_bank, memory_lengths=memory_lengths, step=step, past=past, input_embeds=input_embeds, pplm_return=True) # Generator forward. if not self.copy_attn: if "std" in dec_attn: attn = dec_attn["std"] else: attn = None if self.simple_fusion: lm_dec_out, _ = self.model.lm_decoder(decoder_in, memory_bank.new_zeros( 1, 1, 1), step=step) probs = self.model.generator(dec_out.squeeze(0), lm_dec_out.squeeze(0)) else: probs = self.model.generator(dec_out.squeeze(0)) # print(log_probs) # returns [(batch_size x beam_size) , vocab ] when 1 step # or [ tgt_len, batch_size, vocab ] when full sentence else: attn = dec_attn["copy"] scores, p_copy = self.model.generator(dec_out.view( -1, dec_out.size(2)), attn.view(-1, attn.size(2)), src_map, tags=tags) scores = scores.view(batch.batch_size, -1, scores.size(-1)) scores = collapse_copy_scores(scores, batch, self._tgt_vocab, src_vocabs, batch_dim=0, batch_offset=None) scores = scores.view(decoder_in.size(0), -1, scores.size(-1)) # log_probs = scores.squeeze(0).log() probs = scores.squeeze(0) if use_copy is False: probs = probs[:, :50257] return probs, attn, all_hidden_states, past return probs, attn, all_hidden_states, past, p_copy return probs, attn, all_hidden_states, past def set_random_sampler(self, return_attention=False, temperature=1.0): if isinstance(self.memory_bank, tuple) or isinstance( self.memory_bank, list): if isinstance(self.memory_bank[0], dict): mb_device = self.memory_bank[0][list( self.memory_bank[0].keys())[0]].device else: mb_device = self.memory_bank[0].device else: mb_device = self.memory_bank.device if self.max_length < 400: self.max_length = 400 if self.min_length < 300: self.min_length = 300 self.random_sampler = RandomSampling( self._tgt_pad_idx, self._tgt_bos_idx, self._tgt_eos_idx, self.batch_size, mb_device, self.min_length, self.block_ngram_repeat, self._exclusion_idxs, return_attention, self.max_length, temperature, self.sample_from_topk, self.memory_lengths) def generate_tokens(self, log_probs, attn=None): self.random_sampler.advance(log_probs, attn) any_batch_is_finished = self.random_sampler.is_finished.any() # ATTENTION: The batch size is one if any_batch_is_finished: self.random_sampler.update_finished() if self.random_sampler.done: # Finish the generation, set the resulf for generation self._finalize_result() self.generate_sentence_batch() return False else: return self.random_sampler.alive_seq[:, -1].view(1, 1) def generate_sentence_batch(self): batch = self.results translations = self.xlation_builder.from_batch(batch) for trans in translations: self.all_scores += [trans.pred_scores[:self.n_best]] self.pred_score_total += trans.pred_scores[0] self.pred_words_total += len(trans.pred_sents[0]) n_best_preds = [ " ".join(pred) for pred in trans.pred_sents[:self.n_best] ] self.all_predictions += [n_best_preds] self.out_file.write('\n'.join(n_best_preds) + '\n') self.out_file.flush() if self.verbose: sent_number = next(self.counter) output = trans.log(sent_number) if self.logger: self.logger.info(output) else: os.write(1, output.encode('utf-8')) def _build_result(self): use_src_map = self.copy_attn self.results = { "predictions": None, "scores": None, "attention": None, "batch": self.batch, "gold_score": self._gold_score(self.batch, self.memory_bank, self.src_lengths, self.src_vocabs, use_src_map, self.enc_states, self.batch_size, self.src) } def _build_xlation(self, data, tgt): self.xlation_builder = onmt.translate.TranslationBuilder( data, self.fields, self.n_best, self.replace_unk, tgt) def _finalize_result(self): self.results["scores"] = self.random_sampler.scores self.results["predictions"] = self.random_sampler.predictions self.results["attention"] = self.random_sampler.attention def _run_encoder(self, batch): if hasattr(batch, 'src'): src, src_lengths = batch.src if isinstance(batch.src, tuple) \ else (batch.src, None) enc_states, memory_bank, src_lengths = self.model.encoder( src, src_lengths) if src_lengths is None: assert not isinstance(memory_bank, tuple), \ 'Ensemble decoding only supported for text data' src_lengths = torch.Tensor(batch.batch_size) \ .type_as(memory_bank) \ .long() \ .fill_(memory_bank.size(0)) else: src = None enc_states = None memory_bank = torch.zeros((1, batch.tgt[0].shape[1], 1), dtype=torch.float, device=batch.tgt[0].device) src_lengths = torch.ones((batch.tgt[0].shape[1], ), dtype=torch.long, device=batch.tgt[0].device) # src_lengths = None return src, enc_states, memory_bank, src_lengths def freeze_parameter(self): for parameter in self.model.encoder.parameters(): parameter.requires_grad = False for parameter in self.model.decoder.parameters(): parameter.requires_grad = False for parameter in self.model.generator.parameters(): parameter.requires_grad = False
def _translate_random_sampling( self, batch, src_vocabs, max_length, min_length=0, sampling_temp=1.0, keep_topk=-1, return_attention=False): """Alternative to beam search. Do random sampling at each step.""" assert self.beam_size == 1 # TODO: support these blacklisted features. assert self.block_ngram_repeat == 0 batch_size = batch.batch_size # Encoder forward. src, enc_states, memory_bank, src_lengths = self._run_encoder(batch) self.model.decoder.init_state(src, memory_bank, enc_states) use_src_map = self.copy_attn results = { "predictions": None, "scores": None, "attention": None, "batch": batch, "gold_score": self._gold_score( batch, memory_bank, src_lengths, src_vocabs, use_src_map, enc_states, batch_size, src)} memory_lengths = src_lengths src_map = batch.src_map if use_src_map else None if isinstance(memory_bank, tuple): mb_device = memory_bank[0].device else: mb_device = memory_bank.device random_sampler = RandomSampling( self._tgt_pad_idx, self._tgt_bos_idx, self._tgt_eos_idx, batch_size, mb_device, min_length, self.block_ngram_repeat, self._exclusion_idxs, return_attention, self.max_length, sampling_temp, keep_topk, memory_lengths) for step in range(max_length): # Shape: (1, B, 1) decoder_input = random_sampler.alive_seq[:, -1].view(1, -1, 1) log_probs, attn = self._decode_and_generate( decoder_input, memory_bank, batch, src_vocabs, memory_lengths=memory_lengths, src_map=src_map, step=step, batch_offset=random_sampler.select_indices ) random_sampler.advance(log_probs, attn) any_batch_is_finished = random_sampler.is_finished.any() if any_batch_is_finished: random_sampler.update_finished() if random_sampler.done: break if any_batch_is_finished: select_indices = random_sampler.select_indices # Reorder states. if isinstance(memory_bank, tuple): memory_bank = tuple(x.index_select(1, select_indices) for x in memory_bank) else: memory_bank = memory_bank.index_select(1, select_indices) memory_lengths = memory_lengths.index_select(0, select_indices) if src_map is not None: src_map = src_map.index_select(1, select_indices) self.model.decoder.map_state( lambda state, dim: state.index_select(dim, select_indices)) results["scores"] = random_sampler.scores results["predictions"] = random_sampler.predictions results["attention"] = random_sampler.attention return results
def _translate_random_sampling( self, batch, src_vocabs, max_length, min_length=0, sampling_temp=1.0, keep_topk=-1, return_attention=False): """Alternative to beam search. Do random sampling at each step.""" assert self.beam_size == 1 # TODO: support these blacklisted features. assert self.block_ngram_repeat == 0 batch_size = batch.batch_size # Encoder forward. src, enc_states, memory_bank, src_lengths = self._run_encoder(batch) self.model.decoder.init_state(src, memory_bank, enc_states) use_src_map = self.copy_attn results = { "predictions": None, "scores": None, "attention": None, "batch": batch, "gold_score": self._gold_score( batch, memory_bank, src_lengths, src_vocabs, use_src_map, enc_states, batch_size, src)} memory_lengths = src_lengths src_map = batch.src_map if use_src_map else None if isinstance(memory_bank, tuple): mb_device = memory_bank[0].device else: mb_device = memory_bank.device random_sampler = RandomSampling( self._tgt_pad_idx, self._tgt_bos_idx, self._tgt_eos_idx, batch_size, mb_device, min_length, self.block_ngram_repeat, self._exclusion_idxs, return_attention, self.max_length, sampling_temp, keep_topk, memory_lengths) for step in range(max_length): # Shape: (1, B, 1) decoder_input = random_sampler.alive_seq[:, -1].view(1, -1, 1) log_probs, attn = self._decode_and_generate( decoder_input, memory_bank, batch, src_vocabs, memory_lengths=memory_lengths, src_map=src_map, step=step, batch_offset=random_sampler.select_indices ) random_sampler.advance(log_probs, attn) any_batch_is_finished = random_sampler.is_finished.any() if any_batch_is_finished: random_sampler.update_finished() if random_sampler.done: break if any_batch_is_finished: select_indices = random_sampler.select_indices # Reorder states. if isinstance(memory_bank, tuple): memory_bank = tuple(x.index_select(1, select_indices) for x in memory_bank) else: memory_bank = memory_bank.index_select(1, select_indices) memory_lengths = memory_lengths.index_select(0, select_indices) if src_map is not None: src_map = src_map.index_select(1, select_indices) self.model.decoder.map_state( lambda state, dim: state.index_select(dim, select_indices)) results["scores"] = random_sampler.scores results["predictions"] = random_sampler.predictions results["attention"] = random_sampler.attention return results