def _fast_translate_batch(self, batch, max_length, min_length=0):
        # TODO: faster code path for beam_size == 1.

        # TODO: support these blacklisted features.
        assert not self.dump_beam

        beam_size = self.beam_size
        batch_size = batch.batch_size
        src = batch.src
        segs = batch.segs
        mask_src = batch.mask_src

        src_features = self.model.bert(src, segs, mask_src)
        dec_states = self.model.decoder.init_decoder_state(src,
                                                           src_features,
                                                           with_cache=True)
        device = src_features.device

        # Tile states and memory beam_size times.
        dec_states.map_batch_fn(
            lambda state, dim: tile(state, beam_size, dim=dim))
        src_features = tile(src_features, beam_size, dim=0)
        batch_offset = torch.arange(batch_size,
                                    dtype=torch.long,
                                    device=device)
        beam_offset = torch.arange(0,
                                   batch_size * beam_size,
                                   step=beam_size,
                                   dtype=torch.long,
                                   device=device)
        alive_seq = torch.full([batch_size * beam_size, 1],
                               self.start_token,
                               dtype=torch.long,
                               device=device)

        # Give full probability to the first beam on the first step.
        topk_log_probs = (torch.tensor([0.0] + [float("-inf")] *
                                       (beam_size - 1),
                                       device=device).repeat(batch_size))

        # Structure that holds finished hypotheses.
        hypotheses = [[] for _ in range(batch_size)]  # noqa: F812

        results = {}
        results["predictions"] = [[] for _ in range(batch_size)]  # noqa: F812
        results["scores"] = [[] for _ in range(batch_size)]  # noqa: F812
        results["gold_score"] = [0] * batch_size
        results["batch"] = batch

        for step in range(max_length):
            decoder_input = alive_seq[:, -1].view(1, -1)

            # Decoder forward.
            decoder_input = decoder_input.transpose(0, 1)

            dec_out, dec_states = self.model.decoder(decoder_input,
                                                     src_features,
                                                     dec_states,
                                                     step=step)

            # Generator forward.
            log_probs = self.generator.forward(
                dec_out.transpose(0, 1).squeeze(0))
            vocab_size = log_probs.size(-1)

            if step < min_length:
                log_probs[:, self.end_token] = -1e20

            # Multiply probs by the beam probability.
            log_probs += topk_log_probs.view(-1).unsqueeze(1)

            alpha = self.global_scorer.alpha
            length_penalty = ((5.0 + (step + 1)) / 6.0)**alpha

            # Flatten probs into a list of possibilities.
            curr_scores = log_probs / length_penalty

            if (self.args.block_trigram):
                cur_len = alive_seq.size(1)
                if (cur_len > 3):
                    for i in range(alive_seq.size(0)):
                        fail = False
                        words = [int(w) for w in alive_seq[i]]
                        words = [self.vocab.ids_to_tokens[w] for w in words]
                        words = ' '.join(words).replace(' ##', '').split()
                        if (len(words) <= 3):
                            continue
                        trigrams = [(words[i - 1], words[i], words[i + 1])
                                    for i in range(1,
                                                   len(words) - 1)]
                        trigram = tuple(trigrams[-1])
                        if trigram in trigrams[:-1]:
                            fail = True
                        if fail:
                            curr_scores[i] = -10e20

            curr_scores = curr_scores.reshape(-1, beam_size * vocab_size)
            topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1)

            # Recover log probs.
            topk_log_probs = topk_scores * length_penalty

            # Resolve beam origin and true word ids.
            topk_beam_index = topk_ids.div(vocab_size)
            topk_ids = topk_ids.fmod(vocab_size)

            # Map beam_index to batch_index in the flat representation.
            batch_index = (topk_beam_index +
                           beam_offset[:topk_beam_index.size(0)].unsqueeze(1))
            select_indices = batch_index.view(-1)

            # Append last prediction.
            alive_seq = torch.cat([
                alive_seq.index_select(0, select_indices),
                topk_ids.view(-1, 1)
            ], -1)

            is_finished = topk_ids.eq(self.end_token)
            if step + 1 == max_length:
                is_finished.fill_(1)
            # End condition is top beam is finished.
            end_condition = is_finished[:, 0].eq(1)
            # Save finished hypotheses.
            if is_finished.any():
                predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1))
                for i in range(is_finished.size(0)):
                    b = batch_offset[i]
                    if end_condition[i]:
                        is_finished[i].fill_(1)
                    finished_hyp = is_finished[i].nonzero().view(-1)
                    # Store finished hypotheses for this batch.
                    for j in finished_hyp:
                        hypotheses[b].append(
                            (topk_scores[i, j], predictions[i, j, 1:]))
                    # If the batch reached the end, save the n_best hypotheses.
                    if end_condition[i]:
                        best_hyp = sorted(hypotheses[b],
                                          key=lambda x: x[0],
                                          reverse=True)
                        score, pred = best_hyp[0]

                        results["scores"][b].append(score)
                        results["predictions"][b].append(pred)
                non_finished = end_condition.eq(0).nonzero().view(-1)
                # If all sentences are translated, no need to go further.
                if len(non_finished) == 0:
                    break
                # Remove finished batches for the next step.
                topk_log_probs = topk_log_probs.index_select(0, non_finished)
                batch_index = batch_index.index_select(0, non_finished)
                batch_offset = batch_offset.index_select(0, non_finished)
                alive_seq = predictions.index_select(0, non_finished) \
                    .view(-1, alive_seq.size(-1))
            # Reorder states.
            select_indices = batch_index.view(-1)
            src_features = src_features.index_select(0, select_indices)
            dec_states.map_batch_fn(
                lambda state, dim: state.index_select(dim, select_indices))

        return results
Esempio n. 2
0
    def _fast_translate_batch(self,
                              batch,
                              memory_bank,
                              max_length,
                              memory_mask=None,
                              min_length=2,
                              beam_size=3,
                              hidden_state=None,
                              copy_attn=False):

        batch_size = memory_bank.size(0)

        if self.args.decoder == "rnn":
            dec_states = self.decoder.init_decoder_state(
                batch.src, memory_bank, hidden_state)
        else:
            dec_states = self.decoder.init_decoder_state(batch.src,
                                                         memory_bank,
                                                         with_cache=True)

        # Tile states and memory beam_size times.
        dec_states.map_batch_fn(
            lambda state, dim: tile(state, beam_size, dim=dim))
        memory_bank = tile(memory_bank, beam_size, dim=0)
        memory_mask = tile(memory_mask, beam_size, dim=0)
        if copy_attn:
            src_map = tile(batch.src_map, beam_size, dim=0)
        else:
            src_map = None

        batch_offset = torch.arange(batch_size,
                                    dtype=torch.long,
                                    device=self.device)
        beam_offset = torch.arange(0,
                                   batch_size * beam_size,
                                   step=beam_size,
                                   dtype=torch.long,
                                   device=self.device)

        alive_seq = torch.full([batch_size * beam_size, 1],
                               self.start_token,
                               dtype=torch.long,
                               device=self.device)

        # Give full probability to the first beam on the first step.
        topk_log_probs = (torch.tensor([0.0] + [float("-inf")] *
                                       (beam_size - 1),
                                       device=self.device).repeat(batch_size))

        # Structure that holds finished hypotheses.
        hypotheses = [[] for _ in range(batch_size)]  # noqa: F812

        results = [[] for _ in range(batch_size)]  # noqa: F812

        for step in range(max_length):
            # Decoder forward.
            decoder_input = alive_seq[:, -1].view(1, -1)
            decoder_input = decoder_input.transpose(0, 1)

            if self.args.decoder == "rnn":
                dec_out, dec_states, attn = self.decoder(
                    decoder_input,
                    memory_bank,
                    dec_states,
                    step=step,
                    memory_masks=memory_mask)
            else:
                dec_out, dec_states, attn = self.decoder(
                    decoder_input,
                    memory_bank,
                    dec_states,
                    step=step,
                    memory_masks=memory_mask,
                    requires_att=copy_attn)

            # Generator forward.
            if copy_attn:
                probs = self.generator(
                    dec_out.transpose(0, 1).squeeze(0),
                    attn['copy'].transpose(0, 1).squeeze(0), src_map)
                probs = collapse_copy_scores(
                    probs.unsqueeze(1), batch, self.vocab,
                    tile(batch_offset, beam_size, dim=0))
                log_probs = probs.squeeze(1)[:, :self.vocab_size].log()
            else:
                log_probs = self.generator(dec_out.transpose(0, 1).squeeze(0))

            vocab_size = log_probs.size(-1)

            if step < min_length:
                log_probs[:, self.end_token] = -1e20

            if self.args.block_trigram:
                cur_len = alive_seq.size(1)
                if (cur_len > 3):
                    for i in range(alive_seq.size(0)):
                        fail = False
                        words = [int(w) for w in alive_seq[i]]
                        if (len(words) <= 3):
                            continue
                        trigrams = [(words[i - 1], words[i], words[i + 1])
                                    for i in range(1,
                                                   len(words) - 1)]
                        trigram = tuple(trigrams[-1])
                        if trigram in trigrams[:-1]:
                            fail = True
                        if fail:
                            log_probs[i] = -1e20

            # Multiply probs by the beam probability.
            log_probs += topk_log_probs.view(-1).unsqueeze(1)

            alpha = self.args.alpha
            length_penalty = ((5.0 + (step + 1)) / 6.0)**alpha

            # Flatten probs into a list of possibilities.
            curr_scores = log_probs / length_penalty

            curr_scores = curr_scores.reshape(-1, beam_size * vocab_size)
            topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1)

            # Recover log probs.
            topk_log_probs = topk_scores * length_penalty

            # Resolve beam origin and true word ids.
            topk_beam_index = topk_ids.div(vocab_size)
            topk_ids = topk_ids.fmod(vocab_size)

            # Map beam_index to batch_index in the flat representation.
            batch_index = (topk_beam_index +
                           beam_offset[:topk_beam_index.size(0)].unsqueeze(1))
            select_indices = batch_index.view(-1)

            # Append last prediction.
            alive_seq = torch.cat([
                alive_seq.index_select(0, select_indices),
                topk_ids.view(-1, 1)
            ], -1)

            is_finished = topk_ids.eq(self.end_token)
            if step + 1 == max_length:
                is_finished.fill_(1)
            # End condition is top beam is finished.
            end_condition = is_finished[:, 0].eq(1)
            # Save finished hypotheses.
            if is_finished.any():
                predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1))
                for i in range(is_finished.size(0)):
                    b = batch_offset[i]
                    if end_condition[i]:
                        is_finished[i].fill_(1)
                    finished_hyp = is_finished[i].nonzero().view(-1)
                    # Store finished hypotheses for this batch.
                    for j in finished_hyp:
                        hypotheses[b].append(
                            (topk_scores[i, j], predictions[i, j, 1:]))
                    # If the batch reached the end, save the n_best hypotheses.
                    if end_condition[i]:
                        best_hyp = sorted(hypotheses[b],
                                          key=lambda x: x[0],
                                          reverse=True)
                        _, pred = best_hyp[0]
                        results[b].append(pred)
                non_finished = end_condition.eq(0).nonzero().view(-1)
                # If all sentences are translated, no need to go further.
                if len(non_finished) == 0:
                    break
                # Remove finished batches for the next step.
                topk_log_probs = topk_log_probs.index_select(0, non_finished)
                batch_index = batch_index.index_select(0, non_finished)
                batch_offset = batch_offset.index_select(0, non_finished)
                alive_seq = predictions.index_select(0, non_finished) \
                    .view(-1, alive_seq.size(-1))
            # Reorder states.
            select_indices = batch_index.view(-1)
            if memory_bank is not None:
                memory_bank = memory_bank.index_select(0, select_indices)
            if memory_mask is not None:
                memory_mask = memory_mask.index_select(0, select_indices)
            if src_map is not None:
                src_map = src_map.index_select(0, select_indices)
            dec_states.map_batch_fn(
                lambda state, dim: state.index_select(dim, select_indices))

        results = [t[0] for t in results]
        return results
Esempio n. 3
0
    def _fast_translate_batch(self, batch, max_length, min_length=0):
        # TODO: faster code path for beam_size == 1.

        # TODO: support these blacklisted features.
        assert not self.dump_beam

        beam_size = self.beam_size
        batch_size = batch.batch_size
        src = batch.src
        segs = batch.segs
        mask_src = batch.mask_src
        # print("mask_tgt = ", batch.mask_tgt.size())
        # print(batch.mask_tgt)

        # print("tgt_segs = ", batch.tgt_segs.size())
        # print(batch.tgt_segs)
        # print("tgt = ", batch.tgt.size())
        # print(batch.tgt)
        # exit()

        if self.args.bart:
            src_features = self.model.bert.model.encoder(
                input_ids=src, attention_mask=mask_src)[0]
            # print("output = ", self.model。)
            # past_key_values =
            # print("src_features = ", src_features.size())
            # print(src_features)
        else:
            src_features = self.model.bert(src, segs, mask_src)
        dec_states = self.model.decoder.init_decoder_state(src,
                                                           src_features,
                                                           with_cache=True)
        device = src_features.device

        # Tile states and memory beam_size times.
        dec_states.map_batch_fn(
            lambda state, dim: tile(state, beam_size, dim=dim))
        src_features = tile(src_features, beam_size, dim=0)

        if self.args.bart:
            # print("src = ", src.size())
            # print(src)
            # print("mask_src0 = ", mask_src.size())
            # print(mask_src)
            mask_src = tile(mask_src, beam_size, dim=0).byte()
            bart_dec_cache = None

        batch_offset = torch.arange(batch_size,
                                    dtype=torch.long,
                                    device=device)
        beam_offset = torch.arange(0,
                                   batch_size * beam_size,
                                   step=beam_size,
                                   dtype=torch.long,
                                   device=device)
        alive_seq = torch.full([batch_size * beam_size, 1],
                               self.start_token,
                               dtype=torch.long,
                               device=device)

        if self.args.language_limit:
            language_limit = torch.Tensor(json.load(open(
                self.args.tgt_mask))).long().cuda()

        # language_seg
        language_segs = torch.full([batch_size * beam_size, 1],
                                   0,
                                   dtype=torch.long,
                                   device=device)

        # Give full probability to the first beam on the first step.
        topk_log_probs = (torch.tensor([0.0] + [float("-inf")] *
                                       (beam_size - 1),
                                       device=device).repeat(batch_size))

        # Structure that holds finished hypotheses.
        hypotheses = [[] for _ in range(batch_size)]  # noqa: F812

        results = {}
        results["predictions"] = [[] for _ in range(batch_size)]  # noqa: F812
        results["scores"] = [[] for _ in range(batch_size)]  # noqa: F812
        results["gold_score"] = [0] * batch_size
        results["batch"] = batch

        for step in range(max_length):
            # print("alive_seq = ", alive_seq.size())
            # print(alive_seq)
            # print("language_segs = ", language_segs.size())
            # print(language_segs)
            decoder_input = alive_seq[:, -1].view(1, -1)

            # language_seg和alive_seq始终保持一致,alive_seq ,language也选择。
            # 选择完之后,看这次的token是否为5,做一个01矩阵,新的language直接等于1的位置取反,那就取反
            # 先给取反的位置赋0,(算取反向量,乘以01矩阵)加到原有的上边即可。
            # 最后concat到一起
            decoder_seg_input = language_segs[:, -1].view(1, -1)

            # Decoder forward.
            decoder_input = decoder_input.transpose(0, 1)

            decoder_seg_input = decoder_seg_input.transpose(0, 1)

            if self.args.bart:
                tgt_mask = torch.zeros(decoder_input.size()).byte().cuda()

                # causal_mask = (1 - _get_attn_subsequent_mask(tgt_mask.size(1)).float().cuda()) # * float("-inf")).cuda()
                causal_mask = torch.triu(
                    torch.zeros(tgt_mask.size(1),
                                tgt_mask.size(1)).float().fill_(
                                    float("-inf")).float(), 1).cuda()
                # print(src_features.size())
                # print(mask_src.size())
                # print(mask_src)
                # model_inputs = self.bert.model.prepare_inputs_for_generation(
                #     decoder_input, past=src_features, attention_mask=tgt_mask, use_cache=True
                # )
                dec_output = self.model.bert.model.decoder(
                    input_ids=alive_seq,
                    encoder_hidden_states=src_features,
                    encoder_padding_mask=mask_src,
                    decoder_padding_mask=tgt_mask,
                    decoder_causal_mask=causal_mask,
                    decoder_cached_states=bart_dec_cache,
                    use_cache=True)
                dec_out = dec_output[0]
                bart_dec_cache = dec_output[1][1]
                # print(bart_dec_cache)
                # print('dec_out = ')
                # print(dec_out[0])
                # exit()
            elif self.args.predict_first_language and self.args.multi_task:
                dec_out, dec_states = self.model.decoder_monolingual(
                    decoder_input,
                    src_features,
                    dec_states,
                    step=step,
                    tgt_segs=decoder_seg_input)
            else:
                dec_out, dec_states = self.model.decoder(
                    decoder_input,
                    src_features,
                    dec_states,
                    step=step,
                    tgt_segs=decoder_seg_input)

            # Generator forward.
            log_probs = self.generator.forward(
                dec_out.transpose(0, 1).squeeze(0))
            vocab_size = log_probs.size(-1)

            if self.args.language_limit:

                mask_language_limit = torch.zeros(log_probs.size()).cuda()
                mask_language_limit.index_fill_(1, language_limit, 1)
                # 如果两个语言拼接,那么生成第二语言才限制 如果是直接跨语言的话,那从最开始就要限制
                if self.args.predict_2language:
                    mask_language_limit = mask_language_limit.long(
                    ) * decoder_seg_input
                    mask_language_limit = mask_language_limit + (
                        1 - decoder_seg_input)
                else:
                    mask_language_limit = mask_language_limit.long(
                    ) * torch.ones(decoder_seg_input.size()).long().cuda()

                # 这里,把除了备选位置的都赋为负无穷
                log_probs.masked_fill_((1 - mask_language_limit).byte(), -1e20)

            if step < min_length:
                log_probs[:, self.end_token] = -1e20

            # Multiply probs by the beam probability.
            log_probs += topk_log_probs.view(-1).unsqueeze(1)

            alpha = self.global_scorer.alpha
            length_penalty = ((5.0 + (step + 1)) / 6.0)**alpha

            # Flatten probs into a list of possibilities.
            curr_scores = log_probs / length_penalty

            if (self.args.block_trigram):
                cur_len = alive_seq.size(1)
                if (cur_len > 3):
                    for i in range(alive_seq.size(0)):
                        fail = False
                        words = [int(w) for w in alive_seq[i]]
                        if self.args.bart:
                            words = [self.vocab.decoder[w] for w in words]
                        else:
                            words = [
                                self.vocab.ids_to_tokens[w] for w in words
                            ]
                        words = ' '.join(words).replace(' ##', '').split()
                        if (len(words) <= 3):
                            continue
                        trigrams = [(words[i - 1], words[i], words[i + 1])
                                    for i in range(1,
                                                   len(words) - 1)]
                        trigram = tuple(trigrams[-1])
                        if trigram in trigrams[:-1]:
                            fail = True
                        if fail:
                            curr_scores[i] = -10e20

            curr_scores = curr_scores.reshape(-1, beam_size * vocab_size)
            topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1)

            # Recover log probs.
            topk_log_probs = topk_scores * length_penalty

            # Resolve beam origin and true word ids.
            topk_beam_index = topk_ids.div(vocab_size)
            topk_ids = topk_ids.fmod(vocab_size)

            # Map beam_index to batch_index in the flat representation.
            batch_index = (topk_beam_index +
                           beam_offset[:topk_beam_index.size(0)].unsqueeze(1))
            select_indices = batch_index.view(-1)

            # Append last prediction.
            alive_seq = torch.cat([
                alive_seq.index_select(0, select_indices),
                topk_ids.view(-1, 1)
            ], -1)

            # print("topk_ids = ", topk_ids)
            # print("end_token = ", self.end_token)
            if self.args.bart:
                is_finished = topk_ids.eq(2)
            else:
                is_finished = topk_ids.eq(self.end_token)

            # 是否是第二个语言
            # 先把最后一部分取出来,然后把新加入5的填上1,再拼接起来。
            is_languaged = topk_ids.eq(5)
            # is_languaged = alive_seq[:, -2].unsqueeze(0).eq(5)
            language_segs = language_segs.index_select(0, select_indices)

            last_segs = language_segs[:, -1]
            tmp_seg = last_segs.masked_fill(is_languaged, 1)
            language_segs = torch.cat([language_segs, tmp_seg.view(-1, 1)], -1)

            # print("is_finished = ", is_finished)
            if step + 1 == max_length:
                is_finished.fill_(1)
            # End condition is top beam is finished.
            end_condition = is_finished[:, 0].eq(1)
            # Save finished hypotheses.
            if is_finished.any():
                predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1))
                for i in range(is_finished.size(0)):
                    b = batch_offset[i]
                    if end_condition[i]:
                        is_finished[i].fill_(1)
                    finished_hyp = is_finished[i].nonzero().view(-1)
                    # Store finished hypotheses for this batch.
                    for j in finished_hyp:
                        hypotheses[b].append(
                            (topk_scores[i, j], predictions[i, j, 1:]))
                    # If the batch reached the end, save the n_best hypotheses.
                    if end_condition[i]:
                        best_hyp = sorted(hypotheses[b],
                                          key=lambda x: x[0],
                                          reverse=True)
                        score, pred = best_hyp[0]
                        results["scores"][b].append(score)
                        results["predictions"][b].append(pred)
                non_finished = end_condition.eq(0).nonzero().view(-1)
                # If all sentences are translated, no need to go further.
                if len(non_finished) == 0:
                    break
                # Remove finished batches for the next step.
                topk_log_probs = topk_log_probs.index_select(0, non_finished)
                batch_index = batch_index.index_select(0, non_finished)
                batch_offset = batch_offset.index_select(0, non_finished)
                alive_seq = predictions.index_select(0, non_finished) \
                    .view(-1, alive_seq.size(-1))
            # Reorder states.
            select_indices = batch_index.view(-1)
            src_features = src_features.index_select(0, select_indices)
            dec_states.map_batch_fn(
                lambda state, dim: state.index_select(dim, select_indices))

        return results
Esempio n. 4
0
    def _fast_translate_batch(self, batch, max_length, min_length=0):
        # TODO: faster code path for beam_size == 1.

        # TODO: support these blacklisted features.
        assert not self.dump_beam

        beam_size = self.beam_size
        batch_size = batch.batch_size
        src = batch.src
        segs = batch.segs
        mask_src = batch.mask_src
        clss = batch.clss
        mask_cls = batch.mask_cls
        # print("src ", src.size())
        # print(src)
        # print("mask src", mask_src.size())
        # print(mask_src)
        if self.args.oracle:
            labels = batch.src_sent_labels
            ext_scores = ((labels.float() + 0.1) / 1.3) * mask_cls.float()
        else:
            ext_scores, _, sent_vec = self.model.extractor(
                src, segs, clss, mask_src, mask_cls)

        # print("ext_scores : ", ext_scores.size())
        # print(ext_scores)

        src_features = self.model.abstractor.bert(src, segs, mask_src)
        dec_states = self.model.abstractor.decoder.init_decoder_state(
            src, src_features, with_cache=True)

        device = src_features.device

        # Tile states and memory beam_size times.
        dec_states.map_batch_fn(
            lambda state, dim: tile(state, beam_size, dim=dim))
        src_features = tile(src_features, beam_size, dim=0)
        ext_scores = tile(ext_scores, beam_size, dim=0)
        mask_src = tile(mask_src, beam_size, dim=0)
        mask_cls = tile(mask_cls, beam_size, dim=0)
        clss = tile(clss, beam_size, dim=0)
        src = tile(src, beam_size, dim=0)

        batch_offset = torch.arange(batch_size,
                                    dtype=torch.long,
                                    device=device)
        beam_offset = torch.arange(0,
                                   batch_size * beam_size,
                                   step=beam_size,
                                   dtype=torch.long,
                                   device=device)
        alive_seq = torch.full([batch_size * beam_size, 1],
                               self.start_token,
                               dtype=torch.long,
                               device=device)

        # Give full probability to the first beam on the first step.
        topk_log_probs = (torch.tensor([0.0] + [float("-inf")] *
                                       (beam_size - 1),
                                       device=device).repeat(batch_size))

        # Structure that holds finished hypotheses.
        hypotheses = [[] for _ in range(batch_size)]  # noqa: F812

        results = {}
        results["predictions"] = [[] for _ in range(batch_size)]  # noqa: F812
        results["scores"] = [[] for _ in range(batch_size)]  # noqa: F812
        results["gold_score"] = [0] * batch_size
        results["batch"] = batch

        for step in range(max_length):
            decoder_input = alive_seq[:, -1].view(1, -1)

            # print("alive seq ", alive_seq.size())
            # print(alive_seq)
            # print("decoder_input ", decoder_input.size())
            # print(decoder_input)

            # Decoder forward.
            decoder_input = decoder_input.transpose(0, 1)
            encoder_state = src_features

            decoder_outputs, dec_states, y_emb = self.model.abstractor.decoder(
                decoder_input,
                src_features,
                dec_states,
                step=step,
                need_y_emb=True)
            # print("encoder_state ", encoder_state.size())
            # print(encoder_state)
            # print("decoder_outputs", decoder_outputs.size())
            # print(decoder_outputs)
            # print("y_emb", y_emb.size())
            # print(y_emb)

            src_pad_mask = (1 - mask_src).unsqueeze(1)
            # print("src_pad_mask ", src_pad_mask.size())
            # print(src_pad_mask)
            context_vector, attn_dist = self.model.context_attn(
                src_features,
                src_features,
                decoder_outputs,
                mask=src_pad_mask,
                # layer_cache=layer_cache,
                type="context")
            # print("context vector ", context_vector.size())
            # print(context_vector)
            # print("attn_dist 0", attn_dist.size())
            # print(attn_dist)
            g = torch.sigmoid(
                F.linear(
                    torch.cat([decoder_outputs, y_emb, context_vector], -1),
                    self.model.v, self.model.bv))
            xids = src.unsqueeze(0).transpose(0, 1)

            # xids = xids * mask_tgt.unsqueeze(2)[:, :-1, :].long()

            len0 = src.size(1)
            len0 = torch.Tensor([[len0]]).repeat(src.size(0),
                                                 1).long().to('cuda')
            clss_up = torch.cat((clss, len0), dim=1)
            sent_len = (clss_up[:, 1:] - clss) * mask_cls.long()
            for i in range(mask_cls.size(0)):
                for j in range(mask_cls.size(1)):
                    if sent_len[i][j] < 0:
                        sent_len[i][j] += src.size(1)
            # print("sent_len ", sent_len.size())
            # print(sent_len)
            ext_scores_0 = ext_scores.unsqueeze(1).transpose(1, 2).repeat(
                1, 1, mask_src.size(1))
            for i in range(mask_cls.size(0)):
                tmp_vec = ext_scores_0[i, 0, :sent_len[i][0].int()]

                for j in range(1, mask_cls.size(1)):
                    tmp_vec = torch.cat(
                        (tmp_vec, ext_scores_0[i, j, :sent_len[i][j].int()]),
                        dim=0)
                if i == 0:
                    ext_scores_new = tmp_vec.unsqueeze(0)
                else:
                    ext_scores_new = torch.cat(
                        (ext_scores_new, tmp_vec.unsqueeze(0)), dim=0)
            # print("ext_score_new", ext_scores_new.size())
            # print(ext_scores_new)
            ext_scores_new = ext_scores_new * mask_src.float()
            attn_dist = attn_dist * (ext_scores_new + 1).unsqueeze(1)

            # print("attn_dist 1", attn_dist.size())
            # print(attn_dist)

            ext_dist = Variable(
                torch.zeros(
                    src.size(0), 1,
                    self.model.abstractor.bert.model.config.vocab_size).to(
                        'cuda'))
            # ext_vocab_prob = ext_dist.scatter_add(2, xids, (1 - g) * mask_tgt.unsqueeze(2)[:,:-1,:].float() * attn_pad_mask) * mask_tgt.unsqueeze(2)[:,:-1,:].float()
            ext_vocab_prob = ext_dist.scatter_add(2, xids, (1 - g) * attn_dist)

            # Generator forward.
            softmax_probs = self.model.abstractor.generator.forward(
                decoder_outputs.transpose(0, 1).squeeze(0)) * g.transpose(
                    0, 1).squeeze(0) + ext_vocab_prob.transpose(0,
                                                                1).squeeze(0)
            log_probs = torch.log(softmax_probs)
            vocab_size = log_probs.size(-1)

            if step < min_length:
                log_probs[:, self.end_token] = -1e20

            # Multiply probs by the beam probability.
            log_probs += topk_log_probs.view(-1).unsqueeze(1)

            alpha = self.global_scorer.alpha
            length_penalty = ((5.0 + (step + 1)) / 6.0)**alpha

            # Flatten probs into a list of possibilities.
            curr_scores = log_probs / length_penalty

            if (self.args.block_trigram):
                cur_len = alive_seq.size(1)
                if (cur_len > 3):
                    for i in range(alive_seq.size(0)):
                        fail = False
                        words = [int(w) for w in alive_seq[i]]
                        words = [self.vocab.ids_to_tokens[w] for w in words]
                        words = ' '.join(words).replace(' ##', '').split()
                        if (len(words) <= 3):
                            continue
                        trigrams = [(words[i - 1], words[i], words[i + 1])
                                    for i in range(1,
                                                   len(words) - 1)]
                        trigram = tuple(trigrams[-1])
                        if trigram in trigrams[:-1]:
                            fail = True
                        if fail:
                            curr_scores[i] = -10e20

            curr_scores = curr_scores.reshape(-1, beam_size * vocab_size)
            topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1)

            # Recover log probs.
            topk_log_probs = topk_scores * length_penalty

            # Resolve beam origin and true word ids.
            topk_beam_index = topk_ids.div(vocab_size)
            topk_ids = topk_ids.fmod(vocab_size)

            # Map beam_index to batch_index in the flat representation.
            batch_index = (topk_beam_index +
                           beam_offset[:topk_beam_index.size(0)].unsqueeze(1))
            select_indices = batch_index.view(-1)

            # Append last prediction.
            alive_seq = torch.cat([
                alive_seq.index_select(0, select_indices),
                topk_ids.view(-1, 1)
            ], -1)

            is_finished = topk_ids.eq(self.end_token)
            if step + 1 == max_length:
                is_finished.fill_(1)
            # End condition is top beam is finished.
            end_condition = is_finished[:, 0].eq(1)
            # Save finished hypotheses.
            if is_finished.any():
                predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1))
                for i in range(is_finished.size(0)):
                    b = batch_offset[i]
                    if end_condition[i]:
                        is_finished[i].fill_(1)
                    finished_hyp = is_finished[i].nonzero().view(-1)
                    # Store finished hypotheses for this batch.
                    for j in finished_hyp:
                        hypotheses[b].append(
                            (topk_scores[i, j], predictions[i, j, 1:]))
                    # If the batch reached the end, save the n_best hypotheses.
                    if end_condition[i]:
                        best_hyp = sorted(hypotheses[b],
                                          key=lambda x: x[0],
                                          reverse=True)
                        score, pred = best_hyp[0]

                        results["scores"][b].append(score)
                        results["predictions"][b].append(pred)
                non_finished = end_condition.eq(0).nonzero().view(-1)
                # If all sentences are translated, no need to go further.
                if len(non_finished) == 0:
                    break
                # Remove finished batches for the next step.
                topk_log_probs = topk_log_probs.index_select(0, non_finished)
                batch_index = batch_index.index_select(0, non_finished)
                batch_offset = batch_offset.index_select(0, non_finished)
                alive_seq = predictions.index_select(0, non_finished) \
                    .view(-1, alive_seq.size(-1))
            # Reorder states.
            select_indices = batch_index.view(-1)
            src_features = src_features.index_select(0, select_indices)
            ext_scores = ext_scores.index_select(0, select_indices)
            mask_src = mask_src.index_select(0, select_indices)
            mask_cls = mask_cls.index_select(0, select_indices)
            clss = clss.index_select(0, select_indices)
            src = src.index_select(0, select_indices)

            dec_states.map_batch_fn(
                lambda state, dim: state.index_select(dim, select_indices))

        return results