示例#1
0
 def test_advance_with_all_repeats_gets_blocked(self):
     # all beams repeat (beam >= 1 repeat dummy scores)
     beam_sz = 5
     n_words = 100
     repeat_idx = 47
     ngram_repeat = 3
     beam = Beam(beam_sz, 0, 1, 2, n_best=2,
                 exclusion_tokens=set(),
                 global_scorer=GlobalScorerStub(),
                 block_ngram_repeat=ngram_repeat)
     for i in range(ngram_repeat + 4):
         # predict repeat_idx over and over again
         word_probs = torch.full((beam_sz, n_words), -float('inf'))
         word_probs[0, repeat_idx] = 0
         attns = torch.randn(beam_sz)
         beam.advance(word_probs, attns)
         if i <= ngram_repeat:
             self.assertTrue(
                 beam.scores.equal(
                     torch.tensor(
                         [0] + [-float('inf')] * (beam_sz - 1))))
         else:
             self.assertTrue(
                 beam.scores.equal(torch.tensor(
                     [self.BLOCKED_SCORE] * beam_sz)))
示例#2
0
 def test_advance_with_some_repeats_gets_blocked(self):
     # beam 0 and beam >=2 will repeat (beam >= 2 repeat dummy scores)
     beam_sz = 5
     n_words = 100
     repeat_idx = 47
     ngram_repeat = 3
     beam = Beam(beam_sz, 0, 1, 2, n_best=2,
                 exclusion_tokens=set(),
                 global_scorer=GlobalScorerStub(),
                 block_ngram_repeat=ngram_repeat)
     for i in range(ngram_repeat + 4):
         # non-interesting beams are going to get dummy values
         word_probs = torch.full((beam_sz, n_words), -float('inf'))
         if i == 0:
             # on initial round, only predicted scores for beam 0
             # matter. Make two predictions. Top one will be repeated
             # in beam zero, second one will live on in beam 1.
             word_probs[0, repeat_idx] = -0.1
             word_probs[0, repeat_idx + i + 1] = -2.3
         else:
             # predict the same thing in beam 0
             word_probs[0, repeat_idx] = 0
             # continue pushing around what beam 1 predicts
             word_probs[1, repeat_idx + i + 1] = 0
         attns = torch.randn(beam_sz)
         beam.advance(word_probs, attns)
         if i <= ngram_repeat:
             self.assertFalse(beam.scores[0].eq(self.BLOCKED_SCORE))
             self.assertFalse(beam.scores[1].eq(self.BLOCKED_SCORE))
         else:
             # now beam 0 dies (along with the others), beam 1 -> beam 0
             self.assertFalse(beam.scores[0].eq(self.BLOCKED_SCORE))
             self.assertTrue(
                 beam.scores[1:].equal(torch.tensor(
                     [self.BLOCKED_SCORE] * (beam_sz - 1))))
示例#3
0
    def test_beam_is_done_when_n_best_beams_eos_using_min_length(self):
        # this is also a test that when block_ngram_repeat=0,
        # repeating is acceptable
        beam_sz = 5
        n_words = 100
        _non_eos_idxs = [47, 51, 13, 88, 99]
        valid_score_dist = torch.log_softmax(torch.tensor(
            [6., 5., 4., 3., 2., 1.]), dim=0)
        min_length = 5
        eos_idx = 2
        beam = Beam(beam_sz, 0, 1, eos_idx, n_best=2,
                    exclusion_tokens=set(),
                    min_length=min_length,
                    global_scorer=GlobalScorerStub(),
                    block_ngram_repeat=0)
        for i in range(min_length + 4):
            # non-interesting beams are going to get dummy values
            word_probs = torch.full((beam_sz, n_words), -float('inf'))
            if i == 0:
                # "best" prediction is eos - that should be blocked
                word_probs[0, eos_idx] = valid_score_dist[0]
                # include at least beam_sz predictions OTHER than EOS
                # that are greater than -1e20
                for j, score in zip(_non_eos_idxs, valid_score_dist[1:]):
                    word_probs[0, j] = score
            elif i <= min_length:
                # predict eos in beam 1
                word_probs[1, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz-1, k)
                    word_probs[beam_idx, j] = score
            else:
                word_probs[0, eos_idx] = valid_score_dist[0]
                word_probs[1, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz-1, k)
                    word_probs[beam_idx, j] = score

            attns = torch.randn(beam_sz)
            beam.advance(word_probs, attns)
            if i < min_length:
                self.assertFalse(beam.done)
            elif i == min_length:
                # beam 1 dies on min_length
                self.assertEqual(beam.finished[0][1], beam.min_length + 1)
                self.assertEqual(beam.finished[0][2], 1)
                self.assertFalse(beam.done)
            else:  # i > min_length
                # beam 0 dies on the step after beam 1 dies
                self.assertEqual(beam.finished[1][1], beam.min_length + 2)
                self.assertEqual(beam.finished[1][2], 0)
                self.assertTrue(beam.done)
示例#4
0
    def test_doesnt_predict_eos_if_shorter_than_min_len(self):
        # beam 0 will always predict EOS. The other beams will predict
        # non-eos scores.
        # this is also a test that when block_ngram_repeat=0,
        # repeating is acceptable
        beam_sz = 5
        n_words = 100
        _non_eos_idxs = [47, 51, 13, 88, 99]
        valid_score_dist = torch.log_softmax(torch.tensor(
            [6., 5., 4., 3., 2., 1.]), dim=0)
        min_length = 5
        eos_idx = 2
        beam = Beam(beam_sz, 0, 1, eos_idx, n_best=2,
                    exclusion_tokens=set(),
                    min_length=min_length,
                    global_scorer=GlobalScorerStub(),
                    block_ngram_repeat=0)
        for i in range(min_length + 4):
            # non-interesting beams are going to get dummy values
            word_probs = torch.full((beam_sz, n_words), -float('inf'))
            if i == 0:
                # "best" prediction is eos - that should be blocked
                word_probs[0, eos_idx] = valid_score_dist[0]
                # include at least beam_sz predictions OTHER than EOS
                # that are greater than -1e20
                for j, score in zip(_non_eos_idxs, valid_score_dist[1:]):
                    word_probs[0, j] = score
            else:
                # predict eos in beam 0
                word_probs[0, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz-1, k)
                    word_probs[beam_idx, j] = score

            attns = torch.randn(beam_sz)
            beam.advance(word_probs, attns)
            if i < min_length:
                expected_score_dist = (i+1) * valid_score_dist[1:]
                self.assertTrue(beam.scores.allclose(expected_score_dist))
            elif i == min_length:
                # now the top beam has ended and no others have
                # first beam finished had length beam.min_length
                self.assertEqual(beam.finished[0][1], beam.min_length + 1)
                # first beam finished was 0
                self.assertEqual(beam.finished[0][2], 0)
            else:  # i > min_length
                # not of interest, but want to make sure it keeps running
                # since only beam 0 terminates and n_best = 2
                pass
 def test_repeating_excluded_index_does_not_die(self):
     # beam 0 and beam >= 2 will repeat (beam 2 repeats excluded idx)
     beam_sz = 5
     n_words = 100
     repeat_idx = 47  # will be repeated and should be blocked
     repeat_idx_ignored = 7  # will be repeated and should not be blocked
     ngram_repeat = 3
     beam = Beam(beam_sz,
                 0,
                 1,
                 2,
                 n_best=2,
                 exclusion_tokens=set([repeat_idx_ignored]),
                 global_scorer=GlobalScorerStub(),
                 block_ngram_repeat=ngram_repeat)
     for i in range(ngram_repeat + 4):
         # non-interesting beams are going to get dummy values
         word_probs = torch.full((beam_sz, n_words), -float('inf'))
         if i == 0:
             word_probs[0, repeat_idx] = -0.1
             word_probs[0, repeat_idx + i + 1] = -2.3
             word_probs[0, repeat_idx_ignored] = -5.0
         else:
             # predict the same thing in beam 0
             word_probs[0, repeat_idx] = 0
             # continue pushing around what beam 1 predicts
             word_probs[1, repeat_idx + i + 1] = 0
             # predict the allowed-repeat again in beam 2
             word_probs[2, repeat_idx_ignored] = 0
         attns = torch.randn(beam_sz)
         beam.advance(word_probs, attns)
         if i <= ngram_repeat:
             self.assertFalse(beam.scores[0].eq(self.BLOCKED_SCORE))
             self.assertFalse(beam.scores[1].eq(self.BLOCKED_SCORE))
             self.assertFalse(beam.scores[2].eq(self.BLOCKED_SCORE))
         else:
             # now beam 0 dies, beam 1 -> beam 0, beam 2 -> beam 1
             # and the rest die
             self.assertFalse(beam.scores[0].eq(self.BLOCKED_SCORE))
             # since all preds after i=0 are 0, we can check
             # that the beam is the correct idx by checking that
             # the curr score is the initial score
             self.assertTrue(beam.scores[0].eq(-2.3))
             self.assertFalse(beam.scores[1].eq(self.BLOCKED_SCORE))
             self.assertTrue(beam.scores[1].eq(-5.0))
             self.assertTrue(beam.scores[2:].equal(
                 torch.tensor([self.BLOCKED_SCORE] * (beam_sz - 2))))
示例#6
0
 def test_repeating_excluded_index_does_not_die(self):
     # beam 0 and beam >= 2 will repeat (beam 2 repeats excluded idx)
     beam_sz = 5
     n_words = 100
     repeat_idx = 47  # will be repeated and should be blocked
     repeat_idx_ignored = 7  # will be repeated and should not be blocked
     ngram_repeat = 3
     beam = Beam(beam_sz, 0, 1, 2, n_best=2,
                 exclusion_tokens=set([repeat_idx_ignored]),
                 global_scorer=GlobalScorerStub(),
                 block_ngram_repeat=ngram_repeat)
     for i in range(ngram_repeat + 4):
         # non-interesting beams are going to get dummy values
         word_probs = torch.full((beam_sz, n_words), -float('inf'))
         if i == 0:
             word_probs[0, repeat_idx] = -0.1
             word_probs[0, repeat_idx + i + 1] = -2.3
             word_probs[0, repeat_idx_ignored] = -5.0
         else:
             # predict the same thing in beam 0
             word_probs[0, repeat_idx] = 0
             # continue pushing around what beam 1 predicts
             word_probs[1, repeat_idx + i + 1] = 0
             # predict the allowed-repeat again in beam 2
             word_probs[2, repeat_idx_ignored] = 0
         attns = torch.randn(beam_sz)
         beam.advance(word_probs, attns)
         if i <= ngram_repeat:
             self.assertFalse(beam.scores[0].eq(self.BLOCKED_SCORE))
             self.assertFalse(beam.scores[1].eq(self.BLOCKED_SCORE))
             self.assertFalse(beam.scores[2].eq(self.BLOCKED_SCORE))
         else:
             # now beam 0 dies, beam 1 -> beam 0, beam 2 -> beam 1
             # and the rest die
             self.assertFalse(beam.scores[0].eq(self.BLOCKED_SCORE))
             # since all preds after i=0 are 0, we can check
             # that the beam is the correct idx by checking that
             # the curr score is the initial score
             self.assertTrue(beam.scores[0].eq(-2.3))
             self.assertFalse(beam.scores[1].eq(self.BLOCKED_SCORE))
             self.assertTrue(beam.scores[1].eq(-5.0))
             self.assertTrue(
                 beam.scores[2:].equal(torch.tensor(
                     [self.BLOCKED_SCORE] * (beam_sz - 2))))
 def test_beam_advance_against_known_reference(self):
     scorer = GNMTGlobalScorer(0.7, 0., "avg", "none")
     beam = Beam(self.BEAM_SZ, 0, 1, self.EOS_IDX, n_best=self.N_BEST,
                 exclusion_tokens=set(),
                 min_length=0,
                 global_scorer=scorer,
                 block_ngram_repeat=0)
     expected_beam_scores = self.init_step(beam)
     expected_beam_scores = self.first_step(beam, expected_beam_scores, 3)
     expected_beam_scores = self.second_step(beam, expected_beam_scores, 4)
     self.third_step(beam, expected_beam_scores, 5)
    def test_beam_advance_against_known_reference(self):
        beam = Beam(self.BEAM_SZ, 0, 1, self.EOS_IDX, n_best=self.N_BEST,
                    exclusion_tokens=set(),
                    min_length=0,
                    global_scorer=GlobalScorerStub(),
                    block_ngram_repeat=0)

        expected_beam_scores = self.init_step(beam)
        expected_beam_scores = self.first_step(beam, expected_beam_scores, 1)
        expected_beam_scores = self.second_step(beam, expected_beam_scores, 1)
        self.third_step(beam, expected_beam_scores, 1)
 def test_advance_with_some_repeats_gets_blocked(self):
     # beam 0 and beam >=2 will repeat (beam >= 2 repeat dummy scores)
     beam_sz = 5
     n_words = 100
     repeat_idx = 47
     ngram_repeat = 3
     beam = Beam(beam_sz,
                 0,
                 1,
                 2,
                 n_best=2,
                 exclusion_tokens=set(),
                 global_scorer=GlobalScorerStub(),
                 block_ngram_repeat=ngram_repeat)
     for i in range(ngram_repeat + 4):
         # non-interesting beams are going to get dummy values
         word_probs = torch.full((beam_sz, n_words), -float('inf'))
         if i == 0:
             # on initial round, only predicted scores for beam 0
             # matter. Make two predictions. Top one will be repeated
             # in beam zero, second one will live on in beam 1.
             word_probs[0, repeat_idx] = -0.1
             word_probs[0, repeat_idx + i + 1] = -2.3
         else:
             # predict the same thing in beam 0
             word_probs[0, repeat_idx] = 0
             # continue pushing around what beam 1 predicts
             word_probs[1, repeat_idx + i + 1] = 0
         attns = torch.randn(beam_sz)
         beam.advance(word_probs, attns)
         if i <= ngram_repeat:
             self.assertFalse(beam.scores[0].eq(self.BLOCKED_SCORE))
             self.assertFalse(beam.scores[1].eq(self.BLOCKED_SCORE))
         else:
             # now beam 0 dies (along with the others), beam 1 -> beam 0
             self.assertFalse(beam.scores[0].eq(self.BLOCKED_SCORE))
             self.assertTrue(beam.scores[1:].equal(
                 torch.tensor([self.BLOCKED_SCORE] * (beam_sz - 1))))
    def test_beam_is_done_when_n_best_beams_eos_using_min_length(self):
        # this is also a test that when block_ngram_repeat=0,
        # repeating is acceptable
        beam_sz = 5
        n_words = 100
        _non_eos_idxs = [47, 51, 13, 88, 99]
        valid_score_dist = torch.log_softmax(torch.tensor(
            [6., 5., 4., 3., 2., 1.]),
                                             dim=0)
        min_length = 5
        eos_idx = 2
        beam = Beam(beam_sz,
                    0,
                    1,
                    eos_idx,
                    n_best=2,
                    exclusion_tokens=set(),
                    min_length=min_length,
                    global_scorer=GlobalScorerStub(),
                    block_ngram_repeat=0)
        for i in range(min_length + 4):
            # non-interesting beams are going to get dummy values
            word_probs = torch.full((beam_sz, n_words), -float('inf'))
            if i == 0:
                # "best" prediction is eos - that should be blocked
                word_probs[0, eos_idx] = valid_score_dist[0]
                # include at least beam_sz predictions OTHER than EOS
                # that are greater than -1e20
                for j, score in zip(_non_eos_idxs, valid_score_dist[1:]):
                    word_probs[0, j] = score
            elif i <= min_length:
                # predict eos in beam 1
                word_probs[1, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz - 1, k)
                    word_probs[beam_idx, j] = score
            else:
                word_probs[0, eos_idx] = valid_score_dist[0]
                word_probs[1, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz - 1, k)
                    word_probs[beam_idx, j] = score

            attns = torch.randn(beam_sz)
            beam.advance(word_probs, attns)
            if i < min_length:
                self.assertFalse(beam.done)
            elif i == min_length:
                # beam 1 dies on min_length
                self.assertEqual(beam.finished[0][1], beam.min_length + 1)
                self.assertEqual(beam.finished[0][2], 1)
                self.assertFalse(beam.done)
            else:  # i > min_length
                # beam 0 dies on the step after beam 1 dies
                self.assertEqual(beam.finished[1][1], beam.min_length + 2)
                self.assertEqual(beam.finished[1][2], 0)
                self.assertTrue(beam.done)
示例#11
0
    def translate_batch(self, batch):
        beam_size = self.beam_size
        tgt_field = self.fields['tgt'][0][1]
        vocab = tgt_field.vocab

        pad = vocab.stoi[tgt_field.pad_token]
        eos = vocab.stoi[tgt_field.eos_token]
        bos = vocab.stoi[tgt_field.init_token]
        b = Beam(beam_size,
                 n_best=self.n_best,
                 cuda=self.cuda,
                 pad=pad,
                 eos=eos,
                 bos=bos)

        src, src_lengths = batch.src
        # why doesn't this contain inflection source lengths when ensembling?
        side_info = side_information(batch)

        encoder_out = self.model.encode(src, lengths=src_lengths, **side_info)
        enc_states = encoder_out["enc_state"]
        memory_bank = encoder_out["memory_bank"]
        infl_memory_bank = encoder_out.get("inflection_memory_bank", None)

        self.model.init_decoder_state(enc_states)

        results = dict()

        if "tgt" in batch.__dict__:
            results["gold_score"] = self._score_target(
                batch,
                memory_bank,
                src_lengths,
                inflection_memory_bank=infl_memory_bank,
                **side_info)
            self.model.init_decoder_state(enc_states)
        else:
            results["gold_score"] = 0

        # (2) Repeat src objects `beam_size` times.
        self.model.map_decoder_state(
            lambda state, dim: tile(state, beam_size, dim=dim))

        if isinstance(memory_bank, tuple):
            memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank)
        else:
            memory_bank = tile(memory_bank, beam_size, dim=1)
        memory_lengths = tile(src_lengths, beam_size)

        if infl_memory_bank is not None:
            if isinstance(infl_memory_bank, tuple):
                infl_memory_bank = tuple(
                    tile(x, beam_size, dim=1) for x in infl_memory_bank)
            else:
                infl_memory_bank = tile(infl_memory_bank, beam_size, dim=1)
            tiled_infl_len = tile(side_info["inflection_lengths"], beam_size)
            side_info["inflection_lengths"] = tiled_infl_len

        if "language" in side_info:
            side_info["language"] = tile(side_info["language"], beam_size)

        for i in range(self.max_length):
            if b.done():
                break

            inp = b.current_state.unsqueeze(0)

            # the decoder expects an input of tgt_len x batch
            dec_out, dec_attn = self.model.decode(
                inp,
                memory_bank,
                memory_lengths=memory_lengths,
                inflection_memory_bank=infl_memory_bank,
                **side_info)
            attn = dec_attn["lemma"].squeeze(0)
            out = self.model.generator(dec_out.squeeze(0),
                                       transform=True,
                                       **side_info)

            # b.advance will take attn (beam size x src length)
            b.advance(out, dec_attn)
            select_indices = b.current_origin

            self.model.map_decoder_state(
                lambda state, dim: state.index_select(dim, select_indices))

        scores, ks = b.sort_finished()
        hyps, attn, out_probs = [], [], []
        for i, (times, k) in enumerate(ks[:self.n_best]):
            hyp, att, out_p = b.get_hyp(times, k)
            hyps.append(hyp)
            attn.append(att)
            out_probs.append(out_p)

        results["preds"] = hyps
        results["scores"] = scores
        results["attn"] = attn

        if self.beam_accum is not None:
            parent_ids = [t.tolist() for t in b.prev_ks]
            self.beam_accum["beam_parent_ids"].append(parent_ids)
            scores = [["%4f" % s for s in t.tolist()]
                      for t in b.all_scores][1:]
            self.beam_accum["scores"].append(scores)
            pred_ids = [[vocab.itos[i] for i in t.tolist()]
                        for t in b.next_ys][1:]
            self.beam_accum["predicted_ids"].append(pred_ids)

        if self.attn_path is not None:
            save_attn = {k: v.cpu() for k, v in attn[0].items()}
            src_seq = self.itos(src, "src")
            pred_seq = self.itos(hyps[0], "tgt")
            attn_dict = {"src": src_seq, "pred": pred_seq, "attn": save_attn}
            if "inflection" in save_attn:
                inflection_seq = self.itos(batch.inflection[0], "inflection")
                attn_dict["inflection"] = inflection_seq
            self.attns.append(attn_dict)

        if self.probs_path is not None:
            save_probs = out_probs[0].cpu()
            self.probs.append(save_probs)

        return results
    def _translate_batch(self, batch, data):
        # (0) Prep each of the components of the search.
        # And helper method for reducing verbosity.
        beam_size = self.beam_size
        batch_size = batch.batch_size
        data_type = data.data_type
        vocab = self.fields["tgt"].vocab

        # Define a list of tokens to exclude from ngram-blocking
        # exclusion_list = ["<t>", "</t>", "."]
        exclusion_tokens = set(
            [vocab.stoi[t] for t in self.ignore_when_blocking])

        beam_batch = [
            Beam(
                beam_size,
                n_best=self.n_best,
                cuda=self.cuda,
                global_scorer=self.global_scorer,
                pad=vocab.stoi[inputters.PAD_WORD],
                eos=vocab.stoi[inputters.EOS_WORD],
                bos=vocab.stoi[inputters.BOS_WORD],
                min_length=self.min_length,
                stepwise_penalty=self.stepwise_penalty,
                block_ngram_repeat=self.block_ngram_repeat,
                exclusion_tokens=exclusion_tokens,
            ) for __ in range(batch_size)
        ]

        # (1) Run the encoder on the src.
        src, enc_states, memory_bank, src_lengths = self._run_encoder(
            batch, data_type)
        self.model.decoder.init_state(src, memory_bank, enc_states)

        results = {}
        results["predictions"] = []
        results["scores"] = []
        results["attention"] = []
        results["batch"] = batch
        if "tgt" in batch.__dict__:
            results["gold_score"] = self._score_target(
                batch,
                memory_bank,
                src_lengths,
                data,
                batch.src_map
                if data_type == "text" and self.copy_attn else None,
            )
            self.model.decoder.init_state(src, memory_bank, enc_states)
        else:
            results["gold_score"] = [0] * batch_size

        # (2) Repeat src objects `beam_size` times.
        # We use now  batch_size x beam_size (same as fast mode)
        src_map = (tile(batch.src_map, beam_size, dim=1)
                   if data.data_type == "text" and self.copy_attn else None)
        self.model.decoder.map_state(
            lambda state, dim: tile(state, beam_size, dim=dim))

        if isinstance(memory_bank, tuple):
            memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank)
        else:
            memory_bank = tile(memory_bank, beam_size, dim=1)
        memory_lengths = tile(src_lengths, beam_size)

        # (3) run the decoder to generate sentences, using beam search.
        for i in range(self.max_length):
            if all((b.done() for b in beam_batch)):
                break

            # (a) Construct batch x beam_size nxt words.
            # Get all the pending current beam words and arrange for forward.

            inp = torch.stack([b.current_state for b in beam_batch])
            inp = inp.view(1, -1, 1)

            # (b) Decode and forward
            out, beam_attn = self._decode_and_generate(
                inp,
                memory_bank,
                batch,
                data,
                memory_lengths=memory_lengths,
                src_map=src_map,
                step=i,
            )

            out = out.view(batch_size, beam_size, -1)
            beam_attn = beam_attn.view(batch_size, beam_size, -1)

            # (c) Advance each beam.
            select_indices_array = []
            # Loop over the batch_size number of beam
            for j, b in enumerate(beam_batch):
                b.advance(out[j, :], beam_attn[j, :, :memory_lengths[j]])
                select_indices_array.append(b.current_origin + j * beam_size)
            select_indices = torch.cat(select_indices_array)

            self.model.decoder.map_state(
                lambda state, dim: state.index_select(dim, select_indices))

        # (4) Extract sentences from beam.
        for b in beam_batch:
            n_best = self.n_best
            scores, ks = b.sort_finished(minimum=n_best)
            hyps, attn = [], []
            for i, (times, k) in enumerate(ks[:n_best]):
                hyp, att = b.get_hyp(times, k)
                hyps.append(hyp)
                attn.append(att)
            results["predictions"].append(hyps)
            results["scores"].append(scores)
            results["attention"].append(attn)

            if self.beam_accum is not None:
                parent_ids = [t.tolist() for t in b.prev_ks]
                self.beam_accum["beam_parent_ids"].append(parent_ids)
                scores = [["%4f" % s for s in t.tolist()]
                          for t in b.all_scores][1:]
                self.beam_accum["scores"].append(scores)
                pred_ids = [[vocab.itos[i] for i in t.tolist()]
                            for t in b.next_ys][1:]
                self.beam_accum["predicted_ids"].append(pred_ids)

            if self.beam_scores is not None:
                self.beam_scores.append(torch.stack(b.all_scores).cpu())

        return results
示例#13
0
    def generate(self, src, lengths, dec_idx, max_length=20, beam_size=5, n_best=1):
        assert dec_idx == 0 or dec_idx == 1
        batch_size = src.size(1)
        
        def var(a):
            return torch.tensor(a, requires_grad=False)
        
        def rvar(a):
            return var(a.repeat(1, beam_size, 1))
        
        def bottle(m):
            return m.view(batch_size * beam_size, -1)

        def unbottle(m):
            return m.view(beam_size, batch_size, -1)
        
        def from_beam(beam):
            ret = {"predictions": [],
                   "scores": [],
                   "attention": []}
            for b in beam:
                scores, ks = b.sort_finished(minimum=n_best)
                hyps, attn = [], []
                for i, (times, k) in enumerate(ks[:n_best]):
                    hyp, att = b.get_hyp(times, k)
                    hyps.append(hyp)
                    attn.append(att)
                ret["predictions"].append(hyps)
                ret["scores"].append(scores)
                ret["attention"].append(attn)
            return ret
        
        
        scorer = GNMTGlobalScorer(0, 0, "none", "none")
        
        beam = [Beam(beam_size, n_best=n_best,
                     cuda=self.cuda(),
                     global_scorer=scorer,
                     pad=PAD_IDX,
                     eos=SOS_IDX,
                     bos=EOS_IDX,
                     min_length=0,
                     stepwise_penalty=False,
                     block_ngram_repeat=0)
                for __ in range(batch_size)]
        
        enc_final, memory_bank = self.encoder(src, lengths)
        
        token = torch.full((1, batch_size, 1), SOS_IDX, dtype=torch.long, device=next(self.parameters()).device)
        dec_state = enc_final
        dec_state = self.choose_decoder(dec_idx).init_decoder_state(src, memory_bank, dec_state)
               
        memory_bank = rvar(memory_bank.data)
        memory_lengths = lengths.repeat(beam_size)
        dec_state.repeat_beam_size_times(beam_size)
        
        # unroll
        all_indices = []
        for i in range(max_length):
            if all((b.done() for b in beam)):
                break
                
            inp = var(torch.stack([b.get_current_state() for b in beam]).t().contiguous().view(1, -1))
            inp = inp.unsqueeze(2)
                
            decoder_output, dec_state, attn = self.choose_decoder(dec_idx)(inp, memory_bank, dec_state, memory_lengths=memory_lengths, step=i)
            
            decoder_output = decoder_output.squeeze(0)
            
            out = self.generator(decoder_output).data
            out = unbottle(out)
            
            # beam x tgt_vocab
            beam_attn = unbottle(attn["std"])
            
            for j, b in enumerate(beam):
                b.advance(out[:, j], beam_attn.data[:, j, :memory_lengths[j]])
                dec_state.beam_update(j, b.get_current_origin(), beam_size)
    
        ret = from_beam(beam)
#        ret["src"] = src.transpose(1, 0)

        return ret