Example #1
0
    def test_beam_is_done_when_n_best_beams_eos_using_min_length(self):
        # this is also a test that when block_ngram_repeat=0,
        # repeating is acceptable
        beam_sz = 5
        batch_sz = 3
        n_words = 100
        _non_eos_idxs = [47, 51, 13, 88, 99]
        valid_score_dist = torch.log_softmax(torch.tensor(
            [6., 5., 4., 3., 2., 1.]), dim=0)
        min_length = 5
        eos_idx = 2
        # beam includes start token in cur_len count.
        # Add one to its min_length to compensate
        beam = BeamSearch(
            beam_sz, batch_sz, 0, 1, 2, 2,
            torch.device("cpu"), GlobalScorerStub(),
            min_length + 1, 30, False, 0, set(),
            torch.randint(0, 30, (batch_sz,)))
        for i in range(min_length + 4):
            # non-interesting beams are going to get dummy values
            word_probs = torch.full(
                (batch_sz * beam_sz, n_words), -float('inf'))
            if i == 0:
                # "best" prediction is eos - that should be blocked
                word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                # include at least beam_sz predictions OTHER than EOS
                # that are greater than -1e20
                for j, score in zip(_non_eos_idxs, valid_score_dist[1:]):
                    word_probs[0::beam_sz, j] = score
            elif i <= min_length:
                # predict eos in beam 1
                word_probs[1::beam_sz, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz-1, k)
                    word_probs[beam_idx::beam_sz, j] = score
            else:
                word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                word_probs[1::beam_sz, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz-1, k)
                    word_probs[beam_idx::beam_sz, j] = score

            attns = torch.randn(beam_sz)
            beam.advance(word_probs, attns)
            if i < min_length:
                self.assertFalse(beam.done)
            elif i == min_length:
                # beam 1 dies on min_length
                self.assertTrue(beam.is_finished[:, 1].all())
                beam.update_finished()
                self.assertFalse(beam.done)
            else:  # i > min_length
                # beam 0 dies on the step after beam 1 dies
                self.assertTrue(beam.is_finished[:, 0].all())
                beam.update_finished()
                self.assertTrue(beam.done)
Example #2
0
    def test_beam_is_done_when_n_best_beams_eos_using_min_length(self):
        # this is also a test that when block_ngram_repeat=0,
        # repeating is acceptable
        beam_sz = 5
        batch_sz = 3
        n_words = 100
        _non_eos_idxs = [47, 51, 13, 88, 99]
        valid_score_dist = torch.log_softmax(torch.tensor(
            [6., 5., 4., 3., 2., 1.]), dim=0)
        min_length = 5
        eos_idx = 2
        beam = BeamSearch(
            beam_sz, batch_sz, 0, 1, 2, 2,
            torch.device("cpu"), GlobalScorerStub(),
            min_length, 30, False, 0, set(),
            torch.randint(0, 30, (batch_sz,)), False, 0.)
        for i in range(min_length + 4):
            # non-interesting beams are going to get dummy values
            word_probs = torch.full(
                (batch_sz * beam_sz, n_words), -float('inf'))
            if i == 0:
                # "best" prediction is eos - that should be blocked
                word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                # include at least beam_sz predictions OTHER than EOS
                # that are greater than -1e20
                for j, score in zip(_non_eos_idxs, valid_score_dist[1:]):
                    word_probs[0::beam_sz, j] = score
            elif i <= min_length:
                # predict eos in beam 1
                word_probs[1::beam_sz, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz-1, k)
                    word_probs[beam_idx::beam_sz, j] = score
            else:
                word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                word_probs[1::beam_sz, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz-1, k)
                    word_probs[beam_idx::beam_sz, j] = score

            attns = torch.randn(1, batch_sz * beam_sz, 53)
            beam.advance(word_probs, attns)
            if i < min_length:
                self.assertFalse(beam.done)
            elif i == min_length:
                # beam 1 dies on min_length
                self.assertTrue(beam.is_finished[:, 1].all())
                beam.update_finished()
                self.assertFalse(beam.done)
            else:  # i > min_length
                # beam 0 dies on the step after beam 1 dies
                self.assertTrue(beam.is_finished[:, 0].all())
                beam.update_finished()
                self.assertTrue(beam.done)
Example #3
0
    def forward_dev_beam_search(self, encoder_output: torch.Tensor, pad_mask):
        batch_size = encoder_output.size(1)

        self.state["cache"] = None
        memory_lengths = pad_mask.ne(pad_token_index).sum(dim=0)

        self.map_state(lambda state, dim: tile(state, self.beam_size, dim=dim))
        encoder_output = tile(encoder_output, self.beam_size, dim=1)
        pad_mask = tile(pad_mask, self.beam_size, dim=1)
        memory_lengths = tile(memory_lengths, self.beam_size, dim=0)

        # TODO:
        #  - fix attn (?)
        #  - use coverage_penalty="summary" ou "wu" and beta=0.2 (ou pas)
        #  - use length_penalty="wu" and alpha=0.2 (ou pas)
        beam = BeamSearch(beam_size=self.beam_size, n_best=1, batch_size=batch_size, mb_device=default_device,
                          global_scorer=GNMTGlobalScorer(alpha=0, beta=0, coverage_penalty="none", length_penalty="avg"),
                          pad=pad_token_index, eos=eos_token_index, bos=bos_token_index, min_length=1, max_length=100,
                          return_attention=False, stepwise_penalty=False, block_ngram_repeat=0, exclusion_tokens=set(),
                          memory_lengths=memory_lengths, ratio=-1)

        for i in range(self.max_seq_out_len):
            inp = beam.current_predictions.view(1, -1)

            out, attn = self.forward_step(src=pad_mask, tgt=inp, memory_bank=encoder_output, step=i)  # 1 x batch*beam x hidden
            out = self.linear(out)  # 1 x batch*beam x vocab_out
            out = log_softmax(out, dim=2)  # 1 x batch*beam x vocab_out

            out = out.squeeze(0)  # batch*beam x vocab_out
            # attn = attn.squeeze(0)  # batch*beam x vocab_out
            # out = out.view(batch_size, self.beam_size, -1)  # batch x beam x vocab_out
            # attn = attn.view(batch_size, self.beam_size, -1)
            # TODO: fix attn (?)

            beam.advance(out, attn)
            any_beam_is_finished = beam.is_finished.any()
            if any_beam_is_finished:
                beam.update_finished()
                if beam.done:
                    break

            select_indices = beam.current_origin

            if any_beam_is_finished:
                # Reorder states.
                encoder_output = encoder_output.index_select(1, select_indices)
                pad_mask = pad_mask.index_select(1, select_indices)
                memory_lengths = memory_lengths.index_select(0, select_indices)

            self.map_state(lambda state, dim: state.index_select(dim, select_indices))

        outputs = beam.predictions
        outputs = [x[0] for x in outputs]
        outputs = pad_sequence(outputs, batch_first=True)
        return [outputs]
Example #4
0
    def test_beam_returns_attn_with_correct_length(self):
        beam_sz = 5
        batch_sz = 3
        n_words = 100
        _non_eos_idxs = [47, 51, 13, 88, 99]
        valid_score_dist = torch.log_softmax(torch.tensor(
            [6., 5., 4., 3., 2., 1.]),
                                             dim=0)
        min_length = 5
        eos_idx = 2
        inp_lens = torch.randint(1, 30, (batch_sz, ))
        beam = BeamSearch(beam_sz, batch_sz, 0, 1, 2, 2, GlobalScorerStub(),
                          min_length, 30, True, 0, set(), False, 0.)
        device_init = torch.zeros(1, 1)
        _, _, inp_lens, _ = beam.initialize(device_init, inp_lens)
        # inp_lens is tiled in initialize, reassign to make attn match
        for i in range(min_length + 2):
            # non-interesting beams are going to get dummy values
            word_probs = torch.full((batch_sz * beam_sz, n_words),
                                    -float('inf'))
            if i == 0:
                # "best" prediction is eos - that should be blocked
                word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                # include at least beam_sz predictions OTHER than EOS
                # that are greater than -1e20
                for j, score in zip(_non_eos_idxs, valid_score_dist[1:]):
                    word_probs[0::beam_sz, j] = score
            elif i <= min_length:
                # predict eos in beam 1
                word_probs[1::beam_sz, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz - 1, k)
                    word_probs[beam_idx::beam_sz, j] = score
            else:
                word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                word_probs[1::beam_sz, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz - 1, k)
                    word_probs[beam_idx::beam_sz, j] = score

            attns = torch.randn(1, batch_sz * beam_sz, 53)
            beam.advance(word_probs, attns)
            if i < min_length:
                self.assertFalse(beam.done)
                # no top beams are finished yet
                for b in range(batch_sz):
                    self.assertEqual(beam.attention[b], [])
            elif i == min_length:
                # beam 1 dies on min_length
                self.assertTrue(beam.is_finished[:, 1].all())
                beam.update_finished()
                self.assertFalse(beam.done)
                # no top beams are finished yet
                for b in range(batch_sz):
                    self.assertEqual(beam.attention[b], [])
            else:  # i > min_length
                # beam 0 dies on the step after beam 1 dies
                self.assertTrue(beam.is_finished[:, 0].all())
                beam.update_finished()
                self.assertTrue(beam.done)
                # top beam is finished now so there are attentions
                for b in range(batch_sz):
                    # two beams are finished in each batch
                    self.assertEqual(len(beam.attention[b]), 2)
                    for k in range(2):
                        # second dim is cut down to the non-padded src length
                        self.assertEqual(beam.attention[b][k].shape[-1],
                                         inp_lens[b])
                    # first dim is equal to the time of death
                    # (beam 0 died at current step - adjust for SOS)
                    self.assertEqual(beam.attention[b][0].shape[0], i + 1)
                    # (beam 1 died at last step - adjust for SOS)
                    self.assertEqual(beam.attention[b][1].shape[0], i)
                # behavior gets weird when beam is already done so just stop
                break
Example #5
0
    def _translate_batch(
            self,
            batch,
            src_vocabs,
            max_length,
            min_length=0,
            ratio=0.,
            n_best=1,
            return_attention=False,
            xlation_builder=None,
    ):
        # TODO: support these blacklisted features.
        assert not self.dump_beam

        # (0) Prep the components of the search.
        use_src_map = self.copy_attn
        beam_size = self.beam_size
        batch_size = batch.batch_size

        # (1) Run the encoder on the src.
        src_list, enc_states_list, memory_bank_list, src_lengths_list = self._run_encoder(batch)
        self.model.decoder.init_state(src_list, memory_bank_list, enc_states_list)

        results = {
            "predictions": None,
            "scores": None,
            "attention": None,
            "batch": batch,
            "gold_score": self._gold_score(
                batch, memory_bank_list, src_lengths_list, src_vocabs, use_src_map,
                enc_states_list, batch_size, src_list)}
        
        # (2) Repeat src objects `beam_size` times.
        # We use batch_size x beam_size
        src_map_list = list()
        for src_type in self.src_types:
            src_map_list.append((tile(getattr(batch, f"src_map.{src_type}"), beam_size, dim=1) if use_src_map else None))
        # end for

        self.model.decoder.map_state(
            lambda state, dim: tile(state, beam_size, dim=dim))

        memory_lengths_list = list()
        memory_lengths = list()
        for src_i in range(len(memory_bank_list)):
            if isinstance(memory_bank_list[src_i], tuple):
                memory_bank_list[src_i] = tuple(tile(x, beam_size, dim=1) for x in memory_bank_list[src_i])
                mb_device = memory_bank_list[src_i][0].device
            else:
                memory_bank_list[src_i] = tile(memory_bank_list[src_i], beam_size, dim=1)
                mb_device = memory_bank_list[src_i].device
            # end if
            memory_lengths_list.append(tile(src_lengths_list[src_i], beam_size))
            memory_lengths.append(src_lengths_list[src_i])
        # end for
        memory_lengths = tile(torch.stack(memory_lengths, dim=0).sum(dim=0), beam_size)

        indexes = tile(torch.tensor(list(range(batch_size)), device=self._dev), beam_size)

        # (0) pt 2, prep the beam object
        beam = BeamSearch(
            beam_size,
            n_best=n_best,
            batch_size=batch_size,
            global_scorer=self.global_scorer,
            pad=self._tgt_pad_idx,
            eos=self._tgt_eos_idx,
            bos=self._tgt_bos_idx,
            min_length=min_length,
            ratio=ratio,
            max_length=max_length,
            mb_device=mb_device,
            return_attention=return_attention,
            stepwise_penalty=self.stepwise_penalty,
            block_ngram_repeat=self.block_ngram_repeat,
            exclusion_tokens=self._exclusion_idxs,
            memory_lengths=memory_lengths)

        for step in range(max_length):
            decoder_input = beam.current_predictions.view(1, -1, 1)

            log_probs, attn = self._decode_and_generate(
                decoder_input,
                memory_bank_list,
                batch,
                src_vocabs,
                memory_lengths_list=memory_lengths_list,
                src_map_list=src_map_list,
                step=step,
                batch_offset=beam._batch_offset)

            if self.reranker is not None:
                log_probs = self.reranker.rerank_step_beam_batch(
                    batch,
                    beam,
                    self.beam_size,
                    indexes,
                    log_probs,
                    attn,
                    self.fields["tgt"].base_field.vocab,
                    xlation_builder,
                )
            # end if

            non_finished = None
            beam.advance(log_probs, attn)
            any_beam_is_finished = beam.is_finished.any()
            if any_beam_is_finished:
                non_finished = beam.update_finished()
                if beam.done:
                    break

            select_indices = beam.current_origin

            if any_beam_is_finished:
                # Reorder states.
                for src_i in range(len(memory_bank_list)):
                    if isinstance(memory_bank_list[src_i], tuple):
                        memory_bank_list[src_i] = tuple(x.index_select(1, select_indices)
                                            for x in memory_bank_list[src_i])
                    else:
                        memory_bank_list[src_i] = memory_bank_list[src_i].index_select(1, select_indices)
                    # end if

                    memory_lengths_list[src_i] = memory_lengths_list[src_i].index_select(0, select_indices)
                # end for

                if use_src_map and src_map_list[0] is not None:
                    for src_i in range(len(src_map_list)):
                        src_map_list[src_i] = src_map_list[src_i].index_select(1, select_indices)
                    # end for
                # end if

                indexes = indexes.index_select(0, select_indices)

            self.model.decoder.map_state(
                lambda state, dim: state.index_select(dim, select_indices))

        results["scores"] = beam.scores
        results["predictions"] = beam.predictions
        results["attention"] = beam.attention
        return results
Example #6
0
def batch_beam_search_trs(model, inputs, batch_size, device="cpu", max_len=128, beam_size=20, n_best=1, alpha=1.):
    """ beam search with batch input for Transformer model

    Arguments:
        beam {onmt.BeamSearch} -- opennmt BeamSearch class
        model {torch.nn.Module} -- subclass of torch.nn.Module, required to implement .encode() and .decode() method
        inputs {list} -- list of torch.Tensor for input of encode()

    Keyword Arguments:
        device {str} -- device to eval model (default: {"cpu"})

    Returns:
        result -- 2D list (B, N-best), each element is an (seq, score) pair
    """
    beam = BeamSearch(beam_size, batch_size,
                    pad=C.PAD, bos=C.BOS, eos=C.EOS,
                    n_best=n_best,
                    mb_device=device,
                    global_scorer=GNMTGlobalScorer(alpha, 0.1, "avg", "none"),
                    min_length=0,
                    max_length=max_len,
                    ratio=0.0,
                    memory_lengths=None,
                    block_ngram_repeat=False,
                    exclusion_tokens=None,
                    stepwise_penalty=True,
                    return_attention=False,
                    )
    model.eval()
    is_finished = [False] * beam.batch_size
    with torch.no_grad():
        src_ids, _, src_pos, _, _, src_key_padding_mask, _, original_memory_key_padding_mask = list(
            map(lambda x: x.to(device), inputs))
        
        if src_ids.shape[1]!=batch_size:
            diff = batch_size - src_ids.shape[1]
            src_ids = torch.cat([src_ids] + [src_ids[:,:1]] * diff, dim=1)
            src_pos = torch.cat([src_pos] + [src_pos[:,:1]] * diff, dim=1)
            src_key_padding_mask = torch.cat([src_key_padding_mask]+[src_key_padding_mask[:1]]* diff, dim=0)
            original_memory_key_padding_mask = torch.cat([original_memory_key_padding_mask] +[original_memory_key_padding_mask[:1]]*diff, dim=0)


        model.to(device)
        original_memory = model.encode(src_ids, src_pos, src_key_padding_mask=src_key_padding_mask)

        memory = original_memory
        memory_key_padding_mask = original_memory_key_padding_mask
        while not beam.done:
            len_decoder_inputs = beam.alive_seq.shape[1]
            dec_pos = torch.arange(1, len_decoder_inputs+1).repeat(beam.alive_seq.shape[0], 1).permute(1, 0).to(device)

            # unsqueeze the memory and memory_key_padding_mask in B dim to match the size (BM*BS)
            repeated_memory = memory.repeat(1, 1, beam.beam_size).reshape(
                memory.shape[0], -1, memory.shape[-1])
            repeated_memory_key_padding_mask = memory_key_padding_mask.repeat(
                1, beam.beam_size).reshape(-1, memory_key_padding_mask.shape[1])

            decoder_outputs = model.decode(beam.alive_seq.permute(1, 0), dec_pos, _, repeated_memory, memory_key_padding_mask=repeated_memory_key_padding_mask)[-1]
            if hasattr(model, "proj"):
                logits = model.proj(decoder_outputs)
            elif hasattr(model, "gen"):
                logits = model.gen(decoder_outputs)
            else:
                raise ValueError("Unknown generator!")

            log_probs = torch.nn.functional.log_softmax(logits, dim=1)
            beam.advance(log_probs, None)
            if beam.is_finished.any():
                beam.update_finished()

                # select data for the still-alive index
                for i, n_best in enumerate(beam.predictions):
                    if is_finished[i] == False and len(n_best) == beam.n_best:
                        is_finished[i] = True

                alive_example_idx = [i for i in range(
                    len(is_finished)) if not is_finished[i]]
                if alive_example_idx:
                    memory = original_memory[:, alive_example_idx, :]
                    memory_key_padding_mask = original_memory_key_padding_mask[alive_example_idx]

    # packing data for easy accessing
    results = []
    for batch_preds, batch_scores in zip(beam.predictions, beam.scores):
        n_best_result = []
        for n_best_pred, n_best_score in zip(batch_preds, batch_scores):
            assert isinstance(n_best_pred, torch.Tensor)
            assert isinstance(n_best_score, torch.Tensor)
            n_best_result.append(
                (n_best_pred.tolist(), n_best_score.item())
            )
        results.append(n_best_result)

    return results
Example #7
0
    def test_beam_returns_attn_with_correct_length(self):
        beam_sz = 5
        batch_sz = 3
        n_words = 100
        _non_eos_idxs = [47, 51, 13, 88, 99]
        valid_score_dist = torch.log_softmax(torch.tensor(
            [6., 5., 4., 3., 2., 1.]), dim=0)
        min_length = 5
        eos_idx = 2
        inp_lens = torch.randint(1, 30, (batch_sz,))
        beam = BeamSearch(
            beam_sz, batch_sz, 0, 1, 2, 2,
            torch.device("cpu"), GlobalScorerStub(),
            min_length, 30, True, 0, set(),
            inp_lens, False, 0.)
        for i in range(min_length + 2):
            # non-interesting beams are going to get dummy values
            word_probs = torch.full(
                (batch_sz * beam_sz, n_words), -float('inf'))
            if i == 0:
                # "best" prediction is eos - that should be blocked
                word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                # include at least beam_sz predictions OTHER than EOS
                # that are greater than -1e20
                for j, score in zip(_non_eos_idxs, valid_score_dist[1:]):
                    word_probs[0::beam_sz, j] = score
            elif i <= min_length:
                # predict eos in beam 1
                word_probs[1::beam_sz, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz-1, k)
                    word_probs[beam_idx::beam_sz, j] = score
            else:
                word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                word_probs[1::beam_sz, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz-1, k)
                    word_probs[beam_idx::beam_sz, j] = score

            attns = torch.randn(1, batch_sz * beam_sz, 53)
            beam.advance(word_probs, attns)
            if i < min_length:
                self.assertFalse(beam.done)
                # no top beams are finished yet
                for b in range(batch_sz):
                    self.assertEqual(beam.attention[b], [])
            elif i == min_length:
                # beam 1 dies on min_length
                self.assertTrue(beam.is_finished[:, 1].all())
                beam.update_finished()
                self.assertFalse(beam.done)
                # no top beams are finished yet
                for b in range(batch_sz):
                    self.assertEqual(beam.attention[b], [])
            else:  # i > min_length
                # beam 0 dies on the step after beam 1 dies
                self.assertTrue(beam.is_finished[:, 0].all())
                beam.update_finished()
                self.assertTrue(beam.done)
                # top beam is finished now so there are attentions
                for b in range(batch_sz):
                    # two beams are finished in each batch
                    self.assertEqual(len(beam.attention[b]), 2)
                    for k in range(2):
                        # second dim is cut down to the non-padded src length
                        self.assertEqual(beam.attention[b][k].shape[-1],
                                         inp_lens[b])
                    # first dim is equal to the time of death
                    # (beam 0 died at current step - adjust for SOS)
                    self.assertEqual(beam.attention[b][0].shape[0], i+1)
                    # (beam 1 died at last step - adjust for SOS)
                    self.assertEqual(beam.attention[b][1].shape[0], i)
                # behavior gets weird when beam is already done so just stop
                break
Example #8
0
    def _translate_batch(self,
                         batch,
                         src_vocabs,
                         max_length,
                         min_length=0,
                         ratio=0.,
                         n_best=1,
                         return_attention=False):
        # TODO: support these blacklisted features.
        assert not self.dump_beam

        # (0) Prep the components of the search.
        use_src_map = self.copy_attn
        beam_size = self.beam_size
        batch_size = batch.batch_size

        #### TODO: Augment batch with distractors

        # (1) Run the encoder on the src.
        src, enc_states, memory_bank, src_lengths = self._run_encoder(batch)
        # src has shape [1311, 2, 1]
        # enc_states has shape [1311, 2, 512],
        # Memory_bank has shape [1311, 2, 512]
        self.model.decoder.init_state(src, memory_bank, enc_states)

        results = {
            "predictions":
            None,
            "scores":
            None,
            "attention":
            None,
            "batch":
            batch,
            "gold_score":
            self._gold_score(batch, memory_bank, src_lengths, src_vocabs,
                             use_src_map, enc_states, batch_size, src)
        }

        # (2) Repeat src objects `beam_size` times.
        # We use batch_size x beam_size
        src_map = (tile(batch.src_map, beam_size, dim=1)
                   if use_src_map else None)
        self.model.decoder.map_state(
            lambda state, dim: tile(state, beam_size, dim=dim))

        if isinstance(memory_bank, tuple):
            memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank)
            mb_device = memory_bank[0].device
        else:
            memory_bank = tile(memory_bank, beam_size, dim=1)
            mb_device = memory_bank.device
        memory_lengths = tile(src_lengths, beam_size)
        print('memory_bank size after tile:',
              memory_bank.shape)  #[1311, 20, 512]

        # (0) pt 2, prep the beam object
        beam = BeamSearch(beam_size,
                          n_best=n_best,
                          batch_size=batch_size,
                          global_scorer=self.global_scorer,
                          pad=self._tgt_pad_idx,
                          eos=self._tgt_eos_idx,
                          bos=self._tgt_bos_idx,
                          min_length=min_length,
                          ratio=ratio,
                          max_length=max_length,
                          mb_device=mb_device,
                          return_attention=return_attention,
                          stepwise_penalty=self.stepwise_penalty,
                          block_ngram_repeat=self.block_ngram_repeat,
                          exclusion_tokens=self._exclusion_idxs,
                          memory_lengths=memory_lengths)

        all_log_probs = []
        all_attn = []

        for step in range(max_length):
            decoder_input = beam.current_predictions.view(1, -1, 1)
            # decoder_input has shape[1,20,1]
            # decoder_input gives top 10 predictions for each batch element
            verbose = True if step == 10 else False
            log_probs, attn = self._decode_and_generate(
                decoder_input,
                memory_bank,
                batch,
                src_vocabs,
                memory_lengths=memory_lengths,
                src_map=src_map,
                step=step,
                batch_offset=beam._batch_offset,
                verbose=verbose)

            # log_probs and attn are the probs for next word given that the
            # current word is that in decoder_input
            all_log_probs.append(log_probs)
            all_attn.append(attn)

            beam.advance(log_probs, attn, verbose=verbose)

            any_beam_is_finished = beam.is_finished.any()
            if any_beam_is_finished:
                beam.update_finished()
                if beam.done:
                    break

            select_indices = beam.current_origin

            if any_beam_is_finished:
                # Reorder states.
                if isinstance(memory_bank, tuple):
                    memory_bank = tuple(
                        x.index_select(1, select_indices) for x in memory_bank)
                else:
                    memory_bank = memory_bank.index_select(1, select_indices)

                memory_lengths = memory_lengths.index_select(0, select_indices)

                if src_map is not None:
                    src_map = src_map.index_select(1, select_indices)

            self.model.decoder.map_state(
                lambda state, dim: state.index_select(dim, select_indices))

        print('batch_size:', batch_size)
        print('max_length:', max_length)
        print('all_log_probs has len', len(all_log_probs))
        print('all_log_probs[0].shape', all_log_probs[0].shape)
        print('comparing log_probs[0]', all_log_probs[2][:, 0])

        results["scores"] = beam.scores
        results["predictions"] = beam.predictions
        results["attention"] = beam.attention
        return results
Example #9
0
    def _translate_batch(
            self,
            src,
            src_lengths,
            batch_size,
            min_length=0,
            ratio=0.,
            n_best=1,
            return_attention=False):

        max_length = self.config.max_sequence_length + 1 # to account for EOS
        beam_size = 3
        
        # Encoder forward.
        enc_states, memory_bank, src_lengths = self.encoder(src, src_lengths)
        self.decoder.init_state(src, memory_bank, enc_states)

        results = { "predictions": None, "scores": None, "attention": None }

        # (2) Repeat src objects `beam_size` times.
        # We use batch_size x beam_size
        self.decoder.map_state(lambda state, dim: tile(state, beam_size, dim=dim))

        #if isinstance(memory_bank, tuple):
        #    memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank)
        #    mb_device = memory_bank[0].device
        #else:
        memory_bank = tile(memory_bank, beam_size, dim=1)
        mb_device = memory_bank.device
        memory_lengths = tile(src_lengths, beam_size)

        mb_device = memory_bank[0].device if isinstance(memory_bank, tuple) else memory_bank.device
        
        block_ngram_repeat = 0
        _exclusion_idxs = {}

        # (0) pt 2, prep the beam object
        beam = BeamSearch(
            beam_size,
            n_best=n_best,
            batch_size=batch_size,
            global_scorer=self.scorer,
            pad=self.config.tgt_padding,
            eos=self.config.tgt_eos,
            bos=self.config.tgt_bos,
            min_length=min_length,
            ratio=ratio,
            max_length=max_length,
            mb_device=mb_device,
            return_attention=return_attention,
            stepwise_penalty=None,
            block_ngram_repeat=block_ngram_repeat,
            exclusion_tokens=_exclusion_idxs,
            memory_lengths=memory_lengths)

        for step in range(max_length):
            decoder_input = beam.current_predictions.view(1, -1, 1)

            log_probs, attn = self._decode_and_generate(decoder_input, memory_bank, memory_lengths, step, pretraining = True)

            beam.advance(log_probs, attn)
            any_beam_is_finished = beam.is_finished.any()
            if any_beam_is_finished:
                beam.update_finished()
                if beam.done:
                    break

            select_indices = beam.current_origin

            if any_beam_is_finished:
                # Reorder states.
                if isinstance(memory_bank, tuple):
                    memory_bank = tuple(x.index_select(1, select_indices)
                                        for x in memory_bank)
                else:
                    memory_bank = memory_bank.index_select(1, select_indices)

                memory_lengths = memory_lengths.index_select(0, select_indices)

            self.decoder.map_state(
                lambda state, dim: state.index_select(dim, select_indices))

        results["scores"] = beam.scores
        results["predictions"] = beam.predictions
        results["attention"] = beam.attention
        return results
Example #10
0
                   block_ngram_repeat=block_ngram_repeat,
                   exclusion_tokens=exclusion_idxs,
                   memory_lengths=memory_lengths)
 for step in range(max_length):
     decoder_input = beam.current_predictions.view(1, -1, 1)
     dec_out, dec_attn = model.decoder(decoder_input,
                                       memory_bank,
                                       memory_lengths=memory_lengths,
                                       step=step)
     log_probs = model.generator(dec_out.squeeze(0))
     attn = dec_attn['std']
     log_probs, attn = log_probs.detach(), attn.detach()
     beam.advance(log_probs, attn)
     any_beam_is_finished = beam.is_finished.any()
     if any_beam_is_finished:
         beam.update_finished()
         if beam.done:
             break
     select_indices = beam.current_origin
     if any_beam_is_finished:
         if isinstance(memory_bank, tuple):
             memory_bank = tuple(
                 x.index_select(1, select_indices) for x in memory_bank)
         else:
             memory_bank = memory_bank.index_select(1, select_indices)
         memory_lengths = memory_lengths.index_select(0, select_indices)
     model.decoder.map_state(
         lambda state, dim: state.index_select(dim, select_indices))
 print(src.shape)
 # print([list(map(lambda x:tgt_vocab.itos[x],input)) for input in src.transpose(1,0)])
 # print(beam.predictions)