def forward_word_ins(self, encoder_outs, output_tokens, output_scores,
                         attn, can_ins_word):
        word_ins_score_avg = []
        word_ins_attn_avg = []
        for model, encoder_out in zip(self.models, encoder_outs):
            word_ins_out, word_ins_attn = model.decoder.forward_word_ins(
                _skip(output_tokens, can_ins_word),
                _skip_encoder_out(model.encoder, encoder_out, can_ins_word),
            )
            word_ins_score = F.log_softmax(word_ins_out, 2)
            word_ins_score_avg.append(word_ins_score)
            word_ins_attn_avg.append(word_ins_attn)
        word_ins_score_avg = torch.logsumexp(
            torch.stack(word_ins_score_avg, dim=0), dim=0) - math.log(
                len(self.models))
        if word_ins_attn_avg[0] is not None:
            word_ins_attn_avg = torch.stack(word_ins_attn_avg, dim=0) / len(
                self.models)
        else:
            word_ins_attn_avg = None
        word_ins_score_max, word_ins_pred = word_ins_score_avg.max(-1)

        _tokens, _scores = _apply_ins_words(
            output_tokens[can_ins_word],
            output_scores[can_ins_word],
            word_ins_pred,
            word_ins_score_max,
            self.unk,
        )

        output_tokens = _fill(output_tokens, can_ins_word, _tokens, self.pad)
        output_scores = _fill(output_scores, can_ins_word, _scores, 0)
        attn = _fill(attn, can_ins_word, word_ins_attn, 0.)
        return output_tokens, output_scores, attn
 def forward_mask_ins(self, encoder_outs, output_tokens, output_scores,
                      can_ins_mask, eos_penalty, max_lens):
     mask_ins_score_avg = []
     for model, encoder_out in zip(self.models, encoder_outs):
         mask_ins_out, _ = model.decoder.forward_mask_ins(
             _skip(output_tokens, can_ins_mask),
             _skip_encoder_out(model.encoder, encoder_out, can_ins_mask),
         )
         mask_ins_score = F.log_softmax(mask_ins_out, 2)
         if eos_penalty > 0.0:
             mask_ins_score[:, :, 0] -= eos_penalty
         mask_ins_score_avg.append(mask_ins_score)
     mask_ins_score_avg = torch.logsumexp(
         torch.stack(mask_ins_score_avg, dim=0), dim=0) - math.log(
             len(self.models))
     mask_ins_pred = mask_ins_score_avg.max(-1)[1]
     mask_ins_pred = torch.min(
         mask_ins_pred, max_lens[can_ins_mask,
                                 None].expand_as(mask_ins_pred))
     _tokens, _scores = _apply_ins_masks(
         output_tokens[can_ins_mask],
         output_scores[can_ins_mask],
         mask_ins_pred,
         self.pad,
         self.unk,
         self.eos,
     )
     output_tokens = _fill(output_tokens, can_ins_mask, _tokens, self.pad)
     output_scores = _fill(output_scores, can_ins_mask, _scores, 0)
     return output_tokens, output_scores