Ejemplo n.º 1
0
 def test_advance_with_all_repeats_gets_blocked(self):
     # all beams repeat (beam >= 1 repeat dummy scores)
     beam_sz = 5
     n_words = 100
     repeat_idx = 47
     ngram_repeat = 3
     device_init = torch.zeros(1, 1)
     for batch_sz in [1, 3]:
         beam = BeamSearch(
             beam_sz, batch_sz, 0, 1, 2, 2,
             GlobalScorerStub(), 0, 30,
             False, ngram_repeat, set(),
             False, 0.)
         beam.initialize(device_init, torch.randint(0, 30, (batch_sz,)))
         for i in range(ngram_repeat + 4):
             # predict repeat_idx over and over again
             word_probs = torch.full(
                 (batch_sz * beam_sz, n_words), -float('inf'))
             word_probs[0::beam_sz, repeat_idx] = 0
             attns = torch.randn(1, batch_sz * beam_sz, 53)
             beam.advance(word_probs, attns)
             if i <= ngram_repeat:
                 expected_scores = torch.tensor(
                             [0] + [-float('inf')] * (beam_sz - 1))\
                         .repeat(batch_sz, 1)
                 self.assertTrue(beam.topk_log_probs.equal(expected_scores))
             else:
                 self.assertTrue(
                     beam.topk_log_probs.equal(
                         torch.tensor(self.BLOCKED_SCORE)
                         .repeat(batch_sz, beam_sz)))
Ejemplo n.º 2
0
    def test_beam_is_done_when_n_best_beams_eos_using_min_length(self):
        # this is also a test that when block_ngram_repeat=0,
        # repeating is acceptable
        beam_sz = 5
        batch_sz = 3
        n_words = 100
        _non_eos_idxs = [47, 51, 13, 88, 99]
        valid_score_dist = torch.log_softmax(torch.tensor(
            [6., 5., 4., 3., 2., 1.]), dim=0)
        min_length = 5
        eos_idx = 2
        beam = BeamSearch(
            beam_sz, batch_sz, 0, 1, 2, 2,
            GlobalScorerStub(),
            min_length, 30, False, 0, set(),
            False, 0.)
        device_init = torch.zeros(1, 1)
        beam.initialize(device_init, torch.randint(0, 30, (batch_sz,)))
        for i in range(min_length + 4):
            # non-interesting beams are going to get dummy values
            word_probs = torch.full(
                (batch_sz * beam_sz, n_words), -float('inf'))
            if i == 0:
                # "best" prediction is eos - that should be blocked
                word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                # include at least beam_sz predictions OTHER than EOS
                # that are greater than -1e20
                for j, score in zip(_non_eos_idxs, valid_score_dist[1:]):
                    word_probs[0::beam_sz, j] = score
            elif i <= min_length:
                # predict eos in beam 1
                word_probs[1::beam_sz, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz-1, k)
                    word_probs[beam_idx::beam_sz, j] = score
            else:
                word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                word_probs[1::beam_sz, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz-1, k)
                    word_probs[beam_idx::beam_sz, j] = score

            attns = torch.randn(1, batch_sz * beam_sz, 53)
            beam.advance(word_probs, attns)
            if i < min_length:
                self.assertFalse(beam.done)
            elif i == min_length:
                # beam 1 dies on min_length
                self.assertTrue(beam.is_finished[:, 1].all())
                beam.update_finished()
                self.assertFalse(beam.done)
            else:  # i > min_length
                # beam 0 dies on the step after beam 1 dies
                self.assertTrue(beam.is_finished[:, 0].all())
                beam.update_finished()
                self.assertTrue(beam.done)
Ejemplo n.º 3
0
    def test_doesnt_predict_eos_if_shorter_than_min_len(self):
        # beam 0 will always predict EOS. The other beams will predict
        # non-eos scores.
        for batch_sz in [1, 3]:
            beam_sz = 5
            n_words = 100
            _non_eos_idxs = [47, 51, 13, 88, 99]
            valid_score_dist = torch.log_softmax(torch.tensor(
                [6., 5., 4., 3., 2., 1.]), dim=0)
            min_length = 5
            eos_idx = 2
            lengths = torch.randint(0, 30, (batch_sz,))
            beam = BeamSearch(beam_sz, batch_sz, 0, 1, 2, 2,
                              GlobalScorerStub(),
                              min_length, 30, False, 0, set(),
                              False, 0.)
            device_init = torch.zeros(1, 1)
            beam.initialize(device_init, lengths)
            all_attns = []
            for i in range(min_length + 4):
                # non-interesting beams are going to get dummy values
                word_probs = torch.full(
                    (batch_sz * beam_sz, n_words), -float('inf'))
                if i == 0:
                    # "best" prediction is eos - that should be blocked
                    word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                    # include at least beam_sz predictions OTHER than EOS
                    # that are greater than -1e20
                    for j, score in zip(_non_eos_idxs, valid_score_dist[1:]):
                        word_probs[0::beam_sz, j] = score
                else:
                    # predict eos in beam 0
                    word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                    # provide beam_sz other good predictions
                    for k, (j, score) in enumerate(
                            zip(_non_eos_idxs, valid_score_dist[1:])):
                        beam_idx = min(beam_sz-1, k)
                        word_probs[beam_idx::beam_sz, j] = score

                attns = torch.randn(1, batch_sz * beam_sz, 53)
                all_attns.append(attns)
                beam.advance(word_probs, attns)
                if i < min_length:
                    expected_score_dist = \
                        (i+1) * valid_score_dist[1:].unsqueeze(0)
                    self.assertTrue(
                        beam.topk_log_probs.allclose(
                            expected_score_dist))
                elif i == min_length:
                    # now the top beam has ended and no others have
                    self.assertTrue(beam.is_finished[:, 0].eq(1).all())
                    self.assertTrue(beam.is_finished[:, 1:].eq(0).all())
                else:  # i > min_length
                    # not of interest, but want to make sure it keeps running
                    # since only beam 0 terminates and n_best = 2
                    pass
    def test_advance_with_some_repeats_gets_blocked(self):
        # beam 0 and beam >=2 will repeat (beam >= 2 repeat dummy scores)
        beam_sz = 5
        n_words = 100
        repeat_idx = 47
        ngram_repeat = 3
        no_repeat_score = -2.3
        repeat_score = -0.1
        device_init = torch.zeros(1, 1)
        for batch_sz in [1, 3]:
            beam = BeamSearch(beam_sz, batch_sz, 0, 1, 2, 2,
                              GlobalScorerStub(), 0, 30, False, ngram_repeat,
                              set(), False, 0.)
            beam.initialize(device_init, torch.randint(0, 30, (batch_sz, )))
            for i in range(ngram_repeat + 4):
                # non-interesting beams are going to get dummy values
                word_probs = torch.full((batch_sz * beam_sz, n_words),
                                        -float('inf'))
                if i == 0:
                    # on initial round, only predicted scores for beam 0
                    # matter. Make two predictions. Top one will be repeated
                    # in beam zero, second one will live on in beam 1.
                    word_probs[0::beam_sz, repeat_idx] = repeat_score
                    word_probs[0::beam_sz,
                               repeat_idx + i + 1] = no_repeat_score
                else:
                    # predict the same thing in beam 0
                    word_probs[0::beam_sz, repeat_idx] = 0
                    # continue pushing around what beam 1 predicts
                    word_probs[1::beam_sz, repeat_idx + i + 1] = 0
                attns = torch.randn(1, batch_sz * beam_sz, 53)
                beam.advance(word_probs, attns)
                if i < ngram_repeat:
                    self.assertFalse(beam.topk_log_probs[0::beam_sz].eq(
                        self.BLOCKED_SCORE).any())
                    self.assertFalse(beam.topk_log_probs[1::beam_sz].eq(
                        self.BLOCKED_SCORE).any())
                elif i == ngram_repeat:
                    # now beam 0 dies (along with the others), beam 1 -> beam 0
                    self.assertFalse(beam.topk_log_probs[:, 0].eq(
                        self.BLOCKED_SCORE).any())

                    expected = torch.full([batch_sz, beam_sz], float("-inf"))
                    expected[:, 0] = no_repeat_score
                    expected[:, 1] = self.BLOCKED_SCORE
                    self.assertTrue(beam.topk_log_probs[:, :].equal(expected))
                else:
                    # now beam 0 dies (along with the others), beam 1 -> beam 0
                    self.assertFalse(beam.topk_log_probs[:, 0].eq(
                        self.BLOCKED_SCORE).any())

                    expected = torch.full([batch_sz, beam_sz], float("-inf"))
                    expected[:, 0] = no_repeat_score
                    expected[:, 1:3] = self.BLOCKED_SCORE
                    expected[:, 3:] = float("-inf")
                    self.assertTrue(beam.topk_log_probs.equal(expected))
Ejemplo n.º 5
0
 def test_beam_advance_against_known_reference(self):
     scorer = GNMTGlobalScorer(0.7, 0., "avg", "none")
     beam = BeamSearch(self.BEAM_SZ, self.BATCH_SZ, 0, 1, 2, self.N_BEST,
                       scorer, 0, 30, False, 0, set(), False, 0.)
     device_init = torch.zeros(1, 1)
     beam.initialize(device_init, torch.randint(0, 30, (self.BATCH_SZ, )))
     expected_beam_scores = self.init_step(beam, 1.)
     expected_beam_scores = self.first_step(beam, expected_beam_scores, 3)
     expected_beam_scores = self.second_step(beam, expected_beam_scores, 4)
     self.third_step(beam, expected_beam_scores, 5)
Ejemplo n.º 6
0
 def test_beam_advance_against_known_reference(self):
     beam = BeamSearch(self.BEAM_SZ, self.BATCH_SZ, 0, 1, 2, self.N_BEST,
                       GlobalScorerStub(), 0, 30, False, 0, set(), False,
                       0.)
     device_init = torch.zeros(1, 1)
     beam.initialize(device_init, torch.randint(0, 30, (self.BATCH_SZ, )))
     expected_beam_scores = self.init_step(beam, 1)
     expected_beam_scores = self.first_step(beam, expected_beam_scores, 1)
     expected_beam_scores = self.second_step(beam, expected_beam_scores, 1)
     self.third_step(beam, expected_beam_scores, 1)
Ejemplo n.º 7
0
 def test_repeating_excluded_index_does_not_die(self):
     # beam 0 and beam >= 2 will repeat (beam 2 repeats excluded idx)
     beam_sz = 5
     n_words = 100
     repeat_idx = 47  # will be repeated and should be blocked
     repeat_idx_ignored = 7  # will be repeated and should not be blocked
     ngram_repeat = 3
     device_init = torch.zeros(1, 1)
     for batch_sz in [1, 3]:
         beam = BeamSearch(
             beam_sz, batch_sz, 0, 1, 2, 2,
             GlobalScorerStub(), 0, 30,
             False, ngram_repeat, {repeat_idx_ignored},
             False, 0.)
         beam.initialize(device_init, torch.randint(0, 30, (batch_sz,)))
         for i in range(ngram_repeat + 4):
             # non-interesting beams are going to get dummy values
             word_probs = torch.full(
                 (batch_sz * beam_sz, n_words), -float('inf'))
             if i == 0:
                 word_probs[0::beam_sz, repeat_idx] = -0.1
                 word_probs[0::beam_sz, repeat_idx + i + 1] = -2.3
                 word_probs[0::beam_sz, repeat_idx_ignored] = -5.0
             else:
                 # predict the same thing in beam 0
                 word_probs[0::beam_sz, repeat_idx] = 0
                 # continue pushing around what beam 1 predicts
                 word_probs[1::beam_sz, repeat_idx + i + 1] = 0
                 # predict the allowed-repeat again in beam 2
                 word_probs[2::beam_sz, repeat_idx_ignored] = 0
             attns = torch.randn(1, batch_sz * beam_sz, 53)
             beam.advance(word_probs, attns)
             if i <= ngram_repeat:
                 self.assertFalse(beam.topk_log_probs[:, 0].eq(
                     self.BLOCKED_SCORE).any())
                 self.assertFalse(beam.topk_log_probs[:, 1].eq(
                     self.BLOCKED_SCORE).any())
                 self.assertFalse(beam.topk_log_probs[:, 2].eq(
                     self.BLOCKED_SCORE).any())
             else:
                 # now beam 0 dies, beam 1 -> beam 0, beam 2 -> beam 1
                 # and the rest die
                 self.assertFalse(beam.topk_log_probs[:, 0].eq(
                     self.BLOCKED_SCORE).any())
                 # since all preds after i=0 are 0, we can check
                 # that the beam is the correct idx by checking that
                 # the curr score is the initial score
                 self.assertTrue(beam.topk_log_probs[:, 0].eq(-2.3).all())
                 self.assertFalse(beam.topk_log_probs[:, 1].eq(
                     self.BLOCKED_SCORE).all())
                 self.assertTrue(beam.topk_log_probs[:, 1].eq(-5.0).all())
                 self.assertTrue(
                     beam.topk_log_probs[:, 2:].equal(
                         torch.tensor(self.BLOCKED_SCORE)
                         .repeat(batch_sz, beam_sz - 2)))
Ejemplo n.º 8
0
    def test_advance_with_all_repeats_gets_blocked(self):
        # all beams repeat (beam >= 1 repeat dummy scores)
        beam_sz = 5
        n_words = 100
        repeat_idx = 47
        ngram_repeat = 3
        device_init = torch.zeros(1, 1)
        for batch_sz in [1, 3]:
            beam = BeamSearch(beam_sz, batch_sz, 0, 1, 2, 2,
                              GlobalScorerStub(), 0, 30, False, ngram_repeat,
                              set(), False, 0.)
            beam.initialize(device_init, torch.randint(0, 30, (batch_sz, )))
            for i in range(ngram_repeat + 4):
                # predict repeat_idx over and over again
                word_probs = torch.full((batch_sz * beam_sz, n_words),
                                        -float('inf'))
                word_probs[0::beam_sz, repeat_idx] = 0

                attns = torch.randn(1, batch_sz * beam_sz, 53)
                beam.advance(word_probs, attns)

                if i < ngram_repeat:
                    # before repeat, scores are either 0 or -inf
                    expected_scores = torch.tensor(
                        [0] + [-float('inf')] * (beam_sz - 1))\
                        .repeat(batch_sz, 1)
                    self.assertTrue(beam.topk_log_probs.equal(expected_scores))
                elif i % ngram_repeat == 0:
                    # on repeat, `repeat_idx` score is BLOCKED_SCORE
                    # (but it's still the best score, thus we have
                    # [BLOCKED_SCORE, -inf, -inf, -inf, -inf]
                    expected_scores = torch.tensor(
                        [0] + [-float('inf')] * (beam_sz - 1))\
                        .repeat(batch_sz, 1)
                    expected_scores[:, 0] = self.BLOCKED_SCORE
                    self.assertTrue(beam.topk_log_probs.equal(expected_scores))
                else:
                    # repetitions keeps maximizing score
                    # index 0 has been blocked, so repeating=>+0.0 score
                    # other indexes are -inf so repeating=>BLOCKED_SCORE
                    # which is higher
                    expected_scores = torch.tensor(
                        [0] + [-float('inf')] * (beam_sz - 1))\
                        .repeat(batch_sz, 1)
                    expected_scores[:, :] = self.BLOCKED_SCORE
                    expected_scores = torch.tensor(self.BLOCKED_SCORE).repeat(
                        batch_sz, beam_sz)
Ejemplo n.º 9
0
    def test_beam_returns_attn_with_correct_length(self):
        beam_sz = 5
        batch_sz = 3
        n_words = 100
        _non_eos_idxs = [47, 51, 13, 88, 99]
        valid_score_dist = torch.log_softmax(torch.tensor(
            [6., 5., 4., 3., 2., 1.]),
                                             dim=0)
        min_length = 5
        eos_idx = 2
        inp_lens = torch.randint(1, 30, (batch_sz, ))
        beam = BeamSearch(beam_sz, batch_sz, 0, 1, 2, 2, GlobalScorerStub(),
                          min_length, 30, True, 0, set(), False, 0.)
        device_init = torch.zeros(1, 1)
        _, _, inp_lens, _ = beam.initialize(device_init, inp_lens)
        # inp_lens is tiled in initialize, reassign to make attn match
        for i in range(min_length + 2):
            # non-interesting beams are going to get dummy values
            word_probs = torch.full((batch_sz * beam_sz, n_words),
                                    -float('inf'))
            if i == 0:
                # "best" prediction is eos - that should be blocked
                word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                # include at least beam_sz predictions OTHER than EOS
                # that are greater than -1e20
                for j, score in zip(_non_eos_idxs, valid_score_dist[1:]):
                    word_probs[0::beam_sz, j] = score
            elif i <= min_length:
                # predict eos in beam 1
                word_probs[1::beam_sz, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz - 1, k)
                    word_probs[beam_idx::beam_sz, j] = score
            else:
                word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                word_probs[1::beam_sz, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz - 1, k)
                    word_probs[beam_idx::beam_sz, j] = score

            attns = torch.randn(1, batch_sz * beam_sz, 53)
            beam.advance(word_probs, attns)
            if i < min_length:
                self.assertFalse(beam.done)
                # no top beams are finished yet
                for b in range(batch_sz):
                    self.assertEqual(beam.attention[b], [])
            elif i == min_length:
                # beam 1 dies on min_length
                self.assertTrue(beam.is_finished[:, 1].all())
                beam.update_finished()
                self.assertFalse(beam.done)
                # no top beams are finished yet
                for b in range(batch_sz):
                    self.assertEqual(beam.attention[b], [])
            else:  # i > min_length
                # beam 0 dies on the step after beam 1 dies
                self.assertTrue(beam.is_finished[:, 0].all())
                beam.update_finished()
                self.assertTrue(beam.done)
                # top beam is finished now so there are attentions
                for b in range(batch_sz):
                    # two beams are finished in each batch
                    self.assertEqual(len(beam.attention[b]), 2)
                    for k in range(2):
                        # second dim is cut down to the non-padded src length
                        self.assertEqual(beam.attention[b][k].shape[-1],
                                         inp_lens[b])
                    # first dim is equal to the time of death
                    # (beam 0 died at current step - adjust for SOS)
                    self.assertEqual(beam.attention[b][0].shape[0], i + 1)
                    # (beam 1 died at last step - adjust for SOS)
                    self.assertEqual(beam.attention[b][1].shape[0], i)
                # behavior gets weird when beam is already done so just stop
                break
Ejemplo n.º 10
0
def run(
        n_iterations=7,
        beam_sz=5,
        batch_sz=1,
        n_words=100,
        repeat_idx=47,
        ngram_repeat=-1,
        repeat_logprob=-0.2,
        no_repeat_logprob=-0.1,
        base_logprob=float("-inf"),
        verbose=True,
):
    """
        At each timestep `i`:
            - token `repeat_idx` get a logprob of `repeat_logprob`
            - token `repeat_idx+i` get `no_repeat_logprob`
            - other tokens get `base_logprob`

    """
    def log(*args, **kwargs):
        if verbose:
            print(*args, **kwargs)

    device_init = torch.zeros(1, 1)

    beam = BeamSearch(beam_sz, batch_sz, 0, 1, 2, 4, GlobalScorerStub(), 0, 30,
                      False, ngram_repeat, set(), False, 0.)
    beam.initialize(device_init, torch.randint(0, 30, (batch_sz, )))
    for i in range(n_iterations):
        # non-interesting beams are going to get dummy values
        word_logprobs = torch.full((batch_sz * beam_sz, n_words), base_logprob)
        if i == 0:
            # on initial round, only predicted scores for beam 0
            # matter. Make two predictions. Top one will be repeated
            # in beam zero, second one will live on in beam 1.
            word_logprobs[0::beam_sz, repeat_idx] = repeat_logprob
            word_logprobs[0::beam_sz, repeat_idx + i + 1] = no_repeat_logprob
        else:
            # predict the same thing in beam 0
            word_logprobs[0::beam_sz, repeat_idx] = repeat_logprob
            # continue pushing around what beam 1 predicts
            word_logprobs[0::beam_sz, repeat_idx + i + 1] = no_repeat_logprob
        attns = torch.randn(1, batch_sz * beam_sz, 0)
        beam.advance(word_logprobs, attns)
        if ngram_repeat > 0:
            # NOTE: IGNORE IT FOR NOW
            # if i < ngram_repeat:
            #     assertFalse(
            #         beam.topk_log_probs[0::beam_sz].eq(
            #             BLOCKED_SCORE).any())
            #     assertFalse(
            #         beam.topk_log_probs[1::beam_sz].eq(
            #             BLOCKED_SCORE).any())
            # elif i == ngram_repeat:
            #     assertFalse(
            #         beam.topk_log_probs[:, 0].eq(
            #             BLOCKED_SCORE).any())

            #     expected = torch.full([batch_sz, beam_sz], base_logprob)
            #     expected[:, 0] = (i+1) * no_repeat_logprob
            #     expected[:, 1] = BLOCKED_SCORE
            # else:
            #     expected = torch.full([batch_sz, beam_sz], base_logprob)
            #     expected[:, 0] = i * no_repeat_logprob
            #     expected[:, 1] = BLOCKED_SCORE
            pass
        # log("Iteration (%d): expected %s" % (i, str(expected)))
        log("Iteration (%d): logprobs %s" % (i, str(beam.topk_log_probs)))
        log("Iteration (%d): seq %s" % (i, str(beam.alive_seq)))
        log("Iteration (%d): indices %s" % (i, str(beam.topk_ids)))
        if ngram_repeat > 0:
            log("Iteration (%d): blocked %s" % (i, str(beam.forbidden_tokens)))
    return beam.topk_log_probs, beam.alive_seq