예제 #1
0
 def translate_batch(self, batch, src_vocabs, attn_debug):
     """Translate a batch of sentences."""
     with torch.no_grad():
         if self.beam_size == 1:
             decode_strategy = GreedySearch(
                 pad=self._tgt_pad_idx,
                 bos=self._tgt_bos_idx,
                 eos=self._tgt_eos_idx,
                 batch_size=batch.batch_size,
                 min_length=self.min_length, max_length=self.max_length,
                 block_ngram_repeat=self.block_ngram_repeat,
                 exclusion_tokens=self._exclusion_idxs,
                 return_attention=attn_debug or self.replace_unk,
                 sampling_temp=self.random_sampling_temp,
                 keep_topk=self.sample_from_topk)
         else:
             # TODO: support these blacklisted features
             assert not self.dump_beam
             decode_strategy = BeamSearch(
                 self.beam_size,
                 batch_size=batch.batch_size,
                 pad=self._tgt_pad_idx,
                 bos=self._tgt_bos_idx,
                 eos=self._tgt_eos_idx,
                 n_best=self.n_best,
                 global_scorer=self.global_scorer,
                 min_length=self.min_length, max_length=self.max_length,
                 return_attention=attn_debug or self.replace_unk,
                 block_ngram_repeat=self.block_ngram_repeat,
                 exclusion_tokens=self._exclusion_idxs,
                 stepwise_penalty=self.stepwise_penalty,
                 ratio=self.ratio)
         return self._translate_batch_with_strategy(batch, src_vocabs,
                                                    decode_strategy)
예제 #2
0
    def translate_batch(self, batch, src_vocabs, attn_debug):
        """Translate a batch of sentences."""
        with torch.no_grad():
            tic = time.perf_counter()
            if self.beam_size == 1:
                decode_strategy = GreedySearch(
                    pad=self._tgt_pad_idx,
                    bos=self._tgt_bos_idx,
                    eos=self._tgt_eos_idx,
                    batch_size=batch.batch_size,
                    min_length=self.min_length,
                    max_length=self.max_length,
                    block_ngram_repeat=self.block_ngram_repeat,
                    exclusion_tokens=self._exclusion_idxs,
                    return_attention=attn_debug or self.replace_unk,
                    sampling_temp=self.random_sampling_temp,
                    keep_topk=self.sample_from_topk)
            else:
                # TODO: support these blacklisted features
                assert not self.dump_beam

                decode_strategy = BeamSearch(
                    self.beam_size,
                    batch_size=batch.batch_size,
                    pad=self._tgt_pad_idx,
                    bos=self._tgt_bos_idx,
                    eos=self._tgt_eos_idx,
                    n_best=self.n_best,
                    global_scorer=self.global_scorer,
                    min_length=self.min_length,
                    max_length=self.max_length,
                    return_attention=attn_debug or self.replace_unk,
                    block_ngram_repeat=self.block_ngram_repeat,
                    exclusion_tokens=self._exclusion_idxs,
                    stepwise_penalty=self.stepwise_penalty,
                    ratio=self.ratio)
            toc = time.perf_counter()
            beam_search_time = toc - tic

            tic = time.perf_counter()
            ret = self._translate_batch_with_strategy(batch, src_vocabs,
                                                      decode_strategy)
            toc = time.perf_counter()
            translate_batch_with_strategy_time = toc - tic

            if show_profile_detail:
                print(
                    f"BeamSearch Time {beam_search_time:0.4f} seconds, translate_batch_with_strategy Time {translate_batch_with_strategy_time: 0.4f} seconds"
                )
            return ret
예제 #3
0
 def sample_from_batch(self, batch):
     with torch.no_grad():
         decode_strategy = GreedySearch(
             pad=self._tgt_pad_idx,
             bos=self._tgt_bos_idx,
             eos=self._tgt_eos_idx,
             batch_size=batch.batch_size,
             min_length=self.min_length,
             max_length=self.max_length,
             block_ngram_repeat=0,
             exclusion_tokens=self._exclusion_idxs,
             return_attention=False,
             sampling_temp=self.random_sampling_temp,
             keep_topk=self.sample_from_topk)
         return self._sample_batch_with_strategy(batch, decode_strategy)
예제 #4
0
 def translate_batch(self, batch, src_vocabs, attn_debug, src=None, enc_states=None, memory_bank=None, \
                     src_lengths=None,  src_embed=None, tgt2=False, hidden_state=None):
     """Translate a batch of sentences."""
     if self.beam_size == 1:
         decode_strategy = GreedySearch(
             pad=self._tgt_pad_idx,
             bos=self._tgt_bos_idx,
             eos=self._tgt_eos_idx,
             batch_size=batch.batch_size,
             min_length=self.min_length,
             max_length=self.max_length,
             block_ngram_repeat=self.block_ngram_repeat,
             exclusion_tokens=self._exclusion_idxs,
             return_attention=attn_debug or self.replace_unk,
             sampling_temp=self.random_sampling_temp,
             keep_topk=self.sample_from_topk)
     else:
         # TODO: support these blacklisted features
         assert not self.dump_beam
         decode_strategy = BeamSearch(
             self.beam_size,
             batch_size=batch.batch_size,
             pad=self._tgt_pad_idx,
             bos=self._tgt_bos_idx,
             eos=self._tgt_eos_idx,
             n_best=self.n_best,
             global_scorer=self.global_scorer,
             min_length=self.min_length,
             max_length=self.max_length,
             return_attention=attn_debug or self.replace_unk,
             block_ngram_repeat=self.block_ngram_repeat,
             exclusion_tokens=self._exclusion_idxs,
             stepwise_penalty=self.stepwise_penalty,
             ratio=self.ratio)
     return self._return_gold(batch,
                              src_vocabs,
                              decode_strategy,
                              src,
                              enc_states,
                              memory_bank,
                              src_lengths,
                              src_embed,
                              tgt2,
                              hidden_state=hidden_state)
예제 #5
0
    def test_doesnt_predict_eos_if_shorter_than_min_len(self):
        # batch 0 will always predict EOS. The other batches will predict
        # non-eos scores.
        for batch_sz in [1, 3]:
            n_words = 100
            _non_eos_idxs = [47]
            valid_score_dist = torch.log_softmax(torch.tensor(
                [6., 5.]), dim=0)
            min_length = 5
            eos_idx = 2
            lengths = torch.randint(0, 30, (batch_sz,))
            samp = GreedySearch(
                0, 1, 2, batch_sz, min_length,
                False, set(), False, 30, 1., 1)
            samp.initialize(torch.zeros(1), lengths)
            all_attns = []
            for i in range(min_length + 4):
                word_probs = torch.full(
                    (batch_sz, n_words), -float('inf'))
                # "best" prediction is eos - that should be blocked
                word_probs[0, eos_idx] = valid_score_dist[0]
                # include at least one prediction OTHER than EOS
                # that is greater than -1e20
                word_probs[0, _non_eos_idxs[0]] = valid_score_dist[1]
                word_probs[1:, _non_eos_idxs[0] + i] = 0

                attns = torch.randn(1, batch_sz, 53)
                all_attns.append(attns)
                samp.advance(word_probs, attns)
                if i < min_length:
                    self.assertTrue(
                        samp.topk_scores[0].allclose(valid_score_dist[1]))
                    self.assertTrue(
                        samp.topk_scores[1:].eq(0).all())
                elif i == min_length:
                    # now batch 0 has ended and no others have
                    self.assertTrue(samp.is_finished[0, :].eq(1).all())
                    self.assertTrue(samp.is_finished[1:, 1:].eq(0).all())
                else:  # i > min_length
                    break
예제 #6
0
    def test_returns_correct_scores_deterministic(self):
        for batch_sz in [1, 13]:
            for temp in [1., 3.]:
                n_words = 100
                _non_eos_idxs = [47, 51, 13, 88, 99]
                valid_score_dist_1 = torch.log_softmax(torch.tensor(
                    [6., 5., 4., 3., 2., 1.]), dim=0)
                valid_score_dist_2 = torch.log_softmax(torch.tensor(
                    [6., 1.]), dim=0)
                eos_idx = 2
                lengths = torch.randint(0, 30, (batch_sz,))
                samp = GreedySearch(
                    0, 1, 2, batch_sz, 0,
                    False, set(), False, 30, temp, 1)
                samp.initialize(torch.zeros(1), lengths)
                # initial step
                i = 0
                word_probs = torch.full(
                    (batch_sz, n_words), -float('inf'))
                # batch 0 dies on step 0
                word_probs[0, eos_idx] = valid_score_dist_1[0]
                # include at least one prediction OTHER than EOS
                # that is greater than -1e20
                word_probs[0, _non_eos_idxs] = valid_score_dist_1[1:]
                word_probs[1:, _non_eos_idxs[0] + i] = 0

                attns = torch.randn(1, batch_sz, 53)
                samp.advance(word_probs, attns)
                self.assertTrue(samp.is_finished[0].eq(1).all())
                samp.update_finished()
                self.assertEqual(
                    samp.scores[0], [valid_score_dist_1[0] / temp])
                if batch_sz == 1:
                    self.assertTrue(samp.done)
                    continue
                else:
                    self.assertFalse(samp.done)

                # step 2
                i = 1
                word_probs = torch.full(
                    (batch_sz - 1, n_words), -float('inf'))
                # (old) batch 8 dies on step 1
                word_probs[7, eos_idx] = valid_score_dist_2[0]
                word_probs[0:7, _non_eos_idxs[:2]] = valid_score_dist_2
                word_probs[8:, _non_eos_idxs[:2]] = valid_score_dist_2

                attns = torch.randn(1, batch_sz, 53)
                samp.advance(word_probs, attns)

                self.assertTrue(samp.is_finished[7].eq(1).all())
                samp.update_finished()
                self.assertEqual(
                    samp.scores[8], [valid_score_dist_2[0] / temp])

                # step 3
                i = 2
                word_probs = torch.full(
                    (batch_sz - 2, n_words), -float('inf'))
                # everything dies
                word_probs[:, eos_idx] = 0

                attns = torch.randn(1, batch_sz, 53)
                samp.advance(word_probs, attns)

                self.assertTrue(samp.is_finished.eq(1).all())
                samp.update_finished()
                for b in range(batch_sz):
                    if b != 0 and b != 8:
                        self.assertEqual(samp.scores[b], [0])
                self.assertTrue(samp.done)
예제 #7
0
    def test_returns_correct_scores_non_deterministic(self):
        for batch_sz in [1, 13]:
            for temp in [1., 3.]:
                n_words = 100
                _non_eos_idxs = [47, 51, 13, 88, 99]
                valid_score_dist_1 = torch.log_softmax(torch.tensor(
                    [6., 5., 4., 3., 2., 1.]), dim=0)
                valid_score_dist_2 = torch.log_softmax(torch.tensor(
                    [6., 1.]), dim=0)
                eos_idx = 2
                lengths = torch.randint(0, 30, (batch_sz,))
                samp = GreedySearch(
                    0, 1, 2, batch_sz, 0,
                    False, set(), False, 30, temp, 2)
                samp.initialize(torch.zeros(1), lengths)
                # initial step
                i = 0
                for _ in range(100):
                    word_probs = torch.full(
                        (batch_sz, n_words), -float('inf'))
                    # batch 0 dies on step 0
                    word_probs[0, eos_idx] = valid_score_dist_1[0]
                    # include at least one prediction OTHER than EOS
                    # that is greater than -1e20
                    word_probs[0, _non_eos_idxs] = valid_score_dist_1[1:]
                    word_probs[1:, _non_eos_idxs[0] + i] = 0

                    attns = torch.randn(1, batch_sz, 53)
                    samp.advance(word_probs, attns)
                    if samp.is_finished[0].eq(1).all():
                        break
                else:
                    self.fail("Batch 0 never ended (very unlikely but maybe "
                              "due to stochasticisty. If so, please increase "
                              "the range of the for-loop.")
                samp.update_finished()
                self.assertEqual(
                    samp.scores[0], [valid_score_dist_1[0] / temp])
                if batch_sz == 1:
                    self.assertTrue(samp.done)
                    continue
                else:
                    self.assertFalse(samp.done)

                # step 2
                i = 1
                for _ in range(100):
                    word_probs = torch.full(
                        (batch_sz - 1, n_words), -float('inf'))
                    # (old) batch 8 dies on step 1
                    word_probs[7, eos_idx] = valid_score_dist_2[0]
                    word_probs[0:7, _non_eos_idxs[:2]] = valid_score_dist_2
                    word_probs[8:, _non_eos_idxs[:2]] = valid_score_dist_2

                    attns = torch.randn(1, batch_sz, 53)
                    samp.advance(word_probs, attns)
                    if samp.is_finished[7].eq(1).all():
                        break
                else:
                    self.fail("Batch 8 never ended (very unlikely but maybe "
                              "due to stochasticisty. If so, please increase "
                              "the range of the for-loop.")

                samp.update_finished()
                self.assertEqual(
                    samp.scores[8], [valid_score_dist_2[0] / temp])

                # step 3
                i = 2
                for _ in range(250):
                    word_probs = torch.full(
                        (samp.alive_seq.shape[0], n_words), -float('inf'))
                    # everything dies
                    word_probs[:, eos_idx] = 0

                    attns = torch.randn(1, batch_sz, 53)
                    samp.advance(word_probs, attns)
                    if samp.is_finished.any():
                        samp.update_finished()
                    if samp.is_finished.eq(1).all():
                        break
                else:
                    self.fail("All batches never ended (very unlikely but "
                              "maybe due to stochasticisty. If so, please "
                              "increase the range of the for-loop.")

                for b in range(batch_sz):
                    if b != 0 and b != 8:
                        self.assertEqual(samp.scores[b], [0])
                self.assertTrue(samp.done)
예제 #8
0
    def test_returns_correct_scores_non_deterministic_beams(self):
        beam_size = 10
        for batch_sz in [1, 13]:
            for temp in [1., 3.]:
                n_words = 100
                _non_eos_idxs = [47, 51, 13, 88, 99]
                valid_score_dist_1 = torch.log_softmax(torch.tensor(
                    [6., 5., 4., 3., 2., 1.]),
                                                       dim=0)
                valid_score_dist_2 = torch.log_softmax(torch.tensor([6., 1.]),
                                                       dim=0)
                eos_idx = 2
                lengths = torch.randint(0, 30, (batch_sz, ))
                samp = GreedySearch(0, 1, 2, 3,
                                    batch_sz, GlobalScorerStub(), 0, False,
                                    set(), False, 30, temp, 50, 0, beam_size,
                                    False)
                samp.initialize(torch.zeros((1, 1)), lengths)
                # initial step
                # finish one beam
                i = 0
                for _ in range(100):
                    word_probs = torch.full((batch_sz * beam_size, n_words),
                                            -float('inf'))

                    word_probs[beam_size - 2, eos_idx] = valid_score_dist_1[0]
                    # include at least one prediction OTHER than EOS
                    # that is greater than -1e20
                    word_probs[beam_size - 2,
                               _non_eos_idxs] = valid_score_dist_1[1:]
                    word_probs[beam_size - 2 + 1:, _non_eos_idxs[0] + i] = 0
                    word_probs[:beam_size - 2, _non_eos_idxs[0] + i] = 0

                    attns = torch.randn(1, batch_sz, 53)
                    samp.advance(word_probs, attns)
                    if samp.is_finished[beam_size - 2].eq(1).all():
                        self.assertFalse(samp.is_finished[:beam_size -
                                                          2].eq(1).any())
                        self.assertFalse(samp.is_finished[beam_size - 2 +
                                                          1].eq(1).any())
                        break
                else:
                    self.fail("Batch 0 never ended (very unlikely but maybe "
                              "due to stochasticisty. If so, please increase "
                              "the range of the for-loop.")
                samp.update_finished()
                self.assertEqual([samp.topk_scores[beam_size - 2]],
                                 [valid_score_dist_1[0] / temp])

                # step 2
                # finish example in last batch
                i = 1
                for _ in range(100):
                    word_probs = torch.full(
                        (batch_sz * beam_size - 1, n_words), -float('inf'))
                    # (old) batch 8 dies on step 1
                    word_probs[(batch_sz - 1) * beam_size + 7,
                               eos_idx] = valid_score_dist_2[0]
                    word_probs[:(batch_sz - 1) * beam_size + 7,
                               _non_eos_idxs[:2]] = valid_score_dist_2
                    word_probs[(batch_sz - 1) * beam_size + 8:,
                               _non_eos_idxs[:2]] = valid_score_dist_2

                    attns = torch.randn(1, batch_sz, 53)
                    samp.advance(word_probs, attns)
                    if (samp.is_finished[(batch_sz - 1) * beam_size +
                                         7].eq(1).all()):
                        break
                else:
                    self.fail("Batch 8 never ended (very unlikely but maybe "
                              "due to stochasticisty. If so, please increase "
                              "the range of the for-loop.")

                samp.update_finished()
                self.assertEqual([
                    score for score, _, _ in samp.hypotheses[batch_sz - 1][-1:]
                ], [valid_score_dist_2[0] / temp])

                # step 3
                i = 2
                for _ in range(250):
                    word_probs = torch.full((samp.alive_seq.shape[0], n_words),
                                            -float('inf'))
                    # everything dies
                    word_probs[:, eos_idx] = 0

                    attns = torch.randn(1, batch_sz, 53)
                    samp.advance(word_probs, attns)
                    if samp.is_finished.any():
                        samp.update_finished()
                    if samp.is_finished.eq(1).all():
                        break
                else:
                    self.fail("All batches never ended (very unlikely but "
                              "maybe due to stochasticisty. If so, please "
                              "increase the range of the for-loop.")

                self.assertTrue(samp.done)
예제 #9
0
    def _sample_batch_with_strategy(self, batch,
                                    decode_strategy: GreedySearch):
        usr_src_map = False
        parallel_paths = decode_strategy.parallel_paths

        src, enc_states, memory_bank, src_lengths = self._run_encoder(
            batch=batch)
        self.model.decoder.init_state(src, memory_bank, enc_states)
        results = {
            "scores": None,
            "predictions": None,
            "attention": None,
        }
        src_map = batch.src_map if usr_src_map else None
        fn_map_state, memory_bank, memory_lengths, src_map = decode_strategy.initialize(
            memory_bank, src_lengths, src_map)
        if fn_map_state is not None:
            self.model.decoder.map_state(fn_map_state)

        for step in range(decode_strategy.max_length):
            decoder_input = decode_strategy.current_predictions.view(1, -1, 1)

            log_probs, attn = self._decoder_and_generate(
                decoder_input,
                memory_bank,
                batch,
                memory_lengths=memory_lengths,
                src_map=src_map,
                step=step,
                batch_offset=decode_strategy.batch_offset)

            decode_strategy.advance(log_probs, attn)
            any_finished = decode_strategy.is_finished.any()
            if any_finished:
                decode_strategy.update_finished()
                if decode_strategy.done:
                    break

            select_indices = decode_strategy.select_indices

            if any_finished:
                # Reorder states.
                if isinstance(memory_bank, tuple):
                    memory_bank = tuple(
                        x.index_select(1, select_indices) for x in memory_bank)
                else:
                    memory_bank = memory_bank.index_select(1, select_indices)

                memory_lengths = memory_lengths.index_select(0, select_indices)

                if src_map is not None:
                    src_map = src_map.index_select(1, select_indices)

            if parallel_paths > 1 or any_finished:
                self.model.decoder.map_state(
                    lambda state, dim: state.index_select(dim, select_indices))

        #    results['log_probs'] = log_probs

        results['scores'] = decode_strategy.scores
        results['predictions'] = decode_strategy.predictions
        # predictions is a list which its len is same as batch_size
        results['attention'] = decode_strategy.attention
        return results