Example #1
0
 def test_advance_with_all_repeats_gets_blocked(self):
     # all beams repeat (beam >= 1 repeat dummy scores)
     beam_sz = 5
     n_words = 100
     repeat_idx = 47
     ngram_repeat = 3
     for batch_sz in [1, 3]:
         beam = BeamSearch(
             beam_sz, batch_sz, 0, 1, 2, 2,
             torch.device("cpu"), GlobalScorerStub(), 0, 30,
             False, ngram_repeat, set(),
             torch.randint(0, 30, (batch_sz,)), False, 0.)
         for i in range(ngram_repeat + 4):
             # predict repeat_idx over and over again
             word_probs = torch.full(
                 (batch_sz * beam_sz, n_words), -float('inf'))
             word_probs[0::beam_sz, repeat_idx] = 0
             attns = torch.randn(1, batch_sz * beam_sz, 53)
             beam.advance(word_probs, attns)
             if i <= ngram_repeat:
                 expected_scores = torch.tensor(
                             [0] + [-float('inf')] * (beam_sz - 1))\
                         .repeat(batch_sz, 1)
                 self.assertTrue(beam.topk_log_probs.equal(expected_scores))
             else:
                 self.assertTrue(
                     beam.topk_log_probs.equal(
                         torch.tensor(self.BLOCKED_SCORE)
                         .repeat(batch_sz, beam_sz)))
Example #2
0
    def test_beam_is_done_when_n_best_beams_eos_using_min_length(self):
        # this is also a test that when block_ngram_repeat=0,
        # repeating is acceptable
        beam_sz = 5
        batch_sz = 3
        n_words = 100
        _non_eos_idxs = [47, 51, 13, 88, 99]
        valid_score_dist = torch.log_softmax(torch.tensor(
            [6., 5., 4., 3., 2., 1.]), dim=0)
        min_length = 5
        eos_idx = 2
        # beam includes start token in cur_len count.
        # Add one to its min_length to compensate
        beam = BeamSearch(
            beam_sz, batch_sz, 0, 1, 2, 2,
            torch.device("cpu"), GlobalScorerStub(),
            min_length + 1, 30, False, 0, set(),
            torch.randint(0, 30, (batch_sz,)))
        for i in range(min_length + 4):
            # non-interesting beams are going to get dummy values
            word_probs = torch.full(
                (batch_sz * beam_sz, n_words), -float('inf'))
            if i == 0:
                # "best" prediction is eos - that should be blocked
                word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                # include at least beam_sz predictions OTHER than EOS
                # that are greater than -1e20
                for j, score in zip(_non_eos_idxs, valid_score_dist[1:]):
                    word_probs[0::beam_sz, j] = score
            elif i <= min_length:
                # predict eos in beam 1
                word_probs[1::beam_sz, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz-1, k)
                    word_probs[beam_idx::beam_sz, j] = score
            else:
                word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                word_probs[1::beam_sz, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz-1, k)
                    word_probs[beam_idx::beam_sz, j] = score

            attns = torch.randn(beam_sz)
            beam.advance(word_probs, attns)
            if i < min_length:
                self.assertFalse(beam.done)
            elif i == min_length:
                # beam 1 dies on min_length
                self.assertTrue(beam.is_finished[:, 1].all())
                beam.update_finished()
                self.assertFalse(beam.done)
            else:  # i > min_length
                # beam 0 dies on the step after beam 1 dies
                self.assertTrue(beam.is_finished[:, 0].all())
                beam.update_finished()
                self.assertTrue(beam.done)
Example #3
0
 def test_advance_with_all_repeats_gets_blocked(self):
     # all beams repeat (beam >= 1 repeat dummy scores)
     beam_sz = 5
     n_words = 100
     repeat_idx = 47
     ngram_repeat = 3
     for batch_sz in [1, 3]:
         beam = BeamSearch(beam_sz, batch_sz,
                           0, 1, 2, 2, torch.device("cpu"),
                           GlobalScorerStub(), 0, 30, False, ngram_repeat,
                           set(), torch.randint(0, 30, (batch_sz, )), False)
         for i in range(ngram_repeat + 4):
             # predict repeat_idx over and over again
             word_probs = torch.full((batch_sz * beam_sz, n_words),
                                     -float('inf'))
             word_probs[0::beam_sz, repeat_idx] = 0
             attns = torch.randn(1, batch_sz * beam_sz, 53)
             beam.advance(word_probs, attns)
             if i <= ngram_repeat:
                 expected_scores = torch.tensor(
                             [0] + [-float('inf')] * (beam_sz - 1))\
                         .repeat(batch_sz, 1)
                 self.assertTrue(beam.topk_log_probs.equal(expected_scores))
             else:
                 self.assertTrue(
                     beam.topk_log_probs.equal(
                         torch.tensor(self.BLOCKED_SCORE).repeat(
                             batch_sz, beam_sz)))
Example #4
0
    def test_beam_is_done_when_n_best_beams_eos_using_min_length(self):
        # this is also a test that when block_ngram_repeat=0,
        # repeating is acceptable
        beam_sz = 5
        batch_sz = 3
        n_words = 100
        _non_eos_idxs = [47, 51, 13, 88, 99]
        valid_score_dist = torch.log_softmax(torch.tensor(
            [6., 5., 4., 3., 2., 1.]), dim=0)
        min_length = 5
        eos_idx = 2
        beam = BeamSearch(
            beam_sz, batch_sz, 0, 1, 2, 2,
            torch.device("cpu"), GlobalScorerStub(),
            min_length, 30, False, 0, set(),
            torch.randint(0, 30, (batch_sz,)), False, 0.)
        for i in range(min_length + 4):
            # non-interesting beams are going to get dummy values
            word_probs = torch.full(
                (batch_sz * beam_sz, n_words), -float('inf'))
            if i == 0:
                # "best" prediction is eos - that should be blocked
                word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                # include at least beam_sz predictions OTHER than EOS
                # that are greater than -1e20
                for j, score in zip(_non_eos_idxs, valid_score_dist[1:]):
                    word_probs[0::beam_sz, j] = score
            elif i <= min_length:
                # predict eos in beam 1
                word_probs[1::beam_sz, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz-1, k)
                    word_probs[beam_idx::beam_sz, j] = score
            else:
                word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                word_probs[1::beam_sz, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz-1, k)
                    word_probs[beam_idx::beam_sz, j] = score

            attns = torch.randn(1, batch_sz * beam_sz, 53)
            beam.advance(word_probs, attns)
            if i < min_length:
                self.assertFalse(beam.done)
            elif i == min_length:
                # beam 1 dies on min_length
                self.assertTrue(beam.is_finished[:, 1].all())
                beam.update_finished()
                self.assertFalse(beam.done)
            else:  # i > min_length
                # beam 0 dies on the step after beam 1 dies
                self.assertTrue(beam.is_finished[:, 0].all())
                beam.update_finished()
                self.assertTrue(beam.done)
    def test_advance_with_some_repeats_gets_blocked(self):
        # beam 0 and beam >=2 will repeat (beam >= 2 repeat dummy scores)
        beam_sz = 5
        n_words = 100
        repeat_idx = 47
        ngram_repeat = 3
        no_repeat_score = -2.3
        repeat_score = -0.1
        device_init = torch.zeros(1, 1)
        for batch_sz in [1, 3]:
            beam = BeamSearch(beam_sz, batch_sz, 0, 1, 2, 2,
                              GlobalScorerStub(), 0, 30, False, ngram_repeat,
                              set(), False, 0.)
            beam.initialize(device_init, torch.randint(0, 30, (batch_sz, )))
            for i in range(ngram_repeat + 4):
                # non-interesting beams are going to get dummy values
                word_probs = torch.full((batch_sz * beam_sz, n_words),
                                        -float('inf'))
                if i == 0:
                    # on initial round, only predicted scores for beam 0
                    # matter. Make two predictions. Top one will be repeated
                    # in beam zero, second one will live on in beam 1.
                    word_probs[0::beam_sz, repeat_idx] = repeat_score
                    word_probs[0::beam_sz,
                               repeat_idx + i + 1] = no_repeat_score
                else:
                    # predict the same thing in beam 0
                    word_probs[0::beam_sz, repeat_idx] = 0
                    # continue pushing around what beam 1 predicts
                    word_probs[1::beam_sz, repeat_idx + i + 1] = 0
                attns = torch.randn(1, batch_sz * beam_sz, 53)
                beam.advance(word_probs, attns)
                if i < ngram_repeat:
                    self.assertFalse(beam.topk_log_probs[0::beam_sz].eq(
                        self.BLOCKED_SCORE).any())
                    self.assertFalse(beam.topk_log_probs[1::beam_sz].eq(
                        self.BLOCKED_SCORE).any())
                elif i == ngram_repeat:
                    # now beam 0 dies (along with the others), beam 1 -> beam 0
                    self.assertFalse(beam.topk_log_probs[:, 0].eq(
                        self.BLOCKED_SCORE).any())

                    expected = torch.full([batch_sz, beam_sz], float("-inf"))
                    expected[:, 0] = no_repeat_score
                    expected[:, 1] = self.BLOCKED_SCORE
                    self.assertTrue(beam.topk_log_probs[:, :].equal(expected))
                else:
                    # now beam 0 dies (along with the others), beam 1 -> beam 0
                    self.assertFalse(beam.topk_log_probs[:, 0].eq(
                        self.BLOCKED_SCORE).any())

                    expected = torch.full([batch_sz, beam_sz], float("-inf"))
                    expected[:, 0] = no_repeat_score
                    expected[:, 1:3] = self.BLOCKED_SCORE
                    expected[:, 3:] = float("-inf")
                    self.assertTrue(beam.topk_log_probs.equal(expected))
    def test_doesnt_predict_eos_if_shorter_than_min_len(self):
        # beam 0 will always predict EOS. The other beams will predict
        # non-eos scores.
        for batch_sz in [1, 3]:
            beam_sz = 5
            n_words = 100
            _non_eos_idxs = [47, 51, 13, 88, 99]
            valid_score_dist = torch.log_softmax(torch.tensor(
                [6., 5., 4., 3., 2., 1.]), dim=0)
            min_length = 5
            eos_idx = 2
            lengths = torch.randint(0, 30, (batch_sz,))
            beam = BeamSearch(beam_sz, batch_sz, 0, 1, 2, 2,
                              GlobalScorerStub(),
                              min_length, 30, False, 0, set(),
                              False, 0.)
            device_init = torch.zeros(1, 1)
            beam.initialize(device_init, lengths)
            all_attns = []
            for i in range(min_length + 4):
                # non-interesting beams are going to get dummy values
                word_probs = torch.full(
                    (batch_sz * beam_sz, n_words), -float('inf'))
                if i == 0:
                    # "best" prediction is eos - that should be blocked
                    word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                    # include at least beam_sz predictions OTHER than EOS
                    # that are greater than -1e20
                    for j, score in zip(_non_eos_idxs, valid_score_dist[1:]):
                        word_probs[0::beam_sz, j] = score
                else:
                    # predict eos in beam 0
                    word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                    # provide beam_sz other good predictions
                    for k, (j, score) in enumerate(
                            zip(_non_eos_idxs, valid_score_dist[1:])):
                        beam_idx = min(beam_sz-1, k)
                        word_probs[beam_idx::beam_sz, j] = score

                attns = torch.randn(1, batch_sz * beam_sz, 53)
                all_attns.append(attns)
                beam.advance(word_probs, attns)
                if i < min_length:
                    expected_score_dist = \
                        (i+1) * valid_score_dist[1:].unsqueeze(0)
                    self.assertTrue(
                        beam.topk_log_probs.allclose(
                            expected_score_dist))
                elif i == min_length:
                    # now the top beam has ended and no others have
                    self.assertTrue(beam.is_finished[:, 0].eq(1).all())
                    self.assertTrue(beam.is_finished[:, 1:].eq(0).all())
                else:  # i > min_length
                    # not of interest, but want to make sure it keeps running
                    # since only beam 0 terminates and n_best = 2
                    pass
Example #7
0
    def forward_dev_beam_search(self, encoder_output: torch.Tensor, pad_mask):
        batch_size = encoder_output.size(1)

        self.state["cache"] = None
        memory_lengths = pad_mask.ne(pad_token_index).sum(dim=0)

        self.map_state(lambda state, dim: tile(state, self.beam_size, dim=dim))
        encoder_output = tile(encoder_output, self.beam_size, dim=1)
        pad_mask = tile(pad_mask, self.beam_size, dim=1)
        memory_lengths = tile(memory_lengths, self.beam_size, dim=0)

        # TODO:
        #  - fix attn (?)
        #  - use coverage_penalty="summary" ou "wu" and beta=0.2 (ou pas)
        #  - use length_penalty="wu" and alpha=0.2 (ou pas)
        beam = BeamSearch(beam_size=self.beam_size, n_best=1, batch_size=batch_size, mb_device=default_device,
                          global_scorer=GNMTGlobalScorer(alpha=0, beta=0, coverage_penalty="none", length_penalty="avg"),
                          pad=pad_token_index, eos=eos_token_index, bos=bos_token_index, min_length=1, max_length=100,
                          return_attention=False, stepwise_penalty=False, block_ngram_repeat=0, exclusion_tokens=set(),
                          memory_lengths=memory_lengths, ratio=-1)

        for i in range(self.max_seq_out_len):
            inp = beam.current_predictions.view(1, -1)

            out, attn = self.forward_step(src=pad_mask, tgt=inp, memory_bank=encoder_output, step=i)  # 1 x batch*beam x hidden
            out = self.linear(out)  # 1 x batch*beam x vocab_out
            out = log_softmax(out, dim=2)  # 1 x batch*beam x vocab_out

            out = out.squeeze(0)  # batch*beam x vocab_out
            # attn = attn.squeeze(0)  # batch*beam x vocab_out
            # out = out.view(batch_size, self.beam_size, -1)  # batch x beam x vocab_out
            # attn = attn.view(batch_size, self.beam_size, -1)
            # TODO: fix attn (?)

            beam.advance(out, attn)
            any_beam_is_finished = beam.is_finished.any()
            if any_beam_is_finished:
                beam.update_finished()
                if beam.done:
                    break

            select_indices = beam.current_origin

            if any_beam_is_finished:
                # Reorder states.
                encoder_output = encoder_output.index_select(1, select_indices)
                pad_mask = pad_mask.index_select(1, select_indices)
                memory_lengths = memory_lengths.index_select(0, select_indices)

            self.map_state(lambda state, dim: state.index_select(dim, select_indices))

        outputs = beam.predictions
        outputs = [x[0] for x in outputs]
        outputs = pad_sequence(outputs, batch_first=True)
        return [outputs]
 def test_repeating_excluded_index_does_not_die(self):
     # beam 0 and beam >= 2 will repeat (beam 2 repeats excluded idx)
     beam_sz = 5
     n_words = 100
     repeat_idx = 47  # will be repeated and should be blocked
     repeat_idx_ignored = 7  # will be repeated and should not be blocked
     ngram_repeat = 3
     device_init = torch.zeros(1, 1)
     for batch_sz in [1, 3]:
         beam = BeamSearch(
             beam_sz, batch_sz, 0, 1, 2, 2,
             GlobalScorerStub(), 0, 30,
             False, ngram_repeat, {repeat_idx_ignored},
             False, 0.)
         beam.initialize(device_init, torch.randint(0, 30, (batch_sz,)))
         for i in range(ngram_repeat + 4):
             # non-interesting beams are going to get dummy values
             word_probs = torch.full(
                 (batch_sz * beam_sz, n_words), -float('inf'))
             if i == 0:
                 word_probs[0::beam_sz, repeat_idx] = -0.1
                 word_probs[0::beam_sz, repeat_idx + i + 1] = -2.3
                 word_probs[0::beam_sz, repeat_idx_ignored] = -5.0
             else:
                 # predict the same thing in beam 0
                 word_probs[0::beam_sz, repeat_idx] = 0
                 # continue pushing around what beam 1 predicts
                 word_probs[1::beam_sz, repeat_idx + i + 1] = 0
                 # predict the allowed-repeat again in beam 2
                 word_probs[2::beam_sz, repeat_idx_ignored] = 0
             attns = torch.randn(1, batch_sz * beam_sz, 53)
             beam.advance(word_probs, attns)
             if i <= ngram_repeat:
                 self.assertFalse(beam.topk_log_probs[:, 0].eq(
                     self.BLOCKED_SCORE).any())
                 self.assertFalse(beam.topk_log_probs[:, 1].eq(
                     self.BLOCKED_SCORE).any())
                 self.assertFalse(beam.topk_log_probs[:, 2].eq(
                     self.BLOCKED_SCORE).any())
             else:
                 # now beam 0 dies, beam 1 -> beam 0, beam 2 -> beam 1
                 # and the rest die
                 self.assertFalse(beam.topk_log_probs[:, 0].eq(
                     self.BLOCKED_SCORE).any())
                 # since all preds after i=0 are 0, we can check
                 # that the beam is the correct idx by checking that
                 # the curr score is the initial score
                 self.assertTrue(beam.topk_log_probs[:, 0].eq(-2.3).all())
                 self.assertFalse(beam.topk_log_probs[:, 1].eq(
                     self.BLOCKED_SCORE).all())
                 self.assertTrue(beam.topk_log_probs[:, 1].eq(-5.0).all())
                 self.assertTrue(
                     beam.topk_log_probs[:, 2:].equal(
                         torch.tensor(self.BLOCKED_SCORE)
                         .repeat(batch_sz, beam_sz - 2)))
Example #9
0
    def test_doesnt_predict_eos_if_shorter_than_min_len(self):
        # beam 0 will always predict EOS. The other beams will predict
        # non-eos scores.
        for batch_sz in [1, 3]:
            beam_sz = 5
            n_words = 100
            _non_eos_idxs = [47, 51, 13, 88, 99]
            valid_score_dist = torch.log_softmax(torch.tensor(
                [6., 5., 4., 3., 2., 1.]), dim=0)
            min_length = 5
            eos_idx = 2
            lengths = torch.randint(0, 30, (batch_sz,))
            beam = BeamSearch(beam_sz, batch_sz, 0, 1, 2, 2,
                              torch.device("cpu"), GlobalScorerStub(),
                              min_length, 30, False, 0, set(),
                              lengths, False, 0.)
            all_attns = []
            for i in range(min_length + 4):
                # non-interesting beams are going to get dummy values
                word_probs = torch.full(
                    (batch_sz * beam_sz, n_words), -float('inf'))
                if i == 0:
                    # "best" prediction is eos - that should be blocked
                    word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                    # include at least beam_sz predictions OTHER than EOS
                    # that are greater than -1e20
                    for j, score in zip(_non_eos_idxs, valid_score_dist[1:]):
                        word_probs[0::beam_sz, j] = score
                else:
                    # predict eos in beam 0
                    word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                    # provide beam_sz other good predictions
                    for k, (j, score) in enumerate(
                            zip(_non_eos_idxs, valid_score_dist[1:])):
                        beam_idx = min(beam_sz-1, k)
                        word_probs[beam_idx::beam_sz, j] = score

                attns = torch.randn(1, batch_sz * beam_sz, 53)
                all_attns.append(attns)
                beam.advance(word_probs, attns)
                if i < min_length:
                    expected_score_dist = \
                        (i+1) * valid_score_dist[1:].unsqueeze(0)
                    self.assertTrue(
                        beam.topk_log_probs.allclose(
                            expected_score_dist))
                elif i == min_length:
                    # now the top beam has ended and no others have
                    self.assertTrue(beam.is_finished[:, 0].eq(1).all())
                    self.assertTrue(beam.is_finished[:, 1:].eq(0).all())
                else:  # i > min_length
                    # not of interest, but want to make sure it keeps running
                    # since only beam 0 terminates and n_best = 2
                    pass
Example #10
0
 def test_repeating_excluded_index_does_not_die(self):
     # beam 0 and beam >= 2 will repeat (beam 2 repeats excluded idx)
     beam_sz = 5
     n_words = 100
     repeat_idx = 47  # will be repeated and should be blocked
     repeat_idx_ignored = 7  # will be repeated and should not be blocked
     ngram_repeat = 3
     for batch_sz in [1, 3]:
         beam = BeamSearch(
             beam_sz, batch_sz, 0, 1, 2, 2,
             torch.device("cpu"), GlobalScorerStub(), 0, 30,
             False, ngram_repeat, {repeat_idx_ignored},
             torch.randint(0, 30, (batch_sz,)), False, 0.)
         for i in range(ngram_repeat + 4):
             # non-interesting beams are going to get dummy values
             word_probs = torch.full(
                 (batch_sz * beam_sz, n_words), -float('inf'))
             if i == 0:
                 word_probs[0::beam_sz, repeat_idx] = -0.1
                 word_probs[0::beam_sz, repeat_idx + i + 1] = -2.3
                 word_probs[0::beam_sz, repeat_idx_ignored] = -5.0
             else:
                 # predict the same thing in beam 0
                 word_probs[0::beam_sz, repeat_idx] = 0
                 # continue pushing around what beam 1 predicts
                 word_probs[1::beam_sz, repeat_idx + i + 1] = 0
                 # predict the allowed-repeat again in beam 2
                 word_probs[2::beam_sz, repeat_idx_ignored] = 0
             attns = torch.randn(1, batch_sz * beam_sz, 53)
             beam.advance(word_probs, attns)
             if i <= ngram_repeat:
                 self.assertFalse(beam.topk_log_probs[:, 0].eq(
                     self.BLOCKED_SCORE).any())
                 self.assertFalse(beam.topk_log_probs[:, 1].eq(
                     self.BLOCKED_SCORE).any())
                 self.assertFalse(beam.topk_log_probs[:, 2].eq(
                     self.BLOCKED_SCORE).any())
             else:
                 # now beam 0 dies, beam 1 -> beam 0, beam 2 -> beam 1
                 # and the rest die
                 self.assertFalse(beam.topk_log_probs[:, 0].eq(
                     self.BLOCKED_SCORE).any())
                 # since all preds after i=0 are 0, we can check
                 # that the beam is the correct idx by checking that
                 # the curr score is the initial score
                 self.assertTrue(beam.topk_log_probs[:, 0].eq(-2.3).all())
                 self.assertFalse(beam.topk_log_probs[:, 1].eq(
                     self.BLOCKED_SCORE).all())
                 self.assertTrue(beam.topk_log_probs[:, 1].eq(-5.0).all())
                 self.assertTrue(
                     beam.topk_log_probs[:, 2:].equal(
                         torch.tensor(self.BLOCKED_SCORE)
                         .repeat(batch_sz, beam_sz - 2)))
Example #11
0
    def test_advance_with_all_repeats_gets_blocked(self):
        # all beams repeat (beam >= 1 repeat dummy scores)
        beam_sz = 5
        n_words = 100
        repeat_idx = 47
        ngram_repeat = 3
        device_init = torch.zeros(1, 1)
        for batch_sz in [1, 3]:
            beam = BeamSearch(beam_sz, batch_sz, 0, 1, 2, 2,
                              GlobalScorerStub(), 0, 30, False, ngram_repeat,
                              set(), False, 0.)
            beam.initialize(device_init, torch.randint(0, 30, (batch_sz, )))
            for i in range(ngram_repeat + 4):
                # predict repeat_idx over and over again
                word_probs = torch.full((batch_sz * beam_sz, n_words),
                                        -float('inf'))
                word_probs[0::beam_sz, repeat_idx] = 0

                attns = torch.randn(1, batch_sz * beam_sz, 53)
                beam.advance(word_probs, attns)

                if i < ngram_repeat:
                    # before repeat, scores are either 0 or -inf
                    expected_scores = torch.tensor(
                        [0] + [-float('inf')] * (beam_sz - 1))\
                        .repeat(batch_sz, 1)
                    self.assertTrue(beam.topk_log_probs.equal(expected_scores))
                elif i % ngram_repeat == 0:
                    # on repeat, `repeat_idx` score is BLOCKED_SCORE
                    # (but it's still the best score, thus we have
                    # [BLOCKED_SCORE, -inf, -inf, -inf, -inf]
                    expected_scores = torch.tensor(
                        [0] + [-float('inf')] * (beam_sz - 1))\
                        .repeat(batch_sz, 1)
                    expected_scores[:, 0] = self.BLOCKED_SCORE
                    self.assertTrue(beam.topk_log_probs.equal(expected_scores))
                else:
                    # repetitions keeps maximizing score
                    # index 0 has been blocked, so repeating=>+0.0 score
                    # other indexes are -inf so repeating=>BLOCKED_SCORE
                    # which is higher
                    expected_scores = torch.tensor(
                        [0] + [-float('inf')] * (beam_sz - 1))\
                        .repeat(batch_sz, 1)
                    expected_scores[:, :] = self.BLOCKED_SCORE
                    expected_scores = torch.tensor(self.BLOCKED_SCORE).repeat(
                        batch_sz, beam_sz)
Example #12
0
 def test_advance_with_some_repeats_gets_blocked(self):
     # beam 0 and beam >=2 will repeat (beam >= 2 repeat dummy scores)
     beam_sz = 5
     n_words = 100
     repeat_idx = 47
     ngram_repeat = 3
     for batch_sz in [1, 3]:
         beam = BeamSearch(
             beam_sz, batch_sz, 0, 1, 2, 2,
             torch.device("cpu"), GlobalScorerStub(), 0, 30,
             False, ngram_repeat, set(),
             torch.randint(0, 30, (batch_sz,)), False, 0.)
         for i in range(ngram_repeat + 4):
             # non-interesting beams are going to get dummy values
             word_probs = torch.full(
                 (batch_sz * beam_sz, n_words), -float('inf'))
             if i == 0:
                 # on initial round, only predicted scores for beam 0
                 # matter. Make two predictions. Top one will be repeated
                 # in beam zero, second one will live on in beam 1.
                 word_probs[0::beam_sz, repeat_idx] = -0.1
                 word_probs[0::beam_sz, repeat_idx + i + 1] = -2.3
             else:
                 # predict the same thing in beam 0
                 word_probs[0::beam_sz, repeat_idx] = 0
                 # continue pushing around what beam 1 predicts
                 word_probs[1::beam_sz, repeat_idx + i + 1] = 0
             attns = torch.randn(1, batch_sz * beam_sz, 53)
             beam.advance(word_probs, attns)
             if i <= ngram_repeat:
                 self.assertFalse(
                     beam.topk_log_probs[0::beam_sz].eq(
                         self.BLOCKED_SCORE).any())
                 self.assertFalse(
                     beam.topk_log_probs[1::beam_sz].eq(
                         self.BLOCKED_SCORE).any())
             else:
                 # now beam 0 dies (along with the others), beam 1 -> beam 0
                 self.assertFalse(
                     beam.topk_log_probs[:, 0].eq(
                         self.BLOCKED_SCORE).any())
                 self.assertTrue(
                     beam.topk_log_probs[:, 1:].equal(
                         torch.tensor(self.BLOCKED_SCORE)
                         .repeat(batch_sz, beam_sz-1)))
Example #13
0
 def test_advance_with_some_repeats_gets_blocked(self):
     # beam 0 and beam >=2 will repeat (beam >= 2 repeat dummy scores)
     beam_sz = 5
     n_words = 100
     repeat_idx = 47
     ngram_repeat = 3
     for batch_sz in [1, 3]:
         beam = BeamSearch(
             beam_sz, batch_sz, 0, 1, 2, 2,
             torch.device("cpu"), GlobalScorerStub(), 0, 30,
             False, ngram_repeat, set(),
             torch.randint(0, 30, (batch_sz,)))
         for i in range(ngram_repeat + 4):
             # non-interesting beams are going to get dummy values
             word_probs = torch.full(
                 (batch_sz * beam_sz, n_words), -float('inf'))
             if i == 0:
                 # on initial round, only predicted scores for beam 0
                 # matter. Make two predictions. Top one will be repeated
                 # in beam zero, second one will live on in beam 1.
                 word_probs[0::beam_sz, repeat_idx] = -0.1
                 word_probs[0::beam_sz, repeat_idx + i + 1] = -2.3
             else:
                 # predict the same thing in beam 0
                 word_probs[0::beam_sz, repeat_idx] = 0
                 # continue pushing around what beam 1 predicts
                 word_probs[1::beam_sz, repeat_idx + i + 1] = 0
             attns = torch.randn(beam_sz)
             beam.advance(word_probs, attns)
             if i <= ngram_repeat:
                 self.assertFalse(
                     beam.topk_log_probs[0::beam_sz].eq(
                         self.BLOCKED_SCORE).any())
                 self.assertFalse(
                     beam.topk_log_probs[1::beam_sz].eq(
                         self.BLOCKED_SCORE).any())
             else:
                 # now beam 0 dies (along with the others), beam 1 -> beam 0
                 self.assertFalse(
                     beam.topk_log_probs[:, 0].eq(
                         self.BLOCKED_SCORE).any())
                 self.assertTrue(
                     beam.topk_log_probs[:, 1:].equal(
                         torch.tensor(self.BLOCKED_SCORE)
                         .repeat(batch_sz, beam_sz-1)))
Example #14
0
    def _translate_batch(self,
                         batch,
                         src_vocabs,
                         max_length,
                         min_length=0,
                         ratio=0.,
                         n_best=1,
                         return_attention=False):
        # TODO: support these blacklisted features.
        assert not self.dump_beam

        # (0) Prep the components of the search.
        use_src_map = self.copy_attn
        beam_size = self.beam_size
        batch_size = batch.batch_size

        #### TODO: Augment batch with distractors

        # (1) Run the encoder on the src.
        src, enc_states, memory_bank, src_lengths = self._run_encoder(batch)
        # src has shape [1311, 2, 1]
        # enc_states has shape [1311, 2, 512],
        # Memory_bank has shape [1311, 2, 512]
        self.model.decoder.init_state(src, memory_bank, enc_states)

        results = {
            "predictions":
            None,
            "scores":
            None,
            "attention":
            None,
            "batch":
            batch,
            "gold_score":
            self._gold_score(batch, memory_bank, src_lengths, src_vocabs,
                             use_src_map, enc_states, batch_size, src)
        }

        # (2) Repeat src objects `beam_size` times.
        # We use batch_size x beam_size
        src_map = (tile(batch.src_map, beam_size, dim=1)
                   if use_src_map else None)
        self.model.decoder.map_state(
            lambda state, dim: tile(state, beam_size, dim=dim))

        if isinstance(memory_bank, tuple):
            memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank)
            mb_device = memory_bank[0].device
        else:
            memory_bank = tile(memory_bank, beam_size, dim=1)
            mb_device = memory_bank.device
        memory_lengths = tile(src_lengths, beam_size)
        print('memory_bank size after tile:',
              memory_bank.shape)  #[1311, 20, 512]

        # (0) pt 2, prep the beam object
        beam = BeamSearch(beam_size,
                          n_best=n_best,
                          batch_size=batch_size,
                          global_scorer=self.global_scorer,
                          pad=self._tgt_pad_idx,
                          eos=self._tgt_eos_idx,
                          bos=self._tgt_bos_idx,
                          min_length=min_length,
                          ratio=ratio,
                          max_length=max_length,
                          mb_device=mb_device,
                          return_attention=return_attention,
                          stepwise_penalty=self.stepwise_penalty,
                          block_ngram_repeat=self.block_ngram_repeat,
                          exclusion_tokens=self._exclusion_idxs,
                          memory_lengths=memory_lengths)

        all_log_probs = []
        all_attn = []

        for step in range(max_length):
            decoder_input = beam.current_predictions.view(1, -1, 1)
            # decoder_input has shape[1,20,1]
            # decoder_input gives top 10 predictions for each batch element
            verbose = True if step == 10 else False
            log_probs, attn = self._decode_and_generate(
                decoder_input,
                memory_bank,
                batch,
                src_vocabs,
                memory_lengths=memory_lengths,
                src_map=src_map,
                step=step,
                batch_offset=beam._batch_offset,
                verbose=verbose)

            # log_probs and attn are the probs for next word given that the
            # current word is that in decoder_input
            all_log_probs.append(log_probs)
            all_attn.append(attn)

            beam.advance(log_probs, attn, verbose=verbose)

            any_beam_is_finished = beam.is_finished.any()
            if any_beam_is_finished:
                beam.update_finished()
                if beam.done:
                    break

            select_indices = beam.current_origin

            if any_beam_is_finished:
                # Reorder states.
                if isinstance(memory_bank, tuple):
                    memory_bank = tuple(
                        x.index_select(1, select_indices) for x in memory_bank)
                else:
                    memory_bank = memory_bank.index_select(1, select_indices)

                memory_lengths = memory_lengths.index_select(0, select_indices)

                if src_map is not None:
                    src_map = src_map.index_select(1, select_indices)

            self.model.decoder.map_state(
                lambda state, dim: state.index_select(dim, select_indices))

        print('batch_size:', batch_size)
        print('max_length:', max_length)
        print('all_log_probs has len', len(all_log_probs))
        print('all_log_probs[0].shape', all_log_probs[0].shape)
        print('comparing log_probs[0]', all_log_probs[2][:, 0])

        results["scores"] = beam.scores
        results["predictions"] = beam.predictions
        results["attention"] = beam.attention
        return results
Example #15
0
    def _translate_batch(
            self,
            src,
            src_lengths,
            batch_size,
            min_length=0,
            ratio=0.,
            n_best=1,
            return_attention=False):

        max_length = self.config.max_sequence_length + 1 # to account for EOS
        beam_size = 3
        
        # Encoder forward.
        enc_states, memory_bank, src_lengths = self.encoder(src, src_lengths)
        self.decoder.init_state(src, memory_bank, enc_states)

        results = { "predictions": None, "scores": None, "attention": None }

        # (2) Repeat src objects `beam_size` times.
        # We use batch_size x beam_size
        self.decoder.map_state(lambda state, dim: tile(state, beam_size, dim=dim))

        #if isinstance(memory_bank, tuple):
        #    memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank)
        #    mb_device = memory_bank[0].device
        #else:
        memory_bank = tile(memory_bank, beam_size, dim=1)
        mb_device = memory_bank.device
        memory_lengths = tile(src_lengths, beam_size)

        mb_device = memory_bank[0].device if isinstance(memory_bank, tuple) else memory_bank.device
        
        block_ngram_repeat = 0
        _exclusion_idxs = {}

        # (0) pt 2, prep the beam object
        beam = BeamSearch(
            beam_size,
            n_best=n_best,
            batch_size=batch_size,
            global_scorer=self.scorer,
            pad=self.config.tgt_padding,
            eos=self.config.tgt_eos,
            bos=self.config.tgt_bos,
            min_length=min_length,
            ratio=ratio,
            max_length=max_length,
            mb_device=mb_device,
            return_attention=return_attention,
            stepwise_penalty=None,
            block_ngram_repeat=block_ngram_repeat,
            exclusion_tokens=_exclusion_idxs,
            memory_lengths=memory_lengths)

        for step in range(max_length):
            decoder_input = beam.current_predictions.view(1, -1, 1)

            log_probs, attn = self._decode_and_generate(decoder_input, memory_bank, memory_lengths, step, pretraining = True)

            beam.advance(log_probs, attn)
            any_beam_is_finished = beam.is_finished.any()
            if any_beam_is_finished:
                beam.update_finished()
                if beam.done:
                    break

            select_indices = beam.current_origin

            if any_beam_is_finished:
                # Reorder states.
                if isinstance(memory_bank, tuple):
                    memory_bank = tuple(x.index_select(1, select_indices)
                                        for x in memory_bank)
                else:
                    memory_bank = memory_bank.index_select(1, select_indices)

                memory_lengths = memory_lengths.index_select(0, select_indices)

            self.decoder.map_state(
                lambda state, dim: state.index_select(dim, select_indices))

        results["scores"] = beam.scores
        results["predictions"] = beam.predictions
        results["attention"] = beam.attention
        return results
Example #16
0
def run(
        n_iterations=7,
        beam_sz=5,
        batch_sz=1,
        n_words=100,
        repeat_idx=47,
        ngram_repeat=-1,
        repeat_logprob=-0.2,
        no_repeat_logprob=-0.1,
        base_logprob=float("-inf"),
        verbose=True,
):
    """
        At each timestep `i`:
            - token `repeat_idx` get a logprob of `repeat_logprob`
            - token `repeat_idx+i` get `no_repeat_logprob`
            - other tokens get `base_logprob`

    """
    def log(*args, **kwargs):
        if verbose:
            print(*args, **kwargs)

    device_init = torch.zeros(1, 1)

    beam = BeamSearch(beam_sz, batch_sz, 0, 1, 2, 4, GlobalScorerStub(), 0, 30,
                      False, ngram_repeat, set(), False, 0.)
    beam.initialize(device_init, torch.randint(0, 30, (batch_sz, )))
    for i in range(n_iterations):
        # non-interesting beams are going to get dummy values
        word_logprobs = torch.full((batch_sz * beam_sz, n_words), base_logprob)
        if i == 0:
            # on initial round, only predicted scores for beam 0
            # matter. Make two predictions. Top one will be repeated
            # in beam zero, second one will live on in beam 1.
            word_logprobs[0::beam_sz, repeat_idx] = repeat_logprob
            word_logprobs[0::beam_sz, repeat_idx + i + 1] = no_repeat_logprob
        else:
            # predict the same thing in beam 0
            word_logprobs[0::beam_sz, repeat_idx] = repeat_logprob
            # continue pushing around what beam 1 predicts
            word_logprobs[0::beam_sz, repeat_idx + i + 1] = no_repeat_logprob
        attns = torch.randn(1, batch_sz * beam_sz, 0)
        beam.advance(word_logprobs, attns)
        if ngram_repeat > 0:
            # NOTE: IGNORE IT FOR NOW
            # if i < ngram_repeat:
            #     assertFalse(
            #         beam.topk_log_probs[0::beam_sz].eq(
            #             BLOCKED_SCORE).any())
            #     assertFalse(
            #         beam.topk_log_probs[1::beam_sz].eq(
            #             BLOCKED_SCORE).any())
            # elif i == ngram_repeat:
            #     assertFalse(
            #         beam.topk_log_probs[:, 0].eq(
            #             BLOCKED_SCORE).any())

            #     expected = torch.full([batch_sz, beam_sz], base_logprob)
            #     expected[:, 0] = (i+1) * no_repeat_logprob
            #     expected[:, 1] = BLOCKED_SCORE
            # else:
            #     expected = torch.full([batch_sz, beam_sz], base_logprob)
            #     expected[:, 0] = i * no_repeat_logprob
            #     expected[:, 1] = BLOCKED_SCORE
            pass
        # log("Iteration (%d): expected %s" % (i, str(expected)))
        log("Iteration (%d): logprobs %s" % (i, str(beam.topk_log_probs)))
        log("Iteration (%d): seq %s" % (i, str(beam.alive_seq)))
        log("Iteration (%d): indices %s" % (i, str(beam.topk_ids)))
        if ngram_repeat > 0:
            log("Iteration (%d): blocked %s" % (i, str(beam.forbidden_tokens)))
    return beam.topk_log_probs, beam.alive_seq
Example #17
0
    def test_beam_returns_attn_with_correct_length(self):
        beam_sz = 5
        batch_sz = 3
        n_words = 100
        _non_eos_idxs = [47, 51, 13, 88, 99]
        valid_score_dist = torch.log_softmax(torch.tensor(
            [6., 5., 4., 3., 2., 1.]), dim=0)
        min_length = 5
        eos_idx = 2
        inp_lens = torch.randint(1, 30, (batch_sz,))
        beam = BeamSearch(
            beam_sz, batch_sz, 0, 1, 2, 2,
            torch.device("cpu"), GlobalScorerStub(),
            min_length, 30, True, 0, set(),
            inp_lens, False, 0.)
        for i in range(min_length + 2):
            # non-interesting beams are going to get dummy values
            word_probs = torch.full(
                (batch_sz * beam_sz, n_words), -float('inf'))
            if i == 0:
                # "best" prediction is eos - that should be blocked
                word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                # include at least beam_sz predictions OTHER than EOS
                # that are greater than -1e20
                for j, score in zip(_non_eos_idxs, valid_score_dist[1:]):
                    word_probs[0::beam_sz, j] = score
            elif i <= min_length:
                # predict eos in beam 1
                word_probs[1::beam_sz, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz-1, k)
                    word_probs[beam_idx::beam_sz, j] = score
            else:
                word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                word_probs[1::beam_sz, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz-1, k)
                    word_probs[beam_idx::beam_sz, j] = score

            attns = torch.randn(1, batch_sz * beam_sz, 53)
            beam.advance(word_probs, attns)
            if i < min_length:
                self.assertFalse(beam.done)
                # no top beams are finished yet
                for b in range(batch_sz):
                    self.assertEqual(beam.attention[b], [])
            elif i == min_length:
                # beam 1 dies on min_length
                self.assertTrue(beam.is_finished[:, 1].all())
                beam.update_finished()
                self.assertFalse(beam.done)
                # no top beams are finished yet
                for b in range(batch_sz):
                    self.assertEqual(beam.attention[b], [])
            else:  # i > min_length
                # beam 0 dies on the step after beam 1 dies
                self.assertTrue(beam.is_finished[:, 0].all())
                beam.update_finished()
                self.assertTrue(beam.done)
                # top beam is finished now so there are attentions
                for b in range(batch_sz):
                    # two beams are finished in each batch
                    self.assertEqual(len(beam.attention[b]), 2)
                    for k in range(2):
                        # second dim is cut down to the non-padded src length
                        self.assertEqual(beam.attention[b][k].shape[-1],
                                         inp_lens[b])
                    # first dim is equal to the time of death
                    # (beam 0 died at current step - adjust for SOS)
                    self.assertEqual(beam.attention[b][0].shape[0], i+1)
                    # (beam 1 died at last step - adjust for SOS)
                    self.assertEqual(beam.attention[b][1].shape[0], i)
                # behavior gets weird when beam is already done so just stop
                break
Example #18
0
    def _translate_batch(
            self,
            batch,
            src_vocabs,
            max_length,
            min_length=0,
            ratio=0.,
            n_best=1,
            return_attention=False,
            xlation_builder=None,
    ):
        # TODO: support these blacklisted features.
        assert not self.dump_beam

        # (0) Prep the components of the search.
        use_src_map = self.copy_attn
        beam_size = self.beam_size
        batch_size = batch.batch_size

        # (1) Run the encoder on the src.
        src_list, enc_states_list, memory_bank_list, src_lengths_list = self._run_encoder(batch)
        self.model.decoder.init_state(src_list, memory_bank_list, enc_states_list)

        results = {
            "predictions": None,
            "scores": None,
            "attention": None,
            "batch": batch,
            "gold_score": self._gold_score(
                batch, memory_bank_list, src_lengths_list, src_vocabs, use_src_map,
                enc_states_list, batch_size, src_list)}
        
        # (2) Repeat src objects `beam_size` times.
        # We use batch_size x beam_size
        src_map_list = list()
        for src_type in self.src_types:
            src_map_list.append((tile(getattr(batch, f"src_map.{src_type}"), beam_size, dim=1) if use_src_map else None))
        # end for

        self.model.decoder.map_state(
            lambda state, dim: tile(state, beam_size, dim=dim))

        memory_lengths_list = list()
        memory_lengths = list()
        for src_i in range(len(memory_bank_list)):
            if isinstance(memory_bank_list[src_i], tuple):
                memory_bank_list[src_i] = tuple(tile(x, beam_size, dim=1) for x in memory_bank_list[src_i])
                mb_device = memory_bank_list[src_i][0].device
            else:
                memory_bank_list[src_i] = tile(memory_bank_list[src_i], beam_size, dim=1)
                mb_device = memory_bank_list[src_i].device
            # end if
            memory_lengths_list.append(tile(src_lengths_list[src_i], beam_size))
            memory_lengths.append(src_lengths_list[src_i])
        # end for
        memory_lengths = tile(torch.stack(memory_lengths, dim=0).sum(dim=0), beam_size)

        indexes = tile(torch.tensor(list(range(batch_size)), device=self._dev), beam_size)

        # (0) pt 2, prep the beam object
        beam = BeamSearch(
            beam_size,
            n_best=n_best,
            batch_size=batch_size,
            global_scorer=self.global_scorer,
            pad=self._tgt_pad_idx,
            eos=self._tgt_eos_idx,
            bos=self._tgt_bos_idx,
            min_length=min_length,
            ratio=ratio,
            max_length=max_length,
            mb_device=mb_device,
            return_attention=return_attention,
            stepwise_penalty=self.stepwise_penalty,
            block_ngram_repeat=self.block_ngram_repeat,
            exclusion_tokens=self._exclusion_idxs,
            memory_lengths=memory_lengths)

        for step in range(max_length):
            decoder_input = beam.current_predictions.view(1, -1, 1)

            log_probs, attn = self._decode_and_generate(
                decoder_input,
                memory_bank_list,
                batch,
                src_vocabs,
                memory_lengths_list=memory_lengths_list,
                src_map_list=src_map_list,
                step=step,
                batch_offset=beam._batch_offset)

            if self.reranker is not None:
                log_probs = self.reranker.rerank_step_beam_batch(
                    batch,
                    beam,
                    self.beam_size,
                    indexes,
                    log_probs,
                    attn,
                    self.fields["tgt"].base_field.vocab,
                    xlation_builder,
                )
            # end if

            non_finished = None
            beam.advance(log_probs, attn)
            any_beam_is_finished = beam.is_finished.any()
            if any_beam_is_finished:
                non_finished = beam.update_finished()
                if beam.done:
                    break

            select_indices = beam.current_origin

            if any_beam_is_finished:
                # Reorder states.
                for src_i in range(len(memory_bank_list)):
                    if isinstance(memory_bank_list[src_i], tuple):
                        memory_bank_list[src_i] = tuple(x.index_select(1, select_indices)
                                            for x in memory_bank_list[src_i])
                    else:
                        memory_bank_list[src_i] = memory_bank_list[src_i].index_select(1, select_indices)
                    # end if

                    memory_lengths_list[src_i] = memory_lengths_list[src_i].index_select(0, select_indices)
                # end for

                if use_src_map and src_map_list[0] is not None:
                    for src_i in range(len(src_map_list)):
                        src_map_list[src_i] = src_map_list[src_i].index_select(1, select_indices)
                    # end for
                # end if

                indexes = indexes.index_select(0, select_indices)

            self.model.decoder.map_state(
                lambda state, dim: state.index_select(dim, select_indices))

        results["scores"] = beam.scores
        results["predictions"] = beam.predictions
        results["attention"] = beam.attention
        return results
Example #19
0
def batch_beam_search_trs(model, inputs, batch_size, device="cpu", max_len=128, beam_size=20, n_best=1, alpha=1.):
    """ beam search with batch input for Transformer model

    Arguments:
        beam {onmt.BeamSearch} -- opennmt BeamSearch class
        model {torch.nn.Module} -- subclass of torch.nn.Module, required to implement .encode() and .decode() method
        inputs {list} -- list of torch.Tensor for input of encode()

    Keyword Arguments:
        device {str} -- device to eval model (default: {"cpu"})

    Returns:
        result -- 2D list (B, N-best), each element is an (seq, score) pair
    """
    beam = BeamSearch(beam_size, batch_size,
                    pad=C.PAD, bos=C.BOS, eos=C.EOS,
                    n_best=n_best,
                    mb_device=device,
                    global_scorer=GNMTGlobalScorer(alpha, 0.1, "avg", "none"),
                    min_length=0,
                    max_length=max_len,
                    ratio=0.0,
                    memory_lengths=None,
                    block_ngram_repeat=False,
                    exclusion_tokens=None,
                    stepwise_penalty=True,
                    return_attention=False,
                    )
    model.eval()
    is_finished = [False] * beam.batch_size
    with torch.no_grad():
        src_ids, _, src_pos, _, _, src_key_padding_mask, _, original_memory_key_padding_mask = list(
            map(lambda x: x.to(device), inputs))
        
        if src_ids.shape[1]!=batch_size:
            diff = batch_size - src_ids.shape[1]
            src_ids = torch.cat([src_ids] + [src_ids[:,:1]] * diff, dim=1)
            src_pos = torch.cat([src_pos] + [src_pos[:,:1]] * diff, dim=1)
            src_key_padding_mask = torch.cat([src_key_padding_mask]+[src_key_padding_mask[:1]]* diff, dim=0)
            original_memory_key_padding_mask = torch.cat([original_memory_key_padding_mask] +[original_memory_key_padding_mask[:1]]*diff, dim=0)


        model.to(device)
        original_memory = model.encode(src_ids, src_pos, src_key_padding_mask=src_key_padding_mask)

        memory = original_memory
        memory_key_padding_mask = original_memory_key_padding_mask
        while not beam.done:
            len_decoder_inputs = beam.alive_seq.shape[1]
            dec_pos = torch.arange(1, len_decoder_inputs+1).repeat(beam.alive_seq.shape[0], 1).permute(1, 0).to(device)

            # unsqueeze the memory and memory_key_padding_mask in B dim to match the size (BM*BS)
            repeated_memory = memory.repeat(1, 1, beam.beam_size).reshape(
                memory.shape[0], -1, memory.shape[-1])
            repeated_memory_key_padding_mask = memory_key_padding_mask.repeat(
                1, beam.beam_size).reshape(-1, memory_key_padding_mask.shape[1])

            decoder_outputs = model.decode(beam.alive_seq.permute(1, 0), dec_pos, _, repeated_memory, memory_key_padding_mask=repeated_memory_key_padding_mask)[-1]
            if hasattr(model, "proj"):
                logits = model.proj(decoder_outputs)
            elif hasattr(model, "gen"):
                logits = model.gen(decoder_outputs)
            else:
                raise ValueError("Unknown generator!")

            log_probs = torch.nn.functional.log_softmax(logits, dim=1)
            beam.advance(log_probs, None)
            if beam.is_finished.any():
                beam.update_finished()

                # select data for the still-alive index
                for i, n_best in enumerate(beam.predictions):
                    if is_finished[i] == False and len(n_best) == beam.n_best:
                        is_finished[i] = True

                alive_example_idx = [i for i in range(
                    len(is_finished)) if not is_finished[i]]
                if alive_example_idx:
                    memory = original_memory[:, alive_example_idx, :]
                    memory_key_padding_mask = original_memory_key_padding_mask[alive_example_idx]

    # packing data for easy accessing
    results = []
    for batch_preds, batch_scores in zip(beam.predictions, beam.scores):
        n_best_result = []
        for n_best_pred, n_best_score in zip(batch_preds, batch_scores):
            assert isinstance(n_best_pred, torch.Tensor)
            assert isinstance(n_best_score, torch.Tensor)
            n_best_result.append(
                (n_best_pred.tolist(), n_best_score.item())
            )
        results.append(n_best_result)

    return results
Example #20
0
    def test_beam_returns_attn_with_correct_length(self):
        beam_sz = 5
        batch_sz = 3
        n_words = 100
        _non_eos_idxs = [47, 51, 13, 88, 99]
        valid_score_dist = torch.log_softmax(torch.tensor(
            [6., 5., 4., 3., 2., 1.]),
                                             dim=0)
        min_length = 5
        eos_idx = 2
        inp_lens = torch.randint(1, 30, (batch_sz, ))
        beam = BeamSearch(beam_sz, batch_sz, 0, 1, 2, 2, GlobalScorerStub(),
                          min_length, 30, True, 0, set(), False, 0.)
        device_init = torch.zeros(1, 1)
        _, _, inp_lens, _ = beam.initialize(device_init, inp_lens)
        # inp_lens is tiled in initialize, reassign to make attn match
        for i in range(min_length + 2):
            # non-interesting beams are going to get dummy values
            word_probs = torch.full((batch_sz * beam_sz, n_words),
                                    -float('inf'))
            if i == 0:
                # "best" prediction is eos - that should be blocked
                word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                # include at least beam_sz predictions OTHER than EOS
                # that are greater than -1e20
                for j, score in zip(_non_eos_idxs, valid_score_dist[1:]):
                    word_probs[0::beam_sz, j] = score
            elif i <= min_length:
                # predict eos in beam 1
                word_probs[1::beam_sz, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz - 1, k)
                    word_probs[beam_idx::beam_sz, j] = score
            else:
                word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                word_probs[1::beam_sz, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz - 1, k)
                    word_probs[beam_idx::beam_sz, j] = score

            attns = torch.randn(1, batch_sz * beam_sz, 53)
            beam.advance(word_probs, attns)
            if i < min_length:
                self.assertFalse(beam.done)
                # no top beams are finished yet
                for b in range(batch_sz):
                    self.assertEqual(beam.attention[b], [])
            elif i == min_length:
                # beam 1 dies on min_length
                self.assertTrue(beam.is_finished[:, 1].all())
                beam.update_finished()
                self.assertFalse(beam.done)
                # no top beams are finished yet
                for b in range(batch_sz):
                    self.assertEqual(beam.attention[b], [])
            else:  # i > min_length
                # beam 0 dies on the step after beam 1 dies
                self.assertTrue(beam.is_finished[:, 0].all())
                beam.update_finished()
                self.assertTrue(beam.done)
                # top beam is finished now so there are attentions
                for b in range(batch_sz):
                    # two beams are finished in each batch
                    self.assertEqual(len(beam.attention[b]), 2)
                    for k in range(2):
                        # second dim is cut down to the non-padded src length
                        self.assertEqual(beam.attention[b][k].shape[-1],
                                         inp_lens[b])
                    # first dim is equal to the time of death
                    # (beam 0 died at current step - adjust for SOS)
                    self.assertEqual(beam.attention[b][0].shape[0], i + 1)
                    # (beam 1 died at last step - adjust for SOS)
                    self.assertEqual(beam.attention[b][1].shape[0], i)
                # behavior gets weird when beam is already done so just stop
                break
Example #21
0
                   mb_device=mb_device,
                   return_attention=return_attention,
                   stepwise_penalty=stepwise_penalty,
                   block_ngram_repeat=block_ngram_repeat,
                   exclusion_tokens=exclusion_idxs,
                   memory_lengths=memory_lengths)
 for step in range(max_length):
     decoder_input = beam.current_predictions.view(1, -1, 1)
     dec_out, dec_attn = model.decoder(decoder_input,
                                       memory_bank,
                                       memory_lengths=memory_lengths,
                                       step=step)
     log_probs = model.generator(dec_out.squeeze(0))
     attn = dec_attn['std']
     log_probs, attn = log_probs.detach(), attn.detach()
     beam.advance(log_probs, attn)
     any_beam_is_finished = beam.is_finished.any()
     if any_beam_is_finished:
         beam.update_finished()
         if beam.done:
             break
     select_indices = beam.current_origin
     if any_beam_is_finished:
         if isinstance(memory_bank, tuple):
             memory_bank = tuple(
                 x.index_select(1, select_indices) for x in memory_bank)
         else:
             memory_bank = memory_bank.index_select(1, select_indices)
         memory_lengths = memory_lengths.index_select(0, select_indices)
     model.decoder.map_state(
         lambda state, dim: state.index_select(dim, select_indices))