示例#1
0
    def _translate_random_sampling(self, src, src_lengths, batch_size, min_length=0, sampling_temp=1.0, keep_topk=1, return_attention=False, pretraining=False):

        max_length = self.config.max_sequence_length + 1 # to account for EOS
        
        # Encoder forward.
        enc_states, memory_bank, src_lengths = self.encoder(src, src_lengths)
        self.decoder.init_state(src, memory_bank, enc_states)

        results = { "predictions": None, "scores": None, "attention": None }

        memory_lengths = src_lengths

        mb_device = memory_bank[0].device if isinstance(memory_bank, tuple) else memory_bank.device
        
        block_ngram_repeat = 0
        _exclusion_idxs = {}

        random_sampler = RandomSampling(
            self.config.tgt_padding, self.config.tgt_bos, self.config.tgt_eos,
            batch_size, mb_device, min_length, block_ngram_repeat,
            _exclusion_idxs, return_attention, max_length,
            sampling_temp, keep_topk, memory_lengths)

        for step in range(max_length):
            # Shape: (1, B, 1)
            decoder_input = random_sampler.alive_seq[:, -1].view(1, -1, 1)

            log_probs, attn = self._decode_and_generate(decoder_input, memory_bank, memory_lengths, step, pretraining)
                        
            if self.config.DISTRIBUTIONAL and not pretraining:
                log_probs = (log_probs * self.quantile_weight).sum(dim=3).squeeze(0)
                            
            random_sampler.advance(log_probs, attn)
            any_batch_is_finished = random_sampler.is_finished.any()
            if any_batch_is_finished:
                random_sampler.update_finished()
                if random_sampler.done:
                    break

            if any_batch_is_finished:
                select_indices = random_sampler.select_indices

                # Reorder states.
                if isinstance(memory_bank, tuple):
                    memory_bank = tuple(x.index_select(1, select_indices) for x in memory_bank)
                else:
                    memory_bank = memory_bank.index_select(1, select_indices)

                memory_lengths = memory_lengths.index_select(0, select_indices)

                self.decoder.map_state(
                    lambda state, dim: state.index_select(dim, select_indices))
        
        results["scores"] = random_sampler.scores
        results["predictions"] = random_sampler.predictions
        results["attention"] = random_sampler.attention
        return results
    def _translate_random_sampling(
            self,
            batch,
            src_vocabs,
            max_length,
            min_length=0,
            sampling_temp=1.0,
            keep_topk=-1,
            return_attention=False):
        """Alternative to beam search. Do random sampling at each step."""

        assert self.beam_size == 1

        # TODO: support these blacklisted features.
        assert self.block_ngram_repeat == 0

        batch_size = batch.batch_size

        # Encoder forward.
        src, enc_states, memory_bank, src_lengths = self._run_encoder(batch)
        self.model.decoder.init_state(src, memory_bank, enc_states)

        use_src_map = self.copy_attn

        results = {
            "predictions": None,
            "scores": None,
            "attention": None,
            "batch": batch,
            "gold_score": self._gold_score(
                batch, memory_bank, src_lengths, src_vocabs, use_src_map,
                enc_states, batch_size, src)}

        memory_lengths = src_lengths
        src_map = batch.src_map if use_src_map else None

        if isinstance(memory_bank, tuple):
            mb_device = memory_bank[0].device
        else:
            mb_device = memory_bank.device

        random_sampler = RandomSampling(
            self._tgt_pad_idx, self._tgt_bos_idx, self._tgt_eos_idx,
            batch_size, mb_device, min_length, self.block_ngram_repeat,
            self._exclusion_idxs, return_attention, self.max_length,
            sampling_temp, keep_topk, memory_lengths)

        for step in range(max_length):
            # Shape: (1, B, 1)
            decoder_input = random_sampler.alive_seq[:, -1].view(1, -1, 1)

            log_probs, attn = self._decode_and_generate(
                decoder_input,
                memory_bank,
                batch,
                src_vocabs,
                memory_lengths=memory_lengths,
                src_map=src_map,
                step=step,
                batch_offset=random_sampler.select_indices
            )

            random_sampler.advance(log_probs, attn)
            any_batch_is_finished = random_sampler.is_finished.any()
            if any_batch_is_finished:
                random_sampler.update_finished()
                if random_sampler.done:
                    break

            if any_batch_is_finished:
                select_indices = random_sampler.select_indices

                # Reorder states.
                if isinstance(memory_bank, tuple):
                    memory_bank = tuple(x.index_select(1, select_indices)
                                        for x in memory_bank)
                else:
                    memory_bank = memory_bank.index_select(1, select_indices)

                memory_lengths = memory_lengths.index_select(0, select_indices)

                if src_map is not None:
                    src_map = src_map.index_select(1, select_indices)

                self.model.decoder.map_state(
                    lambda state, dim: state.index_select(dim, select_indices))

        results["scores"] = random_sampler.scores
        results["predictions"] = random_sampler.predictions
        results["attention"] = random_sampler.attention
        return results
示例#3
0
    def test_returns_correct_scores_non_deterministic(self):
        for batch_sz in [1, 13]:
            for temp in [1., 3.]:
                n_words = 100
                _non_eos_idxs = [47, 51, 13, 88, 99]
                valid_score_dist_1 = torch.log_softmax(torch.tensor(
                    [6., 5., 4., 3., 2., 1.]),
                                                       dim=0)
                valid_score_dist_2 = torch.log_softmax(torch.tensor([6., 1.]),
                                                       dim=0)
                eos_idx = 2
                lengths = torch.randint(0, 30, (batch_sz, ))
                samp = RandomSampling(0, 1, 2,
                                      batch_sz, torch.device("cpu"), 0, False,
                                      set(), False, 30, temp, 2, lengths)

                # initial step
                i = 0
                for _ in range(100):
                    word_probs = torch.full((batch_sz, n_words), -float('inf'))
                    # batch 0 dies on step 0
                    word_probs[0, eos_idx] = valid_score_dist_1[0]
                    # include at least one prediction OTHER than EOS
                    # that is greater than -1e20
                    word_probs[0, _non_eos_idxs] = valid_score_dist_1[1:]
                    word_probs[1:, _non_eos_idxs[0] + i] = 0

                    attns = torch.randn(1, batch_sz, 53)
                    samp.advance(word_probs, attns)
                    if samp.is_finished[0].eq(1).all():
                        break
                else:
                    self.fail("Batch 0 never ended (very unlikely but maybe "
                              "due to stochasticisty. If so, please increase "
                              "the range of the for-loop.")
                samp.update_finished()
                self.assertEqual(samp.scores[0],
                                 [valid_score_dist_1[0] / temp])
                if batch_sz == 1:
                    self.assertTrue(samp.done)
                    continue
                else:
                    self.assertFalse(samp.done)

                # step 2
                i = 1
                for _ in range(100):
                    word_probs = torch.full((batch_sz - 1, n_words),
                                            -float('inf'))
                    # (old) batch 8 dies on step 1
                    word_probs[7, eos_idx] = valid_score_dist_2[0]
                    word_probs[0:7, _non_eos_idxs[:2]] = valid_score_dist_2
                    word_probs[8:, _non_eos_idxs[:2]] = valid_score_dist_2

                    attns = torch.randn(1, batch_sz, 53)
                    samp.advance(word_probs, attns)
                    if samp.is_finished[7].eq(1).all():
                        break
                else:
                    self.fail("Batch 8 never ended (very unlikely but maybe "
                              "due to stochasticisty. If so, please increase "
                              "the range of the for-loop.")

                samp.update_finished()
                self.assertEqual(samp.scores[8],
                                 [valid_score_dist_2[0] / temp])

                # step 3
                i = 2
                for _ in range(250):
                    word_probs = torch.full((samp.alive_seq.shape[0], n_words),
                                            -float('inf'))
                    # everything dies
                    word_probs[:, eos_idx] = 0

                    attns = torch.randn(1, batch_sz, 53)
                    samp.advance(word_probs, attns)
                    if samp.is_finished.any():
                        samp.update_finished()
                    if samp.is_finished.eq(1).all():
                        break
                else:
                    self.fail("All batches never ended (very unlikely but "
                              "maybe due to stochasticisty. If so, please "
                              "increase the range of the for-loop.")

                for b in range(batch_sz):
                    if b != 0 and b != 8:
                        self.assertEqual(samp.scores[b], [0])
                self.assertTrue(samp.done)
示例#4
0
    def test_returns_correct_scores_deterministic(self):
        for batch_sz in [1, 13]:
            for temp in [1., 3.]:
                n_words = 100
                _non_eos_idxs = [47, 51, 13, 88, 99]
                valid_score_dist_1 = torch.log_softmax(torch.tensor(
                    [6., 5., 4., 3., 2., 1.]),
                                                       dim=0)
                valid_score_dist_2 = torch.log_softmax(torch.tensor([6., 1.]),
                                                       dim=0)
                eos_idx = 2
                lengths = torch.randint(0, 30, (batch_sz, ))
                samp = RandomSampling(0, 1, 2,
                                      batch_sz, torch.device("cpu"), 0, False,
                                      set(), False, 30, temp, 1, lengths)

                # initial step
                i = 0
                word_probs = torch.full((batch_sz, n_words), -float('inf'))
                # batch 0 dies on step 0
                word_probs[0, eos_idx] = valid_score_dist_1[0]
                # include at least one prediction OTHER than EOS
                # that is greater than -1e20
                word_probs[0, _non_eos_idxs] = valid_score_dist_1[1:]
                word_probs[1:, _non_eos_idxs[0] + i] = 0

                attns = torch.randn(1, batch_sz, 53)
                samp.advance(word_probs, attns)
                self.assertTrue(samp.is_finished[0].eq(1).all())
                samp.update_finished()
                self.assertEqual(samp.scores[0],
                                 [valid_score_dist_1[0] / temp])
                if batch_sz == 1:
                    self.assertTrue(samp.done)
                    continue
                else:
                    self.assertFalse(samp.done)

                # step 2
                i = 1
                word_probs = torch.full((batch_sz - 1, n_words), -float('inf'))
                # (old) batch 8 dies on step 1
                word_probs[7, eos_idx] = valid_score_dist_2[0]
                word_probs[0:7, _non_eos_idxs[:2]] = valid_score_dist_2
                word_probs[8:, _non_eos_idxs[:2]] = valid_score_dist_2

                attns = torch.randn(1, batch_sz, 53)
                samp.advance(word_probs, attns)

                self.assertTrue(samp.is_finished[7].eq(1).all())
                samp.update_finished()
                self.assertEqual(samp.scores[8],
                                 [valid_score_dist_2[0] / temp])

                # step 3
                i = 2
                word_probs = torch.full((batch_sz - 2, n_words), -float('inf'))
                # everything dies
                word_probs[:, eos_idx] = 0

                attns = torch.randn(1, batch_sz, 53)
                samp.advance(word_probs, attns)

                self.assertTrue(samp.is_finished.eq(1).all())
                samp.update_finished()
                for b in range(batch_sz):
                    if b != 0 and b != 8:
                        self.assertEqual(samp.scores[b], [0])
                self.assertTrue(samp.done)
示例#5
0
    def _translate_random_sampling(
            self,
            batch,
            src_vocabs,
            max_length,
            min_length=0,
            sampling_temp=1.0,
            keep_topk=-1,
            return_attention=False):
        """Alternative to beam search. Do random sampling at each step."""

        assert self.beam_size == 1

        # TODO: support these blacklisted features.
        assert self.block_ngram_repeat == 0

        batch_size = batch.batch_size

        # Encoder forward.
        src, enc_states, memory_bank, src_lengths = self._run_encoder(batch)
        self.model.decoder.init_state(src, memory_bank, enc_states)

        use_src_map = self.copy_attn

        results = {
            "predictions": None,
            "scores": None,
            "attention": None,
            "batch": batch,
            "gold_score": self._gold_score(
                batch, memory_bank, src_lengths, src_vocabs, use_src_map,
                enc_states, batch_size, src)}

        memory_lengths = src_lengths
        src_map = batch.src_map if use_src_map else None

        if isinstance(memory_bank, tuple):
            mb_device = memory_bank[0].device
        else:
            mb_device = memory_bank.device

        score_temp, VAD_score_temp, src_relevance_temp, VAD_lambda, decoder_word_embedding = None, None, None, None, None
        if isinstance(self.model.decoder, KGTransformerDecoder) and self.emotion_topk_decoding_temp != 0:
            score_temp = self.model.decoder.score_temp
            VAD_score_temp = self.model.decoder.VAD_score_temp
            src_relevance_temp = self.model.decoder.src_relevance_temp
            VAD_lambda = self.model.decoder.VAD_lambda
            decoder_word_embedding = self.model.decoder.embeddings
        
        random_sampler = RandomSampling(
            self._tgt_pad_idx, self._tgt_bos_idx, self._tgt_eos_idx,
            batch_size, mb_device, min_length, self.block_ngram_repeat,
            self._exclusion_idxs, return_attention, self.max_length,
            sampling_temp, keep_topk, memory_lengths, self.emotion_topk_decoding_temp, 
            score_temp=score_temp, VAD_score_temp=VAD_score_temp, src_relevance_temp=src_relevance_temp, VAD_lambda=VAD_lambda,
            decoder_word_embedding=decoder_word_embedding)

        # batch_emotion = None
        # if hasattr(batch, "emotion"):
        #     print("batch has emotion")
        #     batch_emotion = batch.emotion
        # else:
        #     print("batch has no emotion")
        batch_emotion = None
        if hasattr(batch, "emotion"):
            batch_emotion = batch.emotion
        
        batch_tgt_concept_emb = None
        if hasattr(batch, "tgt_concept_emb"):
            batch_tgt_concept_emb = batch.tgt_concept_emb
        
        batch_tgt_concept_words = None
        if hasattr(batch, "tgt_concept_index"):
            batch_tgt_concept_words = batch.tgt_concept_index, batch.tgt_concept_score, batch.tgt_concept_VAD_score

        for step in range(max_length):
            # Shape: (1, B, 1)
            decoder_input = random_sampler.alive_seq[:, -1].view(1, -1, 1)

            # log_probs: ``(batch_size, vocab_size)``
            log_probs, attn = self._decode_and_generate(
                decoder_input,
                memory_bank,
                batch,
                src_vocabs,
                memory_lengths=memory_lengths,
                src_map=src_map,
                step=step,
                batch_offset=random_sampler.select_indices,
                batch_emotion=batch_emotion,
                batch_tgt_concept_emb=batch_tgt_concept_emb,
                batch_tgt_concept_words=batch_tgt_concept_words
            )

            if hasattr(batch, "tgt_concept_index"):
                random_sampler.advance(log_probs, attn, batch_tgt_concept_words, memory_bank)
            else:
                random_sampler.advance(log_probs, attn)
            
            any_batch_is_finished = random_sampler.is_finished.any()
            if any_batch_is_finished:
                random_sampler.update_finished()
                if random_sampler.done:
                    break

            if any_batch_is_finished:
                select_indices = random_sampler.select_indices

                # Reorder states.
                if isinstance(memory_bank, tuple):
                    # print("Reordering memory bank tuple")
                    # print(select_indices)
                    memory_bank = tuple(x.index_select(1, select_indices)
                                        for x in memory_bank)
                else:
                    # print("Reordering memory bank")
                    # print(select_indices)
                    memory_bank = memory_bank.index_select(1, select_indices)

                memory_lengths = memory_lengths.index_select(0, select_indices)

                if src_map is not None:
                    src_map = src_map.index_select(1, select_indices)

                self.model.decoder.map_state(
                    lambda state, dim: state.index_select(dim, select_indices))

                if hasattr(batch, "emotion"):
                    batch_emotion = batch_emotion.index_select(0, select_indices)

                if hasattr(batch, "tgt_concept_emb"):
                    batch_tgt_concept_emb = batch_tgt_concept_emb.index_select(0, select_indices)
            
                if hasattr(batch, "tgt_concept_index"):
                    batch_tgt_concept_words = batch_tgt_concept_words[0].index_select(0, select_indices), \
                        batch_tgt_concept_words[1].index_select(0, select_indices), \
                            batch_tgt_concept_words[2].index_select(0, select_indices)
            
            # update batch_emotion using a mask
            # TO DO

        results["scores"] = random_sampler.scores
        results["predictions"] = random_sampler.predictions
        results["attention"] = random_sampler.attention
        return results
示例#6
0
class Translator(object):
    """Translate a batch of sentences with a saved model.

    Args:
        model (onmt.modules.NMTModel): NMT model to use for translation
        fields (dict[str, torchtext.data.Field]): A dict
            mapping each side to its list of name-Field pairs.
        src_reader (onmt.inputters.DataReaderBase): Source reader.
        tgt_reader (onmt.inputters.TextDataReader): Target reader.
        gpu (int): GPU device. Set to negative for no GPU.
        n_best (int): How many beams to wait for.
        min_length (int): See
            :class:`onmt.translate.decode_strategy.DecodeStrategy`.
        max_length (int): See
            :class:`onmt.translate.decode_strategy.DecodeStrategy`.
        beam_size (int): Number of beams.
        random_sampling_topk (int): See
            :class:`onmt.translate.random_sampling.RandomSampling`.
        random_sampling_temp (int): See
            :class:`onmt.translate.random_sampling.RandomSampling`.
        stepwise_penalty (bool): Whether coverage penalty is applied every step
            or not.
        dump_beam (bool): Debugging option.
        block_ngram_repeat (int): See
            :class:`onmt.translate.decode_strategy.DecodeStrategy`.
        ignore_when_blocking (set or frozenset): See
            :class:`onmt.translate.decode_strategy.DecodeStrategy`.
        replace_unk (bool): Replace unknown token.
        data_type (str): Source data type.
        verbose (bool): Print/log every translation.
        report_bleu (bool): Print/log Bleu metric.
        report_rouge (bool): Print/log Rouge metric.
        report_time (bool): Print/log total time/frequency.
        copy_attn (bool): Use copy attention.
        global_scorer (onmt.translate.GNMTGlobalScorer): Translation
            scoring/reranking object.
        out_file (TextIO or codecs.StreamReaderWriter): Output file.
        report_score (bool) : Whether to report scores
        logger (logging.Logger or NoneType): Logger.
    """
    def __init__(self,
                 model,
                 fields,
                 src_reader,
                 tgt_reader,
                 gpu=-1,
                 n_best=1,
                 min_length=0,
                 max_length=100,
                 beam_size=30,
                 random_sampling_topk=1,
                 random_sampling_temp=1,
                 stepwise_penalty=None,
                 dump_beam=False,
                 block_ngram_repeat=0,
                 ignore_when_blocking=frozenset(),
                 replace_unk=False,
                 data_type="text",
                 verbose=False,
                 report_bleu=False,
                 report_rouge=False,
                 report_time=False,
                 copy_attn=False,
                 simple_fusion=False,
                 gpt_tgt=False,
                 global_scorer=None,
                 out_file=None,
                 report_score=True,
                 logger=None,
                 seed=-1):
        self.model = model
        self.fields = fields
        tgt_field = dict(self.fields)["tgt"].base_field
        self._tgt_vocab = tgt_field.vocab
        self._tgt_eos_idx = self._tgt_vocab.stoi[tgt_field.eos_token]
        self._tgt_pad_idx = self._tgt_vocab.stoi[tgt_field.pad_token]
        self._tgt_bos_idx = self._tgt_vocab.stoi[tgt_field.init_token]
        # self._tgt_bos_idx = self._tgt_vocab.stoi['Ġsee']
        print(self._tgt_bos_idx)
        self._tgt_unk_idx = self._tgt_vocab.stoi[tgt_field.unk_token]
        self._tgt_vocab_len = len(self._tgt_vocab)

        self._gpu = gpu
        self._use_cuda = gpu > -1
        self._dev = torch.device("cuda", self._gpu) \
            if self._use_cuda else torch.device("cpu")

        self.n_best = n_best
        self.max_length = max_length

        self.beam_size = beam_size
        self.random_sampling_temp = random_sampling_temp
        self.sample_from_topk = random_sampling_topk

        self.min_length = min_length
        self.stepwise_penalty = stepwise_penalty
        self.dump_beam = dump_beam
        self.block_ngram_repeat = block_ngram_repeat
        self.ignore_when_blocking = ignore_when_blocking
        self._exclusion_idxs = {
            self._tgt_vocab.stoi[t]
            for t in self.ignore_when_blocking
        }
        self.src_reader = src_reader
        self.tgt_reader = tgt_reader
        self.replace_unk = replace_unk
        if self.replace_unk and not self.model.decoder.attentional:
            raise ValueError("replace_unk requires an attentional decoder.")
        self.data_type = data_type
        self.verbose = verbose
        self.report_bleu = report_bleu
        self.report_rouge = report_rouge
        self.report_time = report_time

        self.copy_attn = copy_attn
        self.simple_fusion = simple_fusion
        self.gpt_tgt = gpt_tgt

        self.global_scorer = global_scorer
        if self.global_scorer.has_cov_pen and \
                not self.model.decoder.attentional:
            raise ValueError(
                "Coverage penalty requires an attentional decoder.")
        self.out_file = out_file
        self.report_score = report_score
        self.logger = logger

        self.use_filter_pred = False
        self._filter_pred = None

        # for debugging
        self.beam_trace = self.dump_beam != ""
        self.beam_accum = None
        if self.beam_trace:
            self.beam_accum = {
                "predicted_ids": [],
                "beam_parent_ids": [],
                "scores": [],
                "log_probs": []
            }

        set_random_seed(seed, self._use_cuda)

    @classmethod
    def from_opt(cls,
                 model,
                 fields,
                 opt,
                 model_opt,
                 global_scorer=None,
                 out_file=None,
                 report_score=True,
                 logger=None):
        """Alternate constructor.

        Args:
            model (onmt.modules.NMTModel): See :func:`__init__()`.
            fields (dict[str, torchtext.data.Field]): See
                :func:`__init__()`.
            opt (argparse.Namespace): Command line options
            model_opt (argparse.Namespace): Command line options saved with
                the model checkpoint.
            global_scorer (onmt.translate.GNMTGlobalScorer): See
                :func:`__init__()`..
            out_file (TextIO or codecs.StreamReaderWriter): See
                :func:`__init__()`.
            report_score (bool) : See :func:`__init__()`.
            logger (logging.Logger or NoneType): See :func:`__init__()`.
        """

        if opt.data_type == 'none':
            src_reader = None
        else:
            src_reader = inputters.str2reader[opt.data_type].from_opt(opt)
        tgt_reader = inputters.str2reader["text"].from_opt(opt)
        return cls(model,
                   fields,
                   src_reader,
                   tgt_reader,
                   gpu=opt.gpu,
                   n_best=opt.n_best,
                   min_length=opt.min_length,
                   max_length=opt.max_length,
                   beam_size=opt.beam_size,
                   random_sampling_topk=opt.random_sampling_topk,
                   random_sampling_temp=opt.random_sampling_temp,
                   stepwise_penalty=opt.stepwise_penalty,
                   dump_beam=opt.dump_beam,
                   block_ngram_repeat=opt.block_ngram_repeat,
                   ignore_when_blocking=set(opt.ignore_when_blocking),
                   replace_unk=opt.replace_unk,
                   data_type=opt.data_type,
                   verbose=opt.verbose,
                   report_bleu=opt.report_bleu,
                   report_rouge=opt.report_rouge,
                   report_time=opt.report_time,
                   copy_attn=model_opt.copy_attn,
                   simple_fusion=model_opt.simple_fusion,
                   gpt_tgt=model_opt.GPT_representation_mode != 'none'
                   and model_opt.GPT_representation_loc in ['tgt', 'both'],
                   global_scorer=global_scorer,
                   out_file=out_file,
                   report_score=report_score,
                   logger=logger,
                   seed=opt.seed)

    def _gold_score(self, batch, memory_bank, src_lengths, src_vocabs,
                    use_src_map, enc_states, batch_size, src):
        if "tgt" in batch.__dict__:
            gs = self._score_target(batch, memory_bank, src_lengths,
                                    src_vocabs,
                                    batch.src_map if use_src_map else None)
            self.model.decoder.init_state(src, memory_bank, enc_states)
            if self.simple_fusion:
                self.model.lm_decoder.init_state(src, None, None)
        else:
            gs = [0] * batch_size
        return gs

    def build_data_iter(self, opt, batch_size=1):
        if batch_size is None:
            raise ValueError("batch_size must be set")

        def read_file(path):
            priv_str = "r"
            priv_str += "b"
            with open(path, priv_str) as f:
                return f.readlines()

        src = read_file(opt.src)
        tgt = read_file(opt.tgt) if opt.tgt is not None else None

        readers, data, dirs = [], [], []
        if self.src_reader:
            readers += [self.src_reader]
            data += [("src", src)]
            dirs += [None]
        if tgt:
            readers += [self.tgt_reader]
            data += [("tgt", tgt)]
            dirs += [None]

        data = inputters.Dataset(
            self.fields,
            readers=readers,
            data=data,
            dirs=dirs,
            sort_key=inputters.str2sortkey[self.data_type],
            filter_pred=self._filter_pred)

        data_iter = inputters.OrderedIterator(dataset=data,
                                              device=self._dev,
                                              batch_size=batch_size,
                                              train=False,
                                              sort=False,
                                              sort_within_batch=True,
                                              shuffle=False)

        self.src_vocabs = data.src_vocabs

        self._build_xlation(data, tgt)
        self.all_scores = []
        self.all_predictions = []

        self.pred_score_total, self.pred_words_total = 0, 0
        self.gold_score_total, self.gold_words_total = 0, 0
        self.counter = count(1)
        return data_iter

    def set_encoder_state(self, batch, tags=None, temperature=1.0):

        if self.beam_size != 1:
            self.beam_size = 1
        if self.block_ngram_repeat != 0:
            self.block_ngram_repeat = 0

        # Encoder forward.
        src, enc_states, memory_bank, src_lengths = self._run_encoder(batch)
        self.model.decoder.init_state(src, memory_bank, enc_states)
        if self.simple_fusion:
            self.model.lm_decoder.init_state(src, None, None)

        use_src_map = self.copy_attn
        memory_lengths = src_lengths
        src_map = batch.src_map if use_src_map else None

        # set for the decoder usage
        self.enc_states = enc_states
        self.src = src
        self.src_lengths = src_lengths
        self.memory_bank = memory_bank
        self.memory_lengths = memory_lengths
        self.src_map = src_map
        self.batch = batch
        self.batch_size = batch.batch_size
        self.set_random_sampler(temperature=temperature)
        self._build_result()

    def forward_pass(self,
                     decoder_in,
                     step,
                     past=None,
                     input_embeds=None,
                     tags=None,
                     use_copy=True):
        memory_bank = self.memory_bank
        src_vocabs = self.src_vocabs
        memory_lengths = self.memory_lengths
        src_map = self.src_map
        batch = self.batch
        if self.copy_attn:
            # Turn any copied words into UNKs.
            decoder_in = decoder_in.masked_fill(
                decoder_in.gt(self._tgt_vocab_len - 1), self._tgt_unk_idx)

        decoder = self.model.decoder

        dec_out, all_hidden_states, past, dec_attn = decoder(
            decoder_in,
            memory_bank,
            memory_lengths=memory_lengths,
            step=step,
            past=past,
            input_embeds=input_embeds,
            pplm_return=True)

        # Generator forward.
        if not self.copy_attn:
            if "std" in dec_attn:
                attn = dec_attn["std"]
            else:
                attn = None

            if self.simple_fusion:
                lm_dec_out, _ = self.model.lm_decoder(decoder_in,
                                                      memory_bank.new_zeros(
                                                          1, 1, 1),
                                                      step=step)
                probs = self.model.generator(dec_out.squeeze(0),
                                             lm_dec_out.squeeze(0))
            else:
                probs = self.model.generator(dec_out.squeeze(0))
                # print(log_probs)
                # returns [(batch_size x beam_size) , vocab ] when 1 step
                # or [ tgt_len, batch_size, vocab ] when full sentence
        else:
            attn = dec_attn["copy"]

            scores, p_copy = self.model.generator(dec_out.view(
                -1, dec_out.size(2)),
                                                  attn.view(-1, attn.size(2)),
                                                  src_map,
                                                  tags=tags)

            scores = scores.view(batch.batch_size, -1, scores.size(-1))
            scores = collapse_copy_scores(scores,
                                          batch,
                                          self._tgt_vocab,
                                          src_vocabs,
                                          batch_dim=0,
                                          batch_offset=None)
            scores = scores.view(decoder_in.size(0), -1, scores.size(-1))
            # log_probs = scores.squeeze(0).log()
            probs = scores.squeeze(0)
            if use_copy is False:
                probs = probs[:, :50257]
                return probs, attn, all_hidden_states, past
            return probs, attn, all_hidden_states, past, p_copy
        return probs, attn, all_hidden_states, past

    def set_random_sampler(self, return_attention=False, temperature=1.0):
        if isinstance(self.memory_bank, tuple) or isinstance(
                self.memory_bank, list):
            if isinstance(self.memory_bank[0], dict):
                mb_device = self.memory_bank[0][list(
                    self.memory_bank[0].keys())[0]].device
            else:
                mb_device = self.memory_bank[0].device
        else:
            mb_device = self.memory_bank.device

        if self.max_length < 400:
            self.max_length = 400
        if self.min_length < 300:
            self.min_length = 300
        self.random_sampler = RandomSampling(
            self._tgt_pad_idx, self._tgt_bos_idx, self._tgt_eos_idx,
            self.batch_size, mb_device, self.min_length,
            self.block_ngram_repeat, self._exclusion_idxs, return_attention,
            self.max_length, temperature, self.sample_from_topk,
            self.memory_lengths)

    def generate_tokens(self, log_probs, attn=None):

        self.random_sampler.advance(log_probs, attn)
        any_batch_is_finished = self.random_sampler.is_finished.any()
        # ATTENTION: The batch size is one
        if any_batch_is_finished:
            self.random_sampler.update_finished()
            if self.random_sampler.done:
                # Finish the generation, set the resulf for generation
                self._finalize_result()
                self.generate_sentence_batch()
                return False
        else:
            return self.random_sampler.alive_seq[:, -1].view(1, 1)

    def generate_sentence_batch(self):
        batch = self.results
        translations = self.xlation_builder.from_batch(batch)
        for trans in translations:
            self.all_scores += [trans.pred_scores[:self.n_best]]
            self.pred_score_total += trans.pred_scores[0]
            self.pred_words_total += len(trans.pred_sents[0])

            n_best_preds = [
                " ".join(pred) for pred in trans.pred_sents[:self.n_best]
            ]
            self.all_predictions += [n_best_preds]
            self.out_file.write('\n'.join(n_best_preds) + '\n')
            self.out_file.flush()

            if self.verbose:
                sent_number = next(self.counter)
                output = trans.log(sent_number)
                if self.logger:
                    self.logger.info(output)
                else:
                    os.write(1, output.encode('utf-8'))

    def _build_result(self):
        use_src_map = self.copy_attn
        self.results = {
            "predictions":
            None,
            "scores":
            None,
            "attention":
            None,
            "batch":
            self.batch,
            "gold_score":
            self._gold_score(self.batch, self.memory_bank, self.src_lengths,
                             self.src_vocabs, use_src_map, self.enc_states,
                             self.batch_size, self.src)
        }

    def _build_xlation(self, data, tgt):
        self.xlation_builder = onmt.translate.TranslationBuilder(
            data, self.fields, self.n_best, self.replace_unk, tgt)

    def _finalize_result(self):
        self.results["scores"] = self.random_sampler.scores
        self.results["predictions"] = self.random_sampler.predictions
        self.results["attention"] = self.random_sampler.attention

    def _run_encoder(self, batch):
        if hasattr(batch, 'src'):
            src, src_lengths = batch.src if isinstance(batch.src, tuple) \
                else (batch.src, None)

            enc_states, memory_bank, src_lengths = self.model.encoder(
                src, src_lengths)
            if src_lengths is None:
                assert not isinstance(memory_bank, tuple), \
                    'Ensemble decoding only supported for text data'
                src_lengths = torch.Tensor(batch.batch_size) \
                    .type_as(memory_bank) \
                    .long() \
                    .fill_(memory_bank.size(0))
        else:
            src = None
            enc_states = None
            memory_bank = torch.zeros((1, batch.tgt[0].shape[1], 1),
                                      dtype=torch.float,
                                      device=batch.tgt[0].device)
            src_lengths = torch.ones((batch.tgt[0].shape[1], ),
                                     dtype=torch.long,
                                     device=batch.tgt[0].device)
            # src_lengths = None
        return src, enc_states, memory_bank, src_lengths

    def freeze_parameter(self):
        for parameter in self.model.encoder.parameters():
            parameter.requires_grad = False
        for parameter in self.model.decoder.parameters():
            parameter.requires_grad = False
        for parameter in self.model.generator.parameters():
            parameter.requires_grad = False
    def test_returns_correct_scores_non_deterministic(self):
        for batch_sz in [1, 13]:
            for temp in [1., 3.]:
                n_words = 100
                _non_eos_idxs = [47, 51, 13, 88, 99]
                valid_score_dist_1 = torch.log_softmax(torch.tensor(
                    [6., 5., 4., 3., 2., 1.]), dim=0)
                valid_score_dist_2 = torch.log_softmax(torch.tensor(
                    [6., 1.]), dim=0)
                eos_idx = 2
                lengths = torch.randint(0, 30, (batch_sz,))
                samp = RandomSampling(
                    0, 1, 2, batch_sz, torch.device("cpu"), 0,
                    False, set(), False, 30, temp, 2, lengths)

                # initial step
                i = 0
                for _ in range(100):
                    word_probs = torch.full(
                        (batch_sz, n_words), -float('inf'))
                    # batch 0 dies on step 0
                    word_probs[0, eos_idx] = valid_score_dist_1[0]
                    # include at least one prediction OTHER than EOS
                    # that is greater than -1e20
                    word_probs[0, _non_eos_idxs] = valid_score_dist_1[1:]
                    word_probs[1:, _non_eos_idxs[0] + i] = 0

                    attns = torch.randn(1, batch_sz, 53)
                    samp.advance(word_probs, attns)
                    if samp.is_finished[0].eq(1).all():
                        break
                else:
                    self.fail("Batch 0 never ended (very unlikely but maybe "
                              "due to stochasticisty. If so, please increase "
                              "the range of the for-loop.")
                samp.update_finished()
                self.assertEqual(
                    samp.scores[0], [valid_score_dist_1[0] / temp])
                if batch_sz == 1:
                    self.assertTrue(samp.done)
                    continue
                else:
                    self.assertFalse(samp.done)

                # step 2
                i = 1
                for _ in range(100):
                    word_probs = torch.full(
                        (batch_sz - 1, n_words), -float('inf'))
                    # (old) batch 8 dies on step 1
                    word_probs[7, eos_idx] = valid_score_dist_2[0]
                    word_probs[0:7, _non_eos_idxs[:2]] = valid_score_dist_2
                    word_probs[8:, _non_eos_idxs[:2]] = valid_score_dist_2

                    attns = torch.randn(1, batch_sz, 53)
                    samp.advance(word_probs, attns)
                    if samp.is_finished[7].eq(1).all():
                        break
                else:
                    self.fail("Batch 8 never ended (very unlikely but maybe "
                              "due to stochasticisty. If so, please increase "
                              "the range of the for-loop.")

                samp.update_finished()
                self.assertEqual(
                    samp.scores[8], [valid_score_dist_2[0] / temp])

                # step 3
                i = 2
                for _ in range(250):
                    word_probs = torch.full(
                        (samp.alive_seq.shape[0], n_words), -float('inf'))
                    # everything dies
                    word_probs[:, eos_idx] = 0

                    attns = torch.randn(1, batch_sz, 53)
                    samp.advance(word_probs, attns)
                    if samp.is_finished.any():
                        samp.update_finished()
                    if samp.is_finished.eq(1).all():
                        break
                else:
                    self.fail("All batches never ended (very unlikely but "
                              "maybe due to stochasticisty. If so, please "
                              "increase the range of the for-loop.")

                for b in range(batch_sz):
                    if b != 0 and b != 8:
                        self.assertEqual(samp.scores[b], [0])
                self.assertTrue(samp.done)
    def test_returns_correct_scores_deterministic(self):
        for batch_sz in [1, 13]:
            for temp in [1., 3.]:
                n_words = 100
                _non_eos_idxs = [47, 51, 13, 88, 99]
                valid_score_dist_1 = torch.log_softmax(torch.tensor(
                    [6., 5., 4., 3., 2., 1.]), dim=0)
                valid_score_dist_2 = torch.log_softmax(torch.tensor(
                    [6., 1.]), dim=0)
                eos_idx = 2
                lengths = torch.randint(0, 30, (batch_sz,))
                samp = RandomSampling(
                    0, 1, 2, batch_sz, torch.device("cpu"), 0,
                    False, set(), False, 30, temp, 1, lengths)

                # initial step
                i = 0
                word_probs = torch.full(
                    (batch_sz, n_words), -float('inf'))
                # batch 0 dies on step 0
                word_probs[0, eos_idx] = valid_score_dist_1[0]
                # include at least one prediction OTHER than EOS
                # that is greater than -1e20
                word_probs[0, _non_eos_idxs] = valid_score_dist_1[1:]
                word_probs[1:, _non_eos_idxs[0] + i] = 0

                attns = torch.randn(1, batch_sz, 53)
                samp.advance(word_probs, attns)
                self.assertTrue(samp.is_finished[0].eq(1).all())
                samp.update_finished()
                self.assertEqual(
                    samp.scores[0], [valid_score_dist_1[0] / temp])
                if batch_sz == 1:
                    self.assertTrue(samp.done)
                    continue
                else:
                    self.assertFalse(samp.done)

                # step 2
                i = 1
                word_probs = torch.full(
                    (batch_sz - 1, n_words), -float('inf'))
                # (old) batch 8 dies on step 1
                word_probs[7, eos_idx] = valid_score_dist_2[0]
                word_probs[0:7, _non_eos_idxs[:2]] = valid_score_dist_2
                word_probs[8:, _non_eos_idxs[:2]] = valid_score_dist_2

                attns = torch.randn(1, batch_sz, 53)
                samp.advance(word_probs, attns)

                self.assertTrue(samp.is_finished[7].eq(1).all())
                samp.update_finished()
                self.assertEqual(
                    samp.scores[8], [valid_score_dist_2[0] / temp])

                # step 3
                i = 2
                word_probs = torch.full(
                    (batch_sz - 2, n_words), -float('inf'))
                # everything dies
                word_probs[:, eos_idx] = 0

                attns = torch.randn(1, batch_sz, 53)
                samp.advance(word_probs, attns)

                self.assertTrue(samp.is_finished.eq(1).all())
                samp.update_finished()
                for b in range(batch_sz):
                    if b != 0 and b != 8:
                        self.assertEqual(samp.scores[b], [0])
                self.assertTrue(samp.done)
示例#9
0
    def _translate_random_sampling(self,
                                   batch,
                                   src_vocabs,
                                   max_length,
                                   min_length=0,
                                   sampling_temp=1.0,
                                   keep_topk=-1,
                                   return_attention=False):
        """Alternative to beam search. Do random sampling at each step."""

        assert self.beam_size == 1

        # TODO: support these blacklisted features.
        assert self.block_ngram_repeat == 0

        batch_size = batch.batch_size

        # Encoder forward.
        src, enc_states, memory_bank, src_lengths = self._run_encoder(batch)
        self.model.decoder.init_state(src, memory_bank, enc_states)

        use_src_map = self.copy_attn

        results = {
            "predictions":
            None,
            "scores":
            None,
            "attention":
            None,
            "batch":
            batch,
            "gold_score":
            self._gold_score(batch, memory_bank, src_lengths, src_vocabs,
                             use_src_map, enc_states, batch_size, src)
        }

        memory_lengths = src_lengths
        src_map = batch.src_map if use_src_map else None

        if isinstance(memory_bank, tuple):
            mb_device = memory_bank[0].device
        else:
            mb_device = memory_bank.device

        random_sampler = RandomSampling(
            self._tgt_pad_idx, self._tgt_bos_idx, self._tgt_eos_idx,
            batch_size, mb_device, min_length, self.block_ngram_repeat,
            self._exclusion_idxs, return_attention, self.max_length,
            sampling_temp, keep_topk, memory_lengths)

        for step in range(max_length):
            # Shape: (1, B, 1)
            decoder_input = random_sampler.alive_seq[:, -1].view(1, -1, 1)

            log_probs, attn, hack_dict = self._decode_and_generate(
                decoder_input,
                memory_bank,
                batch,
                src_vocabs,
                memory_lengths=memory_lengths,
                src_map=src_map,
                step=step,
                batch_offset=random_sampler.select_indices,
            )

            top_prob = torch.topk(log_probs, k=1, dim=1)

            self.total_tokens += top_prob.indices.size()[0]

            tgt_field = dict(self.fields)["tgt"].base_field
            vocab = tgt_field.vocab

            src_field = dict(self.fields)["src"].base_field
            vocab_src = src_field.vocab

            for i in range(top_prob.indices.size()[0]):
                word = vocab.itos[top_prob.indices[i][0]]
                self.vocab_dict[word] += 1

                if (is_function_word(word)):
                    self.total_function_words += 1
                else:
                    self.total_content_words += 1

            if self.counterfactual_attention_method == 'uniform_or_zero_out_max':
                log_probs_uniform = hack_dict['log_probs_uniform']
                top_prob_uniform = torch.topk(log_probs_uniform, k=1, dim=1)
                unaffected_boolean_uniform = (
                    top_prob.indices == top_prob_uniform.indices)

                log_probs_zero_out_max = hack_dict['log_probs_zero_out_max']
                top_prob_zero_out_max = torch.topk(log_probs_zero_out_max,
                                                   k=1,
                                                   dim=1)
                unaffected_boolean_zero_out_max = (
                    top_prob.indices == top_prob_zero_out_max.indices)

                log_probs_permute = hack_dict['log_probs_permute']
                top_prob_permute = torch.topk(log_probs_permute, k=1, dim=1)
                unaffected_boolean_permute = (
                    top_prob.indices == top_prob_permute.indices)

                for i in range(unaffected_boolean_zero_out_max.size()[0]):
                    if (unaffected_boolean_uniform[i][0].item() == 1 or
                            unaffected_boolean_zero_out_max[i][0].item() == 1
                            or unaffected_boolean_permute[i][0].item() == 1):
                        word = vocab.itos[top_prob.indices[i][0]]

                        if is_function_word(word):
                            self.unaffected_function_words_count += 1
                            self.unaffected_function_words[word] += 1
                        else:
                            self.unaffected_content_words_count += 1
                            self.unaffected_content_words[word] += 1

            else:
                log_probs_counterfactual = hack_dict[
                    'log_probs_counterfactual']
                top_prob_counterfactual = torch.topk(log_probs_counterfactual,
                                                     k=1,
                                                     dim=1)
                unaffected_boolean = (
                    top_prob.indices == top_prob_counterfactual.indices)
                self.unaffected_words_count += unaffected_boolean.sum(
                    dim=0).cpu().numpy()[0]

                for i in range(unaffected_boolean.size()[0]):
                    if (unaffected_boolean[i][0].item() == 1):
                        word = vocab.itos[top_prob.indices[i][0]]

                        if is_function_word(word):
                            self.unaffected_function_words_count += 1
                            self.unaffected_function_words[word] += 1
                        else:
                            self.unaffected_content_words_count += 1
                            self.unaffected_content_words[word] += 1

            #if tvd_permute is True:
            #    max_attention = hack_dict['tvd_max_attention'].cpu()
            #    dist_change_median = hack_dict['tvd_dist_change_median'].cpu()

            #    for i in range(top_prob.indices.size()[0]):
            #        if(max_attention[i] > 0.5 and dist_change_median[i] < 0.2):
            #            self.tvd_tokens[vocab.itos[top_prob.indices[i][0]]] += 1

            #    assert (len(max_attention.size()) == 1)
            #    assert (len(dist_change_median.size()) == 1)
            #    assert (max_attention.size()[0] == dist_change_median.size()[0])

            #    for i in range(max_attention.size()[0]):
            #        max_att_dist_change_pairs.append((float(max_attention[i].cpu()), float(dist_change_median[i].cpu())))

            random_sampler.advance(log_probs, attn)
            any_batch_is_finished = random_sampler.is_finished.any()
            if any_batch_is_finished:
                random_sampler.update_finished()
                if random_sampler.done:
                    break

            if any_batch_is_finished:
                select_indices = random_sampler.select_indices

                # Reorder states.
                if isinstance(memory_bank, tuple):
                    memory_bank = tuple(
                        x.index_select(1, select_indices) for x in memory_bank)
                else:
                    memory_bank = memory_bank.index_select(1, select_indices)

                memory_lengths = memory_lengths.index_select(0, select_indices)

                if src_map is not None:
                    src_map = src_map.index_select(1, select_indices)

                self.model.decoder.map_state(
                    lambda state, dim: state.index_select(dim, select_indices))

        results["scores"] = random_sampler.scores
        results["predictions"] = random_sampler.predictions
        results["attention"] = random_sampler.attention
        return results
示例#10
0
    def _translate_random_sampling(
            self,
            batch,
            src_vocabs,
            max_length,
            min_length=0,
            sampling_temp=1.0,
            keep_topk=-1,
            return_attention=False):
        """Alternative to beam search. Do random sampling at each step."""

        assert self.beam_size == 1

        # TODO: support these blacklisted features.
        assert self.block_ngram_repeat == 0

        batch_size = batch.batch_size

        # Encoder forward.
        src, enc_states, memory_bank, src_lengths = self._run_encoder(batch)
        self.model.decoder.init_state(src, memory_bank, enc_states)

        use_src_map = self.copy_attn

        results = {
            "predictions": None,
            "scores": None,
            "attention": None,
            "batch": batch,
            "gold_score": self._gold_score(
                batch, memory_bank, src_lengths, src_vocabs, use_src_map,
                enc_states, batch_size, src)}

        memory_lengths = src_lengths
        src_map = batch.src_map if use_src_map else None

        if isinstance(memory_bank, tuple):
            mb_device = memory_bank[0].device
        else:
            mb_device = memory_bank.device

        random_sampler = RandomSampling(
            self._tgt_pad_idx, self._tgt_bos_idx, self._tgt_eos_idx,
            batch_size, mb_device, min_length, self.block_ngram_repeat,
            self._exclusion_idxs, return_attention, self.max_length,
            sampling_temp, keep_topk, memory_lengths)

        for step in range(max_length):
            # Shape: (1, B, 1)
            decoder_input = random_sampler.alive_seq[:, -1].view(1, -1, 1)

            log_probs, attn = self._decode_and_generate(
                decoder_input,
                memory_bank,
                batch,
                src_vocabs,
                memory_lengths=memory_lengths,
                src_map=src_map,
                step=step,
                batch_offset=random_sampler.select_indices
            )

            random_sampler.advance(log_probs, attn)
            any_batch_is_finished = random_sampler.is_finished.any()
            if any_batch_is_finished:
                random_sampler.update_finished()
                if random_sampler.done:
                    break

            if any_batch_is_finished:
                select_indices = random_sampler.select_indices

                # Reorder states.
                if isinstance(memory_bank, tuple):
                    memory_bank = tuple(x.index_select(1, select_indices)
                                        for x in memory_bank)
                else:
                    memory_bank = memory_bank.index_select(1, select_indices)

                memory_lengths = memory_lengths.index_select(0, select_indices)

                if src_map is not None:
                    src_map = src_map.index_select(1, select_indices)

                self.model.decoder.map_state(
                    lambda state, dim: state.index_select(dim, select_indices))

        results["scores"] = random_sampler.scores
        results["predictions"] = random_sampler.predictions
        results["attention"] = random_sampler.attention
        return results