def rerank(self, reranker, finalized, encoder_input, beam_size): def rebuild_batch(finalized): finalized_tokens = [f[0]['tokens'] for f in finalized] finalized_maxlen = max(f.size(0) for f in finalized_tokens) final_output_tokens = finalized_tokens[0].new_zeros( len(finalized_tokens), finalized_maxlen).fill_(self.pad) for i, f in enumerate(finalized_tokens): final_output_tokens[i, :f.size(0)] = f return final_output_tokens final_output_tokens = rebuild_batch(finalized) final_output_tokens[:, 0] = self.eos # autoregressive model assumes starting with EOS reranker_encoder_out = reranker.encoder(*encoder_input) length_beam_order = utils.new_arange( final_output_tokens, beam_size, reranker_encoder_out.encoder_out.size(1)).t().reshape(-1) reranker_encoder_out = reranker.encoder.reorder_encoder_out( reranker_encoder_out, length_beam_order) reranking_scores = reranker.get_normalized_probs( reranker.decoder(final_output_tokens[:, :-1], reranker_encoder_out), True, None) reranking_scores = reranking_scores.gather( 2, final_output_tokens[:, 1:, None]) reranking_masks = final_output_tokens[:, 1:].ne(self.pad) reranking_scores = reranking_scores[:, :, 0].masked_fill_( ~reranking_masks, 0).sum(1) reranking_scores = reranking_scores / reranking_masks.sum(1).type_as( reranking_scores) for i in range(len(finalized)): finalized[i][0]['score'] = reranking_scores[i] return finalized
def _apply_ins_masks(in_tokens, in_scores, mask_ins_pred, padding_idx, unk_idx, eos_idx): in_masks = in_tokens.ne(padding_idx) in_lengths = in_masks.sum(1) # HACK: hacky way to shift all the paddings to eos first. in_tokens.masked_fill_(~in_masks, eos_idx) mask_ins_pred.masked_fill_(~in_masks[:, 1:], 0) out_lengths = in_lengths + mask_ins_pred.sum(1) out_max_len = out_lengths.max() out_masks = (new_arange(out_lengths, out_max_len)[None, :] < out_lengths[:, None]) reordering = (mask_ins_pred + in_masks[:, 1:].long()).cumsum(1) out_tokens = (in_tokens.new_zeros( in_tokens.size(0), out_max_len).fill_(padding_idx).masked_fill_(out_masks, unk_idx)) out_tokens[:, 0] = in_tokens[:, 0] out_tokens.scatter_(1, reordering, in_tokens[:, 1:]) out_scores = None if in_scores is not None: in_scores.masked_fill_(~in_masks, 0) out_scores = in_scores.new_zeros(*out_tokens.size()) out_scores[:, 0] = in_scores[:, 0] out_scores.scatter_(1, reordering, in_scores[:, 1:]) return out_tokens, out_scores
def _apply_del_words(in_tokens, in_scores, in_attn, word_del_pred, padding_idx, bos_idx, eos_idx): # apply deletion to a tensor in_masks = in_tokens.ne(padding_idx) bos_eos_masks = in_tokens.eq(bos_idx) | in_tokens.eq(eos_idx) max_len = in_tokens.size(1) word_del_pred.masked_fill_(~in_masks, 1) word_del_pred.masked_fill_(bos_eos_masks, 0) reordering = (new_arange(in_tokens).masked_fill_(word_del_pred, max_len).sort(1)[1]) out_tokens = in_tokens.masked_fill(word_del_pred, padding_idx).gather(1, reordering) out_scores = None if in_scores is not None: out_scores = in_scores.masked_fill(word_del_pred, 0).gather(1, reordering) out_attn = None if in_attn is not None: _mask = word_del_pred[:, :, None].expand_as(in_attn) _reordering = reordering[:, :, None].expand_as(in_attn) out_attn = in_attn.masked_fill(_mask, 0.).gather(1, _reordering) return out_tokens, out_scores, out_attn
def _skeptical_unmasking(output_scores, output_masks, p): sorted_index = output_scores.sort(-1)[1] boundary_len = ( (output_masks.sum(1, keepdim=True).type_as(output_scores) - 2) * p ).long() skeptical_mask = new_arange(output_masks) < boundary_len return skeptical_mask.scatter(1, sorted_index, skeptical_mask)
def initialize_output_tokens(self, encoder_out, src_tokens): # length prediction length_tgt = self.decoder.forward_length_prediction( self.decoder.forward_length(normalize=True, encoder_out=encoder_out), encoder_out=encoder_out) max_length = length_tgt.clamp_(min=2).max() idx_length = utils.new_arange(src_tokens, max_length) initial_output_tokens = src_tokens.new_zeros( src_tokens.size(0), max_length).fill_(self.pad) initial_output_tokens.masked_fill_( idx_length[None, :] < length_tgt[:, None], self.unk) initial_output_tokens[:, 0] = self.bos initial_output_tokens.scatter_(1, length_tgt[:, None] - 1, self.eos) initial_output_scores = initial_output_tokens.new_zeros( *initial_output_tokens.size()).type_as(encoder_out.encoder_out) return DecoderOut(output_tokens=initial_output_tokens, output_scores=initial_output_scores, attn=None, step=0, max_step=0, history=None)
def _uniform_assignment(src_lens, trg_lens): max_trg_len = trg_lens.max() steps = (src_lens.float() - 1) / (trg_lens.float() - 1) # step-size # max_trg_len index_t = utils.new_arange(trg_lens, max_trg_len).float() index_t = steps[:, None] * index_t[None, :] # batch_size X max_trg_len index_t = torch.round(index_t).long().detach() return index_t
def regenerate_length_beam(self, decoder_out, beam_size): output_tokens = decoder_out.output_tokens length_tgt = output_tokens.ne(self.pad).sum(1) length_tgt = length_tgt[:, None] + utils.new_arange( length_tgt, 1, beam_size) - beam_size // 2 length_tgt = length_tgt.view(-1).clamp_(min=2) max_length = length_tgt.max() idx_length = utils.new_arange(length_tgt, max_length) initial_output_tokens = output_tokens.new_zeros( length_tgt.size(0), max_length).fill_(self.pad) initial_output_tokens.masked_fill_( idx_length[None, :] < length_tgt[:, None], self.unk) initial_output_tokens[:, 0] = self.bos initial_output_tokens.scatter_(1, length_tgt[:, None] - 1, self.eos) initial_output_scores = initial_output_tokens.new_zeros( *initial_output_tokens.size()).type_as(decoder_out.output_scores) return decoder_out._replace(output_tokens=initial_output_tokens, output_scores=initial_output_scores)
def _apply_ins_words(in_tokens, in_scores, word_ins_pred, word_ins_scores, padding_idx): padding_masks = in_tokens[:, 1:].eq(padding_idx) word_ins_scores.masked_fill_(padding_masks, 0.0) word_ins_pred.masked_fill_(padding_masks, padding_idx) in_coords = new_arange(in_tokens).type_as(in_scores) # shift all padding predictions to infinite out_coords = (in_coords[:, 1:] - 0.5).masked_fill( word_ins_pred.eq(padding_idx), float("inf")) out_coords = torch.cat([in_coords, out_coords], 1).sort(-1)[1] out_tokens = torch.cat([in_tokens, word_ins_pred], 1).gather(1, out_coords) out_scores = torch.cat([in_scores, word_ins_scores], 1).gather(1, out_coords) return out_tokens, out_scores
def _random_mask(target_tokens): pad = self.tgt_dict.pad() bos = self.tgt_dict.bos() eos = self.tgt_dict.eos() unk = self.tgt_dict.unk() target_masks = target_tokens.ne(pad) & \ target_tokens.ne(bos) & \ target_tokens.ne(eos) target_score = target_tokens.clone().float().uniform_() target_score.masked_fill_(~target_masks, 2.0) target_length = target_masks.sum(1).float() target_length = target_length * target_length.clone().uniform_() target_length = target_length + 1 # make sure to mask at least one token. _, target_rank = target_score.sort(1) target_cutoff = new_arange( target_rank) < target_length[:, None].long() prev_target_tokens = target_tokens.masked_fill( target_cutoff.scatter(1, target_rank, target_cutoff), unk) return prev_target_tokens
def generate(self, models, sample, prefix_tokens=None): # TODO: iterative refinement generator does not support ensemble for now. if not self.retain_dropout: for model in models: model.eval() model, reranker = models[0], None if self.reranking: assert len( models) > 1, "Assuming the last checkpoint is the reranker" assert self.beam_size > 1, "Reranking requires multiple translation for each example" reranker = models[-1] models = models[:-1] if len(models) > 1 and hasattr(model, 'enable_ensemble'): assert model.allow_ensemble, "{} does not support ensembling".format( model.__class__.__name__) model.enable_ensemble(models) # TODO: better encoder inputs? src_tokens = sample["net_input"]["src_tokens"] src_lengths = sample["net_input"]["src_lengths"] bsz, src_len = src_tokens.size() # initialize encoder_out = model.forward_encoder([src_tokens, src_lengths]) prev_decoder_out = model.initialize_output_tokens( encoder_out, src_tokens) if self.beam_size > 1: assert model.allow_length_beam, \ "{} does not support decoding with length beam.".format(model.__class__.__name__) # regenerate data based on length-beam length_beam_order = utils.new_arange(src_tokens, self.beam_size, bsz).t().reshape(-1) encoder_out = model.encoder.reorder_encoder_out( encoder_out, length_beam_order) prev_decoder_out = model.regenerate_length_beam( prev_decoder_out, self.beam_size) bsz = bsz * self.beam_size sent_idxs = torch.arange(bsz) prev_output_tokens = prev_decoder_out.output_tokens.clone() if self.retain_history: prev_decoder_out = prev_decoder_out._replace( history=[prev_output_tokens]) finalized = [[] for _ in range(bsz)] def is_a_loop(x, y, s, a): b, l_x, l_y = x.size(0), x.size(1), y.size(1) if l_x > l_y: y = torch.cat([y, x.new_zeros(b, l_x - l_y).fill_(self.pad)], 1) s = torch.cat([s, s.new_zeros(b, l_x - l_y)], 1) if a is not None: a = torch.cat([a, a.new_zeros(b, l_x - l_y, a.size(2))], 1) elif l_x < l_y: x = torch.cat([x, y.new_zeros(b, l_y - l_x).fill_(self.pad)], 1) return (x == y).all(1), y, s, a def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn): cutoff = prev_out_token.ne(self.pad) tokens = prev_out_token[cutoff] if prev_out_score is None: scores, score = None, None else: scores = prev_out_score[cutoff] score = scores.mean() if prev_out_attn is None: hypo_attn, alignment = None, None else: hypo_attn = prev_out_attn[cutoff] alignment = hypo_attn.max(dim=1)[1] return { "steps": step, "tokens": tokens, "positional_scores": scores, "score": score, "hypo_attn": hypo_attn, "alignment": alignment, } for step in range(self.max_iter + 1): decoder_options = { "eos_penalty": self.eos_penalty, "max_ratio": self.max_ratio, "decoding_format": self.decoding_format, } prev_decoder_out = prev_decoder_out._replace( step=step, max_step=self.max_iter + 1, ) decoder_out = model.forward_decoder(prev_decoder_out, encoder_out, **decoder_options) if self.adaptive: # terminate if there is a loop terminated, out_tokens, out_scores, out_attn = is_a_loop( prev_output_tokens, decoder_out.output_tokens, decoder_out.output_scores, decoder_out.attn) decoder_out = decoder_out._replace( output_tokens=out_tokens, output_scores=out_scores, attn=out_attn, ) else: terminated = decoder_out.output_tokens.new_zeros( decoder_out.output_tokens.size(0)).bool() if step == self.max_iter: # reach last iteration, terminate terminated.fill_(1) # collect finalized sentences finalized_idxs = sent_idxs[terminated] finalized_tokens = decoder_out.output_tokens[terminated] finalized_scores = decoder_out.output_scores[terminated] finalized_attn = (None if decoder_out.attn is None else decoder_out.attn[terminated]) if self.retain_history: finalized_history_tokens = [ h[terminated] for h in decoder_out.history ] for i in range(finalized_idxs.size(0)): finalized[finalized_idxs[i]] = [ finalized_hypos( step, finalized_tokens[i], finalized_scores[i], None if finalized_attn is None else finalized_attn[i], ) ] if self.retain_history: finalized[finalized_idxs[i]][0]['history'] = [] for j in range(len(finalized_history_tokens)): finalized[finalized_idxs[i]][0]['history'].append( finalized_hypos(step, finalized_history_tokens[j][i], None, None)) # check if all terminated if terminated.sum() == terminated.size(0): break # for next step not_terminated = ~terminated prev_decoder_out = decoder_out._replace( output_tokens=decoder_out.output_tokens[not_terminated], output_scores=decoder_out.output_scores[not_terminated], attn=decoder_out.attn[not_terminated] if decoder_out.attn is not None else None, history=[h[not_terminated] for h in decoder_out.history] if decoder_out.history is not None else None) encoder_out = model.encoder.reorder_encoder_out( encoder_out, not_terminated.nonzero().squeeze()) sent_idxs = sent_idxs[not_terminated] prev_output_tokens = prev_decoder_out.output_tokens.clone() if self.beam_size > 1: if reranker is not None: finalized = self.rerank(reranker, finalized, [src_tokens, src_lengths], self.beam_size) # aggregate information from length beam finalized = [ finalized[np.argmax([ finalized[self.beam_size * i + j][0]['score'] for j in range(self.beam_size) ]) + self.beam_size * i] for i in range(len(finalized) // self.beam_size) ] return finalized