示例#1
0
    def _translate_random_sampling(self,
                                   batch,
                                   src_vocabs,
                                   max_length,
                                   min_length=0,
                                   sampling_temp=1.0,
                                   keep_topk=-1,
                                   return_attention=False):
        """Alternative to beam search. Do random sampling at each step."""

        assert self.beam_size == 1

        # TODO: support these blacklisted features.
        assert self.block_ngram_repeat == 0

        batch_size = batch.batch_size

        # Encoder forward.
        src, enc_states, memory_bank, src_lengths = self._run_encoder(batch)
        self.model.decoder.init_state(src, memory_bank, enc_states)

        use_src_map = self.copy_attn

        results = {
            "predictions":
            None,
            "scores":
            None,
            "attention":
            None,
            "batch":
            batch,
            "gold_score":
            self._gold_score(batch, memory_bank, src_lengths, src_vocabs,
                             use_src_map, enc_states, batch_size, src)
        }

        memory_lengths = src_lengths
        src_map = batch.src_map if use_src_map else None

        if isinstance(memory_bank, tuple):
            mb_device = memory_bank[0].device
        else:
            mb_device = memory_bank.device

        random_sampler = RandomSampling(
            self._tgt_pad_idx, self._tgt_bos_idx, self._tgt_eos_idx,
            batch_size, mb_device, min_length, self.block_ngram_repeat,
            self._exclusion_idxs, return_attention, self.max_length,
            sampling_temp, keep_topk, memory_lengths)

        for step in range(max_length):
            # Shape: (1, B, 1)
            decoder_input = random_sampler.alive_seq[:, -1].view(1, -1, 1)

            log_probs, attn = self._decode_and_generate(
                decoder_input,
                memory_bank,
                batch,
                src_vocabs,
                memory_lengths=memory_lengths,
                src_map=src_map,
                step=step,
                batch_offset=random_sampler.select_indices)

            random_sampler.advance(log_probs, attn)
            any_batch_is_finished = random_sampler.is_finished.any()
            if any_batch_is_finished:
                random_sampler.update_finished()
                if random_sampler.done:
                    break

            if any_batch_is_finished:
                select_indices = random_sampler.select_indices

                # Reorder states.
                if isinstance(memory_bank, tuple):
                    memory_bank = tuple(
                        x.index_select(1, select_indices) for x in memory_bank)
                else:
                    memory_bank = memory_bank.index_select(1, select_indices)

                memory_lengths = memory_lengths.index_select(0, select_indices)

                if src_map is not None:
                    src_map = src_map.index_select(1, select_indices)

                self.model.decoder.map_state(
                    lambda state, dim: state.index_select(dim, select_indices))

        results["scores"] = random_sampler.scores
        results["predictions"] = random_sampler.predictions
        results["attention"] = random_sampler.attention
        return results
示例#2
0
    def _translate_random_sampling(
            self,
            batch,
            src_vocabs,
            max_length,
            min_length=0,
            sampling_temp=1.0,
            keep_topk=-1,
            return_attention=False):
        """Alternative to beam search. Do random sampling at each step."""

        assert self.beam_size == 1

        # TODO: support these blacklisted features.
        assert self.block_ngram_repeat == 0

        batch_size = batch.batch_size

        # Encoder forward.
        src, enc_states, memory_bank, src_lengths = self._run_encoder(batch)
        self.model.decoder.init_state(src, memory_bank, enc_states)

        use_src_map = self.copy_attn

        results = {
            "predictions": None,
            "scores": None,
            "attention": None,
            "batch": batch,
            "gold_score": self._gold_score(
                batch, memory_bank, src_lengths, src_vocabs, use_src_map,
                enc_states, batch_size, src)}

        memory_lengths = src_lengths
        src_map = batch.src_map if use_src_map else None

        if isinstance(memory_bank, tuple):
            mb_device = memory_bank[0].device
        else:
            mb_device = memory_bank.device

        score_temp, VAD_score_temp, src_relevance_temp, VAD_lambda, decoder_word_embedding = None, None, None, None, None
        if isinstance(self.model.decoder, KGTransformerDecoder) and self.emotion_topk_decoding_temp != 0:
            score_temp = self.model.decoder.score_temp
            VAD_score_temp = self.model.decoder.VAD_score_temp
            src_relevance_temp = self.model.decoder.src_relevance_temp
            VAD_lambda = self.model.decoder.VAD_lambda
            decoder_word_embedding = self.model.decoder.embeddings
        
        random_sampler = RandomSampling(
            self._tgt_pad_idx, self._tgt_bos_idx, self._tgt_eos_idx,
            batch_size, mb_device, min_length, self.block_ngram_repeat,
            self._exclusion_idxs, return_attention, self.max_length,
            sampling_temp, keep_topk, memory_lengths, self.emotion_topk_decoding_temp, 
            score_temp=score_temp, VAD_score_temp=VAD_score_temp, src_relevance_temp=src_relevance_temp, VAD_lambda=VAD_lambda,
            decoder_word_embedding=decoder_word_embedding)

        # batch_emotion = None
        # if hasattr(batch, "emotion"):
        #     print("batch has emotion")
        #     batch_emotion = batch.emotion
        # else:
        #     print("batch has no emotion")
        batch_emotion = None
        if hasattr(batch, "emotion"):
            batch_emotion = batch.emotion
        
        batch_tgt_concept_emb = None
        if hasattr(batch, "tgt_concept_emb"):
            batch_tgt_concept_emb = batch.tgt_concept_emb
        
        batch_tgt_concept_words = None
        if hasattr(batch, "tgt_concept_index"):
            batch_tgt_concept_words = batch.tgt_concept_index, batch.tgt_concept_score, batch.tgt_concept_VAD_score

        for step in range(max_length):
            # Shape: (1, B, 1)
            decoder_input = random_sampler.alive_seq[:, -1].view(1, -1, 1)

            # log_probs: ``(batch_size, vocab_size)``
            log_probs, attn = self._decode_and_generate(
                decoder_input,
                memory_bank,
                batch,
                src_vocabs,
                memory_lengths=memory_lengths,
                src_map=src_map,
                step=step,
                batch_offset=random_sampler.select_indices,
                batch_emotion=batch_emotion,
                batch_tgt_concept_emb=batch_tgt_concept_emb,
                batch_tgt_concept_words=batch_tgt_concept_words
            )

            if hasattr(batch, "tgt_concept_index"):
                random_sampler.advance(log_probs, attn, batch_tgt_concept_words, memory_bank)
            else:
                random_sampler.advance(log_probs, attn)
            
            any_batch_is_finished = random_sampler.is_finished.any()
            if any_batch_is_finished:
                random_sampler.update_finished()
                if random_sampler.done:
                    break

            if any_batch_is_finished:
                select_indices = random_sampler.select_indices

                # Reorder states.
                if isinstance(memory_bank, tuple):
                    # print("Reordering memory bank tuple")
                    # print(select_indices)
                    memory_bank = tuple(x.index_select(1, select_indices)
                                        for x in memory_bank)
                else:
                    # print("Reordering memory bank")
                    # print(select_indices)
                    memory_bank = memory_bank.index_select(1, select_indices)

                memory_lengths = memory_lengths.index_select(0, select_indices)

                if src_map is not None:
                    src_map = src_map.index_select(1, select_indices)

                self.model.decoder.map_state(
                    lambda state, dim: state.index_select(dim, select_indices))

                if hasattr(batch, "emotion"):
                    batch_emotion = batch_emotion.index_select(0, select_indices)

                if hasattr(batch, "tgt_concept_emb"):
                    batch_tgt_concept_emb = batch_tgt_concept_emb.index_select(0, select_indices)
            
                if hasattr(batch, "tgt_concept_index"):
                    batch_tgt_concept_words = batch_tgt_concept_words[0].index_select(0, select_indices), \
                        batch_tgt_concept_words[1].index_select(0, select_indices), \
                            batch_tgt_concept_words[2].index_select(0, select_indices)
            
            # update batch_emotion using a mask
            # TO DO

        results["scores"] = random_sampler.scores
        results["predictions"] = random_sampler.predictions
        results["attention"] = random_sampler.attention
        return results
示例#3
0
    def test_returns_correct_scores_deterministic(self):
        for batch_sz in [1, 13]:
            for temp in [1., 3.]:
                n_words = 100
                _non_eos_idxs = [47, 51, 13, 88, 99]
                valid_score_dist_1 = torch.log_softmax(torch.tensor(
                    [6., 5., 4., 3., 2., 1.]), dim=0)
                valid_score_dist_2 = torch.log_softmax(torch.tensor(
                    [6., 1.]), dim=0)
                eos_idx = 2
                lengths = torch.randint(0, 30, (batch_sz,))
                samp = RandomSampling(
                    0, 1, 2, batch_sz, torch.device("cpu"), 0,
                    False, set(), False, 30, temp, 1, lengths)

                # initial step
                i = 0
                word_probs = torch.full(
                    (batch_sz, n_words), -float('inf'))
                # batch 0 dies on step 0
                word_probs[0, eos_idx] = valid_score_dist_1[0]
                # include at least one prediction OTHER than EOS
                # that is greater than -1e20
                word_probs[0, _non_eos_idxs] = valid_score_dist_1[1:]
                word_probs[1:, _non_eos_idxs[0] + i] = 0

                attns = torch.randn(1, batch_sz, 53)
                samp.advance(word_probs, attns)
                self.assertTrue(samp.is_finished[0].eq(1).all())
                samp.update_finished()
                self.assertEqual(
                    samp.scores[0], [valid_score_dist_1[0] / temp])
                if batch_sz == 1:
                    self.assertTrue(samp.done)
                    continue
                else:
                    self.assertFalse(samp.done)

                # step 2
                i = 1
                word_probs = torch.full(
                    (batch_sz - 1, n_words), -float('inf'))
                # (old) batch 8 dies on step 1
                word_probs[7, eos_idx] = valid_score_dist_2[0]
                word_probs[0:7, _non_eos_idxs[:2]] = valid_score_dist_2
                word_probs[8:, _non_eos_idxs[:2]] = valid_score_dist_2

                attns = torch.randn(1, batch_sz, 53)
                samp.advance(word_probs, attns)

                self.assertTrue(samp.is_finished[7].eq(1).all())
                samp.update_finished()
                self.assertEqual(
                    samp.scores[8], [valid_score_dist_2[0] / temp])

                # step 3
                i = 2
                word_probs = torch.full(
                    (batch_sz - 2, n_words), -float('inf'))
                # everything dies
                word_probs[:, eos_idx] = 0

                attns = torch.randn(1, batch_sz, 53)
                samp.advance(word_probs, attns)

                self.assertTrue(samp.is_finished.eq(1).all())
                samp.update_finished()
                for b in range(batch_sz):
                    if b != 0 and b != 8:
                        self.assertEqual(samp.scores[b], [0])
                self.assertTrue(samp.done)
示例#4
0
    def test_returns_correct_scores_non_deterministic(self):
        for batch_sz in [1, 13]:
            for temp in [1., 3.]:
                n_words = 100
                _non_eos_idxs = [47, 51, 13, 88, 99]
                valid_score_dist_1 = torch.log_softmax(torch.tensor(
                    [6., 5., 4., 3., 2., 1.]), dim=0)
                valid_score_dist_2 = torch.log_softmax(torch.tensor(
                    [6., 1.]), dim=0)
                eos_idx = 2
                lengths = torch.randint(0, 30, (batch_sz,))
                samp = RandomSampling(
                    0, 1, 2, batch_sz, torch.device("cpu"), 0,
                    False, set(), False, 30, temp, 2, lengths)

                # initial step
                i = 0
                for _ in range(100):
                    word_probs = torch.full(
                        (batch_sz, n_words), -float('inf'))
                    # batch 0 dies on step 0
                    word_probs[0, eos_idx] = valid_score_dist_1[0]
                    # include at least one prediction OTHER than EOS
                    # that is greater than -1e20
                    word_probs[0, _non_eos_idxs] = valid_score_dist_1[1:]
                    word_probs[1:, _non_eos_idxs[0] + i] = 0

                    attns = torch.randn(1, batch_sz, 53)
                    samp.advance(word_probs, attns)
                    if samp.is_finished[0].eq(1).all():
                        break
                else:
                    self.fail("Batch 0 never ended (very unlikely but maybe "
                              "due to stochasticisty. If so, please increase "
                              "the range of the for-loop.")
                samp.update_finished()
                self.assertEqual(
                    samp.scores[0], [valid_score_dist_1[0] / temp])
                if batch_sz == 1:
                    self.assertTrue(samp.done)
                    continue
                else:
                    self.assertFalse(samp.done)

                # step 2
                i = 1
                for _ in range(100):
                    word_probs = torch.full(
                        (batch_sz - 1, n_words), -float('inf'))
                    # (old) batch 8 dies on step 1
                    word_probs[7, eos_idx] = valid_score_dist_2[0]
                    word_probs[0:7, _non_eos_idxs[:2]] = valid_score_dist_2
                    word_probs[8:, _non_eos_idxs[:2]] = valid_score_dist_2

                    attns = torch.randn(1, batch_sz, 53)
                    samp.advance(word_probs, attns)
                    if samp.is_finished[7].eq(1).all():
                        break
                else:
                    self.fail("Batch 8 never ended (very unlikely but maybe "
                              "due to stochasticisty. If so, please increase "
                              "the range of the for-loop.")

                samp.update_finished()
                self.assertEqual(
                    samp.scores[8], [valid_score_dist_2[0] / temp])

                # step 3
                i = 2
                for _ in range(250):
                    word_probs = torch.full(
                        (samp.alive_seq.shape[0], n_words), -float('inf'))
                    # everything dies
                    word_probs[:, eos_idx] = 0

                    attns = torch.randn(1, batch_sz, 53)
                    samp.advance(word_probs, attns)
                    if samp.is_finished.any():
                        samp.update_finished()
                    if samp.is_finished.eq(1).all():
                        break
                else:
                    self.fail("All batches never ended (very unlikely but "
                              "maybe due to stochasticisty. If so, please "
                              "increase the range of the for-loop.")

                for b in range(batch_sz):
                    if b != 0 and b != 8:
                        self.assertEqual(samp.scores[b], [0])
                self.assertTrue(samp.done)
示例#5
0
    def _translate_random_sampling(self,
                                   batch,
                                   src_vocabs,
                                   max_length,
                                   min_length=0,
                                   sampling_temp=1.0,
                                   keep_topk=-1,
                                   return_attention=False):
        """Alternative to beam search. Do random sampling at each step."""

        assert self.beam_size == 1

        # TODO: support these blacklisted features.
        assert self.block_ngram_repeat == 0

        batch_size = batch.batch_size

        # Encoder forward.
        src, enc_states, memory_bank, src_lengths = self._run_encoder(batch)
        self.model.decoder.init_state(src, memory_bank, enc_states)

        use_src_map = self.copy_attn

        results = {
            "predictions":
            None,
            "scores":
            None,
            "attention":
            None,
            "batch":
            batch,
            "gold_score":
            self._gold_score(batch, memory_bank, src_lengths, src_vocabs,
                             use_src_map, enc_states, batch_size, src)
        }

        memory_lengths = src_lengths
        src_map = batch.src_map if use_src_map else None

        if isinstance(memory_bank, tuple):
            mb_device = memory_bank[0].device
        else:
            mb_device = memory_bank.device

        random_sampler = RandomSampling(
            self._tgt_pad_idx, self._tgt_bos_idx, self._tgt_eos_idx,
            batch_size, mb_device, min_length, self.block_ngram_repeat,
            self._exclusion_idxs, return_attention, self.max_length,
            sampling_temp, keep_topk, memory_lengths)

        for step in range(max_length):
            # Shape: (1, B, 1)
            decoder_input = random_sampler.alive_seq[:, -1].view(1, -1, 1)

            log_probs, attn, hack_dict = self._decode_and_generate(
                decoder_input,
                memory_bank,
                batch,
                src_vocabs,
                memory_lengths=memory_lengths,
                src_map=src_map,
                step=step,
                batch_offset=random_sampler.select_indices,
            )

            top_prob = torch.topk(log_probs, k=1, dim=1)

            self.total_tokens += top_prob.indices.size()[0]

            tgt_field = dict(self.fields)["tgt"].base_field
            vocab = tgt_field.vocab

            src_field = dict(self.fields)["src"].base_field
            vocab_src = src_field.vocab

            for i in range(top_prob.indices.size()[0]):
                word = vocab.itos[top_prob.indices[i][0]]
                self.vocab_dict[word] += 1

                if (is_function_word(word)):
                    self.total_function_words += 1
                else:
                    self.total_content_words += 1

            if self.counterfactual_attention_method == 'uniform_or_zero_out_max':
                log_probs_uniform = hack_dict['log_probs_uniform']
                top_prob_uniform = torch.topk(log_probs_uniform, k=1, dim=1)
                unaffected_boolean_uniform = (
                    top_prob.indices == top_prob_uniform.indices)

                log_probs_zero_out_max = hack_dict['log_probs_zero_out_max']
                top_prob_zero_out_max = torch.topk(log_probs_zero_out_max,
                                                   k=1,
                                                   dim=1)
                unaffected_boolean_zero_out_max = (
                    top_prob.indices == top_prob_zero_out_max.indices)

                log_probs_permute = hack_dict['log_probs_permute']
                top_prob_permute = torch.topk(log_probs_permute, k=1, dim=1)
                unaffected_boolean_permute = (
                    top_prob.indices == top_prob_permute.indices)

                for i in range(unaffected_boolean_zero_out_max.size()[0]):
                    if (unaffected_boolean_uniform[i][0].item() == 1 or
                            unaffected_boolean_zero_out_max[i][0].item() == 1
                            or unaffected_boolean_permute[i][0].item() == 1):
                        word = vocab.itos[top_prob.indices[i][0]]

                        if is_function_word(word):
                            self.unaffected_function_words_count += 1
                            self.unaffected_function_words[word] += 1
                        else:
                            self.unaffected_content_words_count += 1
                            self.unaffected_content_words[word] += 1

            else:
                log_probs_counterfactual = hack_dict[
                    'log_probs_counterfactual']
                top_prob_counterfactual = torch.topk(log_probs_counterfactual,
                                                     k=1,
                                                     dim=1)
                unaffected_boolean = (
                    top_prob.indices == top_prob_counterfactual.indices)
                self.unaffected_words_count += unaffected_boolean.sum(
                    dim=0).cpu().numpy()[0]

                for i in range(unaffected_boolean.size()[0]):
                    if (unaffected_boolean[i][0].item() == 1):
                        word = vocab.itos[top_prob.indices[i][0]]

                        if is_function_word(word):
                            self.unaffected_function_words_count += 1
                            self.unaffected_function_words[word] += 1
                        else:
                            self.unaffected_content_words_count += 1
                            self.unaffected_content_words[word] += 1

            #if tvd_permute is True:
            #    max_attention = hack_dict['tvd_max_attention'].cpu()
            #    dist_change_median = hack_dict['tvd_dist_change_median'].cpu()

            #    for i in range(top_prob.indices.size()[0]):
            #        if(max_attention[i] > 0.5 and dist_change_median[i] < 0.2):
            #            self.tvd_tokens[vocab.itos[top_prob.indices[i][0]]] += 1

            #    assert (len(max_attention.size()) == 1)
            #    assert (len(dist_change_median.size()) == 1)
            #    assert (max_attention.size()[0] == dist_change_median.size()[0])

            #    for i in range(max_attention.size()[0]):
            #        max_att_dist_change_pairs.append((float(max_attention[i].cpu()), float(dist_change_median[i].cpu())))

            random_sampler.advance(log_probs, attn)
            any_batch_is_finished = random_sampler.is_finished.any()
            if any_batch_is_finished:
                random_sampler.update_finished()
                if random_sampler.done:
                    break

            if any_batch_is_finished:
                select_indices = random_sampler.select_indices

                # Reorder states.
                if isinstance(memory_bank, tuple):
                    memory_bank = tuple(
                        x.index_select(1, select_indices) for x in memory_bank)
                else:
                    memory_bank = memory_bank.index_select(1, select_indices)

                memory_lengths = memory_lengths.index_select(0, select_indices)

                if src_map is not None:
                    src_map = src_map.index_select(1, select_indices)

                self.model.decoder.map_state(
                    lambda state, dim: state.index_select(dim, select_indices))

        results["scores"] = random_sampler.scores
        results["predictions"] = random_sampler.predictions
        results["attention"] = random_sampler.attention
        return results