Esempio n. 1
0
    def initialize(self,
                   src,
                   src_lengths,
                   src_map=None,
                   device=None,
                   target_prefix=None):
        """Initialize for decoding.
        Repeat src objects `beam_size` times.
        """
        def fn_map_state(state, dim):
            return tile(state, self.beam_size, dim=dim)

        src = fn_map_state(src, dim=1)
        if src_map is not None:
            src_map = tile(src_map, self.beam_size, dim=1)
        if device is None:
            device = src.device

        self.memory_lengths = tile(src_lengths, self.beam_size)
        if target_prefix is not None:
            target_prefix = tile(target_prefix, self.beam_size, dim=1)

        super(BeamSearchLM, self).initialize_(None,
                                              self.memory_lengths,
                                              src_map=src_map,
                                              device=device,
                                              target_prefix=target_prefix)

        return fn_map_state, src, self.memory_lengths, src_map
Esempio n. 2
0
    def initialize(self, memory_bank, src_lengths, src_map=None, device=None):
        """Initialize for decoding."""
        #fn_map_state = None
        def fn_map_state(state, dim):
            return tile(state, self.sample_size, dim=dim)

        if isinstance(memory_bank, tuple):
            memory_bank = tuple(tile(x, self.sample_size, dim=1)
                                for x in memory_bank)
            mb_device = memory_bank[0].device
        else:
            memory_bank = tile(memory_bank, self.sample_size, dim=1)
            mb_device = memory_bank.device
        if src_map is not None:
            src_map = tile(src_map, self.sample_size, dim=1)
        if device is None:
            device = mb_device

        self.memory_lengths = tile(src_lengths, self.sample_size)
        super(PriorSampling, self).initialize(
            memory_bank, self.memory_lengths, src_map, device)
        self.select_indices = torch.arange(
            self.batch_size * self.sample_size, dtype=torch.long, device=device)
        self.original_batch_idx = tile(torch.arange(
            self.batch_size, dtype=torch.long, device=device), self.sample_size)
        return fn_map_state, memory_bank, self.memory_lengths, src_map
Esempio n. 3
0
    def initialize(self,
                   memory_bank,
                   src_lengths,
                   src_map=None,
                   device=None,
                   target_prefix=None):
        """Initialize for decoding.
        Repeat src objects `beam_size` times.
        """
        def fn_map_state(state, dim):
            return tile(state, self.beam_size, dim=dim)

        if isinstance(memory_bank, tuple):
            memory_bank = tuple(
                tile(x, self.beam_size, dim=1) for x in memory_bank)
            mb_device = memory_bank[0].device
        else:
            memory_bank = tile(memory_bank, self.beam_size, dim=1)
            mb_device = memory_bank.device
        if src_map is not None:
            src_map = tile(src_map, self.beam_size, dim=1)
        if device is None:
            device = mb_device

        self.memory_lengths = tile(src_lengths, self.beam_size)
        if target_prefix is not None:
            target_prefix = tile(target_prefix, self.beam_size, dim=1)

        super(BeamSearch, self).initialize_(memory_bank, self.memory_lengths,
                                            src_map, device, target_prefix)

        return fn_map_state, memory_bank, self.memory_lengths, src_map
Esempio n. 4
0
    def forward_dev_beam_search(self, encoder_output: torch.Tensor, pad_mask):
        batch_size = encoder_output.size(1)

        self.state["cache"] = None
        memory_lengths = pad_mask.ne(pad_token_index).sum(dim=0)

        self.map_state(lambda state, dim: tile(state, self.beam_size, dim=dim))
        encoder_output = tile(encoder_output, self.beam_size, dim=1)
        pad_mask = tile(pad_mask, self.beam_size, dim=1)
        memory_lengths = tile(memory_lengths, self.beam_size, dim=0)

        # TODO:
        #  - fix attn (?)
        #  - use coverage_penalty="summary" ou "wu" and beta=0.2 (ou pas)
        #  - use length_penalty="wu" and alpha=0.2 (ou pas)
        beam = BeamSearch(beam_size=self.beam_size, n_best=1, batch_size=batch_size, mb_device=default_device,
                          global_scorer=GNMTGlobalScorer(alpha=0, beta=0, coverage_penalty="none", length_penalty="avg"),
                          pad=pad_token_index, eos=eos_token_index, bos=bos_token_index, min_length=1, max_length=100,
                          return_attention=False, stepwise_penalty=False, block_ngram_repeat=0, exclusion_tokens=set(),
                          memory_lengths=memory_lengths, ratio=-1)

        for i in range(self.max_seq_out_len):
            inp = beam.current_predictions.view(1, -1)

            out, attn = self.forward_step(src=pad_mask, tgt=inp, memory_bank=encoder_output, step=i)  # 1 x batch*beam x hidden
            out = self.linear(out)  # 1 x batch*beam x vocab_out
            out = log_softmax(out, dim=2)  # 1 x batch*beam x vocab_out

            out = out.squeeze(0)  # batch*beam x vocab_out
            # attn = attn.squeeze(0)  # batch*beam x vocab_out
            # out = out.view(batch_size, self.beam_size, -1)  # batch x beam x vocab_out
            # attn = attn.view(batch_size, self.beam_size, -1)
            # TODO: fix attn (?)

            beam.advance(out, attn)
            any_beam_is_finished = beam.is_finished.any()
            if any_beam_is_finished:
                beam.update_finished()
                if beam.done:
                    break

            select_indices = beam.current_origin

            if any_beam_is_finished:
                # Reorder states.
                encoder_output = encoder_output.index_select(1, select_indices)
                pad_mask = pad_mask.index_select(1, select_indices)
                memory_lengths = memory_lengths.index_select(0, select_indices)

            self.map_state(lambda state, dim: state.index_select(dim, select_indices))

        outputs = beam.predictions
        outputs = [x[0] for x in outputs]
        outputs = pad_sequence(outputs, batch_first=True)
        return [outputs]
Esempio n. 5
0
    def initialize(self,
                   memory_bank,
                   src_lengths,
                   src_map=None,
                   device=None,
                   target_prefix=None):
        """Initialize for decoding.
        Repeat src objects `beam_size` times.
        """
        def fn_map_state(state, dim):
            return tile(state, self.beam_size, dim=dim)

        if isinstance(memory_bank, tuple):
            memory_bank = tuple(
                tile(x, self.beam_size, dim=1) for x in memory_bank)
            mb_device = memory_bank[0].device
        else:
            memory_bank = tile(memory_bank, self.beam_size, dim=1)
            mb_device = memory_bank.device
        if src_map is not None:
            src_map = tile(src_map, self.beam_size, dim=1)
        if device is None:
            device = mb_device

        self.memory_lengths = tile(src_lengths, self.beam_size)
        if target_prefix is not None:
            target_prefix = tile(target_prefix, self.beam_size, dim=1)

        super(BeamSearch, self).initialize(memory_bank, self.memory_lengths,
                                           src_map, device, target_prefix)

        self.best_scores = torch.full([self.batch_size],
                                      -1e10,
                                      dtype=torch.float,
                                      device=device)
        self._beam_offset = torch.arange(0,
                                         self.batch_size * self.beam_size,
                                         step=self.beam_size,
                                         dtype=torch.long,
                                         device=device)
        self.topk_log_probs = torch.tensor(
            [0.0] + [float("-inf")] * (self.beam_size - 1),
            device=device).repeat(self.batch_size)
        # buffers for the topk scores and 'backpointer'
        self.topk_scores = torch.empty((self.batch_size, self.beam_size),
                                       dtype=torch.float,
                                       device=device)
        self.topk_ids = torch.empty((self.batch_size, self.beam_size),
                                    dtype=torch.long,
                                    device=device)
        self._batch_index = torch.empty([self.batch_size, self.beam_size],
                                        dtype=torch.long,
                                        device=device)
        return fn_map_state, memory_bank, self.memory_lengths, src_map
Esempio n. 6
0
    def _align_forward(self, batch, predictions):
        """
        For a batch of input and its prediction, return a list of batch predict
        alignment src indice Tensor in size ``(batch, n_best,)``.
        """
        # (0) add BOS and padding to tgt prediction
        if hasattr(batch, 'tgt'):
            batch_tgt_idxs = batch.tgt.transpose(1, 2).transpose(0, 2)
        else:
            batch_tgt_idxs = self._align_pad_prediction(predictions,
                                                        bos=self._tgt_bos_idx,
                                                        pad=self._tgt_pad_idx)
        tgt_mask = (batch_tgt_idxs.eq(self._tgt_pad_idx)
                    | batch_tgt_idxs.eq(self._tgt_eos_idx)
                    | batch_tgt_idxs.eq(self._tgt_bos_idx))

        n_best = batch_tgt_idxs.size(1)
        # (1) Encoder forward.
        src, enc_states, memory_bank, src_lengths = self._run_encoder(batch)

        # (2) Repeat src objects `n_best` times.
        # We use batch_size x n_best, get ``(src_len, batch * n_best, nfeat)``
        src = tile(src, n_best, dim=1)
        enc_states = tile(enc_states, n_best, dim=1)
        if isinstance(memory_bank, tuple):
            memory_bank = tuple(tile(x, n_best, dim=1) for x in memory_bank)
        else:
            memory_bank = tile(memory_bank, n_best, dim=1)
        src_lengths = tile(src_lengths, n_best)  # ``(batch * n_best,)``

        # (3) Init decoder with n_best src,
        self.model.decoder.init_state(src, memory_bank, enc_states)
        self.turbo_decoder.init_state(src, memory_bank, enc_states)

        # reshape tgt to ``(len, batch * n_best, nfeat)``
        tgt = batch_tgt_idxs.view(-1, batch_tgt_idxs.size(-1)).T.unsqueeze(-1)
        dec_in = tgt[:-1]  # exclude last target from inputs

        _, attns = self.model.decoder(dec_in,
                                      memory_bank,
                                      memory_lengths=src_lengths,
                                      with_align=True)

        alignment_attn = attns["align"]  # ``(B, tgt_len-1, src_len)``
        # masked_select
        align_tgt_mask = tgt_mask.view(-1, tgt_mask.size(-1))
        prediction_mask = align_tgt_mask[:, 1:]  # exclude bos to match pred
        # get aligned src id for each prediction's valid tgt tokens
        alignement = extract_alignment(alignment_attn, prediction_mask,
                                       src_lengths, n_best)
        return alignement
 def forward(self, batch_size, beam_size, max_seq_len, memory,
             memory_seq_lens):
     extended_memory = tile(memory, beam_size)
     extended_memory_seq_lens = tile(memory_seq_lens, beam_size)
     output_ids, parent_ids, out_seq_lens = self.decoding.forward(
         batch_size, beam_size, max_seq_len, extended_memory,
         extended_memory_seq_lens)
     parent_ids = parent_ids % beam_size
     beams, lengths = finalize(beam_size,
                               output_ids,
                               parent_ids,
                               out_seq_lens,
                               self.end_id,
                               max_seq_len,
                               args=self.args)
     return beams, lengths
Esempio n. 8
0
    def initialize_tile(self,
                        memory_bank,
                        src_lengths,
                        src_map=None,
                        target_prefix=None):
        def fn_map_state(state, dim):
            return tile(state, self.beam_size, dim=dim)

        if isinstance(memory_bank, tuple):
            memory_bank = tuple(
                tile(x, self.beam_size, dim=1) for x in memory_bank)
        elif memory_bank is not None:
            memory_bank = tile(memory_bank, self.beam_size, dim=1)
        if src_map is not None:
            src_map = tile(src_map, self.beam_size, dim=1)

        self.memory_lengths = tile(src_lengths, self.beam_size)
        if target_prefix is not None:
            target_prefix = tile(target_prefix, self.beam_size, dim=1)

        return fn_map_state, memory_bank, src_map, target_prefix
    def forward(self, batch_size, beam_size, max_seq_len, memory,
                memory_seq_lens):
        extended_memory = tile(memory, beam_size)
        batchxbeam = extended_memory.size(0)
        extended_memory = extended_memory.transpose(0, 1).contiguous()

        extended_memory_seq_lens = tile(memory_seq_lens, beam_size)
        start_ids = extended_memory_seq_lens.new_full((batchxbeam, ),
                                                      self.start_id,
                                                      dtype=torch.int64)

        initial_log_probs = extended_memory.new_full((beam_size, ),
                                                     -float("inf"),
                                                     dtype=torch.float32)
        initial_log_probs[0] = 0.
        initial_log_probs = initial_log_probs.repeat(batch_size)
        sequence_lengths = extended_memory_seq_lens.new_full((batchxbeam, ), 0)
        finished = extended_memory_seq_lens.new_full((batchxbeam, ),
                                                     0,
                                                     dtype=torch.bool)

        dtype_info = torch.finfo(extended_memory.dtype)
        eos_max_prob = extended_memory.new_full((batchxbeam, self.vocab_size),
                                                dtype_info.min)
        eos_max_prob[:, self.end_id] = dtype_info.max

        self.decoder.init_state(extended_memory, extended_memory, None)
        word_ids = start_ids
        cum_log_probs = initial_log_probs

        for step in range(max_seq_len):
            if not torch.bitwise_not(finished).any():
                break
            word_ids = word_ids.view(1, -1, 1)
            dec_out, dec_attn = self.decoder(
                word_ids,
                extended_memory,
                memory_lengths=extended_memory_seq_lens,
                step=step)
            logits = self.generator(dec_out.squeeze(0))
            logits = torch.where(finished.view(-1, 1), eos_max_prob,
                                 logits).to(torch.float32)
            log_probs = self.logsoftmax(logits.to(torch.float32))

            total_probs = log_probs + torch.unsqueeze(cum_log_probs, 1)
            total_probs = total_probs.view(-1, beam_size * self.vocab_size)

            # beamsearch
            # _, sample_ids = torch.topk(total_probs, beam_size)
            # sample_ids = sample_ids.view(-1)

            #diversesiblingsearch
            sibling_score = torch.arange(1, beam_size + 1).to(
                total_probs.dtype).to(extended_memory.device
                                      ) * self.diversity_rate  # [beam_size]
            scores, ids = torch.topk(
                total_probs.view(-1, beam_size, self.vocab_size),
                beam_size)  # [batch size, beam width, beam width]
            scores = scores + sibling_score  # [batch size, beam width, beam width]
            scores = scores.view(-1, beam_size * beam_size)
            ids = ids + torch.unsqueeze(
                torch.unsqueeze(
                    torch.arange(0, beam_size).to(extended_memory.device) *
                    self.vocab_size, 0), -1)
            ids = ids.view(-1, beam_size * beam_size)
            _, final_ids = torch.topk(scores,
                                      beam_size)  # [batch size, beam size]
            final_ids = final_ids.view(-1, 1)
            batch_index = torch.arange(0, batch_size).to(
                extended_memory.device).view(-1,
                                             1).repeat(1,
                                                       beam_size).view(-1, 1)
            index = torch.cat([batch_index, final_ids], 1)
            sample_ids = gather_nd(ids, index)

            word_ids = sample_ids % self.vocab_size  # [batch_size * beam_size]
            beam_ids = sample_ids // self.vocab_size  # [batch_size * beam_size]
            beam_indices = (torch.arange(batchxbeam).to(extended_memory.device)
                            // beam_size) * beam_size + beam_ids

            sequence_lengths = torch.where(finished, sequence_lengths,
                                           sequence_lengths + 1)

            batch_pos = torch.arange(batchxbeam).to(
                extended_memory.device) // beam_size
            next_cum_log_probs = gather_nd(
                total_probs, torch.stack([batch_pos, sample_ids],
                                         -1))  # [batch_size * beam_size]
            finished = finished.index_select(0, beam_indices)
            sequence_lengths = sequence_lengths.index_select(0, beam_indices)

            self.decoder.map_state(
                lambda state, dim: state.index_select(dim, beam_indices))
            if step == 0:
                parent_ids = beam_ids.view(1, -1)
                output_ids = word_ids.view(1, -1)
            else:
                parent_ids = torch.cat((parent_ids, beam_ids.view(1, -1)))
                output_ids = torch.cat((output_ids, word_ids.view(1, -1)))
            cum_log_probs = torch.where(finished, cum_log_probs,
                                        next_cum_log_probs)
            finished = torch.bitwise_or(finished,
                                        torch.eq(word_ids, self.end_id))

        beams, lengths = finalize(beam_size,
                                  output_ids,
                                  parent_ids,
                                  sequence_lengths,
                                  self.end_id,
                                  args=self.args)
        return beams, lengths
Esempio n. 10
0
    def _translate_batch(self, batch, data):
        # (0) Prep each of the components of the search.
        # And helper method for reducing verbosity.
        beam_size = self.beam_size
        batch_size = batch.batch_size
        tgt_field = self.fields['tgt'][0][1].base_field
        vocab = tgt_field.vocab

        # Define a set of tokens to exclude from ngram-blocking
        exclusion_tokens = {vocab.stoi[t] for t in self.ignore_when_blocking}

        pad = vocab.stoi[tgt_field.pad_token]
        eos = vocab.stoi[tgt_field.eos_token]
        bos = vocab.stoi[tgt_field.init_token]
        beam = [
            onmt.translate.Beam(beam_size,
                                n_best=self.n_best,
                                cuda=self.cuda,
                                global_scorer=self.global_scorer,
                                pad=pad,
                                eos=eos,
                                bos=bos,
                                min_length=self.min_length,
                                stepwise_penalty=self.stepwise_penalty,
                                block_ngram_repeat=self.block_ngram_repeat,
                                exclusion_tokens=exclusion_tokens)
            for __ in range(batch_size)
        ]

        # (1) Run the encoder on the src.
        src, enc_states_ques, memory_bank_ques, ques_lengths, ans, enc_states_ans, memory_bank_ans, ans_lengths = self._run_encoder(
            batch)
        #################33 Modified #######################
        enc_states_final = tuple(
            torch.add(enc_q, enc_ans)
            for enc_q, enc_ans in zip(enc_states_ques, enc_states_ans))
        memory_bank_final = torch.cat([memory_bank_ques, memory_bank_ans], 0)
        memory_lengths_final = torch.add(ques_lengths, ans_lengths)
        self.model.decoder.init_state(src, memory_bank_final, enc_states_final)

        results = {}
        results["predictions"] = []
        results["scores"] = []
        results["attention"] = []
        results["batch"] = batch
        if "tgt" in batch.__dict__:
            results["gold_score"] = self._score_target(
                batch, memory_bank_final, memory_lengths_final, data,
                batch.src_map if self.copy_attn else None)
            self.model.decoder.init_state(src, memory_bank_final,
                                          enc_states_final)
        else:
            results["gold_score"] = [0] * batch_size

        # (2) Repeat src objects `beam_size` times.
        # We use now  batch_size x beam_size (same as fast mode)
        src_map = (tile(batch.src_map, beam_size, dim=1)
                   if self.copy_attn else None)
        self.model.decoder.map_state(
            lambda state, dim: tile(state, beam_size, dim=dim))

        if isinstance(memory_bank_final, tuple):
            memory_bank_final = tuple(
                tile(x, beam_size, dim=1) for x in memory_bank_final)
        else:
            memory_bank_final = tile(memory_bank_final, beam_size, dim=1)
        memory_lengths_final = tile(memory_lengths_final, beam_size)

        # (3) run the decoder to generate sentences, using beam search.
        for i in range(self.max_length):
            if all((b.done() for b in beam)):
                break

            # (a) Construct batch x beam_size nxt words.
            # Get all the pending current beam words and arrange for forward.

            inp = torch.stack([b.get_current_state() for b in beam])
            inp = inp.view(1, -1, 1)

            # (b) Decode and forward
            out, beam_attn = self._decode_and_generate(
                inp,
                memory_bank_final,
                batch,
                data,
                memory_lengths=memory_lengths_final,
                src_map=src_map,
                step=i)
            out = out.view(batch_size, beam_size, -1)
            beam_attn = beam_attn.view(batch_size, beam_size, -1)

            # (c) Advance each beam.
            select_indices_array = []
            # Loop over the batch_size number of beam
            for j, b in enumerate(beam):
                b.advance(out[j, :],
                          beam_attn.data[j, :, :memory_lengths_final[j]])
                select_indices_array.append(b.get_current_origin() +
                                            j * beam_size)
            select_indices = torch.cat(select_indices_array)

            self.model.decoder.map_state(
                lambda state, dim: state.index_select(dim, select_indices))

        # (4) Extract sentences from beam.
        for b in beam:
            scores, ks = b.sort_finished(minimum=self.n_best)
            hyps, attn = [], []
            for times, k in ks[:self.n_best]:
                hyp, att = b.get_hyp(times, k)
                hyps.append(hyp)
                attn.append(att)
            results["predictions"].append(hyps)
            results["scores"].append(scores)
            results["attention"].append(attn)

        return results
Esempio n. 11
0
    def _translate_batch(
            self,
            batch,
            src_vocabs,
            max_length,
            min_length=0,
            ratio=0.,
            n_best=1,
            return_attention=False):
        # TODO: support these blacklisted features.
        assert not self.dump_beam

        # (0) Prep the components of the search.
        use_src_map = self.copy_attn
        beam_size = self.beam_size
        batch_size = batch.batch_size

        # (1) Run the encoder on the src.
        src, enc_states, memory_bank, src_lengths = self._run_encoder(batch)
        self.model.decoder.init_state(src, memory_bank, enc_states)

        results = {
            "predictions": None,
            "scores": None,
            "attention": None,
            "batch": batch,
            "gold_score": self._gold_score(
                batch, memory_bank, src_lengths, src_vocabs, use_src_map,
                enc_states, batch_size, src)}

        # (2) Repeat src objects `beam_size` times.
        # We use batch_size x beam_size
        src_map = (tile(batch.src_map, beam_size, dim=1)
                   if use_src_map else None)
        self.model.decoder.map_state(
            lambda state, dim: tile(state, beam_size, dim=dim))

        if isinstance(memory_bank, tuple):
            memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank)
            mb_device = memory_bank[0].device
        else:
            memory_bank = tile(memory_bank, beam_size, dim=1)
            mb_device = memory_bank.device
        memory_lengths = tile(src_lengths, beam_size)

        # (0) pt 2, prep the beam object
        beam = BeamSearch(
            beam_size,
            n_best=n_best,
            batch_size=batch_size,
            global_scorer=self.global_scorer,
            pad=self._tgt_pad_idx,
            eos=self._tgt_eos_idx,
            bos=self._tgt_bos_idx,
            min_length=min_length,
            ratio=ratio,
            max_length=max_length,
            mb_device=mb_device,
            return_attention=return_attention,
            stepwise_penalty=self.stepwise_penalty,
            block_ngram_repeat=self.block_ngram_repeat,
            exclusion_tokens=self._exclusion_idxs,
            memory_lengths=memory_lengths)

        for step in range(max_length):
            decoder_input = beam.current_predictions.view(1, -1, 1)

            log_probs, attn = self._decode_and_generate(
                decoder_input,
                memory_bank,
                batch,
                src_vocabs,
                memory_lengths=memory_lengths,
                src_map=src_map,
                step=step,
                batch_offset=beam._batch_offset)

            beam.advance(log_probs, attn)
            any_beam_is_finished = beam.is_finished.any()
            if any_beam_is_finished:
                beam.update_finished()
                if beam.done:
                    break

            select_indices = beam.current_origin

            if any_beam_is_finished:
                # Reorder states.
                if isinstance(memory_bank, tuple):
                    memory_bank = tuple(x.index_select(1, select_indices)
                                        for x in memory_bank)
                else:
                    memory_bank = memory_bank.index_select(1, select_indices)

                memory_lengths = memory_lengths.index_select(0, select_indices)

                if src_map is not None:
                    src_map = src_map.index_select(1, select_indices)

            self.model.decoder.map_state(
                lambda state, dim: state.index_select(dim, select_indices))

        results["scores"] = beam.scores
        results["predictions"] = beam.predictions
        results["attention"] = beam.attention
        return results
Esempio n. 12
0
    def _fast_translate_batch(self,
                              batch,
                              data,
                              max_length,
                              min_length=0,
                              n_best=1,
                              return_attention=False):
        # TODO: faster code path for beam_size == 1.

        # TODO: support these blacklisted features.
        assert data.data_type == 'text'
        assert not self.copy_attn
        assert not self.dump_beam
        assert not self.use_filter_pred
        assert self.block_ngram_repeat == 0
        assert self.global_scorer.beta == 0

        beam_size = self.beam_size
        batch_size = batch.batch_size
        vocab = self.fields["tgt"].vocab
        start_token = vocab.stoi[inputters.BOS_WORD]
        end_token = vocab.stoi[inputters.EOS_WORD]

        # Encoder forward.
        src = inputters.make_features(batch, 'src', data.data_type)
        _, src_lengths = batch.src
        enc_states, memory_bank, src_lengths \
            = self.model.encoder(src, src_lengths)
        dec_states = self.model.decoder.init_decoder_state(
            src, memory_bank, enc_states, with_cache=True)

        # Tile states and memory beam_size times.
        dec_states.map_batch_fn(
            lambda state, dim: tile(state, beam_size, dim=dim))

        if type(memory_bank) == tuple:
            device = memory_bank[0].device
            memory_bank = tuple(tile(m, beam_size, dim=1) for m in memory_bank)
        else:
            memory_bank = tile(memory_bank, beam_size, dim=1)
            device = memory_bank.device
        memory_lengths = tile(src_lengths, beam_size)

        top_beam_finished = torch.zeros([batch_size], dtype=torch.uint8)
        batch_offset = torch.arange(batch_size, dtype=torch.long)

        beam_offset = torch.arange(
            0,
            batch_size * beam_size,
            step=beam_size,
            dtype=torch.long,
            device=device)
        alive_seq = torch.full(
            [batch_size * beam_size, 1],
            start_token,
            dtype=torch.long,
            device=device)
        alive_attn = None

        # Give full probability to the first beam on the first step.
        topk_log_probs = (
            torch.tensor([0.0] + [float("-inf")] * (beam_size - 1),
                         device=device).repeat(batch_size))

        # Structure that holds finished hypotheses.
        hypotheses = [[] for _ in range(batch_size)]  # noqa: F812

        results = {}
        results["predictions"] = [[] for _ in range(batch_size)]  # noqa: F812
        results["scores"] = [[] for _ in range(batch_size)]  # noqa: F812
        results["attention"] = [[] for _ in range(batch_size)]  # noqa: F812
        results["gold_score"] = [0] * batch_size
        results["batch"] = batch

        if self.mask is not None:
            mask = self.mask.get_log_probs_masking_tensor(src.squeeze(2), beam_size).to(memory_bank.device)


        for step in range(max_length):
            decoder_input = alive_seq[:, -1].view(1, -1, 1)

            # Decoder forward.
            dec_out, dec_states, attn = self.model.decoder(
                decoder_input,
                memory_bank,
                dec_states,
                memory_lengths=memory_lengths,
                step=step)

            # Generator forward.
            log_probs = self.model.generator.forward(dec_out.squeeze(0))
            vocab_size = log_probs.size(-1)

            if step < min_length:
                log_probs[:, end_token] = -1e20

            if self.mask is not None:
                log_probs = log_probs * mask

            # Multiply probs by the beam probability.
            log_probs += topk_log_probs.view(-1).unsqueeze(1)

            alpha = self.global_scorer.alpha
            length_penalty = ((5.0 + (step + 1)) / 6.0) ** alpha

            # Flatten probs into a list of possibilities.
            curr_scores = log_probs / length_penalty
            curr_scores = curr_scores.reshape(-1, beam_size * vocab_size)
            topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1)

            # Recover log probs.
            topk_log_probs = topk_scores * length_penalty

            # Resolve beam origin and true word ids.
            topk_beam_index = topk_ids.div(vocab_size)
            topk_ids = topk_ids.fmod(vocab_size)

            # Map beam_index to batch_index in the flat representation.
            batch_index = (
                    topk_beam_index
                    + beam_offset[:topk_beam_index.size(0)].unsqueeze(1))
            select_indices = batch_index.view(-1)

            # Append last prediction.
            alive_seq = torch.cat(
                [alive_seq.index_select(0, select_indices),
                 topk_ids.view(-1, 1)], -1)
            if return_attention:
                current_attn = attn["std"].index_select(1, select_indices)
                if alive_attn is None:
                    alive_attn = current_attn
                else:
                    alive_attn = alive_attn.index_select(1, select_indices)
                    alive_attn = torch.cat([alive_attn, current_attn], 0)

            is_finished = topk_ids.eq(end_token)
            if step + 1 == max_length:
                is_finished.fill_(1)

            # Save finished hypotheses.
            if is_finished.any():
                # Penalize beams that finished.
                topk_log_probs.masked_fill_(is_finished, -1e10)
                is_finished = is_finished.to('cpu')
                top_beam_finished |= is_finished[:, 0].eq(1)

                predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1))
                attention = (
                    alive_attn.view(
                        alive_attn.size(0), -1, beam_size, alive_attn.size(-1))
                    if alive_attn is not None else None)
                non_finished_batch = []
                for i in range(is_finished.size(0)):
                    b = batch_offset[i]
                    finished_hyp = is_finished[i].nonzero().view(-1)
                    # Store finished hypotheses for this batch.
                    for j in finished_hyp:
                        # if (predictions[i, j, 1:] == end_token).sum() <= 1:
                        hypotheses[b].append((
                            topk_scores[i, j],
                            predictions[i, j, 1:],  # Ignore start_token.
                            attention[:, i, j, :memory_lengths[i]]
                            if attention is not None else None))
                    # End condition is the top beam finished and we can return
                    # n_best hypotheses.
                    if top_beam_finished[i] and len(hypotheses[b]) >= n_best:
                        best_hyp = sorted(
                            hypotheses[b], key=lambda x: x[0], reverse=True)
                        for n, (score, pred, attn) in enumerate(best_hyp):
                            if n >= n_best:
                                break
                            results["scores"][b].append(score)
                            results["predictions"][b].append(pred)
                            results["attention"][b].append(
                                attn if attn is not None else [])
                    else:
                        non_finished_batch.append(i)
                non_finished = torch.tensor(non_finished_batch)
                # If all sentences are translated, no need to go further.
                if len(non_finished) == 0:
                    break
                # Remove finished batches for the next step.
                top_beam_finished = top_beam_finished.index_select(
                    0, non_finished)
                batch_offset = batch_offset.index_select(0, non_finished)
                non_finished = non_finished.to(topk_ids.device)
                topk_log_probs = topk_log_probs.index_select(0, non_finished)
                batch_index = batch_index.index_select(0, non_finished)
                select_indices = batch_index.view(-1)
                alive_seq = predictions.index_select(0, non_finished) \
                    .view(-1, alive_seq.size(-1))
                if alive_attn is not None:
                    alive_attn = attention.index_select(1, non_finished) \
                        .view(alive_attn.size(0),
                              -1, alive_attn.size(-1))

            # Reorder states.
            if type(memory_bank) == tuple:
                memory_bank = tuple(m.index_select(1, select_indices) for m in memory_bank)
            else:
                memory_bank = memory_bank.index_select(1, select_indices)
            memory_lengths = memory_lengths.index_select(0, select_indices)
            dec_states.map_batch_fn(
                lambda state, dim: state.index_select(dim, select_indices))

            if self.mask is not None:
                mask = mask.index_select(0, select_indices)

        return results
Esempio n. 13
0
 def fn_map_state(state, dim):
     return tile(state, self.beam_size, dim=dim)
Esempio n. 14
0
    def _translate_batch(
            self,
            batch,
            src_vocabs,
            max_length,
            min_length=0,
            ratio=0.,
            n_best=1,
            return_attention=False,
            xlation_builder=None,
    ):
        # TODO: support these blacklisted features.
        assert not self.dump_beam

        # (0) Prep the components of the search.
        use_src_map = self.copy_attn
        beam_size = self.beam_size
        batch_size = batch.batch_size

        # (1) Run the encoder on the src.
        src_list, enc_states_list, memory_bank_list, src_lengths_list = self._run_encoder(batch)
        self.model.decoder.init_state(src_list, memory_bank_list, enc_states_list)

        results = {
            "predictions": None,
            "scores": None,
            "attention": None,
            "batch": batch,
            "gold_score": self._gold_score(
                batch, memory_bank_list, src_lengths_list, src_vocabs, use_src_map,
                enc_states_list, batch_size, src_list)}
        
        # (2) Repeat src objects `beam_size` times.
        # We use batch_size x beam_size
        src_map_list = list()
        for src_type in self.src_types:
            src_map_list.append((tile(getattr(batch, f"src_map.{src_type}"), beam_size, dim=1) if use_src_map else None))
        # end for

        self.model.decoder.map_state(
            lambda state, dim: tile(state, beam_size, dim=dim))

        memory_lengths_list = list()
        memory_lengths = list()
        for src_i in range(len(memory_bank_list)):
            if isinstance(memory_bank_list[src_i], tuple):
                memory_bank_list[src_i] = tuple(tile(x, beam_size, dim=1) for x in memory_bank_list[src_i])
                mb_device = memory_bank_list[src_i][0].device
            else:
                memory_bank_list[src_i] = tile(memory_bank_list[src_i], beam_size, dim=1)
                mb_device = memory_bank_list[src_i].device
            # end if
            memory_lengths_list.append(tile(src_lengths_list[src_i], beam_size))
            memory_lengths.append(src_lengths_list[src_i])
        # end for
        memory_lengths = tile(torch.stack(memory_lengths, dim=0).sum(dim=0), beam_size)

        indexes = tile(torch.tensor(list(range(batch_size)), device=self._dev), beam_size)

        # (0) pt 2, prep the beam object
        beam = BeamSearch(
            beam_size,
            n_best=n_best,
            batch_size=batch_size,
            global_scorer=self.global_scorer,
            pad=self._tgt_pad_idx,
            eos=self._tgt_eos_idx,
            bos=self._tgt_bos_idx,
            min_length=min_length,
            ratio=ratio,
            max_length=max_length,
            mb_device=mb_device,
            return_attention=return_attention,
            stepwise_penalty=self.stepwise_penalty,
            block_ngram_repeat=self.block_ngram_repeat,
            exclusion_tokens=self._exclusion_idxs,
            memory_lengths=memory_lengths)

        for step in range(max_length):
            decoder_input = beam.current_predictions.view(1, -1, 1)

            log_probs, attn = self._decode_and_generate(
                decoder_input,
                memory_bank_list,
                batch,
                src_vocabs,
                memory_lengths_list=memory_lengths_list,
                src_map_list=src_map_list,
                step=step,
                batch_offset=beam._batch_offset)

            if self.reranker is not None:
                log_probs = self.reranker.rerank_step_beam_batch(
                    batch,
                    beam,
                    self.beam_size,
                    indexes,
                    log_probs,
                    attn,
                    self.fields["tgt"].base_field.vocab,
                    xlation_builder,
                )
            # end if

            non_finished = None
            beam.advance(log_probs, attn)
            any_beam_is_finished = beam.is_finished.any()
            if any_beam_is_finished:
                non_finished = beam.update_finished()
                if beam.done:
                    break

            select_indices = beam.current_origin

            if any_beam_is_finished:
                # Reorder states.
                for src_i in range(len(memory_bank_list)):
                    if isinstance(memory_bank_list[src_i], tuple):
                        memory_bank_list[src_i] = tuple(x.index_select(1, select_indices)
                                            for x in memory_bank_list[src_i])
                    else:
                        memory_bank_list[src_i] = memory_bank_list[src_i].index_select(1, select_indices)
                    # end if

                    memory_lengths_list[src_i] = memory_lengths_list[src_i].index_select(0, select_indices)
                # end for

                if use_src_map and src_map_list[0] is not None:
                    for src_i in range(len(src_map_list)):
                        src_map_list[src_i] = src_map_list[src_i].index_select(1, select_indices)
                    # end for
                # end if

                indexes = indexes.index_select(0, select_indices)

            self.model.decoder.map_state(
                lambda state, dim: state.index_select(dim, select_indices))

        results["scores"] = beam.scores
        results["predictions"] = beam.predictions
        results["attention"] = beam.attention
        return results
Esempio n. 15
0
    def _fast_translate_batch(self,
                              batch,
                              data,
                              min_length=0,
                              node_type=None,
                              atc=None):
        # TODO: faster code path for beam_size == 1.

        # TODO: support these blacklisted features.
        assert data.data_type == 'text'
        assert not self.copy_attn
        assert not self.dump_beam
        assert not self.use_filter_pred
        assert self.block_ngram_repeat == 0
        assert self.global_scorer.beta == 0
        beam_size = self.beam_size
        batch_size = batch.batch_size
        vocab = self.fields["tgt"].vocab
        node_type_vocab = self.fields['tgt_feat_0'].vocab
        start_token = vocab.stoi[inputters.BOS_WORD]
        end_token = vocab.stoi[inputters.EOS_WORD]
        unk_token = vocab.stoi[inputters.UNK]
        token_masks, allowed_token_indices, not_allowed_token_indices\
            = generate_token_mask(atc, node_type_vocab, vocab)
        # debug(token_masks.keys())
        assert batch_size == 1, "Only 1 example decoding at a time supported"
        assert (node_type is not None) and isinstance(node_type, str), \
            "Node type string must be provided to translate tokens"

        node_types = [
            var(node_type_vocab.stoi[n_type.strip()])
            for n_type in node_type.split()
        ]
        node_types.append(var(node_type_vocab.stoi['-1']))
        if self.cuda:
            node_types = [n_type.cuda() for n_type in node_types]
        # debug(node_types)
        max_length = len(node_types)

        # Encoder forward.
        src = inputters.make_features(batch, 'src', data.data_type)
        _, src_lengths = batch.src
        enc_states, memory_bank = self.model.encoder(src, src_lengths)
        dec_states = self.model.decoder.init_decoder_state(src,
                                                           memory_bank,
                                                           enc_states,
                                                           with_cache=True)

        # Tile states and memory beam_size times.
        dec_states.map_batch_fn(
            lambda state, dim: tile(state, beam_size, dim=dim))
        memory_bank = tile(memory_bank, beam_size, dim=1)
        memory_lengths = tile(src_lengths, beam_size)

        batch_offset = torch.arange(batch_size,
                                    dtype=torch.long,
                                    device=memory_bank.device)
        beam_offset = torch.arange(0,
                                   batch_size * beam_size,
                                   step=beam_size,
                                   dtype=torch.long,
                                   device=memory_bank.device)
        alive_seq = torch.full([batch_size * beam_size, 1],
                               start_token,
                               dtype=torch.long,
                               device=memory_bank.device)
        alive_attn = None

        # Give full probability to the first beam on the first step.
        topk_log_probs = (torch.tensor(
            [0.0] + [float("-inf")] * (beam_size - 1),
            device=memory_bank.device).repeat(batch_size))

        results = dict()
        results["predictions"] = [[] for _ in range(batch_size)]  # noqa: F812
        results["scores"] = [[] for _ in range(batch_size)]  # noqa: F812
        results["attention"] = [[] for _ in range(batch_size)]  # noqa: F812
        results["gold_score"] = [0] * batch_size
        results["batch"] = batch
        save_attention = [[] for _ in range(batch_size)]

        # max_length += 1

        for step in range(max_length):
            decoder_input = alive_seq[:, -1].view(1, -1, 1)
            node_type = node_types[step]
            node_type_str = str(node_type_vocab.itos[node_type.item()])
            not_allowed_indices = not_allowed_token_indices[node_type_str]
            extra_input = torch.stack(
                [node_types[step] for _ in range(decoder_input.shape[1])])
            extra_input = extra_input.view(1, -1, 1)
            final_input = torch.cat((decoder_input, extra_input), dim=-1)
            if self.cuda:
                final_input = final_input.cuda()
            # Decoder forward.
            dec_out, dec_states, attn = self.model.decoder(
                final_input,
                memory_bank,
                dec_states,
                memory_lengths=memory_lengths,
                step=step)

            # Generator forward.
            log_probs = self.model.generator.forward(dec_out.squeeze(0))
            vocab_size = log_probs.size(-1)
            # debug(vocab_size, len(vocab))
            if step < min_length:
                log_probs[:, end_token] = -1.1e20

            # debug(len(not_allowed_indices))
            log_probs[:, not_allowed_indices] = -1e20
            log_probs += topk_log_probs.view(-1).unsqueeze(1)
            # debug('Source  Shape :\t', memory_bank.size())
            # debug('Probab Shape :\t', log_probs.size())
            #
            attn_probs = attn['std'].squeeze()  # (beam_size, source_length)
            # debug('Inside Attn:\t', attn['std'].size())

            alpha = self.global_scorer.alpha
            length_penalty = ((5.0 + (step + 1)) / 6.0)**alpha

            # Flatten probs into a list of possibilities.
            curr_scores = log_probs / length_penalty
            curr_scores = curr_scores.reshape(-1, beam_size * vocab_size)
            topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1)

            # Recover log probs.
            topk_log_probs = topk_scores * length_penalty

            # Resolve beam origin and true word ids.
            topk_beam_index = topk_ids.div(vocab_size)
            topk_ids = topk_ids.fmod(vocab_size)
            beam_indices = topk_beam_index.squeeze().cpu().numpy().tolist()
            if len(attn_probs.shape) == 1:
                attn_to_save = attn_probs[:]
            else:
                attn_to_save = attn_probs[beam_indices, :]
            save_attention[0].append(attn_to_save)
            # Map beam_index to batch_index in the flat representation.
            batch_index = (topk_beam_index +
                           beam_offset[:topk_beam_index.size(0)].unsqueeze(1))

            # Select and reorder alive batches.
            select_indices = batch_index.view(-1)
            alive_seq = alive_seq.index_select(0, select_indices)
            memory_bank = memory_bank.index_select(1, select_indices)
            memory_lengths = memory_lengths.index_select(0, select_indices)
            dec_states.map_batch_fn(
                lambda state, dim: state.index_select(dim, select_indices))
            alive_seq = torch.cat([alive_seq, topk_ids.view(-1, 1)], -1)

        # # End condition is the top beam reached end_token.
        # end_condition = topk_ids[: , 0].eq(end_token)
        # if step + 1 == max_length:
        #     end_condition.fill_(1)
        # finished = end_condition.nonzero().view(-1)
        #
        # # Save result of finished sentences.
        # if len(finished) > 0:
        predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1))
        scores = topk_scores.view(-1, beam_size)
        attention = None
        if alive_attn is not None:
            attention = alive_attn.view(alive_attn.size(0), -1, beam_size,
                                        alive_attn.size(-1))
        # debug(predictions.size())
        # debug(end_token)
        # debug(start_token)

        for i in range(len(predictions)):
            b = batch_offset[i]
            for n in range(self.n_best):
                # debug(unk_token in predictions[i, n, 1:].cpu().numpy().tolist())
                results["predictions"][b].append(predictions[i, n, 1:])
                results["scores"][b].append(scores[i, n])
                if attention is None:
                    results["attention"][b].append([])
                else:
                    results["attention"][b].append(
                        attention[:, i, n, :memory_lengths[i]])
                results["save_attention"] = save_attention
                # results["save_attention"] =
                # non_finished = end_condition.eq(0).nonzero().view(-1)
                # # If all sentences are translated, no need to go further.
                # if len(non_finished) == 0:
                #     break
                # # Remove finished batches for the next step.
                # topk_log_probs = topk_log_probs.index_select(
                #     0 , non_finished.to(topk_log_probs.device))
                # topk_ids = topk_ids.index_select(0 , non_finished)
                # batch_index = batch_index.index_select(0 , non_finished)
                # batch_offset = batch_offset.index_select(0 , non_finished)
        results["gold_score"] = [0] * batch_size
        if "tgt" in batch.__dict__:
            results["gold_score"] = self._run_target(batch, data)
        return results
Esempio n. 16
0
    def initialize(self, memory_bank, src_lengths, src_map=None, device=None):
        """Initialize for decoding.
        Repeat src objects `beam_size` times.
        """

        def fn_map_state(state, dim):
            return tile(state, self.beam_size, dim=dim)

        if isinstance(memory_bank, tuple):
            memory_bank = tuple(tile(x, self.beam_size, dim=1)
                                for x in memory_bank)
            mb_device = memory_bank[0].device
        else:
            memory_bank = tile(memory_bank, self.beam_size, dim=1)
            mb_device = memory_bank.device
        if src_map is not None:
            src_map = tile(src_map, self.beam_size, dim=1)
        if device is None:
            device = mb_device

        self.memory_lengths = tile(src_lengths, self.beam_size)
        super(BeamSearch, self).initialize(
            memory_bank, self.memory_lengths, src_map, device)

        ######### base scripts ###########
        # if device is None:
        #     device = torch.device('cpu')
        # self.alive_seq = torch.full(
        #     [self.batch_size * self.parallel_paths, 1], self.bos,
        #     dtype=torch.long, device=device)
        # self.is_finished = torch.zeros(
        #     [self.batch_size, self.parallel_paths],
        #     dtype=torch.uint8, device=device)
        # return None, memory_bank, src_lengths, src_map
        ######### base scripts ###########

        self.best_scores = torch.full(
            [self.batch_size], -1e10, dtype=torch.float, device=device)
        self._beam_offset = torch.arange(
            0, self.batch_size * self.beam_size, step=self.beam_size,
            dtype=torch.long, device=device)
        self.topk_log_probs = torch.tensor(
            [0.0] + [float("-inf")] * (self.beam_size - 1), device=device
        ).repeat(self.batch_size)
        # buffers for the topk scores and 'backpointer'
        self.topk_scores = torch.empty((self.batch_size, self.beam_size),
                                       dtype=torch.float, device=device)
        self.topk_ids = torch.empty((self.batch_size, self.beam_size),
                                    dtype=torch.long, device=device)
        self._batch_index = torch.empty([self.batch_size, self.beam_size],
                                        dtype=torch.long, device=device)
        return fn_map_state, memory_bank, self.memory_lengths, src_map

        def __len__(self):
        return self.alive_seq.shape[1]

    def ensure_min_length(self, log_probs):
        if len(self) <= self.min_length:
            log_probs[:, self.eos] = -1e20

    def ensure_max_length(self):
        # add one to account for BOS. Don't account for EOS because hitting
        # this implies it hasn't been found.
        if len(self) == self.max_length + 1:
            self.is_finished.fill_(1)

    def block_ngram_repeats(self, log_probs):
        cur_len = len(self)
        if self.block_ngram_repeat > 0 and cur_len > 1:
            for path_idx in range(self.alive_seq.shape[0]):
                # skip BOS
                hyp = self.alive_seq[path_idx, 1:]
                ngrams = set()
                fail = False
                gram = []
                for i in range(cur_len - 1):
                    # Last n tokens, n = block_ngram_repeat
                    gram = (gram + [hyp[i].item()])[-self.block_ngram_repeat:]
                    # skip the blocking if any token in gram is excluded
                    if set(gram) & self.exclusion_tokens:
                        continue
                    if tuple(gram) in ngrams:
                        fail = True
                    ngrams.add(tuple(gram))
                if fail:
                    log_probs[path_idx] = -10e20

    @property
    def current_predictions(self):
        return self.alive_seq[:, -1]

    @property
    def current_backptr(self):
        # for testing
        return self.select_indices.view(self.batch_size, self.beam_size)\
            .fmod(self.beam_size)

    @property
    def batch_offset(self):
        return self._batch_offset

    def advance(self, log_probs, attn):
        vocab_size = log_probs.size(-1)

        # using integer division to get an integer _B without casting
        _B = log_probs.shape[0] // self.beam_size

        if self._stepwise_cov_pen and self._prev_penalty is not None:
            self.topk_log_probs += self._prev_penalty
            self.topk_log_probs -= self.global_scorer.cov_penalty(
                self._coverage + attn, self.global_scorer.beta).view(
                _B, self.beam_size)

        # force the output to be longer than self.min_length
        step = len(self)
        self.ensure_min_length(log_probs)

        # Multiply probs by the beam probability.
        log_probs += self.topk_log_probs.view(_B * self.beam_size, 1)

        self.block_ngram_repeats(log_probs)

        # if the sequence ends now, then the penalty is the current
        # length + 1, to include the EOS token
        length_penalty = self.global_scorer.length_penalty(
            step + 1, alpha=self.global_scorer.alpha)

        # Flatten probs into a list of possibilities.
        curr_scores = log_probs / length_penalty
        curr_scores = curr_scores.reshape(_B, self.beam_size * vocab_size)
        torch.topk(curr_scores,  self.beam_size, dim=-1,
                   out=(self.topk_scores, self.topk_ids))

        # Recover log probs.
        # Length penalty is just a scalar. It doesn't matter if it's applied
        # before or after the topk.
        torch.mul(self.topk_scores, length_penalty, out=self.topk_log_probs)

        # Resolve beam origin and map to batch index flat representation.
        torch.div(self.topk_ids, vocab_size, out=self._batch_index)
        self._batch_index += self._beam_offset[:_B].unsqueeze(1)
        self.select_indices = self._batch_index.view(_B * self.beam_size)
        self.topk_ids.fmod_(vocab_size)  # resolve true word ids

        # Append last prediction.
        self.alive_seq = torch.cat(
            [self.alive_seq.index_select(0, self.select_indices),
             self.topk_ids.view(_B * self.beam_size, 1)], -1)
        if self.return_attention or self._cov_pen:
            current_attn = attn.index_select(1, self.select_indices)
            if step == 1:
                self.alive_attn = current_attn
                # update global state (step == 1)
                if self._cov_pen:  # coverage penalty
                    self._prev_penalty = torch.zeros_like(self.topk_log_probs)
                    self._coverage = current_attn
            else:
                self.alive_attn = self.alive_attn.index_select(
                    1, self.select_indices)
                self.alive_attn = torch.cat([self.alive_attn, current_attn], 0)
                # update global state (step > 1)
                if self._cov_pen:
                    self._coverage = self._coverage.index_select(
                        1, self.select_indices)
                    self._coverage += current_attn
                    self._prev_penalty = self.global_scorer.cov_penalty(
                        self._coverage, beta=self.global_scorer.beta).view(
                            _B, self.beam_size)

        if self._vanilla_cov_pen:
            # shape: (batch_size x beam_size, 1)
            cov_penalty = self.global_scorer.cov_penalty(
                self._coverage,
                beta=self.global_scorer.beta)
            self.topk_scores -= cov_penalty.view(_B, self.beam_size).float()

        self.is_finished = self.topk_ids.eq(self.eos)
        self.ensure_max_length()

    def update_finished(self):
        # Penalize beams that finished.
        _B_old = self.topk_log_probs.shape[0]
        step = self.alive_seq.shape[-1]  # 1 greater than the step in advance
        self.topk_log_probs.masked_fill_(self.is_finished, -1e10)
        # on real data (newstest2017) with the pretrained transformer,
        # it's faster to not move this back to the original device
        self.is_finished = self.is_finished.to('cpu')
        self.top_beam_finished |= self.is_finished[:, 0].eq(1)
        predictions = self.alive_seq.view(_B_old, self.beam_size, step)
        attention = (
            self.alive_attn.view(
                step - 1, _B_old, self.beam_size, self.alive_attn.size(-1))
            if self.alive_attn is not None else None)
        non_finished_batch = []
        for i in range(self.is_finished.size(0)):  # Batch level
            b = self._batch_offset[i]
            finished_hyp = self.is_finished[i].nonzero().view(-1)
            # Store finished hypotheses for this batch.
            for j in finished_hyp:  # Beam level: finished beam j in batch i
                if self.ratio > 0:
                    s = self.topk_scores[i, j] / (step + 1)
                    if self.best_scores[b] < s:
                        self.best_scores[b] = s
                self.hypotheses[b].append((
                    self.topk_scores[i, j],
                    predictions[i, j, 1:],  # Ignore start_token.
                    attention[:, i, j, :self.memory_lengths[i]]
                    if attention is not None else None))
            # End condition is the top beam finished and we can return
            # n_best hypotheses.
            if self.ratio > 0:
                pred_len = self.memory_lengths[i] * self.ratio
                finish_flag = ((self.topk_scores[i, 0] / pred_len)
                               <= self.best_scores[b]) or \
                    self.is_finished[i].all()
            else:
                finish_flag = self.top_beam_finished[i] != 0
            if finish_flag and len(self.hypotheses[b]) >= self.n_best:
                best_hyp = sorted(
                    self.hypotheses[b], key=lambda x: x[0], reverse=True)
                for n, (score, pred, attn) in enumerate(best_hyp):
                    if n >= self.n_best:
                        break
                    self.scores[b].append(score)
                    self.predictions[b].append(pred)  # ``(batch, n_best,)``
                    self.attention[b].append(
                        attn if attn is not None else [])
            else:
                non_finished_batch.append(i)
        non_finished = torch.tensor(non_finished_batch)
        # If all sentences are translated, no need to go further.
        if len(non_finished) == 0:
            self.done = True
            return

        _B_new = non_finished.shape[0]
        # Remove finished batches for the next step.
        self.top_beam_finished = self.top_beam_finished.index_select(
            0, non_finished)
        self._batch_offset = self._batch_offset.index_select(0, non_finished)
        non_finished = non_finished.to(self.topk_ids.device)
        self.topk_log_probs = self.topk_log_probs.index_select(0,
                                                               non_finished)
        self._batch_index = self._batch_index.index_select(0, non_finished)
        self.select_indices = self._batch_index.view(_B_new * self.beam_size)
        self.alive_seq = predictions.index_select(0, non_finished) \
            .view(-1, self.alive_seq.size(-1))
        self.topk_scores = self.topk_scores.index_select(0, non_finished)
        self.topk_ids = self.topk_ids.index_select(0, non_finished)
        if self.alive_attn is not None:
            inp_seq_len = self.alive_attn.size(-1)
            self.alive_attn = attention.index_select(1, non_finished) \
                .view(step - 1, _B_new * self.beam_size, inp_seq_len)
            if self._cov_pen:
                self._coverage = self._coverage \
                    .view(1, _B_old, self.beam_size, inp_seq_len) \
                    .index_select(1, non_finished) \
                    .view(1, _B_new * self.beam_size, inp_seq_len)
                if self._stepwise_cov_pen:
                    self._prev_penalty = self._prev_penalty.index_select(
                        0, non_finished)
Esempio n. 17
0
    def _translate_batch(self, batch, data, builder):
        # (0) Prep each of the components of the search.
        # And helper method for reducing verbosity.
        beam_size = self.beam_size
        batch_size = batch.batch_size
        data_type = data.data_type
        tgt_field = self.fields['tgt'][0][1]
        vocab = tgt_field.vocab

        # Define a set of tokens to exclude from ngram-blocking
        exclusion_tokens = {vocab.stoi[t] for t in self.ignore_when_blocking}

        pad = vocab.stoi[tgt_field.pad_token]
        eos = vocab.stoi[tgt_field.eos_token]
        bos = vocab.stoi[tgt_field.init_token]
        beam = [
            onmt.translate.Beam(beam_size,
                                n_best=self.n_best,
                                cuda=self.cuda,
                                global_scorer=self.global_scorer,
                                pad=pad,
                                eos=eos,
                                bos=bos,
                                vocab=vocab,
                                min_length=self.min_length,
                                stepwise_penalty=self.stepwise_penalty,
                                block_ngram_repeat=self.block_ngram_repeat,
                                exclusion_tokens=exclusion_tokens,
                                num_clusters=self.num_clusters,
                                embeddings=self.cluster_embeddings,
                                prev_hyps=self.prev_hyps,
                                hamming_dist=self.hamming_dist,
                                k_per_cand=self.k_per_cand,
                                hamming_penalty=self.hamming_penalty)
            for __ in range(batch_size)
        ]

        # (1) Run the encoder on the src.
        src, enc_states, memory_bank, src_lengths = self._run_encoder(
            batch, data_type)
        self.model.decoder.init_state(src, memory_bank, enc_states)

        results = {}
        results["predictions"] = []
        results["scores"] = []
        results["attention"] = []
        results["batch"] = batch
        if "tgt" in batch.__dict__:
            results["gold_score"] = self._score_target(
                batch, memory_bank, src_lengths, data, batch.src_map
                if data_type == 'text' and self.copy_attn else None)
            self.model.decoder.init_state(src, memory_bank, enc_states)
        else:
            results["gold_score"] = [0] * batch_size

        # (2) Repeat src objects `beam_size` times.
        # We use now  batch_size x beam_size (same as fast mode)
        src_map = (tile(batch.src_map, beam_size, dim=1)
                   if data.data_type == 'text' and self.copy_attn else None)
        self.model.decoder.map_state(
            lambda state, dim: tile(state, beam_size, dim=dim))

        if isinstance(memory_bank, tuple):
            memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank)
        else:
            memory_bank = tile(memory_bank, beam_size, dim=1)
        memory_lengths = tile(src_lengths, beam_size)

        # Saves new hypotheses
        new_hyps = []

        # (3) run the decoder to generate sentences, using beam search.
        for i in range(self.max_length):
            if all((b.done() for b in beam)):
                break

            # (a) Construct batch x beam_size nxt words.
            # Get all the pending current beam words and arrange for forward.

            inp = torch.stack([b.get_current_state() for b in beam])
            inp = inp.view(1, -1, 1)

            # (b) Decode and forward
            out, beam_attn = self._decode_and_generate(
                inp,
                memory_bank,
                batch,
                data,
                memory_lengths=memory_lengths,
                src_map=src_map,
                step=i)
            out = out.view(batch_size, beam_size, -1)
            beam_attn = beam_attn.view(batch_size, beam_size, -1)

            # (c) Advance each beam.
            select_indices_array = []
            # Loop over the batch_size number of beam
            for j, b in enumerate(beam):

                ## Gets previous beam
                current_beam = []
                if i > 0:
                    ret2, fins = self._from_current_beam(beam)
                    ret2["gold_score"] = [0] * batch_size
                    if "tgt" in batch.__dict__:
                        ret2["gold_score"] = self._run_target(batch, data)
                    ret2["batch"] = batch
                    current_beam = self.debug_translation(ret2, builder,
                                                          fins)[0]
                new_hyps += current_beam

                b.advance(out[j, :], beam_attn.data[j, :, :memory_lengths[j]],
                          current_beam, i)
                select_indices_array.append(b.get_current_origin() +
                                            j * beam_size)
            select_indices = torch.cat(select_indices_array)

            self.model.decoder.map_state(
                lambda state, dim: state.index_select(dim, select_indices))

        # Adds all partial hypotheses to prev_hyps for the next iteration of beam search
        self.prev_hyps += new_hyps

        # (4) Extract sentences from beam.
        for b in beam:
            scores, ks = b.sort_finished(minimum=self.n_best)
            hyps, attn = [], []
            for i, (times, k) in enumerate(ks[:self.n_best]):
                hyp, att = b.get_hyp(times, k)
                hyps.append(hyp)
                attn.append(att)
            results["predictions"].append(hyps)
            results["scores"].append(scores)
            results["attention"].append(attn)

        return results
Esempio n. 18
0
    def _translate_batch(self,
                         batch,
                         src_vocabs,
                         max_length,
                         min_length=0,
                         ratio=0.,
                         n_best=1,
                         return_attention=False):
        # TODO: support these blacklisted features.
        assert not self.dump_beam

        # (0) Prep the components of the search.
        use_src_map = self.copy_attn
        beam_size = self.beam_size
        batch_size = batch.batch_size

        #### TODO: Augment batch with distractors

        # (1) Run the encoder on the src.
        src, enc_states, memory_bank, src_lengths = self._run_encoder(batch)
        # src has shape [1311, 2, 1]
        # enc_states has shape [1311, 2, 512],
        # Memory_bank has shape [1311, 2, 512]
        self.model.decoder.init_state(src, memory_bank, enc_states)

        results = {
            "predictions":
            None,
            "scores":
            None,
            "attention":
            None,
            "batch":
            batch,
            "gold_score":
            self._gold_score(batch, memory_bank, src_lengths, src_vocabs,
                             use_src_map, enc_states, batch_size, src)
        }

        # (2) Repeat src objects `beam_size` times.
        # We use batch_size x beam_size
        src_map = (tile(batch.src_map, beam_size, dim=1)
                   if use_src_map else None)
        self.model.decoder.map_state(
            lambda state, dim: tile(state, beam_size, dim=dim))

        if isinstance(memory_bank, tuple):
            memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank)
            mb_device = memory_bank[0].device
        else:
            memory_bank = tile(memory_bank, beam_size, dim=1)
            mb_device = memory_bank.device
        memory_lengths = tile(src_lengths, beam_size)
        print('memory_bank size after tile:',
              memory_bank.shape)  #[1311, 20, 512]

        # (0) pt 2, prep the beam object
        beam = BeamSearch(beam_size,
                          n_best=n_best,
                          batch_size=batch_size,
                          global_scorer=self.global_scorer,
                          pad=self._tgt_pad_idx,
                          eos=self._tgt_eos_idx,
                          bos=self._tgt_bos_idx,
                          min_length=min_length,
                          ratio=ratio,
                          max_length=max_length,
                          mb_device=mb_device,
                          return_attention=return_attention,
                          stepwise_penalty=self.stepwise_penalty,
                          block_ngram_repeat=self.block_ngram_repeat,
                          exclusion_tokens=self._exclusion_idxs,
                          memory_lengths=memory_lengths)

        all_log_probs = []
        all_attn = []

        for step in range(max_length):
            decoder_input = beam.current_predictions.view(1, -1, 1)
            # decoder_input has shape[1,20,1]
            # decoder_input gives top 10 predictions for each batch element
            verbose = True if step == 10 else False
            log_probs, attn = self._decode_and_generate(
                decoder_input,
                memory_bank,
                batch,
                src_vocabs,
                memory_lengths=memory_lengths,
                src_map=src_map,
                step=step,
                batch_offset=beam._batch_offset,
                verbose=verbose)

            # log_probs and attn are the probs for next word given that the
            # current word is that in decoder_input
            all_log_probs.append(log_probs)
            all_attn.append(attn)

            beam.advance(log_probs, attn, verbose=verbose)

            any_beam_is_finished = beam.is_finished.any()
            if any_beam_is_finished:
                beam.update_finished()
                if beam.done:
                    break

            select_indices = beam.current_origin

            if any_beam_is_finished:
                # Reorder states.
                if isinstance(memory_bank, tuple):
                    memory_bank = tuple(
                        x.index_select(1, select_indices) for x in memory_bank)
                else:
                    memory_bank = memory_bank.index_select(1, select_indices)

                memory_lengths = memory_lengths.index_select(0, select_indices)

                if src_map is not None:
                    src_map = src_map.index_select(1, select_indices)

            self.model.decoder.map_state(
                lambda state, dim: state.index_select(dim, select_indices))

        print('batch_size:', batch_size)
        print('max_length:', max_length)
        print('all_log_probs has len', len(all_log_probs))
        print('all_log_probs[0].shape', all_log_probs[0].shape)
        print('comparing log_probs[0]', all_log_probs[2][:, 0])

        results["scores"] = beam.scores
        results["predictions"] = beam.predictions
        results["attention"] = beam.attention
        return results
Esempio n. 19
0
    def _translate_batch(
            self,
            src,
            src_lengths,
            batch_size,
            min_length=0,
            ratio=0.,
            n_best=1,
            return_attention=False):

        max_length = self.config.max_sequence_length + 1 # to account for EOS
        beam_size = 3
        
        # Encoder forward.
        enc_states, memory_bank, src_lengths = self.encoder(src, src_lengths)
        self.decoder.init_state(src, memory_bank, enc_states)

        results = { "predictions": None, "scores": None, "attention": None }

        # (2) Repeat src objects `beam_size` times.
        # We use batch_size x beam_size
        self.decoder.map_state(lambda state, dim: tile(state, beam_size, dim=dim))

        #if isinstance(memory_bank, tuple):
        #    memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank)
        #    mb_device = memory_bank[0].device
        #else:
        memory_bank = tile(memory_bank, beam_size, dim=1)
        mb_device = memory_bank.device
        memory_lengths = tile(src_lengths, beam_size)

        mb_device = memory_bank[0].device if isinstance(memory_bank, tuple) else memory_bank.device
        
        block_ngram_repeat = 0
        _exclusion_idxs = {}

        # (0) pt 2, prep the beam object
        beam = BeamSearch(
            beam_size,
            n_best=n_best,
            batch_size=batch_size,
            global_scorer=self.scorer,
            pad=self.config.tgt_padding,
            eos=self.config.tgt_eos,
            bos=self.config.tgt_bos,
            min_length=min_length,
            ratio=ratio,
            max_length=max_length,
            mb_device=mb_device,
            return_attention=return_attention,
            stepwise_penalty=None,
            block_ngram_repeat=block_ngram_repeat,
            exclusion_tokens=_exclusion_idxs,
            memory_lengths=memory_lengths)

        for step in range(max_length):
            decoder_input = beam.current_predictions.view(1, -1, 1)

            log_probs, attn = self._decode_and_generate(decoder_input, memory_bank, memory_lengths, step, pretraining = True)

            beam.advance(log_probs, attn)
            any_beam_is_finished = beam.is_finished.any()
            if any_beam_is_finished:
                beam.update_finished()
                if beam.done:
                    break

            select_indices = beam.current_origin

            if any_beam_is_finished:
                # Reorder states.
                if isinstance(memory_bank, tuple):
                    memory_bank = tuple(x.index_select(1, select_indices)
                                        for x in memory_bank)
                else:
                    memory_bank = memory_bank.index_select(1, select_indices)

                memory_lengths = memory_lengths.index_select(0, select_indices)

            self.decoder.map_state(
                lambda state, dim: state.index_select(dim, select_indices))

        results["scores"] = beam.scores
        results["predictions"] = beam.predictions
        results["attention"] = beam.attention
        return results
Esempio n. 20
0
    def _translate_batch(self, batch, data):
        # (0) Prep each of the components of the search.
        # And helper method for reducing verbosity.
        beam_size = self.beam_size
        batch_size = batch.batch_size
        data_type = data.data_type
        vocab = self.fields["tgt"].vocab

        # Define a set of tokens to exclude from ngram-blocking
        exclusion_tokens = {vocab.stoi[t] for t in self.ignore_when_blocking}

        pad = vocab.stoi[self.fields['tgt'].pad_token]
        eos = vocab.stoi[self.fields['tgt'].eos_token]
        bos = vocab.stoi[self.fields['tgt'].init_token]
        beam = [
            onmt.translate.Beam(beam_size,
                                n_best=self.n_best,
                                cuda=self.cuda,
                                global_scorer=self.global_scorer,
                                pad=pad,
                                eos=eos,
                                bos=bos,
                                min_length=self.min_length,
                                stepwise_penalty=self.stepwise_penalty,
                                block_ngram_repeat=self.block_ngram_repeat,
                                exclusion_tokens=exclusion_tokens)
            for __ in range(batch_size)
        ]

        # (1) Run the encoder on the src.
        src, knl, enc_states, his_memory_bank, src_memory_bank, knl_memory_bank, src_lengths = self._run_encoder(
            batch, data_type)
        self.model.decoder.init_state(src[100:, :, :], src[100:, :, :],
                                      src_memory_bank, enc_states)

        first_dec_words = torch.zeros(
            (self.max_length, batch_size,
             1)).fill_(self.model.encoder.embeddings.word_padding_idx).long()

        results = {}
        results["predictions"] = []
        results["scores"] = []
        results["attention"] = []
        results["batch"] = batch
        if "tgt" in batch.__dict__:
            results["gold_score"] = self._score_target(
                batch,
                his_memory_bank,
                src_memory_bank,
                knl_memory_bank,
                knl,
                src_lengths,
                data,
                batch.src_map
                if data_type == 'text' and self.copy_attn else None,
            )
            self.model.decoder.init_state(src[100:, :, :], src[100:, :, :],
                                          src_memory_bank, enc_states)
        else:
            results["gold_score"] = [0] * batch_size
        # (2) Repeat src objects `beam_size` times.
        # We use now  batch_size x beam_size (same as fast mode)
        src_map = (tile(batch.src_map, beam_size, dim=1)
                   if data.data_type == 'text' and self.copy_attn else None)
        self.model.decoder.map_state(
            lambda state, dim: tile(state, beam_size, dim=dim))

        if isinstance(src_memory_bank, tuple):
            src_memory_bank = tuple(
                tile(x, beam_size, dim=1) for x in src_memory_bank)
            mb_device = src_memory_bank[0].device
        else:
            src_memory_bank = tile(src_memory_bank, beam_size, dim=1)
            mb_device = src_memory_bank.device

        if isinstance(knl_memory_bank, tuple):
            knl_memory_bank = tuple(
                tile(x, beam_size, dim=1) for x in knl_memory_bank)
            mb_device = knl_memory_bank[0].device
        else:
            knl_memory_bank = tile(knl_memory_bank, beam_size, dim=1)
            mb_device = knl_memory_bank.device

        if isinstance(his_memory_bank, tuple):
            his_memory_bank = tuple(
                tile(x, beam_size, dim=1) for x in his_memory_bank)
            mb_device = his_memory_bank[0].device
        else:
            his_memory_bank = tile(his_memory_bank, beam_size, dim=1)
            mb_device = his_memory_bank.device
        memory_lengths = tile(src_lengths, beam_size)
        # (3) run the first decoder to generate sentences, using beam search.
        for i in range(self.max_length):
            if all((b.done() for b in beam)):
                break

            # (a) Construct batch x beam_size nxt words.
            # Get all the pending current beam words and arrange for forward.

            inp = torch.stack([b.get_current_state() for b in beam])
            inp = inp.view(1, -1, 1)

            # (b) Decode and forward
            dec_out, dec_attn = self.model.decoder(
                inp,
                src_memory_bank,
                his_memory_bank,
                memory_lengths=memory_lengths,
                step=i)
            beam_attn = dec_attn["std"]
            out = self.model.generator(dec_out.squeeze(0))

            out = out.view(batch_size, beam_size, -1)
            beam_attn = beam_attn.view(batch_size, beam_size, -1)

            # (c) Advance each beam.
            select_indices_array = []
            # Loop over the batch_size number of beam
            for j, b in enumerate(beam):
                b.advance(out[j, :], beam_attn.data[j, :, :memory_lengths[j]])
                select_indices_array.append(b.get_current_origin() +
                                            j * beam_size)
            select_indices = torch.cat(select_indices_array)

            self.model.decoder.map_state(
                lambda state, dim: state.index_select(dim, select_indices))

        # Extract the first decoder sentences from beam.
        for j, b in enumerate(beam):
            scores, ks = b.sort_finished(minimum=self.n_best)
            hyps = []
            for i, (times, k) in enumerate(ks[:self.n_best]):
                hyp, _ = b.get_hyp(times, k)
                for h in range(len(hyp)):
                    first_dec_words[h, j, 0] = hyp[h]
        first_dec_words = first_dec_words.cuda()
        emb, decode1_memory_bank, decode1_mask = self.model.encoder.histransformer(
            first_dec_words[:50, :, :], None)
        self.model.decoder2.init_state(first_dec_words[:50, :, :],
                                       knl[600:, :, :], None, None)
        if isinstance(decode1_memory_bank, tuple):
            decode1_memory_bank = tuple(
                tile(x, beam_size, dim=1) for x in decode1_memory_bank)
            mb_device = decode1_memory_bank[0].device
        else:
            decode1_memory_bank = tile(decode1_memory_bank, beam_size, dim=1)
            mb_device = decode1_memory_bank.device
        self.model.decoder2.map_state(
            lambda state, dim: tile(state, beam_size, dim=dim))

        # (4) run the second decoder to generate sentences, using beam search.
        for i in range(self.max_length):
            if all((b.done() for b in beam)):
                break

            # (a) Construct batch x beam_size nxt words.
            # Get all the pending current beam words and arrange for forward.

            inp = torch.stack([b.get_current_state() for b in beam])
            inp = inp.view(1, -1, 1)

            # (b) Decode and forward
            dec_out, dec_attn = self.model.decoder2(
                inp,
                decode1_memory_bank,
                knl_memory_bank,
                memory_lengths=memory_lengths,
                step=i)
            beam_attn = dec_attn["std"]
            out = self.model.generator(dec_out.squeeze(0))

            out = out.view(batch_size, beam_size, -1)
            beam_attn = beam_attn.view(batch_size, beam_size, -1)

            # (c) Advance each beam.
            select_indices_array = []
            # Loop over the batch_size number of beam
            for j, b in enumerate(beam):
                b.advance(out[j, :], beam_attn.data[j, :, :memory_lengths[j]])
                select_indices_array.append(b.get_current_origin() +
                                            j * beam_size)
            select_indices = torch.cat(select_indices_array)

            self.model.decoder2.map_state(
                lambda state, dim: state.index_select(dim, select_indices))
        # Extract the second decoder sentences from beam.
        for b in beam:
            scores, ks = b.sort_finished(minimum=self.n_best)
            hyps, attn = [], []
            for i, (times, k) in enumerate(ks[:self.n_best]):
                hyp, att = b.get_hyp(times, k)
                hyps.append(hyp)
                attn.append(att)
            results["predictions"].append(hyps)
            results["scores"].append(scores)
            results["attention"].append(attn)

        return results
Esempio n. 21
0
    def _translate_batch(
            self,
            batch,
            src_vocabs,
            max_length,
            min_length=0,
            ratio=0.,
            n_best=1,
            return_attention=False):
        # TODO: support these blacklisted features.
        assert not self.dump_beam

        # (0) Prep the components of the search.
        use_src_map = self.copy_attn

        beam_size = self.beam_size #default 5
        batch_size = batch.batch_size #default 30

        # (1) Run the encoder on the src.
        src, enc_states, memory_bank, src_lengths = self._run_encoder(batch)
        #src.size() = torch.Size([59, 30, 1]) [src_len,batch_size,1]
        #enc_states[0/1].size() = [2,30,500]
        #memory_bank.size() =[59,30,500]
        #src_lengths.size() = [30]
        self.model.decoder.init_state(src, memory_bank, enc_states)

        results = {
            "predictions": None,
            "scores": None,
            "attention": None,
            "batch": batch,
            "gold_score": self._gold_score(
                batch, memory_bank, src_lengths, src_vocabs, use_src_map,
                enc_states, batch_size, src)}

        # (2) Repeat src objects `beam_size` times.
        # We use batch_size x beam_size
        src_map = (tile(batch.src_map, beam_size, dim=1)
                   if use_src_map else None)
        self.model.decoder.map_state(
            lambda state, dim: tile(state, beam_size, dim=dim))

        if isinstance(memory_bank, tuple):
            memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank)    #把张量x在dim=1,重复beam_size次。beam_size=1是batch_size的维度。
            mb_device = memory_bank[0].device
        else:
            memory_bank = tile(memory_bank, beam_size, dim=1)
            mb_device = memory_bank.device
        memory_lengths = tile(src_lengths, beam_size)

        # (0) pt 2, prep the beam object
        beam = BeamSearch(
            beam_size,
            n_best=n_best,
            batch_size=batch_size,
            global_scorer=self.global_scorer,
            pad=self._tgt_pad_idx,
            eos=self._tgt_eos_idx,
            bos=self._tgt_bos_idx,
            min_length=min_length,
            ratio=ratio,
            max_length=max_length, #Maximum prediction length.
            mb_device=mb_device,
            return_attention=return_attention,
            stepwise_penalty=self.stepwise_penalty,
            block_ngram_repeat=self.block_ngram_repeat,
            exclusion_tokens=self._exclusion_idxs,
            memory_lengths=memory_lengths)

        for step in range(max_length): #一共走这个多个step,每个step将beam_size * batch_size个分支加入
            decoder_input = beam.current_predictions.view(1, -1, 1) #decoder_input.size() = torch.Size([1,150,1]) 150 = 30 * 5 = batch_size * beam_size
            # @property
            # def current_predictions(self):
            #     return self.alive_seq[:, -1]
            log_probs, attn = self._decode_and_generate(
                decoder_input,
                memory_bank,#torch.Size([59, 150, 500])
                batch,
                src_vocabs,
                memory_lengths=memory_lengths,
                src_map=src_map,
                step=step,
                batch_offset=beam._batch_offset)
            # print("log_probs = ",log_probs) #[150, 50004] 这个50004应该是词表的大小,词表中的单词应该是5万,多出来的4个应该是<s> </s> <unk> <pad>
            # print("attn = ",attn) #torch.Size([1, 150, 59]) 这个59应该是src中的最长的句子的长度
            # print("decoder_input = ",decoder_input.size())
            beam.advance(log_probs, attn)#这个里面完成的工作应该是将150再变回30,
            any_beam_is_finished = beam.is_finished.any()
            if any_beam_is_finished:
                beam.update_finished()
                if beam.done:
                    break

            select_indices = beam.current_origin

            if any_beam_is_finished:
                # Reorder states.
                if isinstance(memory_bank, tuple):
                    memory_bank = tuple(x.index_select(1, select_indices)
                                        for x in memory_bank)
                else:
                    memory_bank = memory_bank.index_select(1, select_indices)

                memory_lengths = memory_lengths.index_select(0, select_indices)

                if src_map is not None:
                    src_map = src_map.index_select(1, select_indices)

            self.model.decoder.map_state(
                lambda state, dim: state.index_select(dim, select_indices))

        results["scores"] = beam.scores
        results["predictions"] = beam.predictions
        results["attention"] = beam.attention
        return results
Esempio n. 22
0
    def _fast_translate_batch(self, batch, data):
        # TODO: faster code path for beam_size == 1.

        # TODO: support these blacklisted features.
        assert data.data_type == 'text'
        assert self.n_best == 1
        assert self.min_length == 0
        assert not self.copy_attn
        assert not self.replace_unk
        assert not self.dump_beam
        assert not self.use_filter_pred
        assert self.block_ngram_repeat == 0
        assert self.global_scorer.beta == 0

        beam_size = self.beam_size
        batch_size = batch.batch_size
        vocab = self.fields["tgt"].vocab
        start_token = vocab.stoi[inputters.BOS_WORD]
        end_token = vocab.stoi[inputters.EOS_WORD]

        # Encoder forward.
        src = inputters.make_features(batch, 'src', data.data_type)
        _, src_lengths = batch.src
        enc_states, memory_bank = self.model.encoder(src, src_lengths)
        dec_states = self.model.decoder.init_decoder_state(
            src, memory_bank, enc_states)

        # Tile states and memory beam_size times.
        dec_states.map_batch_fn(
            lambda state, dim: tile(state, beam_size, dim=dim))
        memory_bank = tile(memory_bank, beam_size, dim=1)
        memory_lengths = tile(src_lengths, beam_size)

        batch_offset = torch.arange(batch_size,
                                    dtype=torch.long,
                                    device=memory_bank.device)
        beam_offset = torch.arange(0,
                                   batch_size * beam_size,
                                   step=beam_size,
                                   dtype=torch.long,
                                   device=memory_bank.device)
        alive_seq = torch.full([batch_size * beam_size, 1],
                               start_token,
                               dtype=torch.long,
                               device=memory_bank.device)

        # Give full probability to the first beam on the first step.
        topk_log_probs = (torch.tensor(
            [0.0] + [float("-inf")] * (beam_size - 1),
            device=memory_bank.device).repeat(batch_size))

        results = {}
        results["predictions"] = [[] for _ in range(batch_size)]  # noqa: F812
        results["scores"] = [[] for _ in range(batch_size)]  # noqa: F812
        results["attention"] = [[[]] for _ in range(batch_size)]  # noqa: F812
        results["gold_score"] = [0] * batch_size
        results["batch"] = batch

        for step in range(self.max_length):
            decoder_input = alive_seq[:, -1].view(1, -1, 1)

            # Decoder forward.
            dec_out, dec_states, attn = self.model.decoder(
                decoder_input,
                memory_bank,
                dec_states,
                memory_lengths=memory_lengths,
                step=step)

            # Generator forward.
            log_probs = self.model.generator.forward(dec_out.squeeze(0))
            vocab_size = log_probs.size(-1)

            # Multiply probs by the beam probability.
            log_probs += topk_log_probs.view(-1).unsqueeze(1)

            alpha = self.global_scorer.alpha
            length_penalty = ((5.0 + (step + 1)) / 6.0)**alpha

            # Flatten probs into a list of possibilities.
            curr_scores = log_probs / length_penalty
            curr_scores = curr_scores.reshape(-1, beam_size * vocab_size)
            topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1)

            # Recover log probs.
            topk_log_probs = topk_scores * length_penalty

            # Resolve beam origin and true word ids.
            topk_beam_index = topk_ids.div(vocab_size)
            topk_ids = topk_ids.fmod(vocab_size)

            # Map beam_index to batch_index in the flat representation.
            batch_index = (topk_beam_index +
                           beam_offset[:topk_beam_index.size(0)].unsqueeze(1))

            # End condition is the top beam reached end_token.
            finished = topk_ids[:, 0].eq(end_token)
            finished_count = finished.sum()

            # Save result of finished sentences.
            if finished_count > 0 or step + 1 == self.max_length:
                predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1))
                scores = topk_scores.view(-1, beam_size)
                for i, is_finished in enumerate(finished.tolist()):
                    if step + 1 != self.max_length and is_finished == 0:
                        continue
                    # TODO: if we get there because of max_length, the last
                    # predicted token is currently discarded.
                    b = batch_offset[i]
                    results["predictions"][b].append(predictions[i, 0, 1:])
                    results["scores"][b].append(scores[i, 0])

            non_finished = finished.eq(0).nonzero().view(-1)

            # If all sentences are translated, no need to go further.
            if non_finished.nelement() == 0:
                break

            # Remove finished batches for the next step.
            if non_finished.nelement() < finished.nelement():
                topk_log_probs = topk_log_probs.index_select(
                    0, non_finished.to(topk_log_probs.device))
                topk_ids = topk_ids.index_select(0, non_finished)
                batch_index = batch_index.index_select(0, non_finished)
                batch_offset = batch_offset.index_select(0, non_finished)

            # Select and reorder alive batches.
            select_indices = batch_index.view(-1)
            alive_seq = alive_seq.index_select(0, select_indices)
            memory_bank = memory_bank.index_select(1, select_indices)
            memory_lengths = memory_lengths.index_select(0, select_indices)
            dec_states.map_batch_fn(
                lambda state, dim: state.index_select(dim, select_indices))

            # Append last prediction.
            alive_seq = torch.cat([alive_seq, topk_ids.view(-1, 1)], -1)

        return results
Esempio n. 23
0
    def _translate_batch(self, batch, data):
        # (0) Prep each of the components of the search.
        # And helper method for reducing verbosity.
        beam_size = self.beam_size
        batch_size = batch.batch_size
        data_type = data.data_type
        vocab = self.fields["tgt"].vocab

        # Define a list of tokens to exclude from ngram-blocking
        # exclusion_list = ["<t>", "</t>", "."]
        exclusion_tokens = set([vocab.stoi[t]
                                for t in self.ignore_when_blocking])

        beam = [onmt.translate.Beam(beam_size, n_best=self.n_best,
                                    cuda=self.cuda,
                                    global_scorer=self.global_scorer,
                                    pad=vocab.stoi[inputters.PAD_WORD],
                                    eos=vocab.stoi[inputters.EOS_WORD],
                                    bos=vocab.stoi[inputters.BOS_WORD],
                                    min_length=self.min_length,
                                    stepwise_penalty=self.stepwise_penalty,
                                    block_ngram_repeat=self.block_ngram_repeat,
                                    exclusion_tokens=exclusion_tokens)
                for __ in range(batch_size)]

        # (1) Run the encoder on the src.
        src, enc_states, memory_bank, src_lengths, _ = self._run_encoder(
            batch, data_type)
        self.model.decoder.init_state(src, memory_bank, enc_states)

        results = {}
        results["predictions"] = []
        results["scores"] = []
        results["attention"] = []
        results["batch"] = batch
        if "tgt" in batch.__dict__:
            results["gold_score"] = self._score_target(
                batch, memory_bank, src_lengths, data, batch.src_map
                if data_type == 'text' and self.copy_attn else None)
            self.model.decoder.init_state(src, memory_bank, enc_states)
        else:
            results["gold_score"] = [0] * batch_size

        # (2) Repeat src objects `beam_size` times.
        # We use now  batch_size x beam_size (same as fast mode)
        src_map = (tile(batch.src_map, beam_size, dim=1)
                   if data.data_type == 'text' and self.copy_attn else None)
        self.model.decoder.map_state(
            lambda state, dim: tile(state, beam_size, dim=dim))

        if isinstance(memory_bank, tuple):
            memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank)
        else:
            memory_bank = tile(memory_bank, beam_size, dim=1)
        memory_lengths = tile(src_lengths, beam_size)

        # (3) run the decoder to generate sentences, using beam search.
        for i in range(self.max_length):
            if all((b.done() for b in beam)):
                break

            # (a) Construct batch x beam_size nxt words.
            # Get all the pending current beam words and arrange for forward.

            inp = torch.stack([b.get_current_state() for b in beam])
            inp = inp.view(1, -1, 1)

            # (b) Decode and forward
            out, beam_attn = \
                self._decode_and_generate(inp, memory_bank, batch, data,
                                          memory_lengths=memory_lengths,
                                          src_map=src_map, step=i)

            out = out.view(batch_size, beam_size, -1)
            beam_attn = beam_attn.view(batch_size, beam_size, -1)

            # (c) Advance each beam.
            select_indices_array = []
            # Loop over the batch_size number of beam
            for j, b in enumerate(beam):
                b.advance(out[j, :],
                          beam_attn.data[j, :, :memory_lengths[j]])
                select_indices_array.append(
                    b.get_current_origin() + j * beam_size)
            select_indices = torch.cat(select_indices_array)

            self.model.decoder.map_state(
                lambda state, dim: state.index_select(dim, select_indices))

        # (4) Extract sentences from beam.
        for b in beam:
            n_best = self.n_best
            scores, ks = b.sort_finished(minimum=n_best)
            for i, (times, k) in enumerate(ks[:n_best]):
                hyp, att = b.get_hyp(times, k)
                hyps.append(hyp)
                attn.append(att)
            results["predictions"].append(hyps)
            results["scores"].append(scores)
            results["attention"].append(attn)

        return results
Esempio n. 24
0
    def translate_batch(self, batch):
        beam_size = self.beam_size
        tgt_field = self.fields['tgt'][0][1]
        vocab = tgt_field.vocab

        pad = vocab.stoi[tgt_field.pad_token]
        eos = vocab.stoi[tgt_field.eos_token]
        bos = vocab.stoi[tgt_field.init_token]
        b = Beam(beam_size,
                 n_best=self.n_best,
                 cuda=self.cuda,
                 pad=pad,
                 eos=eos,
                 bos=bos)

        src, src_lengths = batch.src
        # why doesn't this contain inflection source lengths when ensembling?
        side_info = side_information(batch)

        encoder_out = self.model.encode(src, lengths=src_lengths, **side_info)
        enc_states = encoder_out["enc_state"]
        memory_bank = encoder_out["memory_bank"]
        infl_memory_bank = encoder_out.get("inflection_memory_bank", None)

        self.model.init_decoder_state(enc_states)

        results = dict()

        if "tgt" in batch.__dict__:
            results["gold_score"] = self._score_target(
                batch,
                memory_bank,
                src_lengths,
                inflection_memory_bank=infl_memory_bank,
                **side_info)
            self.model.init_decoder_state(enc_states)
        else:
            results["gold_score"] = 0

        # (2) Repeat src objects `beam_size` times.
        self.model.map_decoder_state(
            lambda state, dim: tile(state, beam_size, dim=dim))

        if isinstance(memory_bank, tuple):
            memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank)
        else:
            memory_bank = tile(memory_bank, beam_size, dim=1)
        memory_lengths = tile(src_lengths, beam_size)

        if infl_memory_bank is not None:
            if isinstance(infl_memory_bank, tuple):
                infl_memory_bank = tuple(
                    tile(x, beam_size, dim=1) for x in infl_memory_bank)
            else:
                infl_memory_bank = tile(infl_memory_bank, beam_size, dim=1)
            tiled_infl_len = tile(side_info["inflection_lengths"], beam_size)
            side_info["inflection_lengths"] = tiled_infl_len

        if "language" in side_info:
            side_info["language"] = tile(side_info["language"], beam_size)

        for i in range(self.max_length):
            if b.done():
                break

            inp = b.current_state.unsqueeze(0)

            # the decoder expects an input of tgt_len x batch
            dec_out, dec_attn = self.model.decode(
                inp,
                memory_bank,
                memory_lengths=memory_lengths,
                inflection_memory_bank=infl_memory_bank,
                **side_info)
            attn = dec_attn["lemma"].squeeze(0)
            out = self.model.generator(dec_out.squeeze(0),
                                       transform=True,
                                       **side_info)

            # b.advance will take attn (beam size x src length)
            b.advance(out, dec_attn)
            select_indices = b.current_origin

            self.model.map_decoder_state(
                lambda state, dim: state.index_select(dim, select_indices))

        scores, ks = b.sort_finished()
        hyps, attn, out_probs = [], [], []
        for i, (times, k) in enumerate(ks[:self.n_best]):
            hyp, att, out_p = b.get_hyp(times, k)
            hyps.append(hyp)
            attn.append(att)
            out_probs.append(out_p)

        results["preds"] = hyps
        results["scores"] = scores
        results["attn"] = attn

        if self.beam_accum is not None:
            parent_ids = [t.tolist() for t in b.prev_ks]
            self.beam_accum["beam_parent_ids"].append(parent_ids)
            scores = [["%4f" % s for s in t.tolist()]
                      for t in b.all_scores][1:]
            self.beam_accum["scores"].append(scores)
            pred_ids = [[vocab.itos[i] for i in t.tolist()]
                        for t in b.next_ys][1:]
            self.beam_accum["predicted_ids"].append(pred_ids)

        if self.attn_path is not None:
            save_attn = {k: v.cpu() for k, v in attn[0].items()}
            src_seq = self.itos(src, "src")
            pred_seq = self.itos(hyps[0], "tgt")
            attn_dict = {"src": src_seq, "pred": pred_seq, "attn": save_attn}
            if "inflection" in save_attn:
                inflection_seq = self.itos(batch.inflection[0], "inflection")
                attn_dict["inflection"] = inflection_seq
            self.attns.append(attn_dict)

        if self.probs_path is not None:
            save_probs = out_probs[0].cpu()
            self.probs.append(save_probs)

        return results
Esempio n. 25
0
    def _translate_batch(
            self,
            batch,
            src_vocabs,
            max_length,
            min_length=0,
            ratio=0.,
            n_best=1,
            return_attention=False):
        # TODO: support these blacklisted features.
        assert not self.dump_beam

        # (0) Prep the components of the search.
        use_src_map = self.copy_attn
        beam_size = self.beam_size
        batch_size = batch.batch_size

        # (1) Run the encoder on the src.
        src, enc_states, memory_bank, src_lengths = self._run_encoder(batch)
        self.model.decoder.init_state(src, memory_bank, enc_states)

        results = {
            "predictions": None,
            "scores": None,
            "attention": None,
            "batch": batch,
            "gold_score": self._gold_score(
                batch, memory_bank, src_lengths, src_vocabs, use_src_map,
                enc_states, batch_size, src)}

        # (2) Repeat src objects `beam_size` times.
        # We use batch_size x beam_size
        src_map = (tile(batch.src_map, beam_size, dim=1)
                   if use_src_map else None)
        self.model.decoder.map_state(
            lambda state, dim: tile(state, beam_size, dim=dim))

        if isinstance(memory_bank, tuple):
            memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank)
            mb_device = memory_bank[0].device
        else:
            memory_bank = tile(memory_bank, beam_size, dim=1)
            mb_device = memory_bank.device
        memory_lengths = tile(src_lengths, beam_size)

        # (0) pt 2, prep the beam object
        beam = BeamSearch(
            beam_size,
            n_best=n_best,
            batch_size=batch_size,
            global_scorer=self.global_scorer,
            pad=self._tgt_pad_idx,
            eos=self._tgt_eos_idx,
            bos=self._tgt_bos_idx,
            min_length=min_length,
            ratio=ratio,
            max_length=max_length,
            mb_device=mb_device,
            return_attention=return_attention,
            stepwise_penalty=self.stepwise_penalty,
            block_ngram_repeat=self.block_ngram_repeat,
            exclusion_tokens=self._exclusion_idxs,
            memory_lengths=memory_lengths)

        for step in range(max_length):
            decoder_input = beam.current_predictions.view(1, -1, 1)

            log_probs, attn = self._decode_and_generate(
                decoder_input,
                memory_bank,
                batch,
                src_vocabs,
                memory_lengths=memory_lengths,
                src_map=src_map,
                step=step,
                batch_offset=beam._batch_offset)

            beam.advance(log_probs, attn)
            any_beam_is_finished = beam.is_finished.any()
            if any_beam_is_finished:
                beam.update_finished()
                if beam.done:
                    break

            select_indices = beam.current_origin

            if any_beam_is_finished:
                # Reorder states.
                if isinstance(memory_bank, tuple):
                    memory_bank = tuple(x.index_select(1, select_indices)
                                        for x in memory_bank)
                else:
                    memory_bank = memory_bank.index_select(1, select_indices)

                memory_lengths = memory_lengths.index_select(0, select_indices)

                if src_map is not None:
                    src_map = src_map.index_select(1, select_indices)

            self.model.decoder.map_state(
                lambda state, dim: state.index_select(dim, select_indices))

        results["scores"] = beam.scores
        results["predictions"] = beam.predictions
        results["attention"] = beam.attention
        return results
Esempio n. 26
0
    def _fast_translate_batch(self,
                              batch,
                              data,
                              max_length,
                              min_length=0,
                              n_best=1,
                              return_attention=False):
        # TODO: faster code path for beam_size == 1.

        # TODO: support these blacklisted features.
        assert not self.dump_beam
        assert not self.use_filter_pred
        assert self.block_ngram_repeat == 0
        assert self.global_scorer.beta == 0

        beam_size = self.beam_size
        batch_size = batch.batch_size
        vocab = self.fields["tgt"].vocab
        start_token = vocab.stoi[inputters.BOS_WORD]
        end_token = vocab.stoi[inputters.EOS_WORD]

        # Encoder forward.
        src, enc_states, memory_bank, src_lengths = self._run_encoder(
            batch, data.data_type)
        self.model.decoder.init_state(src, memory_bank, enc_states, with_cache=True)

        if self.refer:
            ref_list, ref_states_list, ref_bank_list, ref_lengths_list, ref_prs_list = [], [], [], [], []
            for i in range(self.refer):
                ref, ref_states, ref_bank, ref_lengths, ref_prs = self._run_refer(batch, data.data_type, k=i)
                ref_list.append(ref)
                ref_states_list.append(ref_states)
                ref_bank_list.append(ref_bank)
                ref_lengths_list.append(ref_lengths)
                ref_prs_list.append(ref_prs)
                self.extra_decoders[i].init_state(ref, ref_bank, ref_states, with_cache=True)
        else:
            ref_list, ref_states_list, ref_bank_list, ref_lengths_list, ref_prs_list = tuple([None]*5)

        results = dict()
        results["predictions"] = [[] for _ in range(batch_size)]  # noqa: F812
        results["scores"] = [[] for _ in range(batch_size)]  # noqa: F812
        results["attention"] = [[] for _ in range(batch_size)]  # noqa: F812
        results["batch"] = batch
        if "tgt" in batch.__dict__:
            results["gold_score"] = self._score_target(
                batch, memory_bank, src_lengths, data, batch.src_map
                if data.data_type == 'text' and self.copy_attn else None)
            self.model.decoder.init_state(
                src, memory_bank, enc_states, with_cache=True)
        else:
            results["gold_score"] = [0] * batch_size

        # Tile states and memory beam_size times.
        self.model.decoder.map_state(
            lambda state, dim: tile(state, beam_size, dim=dim))
        if isinstance(memory_bank, tuple):
            memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank)

            mb_device = memory_bank[0].device
        else:
            memory_bank = tile(memory_bank, beam_size, dim=1)

            mb_device = memory_bank.device

        memory_lengths = tile(src_lengths, beam_size)
        src_map = (tile(batch.src_map, beam_size, dim=1)
                   if data.data_type == 'text' and self.copy_attn else None)

        if self.refer:
            for i in range(self.refer):
                self.extra_decoders[i].map_state(
                    lambda state, dim: tile(state, beam_size, dim=dim))
                if isinstance(ref_bank_list[i], tuple):
                    ref_bank_list[i] = tuple(tile(x, beam_size, dim=1) for x in ref_bank_list[i])
                else:
                    ref_bank_list[i] = tile(ref_bank_list[i], beam_size, dim=1)
                ref_lengths_list[i] = tile(ref_lengths_list[i], beam_size)
                ref_prs_list[i] = tile(ref_prs_list[i], beam_size).view(-1, 1)
            # ref_prs = torch.rand_like(ref_prs)
        else:
            ref_bank_list, ref_lengths_list = None, None

        top_beam_finished = torch.zeros([batch_size], dtype=torch.uint8)
        batch_offset = torch.arange(batch_size, dtype=torch.long)
        beam_offset = torch.arange(
            0,
            batch_size * beam_size,
            step=beam_size,
            dtype=torch.long,
            device=mb_device)
        alive_seq = torch.full(
            [batch_size * beam_size, 1],
            start_token,
            dtype=torch.long,
            device=mb_device)
        alive_attn = None

        # Give full probability to the first beam on the first step.
        topk_log_probs = (
            torch.tensor([0.0] + [float("-inf")] * (beam_size - 1),
                         device=mb_device).repeat(batch_size))

        # Structure that holds finished hypotheses.
        hypotheses = [[] for _ in range(batch_size)]  # noqa: F812

        for step in range(max_length):
            decoder_input = alive_seq[:, -1].view(1, -1, 1)

            log_probs, attn = \
                self._decode_and_generate(decoder_input, memory_bank,
                                          batch, data,
                                          memory_lengths=memory_lengths,
                                          src_map=src_map,
                                          step=step,
                                          batch_offset=batch_offset, ref_bank=ref_bank_list,
                                          ref_lengths=ref_lengths_list, ref_prs=ref_prs_list)

            vocab_size = log_probs.size(-1)

            if self.guide:
                log_probs = self.guide_by_tp(alive_seq, log_probs)

            if step < min_length:
                log_probs[:, end_token] = -1e20

            # Multiply probs by the beam probability.
            log_probs += topk_log_probs.view(-1).unsqueeze(1)

            alpha = self.global_scorer.alpha
            length_penalty = ((5.0 + (step + 1)) / 6.0) ** alpha

            # Flatten probs into a list of possibilities.
            curr_scores = log_probs / length_penalty
            curr_scores = curr_scores.reshape(-1, beam_size * vocab_size)
            topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1)

            # Recover log probs.
            topk_log_probs = topk_scores * length_penalty

            # Resolve beam origin and true word ids.
            topk_beam_index = topk_ids.div(vocab_size)
            topk_ids = topk_ids.fmod(vocab_size)

            # topk_ids = torch.cat([batch.tgt[step + 1][batch_offset].view(-1, 1), topk_ids], -1)[:, :self.beam_size]

            # Map beam_index to batch_index in the flat representation.
            batch_index = (
                    topk_beam_index
                    + beam_offset[:topk_beam_index.size(0)].unsqueeze(1))
            select_indices = batch_index.view(-1)

            # Append last prediction.
            alive_seq = torch.cat(
                [alive_seq.index_select(0, select_indices),
                 topk_ids.contiguous().view(-1, 1)], -1)
            if return_attention:
                current_attn = attn.index_select(1, select_indices)
                if alive_attn is None:
                    alive_attn = current_attn
                else:
                    alive_attn = alive_attn.index_select(1, select_indices)
                    alive_attn = torch.cat([alive_attn, current_attn], 0)

            is_finished = topk_ids.eq(end_token)
            if step + 1 == max_length:
                is_finished.fill_(1)

            # Save finished hypotheses.
            if is_finished.any():
                # Penalize beams that finished.
                topk_log_probs.masked_fill_(is_finished, -1e10)
                is_finished = is_finished.to('cpu')
                top_beam_finished |= is_finished[:, 0].eq(1)
                predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1))
                attention = (
                    alive_attn.view(
                        alive_attn.size(0), -1, beam_size, alive_attn.size(-1))
                    if alive_attn is not None else None)
                non_finished_batch = []
                for i in range(is_finished.size(0)):
                    b = batch_offset[i]
                    finished_hyp = is_finished[i].nonzero().view(-1)
                    # Store finished hypotheses for this batch.
                    for j in finished_hyp:
                        hypotheses[b].append((
                            topk_scores[i, j],
                            predictions[i, j, 1:],  # Ignore start_token.
                            attention[:, i, j, :memory_lengths[i]]
                            if attention is not None else None))
                    # End condition is the top beam finished and we can return
                    # n_best hypotheses.
                    if top_beam_finished[i] and len(hypotheses[b]) >= n_best:
                        best_hyp = sorted(
                            hypotheses[b], key=lambda x: x[0], reverse=True)
                        for n, (score, pred, attn) in enumerate(best_hyp):
                            if n >= n_best:
                                break
                            results["scores"][b].append(score)
                            results["predictions"][b].append(pred)
                            results["attention"][b].append(
                                attn if attn is not None else [])
                    else:
                        non_finished_batch.append(i)
                non_finished = torch.tensor(non_finished_batch)
                # If all sentences are translated, no need to go further.
                if len(non_finished) == 0:
                    break
                # Remove finished batches for the next step.
                top_beam_finished = top_beam_finished.index_select(
                    0, non_finished)
                batch_offset = batch_offset.index_select(0, non_finished)
                non_finished = non_finished.to(topk_ids.device)
                topk_log_probs = topk_log_probs.index_select(0, non_finished)
                batch_index = batch_index.index_select(0, non_finished)
                select_indices = batch_index.view(-1)
                alive_seq = predictions.index_select(0, non_finished) \
                    .view(-1, alive_seq.size(-1))
                if alive_attn is not None:
                    alive_attn = attention.index_select(1, non_finished) \
                        .view(alive_attn.size(0),
                              -1, alive_attn.size(-1))

            # Reorder states.
            if isinstance(memory_bank, tuple):
                memory_bank = tuple(x.index_select(1, select_indices)
                                    for x in memory_bank)
            else:
                memory_bank = memory_bank.index_select(1, select_indices)

            memory_lengths = memory_lengths.index_select(0, select_indices)
            if self.refer:
                for i in range(self.refer):
                    if isinstance(ref_bank_list[i], tuple):
                        ref_bank_list[i] = tuple(x.index_select(1, select_indices)for x in ref_bank_list[i])
                    else:
                        ref_bank_list[i] = ref_bank_list[i].index_select(1, select_indices)

                    ref_lengths_list[i] = ref_lengths_list[i].index_select(0, select_indices)
                    ref_prs_list[i] = ref_prs_list[i].index_select(0, select_indices)
                    self.extra_decoders[i].map_state(
                        lambda state, dim: state.index_select(dim, select_indices))

            self.model.decoder.map_state(
                lambda state, dim: state.index_select(dim, select_indices))
            if src_map is not None:
                src_map = src_map.index_select(1, select_indices)
        if self.guide:
            print(self.batch_num)
            self.batch_num += 1
        return results
Esempio n. 27
0
    def forward(self, input, attn_mask=None):
        """
        Inputs of forward function
            input: [target length, batch size, embed dim]
            attn_mask [(batch size), sequence_length, sequence_length]

        Outputs of forward function
            attn_output: [target length, batch size, embed dim]
            attn_output_weights: [batch size, target length, sequence length]
        """

        seq_len, bsz, embed_dim = input.size()
        assert embed_dim == self.embed_dim

        # self-attention
        q, k, v = F.linear(input, self.in_proj_weight,
                           self.in_proj_bias).chunk(3, dim=-1)
        q *= self.scaling

        # Cut q, k, v in num_heads part
        q = q.contiguous().view(seq_len, bsz * self.num_heads,
                                self.head_dim).transpose(0, 1)
        k = k.contiguous().view(-1, bsz * self.num_heads,
                                self.head_dim).transpose(0, 1)
        v = v.contiguous().view(-1, bsz * self.num_heads,
                                self.head_dim).transpose(0, 1)

        #  Gated Linear Unit
        if self._use_glu:
            q = self.query_glu(q)
            k = self.key_glu(k)

        # batch matrix multply query against key
        # attn_output_weights is [bsz * num_heads, seq_len, seq_len]
        attn_output_weights = torch.bmm(q, k.transpose(1, 2))

        assert list(attn_output_weights.size()) == [
            bsz * self.num_heads, seq_len, seq_len
        ]

        if attn_mask is not None:
            if attn_mask.dim() == 2:
                # We use the same mask for each item in the batch
                attn_mask = attn_mask.unsqueeze(0)
            elif attn_mask.dim() == 3:
                # Each item in the batch has its own mask.
                # We need to inflate the mask to go with all heads
                attn_mask = tile(attn_mask, count=self.num_heads, dim=0)
            else:
                # Don't known what we would be doing here...
                raise RuntimeError(f'Wrong mask dim: {attn_mask.dim()}')

            # The mask should be either 0 of -inf to go with softmax
            attn_output_weights += attn_mask

        attn_output_weights = F.softmax(
            attn_output_weights.float(),
            dim=-1,
            dtype=torch.float32 if attn_output_weights.dtype == torch.float16
            else attn_output_weights.dtype)
        attn_output_weights = F.dropout(attn_output_weights,
                                        p=self.dropout,
                                        training=self.training)

        attn_output = torch.bmm(attn_output_weights, v)
        assert list(attn_output.size()) == [
            bsz * self.num_heads, seq_len, self.head_dim
        ]
        attn_output = attn_output.transpose(0, 1).contiguous().view(
            seq_len, bsz, embed_dim)
        attn_output = self.out_proj(attn_output)

        # average attention weights over heads
        attn_output_weights = attn_output_weights.view(bsz, self.num_heads,
                                                       seq_len, seq_len)
        attn_output_weights = attn_output_weights.sum(dim=1) / self.num_heads

        return attn_output, attn_output_weights
    def _translate_batch_with_strategy(self, batch, src_vocabs,
                                       decode_strategy, only_gold_score):
        """Translate a batch of sentences step by step using cache.

        Args:
            batch: a batch of sentences, yield by data iterator.
            src_vocabs (list): list of torchtext.data.Vocab if can_copy.
            decode_strategy (DecodeStrategy): A decode strategy to use for
                generate translation step by step.

        Returns:
            results (dict): The translation results.
        """
        # (0) Prep the components of the search.
        use_src_map = self.copy_attn
        parallel_paths = decode_strategy.parallel_paths  # beam_size
        batch_size = batch.batch_size

        initial_encoding_token = None
        initial_decoding_token = self.expert_index(self.expert_id)

        # initial_token = self.expert_index(self.expert_id)

        # (1) Run the encoder on the src.
        src, enc_states, memory_bank, src_lengths = \
            self._run_encoder(batch, initial_encoding_token)
        self.model.decoder.init_state(src, memory_bank, enc_states)

        results = {
            "predictions":
            None,
            "scores":
            None,
            "attention":
            None,
            "batch":
            batch,
            "gold_score":
            self._gold_score(batch, memory_bank, src_lengths, src_vocabs,
                             use_src_map, enc_states, batch_size, src,
                             initial_decoding_token)
        }

        if only_gold_score:
            results["scores"] = [[x] for x in results["gold_score"]]
            results["predictions"] = [[None]
                                      for _ in decode_strategy.predictions]
            results["attention"] = [[None] for _ in range(batch_size)]
            results["alignment"] = [[] for _ in range(batch_size)]
        else:
            # (2) prep decode_strategy. Possibly repeat src objects.
            src_map = batch.src_map if use_src_map else None
            fn_map_state, memory_bank, memory_lengths, src_map = \
                decode_strategy.initialize(memory_bank, src_lengths, src_map,
                                        initial_token=initial_decoding_token)
            if fn_map_state is not None:
                self.model.decoder.map_state(fn_map_state)

            # (3) Begin decoding step by step:
            for step in range(decode_strategy.max_length):
                decoder_input = decode_strategy.current_predictions.view(
                    1, -1, 1)
                log_probs, attn = self._decode_and_generate(
                    decoder_input,
                    memory_bank,
                    batch,
                    src_vocabs,
                    memory_lengths=memory_lengths,
                    src_map=src_map,
                    step=step,
                    batch_offset=decode_strategy.batch_offset)
                if step == 0 and self.model.prior is not None:
                    lprob_z = self.model.prior(memory_bank, memory_lengths)
                    log_probs += tile(lprob_z[:, self.expert_id].unsqueeze(-1),
                                      self._tgt_vocab_len,
                                      dim=1)

                decode_strategy.advance(log_probs, attn)
                any_finished = decode_strategy.is_finished.any()
                if any_finished:
                    decode_strategy.update_finished()
                    if decode_strategy.done:
                        break

                select_indices = decode_strategy.select_indices

                if any_finished:
                    # Reorder states.
                    if isinstance(memory_bank, tuple):
                        memory_bank = tuple(
                            x.index_select(1, select_indices)
                            for x in memory_bank)
                    else:
                        memory_bank = memory_bank.index_select(
                            1, select_indices)

                    memory_lengths = memory_lengths.index_select(
                        0, select_indices)

                    if src_map is not None:
                        src_map = src_map.index_select(1, select_indices)

                if parallel_paths > 1 or any_finished:
                    self.model.decoder.map_state(
                        lambda state, dim: state.index_select(
                            dim, select_indices))

            results["scores"] = decode_strategy.scores
            results["predictions"] = decode_strategy.predictions
            results["attention"] = decode_strategy.attention
            if self.report_align:
                results["alignment"] = self._align_forward(
                    batch, decode_strategy.predictions)
            else:
                results["alignment"] = [[] for _ in range(batch_size)]
        return results
Esempio n. 29
0
    def _fast_translate_batch(self,
                              batch,
                              data,
                              max_length,
                              min_length=0,
                              n_best=1,
                              return_attention=False):
        # TODO: faster code path for beam_size == 1.

        # TODO: support these blacklisted features.
        assert data.data_type == 'text'
        assert not self.copy_attn
        assert not self.dump_beam
        assert not self.use_filter_pred
        assert self.block_ngram_repeat == 0
        assert self.global_scorer.beta == 0

        beam_size = self.beam_size
        batch_size = batch.batch_size
        vocab = self.fields["tgt"].vocab
        start_token = vocab.stoi[inputters.BOS_WORD]
        end_token = vocab.stoi[inputters.EOS_WORD]

        # Encoder forward.
        src = inputters.make_features(batch, 'src', data.data_type)
        _, src_lengths = batch.src
        enc_states, memory_bank = self.model.encoder(src, src_lengths)
        dec_states = self.model.decoder.init_decoder_state(
            src, memory_bank, enc_states, with_cache=True)

        # Tile states and memory beam_size times.
        dec_states.map_batch_fn(
            lambda state, dim: tile(state, beam_size, dim=dim))
        memory_bank = tile(memory_bank, beam_size, dim=1)
        memory_lengths = tile(src_lengths, beam_size)

        batch_offset = torch.arange(
            batch_size, dtype=torch.long, device=memory_bank.device)
        beam_offset = torch.arange(
            0,
            batch_size * beam_size,
            step=beam_size,
            dtype=torch.long,
            device=memory_bank.device)
        alive_seq = torch.full(
            [batch_size * beam_size, 1],
            start_token,
            dtype=torch.long,
            device=memory_bank.device)
        alive_attn = None

        # Give full probability to the first beam on the first step.
        topk_log_probs = (
            torch.tensor([0.0] + [float("-inf")] * (beam_size - 1),
                         device=memory_bank.device).repeat(batch_size))

        results = {}
        results["predictions"] = [[] for _ in range(batch_size)]  # noqa: F812
        results["scores"] = [[] for _ in range(batch_size)]  # noqa: F812
        results["attention"] = [[] for _ in range(batch_size)]  # noqa: F812
        results["gold_score"] = [0] * batch_size
        results["batch"] = batch

        max_length += 1

        for step in range(max_length):
            decoder_input = alive_seq[:, -1].view(1, -1, 1)

            # Decoder forward.
            dec_out, dec_states, attn = self.model.decoder(
                decoder_input,
                memory_bank,
                dec_states,
                memory_lengths=memory_lengths,
                step=step)

            # Generator forward.
            log_probs = self.model.generator.forward(dec_out.squeeze(0))
            vocab_size = log_probs.size(-1)

            if step < min_length:
                log_probs[:, end_token] = -1e20

            # Multiply probs by the beam probability.
            log_probs += topk_log_probs.view(-1).unsqueeze(1)

            alpha = self.global_scorer.alpha
            length_penalty = ((5.0 + (step + 1)) / 6.0) ** alpha

            # Flatten probs into a list of possibilities.
            curr_scores = log_probs / length_penalty
            curr_scores = curr_scores.reshape(-1, beam_size * vocab_size)
            topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1)

            # Recover log probs.
            topk_log_probs = topk_scores * length_penalty

            # Resolve beam origin and true word ids.
            topk_beam_index = topk_ids.div(vocab_size)
            topk_ids = topk_ids.fmod(vocab_size)

            # Map beam_index to batch_index in the flat representation.
            batch_index = (
                topk_beam_index
                + beam_offset[:topk_beam_index.size(0)].unsqueeze(1))

            # End condition is the top beam reached end_token.
            end_condition = topk_ids[:, 0].eq(end_token)
            if step + 1 == max_length:
                end_condition.fill_(1)
            finished = end_condition.nonzero().view(-1)

            # Save result of finished sentences.
            if len(finished) > 0:
                predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1))
                scores = topk_scores.view(-1, beam_size)
                attention = None
                if alive_attn is not None:
                    attention = alive_attn.view(
                        alive_attn.size(0), -1, beam_size, alive_attn.size(-1))
                for i in finished:
                    b = batch_offset[i]
                    for n in range(n_best):
                        results["predictions"][b].append(predictions[i, n, 1:])
                        results["scores"][b].append(scores[i, n])
                        if attention is None:
                            results["attention"][b].append([])
                        else:
                            results["attention"][b].append(
                                attention[:, i, n, :memory_lengths[i]])
                non_finished = end_condition.eq(0).nonzero().view(-1)
                # If all sentences are translated, no need to go further.
                if len(non_finished) == 0:
                    break
                # Remove finished batches for the next step.
                topk_log_probs = topk_log_probs.index_select(
                    0, non_finished.to(topk_log_probs.device))
                topk_ids = topk_ids.index_select(0, non_finished)
                batch_index = batch_index.index_select(0, non_finished)
                batch_offset = batch_offset.index_select(0, non_finished)

            # Select and reorder alive batches.
            select_indices = batch_index.view(-1)
            alive_seq = alive_seq.index_select(0, select_indices)
            memory_bank = memory_bank.index_select(1, select_indices)
            memory_lengths = memory_lengths.index_select(0, select_indices)
            dec_states.map_batch_fn(
                lambda state, dim: state.index_select(dim, select_indices))

            # Append last prediction.
            alive_seq = torch.cat([alive_seq, topk_ids.view(-1, 1)], -1)

            if return_attention:
                current_attn = attn["std"].index_select(1, select_indices)
                if alive_attn is None:
                    alive_attn = current_attn
                else:
                    alive_attn = alive_attn.index_select(1, select_indices)
                    alive_attn = torch.cat([alive_attn, current_attn], 0)

        return results
Esempio n. 30
0
min_length = opt.min_length
ratio = 0.
max_length = opt.max_length
mb_device = 0
stepwise_penalty = None
block_ngram_repeat = 0
global_scorer = onmt.translate.GNMTGlobalScorer.from_opt(opt)
return_attention = False

for batch in data_iter:
    print()
    src, src_lengths = batch.src
    memory_lengths = src_lengths
    enc_states, memory_bank, src_lengths = model.encoder(src, src_lengths)
    model.decoder.init_state(src, memory_bank, enc_states)
    model.decoder.map_state(lambda state, dim: tile(state, beam_size, dim=dim))
    if isinstance(memory_bank, tuple):
        memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank)
        mb_device = memory_bank[0].device
    else:
        memory_bank = tile(memory_bank, beam_size, dim=1)
        mb_device = memory_bank.device
    memory_lengths = tile(src_lengths, beam_size)
    beam = BeamSearch(beam_size,
                      n_best=n_best,
                      batch_size=batch_size,
                      global_scorer=global_scorer,
                      pad=tgt_pad_idx,
                      eos=tgt_eos_idx,
                      bos=tgt_bos_idx,
                      min_length=min_length,
Esempio n. 31
0
    def _translate_batch_deprecated(self, batch, src_vocabs):
        # (0) Prep each of the components of the search.
        # And helper method for reducing verbosity.
        use_src_map = self.copy_attn
        beam_size = self.beam_size
        batch_size = batch.batch_size

        beam = [onmt.translate.Beam(
            beam_size,
            n_best=self.n_best,
            cuda=self.cuda,
            global_scorer=self.global_scorer,
            pad=self._tgt_pad_idx,
            eos=self._tgt_eos_idx,
            bos=self._tgt_bos_idx,
            min_length=self.min_length,
            stepwise_penalty=self.stepwise_penalty,
            block_ngram_repeat=self.block_ngram_repeat,
            exclusion_tokens=self._exclusion_idxs)
            for __ in range(batch_size)]

        # (1) Run the encoder on the src.
        src, enc_states, memory_bank, src_lengths = self._run_encoder(batch)
        self.model.decoder.init_state(src, memory_bank, enc_states)

        results = {
            "predictions": [],
            "scores": [],
            "attention": [],
            "batch": batch,
            "gold_score": self._gold_score(
                batch, memory_bank, src_lengths, src_vocabs, use_src_map,
                enc_states, batch_size, src)}

        # (2) Repeat src objects `beam_size` times.
        # We use now  batch_size x beam_size (same as fast mode)
        src_map = (tile(batch.src_map, beam_size, dim=1)
                   if use_src_map else None)
        self.model.decoder.map_state(
            lambda state, dim: tile(state, beam_size, dim=dim))

        if isinstance(memory_bank, tuple):
            memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank)
        else:
            memory_bank = tile(memory_bank, beam_size, dim=1)
        memory_lengths = tile(src_lengths, beam_size)

        # (3) run the decoder to generate sentences, using beam search.
        for i in range(self.max_length):
            if all((b.done for b in beam)):
                break

            # (a) Construct batch x beam_size nxt words.
            # Get all the pending current beam words and arrange for forward.

            inp = torch.stack([b.current_predictions for b in beam])
            inp = inp.view(1, -1, 1)

            # (b) Decode and forward
            out, beam_attn = self._decode_and_generate(
                inp, memory_bank, batch, src_vocabs,
                memory_lengths=memory_lengths, src_map=src_map, step=i
            )
            out = out.view(batch_size, beam_size, -1)
            beam_attn = beam_attn.view(batch_size, beam_size, -1)

            # (c) Advance each beam.
            select_indices_array = []
            # Loop over the batch_size number of beam
            for j, b in enumerate(beam):
                if not b.done:
                    b.advance(out[j, :],
                              beam_attn.data[j, :, :memory_lengths[j]])
                select_indices_array.append(
                    b.current_origin + j * beam_size)
            select_indices = torch.cat(select_indices_array)

            self.model.decoder.map_state(
                lambda state, dim: state.index_select(dim, select_indices))

        # (4) Extract sentences from beam.
        for b in beam:
            scores, ks = b.sort_finished(minimum=self.n_best)
            hyps, attn = [], []
            for times, k in ks[:self.n_best]:
                hyp, att = b.get_hyp(times, k)
                hyps.append(hyp)
                attn.append(att)
            results["predictions"].append(hyps)
            results["scores"].append(scores)
            results["attention"].append(attn)

        return results
Esempio n. 32
0
    def _translate_batch_deprecated(self, batch, src_vocabs):
        # (0) Prep each of the components of the search.
        # And helper method for reducing verbosity.
        use_src_map = self.copy_attn
        beam_size = self.beam_size
        batch_size = batch.batch_size

        beam = [onmt.translate.Beam(
            beam_size,
            n_best=self.n_best,
            cuda=self.cuda,
            global_scorer=self.global_scorer,
            pad=self._tgt_pad_idx,
            eos=self._tgt_eos_idx,
            bos=self._tgt_bos_idx,
            min_length=self.min_length,
            stepwise_penalty=self.stepwise_penalty,
            block_ngram_repeat=self.block_ngram_repeat,
            exclusion_tokens=self._exclusion_idxs)
            for __ in range(batch_size)]

        # (1) Run the encoder on the src.
        src, enc_states, memory_bank, src_lengths = self._run_encoder(batch)
        self.model.decoder.init_state(src, memory_bank, enc_states)

        results = {
            "predictions": [],
            "scores": [],
            "attention": [],
            "batch": batch,
            "gold_score": self._gold_score(
                batch, memory_bank, src_lengths, src_vocabs, use_src_map,
                enc_states, batch_size, src)}

        # (2) Repeat src objects `beam_size` times.
        # We use now  batch_size x beam_size (same as fast mode)
        src_map = (tile(batch.src_map, beam_size, dim=1)
                   if use_src_map else None)
        self.model.decoder.map_state(
            lambda state, dim: tile(state, beam_size, dim=dim))

        if isinstance(memory_bank, tuple):
            memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank)
        else:
            memory_bank = tile(memory_bank, beam_size, dim=1)
        memory_lengths = tile(src_lengths, beam_size)

        # (3) run the decoder to generate sentences, using beam search.
        for i in range(self.max_length):
            if all((b.done for b in beam)):
                break

            # (a) Construct batch x beam_size nxt words.
            # Get all the pending current beam words and arrange for forward.

            inp = torch.stack([b.current_predictions for b in beam])
            inp = inp.view(1, -1, 1)

            # (b) Decode and forward
            out, beam_attn = self._decode_and_generate(
                inp, memory_bank, batch, src_vocabs,
                memory_lengths=memory_lengths, src_map=src_map, step=i
            )
            out = out.view(batch_size, beam_size, -1)
            beam_attn = beam_attn.view(batch_size, beam_size, -1)

            # (c) Advance each beam.
            select_indices_array = []
            # Loop over the batch_size number of beam
            for j, b in enumerate(beam):
                if not b.done:
                    b.advance(out[j, :],
                              beam_attn.data[j, :, :memory_lengths[j]])
                select_indices_array.append(
                    b.current_origin + j * beam_size)
            select_indices = torch.cat(select_indices_array)

            self.model.decoder.map_state(
                lambda state, dim: state.index_select(dim, select_indices))

        # (4) Extract sentences from beam.
        for b in beam:
            scores, ks = b.sort_finished(minimum=self.n_best)
            hyps, attn = [], []
            for times, k in ks[:self.n_best]:
                hyp, att = b.get_hyp(times, k)
                hyps.append(hyp)
                attn.append(att)
            results["predictions"].append(hyps)
            results["scores"].append(scores)
            results["attention"].append(attn)

        return results