def forward_word_ins(self, encoder_outs, output_tokens, output_scores, attn, can_ins_word): word_ins_score_avg = [] word_ins_attn_avg = [] for model, encoder_out in zip(self.models, encoder_outs): word_ins_out, word_ins_attn = model.decoder.forward_word_ins( _skip(output_tokens, can_ins_word), _skip_encoder_out(model.encoder, encoder_out, can_ins_word), ) word_ins_score = F.log_softmax(word_ins_out, 2) word_ins_score_avg.append(word_ins_score) word_ins_attn_avg.append(word_ins_attn) word_ins_score_avg = torch.logsumexp( torch.stack(word_ins_score_avg, dim=0), dim=0) - math.log( len(self.models)) if word_ins_attn_avg[0] is not None: word_ins_attn_avg = torch.stack(word_ins_attn_avg, dim=0) / len( self.models) else: word_ins_attn_avg = None word_ins_score_max, word_ins_pred = word_ins_score_avg.max(-1) _tokens, _scores = _apply_ins_words( output_tokens[can_ins_word], output_scores[can_ins_word], word_ins_pred, word_ins_score_max, self.unk, ) output_tokens = _fill(output_tokens, can_ins_word, _tokens, self.pad) output_scores = _fill(output_scores, can_ins_word, _scores, 0) attn = _fill(attn, can_ins_word, word_ins_attn, 0.) return output_tokens, output_scores, attn
def forward_mask_ins(self, encoder_outs, output_tokens, output_scores, can_ins_mask, eos_penalty, max_lens): mask_ins_score_avg = [] for model, encoder_out in zip(self.models, encoder_outs): mask_ins_out, _ = model.decoder.forward_mask_ins( _skip(output_tokens, can_ins_mask), _skip_encoder_out(model.encoder, encoder_out, can_ins_mask), ) mask_ins_score = F.log_softmax(mask_ins_out, 2) if eos_penalty > 0.0: mask_ins_score[:, :, 0] -= eos_penalty mask_ins_score_avg.append(mask_ins_score) mask_ins_score_avg = torch.logsumexp( torch.stack(mask_ins_score_avg, dim=0), dim=0) - math.log( len(self.models)) mask_ins_pred = mask_ins_score_avg.max(-1)[1] mask_ins_pred = torch.min( mask_ins_pred, max_lens[can_ins_mask, None].expand_as(mask_ins_pred)) _tokens, _scores = _apply_ins_masks( output_tokens[can_ins_mask], output_scores[can_ins_mask], mask_ins_pred, self.pad, self.unk, self.eos, ) output_tokens = _fill(output_tokens, can_ins_mask, _tokens, self.pad) output_scores = _fill(output_scores, can_ins_mask, _scores, 0) return output_tokens, output_scores