Пример #1
0
 def test_advance_with_all_repeats_gets_blocked(self):
     # all beams repeat (beam >= 1 repeat dummy scores)
     beam_sz = 5
     n_words = 100
     repeat_idx = 47
     ngram_repeat = 3
     for batch_sz in [1, 3]:
         beam = BeamSearch(beam_sz, batch_sz,
                           0, 1, 2, 2, torch.device("cpu"),
                           GlobalScorerStub(), 0, 30, False, ngram_repeat,
                           set(), torch.randint(0, 30, (batch_sz, )), False)
         for i in range(ngram_repeat + 4):
             # predict repeat_idx over and over again
             word_probs = torch.full((batch_sz * beam_sz, n_words),
                                     -float('inf'))
             word_probs[0::beam_sz, repeat_idx] = 0
             attns = torch.randn(1, batch_sz * beam_sz, 53)
             beam.advance(word_probs, attns)
             if i <= ngram_repeat:
                 expected_scores = torch.tensor(
                             [0] + [-float('inf')] * (beam_sz - 1))\
                         .repeat(batch_sz, 1)
                 self.assertTrue(beam.topk_log_probs.equal(expected_scores))
             else:
                 self.assertTrue(
                     beam.topk_log_probs.equal(
                         torch.tensor(self.BLOCKED_SCORE).repeat(
                             batch_sz, beam_sz)))
Пример #2
0
    def test_beam_is_done_when_n_best_beams_eos_using_min_length(self):
        # this is also a test that when block_ngram_repeat=0,
        # repeating is acceptable
        beam_sz = 5
        batch_sz = 3
        n_words = 100
        _non_eos_idxs = [47, 51, 13, 88, 99]
        valid_score_dist = torch.log_softmax(torch.tensor(
            [6., 5., 4., 3., 2., 1.]), dim=0)
        min_length = 5
        eos_idx = 2
        beam = BeamSearch(
            beam_sz, batch_sz, 0, 1, 2, 2,
            GlobalScorerStub(),
            min_length, 30, False, 0, set(),
            False, 0.)
        device_init = torch.zeros(1, 1)
        beam.initialize(device_init, torch.randint(0, 30, (batch_sz,)))
        for i in range(min_length + 4):
            # non-interesting beams are going to get dummy values
            word_probs = torch.full(
                (batch_sz * beam_sz, n_words), -float('inf'))
            if i == 0:
                # "best" prediction is eos - that should be blocked
                word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                # include at least beam_sz predictions OTHER than EOS
                # that are greater than -1e20
                for j, score in zip(_non_eos_idxs, valid_score_dist[1:]):
                    word_probs[0::beam_sz, j] = score
            elif i <= min_length:
                # predict eos in beam 1
                word_probs[1::beam_sz, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz-1, k)
                    word_probs[beam_idx::beam_sz, j] = score
            else:
                word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                word_probs[1::beam_sz, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz-1, k)
                    word_probs[beam_idx::beam_sz, j] = score

            attns = torch.randn(1, batch_sz * beam_sz, 53)
            beam.advance(word_probs, attns)
            if i < min_length:
                self.assertFalse(beam.done)
            elif i == min_length:
                # beam 1 dies on min_length
                self.assertTrue(beam.is_finished[:, 1].all())
                beam.update_finished()
                self.assertFalse(beam.done)
            else:  # i > min_length
                # beam 0 dies on the step after beam 1 dies
                self.assertTrue(beam.is_finished[:, 0].all())
                beam.update_finished()
                self.assertTrue(beam.done)
Пример #3
0
 def test_advance_with_all_repeats_gets_blocked(self):
     # all beams repeat (beam >= 1 repeat dummy scores)
     beam_sz = 5
     n_words = 100
     repeat_idx = 47
     ngram_repeat = 3
     for batch_sz in [1, 3]:
         beam = BeamSearch(
             beam_sz, batch_sz, 0, 1, 2, 2,
             torch.device("cpu"), GlobalScorerStub(), 0, 30,
             False, ngram_repeat, set(),
             torch.randint(0, 30, (batch_sz,)), False, 0.)
         for i in range(ngram_repeat + 4):
             # predict repeat_idx over and over again
             word_probs = torch.full(
                 (batch_sz * beam_sz, n_words), -float('inf'))
             word_probs[0::beam_sz, repeat_idx] = 0
             attns = torch.randn(1, batch_sz * beam_sz, 53)
             beam.advance(word_probs, attns)
             if i <= ngram_repeat:
                 expected_scores = torch.tensor(
                             [0] + [-float('inf')] * (beam_sz - 1))\
                         .repeat(batch_sz, 1)
                 self.assertTrue(beam.topk_log_probs.equal(expected_scores))
             else:
                 self.assertTrue(
                     beam.topk_log_probs.equal(
                         torch.tensor(self.BLOCKED_SCORE)
                         .repeat(batch_sz, beam_sz)))
Пример #4
0
 def test_beam_advance_against_known_reference(self):
     beam = BeamSearch(self.BEAM_SZ, self.BATCH_SZ, 0, 1, 2, self.N_BEST,
                       GlobalScorerStub(), 0, 30, False, 0, set(), False,
                       0.)
     device_init = torch.zeros(1, 1)
     beam.initialize(device_init, torch.randint(0, 30, (self.BATCH_SZ, )))
     expected_beam_scores = self.init_step(beam, 1)
     expected_beam_scores = self.first_step(beam, expected_beam_scores, 1)
     expected_beam_scores = self.second_step(beam, expected_beam_scores, 1)
     self.third_step(beam, expected_beam_scores, 1)
Пример #5
0
 def test_beam_advance_against_known_reference(self):
     scorer = GNMTGlobalScorer(0.7, 0., "avg", "none")
     beam = BeamSearch(self.BEAM_SZ, self.BATCH_SZ, 0, 1, 2, self.N_BEST,
                       scorer, 0, 30, False, 0, set(), False, 0.)
     device_init = torch.zeros(1, 1)
     beam.initialize(device_init, torch.randint(0, 30, (self.BATCH_SZ, )))
     expected_beam_scores = self.init_step(beam, 1.)
     expected_beam_scores = self.first_step(beam, expected_beam_scores, 3)
     expected_beam_scores = self.second_step(beam, expected_beam_scores, 4)
     self.third_step(beam, expected_beam_scores, 5)
Пример #6
0
    def test_doesnt_predict_eos_if_shorter_than_min_len(self):
        # beam 0 will always predict EOS. The other beams will predict
        # non-eos scores.
        for batch_sz in [1, 3]:
            beam_sz = 5
            n_words = 100
            _non_eos_idxs = [47, 51, 13, 88, 99]
            valid_score_dist = torch.log_softmax(torch.tensor(
                [6., 5., 4., 3., 2., 1.]), dim=0)
            min_length = 5
            eos_idx = 2
            lengths = torch.randint(0, 30, (batch_sz,))
            beam = BeamSearch(beam_sz, batch_sz, 0, 1, 2, 2,
                              torch.device("cpu"), GlobalScorerStub(),
                              min_length, 30, False, 0, set(),
                              lengths, False, 0.)
            all_attns = []
            for i in range(min_length + 4):
                # non-interesting beams are going to get dummy values
                word_probs = torch.full(
                    (batch_sz * beam_sz, n_words), -float('inf'))
                if i == 0:
                    # "best" prediction is eos - that should be blocked
                    word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                    # include at least beam_sz predictions OTHER than EOS
                    # that are greater than -1e20
                    for j, score in zip(_non_eos_idxs, valid_score_dist[1:]):
                        word_probs[0::beam_sz, j] = score
                else:
                    # predict eos in beam 0
                    word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                    # provide beam_sz other good predictions
                    for k, (j, score) in enumerate(
                            zip(_non_eos_idxs, valid_score_dist[1:])):
                        beam_idx = min(beam_sz-1, k)
                        word_probs[beam_idx::beam_sz, j] = score

                attns = torch.randn(1, batch_sz * beam_sz, 53)
                all_attns.append(attns)
                beam.advance(word_probs, attns)
                if i < min_length:
                    expected_score_dist = \
                        (i+1) * valid_score_dist[1:].unsqueeze(0)
                    self.assertTrue(
                        beam.topk_log_probs.allclose(
                            expected_score_dist))
                elif i == min_length:
                    # now the top beam has ended and no others have
                    self.assertTrue(beam.is_finished[:, 0].eq(1).all())
                    self.assertTrue(beam.is_finished[:, 1:].eq(0).all())
                else:  # i > min_length
                    # not of interest, but want to make sure it keeps running
                    # since only beam 0 terminates and n_best = 2
                    pass
    def test_doesnt_predict_eos_if_shorter_than_min_len(self):
        # beam 0 will always predict EOS. The other beams will predict
        # non-eos scores.
        for batch_sz in [1, 3]:
            beam_sz = 5
            n_words = 100
            _non_eos_idxs = [47, 51, 13, 88, 99]
            valid_score_dist = torch.log_softmax(torch.tensor(
                [6., 5., 4., 3., 2., 1.]), dim=0)
            min_length = 5
            eos_idx = 2
            lengths = torch.randint(0, 30, (batch_sz,))
            beam = BeamSearch(beam_sz, batch_sz, 0, 1, 2, 2,
                              torch.device("cpu"), GlobalScorerStub(),
                              min_length, 30, False, 0, set(),
                              lengths, False)
            all_attns = []
            for i in range(min_length + 4):
                # non-interesting beams are going to get dummy values
                word_probs = torch.full(
                    (batch_sz * beam_sz, n_words), -float('inf'))
                if i == 0:
                    # "best" prediction is eos - that should be blocked
                    word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                    # include at least beam_sz predictions OTHER than EOS
                    # that are greater than -1e20
                    for j, score in zip(_non_eos_idxs, valid_score_dist[1:]):
                        word_probs[0::beam_sz, j] = score
                else:
                    # predict eos in beam 0
                    word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                    # provide beam_sz other good predictions
                    for k, (j, score) in enumerate(
                            zip(_non_eos_idxs, valid_score_dist[1:])):
                        beam_idx = min(beam_sz-1, k)
                        word_probs[beam_idx::beam_sz, j] = score

                attns = torch.randn(1, batch_sz * beam_sz, 53)
                all_attns.append(attns)
                beam.advance(word_probs, attns)
                if i < min_length:
                    expected_score_dist = \
                        (i+1) * valid_score_dist[1:].unsqueeze(0)
                    self.assertTrue(
                        beam.topk_log_probs.allclose(
                            expected_score_dist))
                elif i == min_length:
                    # now the top beam has ended and no others have
                    self.assertTrue(beam.is_finished[:, 0].eq(1).all())
                    self.assertTrue(beam.is_finished[:, 1:].eq(0).all())
                else:  # i > min_length
                    # not of interest, but want to make sure it keeps running
                    # since only beam 0 terminates and n_best = 2
                    pass
Пример #8
0
 def test_repeating_excluded_index_does_not_die(self):
     # beam 0 and beam >= 2 will repeat (beam 2 repeats excluded idx)
     beam_sz = 5
     n_words = 100
     repeat_idx = 47  # will be repeated and should be blocked
     repeat_idx_ignored = 7  # will be repeated and should not be blocked
     ngram_repeat = 3
     for batch_sz in [1, 3]:
         beam = BeamSearch(
             beam_sz, batch_sz, 0, 1, 2, 2,
             torch.device("cpu"), GlobalScorerStub(), 0, 30,
             False, ngram_repeat, {repeat_idx_ignored},
             torch.randint(0, 30, (batch_sz,)))
         for i in range(ngram_repeat + 4):
             # non-interesting beams are going to get dummy values
             word_probs = torch.full(
                 (batch_sz * beam_sz, n_words), -float('inf'))
             if i == 0:
                 word_probs[0::beam_sz, repeat_idx] = -0.1
                 word_probs[0::beam_sz, repeat_idx + i + 1] = -2.3
                 word_probs[0::beam_sz, repeat_idx_ignored] = -5.0
             else:
                 # predict the same thing in beam 0
                 word_probs[0::beam_sz, repeat_idx] = 0
                 # continue pushing around what beam 1 predicts
                 word_probs[1::beam_sz, repeat_idx + i + 1] = 0
                 # predict the allowed-repeat again in beam 2
                 word_probs[2::beam_sz, repeat_idx_ignored] = 0
             attns = torch.randn(beam_sz)
             beam.advance(word_probs, attns)
             if i <= ngram_repeat:
                 self.assertFalse(beam.topk_log_probs[:, 0].eq(
                     self.BLOCKED_SCORE).any())
                 self.assertFalse(beam.topk_log_probs[:, 1].eq(
                     self.BLOCKED_SCORE).any())
                 self.assertFalse(beam.topk_log_probs[:, 2].eq(
                     self.BLOCKED_SCORE).any())
             else:
                 # now beam 0 dies, beam 1 -> beam 0, beam 2 -> beam 1
                 # and the rest die
                 self.assertFalse(beam.topk_log_probs[:, 0].eq(
                     self.BLOCKED_SCORE).any())
                 # since all preds after i=0 are 0, we can check
                 # that the beam is the correct idx by checking that
                 # the curr score is the initial score
                 self.assertTrue(beam.topk_log_probs[:, 0].eq(-2.3).all())
                 self.assertFalse(beam.topk_log_probs[:, 1].eq(
                     self.BLOCKED_SCORE).all())
                 self.assertTrue(beam.topk_log_probs[:, 1].eq(-5.0).all())
                 self.assertTrue(
                     beam.topk_log_probs[:, 2:].equal(
                         torch.tensor(self.BLOCKED_SCORE)
                         .repeat(batch_sz, beam_sz - 2)))
Пример #9
0
 def test_repeating_excluded_index_does_not_die(self):
     # beam 0 and beam >= 2 will repeat (beam 2 repeats excluded idx)
     beam_sz = 5
     n_words = 100
     repeat_idx = 47  # will be repeated and should be blocked
     repeat_idx_ignored = 7  # will be repeated and should not be blocked
     ngram_repeat = 3
     for batch_sz in [1, 3]:
         beam = BeamSearch(
             beam_sz, batch_sz, 0, 1, 2, 2,
             torch.device("cpu"), GlobalScorerStub(), 0, 30,
             False, ngram_repeat, {repeat_idx_ignored},
             torch.randint(0, 30, (batch_sz,)), False, 0.)
         for i in range(ngram_repeat + 4):
             # non-interesting beams are going to get dummy values
             word_probs = torch.full(
                 (batch_sz * beam_sz, n_words), -float('inf'))
             if i == 0:
                 word_probs[0::beam_sz, repeat_idx] = -0.1
                 word_probs[0::beam_sz, repeat_idx + i + 1] = -2.3
                 word_probs[0::beam_sz, repeat_idx_ignored] = -5.0
             else:
                 # predict the same thing in beam 0
                 word_probs[0::beam_sz, repeat_idx] = 0
                 # continue pushing around what beam 1 predicts
                 word_probs[1::beam_sz, repeat_idx + i + 1] = 0
                 # predict the allowed-repeat again in beam 2
                 word_probs[2::beam_sz, repeat_idx_ignored] = 0
             attns = torch.randn(1, batch_sz * beam_sz, 53)
             beam.advance(word_probs, attns)
             if i <= ngram_repeat:
                 self.assertFalse(beam.topk_log_probs[:, 0].eq(
                     self.BLOCKED_SCORE).any())
                 self.assertFalse(beam.topk_log_probs[:, 1].eq(
                     self.BLOCKED_SCORE).any())
                 self.assertFalse(beam.topk_log_probs[:, 2].eq(
                     self.BLOCKED_SCORE).any())
             else:
                 # now beam 0 dies, beam 1 -> beam 0, beam 2 -> beam 1
                 # and the rest die
                 self.assertFalse(beam.topk_log_probs[:, 0].eq(
                     self.BLOCKED_SCORE).any())
                 # since all preds after i=0 are 0, we can check
                 # that the beam is the correct idx by checking that
                 # the curr score is the initial score
                 self.assertTrue(beam.topk_log_probs[:, 0].eq(-2.3).all())
                 self.assertFalse(beam.topk_log_probs[:, 1].eq(
                     self.BLOCKED_SCORE).all())
                 self.assertTrue(beam.topk_log_probs[:, 1].eq(-5.0).all())
                 self.assertTrue(
                     beam.topk_log_probs[:, 2:].equal(
                         torch.tensor(self.BLOCKED_SCORE)
                         .repeat(batch_sz, beam_sz - 2)))
Пример #10
0
 def translate_batch(self, batch, src_vocabs, attn_debug):
     """Translate a batch of sentences."""
     with torch.no_grad():
         if self.beam_size == 1:
             decode_strategy = GreedySearch(
                 pad=self._tgt_pad_idx,
                 bos=self._tgt_bos_idx,
                 eos=self._tgt_eos_idx,
                 batch_size=batch.batch_size,
                 min_length=self.min_length, max_length=self.max_length,
                 block_ngram_repeat=self.block_ngram_repeat,
                 exclusion_tokens=self._exclusion_idxs,
                 return_attention=attn_debug or self.replace_unk,
                 sampling_temp=self.random_sampling_temp,
                 keep_topk=self.sample_from_topk)
         else:
             # TODO: support these blacklisted features
             assert not self.dump_beam
             decode_strategy = BeamSearch(
                 self.beam_size,
                 batch_size=batch.batch_size,
                 pad=self._tgt_pad_idx,
                 bos=self._tgt_bos_idx,
                 eos=self._tgt_eos_idx,
                 n_best=self.n_best,
                 global_scorer=self.global_scorer,
                 min_length=self.min_length, max_length=self.max_length,
                 return_attention=attn_debug or self.replace_unk,
                 block_ngram_repeat=self.block_ngram_repeat,
                 exclusion_tokens=self._exclusion_idxs,
                 stepwise_penalty=self.stepwise_penalty,
                 ratio=self.ratio)
         return self._translate_batch_with_strategy(batch, src_vocabs,
                                                    decode_strategy)
Пример #11
0
 def test_advance_with_some_repeats_gets_blocked(self):
     # beam 0 and beam >=2 will repeat (beam >= 2 repeat dummy scores)
     beam_sz = 5
     n_words = 100
     repeat_idx = 47
     ngram_repeat = 3
     for batch_sz in [1, 3]:
         beam = BeamSearch(
             beam_sz, batch_sz, 0, 1, 2, 2,
             torch.device("cpu"), GlobalScorerStub(), 0, 30,
             False, ngram_repeat, set(),
             torch.randint(0, 30, (batch_sz,)))
         for i in range(ngram_repeat + 4):
             # non-interesting beams are going to get dummy values
             word_probs = torch.full(
                 (batch_sz * beam_sz, n_words), -float('inf'))
             if i == 0:
                 # on initial round, only predicted scores for beam 0
                 # matter. Make two predictions. Top one will be repeated
                 # in beam zero, second one will live on in beam 1.
                 word_probs[0::beam_sz, repeat_idx] = -0.1
                 word_probs[0::beam_sz, repeat_idx + i + 1] = -2.3
             else:
                 # predict the same thing in beam 0
                 word_probs[0::beam_sz, repeat_idx] = 0
                 # continue pushing around what beam 1 predicts
                 word_probs[1::beam_sz, repeat_idx + i + 1] = 0
             attns = torch.randn(beam_sz)
             beam.advance(word_probs, attns)
             if i <= ngram_repeat:
                 self.assertFalse(
                     beam.topk_log_probs[0::beam_sz].eq(
                         self.BLOCKED_SCORE).any())
                 self.assertFalse(
                     beam.topk_log_probs[1::beam_sz].eq(
                         self.BLOCKED_SCORE).any())
             else:
                 # now beam 0 dies (along with the others), beam 1 -> beam 0
                 self.assertFalse(
                     beam.topk_log_probs[:, 0].eq(
                         self.BLOCKED_SCORE).any())
                 self.assertTrue(
                     beam.topk_log_probs[:, 1:].equal(
                         torch.tensor(self.BLOCKED_SCORE)
                         .repeat(batch_sz, beam_sz-1)))
Пример #12
0
 def test_advance_with_some_repeats_gets_blocked(self):
     # beam 0 and beam >=2 will repeat (beam >= 2 repeat dummy scores)
     beam_sz = 5
     n_words = 100
     repeat_idx = 47
     ngram_repeat = 3
     for batch_sz in [1, 3]:
         beam = BeamSearch(
             beam_sz, batch_sz, 0, 1, 2, 2,
             torch.device("cpu"), GlobalScorerStub(), 0, 30,
             False, ngram_repeat, set(),
             torch.randint(0, 30, (batch_sz,)), False, 0.)
         for i in range(ngram_repeat + 4):
             # non-interesting beams are going to get dummy values
             word_probs = torch.full(
                 (batch_sz * beam_sz, n_words), -float('inf'))
             if i == 0:
                 # on initial round, only predicted scores for beam 0
                 # matter. Make two predictions. Top one will be repeated
                 # in beam zero, second one will live on in beam 1.
                 word_probs[0::beam_sz, repeat_idx] = -0.1
                 word_probs[0::beam_sz, repeat_idx + i + 1] = -2.3
             else:
                 # predict the same thing in beam 0
                 word_probs[0::beam_sz, repeat_idx] = 0
                 # continue pushing around what beam 1 predicts
                 word_probs[1::beam_sz, repeat_idx + i + 1] = 0
             attns = torch.randn(1, batch_sz * beam_sz, 53)
             beam.advance(word_probs, attns)
             if i <= ngram_repeat:
                 self.assertFalse(
                     beam.topk_log_probs[0::beam_sz].eq(
                         self.BLOCKED_SCORE).any())
                 self.assertFalse(
                     beam.topk_log_probs[1::beam_sz].eq(
                         self.BLOCKED_SCORE).any())
             else:
                 # now beam 0 dies (along with the others), beam 1 -> beam 0
                 self.assertFalse(
                     beam.topk_log_probs[:, 0].eq(
                         self.BLOCKED_SCORE).any())
                 self.assertTrue(
                     beam.topk_log_probs[:, 1:].equal(
                         torch.tensor(self.BLOCKED_SCORE)
                         .repeat(batch_sz, beam_sz-1)))
Пример #13
0
    def test_beam_advance_against_known_reference(self):
        beam = BeamSearch(self.BEAM_SZ, self.BATCH_SZ, 0, 1, 2, self.N_BEST,
                          torch.device("cpu"), GlobalScorerStub(), 0, 30,
                          False, 0, set(),
                          torch.randint(0, 30, (self.BATCH_SZ, )))

        expected_beam_scores = self.init_step(beam)
        expected_beam_scores = self.first_step(beam, expected_beam_scores)
        expected_beam_scores = self.second_step(beam, expected_beam_scores)
        self.third_step(beam, expected_beam_scores)
Пример #14
0
    def test_beam_is_done_when_n_best_beams_eos_using_min_length(self):
        # this is also a test that when block_ngram_repeat=0,
        # repeating is acceptable
        beam_sz = 5
        batch_sz = 3
        n_words = 100
        _non_eos_idxs = [47, 51, 13, 88, 99]
        valid_score_dist = torch.log_softmax(torch.tensor(
            [6., 5., 4., 3., 2., 1.]), dim=0)
        min_length = 5
        eos_idx = 2
        beam = BeamSearch(
            beam_sz, batch_sz, 0, 1, 2, 2,
            torch.device("cpu"), GlobalScorerStub(),
            min_length, 30, False, 0, set(),
            torch.randint(0, 30, (batch_sz,)), False, 0.)
        for i in range(min_length + 4):
            # non-interesting beams are going to get dummy values
            word_probs = torch.full(
                (batch_sz * beam_sz, n_words), -float('inf'))
            if i == 0:
                # "best" prediction is eos - that should be blocked
                word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                # include at least beam_sz predictions OTHER than EOS
                # that are greater than -1e20
                for j, score in zip(_non_eos_idxs, valid_score_dist[1:]):
                    word_probs[0::beam_sz, j] = score
            elif i <= min_length:
                # predict eos in beam 1
                word_probs[1::beam_sz, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz-1, k)
                    word_probs[beam_idx::beam_sz, j] = score
            else:
                word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                word_probs[1::beam_sz, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz-1, k)
                    word_probs[beam_idx::beam_sz, j] = score

            attns = torch.randn(1, batch_sz * beam_sz, 53)
            beam.advance(word_probs, attns)
            if i < min_length:
                self.assertFalse(beam.done)
            elif i == min_length:
                # beam 1 dies on min_length
                self.assertTrue(beam.is_finished[:, 1].all())
                beam.update_finished()
                self.assertFalse(beam.done)
            else:  # i > min_length
                # beam 0 dies on the step after beam 1 dies
                self.assertTrue(beam.is_finished[:, 0].all())
                beam.update_finished()
                self.assertTrue(beam.done)
Пример #15
0
    def translate_batch(self, batch, src_vocabs, attn_debug):
        """Translate a batch of sentences."""
        with torch.no_grad():
            tic = time.perf_counter()
            if self.beam_size == 1:
                decode_strategy = GreedySearch(
                    pad=self._tgt_pad_idx,
                    bos=self._tgt_bos_idx,
                    eos=self._tgt_eos_idx,
                    batch_size=batch.batch_size,
                    min_length=self.min_length,
                    max_length=self.max_length,
                    block_ngram_repeat=self.block_ngram_repeat,
                    exclusion_tokens=self._exclusion_idxs,
                    return_attention=attn_debug or self.replace_unk,
                    sampling_temp=self.random_sampling_temp,
                    keep_topk=self.sample_from_topk)
            else:
                # TODO: support these blacklisted features
                assert not self.dump_beam

                decode_strategy = BeamSearch(
                    self.beam_size,
                    batch_size=batch.batch_size,
                    pad=self._tgt_pad_idx,
                    bos=self._tgt_bos_idx,
                    eos=self._tgt_eos_idx,
                    n_best=self.n_best,
                    global_scorer=self.global_scorer,
                    min_length=self.min_length,
                    max_length=self.max_length,
                    return_attention=attn_debug or self.replace_unk,
                    block_ngram_repeat=self.block_ngram_repeat,
                    exclusion_tokens=self._exclusion_idxs,
                    stepwise_penalty=self.stepwise_penalty,
                    ratio=self.ratio)
            toc = time.perf_counter()
            beam_search_time = toc - tic

            tic = time.perf_counter()
            ret = self._translate_batch_with_strategy(batch, src_vocabs,
                                                      decode_strategy)
            toc = time.perf_counter()
            translate_batch_with_strategy_time = toc - tic

            if show_profile_detail:
                print(
                    f"BeamSearch Time {beam_search_time:0.4f} seconds, translate_batch_with_strategy Time {translate_batch_with_strategy_time: 0.4f} seconds"
                )
            return ret
    def test_advance_with_some_repeats_gets_blocked(self):
        # beam 0 and beam >=2 will repeat (beam >= 2 repeat dummy scores)
        beam_sz = 5
        n_words = 100
        repeat_idx = 47
        ngram_repeat = 3
        no_repeat_score = -2.3
        repeat_score = -0.1
        device_init = torch.zeros(1, 1)
        for batch_sz in [1, 3]:
            beam = BeamSearch(beam_sz, batch_sz, 0, 1, 2, 2,
                              GlobalScorerStub(), 0, 30, False, ngram_repeat,
                              set(), False, 0.)
            beam.initialize(device_init, torch.randint(0, 30, (batch_sz, )))
            for i in range(ngram_repeat + 4):
                # non-interesting beams are going to get dummy values
                word_probs = torch.full((batch_sz * beam_sz, n_words),
                                        -float('inf'))
                if i == 0:
                    # on initial round, only predicted scores for beam 0
                    # matter. Make two predictions. Top one will be repeated
                    # in beam zero, second one will live on in beam 1.
                    word_probs[0::beam_sz, repeat_idx] = repeat_score
                    word_probs[0::beam_sz,
                               repeat_idx + i + 1] = no_repeat_score
                else:
                    # predict the same thing in beam 0
                    word_probs[0::beam_sz, repeat_idx] = 0
                    # continue pushing around what beam 1 predicts
                    word_probs[1::beam_sz, repeat_idx + i + 1] = 0
                attns = torch.randn(1, batch_sz * beam_sz, 53)
                beam.advance(word_probs, attns)
                if i < ngram_repeat:
                    self.assertFalse(beam.topk_log_probs[0::beam_sz].eq(
                        self.BLOCKED_SCORE).any())
                    self.assertFalse(beam.topk_log_probs[1::beam_sz].eq(
                        self.BLOCKED_SCORE).any())
                elif i == ngram_repeat:
                    # now beam 0 dies (along with the others), beam 1 -> beam 0
                    self.assertFalse(beam.topk_log_probs[:, 0].eq(
                        self.BLOCKED_SCORE).any())

                    expected = torch.full([batch_sz, beam_sz], float("-inf"))
                    expected[:, 0] = no_repeat_score
                    expected[:, 1] = self.BLOCKED_SCORE
                    self.assertTrue(beam.topk_log_probs[:, :].equal(expected))
                else:
                    # now beam 0 dies (along with the others), beam 1 -> beam 0
                    self.assertFalse(beam.topk_log_probs[:, 0].eq(
                        self.BLOCKED_SCORE).any())

                    expected = torch.full([batch_sz, beam_sz], float("-inf"))
                    expected[:, 0] = no_repeat_score
                    expected[:, 1:3] = self.BLOCKED_SCORE
                    expected[:, 3:] = float("-inf")
                    self.assertTrue(beam.topk_log_probs.equal(expected))
Пример #17
0
    def forward_dev_beam_search(self, encoder_output: torch.Tensor, pad_mask):
        batch_size = encoder_output.size(1)

        self.state["cache"] = None
        memory_lengths = pad_mask.ne(pad_token_index).sum(dim=0)

        self.map_state(lambda state, dim: tile(state, self.beam_size, dim=dim))
        encoder_output = tile(encoder_output, self.beam_size, dim=1)
        pad_mask = tile(pad_mask, self.beam_size, dim=1)
        memory_lengths = tile(memory_lengths, self.beam_size, dim=0)

        # TODO:
        #  - fix attn (?)
        #  - use coverage_penalty="summary" ou "wu" and beta=0.2 (ou pas)
        #  - use length_penalty="wu" and alpha=0.2 (ou pas)
        beam = BeamSearch(beam_size=self.beam_size, n_best=1, batch_size=batch_size, mb_device=default_device,
                          global_scorer=GNMTGlobalScorer(alpha=0, beta=0, coverage_penalty="none", length_penalty="avg"),
                          pad=pad_token_index, eos=eos_token_index, bos=bos_token_index, min_length=1, max_length=100,
                          return_attention=False, stepwise_penalty=False, block_ngram_repeat=0, exclusion_tokens=set(),
                          memory_lengths=memory_lengths, ratio=-1)

        for i in range(self.max_seq_out_len):
            inp = beam.current_predictions.view(1, -1)

            out, attn = self.forward_step(src=pad_mask, tgt=inp, memory_bank=encoder_output, step=i)  # 1 x batch*beam x hidden
            out = self.linear(out)  # 1 x batch*beam x vocab_out
            out = log_softmax(out, dim=2)  # 1 x batch*beam x vocab_out

            out = out.squeeze(0)  # batch*beam x vocab_out
            # attn = attn.squeeze(0)  # batch*beam x vocab_out
            # out = out.view(batch_size, self.beam_size, -1)  # batch x beam x vocab_out
            # attn = attn.view(batch_size, self.beam_size, -1)
            # TODO: fix attn (?)

            beam.advance(out, attn)
            any_beam_is_finished = beam.is_finished.any()
            if any_beam_is_finished:
                beam.update_finished()
                if beam.done:
                    break

            select_indices = beam.current_origin

            if any_beam_is_finished:
                # Reorder states.
                encoder_output = encoder_output.index_select(1, select_indices)
                pad_mask = pad_mask.index_select(1, select_indices)
                memory_lengths = memory_lengths.index_select(0, select_indices)

            self.map_state(lambda state, dim: state.index_select(dim, select_indices))

        outputs = beam.predictions
        outputs = [x[0] for x in outputs]
        outputs = pad_sequence(outputs, batch_first=True)
        return [outputs]
Пример #18
0
 def translate_batch(self, batch, src_vocabs, attn_debug, src=None, enc_states=None, memory_bank=None, \
                     src_lengths=None,  src_embed=None, tgt2=False, hidden_state=None):
     """Translate a batch of sentences."""
     if self.beam_size == 1:
         decode_strategy = GreedySearch(
             pad=self._tgt_pad_idx,
             bos=self._tgt_bos_idx,
             eos=self._tgt_eos_idx,
             batch_size=batch.batch_size,
             min_length=self.min_length,
             max_length=self.max_length,
             block_ngram_repeat=self.block_ngram_repeat,
             exclusion_tokens=self._exclusion_idxs,
             return_attention=attn_debug or self.replace_unk,
             sampling_temp=self.random_sampling_temp,
             keep_topk=self.sample_from_topk)
     else:
         # TODO: support these blacklisted features
         assert not self.dump_beam
         decode_strategy = BeamSearch(
             self.beam_size,
             batch_size=batch.batch_size,
             pad=self._tgt_pad_idx,
             bos=self._tgt_bos_idx,
             eos=self._tgt_eos_idx,
             n_best=self.n_best,
             global_scorer=self.global_scorer,
             min_length=self.min_length,
             max_length=self.max_length,
             return_attention=attn_debug or self.replace_unk,
             block_ngram_repeat=self.block_ngram_repeat,
             exclusion_tokens=self._exclusion_idxs,
             stepwise_penalty=self.stepwise_penalty,
             ratio=self.ratio)
     return self._return_gold(batch,
                              src_vocabs,
                              decode_strategy,
                              src,
                              enc_states,
                              memory_bank,
                              src_lengths,
                              src_embed,
                              tgt2,
                              hidden_state=hidden_state)
Пример #19
0
    def test_advance_with_all_repeats_gets_blocked(self):
        # all beams repeat (beam >= 1 repeat dummy scores)
        beam_sz = 5
        n_words = 100
        repeat_idx = 47
        ngram_repeat = 3
        device_init = torch.zeros(1, 1)
        for batch_sz in [1, 3]:
            beam = BeamSearch(beam_sz, batch_sz, 0, 1, 2, 2,
                              GlobalScorerStub(), 0, 30, False, ngram_repeat,
                              set(), False, 0.)
            beam.initialize(device_init, torch.randint(0, 30, (batch_sz, )))
            for i in range(ngram_repeat + 4):
                # predict repeat_idx over and over again
                word_probs = torch.full((batch_sz * beam_sz, n_words),
                                        -float('inf'))
                word_probs[0::beam_sz, repeat_idx] = 0

                attns = torch.randn(1, batch_sz * beam_sz, 53)
                beam.advance(word_probs, attns)

                if i < ngram_repeat:
                    # before repeat, scores are either 0 or -inf
                    expected_scores = torch.tensor(
                        [0] + [-float('inf')] * (beam_sz - 1))\
                        .repeat(batch_sz, 1)
                    self.assertTrue(beam.topk_log_probs.equal(expected_scores))
                elif i % ngram_repeat == 0:
                    # on repeat, `repeat_idx` score is BLOCKED_SCORE
                    # (but it's still the best score, thus we have
                    # [BLOCKED_SCORE, -inf, -inf, -inf, -inf]
                    expected_scores = torch.tensor(
                        [0] + [-float('inf')] * (beam_sz - 1))\
                        .repeat(batch_sz, 1)
                    expected_scores[:, 0] = self.BLOCKED_SCORE
                    self.assertTrue(beam.topk_log_probs.equal(expected_scores))
                else:
                    # repetitions keeps maximizing score
                    # index 0 has been blocked, so repeating=>+0.0 score
                    # other indexes are -inf so repeating=>BLOCKED_SCORE
                    # which is higher
                    expected_scores = torch.tensor(
                        [0] + [-float('inf')] * (beam_sz - 1))\
                        .repeat(batch_sz, 1)
                    expected_scores[:, :] = self.BLOCKED_SCORE
                    expected_scores = torch.tensor(self.BLOCKED_SCORE).repeat(
                        batch_sz, beam_sz)
Пример #20
0
    def _translate_batch(
            self,
            batch,
            src_vocabs,
            max_length,
            min_length=0,
            ratio=0.,
            n_best=1,
            return_attention=False):
        # TODO: support these blacklisted features.
        assert not self.dump_beam

        # (0) Prep the components of the search.
        use_src_map = self.copy_attn
        beam_size = self.beam_size
        batch_size = batch.batch_size

        # (1) Run the encoder on the src.
        src, enc_states, memory_bank, src_lengths = self._run_encoder(batch)
        self.model.decoder.init_state(src, memory_bank, enc_states)

        results = {
            "predictions": None,
            "scores": None,
            "attention": None,
            "batch": batch,
            "gold_score": self._gold_score(
                batch, memory_bank, src_lengths, src_vocabs, use_src_map,
                enc_states, batch_size, src)}

        # (2) Repeat src objects `beam_size` times.
        # We use batch_size x beam_size
        src_map = (tile(batch.src_map, beam_size, dim=1)
                   if use_src_map else None)
        self.model.decoder.map_state(
            lambda state, dim: tile(state, beam_size, dim=dim))

        if isinstance(memory_bank, tuple):
            memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank)
            mb_device = memory_bank[0].device
        else:
            memory_bank = tile(memory_bank, beam_size, dim=1)
            mb_device = memory_bank.device
        memory_lengths = tile(src_lengths, beam_size)

        # (0) pt 2, prep the beam object
        beam = BeamSearch(
            beam_size,
            n_best=n_best,
            batch_size=batch_size,
            global_scorer=self.global_scorer,
            pad=self._tgt_pad_idx,
            eos=self._tgt_eos_idx,
            bos=self._tgt_bos_idx,
            min_length=min_length,
            ratio=ratio,
            max_length=max_length,
            mb_device=mb_device,
            return_attention=return_attention,
            stepwise_penalty=self.stepwise_penalty,
            block_ngram_repeat=self.block_ngram_repeat,
            exclusion_tokens=self._exclusion_idxs,
            memory_lengths=memory_lengths)

        for step in range(max_length):
            decoder_input = beam.current_predictions.view(1, -1, 1)

            log_probs, attn = self._decode_and_generate(
                decoder_input,
                memory_bank,
                batch,
                src_vocabs,
                memory_lengths=memory_lengths,
                src_map=src_map,
                step=step,
                batch_offset=beam._batch_offset)

            beam.advance(log_probs, attn)
            any_beam_is_finished = beam.is_finished.any()
            if any_beam_is_finished:
                beam.update_finished()
                if beam.done:
                    break

            select_indices = beam.current_origin

            if any_beam_is_finished:
                # Reorder states.
                if isinstance(memory_bank, tuple):
                    memory_bank = tuple(x.index_select(1, select_indices)
                                        for x in memory_bank)
                else:
                    memory_bank = memory_bank.index_select(1, select_indices)

                memory_lengths = memory_lengths.index_select(0, select_indices)

                if src_map is not None:
                    src_map = src_map.index_select(1, select_indices)

            self.model.decoder.map_state(
                lambda state, dim: state.index_select(dim, select_indices))

        results["scores"] = beam.scores
        results["predictions"] = beam.predictions
        results["attention"] = beam.attention
        return results
    def _translate_batch(
            self,
            batch,
            src_vocabs,
            max_length,
            min_length=0,
            ratio=0.,
            n_best=1,
            return_attention=False):
        # TODO: support these blacklisted features.
        assert not self.dump_beam

        # (0) Prep the components of the search.
        use_src_map = self.copy_attn
        beam_size = self.beam_size
        batch_size = batch.batch_size

        # (1) Run the encoder on the src.
        src, enc_states, memory_bank, src_lengths = self._run_encoder(batch)
        self.model.decoder.init_state(src, memory_bank, enc_states)

        results = {
            "predictions": None,
            "scores": None,
            "attention": None,
            "batch": batch,
            "gold_score": self._gold_score(
                batch, memory_bank, src_lengths, src_vocabs, use_src_map,
                enc_states, batch_size, src)}

        # (2) Repeat src objects `beam_size` times.
        # We use batch_size x beam_size
        src_map = (tile(batch.src_map, beam_size, dim=1)
                   if use_src_map else None)
        self.model.decoder.map_state(
            lambda state, dim: tile(state, beam_size, dim=dim))

        if isinstance(memory_bank, tuple):
            memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank)
            mb_device = memory_bank[0].device
        else:
            memory_bank = tile(memory_bank, beam_size, dim=1)
            mb_device = memory_bank.device
        memory_lengths = tile(src_lengths, beam_size)

        # (0) pt 2, prep the beam object
        beam = BeamSearch(
            beam_size,
            n_best=n_best,
            batch_size=batch_size,
            global_scorer=self.global_scorer,
            pad=self._tgt_pad_idx,
            eos=self._tgt_eos_idx,
            bos=self._tgt_bos_idx,
            min_length=min_length,
            ratio=ratio,
            max_length=max_length,
            mb_device=mb_device,
            return_attention=return_attention,
            stepwise_penalty=self.stepwise_penalty,
            block_ngram_repeat=self.block_ngram_repeat,
            exclusion_tokens=self._exclusion_idxs,
            memory_lengths=memory_lengths)

        states = None

        #see = 23
        see = 0
        if self.constraint:
            itos = self.fields["tgt"].base_field.vocab.itos
            stoi = self.fields["tgt"].base_field.vocab.stoi
            states = BB_sequence_state(
                itos,
                stoi,
                mb_device,
                batch_size,
                beam_size,
                eos=self._tgt_eos_idx)

        for step in range(max_length):
            decoder_input = beam.current_predictions.view(1, -1, 1)
            #print("================= step",step)
            #print(decoder_input[0][see*10:(see+1)*10,0])
            log_probs, attn = self._decode_and_generate(
                decoder_input,
                states,
                memory_bank,
                batch,
                src_vocabs,
                memory_lengths=memory_lengths,
                src_map=src_map,
                step=step,
                batch_offset=beam._batch_offset)

            beam.advance(log_probs, attn)

            lastest_action = beam.current_predictions.data.tolist()
            lastest_score = beam.current_scores.view(-1).data.tolist()
            select_indices = beam.current_origin

            #print(select_indices[see*10:see*10+10] % 10)
            #print(beam.current_scores.view(-1)[see*10:see*10+10] % 10)
            #print(lastest_action[see*10:(see+1)*10])
            #print(lastest_score[see*10:(see+1)*10])
            #print(lastest_action)
            #print(lastest_score)
            #for act in lastest_action[see*10:(see+1)*10]:
            #    if act < len(self.fields["tgt"].base_field.vocab.itos):
            #        print(act, self.fields["tgt"].base_field.vocab.itos[act], end=" | ")
            #    else:
            #        print(act, "copy", end=" | ")
            #print()
            #for act in lastest_action[::10]:
            #    if act < len(self.fields["tgt"].base_field.vocab.itos):
            #        print(self.fields["tgt"].base_field.vocab.itos[act], end=" ")
            #    else:
            #        print("copy", end=" ")
            #print()
            #print(lastest_score)
            if states is not None:
                states.update_beam(lastest_action, select_indices.data.tolist(), lastest_score)
            #print(select_indices[see*10:see*10+10] % 10)
            #for i in range(10):
            #    states.states[see*10+i].print()
            #print()
            any_beam_is_finished = beam.is_finished.any()
            if any_beam_is_finished:
            #    print("any_beam_is_finished")
                finished_batch = beam.update_finished()
                cnt = 0
                for bidx in finished_batch:
                    if bidx < see:
                        cnt += 1
                see -= cnt
                #exit()
                if beam.done:
                    break

            select_indices = beam.current_origin
            #print("REDUCE",select_indices.size())
            
            if any_beam_is_finished:
                # Reorder states.
                if states is not None:
                    states.index_select(select_indices)
                if isinstance(memory_bank, tuple):
                    memory_bank = tuple(x.index_select(1, select_indices)
                                        for x in memory_bank)
                else:
                    memory_bank = memory_bank.index_select(1, select_indices)

                memory_lengths = memory_lengths.index_select(0, select_indices)

                if src_map is not None:
                    src_map = src_map.index_select(1, select_indices)

            self.model.decoder.map_state(
                lambda state, dim: state.index_select(dim, select_indices))

        results["scores"] = beam.scores
        results["predictions"] = beam.predictions
        results["attention"] = beam.attention
        return results
Пример #22
0
    def _translate_batch(
            self,
            batch,
            src_vocabs,
            max_length,
            min_length=0,
            ratio=0.,
            n_best=1,
            return_attention=False,tags=[]):
        # TODO: support these blacklisted features.
        assert not self.dump_beam

        # (0) Prep the components of the search.
        use_src_map = self.copy_attn
        beam_size = self.beam_size
        batch_size = batch.batch_size
        tags = self.ctags 

        if tags[-1] =="EN" or tags[-1]=="DE":
            lang = tags[-1]
            tags = tags[:-1]
        else:
            lang =None
            assert(False)

        if lang is not None:
            enc1 = self.model.decoder
            enc2 = self.model.decoder2
            if lang =="EN":
                self.model.decoder=enc2


        # (1) Run the encoder on the src.
        allstuff = self._run_encoder(batch,tags=tags)
        #print (len(allstuff))
        if len(allstuff) == 3:
            src, enc_states, memory_bank,  = allstuff
        elif len(allstuff) == 4:
            src, enc_states, memory_bank, src_lengths  =allstuff


        thing =  (enc_states.data.cpu().numpy())
        lengths =  (src_lengths.data.cpu().numpy())
        maxvecs = [] 


        self.model.decoder.init_state(src, memory_bank, enc_states)

        results = {
            "predictions": None,
            "scores": None,
            "attention": None,
            "batch": batch,
            "gold_score": self._gold_score(
                batch, memory_bank, src_lengths, src_vocabs, use_src_map,
                enc_states, batch_size, src)}

        # (2) Repeat src objects `beam_size` times.
        # We use batch_size x beam_size
        src_map = (tile(batch.src_map, beam_size, dim=1)
                   if use_src_map else None)

        self.model.decoder.map_state(
                lambda state, dim: tile(state, beam_size, dim=dim))

        if isinstance(memory_bank, tuple):
            memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank)
            mb_device = memory_bank[0].device
        else:
            memory_bank = tile(memory_bank, beam_size, dim=1)
            mb_device = memory_bank.device
        memory_lengths = tile(src_lengths, beam_size)
        # (0) pt 2, prep the beam object

        if hasattr(batch,"tgt"):
            print ("tgt")
            tgt  = (batch.tgt.data.cpu().numpy()).T
        else:
            print ("no tgt")
            tgt = [[] for _ in range(100000)] #
        beam = BeamSearch(
            beam_size,
            n_best=n_best,
            batch_size=batch_size,
            global_scorer=self.global_scorer,
            pad=self._tgt_pad_idx,
            eos=self._tgt_eos_idx,
            bos=self._tgt_bos_idx,
            min_length=min_length,
            ratio=ratio,
            max_length=max_length,
            mb_device=mb_device,
            return_attention=return_attention,
            stepwise_penalty=self.stepwise_penalty,
            block_ngram_repeat=self.block_ngram_repeat,
            exclusion_tokens=self._exclusion_idxs,
            memory_lengths=memory_lengths,i2w=self._tgt_vocab.itos,batch=batch)
        if hasattr(batch,"tgt"):

        
            tgt  = (batch.tgt.data.cpu().numpy()).T.squeeze()

        else:

            tgt = [[] for _ in range(100000)] #
        for step in range(max_length):
            decoder_input = beam.current_predictions.view(1, -1, 1)

            log_probs, attn = self._decode_and_generate(
                decoder_input,
                memory_bank,
                batch,
                src_vocabs,
                memory_lengths=memory_lengths,
                src_map=src_map,
                step=step,
                batch_offset=beam._batch_offset)
            
            beam.advance(log_probs, attn)
            any_beam_is_finished = beam.is_finished.any()
            if any_beam_is_finished:
                beam.update_finished()
                if beam.done:
                    break

            select_indices = beam.current_origin
            if any_beam_is_finished:
                # Reorder states.
                if isinstance(memory_bank, tuple):
                    memory_bank = tuple(x.index_select(1, select_indices)
                                        for x in memory_bank)
                else:
                    memory_bank = memory_bank.index_select(1, select_indices)

                memory_lengths = memory_lengths.index_select(0, select_indices)

                if src_map is not None:
                    src_map = src_map.index_select(1, select_indices)
                    print ("NOW")
                    print (src_map)
            if lang is not None and lang=="EN" and False:
                self.model.decoder2.map_state(
                    lambda state, dim: state.index_select(dim, select_indices))
            else:
                self.model.decoder.map_state(
                    lambda state, dim: state.index_select(dim, select_indices))
        #print (len(beam.scores))
        results["scores"] = beam.scores
        results["predictions"] = beam.predictions
        results["attention"] = beam.attention
        results["maxvecs"] = []
        if lang is not None:
            self.decoder=enc1
        return results
Пример #23
0
    def test_beam_returns_attn_with_correct_length(self):
        beam_sz = 5
        batch_sz = 3
        n_words = 100
        _non_eos_idxs = [47, 51, 13, 88, 99]
        valid_score_dist = torch.log_softmax(torch.tensor(
            [6., 5., 4., 3., 2., 1.]),
                                             dim=0)
        min_length = 5
        eos_idx = 2
        inp_lens = torch.randint(1, 30, (batch_sz, ))
        beam = BeamSearch(beam_sz, batch_sz, 0, 1, 2, 2, GlobalScorerStub(),
                          min_length, 30, True, 0, set(), False, 0.)
        device_init = torch.zeros(1, 1)
        _, _, inp_lens, _ = beam.initialize(device_init, inp_lens)
        # inp_lens is tiled in initialize, reassign to make attn match
        for i in range(min_length + 2):
            # non-interesting beams are going to get dummy values
            word_probs = torch.full((batch_sz * beam_sz, n_words),
                                    -float('inf'))
            if i == 0:
                # "best" prediction is eos - that should be blocked
                word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                # include at least beam_sz predictions OTHER than EOS
                # that are greater than -1e20
                for j, score in zip(_non_eos_idxs, valid_score_dist[1:]):
                    word_probs[0::beam_sz, j] = score
            elif i <= min_length:
                # predict eos in beam 1
                word_probs[1::beam_sz, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz - 1, k)
                    word_probs[beam_idx::beam_sz, j] = score
            else:
                word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                word_probs[1::beam_sz, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz - 1, k)
                    word_probs[beam_idx::beam_sz, j] = score

            attns = torch.randn(1, batch_sz * beam_sz, 53)
            beam.advance(word_probs, attns)
            if i < min_length:
                self.assertFalse(beam.done)
                # no top beams are finished yet
                for b in range(batch_sz):
                    self.assertEqual(beam.attention[b], [])
            elif i == min_length:
                # beam 1 dies on min_length
                self.assertTrue(beam.is_finished[:, 1].all())
                beam.update_finished()
                self.assertFalse(beam.done)
                # no top beams are finished yet
                for b in range(batch_sz):
                    self.assertEqual(beam.attention[b], [])
            else:  # i > min_length
                # beam 0 dies on the step after beam 1 dies
                self.assertTrue(beam.is_finished[:, 0].all())
                beam.update_finished()
                self.assertTrue(beam.done)
                # top beam is finished now so there are attentions
                for b in range(batch_sz):
                    # two beams are finished in each batch
                    self.assertEqual(len(beam.attention[b]), 2)
                    for k in range(2):
                        # second dim is cut down to the non-padded src length
                        self.assertEqual(beam.attention[b][k].shape[-1],
                                         inp_lens[b])
                    # first dim is equal to the time of death
                    # (beam 0 died at current step - adjust for SOS)
                    self.assertEqual(beam.attention[b][0].shape[0], i + 1)
                    # (beam 1 died at last step - adjust for SOS)
                    self.assertEqual(beam.attention[b][1].shape[0], i)
                # behavior gets weird when beam is already done so just stop
                break
Пример #24
0
def run(
        n_iterations=7,
        beam_sz=5,
        batch_sz=1,
        n_words=100,
        repeat_idx=47,
        ngram_repeat=-1,
        repeat_logprob=-0.2,
        no_repeat_logprob=-0.1,
        base_logprob=float("-inf"),
        verbose=True,
):
    """
        At each timestep `i`:
            - token `repeat_idx` get a logprob of `repeat_logprob`
            - token `repeat_idx+i` get `no_repeat_logprob`
            - other tokens get `base_logprob`

    """
    def log(*args, **kwargs):
        if verbose:
            print(*args, **kwargs)

    device_init = torch.zeros(1, 1)

    beam = BeamSearch(beam_sz, batch_sz, 0, 1, 2, 4, GlobalScorerStub(), 0, 30,
                      False, ngram_repeat, set(), False, 0.)
    beam.initialize(device_init, torch.randint(0, 30, (batch_sz, )))
    for i in range(n_iterations):
        # non-interesting beams are going to get dummy values
        word_logprobs = torch.full((batch_sz * beam_sz, n_words), base_logprob)
        if i == 0:
            # on initial round, only predicted scores for beam 0
            # matter. Make two predictions. Top one will be repeated
            # in beam zero, second one will live on in beam 1.
            word_logprobs[0::beam_sz, repeat_idx] = repeat_logprob
            word_logprobs[0::beam_sz, repeat_idx + i + 1] = no_repeat_logprob
        else:
            # predict the same thing in beam 0
            word_logprobs[0::beam_sz, repeat_idx] = repeat_logprob
            # continue pushing around what beam 1 predicts
            word_logprobs[0::beam_sz, repeat_idx + i + 1] = no_repeat_logprob
        attns = torch.randn(1, batch_sz * beam_sz, 0)
        beam.advance(word_logprobs, attns)
        if ngram_repeat > 0:
            # NOTE: IGNORE IT FOR NOW
            # if i < ngram_repeat:
            #     assertFalse(
            #         beam.topk_log_probs[0::beam_sz].eq(
            #             BLOCKED_SCORE).any())
            #     assertFalse(
            #         beam.topk_log_probs[1::beam_sz].eq(
            #             BLOCKED_SCORE).any())
            # elif i == ngram_repeat:
            #     assertFalse(
            #         beam.topk_log_probs[:, 0].eq(
            #             BLOCKED_SCORE).any())

            #     expected = torch.full([batch_sz, beam_sz], base_logprob)
            #     expected[:, 0] = (i+1) * no_repeat_logprob
            #     expected[:, 1] = BLOCKED_SCORE
            # else:
            #     expected = torch.full([batch_sz, beam_sz], base_logprob)
            #     expected[:, 0] = i * no_repeat_logprob
            #     expected[:, 1] = BLOCKED_SCORE
            pass
        # log("Iteration (%d): expected %s" % (i, str(expected)))
        log("Iteration (%d): logprobs %s" % (i, str(beam.topk_log_probs)))
        log("Iteration (%d): seq %s" % (i, str(beam.alive_seq)))
        log("Iteration (%d): indices %s" % (i, str(beam.topk_ids)))
        if ngram_repeat > 0:
            log("Iteration (%d): blocked %s" % (i, str(beam.forbidden_tokens)))
    return beam.topk_log_probs, beam.alive_seq
Пример #25
0
    def _translate_batch(
            self,
            batch,
            src_vocabs,
            max_length,
            min_length=0,
            ratio=0.,
            n_best=1,
            return_attention=False,
            xlation_builder=None,
    ):
        # TODO: support these blacklisted features.
        assert not self.dump_beam

        # (0) Prep the components of the search.
        use_src_map = self.copy_attn
        beam_size = self.beam_size
        batch_size = batch.batch_size

        # (1) Run the encoder on the src.
        src_list, enc_states_list, memory_bank_list, src_lengths_list = self._run_encoder(batch)
        self.model.decoder.init_state(src_list, memory_bank_list, enc_states_list)

        results = {
            "predictions": None,
            "scores": None,
            "attention": None,
            "batch": batch,
            "gold_score": self._gold_score(
                batch, memory_bank_list, src_lengths_list, src_vocabs, use_src_map,
                enc_states_list, batch_size, src_list)}
        
        # (2) Repeat src objects `beam_size` times.
        # We use batch_size x beam_size
        src_map_list = list()
        for src_type in self.src_types:
            src_map_list.append((tile(getattr(batch, f"src_map.{src_type}"), beam_size, dim=1) if use_src_map else None))
        # end for

        self.model.decoder.map_state(
            lambda state, dim: tile(state, beam_size, dim=dim))

        memory_lengths_list = list()
        memory_lengths = list()
        for src_i in range(len(memory_bank_list)):
            if isinstance(memory_bank_list[src_i], tuple):
                memory_bank_list[src_i] = tuple(tile(x, beam_size, dim=1) for x in memory_bank_list[src_i])
                mb_device = memory_bank_list[src_i][0].device
            else:
                memory_bank_list[src_i] = tile(memory_bank_list[src_i], beam_size, dim=1)
                mb_device = memory_bank_list[src_i].device
            # end if
            memory_lengths_list.append(tile(src_lengths_list[src_i], beam_size))
            memory_lengths.append(src_lengths_list[src_i])
        # end for
        memory_lengths = tile(torch.stack(memory_lengths, dim=0).sum(dim=0), beam_size)

        indexes = tile(torch.tensor(list(range(batch_size)), device=self._dev), beam_size)

        # (0) pt 2, prep the beam object
        beam = BeamSearch(
            beam_size,
            n_best=n_best,
            batch_size=batch_size,
            global_scorer=self.global_scorer,
            pad=self._tgt_pad_idx,
            eos=self._tgt_eos_idx,
            bos=self._tgt_bos_idx,
            min_length=min_length,
            ratio=ratio,
            max_length=max_length,
            mb_device=mb_device,
            return_attention=return_attention,
            stepwise_penalty=self.stepwise_penalty,
            block_ngram_repeat=self.block_ngram_repeat,
            exclusion_tokens=self._exclusion_idxs,
            memory_lengths=memory_lengths)

        for step in range(max_length):
            decoder_input = beam.current_predictions.view(1, -1, 1)

            log_probs, attn = self._decode_and_generate(
                decoder_input,
                memory_bank_list,
                batch,
                src_vocabs,
                memory_lengths_list=memory_lengths_list,
                src_map_list=src_map_list,
                step=step,
                batch_offset=beam._batch_offset)

            if self.reranker is not None:
                log_probs = self.reranker.rerank_step_beam_batch(
                    batch,
                    beam,
                    self.beam_size,
                    indexes,
                    log_probs,
                    attn,
                    self.fields["tgt"].base_field.vocab,
                    xlation_builder,
                )
            # end if

            non_finished = None
            beam.advance(log_probs, attn)
            any_beam_is_finished = beam.is_finished.any()
            if any_beam_is_finished:
                non_finished = beam.update_finished()
                if beam.done:
                    break

            select_indices = beam.current_origin

            if any_beam_is_finished:
                # Reorder states.
                for src_i in range(len(memory_bank_list)):
                    if isinstance(memory_bank_list[src_i], tuple):
                        memory_bank_list[src_i] = tuple(x.index_select(1, select_indices)
                                            for x in memory_bank_list[src_i])
                    else:
                        memory_bank_list[src_i] = memory_bank_list[src_i].index_select(1, select_indices)
                    # end if

                    memory_lengths_list[src_i] = memory_lengths_list[src_i].index_select(0, select_indices)
                # end for

                if use_src_map and src_map_list[0] is not None:
                    for src_i in range(len(src_map_list)):
                        src_map_list[src_i] = src_map_list[src_i].index_select(1, select_indices)
                    # end for
                # end if

                indexes = indexes.index_select(0, select_indices)

            self.model.decoder.map_state(
                lambda state, dim: state.index_select(dim, select_indices))

        results["scores"] = beam.scores
        results["predictions"] = beam.predictions
        results["attention"] = beam.attention
        return results
Пример #26
0
 model.decoder.map_state(lambda state, dim: tile(state, beam_size, dim=dim))
 if isinstance(memory_bank, tuple):
     memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank)
     mb_device = memory_bank[0].device
 else:
     memory_bank = tile(memory_bank, beam_size, dim=1)
     mb_device = memory_bank.device
 memory_lengths = tile(src_lengths, beam_size)
 beam = BeamSearch(beam_size,
                   n_best=n_best,
                   batch_size=batch_size,
                   global_scorer=global_scorer,
                   pad=tgt_pad_idx,
                   eos=tgt_eos_idx,
                   bos=tgt_bos_idx,
                   min_length=min_length,
                   ratio=ratio,
                   max_length=max_length,
                   mb_device=mb_device,
                   return_attention=return_attention,
                   stepwise_penalty=stepwise_penalty,
                   block_ngram_repeat=block_ngram_repeat,
                   exclusion_tokens=exclusion_idxs,
                   memory_lengths=memory_lengths)
 for step in range(max_length):
     decoder_input = beam.current_predictions.view(1, -1, 1)
     dec_out, dec_attn = model.decoder(decoder_input,
                                       memory_bank,
                                       memory_lengths=memory_lengths,
                                       step=step)
     log_probs = model.generator(dec_out.squeeze(0))
     attn = dec_attn['std']
Пример #27
0
    def __init__(self,
                 onmt_summarizer,
                 batch_size,
                 beam_size=None,
                 n_best=None,
                 diverse=None,
                 distractor=None,
                 scramble_idxs=None,
                 logger=None):
        super().__init__(logger)

        assert diverse in ['rank', None], 'Invalid diverse beam type!'

        s = onmt_summarizer
        T = onmt_summarizer.translator

        # default values are given by T
        beam_size = T.beam_size if beam_size is None else beam_size
        n_best = T.n_best if n_best is None else n_best

        if distractor is None:
            memory_lengths = s.memory_lengths
        else:
            tgt_idx = scramble2tgt(scramble_idxs, distractor.d_factor)
            memory_lengths = s.memory_lengths.view(-1, beam_size) \
                .index_select(0, tgt_idx).view(-1)

        if diverse is None:
            self.beam = BeamSearch(
                beam_size,
                n_best=n_best,
                batch_size=batch_size,  # actual batch size
                global_scorer=T.global_scorer,
                pad=T._tgt_pad_idx,
                eos=T._tgt_eos_idx,
                bos=T._tgt_bos_idx,
                min_length=T.min_length,
                ratio=T.ratio,
                max_length=T.max_length,
                mb_device=s.mb_device,
                return_attention=T.replace_unk,
                stepwise_penalty=T.stepwise_penalty,
                block_ngram_repeat=T.block_ngram_repeat,
                exclusion_tokens=T._exclusion_idxs,
                memory_lengths=s.memory_lengths)
        elif diverse == 'rank':
            self.beam = RankDiverseBeam(
                beam_size,
                n_best=n_best,
                batch_size=batch_size,  # actual batch size
                global_scorer=T.global_scorer,
                pad=T._tgt_pad_idx,
                eos=T._tgt_eos_idx,
                bos=T._tgt_bos_idx,
                min_length=T.min_length,
                ratio=T.ratio,
                max_length=T.max_length,
                mb_device=s.mb_device,
                return_attention=T.replace_unk,
                stepwise_penalty=T.stepwise_penalty,
                block_ngram_repeat=T.block_ngram_repeat,
                exclusion_tokens=T._exclusion_idxs,
                memory_lengths=s.memory_lengths)
Пример #28
0
    def _translate_batch(
            self,
            batch,
            src_vocabs,
            max_length,
            min_length=0,
            ratio=0.,
            n_best=1,
            return_attention=False):
        # TODO: support these blacklisted features.
        assert not self.dump_beam

        # (0) Prep the components of the search.
        use_src_map = self.copy_attn

        beam_size = self.beam_size #default 5
        batch_size = batch.batch_size #default 30

        # (1) Run the encoder on the src.
        src, enc_states, memory_bank, src_lengths = self._run_encoder(batch)
        #src.size() = torch.Size([59, 30, 1]) [src_len,batch_size,1]
        #enc_states[0/1].size() = [2,30,500]
        #memory_bank.size() =[59,30,500]
        #src_lengths.size() = [30]
        self.model.decoder.init_state(src, memory_bank, enc_states)

        results = {
            "predictions": None,
            "scores": None,
            "attention": None,
            "batch": batch,
            "gold_score": self._gold_score(
                batch, memory_bank, src_lengths, src_vocabs, use_src_map,
                enc_states, batch_size, src)}

        # (2) Repeat src objects `beam_size` times.
        # We use batch_size x beam_size
        src_map = (tile(batch.src_map, beam_size, dim=1)
                   if use_src_map else None)
        self.model.decoder.map_state(
            lambda state, dim: tile(state, beam_size, dim=dim))

        if isinstance(memory_bank, tuple):
            memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank)    #把张量x在dim=1,重复beam_size次。beam_size=1是batch_size的维度。
            mb_device = memory_bank[0].device
        else:
            memory_bank = tile(memory_bank, beam_size, dim=1)
            mb_device = memory_bank.device
        memory_lengths = tile(src_lengths, beam_size)

        # (0) pt 2, prep the beam object
        beam = BeamSearch(
            beam_size,
            n_best=n_best,
            batch_size=batch_size,
            global_scorer=self.global_scorer,
            pad=self._tgt_pad_idx,
            eos=self._tgt_eos_idx,
            bos=self._tgt_bos_idx,
            min_length=min_length,
            ratio=ratio,
            max_length=max_length, #Maximum prediction length.
            mb_device=mb_device,
            return_attention=return_attention,
            stepwise_penalty=self.stepwise_penalty,
            block_ngram_repeat=self.block_ngram_repeat,
            exclusion_tokens=self._exclusion_idxs,
            memory_lengths=memory_lengths)

        for step in range(max_length): #一共走这个多个step,每个step将beam_size * batch_size个分支加入
            decoder_input = beam.current_predictions.view(1, -1, 1) #decoder_input.size() = torch.Size([1,150,1]) 150 = 30 * 5 = batch_size * beam_size
            # @property
            # def current_predictions(self):
            #     return self.alive_seq[:, -1]
            log_probs, attn = self._decode_and_generate(
                decoder_input,
                memory_bank,#torch.Size([59, 150, 500])
                batch,
                src_vocabs,
                memory_lengths=memory_lengths,
                src_map=src_map,
                step=step,
                batch_offset=beam._batch_offset)
            # print("log_probs = ",log_probs) #[150, 50004] 这个50004应该是词表的大小,词表中的单词应该是5万,多出来的4个应该是<s> </s> <unk> <pad>
            # print("attn = ",attn) #torch.Size([1, 150, 59]) 这个59应该是src中的最长的句子的长度
            # print("decoder_input = ",decoder_input.size())
            beam.advance(log_probs, attn)#这个里面完成的工作应该是将150再变回30,
            any_beam_is_finished = beam.is_finished.any()
            if any_beam_is_finished:
                beam.update_finished()
                if beam.done:
                    break

            select_indices = beam.current_origin

            if any_beam_is_finished:
                # Reorder states.
                if isinstance(memory_bank, tuple):
                    memory_bank = tuple(x.index_select(1, select_indices)
                                        for x in memory_bank)
                else:
                    memory_bank = memory_bank.index_select(1, select_indices)

                memory_lengths = memory_lengths.index_select(0, select_indices)

                if src_map is not None:
                    src_map = src_map.index_select(1, select_indices)

            self.model.decoder.map_state(
                lambda state, dim: state.index_select(dim, select_indices))

        results["scores"] = beam.scores
        results["predictions"] = beam.predictions
        results["attention"] = beam.attention
        return results
Пример #29
0
    def _translate_batch(
            self,
            src,
            src_lengths,
            batch_size,
            min_length=0,
            ratio=0.,
            n_best=1,
            return_attention=False):

        max_length = self.config.max_sequence_length + 1 # to account for EOS
        beam_size = 3
        
        # Encoder forward.
        enc_states, memory_bank, src_lengths = self.encoder(src, src_lengths)
        self.decoder.init_state(src, memory_bank, enc_states)

        results = { "predictions": None, "scores": None, "attention": None }

        # (2) Repeat src objects `beam_size` times.
        # We use batch_size x beam_size
        self.decoder.map_state(lambda state, dim: tile(state, beam_size, dim=dim))

        #if isinstance(memory_bank, tuple):
        #    memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank)
        #    mb_device = memory_bank[0].device
        #else:
        memory_bank = tile(memory_bank, beam_size, dim=1)
        mb_device = memory_bank.device
        memory_lengths = tile(src_lengths, beam_size)

        mb_device = memory_bank[0].device if isinstance(memory_bank, tuple) else memory_bank.device
        
        block_ngram_repeat = 0
        _exclusion_idxs = {}

        # (0) pt 2, prep the beam object
        beam = BeamSearch(
            beam_size,
            n_best=n_best,
            batch_size=batch_size,
            global_scorer=self.scorer,
            pad=self.config.tgt_padding,
            eos=self.config.tgt_eos,
            bos=self.config.tgt_bos,
            min_length=min_length,
            ratio=ratio,
            max_length=max_length,
            mb_device=mb_device,
            return_attention=return_attention,
            stepwise_penalty=None,
            block_ngram_repeat=block_ngram_repeat,
            exclusion_tokens=_exclusion_idxs,
            memory_lengths=memory_lengths)

        for step in range(max_length):
            decoder_input = beam.current_predictions.view(1, -1, 1)

            log_probs, attn = self._decode_and_generate(decoder_input, memory_bank, memory_lengths, step, pretraining = True)

            beam.advance(log_probs, attn)
            any_beam_is_finished = beam.is_finished.any()
            if any_beam_is_finished:
                beam.update_finished()
                if beam.done:
                    break

            select_indices = beam.current_origin

            if any_beam_is_finished:
                # Reorder states.
                if isinstance(memory_bank, tuple):
                    memory_bank = tuple(x.index_select(1, select_indices)
                                        for x in memory_bank)
                else:
                    memory_bank = memory_bank.index_select(1, select_indices)

                memory_lengths = memory_lengths.index_select(0, select_indices)

            self.decoder.map_state(
                lambda state, dim: state.index_select(dim, select_indices))

        results["scores"] = beam.scores
        results["predictions"] = beam.predictions
        results["attention"] = beam.attention
        return results
Пример #30
0
    def _translate_batch(self,
                         batch,
                         src_vocabs,
                         max_length,
                         min_length=0,
                         ratio=0.,
                         n_best=1,
                         return_attention=False):
        # TODO: support these blacklisted features.
        assert not self.dump_beam

        # (0) Prep the components of the search.
        use_src_map = self.copy_attn
        beam_size = self.beam_size
        batch_size = batch.batch_size

        #### TODO: Augment batch with distractors

        # (1) Run the encoder on the src.
        src, enc_states, memory_bank, src_lengths = self._run_encoder(batch)
        # src has shape [1311, 2, 1]
        # enc_states has shape [1311, 2, 512],
        # Memory_bank has shape [1311, 2, 512]
        self.model.decoder.init_state(src, memory_bank, enc_states)

        results = {
            "predictions":
            None,
            "scores":
            None,
            "attention":
            None,
            "batch":
            batch,
            "gold_score":
            self._gold_score(batch, memory_bank, src_lengths, src_vocabs,
                             use_src_map, enc_states, batch_size, src)
        }

        # (2) Repeat src objects `beam_size` times.
        # We use batch_size x beam_size
        src_map = (tile(batch.src_map, beam_size, dim=1)
                   if use_src_map else None)
        self.model.decoder.map_state(
            lambda state, dim: tile(state, beam_size, dim=dim))

        if isinstance(memory_bank, tuple):
            memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank)
            mb_device = memory_bank[0].device
        else:
            memory_bank = tile(memory_bank, beam_size, dim=1)
            mb_device = memory_bank.device
        memory_lengths = tile(src_lengths, beam_size)
        print('memory_bank size after tile:',
              memory_bank.shape)  #[1311, 20, 512]

        # (0) pt 2, prep the beam object
        beam = BeamSearch(beam_size,
                          n_best=n_best,
                          batch_size=batch_size,
                          global_scorer=self.global_scorer,
                          pad=self._tgt_pad_idx,
                          eos=self._tgt_eos_idx,
                          bos=self._tgt_bos_idx,
                          min_length=min_length,
                          ratio=ratio,
                          max_length=max_length,
                          mb_device=mb_device,
                          return_attention=return_attention,
                          stepwise_penalty=self.stepwise_penalty,
                          block_ngram_repeat=self.block_ngram_repeat,
                          exclusion_tokens=self._exclusion_idxs,
                          memory_lengths=memory_lengths)

        all_log_probs = []
        all_attn = []

        for step in range(max_length):
            decoder_input = beam.current_predictions.view(1, -1, 1)
            # decoder_input has shape[1,20,1]
            # decoder_input gives top 10 predictions for each batch element
            verbose = True if step == 10 else False
            log_probs, attn = self._decode_and_generate(
                decoder_input,
                memory_bank,
                batch,
                src_vocabs,
                memory_lengths=memory_lengths,
                src_map=src_map,
                step=step,
                batch_offset=beam._batch_offset,
                verbose=verbose)

            # log_probs and attn are the probs for next word given that the
            # current word is that in decoder_input
            all_log_probs.append(log_probs)
            all_attn.append(attn)

            beam.advance(log_probs, attn, verbose=verbose)

            any_beam_is_finished = beam.is_finished.any()
            if any_beam_is_finished:
                beam.update_finished()
                if beam.done:
                    break

            select_indices = beam.current_origin

            if any_beam_is_finished:
                # Reorder states.
                if isinstance(memory_bank, tuple):
                    memory_bank = tuple(
                        x.index_select(1, select_indices) for x in memory_bank)
                else:
                    memory_bank = memory_bank.index_select(1, select_indices)

                memory_lengths = memory_lengths.index_select(0, select_indices)

                if src_map is not None:
                    src_map = src_map.index_select(1, select_indices)

            self.model.decoder.map_state(
                lambda state, dim: state.index_select(dim, select_indices))

        print('batch_size:', batch_size)
        print('max_length:', max_length)
        print('all_log_probs has len', len(all_log_probs))
        print('all_log_probs[0].shape', all_log_probs[0].shape)
        print('comparing log_probs[0]', all_log_probs[2][:, 0])

        results["scores"] = beam.scores
        results["predictions"] = beam.predictions
        results["attention"] = beam.attention
        return results
Пример #31
0
    def test_beam_returns_attn_with_correct_length(self):
        beam_sz = 5
        batch_sz = 3
        n_words = 100
        _non_eos_idxs = [47, 51, 13, 88, 99]
        valid_score_dist = torch.log_softmax(torch.tensor(
            [6., 5., 4., 3., 2., 1.]), dim=0)
        min_length = 5
        eos_idx = 2
        inp_lens = torch.randint(1, 30, (batch_sz,))
        beam = BeamSearch(
            beam_sz, batch_sz, 0, 1, 2, 2,
            torch.device("cpu"), GlobalScorerStub(),
            min_length, 30, True, 0, set(),
            inp_lens, False, 0.)
        for i in range(min_length + 2):
            # non-interesting beams are going to get dummy values
            word_probs = torch.full(
                (batch_sz * beam_sz, n_words), -float('inf'))
            if i == 0:
                # "best" prediction is eos - that should be blocked
                word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                # include at least beam_sz predictions OTHER than EOS
                # that are greater than -1e20
                for j, score in zip(_non_eos_idxs, valid_score_dist[1:]):
                    word_probs[0::beam_sz, j] = score
            elif i <= min_length:
                # predict eos in beam 1
                word_probs[1::beam_sz, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz-1, k)
                    word_probs[beam_idx::beam_sz, j] = score
            else:
                word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                word_probs[1::beam_sz, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz-1, k)
                    word_probs[beam_idx::beam_sz, j] = score

            attns = torch.randn(1, batch_sz * beam_sz, 53)
            beam.advance(word_probs, attns)
            if i < min_length:
                self.assertFalse(beam.done)
                # no top beams are finished yet
                for b in range(batch_sz):
                    self.assertEqual(beam.attention[b], [])
            elif i == min_length:
                # beam 1 dies on min_length
                self.assertTrue(beam.is_finished[:, 1].all())
                beam.update_finished()
                self.assertFalse(beam.done)
                # no top beams are finished yet
                for b in range(batch_sz):
                    self.assertEqual(beam.attention[b], [])
            else:  # i > min_length
                # beam 0 dies on the step after beam 1 dies
                self.assertTrue(beam.is_finished[:, 0].all())
                beam.update_finished()
                self.assertTrue(beam.done)
                # top beam is finished now so there are attentions
                for b in range(batch_sz):
                    # two beams are finished in each batch
                    self.assertEqual(len(beam.attention[b]), 2)
                    for k in range(2):
                        # second dim is cut down to the non-padded src length
                        self.assertEqual(beam.attention[b][k].shape[-1],
                                         inp_lens[b])
                    # first dim is equal to the time of death
                    # (beam 0 died at current step - adjust for SOS)
                    self.assertEqual(beam.attention[b][0].shape[0], i+1)
                    # (beam 1 died at last step - adjust for SOS)
                    self.assertEqual(beam.attention[b][1].shape[0], i)
                # behavior gets weird when beam is already done so just stop
                break
Пример #32
0
def batch_beam_search_trs(model, inputs, batch_size, device="cpu", max_len=128, beam_size=20, n_best=1, alpha=1.):
    """ beam search with batch input for Transformer model

    Arguments:
        beam {onmt.BeamSearch} -- opennmt BeamSearch class
        model {torch.nn.Module} -- subclass of torch.nn.Module, required to implement .encode() and .decode() method
        inputs {list} -- list of torch.Tensor for input of encode()

    Keyword Arguments:
        device {str} -- device to eval model (default: {"cpu"})

    Returns:
        result -- 2D list (B, N-best), each element is an (seq, score) pair
    """
    beam = BeamSearch(beam_size, batch_size,
                    pad=C.PAD, bos=C.BOS, eos=C.EOS,
                    n_best=n_best,
                    mb_device=device,
                    global_scorer=GNMTGlobalScorer(alpha, 0.1, "avg", "none"),
                    min_length=0,
                    max_length=max_len,
                    ratio=0.0,
                    memory_lengths=None,
                    block_ngram_repeat=False,
                    exclusion_tokens=None,
                    stepwise_penalty=True,
                    return_attention=False,
                    )
    model.eval()
    is_finished = [False] * beam.batch_size
    with torch.no_grad():
        src_ids, _, src_pos, _, _, src_key_padding_mask, _, original_memory_key_padding_mask = list(
            map(lambda x: x.to(device), inputs))
        
        if src_ids.shape[1]!=batch_size:
            diff = batch_size - src_ids.shape[1]
            src_ids = torch.cat([src_ids] + [src_ids[:,:1]] * diff, dim=1)
            src_pos = torch.cat([src_pos] + [src_pos[:,:1]] * diff, dim=1)
            src_key_padding_mask = torch.cat([src_key_padding_mask]+[src_key_padding_mask[:1]]* diff, dim=0)
            original_memory_key_padding_mask = torch.cat([original_memory_key_padding_mask] +[original_memory_key_padding_mask[:1]]*diff, dim=0)


        model.to(device)
        original_memory = model.encode(src_ids, src_pos, src_key_padding_mask=src_key_padding_mask)

        memory = original_memory
        memory_key_padding_mask = original_memory_key_padding_mask
        while not beam.done:
            len_decoder_inputs = beam.alive_seq.shape[1]
            dec_pos = torch.arange(1, len_decoder_inputs+1).repeat(beam.alive_seq.shape[0], 1).permute(1, 0).to(device)

            # unsqueeze the memory and memory_key_padding_mask in B dim to match the size (BM*BS)
            repeated_memory = memory.repeat(1, 1, beam.beam_size).reshape(
                memory.shape[0], -1, memory.shape[-1])
            repeated_memory_key_padding_mask = memory_key_padding_mask.repeat(
                1, beam.beam_size).reshape(-1, memory_key_padding_mask.shape[1])

            decoder_outputs = model.decode(beam.alive_seq.permute(1, 0), dec_pos, _, repeated_memory, memory_key_padding_mask=repeated_memory_key_padding_mask)[-1]
            if hasattr(model, "proj"):
                logits = model.proj(decoder_outputs)
            elif hasattr(model, "gen"):
                logits = model.gen(decoder_outputs)
            else:
                raise ValueError("Unknown generator!")

            log_probs = torch.nn.functional.log_softmax(logits, dim=1)
            beam.advance(log_probs, None)
            if beam.is_finished.any():
                beam.update_finished()

                # select data for the still-alive index
                for i, n_best in enumerate(beam.predictions):
                    if is_finished[i] == False and len(n_best) == beam.n_best:
                        is_finished[i] = True

                alive_example_idx = [i for i in range(
                    len(is_finished)) if not is_finished[i]]
                if alive_example_idx:
                    memory = original_memory[:, alive_example_idx, :]
                    memory_key_padding_mask = original_memory_key_padding_mask[alive_example_idx]

    # packing data for easy accessing
    results = []
    for batch_preds, batch_scores in zip(beam.predictions, beam.scores):
        n_best_result = []
        for n_best_pred, n_best_score in zip(batch_preds, batch_scores):
            assert isinstance(n_best_pred, torch.Tensor)
            assert isinstance(n_best_score, torch.Tensor)
            n_best_result.append(
                (n_best_pred.tolist(), n_best_score.item())
            )
        results.append(n_best_result)

    return results