def _translate_random_sampling(self, batch, src_vocabs, max_length, min_length=0, sampling_temp=1.0, keep_topk=-1, return_attention=False): """Alternative to beam search. Do random sampling at each step.""" assert self.beam_size == 1 # TODO: support these blacklisted features. assert self.block_ngram_repeat == 0 batch_size = batch.batch_size # Encoder forward. src, enc_states, memory_bank, src_lengths = self._run_encoder(batch) self.model.decoder.init_state(src, memory_bank, enc_states) use_src_map = self.copy_attn results = { "predictions": None, "scores": None, "attention": None, "batch": batch, "gold_score": self._gold_score(batch, memory_bank, src_lengths, src_vocabs, use_src_map, enc_states, batch_size, src) } memory_lengths = src_lengths src_map = batch.src_map if use_src_map else None if isinstance(memory_bank, tuple): mb_device = memory_bank[0].device else: mb_device = memory_bank.device random_sampler = RandomSampling( self._tgt_pad_idx, self._tgt_bos_idx, self._tgt_eos_idx, batch_size, mb_device, min_length, self.block_ngram_repeat, self._exclusion_idxs, return_attention, self.max_length, sampling_temp, keep_topk, memory_lengths) for step in range(max_length): # Shape: (1, B, 1) decoder_input = random_sampler.alive_seq[:, -1].view(1, -1, 1) log_probs, attn = self._decode_and_generate( decoder_input, memory_bank, batch, src_vocabs, memory_lengths=memory_lengths, src_map=src_map, step=step, batch_offset=random_sampler.select_indices) random_sampler.advance(log_probs, attn) any_batch_is_finished = random_sampler.is_finished.any() if any_batch_is_finished: random_sampler.update_finished() if random_sampler.done: break if any_batch_is_finished: select_indices = random_sampler.select_indices # Reorder states. if isinstance(memory_bank, tuple): memory_bank = tuple( x.index_select(1, select_indices) for x in memory_bank) else: memory_bank = memory_bank.index_select(1, select_indices) memory_lengths = memory_lengths.index_select(0, select_indices) if src_map is not None: src_map = src_map.index_select(1, select_indices) self.model.decoder.map_state( lambda state, dim: state.index_select(dim, select_indices)) results["scores"] = random_sampler.scores results["predictions"] = random_sampler.predictions results["attention"] = random_sampler.attention return results
def _translate_random_sampling( self, batch, src_vocabs, max_length, min_length=0, sampling_temp=1.0, keep_topk=-1, return_attention=False): """Alternative to beam search. Do random sampling at each step.""" assert self.beam_size == 1 # TODO: support these blacklisted features. assert self.block_ngram_repeat == 0 batch_size = batch.batch_size # Encoder forward. src, enc_states, memory_bank, src_lengths = self._run_encoder(batch) self.model.decoder.init_state(src, memory_bank, enc_states) use_src_map = self.copy_attn results = { "predictions": None, "scores": None, "attention": None, "batch": batch, "gold_score": self._gold_score( batch, memory_bank, src_lengths, src_vocabs, use_src_map, enc_states, batch_size, src)} memory_lengths = src_lengths src_map = batch.src_map if use_src_map else None if isinstance(memory_bank, tuple): mb_device = memory_bank[0].device else: mb_device = memory_bank.device score_temp, VAD_score_temp, src_relevance_temp, VAD_lambda, decoder_word_embedding = None, None, None, None, None if isinstance(self.model.decoder, KGTransformerDecoder) and self.emotion_topk_decoding_temp != 0: score_temp = self.model.decoder.score_temp VAD_score_temp = self.model.decoder.VAD_score_temp src_relevance_temp = self.model.decoder.src_relevance_temp VAD_lambda = self.model.decoder.VAD_lambda decoder_word_embedding = self.model.decoder.embeddings random_sampler = RandomSampling( self._tgt_pad_idx, self._tgt_bos_idx, self._tgt_eos_idx, batch_size, mb_device, min_length, self.block_ngram_repeat, self._exclusion_idxs, return_attention, self.max_length, sampling_temp, keep_topk, memory_lengths, self.emotion_topk_decoding_temp, score_temp=score_temp, VAD_score_temp=VAD_score_temp, src_relevance_temp=src_relevance_temp, VAD_lambda=VAD_lambda, decoder_word_embedding=decoder_word_embedding) # batch_emotion = None # if hasattr(batch, "emotion"): # print("batch has emotion") # batch_emotion = batch.emotion # else: # print("batch has no emotion") batch_emotion = None if hasattr(batch, "emotion"): batch_emotion = batch.emotion batch_tgt_concept_emb = None if hasattr(batch, "tgt_concept_emb"): batch_tgt_concept_emb = batch.tgt_concept_emb batch_tgt_concept_words = None if hasattr(batch, "tgt_concept_index"): batch_tgt_concept_words = batch.tgt_concept_index, batch.tgt_concept_score, batch.tgt_concept_VAD_score for step in range(max_length): # Shape: (1, B, 1) decoder_input = random_sampler.alive_seq[:, -1].view(1, -1, 1) # log_probs: ``(batch_size, vocab_size)`` log_probs, attn = self._decode_and_generate( decoder_input, memory_bank, batch, src_vocabs, memory_lengths=memory_lengths, src_map=src_map, step=step, batch_offset=random_sampler.select_indices, batch_emotion=batch_emotion, batch_tgt_concept_emb=batch_tgt_concept_emb, batch_tgt_concept_words=batch_tgt_concept_words ) if hasattr(batch, "tgt_concept_index"): random_sampler.advance(log_probs, attn, batch_tgt_concept_words, memory_bank) else: random_sampler.advance(log_probs, attn) any_batch_is_finished = random_sampler.is_finished.any() if any_batch_is_finished: random_sampler.update_finished() if random_sampler.done: break if any_batch_is_finished: select_indices = random_sampler.select_indices # Reorder states. if isinstance(memory_bank, tuple): # print("Reordering memory bank tuple") # print(select_indices) memory_bank = tuple(x.index_select(1, select_indices) for x in memory_bank) else: # print("Reordering memory bank") # print(select_indices) memory_bank = memory_bank.index_select(1, select_indices) memory_lengths = memory_lengths.index_select(0, select_indices) if src_map is not None: src_map = src_map.index_select(1, select_indices) self.model.decoder.map_state( lambda state, dim: state.index_select(dim, select_indices)) if hasattr(batch, "emotion"): batch_emotion = batch_emotion.index_select(0, select_indices) if hasattr(batch, "tgt_concept_emb"): batch_tgt_concept_emb = batch_tgt_concept_emb.index_select(0, select_indices) if hasattr(batch, "tgt_concept_index"): batch_tgt_concept_words = batch_tgt_concept_words[0].index_select(0, select_indices), \ batch_tgt_concept_words[1].index_select(0, select_indices), \ batch_tgt_concept_words[2].index_select(0, select_indices) # update batch_emotion using a mask # TO DO results["scores"] = random_sampler.scores results["predictions"] = random_sampler.predictions results["attention"] = random_sampler.attention return results
def test_returns_correct_scores_deterministic(self): for batch_sz in [1, 13]: for temp in [1., 3.]: n_words = 100 _non_eos_idxs = [47, 51, 13, 88, 99] valid_score_dist_1 = torch.log_softmax(torch.tensor( [6., 5., 4., 3., 2., 1.]), dim=0) valid_score_dist_2 = torch.log_softmax(torch.tensor( [6., 1.]), dim=0) eos_idx = 2 lengths = torch.randint(0, 30, (batch_sz,)) samp = RandomSampling( 0, 1, 2, batch_sz, torch.device("cpu"), 0, False, set(), False, 30, temp, 1, lengths) # initial step i = 0 word_probs = torch.full( (batch_sz, n_words), -float('inf')) # batch 0 dies on step 0 word_probs[0, eos_idx] = valid_score_dist_1[0] # include at least one prediction OTHER than EOS # that is greater than -1e20 word_probs[0, _non_eos_idxs] = valid_score_dist_1[1:] word_probs[1:, _non_eos_idxs[0] + i] = 0 attns = torch.randn(1, batch_sz, 53) samp.advance(word_probs, attns) self.assertTrue(samp.is_finished[0].eq(1).all()) samp.update_finished() self.assertEqual( samp.scores[0], [valid_score_dist_1[0] / temp]) if batch_sz == 1: self.assertTrue(samp.done) continue else: self.assertFalse(samp.done) # step 2 i = 1 word_probs = torch.full( (batch_sz - 1, n_words), -float('inf')) # (old) batch 8 dies on step 1 word_probs[7, eos_idx] = valid_score_dist_2[0] word_probs[0:7, _non_eos_idxs[:2]] = valid_score_dist_2 word_probs[8:, _non_eos_idxs[:2]] = valid_score_dist_2 attns = torch.randn(1, batch_sz, 53) samp.advance(word_probs, attns) self.assertTrue(samp.is_finished[7].eq(1).all()) samp.update_finished() self.assertEqual( samp.scores[8], [valid_score_dist_2[0] / temp]) # step 3 i = 2 word_probs = torch.full( (batch_sz - 2, n_words), -float('inf')) # everything dies word_probs[:, eos_idx] = 0 attns = torch.randn(1, batch_sz, 53) samp.advance(word_probs, attns) self.assertTrue(samp.is_finished.eq(1).all()) samp.update_finished() for b in range(batch_sz): if b != 0 and b != 8: self.assertEqual(samp.scores[b], [0]) self.assertTrue(samp.done)
def test_returns_correct_scores_non_deterministic(self): for batch_sz in [1, 13]: for temp in [1., 3.]: n_words = 100 _non_eos_idxs = [47, 51, 13, 88, 99] valid_score_dist_1 = torch.log_softmax(torch.tensor( [6., 5., 4., 3., 2., 1.]), dim=0) valid_score_dist_2 = torch.log_softmax(torch.tensor( [6., 1.]), dim=0) eos_idx = 2 lengths = torch.randint(0, 30, (batch_sz,)) samp = RandomSampling( 0, 1, 2, batch_sz, torch.device("cpu"), 0, False, set(), False, 30, temp, 2, lengths) # initial step i = 0 for _ in range(100): word_probs = torch.full( (batch_sz, n_words), -float('inf')) # batch 0 dies on step 0 word_probs[0, eos_idx] = valid_score_dist_1[0] # include at least one prediction OTHER than EOS # that is greater than -1e20 word_probs[0, _non_eos_idxs] = valid_score_dist_1[1:] word_probs[1:, _non_eos_idxs[0] + i] = 0 attns = torch.randn(1, batch_sz, 53) samp.advance(word_probs, attns) if samp.is_finished[0].eq(1).all(): break else: self.fail("Batch 0 never ended (very unlikely but maybe " "due to stochasticisty. If so, please increase " "the range of the for-loop.") samp.update_finished() self.assertEqual( samp.scores[0], [valid_score_dist_1[0] / temp]) if batch_sz == 1: self.assertTrue(samp.done) continue else: self.assertFalse(samp.done) # step 2 i = 1 for _ in range(100): word_probs = torch.full( (batch_sz - 1, n_words), -float('inf')) # (old) batch 8 dies on step 1 word_probs[7, eos_idx] = valid_score_dist_2[0] word_probs[0:7, _non_eos_idxs[:2]] = valid_score_dist_2 word_probs[8:, _non_eos_idxs[:2]] = valid_score_dist_2 attns = torch.randn(1, batch_sz, 53) samp.advance(word_probs, attns) if samp.is_finished[7].eq(1).all(): break else: self.fail("Batch 8 never ended (very unlikely but maybe " "due to stochasticisty. If so, please increase " "the range of the for-loop.") samp.update_finished() self.assertEqual( samp.scores[8], [valid_score_dist_2[0] / temp]) # step 3 i = 2 for _ in range(250): word_probs = torch.full( (samp.alive_seq.shape[0], n_words), -float('inf')) # everything dies word_probs[:, eos_idx] = 0 attns = torch.randn(1, batch_sz, 53) samp.advance(word_probs, attns) if samp.is_finished.any(): samp.update_finished() if samp.is_finished.eq(1).all(): break else: self.fail("All batches never ended (very unlikely but " "maybe due to stochasticisty. If so, please " "increase the range of the for-loop.") for b in range(batch_sz): if b != 0 and b != 8: self.assertEqual(samp.scores[b], [0]) self.assertTrue(samp.done)
def _translate_random_sampling(self, batch, src_vocabs, max_length, min_length=0, sampling_temp=1.0, keep_topk=-1, return_attention=False): """Alternative to beam search. Do random sampling at each step.""" assert self.beam_size == 1 # TODO: support these blacklisted features. assert self.block_ngram_repeat == 0 batch_size = batch.batch_size # Encoder forward. src, enc_states, memory_bank, src_lengths = self._run_encoder(batch) self.model.decoder.init_state(src, memory_bank, enc_states) use_src_map = self.copy_attn results = { "predictions": None, "scores": None, "attention": None, "batch": batch, "gold_score": self._gold_score(batch, memory_bank, src_lengths, src_vocabs, use_src_map, enc_states, batch_size, src) } memory_lengths = src_lengths src_map = batch.src_map if use_src_map else None if isinstance(memory_bank, tuple): mb_device = memory_bank[0].device else: mb_device = memory_bank.device random_sampler = RandomSampling( self._tgt_pad_idx, self._tgt_bos_idx, self._tgt_eos_idx, batch_size, mb_device, min_length, self.block_ngram_repeat, self._exclusion_idxs, return_attention, self.max_length, sampling_temp, keep_topk, memory_lengths) for step in range(max_length): # Shape: (1, B, 1) decoder_input = random_sampler.alive_seq[:, -1].view(1, -1, 1) log_probs, attn, hack_dict = self._decode_and_generate( decoder_input, memory_bank, batch, src_vocabs, memory_lengths=memory_lengths, src_map=src_map, step=step, batch_offset=random_sampler.select_indices, ) top_prob = torch.topk(log_probs, k=1, dim=1) self.total_tokens += top_prob.indices.size()[0] tgt_field = dict(self.fields)["tgt"].base_field vocab = tgt_field.vocab src_field = dict(self.fields)["src"].base_field vocab_src = src_field.vocab for i in range(top_prob.indices.size()[0]): word = vocab.itos[top_prob.indices[i][0]] self.vocab_dict[word] += 1 if (is_function_word(word)): self.total_function_words += 1 else: self.total_content_words += 1 if self.counterfactual_attention_method == 'uniform_or_zero_out_max': log_probs_uniform = hack_dict['log_probs_uniform'] top_prob_uniform = torch.topk(log_probs_uniform, k=1, dim=1) unaffected_boolean_uniform = ( top_prob.indices == top_prob_uniform.indices) log_probs_zero_out_max = hack_dict['log_probs_zero_out_max'] top_prob_zero_out_max = torch.topk(log_probs_zero_out_max, k=1, dim=1) unaffected_boolean_zero_out_max = ( top_prob.indices == top_prob_zero_out_max.indices) log_probs_permute = hack_dict['log_probs_permute'] top_prob_permute = torch.topk(log_probs_permute, k=1, dim=1) unaffected_boolean_permute = ( top_prob.indices == top_prob_permute.indices) for i in range(unaffected_boolean_zero_out_max.size()[0]): if (unaffected_boolean_uniform[i][0].item() == 1 or unaffected_boolean_zero_out_max[i][0].item() == 1 or unaffected_boolean_permute[i][0].item() == 1): word = vocab.itos[top_prob.indices[i][0]] if is_function_word(word): self.unaffected_function_words_count += 1 self.unaffected_function_words[word] += 1 else: self.unaffected_content_words_count += 1 self.unaffected_content_words[word] += 1 else: log_probs_counterfactual = hack_dict[ 'log_probs_counterfactual'] top_prob_counterfactual = torch.topk(log_probs_counterfactual, k=1, dim=1) unaffected_boolean = ( top_prob.indices == top_prob_counterfactual.indices) self.unaffected_words_count += unaffected_boolean.sum( dim=0).cpu().numpy()[0] for i in range(unaffected_boolean.size()[0]): if (unaffected_boolean[i][0].item() == 1): word = vocab.itos[top_prob.indices[i][0]] if is_function_word(word): self.unaffected_function_words_count += 1 self.unaffected_function_words[word] += 1 else: self.unaffected_content_words_count += 1 self.unaffected_content_words[word] += 1 #if tvd_permute is True: # max_attention = hack_dict['tvd_max_attention'].cpu() # dist_change_median = hack_dict['tvd_dist_change_median'].cpu() # for i in range(top_prob.indices.size()[0]): # if(max_attention[i] > 0.5 and dist_change_median[i] < 0.2): # self.tvd_tokens[vocab.itos[top_prob.indices[i][0]]] += 1 # assert (len(max_attention.size()) == 1) # assert (len(dist_change_median.size()) == 1) # assert (max_attention.size()[0] == dist_change_median.size()[0]) # for i in range(max_attention.size()[0]): # max_att_dist_change_pairs.append((float(max_attention[i].cpu()), float(dist_change_median[i].cpu()))) random_sampler.advance(log_probs, attn) any_batch_is_finished = random_sampler.is_finished.any() if any_batch_is_finished: random_sampler.update_finished() if random_sampler.done: break if any_batch_is_finished: select_indices = random_sampler.select_indices # Reorder states. if isinstance(memory_bank, tuple): memory_bank = tuple( x.index_select(1, select_indices) for x in memory_bank) else: memory_bank = memory_bank.index_select(1, select_indices) memory_lengths = memory_lengths.index_select(0, select_indices) if src_map is not None: src_map = src_map.index_select(1, select_indices) self.model.decoder.map_state( lambda state, dim: state.index_select(dim, select_indices)) results["scores"] = random_sampler.scores results["predictions"] = random_sampler.predictions results["attention"] = random_sampler.attention return results