def forward_word_ins(self, encoder_outs, output_tokens, output_scores, attn, can_ins_word): word_ins_score_avg = [] word_ins_attn_avg = [] for model, encoder_out in zip(self.models, encoder_outs): word_ins_out, word_ins_attn = model.decoder.forward_word_ins( _skip(output_tokens, can_ins_word), _skip(encoder_out, can_ins_word), ) word_ins_score = F.log_softmax(word_ins_out, 2) word_ins_score_avg.append(word_ins_score) word_ins_attn_avg.append(word_ins_attn) word_ins_score_avg = torch.logsumexp( torch.stack(word_ins_score_avg, dim=0), dim=0) - math.log( len(self.models)) if word_ins_attn_avg[0] is not None: word_ins_attn_avg = torch.stack(word_ins_attn_avg, dim=0) / len( self.models) else: word_ins_attn_avg = None word_ins_score_max, word_ins_pred = word_ins_score_avg.max(-1) _tokens, _scores = _apply_ins_words( output_tokens[can_ins_word], output_scores[can_ins_word], word_ins_pred, word_ins_score_max, self.unk, ) output_tokens = _fill(output_tokens, can_ins_word, _tokens, self.pad) output_scores = _fill(output_scores, can_ins_word, _scores, 0) attn = _fill(attn, can_ins_word, word_ins_attn, 0.) return output_tokens, output_scores, attn
def forward_mask_ins(self, encoder_outs, output_tokens, output_scores, can_ins_mask, eos_penalty, max_lens): mask_ins_score_avg = [] for model, encoder_out in zip(self.models, encoder_outs): mask_ins_out, _ = model.decoder.forward_mask_ins( _skip(output_tokens, can_ins_mask), _skip(encoder_out, can_ins_mask), ) mask_ins_score = F.log_softmax(mask_ins_out, 2) if eos_penalty > 0.0: mask_ins_score[:, :, 0] -= eos_penalty mask_ins_score_avg.append(mask_ins_score) mask_ins_score_avg = torch.logsumexp( torch.stack(mask_ins_score_avg, dim=0), dim=0) - math.log( len(self.models)) mask_ins_pred = mask_ins_score_avg.max(-1)[1] mask_ins_pred = torch.min( mask_ins_pred, max_lens[can_ins_mask, None].expand_as(mask_ins_pred)) _tokens, _scores = _apply_ins_masks( output_tokens[can_ins_mask], output_scores[can_ins_mask], mask_ins_pred, self.pad, self.unk, self.eos, ) output_tokens = _fill(output_tokens, can_ins_mask, _tokens, self.pad) output_scores = _fill(output_scores, can_ins_mask, _scores, 0) return output_tokens, output_scores
def generate(self, models, sample, prefix_tokens=None): # TODO: model ensemble assert len(models) == 1, 'only support single model' model = models[0] if not self.retain_dropout: model.eval() # TODO: better encoder inputs? src_tokens = sample['net_input']['src_tokens'] src_lengths = sample['net_input']['src_lengths'] bsz, src_len = src_tokens.size() sent_idxs = torch.arange(bsz, device=src_tokens.device) # encoding encoder_out = model.forward_encoder([src_tokens, src_lengths]) # initialize buffers (very model specific, with length prediction or not) prev_decoder_out = model.initialize_output_tokens( encoder_out, src_tokens) prev_out_tokens = prev_decoder_out['output_tokens'].clone() finalized = [[] for _ in range(bsz)] def is_a_loop(x, y, s, a): b, l_x, l_y = x.size(0), x.size(1), y.size(1) if l_x > l_y: y = torch.cat([y, x.new_zeros(b, l_x - l_y).fill_(self.pad)], 1) s = torch.cat([s, s.new_zeros(b, l_x - l_y)], 1) if a is not None: a = torch.cat([a, a.new_zeros(b, l_x - l_y, a.size(2))], 1) elif l_x < l_y: x = torch.cat([x, y.new_zeros(b, l_y - l_x).fill_(self.pad)], 1) return (x == y).all(1), y, s, a def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn): cutoff = prev_out_token.ne(self.pad) tokens = prev_out_token[cutoff] scores = prev_out_score[cutoff] if prev_out_attn is None: hypo_attn, alignment = None, None else: hypo_attn = prev_out_attn[cutoff] alignment = hypo_attn.max(dim=1)[1] return { 'steps': step, 'tokens': tokens, 'positional_scores': scores, 'score': scores.mean(), 'hypo_attn': hypo_attn, 'alignment': alignment, } for step in range(self.max_iter + 1): decoder_options = { 'eos_penalty': self.eos_penalty, 'max_ratio': self.max_ratio, 'decoding_format': self.decoding_format } prev_decoder_out['step'] = step prev_decoder_out['max_step'] = self.max_iter + 1 decoder_out = model.forward_decoder(prev_decoder_out, encoder_out, **decoder_options) if self.adaptive: # terminate if there is a loop terminated, out_tokens, out_scores, out_attn = is_a_loop( prev_out_tokens, decoder_out['output_tokens'], decoder_out['output_scores'], decoder_out['attn']) decoder_out['output_tokens'] = out_tokens decoder_out['output_scores'] = out_scores decoder_out['attn'] = out_attn else: terminated = decoder_out['output_tokens'].new_zeros( decoder_out['output_tokens'].size(0)).bool() if step == self.max_iter: # reach last iteration, terminate terminated.fill_(1) # collect finalized sentences finalized_idxs = sent_idxs[terminated] finalized_tokens = decoder_out['output_tokens'][terminated] finalized_scores = decoder_out['output_scores'][terminated] finalized_attn = None if decoder_out[ 'attn'] is None else decoder_out['attn'][terminated] for i in range(finalized_idxs.size(0)): finalized[finalized_idxs[i]] = [ finalized_hypos( step, finalized_tokens[i], finalized_scores[i], None if finalized_attn is None else finalized_attn[i]) ] # check if all terminated if terminated.sum() == terminated.size(0): break # for next step prev_decoder_out = _skip(decoder_out, ~terminated) encoder_out = _skip(encoder_out, ~terminated) sent_idxs = _skip(sent_idxs, ~terminated) prev_out_tokens = prev_decoder_out['output_tokens'].clone() return finalized
def generate(self, models, sample, prefix_tokens=None): if len(models) == 1: # Keep this for other NAT models for which we have yet to implement ensemble wrappers. Later delete this. model = models[0] elif isinstance(models[0], LevenshteinTransformerModel): model = EnsembleLevT(models) else: raise NotImplementedError if not self.retain_dropout: model.eval() # TODO: better encoder inputs? src_tokens = sample["net_input"]["src_tokens"] src_lengths = sample["net_input"]["src_lengths"] bsz, src_len = src_tokens.size() sent_idxs = torch.arange(bsz) # encoding encoder_out = model.forward_encoder([src_tokens, src_lengths]) # initialize buffers (very model specific, with length prediction or not) prev_decoder_out = model.initialize_output_tokens( encoder_out, src_tokens) prev_output_tokens = prev_decoder_out[0].clone() finalized = [[] for _ in range(bsz)] def is_a_loop(x, y, s, a): b, l_x, l_y = x.size(0), x.size(1), y.size(1) if l_x > l_y: y = torch.cat([y, x.new_zeros(b, l_x - l_y).fill_(self.pad)], 1) s = torch.cat([s, s.new_zeros(b, l_x - l_y)], 1) if a is not None: a = torch.cat([a, a.new_zeros(b, l_x - l_y, a.size(2))], 1) elif l_x < l_y: x = torch.cat([x, y.new_zeros(b, l_y - l_x).fill_(self.pad)], 1) return (x == y).all(1), y, s, a def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn): cutoff = prev_out_token.ne(self.pad) tokens = prev_out_token[cutoff] scores = prev_out_score[cutoff] if prev_out_attn is None: hypo_attn, alignment = None, None else: hypo_attn = prev_out_attn[cutoff] alignment = hypo_attn.max(dim=1)[1] return { "steps": step, "tokens": tokens, "positional_scores": scores, "score": scores.mean(), "hypo_attn": hypo_attn, "alignment": alignment, } for step in range(self.max_iter + 1): decoder_options = { "eos_penalty": self.eos_penalty, "max_ratio": self.max_ratio, "decoding_format": self.decoding_format, } prev_decoder_out[3] = step prev_decoder_out[4] = self.max_iter + 1 decoder_out = model.forward_decoder(prev_decoder_out, encoder_out, **decoder_options) if self.adaptive: # terminate if there is a loop terminated, out_tokens, out_scores, out_attn = is_a_loop( prev_output_tokens, decoder_out[0], decoder_out[1], decoder_out[2]) decoder_out[0] = out_tokens decoder_out[1] = out_scores decoder_out[2] = out_attn else: terminated = decoder_out[0].new_zeros( decoder_out[0].size(0)).bool() if step == self.max_iter: # reach last iteration, terminate terminated.fill_(1) # collect finalized sentences finalized_idxs = sent_idxs[terminated] finalized_tokens = decoder_out[0][terminated] finalized_scores = decoder_out[1][terminated] finalized_attn = (None if decoder_out[2] is None else decoder_out[2][terminated]) for i in range(finalized_idxs.size(0)): finalized[finalized_idxs[i]] = [ finalized_hypos( step, finalized_tokens[i], finalized_scores[i], None if finalized_attn is None else finalized_attn[i], ) ] # check if all terminated if terminated.sum() == terminated.size(0): break # for next step prev_decoder_out = _skip(decoder_out, ~terminated) encoder_out = script_skip_tensor_list(encoder_out, ~terminated) sent_idxs = _skip(sent_idxs, ~terminated) prev_output_tokens = prev_decoder_out[0].clone() return finalized
def forward_decoder(self, decoder_out, encoder_out, eos_penalty=0.0, max_ratio=None, **kwargs): output_tokens = decoder_out["output_tokens"] output_scores = decoder_out["output_scores"] attn = decoder_out["attn"] if max_ratio is None: max_lens = output_tokens.new(output_tokens.size(0)).fill_(255) else: max_lens = ((~encoder_out["encoder_padding_mask"]).sum(1) * max_ratio).clamp(min=10) # delete words # do not delete tokens if it is <s> </s> can_del_word = output_tokens.ne(self.pad).sum(1) > 2 if can_del_word.sum() != 0: # we cannot delete, skip word_del_out, word_del_attn = self.decoder.forward_word_del( _skip(output_tokens, can_del_word), _skip(encoder_out, can_del_word)) word_del_score = F.log_softmax(word_del_out, 2) word_del_pred = word_del_score.max(-1)[1].bool() _tokens, _scores, _attn = _apply_del_words( output_tokens[can_del_word], output_scores[can_del_word], word_del_attn, word_del_pred, self.pad, self.bos, self.eos, ) output_tokens = _fill(output_tokens, can_del_word, _tokens, self.pad) output_scores = _fill(output_scores, can_del_word, _scores, 0) attn = _fill(attn, can_del_word, _attn, 0.) # insert placeholders can_ins_mask = output_tokens.ne(self.pad).sum(1) < max_lens if can_ins_mask.sum() != 0: mask_ins_out, _ = self.decoder.forward_mask_ins( _skip(output_tokens, can_ins_mask), _skip(encoder_out, can_ins_mask)) mask_ins_score = F.log_softmax(mask_ins_out, 2) if eos_penalty > 0.0: mask_ins_score[:, :, 0] -= eos_penalty mask_ins_pred = mask_ins_score.max(-1)[1] mask_ins_pred = torch.min( mask_ins_pred, max_lens[:, None].expand_as(mask_ins_pred)) _tokens, _scores = _apply_ins_masks( output_tokens[can_ins_mask], output_scores[can_ins_mask], mask_ins_pred, self.pad, self.unk, self.eos, ) output_tokens = _fill(output_tokens, can_ins_mask, _tokens, self.pad) output_scores = _fill(output_scores, can_ins_mask, _scores, 0) # insert words can_ins_word = output_tokens.eq(self.unk).sum(1) > 0 if can_ins_word.sum() != 0: word_ins_out, word_ins_attn = self.decoder.forward_word_ins( _skip(output_tokens, can_ins_word), _skip(encoder_out, can_ins_word)) word_ins_score = F.log_softmax(word_ins_out, 2) word_ins_pred = word_ins_score.max(-1)[1] _tokens, _scores = _apply_ins_words( output_tokens[can_ins_word], output_scores[can_ins_word], word_ins_pred, word_ins_score, self.unk, ) output_tokens = _fill(output_tokens, can_ins_word, _tokens, self.pad) output_scores = _fill(output_scores, can_ins_word, _scores, 0) attn = _fill(attn, can_ins_word, word_ins_attn, 0.) # delete some unnecessary paddings cut_off = output_tokens.ne(self.pad).sum(1).max() output_tokens = output_tokens[:, :cut_off] output_scores = output_scores[:, :cut_off] attn = None if attn is None else attn[:, :cut_off, :] return { "output_tokens": output_tokens, "output_scores": output_scores, "attn": attn, }
def generate(self, models, sample, prefix_tokens=None): if len(models) == 1: # Keep this for other NAT models for which we have yet to implement ensemble wrappers. Later delete this. model = models[0] elif isinstance(models[0], LevenshteinTransformerModel): model = EnsembleLevT(models) else: raise NotImplementedError if not self.retain_dropout: model.eval() # TODO: better encoder inputs? src_tokens = sample['net_input']['src_tokens'] src_lengths = sample['net_input']['src_lengths'] tgt_init_tokens = None tgt_init_lengths = None if 'tgt_init_tokens' in sample['net_input']: tgt_init_tokens = sample['net_input']['tgt_init_tokens'] tgt_init_lengths = sample['net_input']['tgt_init_lengths'] bsz, src_len = src_tokens.size() sent_idxs = torch.arange(bsz, device=src_tokens.device) # encoding encoder_out = model.forward_encoder([src_tokens, src_lengths]) # initialize buffers (very model specific, with length prediction or not) prev_decoder_out = model.initialize_output_tokens( encoder_out, src_tokens, tgt_init_tokens, tgt_init_lengths) prev_out_tokens = prev_decoder_out['output_tokens'].clone() finalized = [[] for _ in range(bsz)] def is_a_loop(x, y, s, a, c, d): b, l_x, l_y = x.size(0), x.size(1), y.size(1) if l_x > l_y: y = torch.cat([y, x.new_zeros(b, l_x - l_y).fill_(self.pad)], 1) s = torch.cat([s, s.new_zeros(b, l_x - l_y)], 1) if a is not None: a = torch.cat([a, a.new_zeros(b, l_x - l_y, a.size(2))], 1) if c is not None: c = torch.cat([c, c.new_zeros(b, l_x - l_y)], 1) if d is not None: d = torch.cat([d, d.new_zeros(b, l_x - l_y)], 1) elif l_x < l_y: x = torch.cat([x, y.new_zeros(b, l_y - l_x).fill_(self.pad)], 1) return (x == y).all(1), y, s, a, c, d def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn, prev_out_const_del, prev_out_const_ins, src_tokens): cutoff = prev_out_token.ne(self.pad) tokens = prev_out_token[cutoff] scores = prev_out_score[cutoff] const_del = None if prev_out_const_del is not None: const_del = prev_out_const_del[cutoff] const_ins = None if prev_out_const_ins is not None: const_ins = prev_out_const_ins[cutoff] if prev_out_attn is None: hypo_attn, alignment = None, None else: hypo_attn = prev_out_attn[cutoff] alignment = utils.extract_hard_alignment( hypo_attn, src_tokens, tokens, self.pad, self.eos) return { 'steps': step, 'tokens': tokens, 'positional_scores': scores, 'score': scores.mean(), 'hypo_attn': hypo_attn, 'alignment': alignment, 'const_del': const_del, 'const_ins': const_ins, } for step in range(self.max_iter + 1): decoder_options = { 'eos_penalty': self.eos_penalty, 'max_ratio': self.max_ratio, 'decoding_format': self.decoding_format, 'preserve_constraint': self.preserve_constraint, 'allow_insertion_constraint': self.allow_insertion_constraint, } prev_decoder_out['step'] = step prev_decoder_out['max_step'] = self.max_iter + 1 decoder_out = model.forward_decoder(prev_decoder_out, encoder_out, **decoder_options) if self.adaptive: # terminate if there is a loop terminated, out_tokens, out_scores, out_attn, out_const_del, out_const_ins = is_a_loop( prev_out_tokens, decoder_out['output_tokens'], decoder_out['output_scores'], decoder_out['attn'], decoder_out['const_del'], decoder_out['const_ins']) decoder_out['output_tokens'] = out_tokens decoder_out['output_scores'] = out_scores decoder_out['attn'] = out_attn decoder_out['const_del'] = out_const_del decoder_out['const_ins'] = out_const_ins else: terminated = decoder_out['output_tokens'].new_zeros( decoder_out['output_tokens'].size(0)).bool() if step == self.max_iter: # reach last iteration, terminate terminated.fill_(1) # collect finalized sentences finalized_idxs = sent_idxs[terminated] finalized_tokens = decoder_out['output_tokens'][terminated] finalized_scores = decoder_out['output_scores'][terminated] finalized_attn = None if decoder_out[ 'attn'] is None else decoder_out['attn'][terminated] finalized_const_del = None if decoder_out[ 'const_del'] is None else decoder_out['const_del'][terminated] finalized_const_ins = None if decoder_out[ 'const_ins'] is None else decoder_out['const_ins'][terminated] for i in range(finalized_idxs.size(0)): finalized[finalized_idxs[i]] = [ finalized_hypos( step, finalized_tokens[i], finalized_scores[i], None if finalized_attn is None else finalized_attn[i], None if finalized_const_del is None else finalized_const_del[i], None if finalized_const_ins is None else finalized_const_ins[i], src_tokens[finalized_idxs[i]]) ] # check if all terminated if terminated.sum() == terminated.size(0): break # for next step prev_decoder_out = _skip(decoder_out, ~terminated) encoder_out = _skip(encoder_out, ~terminated) sent_idxs = _skip(sent_idxs, ~terminated) prev_out_tokens = prev_decoder_out['output_tokens'].clone() return finalized