def generate(self, models, sample, **kwargs): self.model.reset_incremental_state() finalized = super()._generate(sample, **kwargs) src_tokens = sample["net_input"]["src_tokens"] bsz = src_tokens.shape[0] beam_size = self.beam_size src_tokens, src_lengths, prev_output_tokens, tgt_tokens = self._prepare_batch_for_alignment( sample, finalized) if any( getattr(m, "full_context_alignment", False) for m in self.model.models): attn = self.model.forward_align(src_tokens, src_lengths, prev_output_tokens) else: attn = [ finalized[i // beam_size][i % beam_size]["attention"].transpose( 1, 0) for i in range(bsz * beam_size) ] # Process the attn matrix to extract hard alignments. for i in range(bsz * beam_size): alignment = utils.extract_hard_alignment(attn[i], src_tokens[i], tgt_tokens[i], self.pad, self.eos) finalized[i // beam_size][i % beam_size]["alignment"] = alignment return finalized
def generate(self, models, sample, **kwargs): model = EnsembleModelWithAlignment(models) finalized = super()._generate(model, sample, **kwargs) src_tokens = sample['net_input']['src_tokens'] bsz = src_tokens.shape[0] beam_size = self.beam_size src_tokens, src_lengths, prev_output_tokens, tgt_tokens = \ self._prepare_batch_for_alignment(sample, finalized) if any( getattr(m, 'full_context_alignment', False) for m in model.models): attn = model.forward_align(src_tokens, src_lengths, prev_output_tokens) else: attn = [ finalized[i // beam_size][i % beam_size]['attention'].transpose( 1, 0) for i in range(bsz * beam_size) ] # Process the attn matrix to extract hard alignments. for i in range(bsz * beam_size): alignment = utils.extract_hard_alignment(attn[i], src_tokens[i], tgt_tokens[i], self.pad, self.eos) finalized[i // beam_size][i % beam_size]['alignment'] = alignment return finalized
def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn, prev_out_const_del, prev_out_const_ins, src_tokens): cutoff = prev_out_token.ne(self.pad) tokens = prev_out_token[cutoff] scores = prev_out_score[cutoff] const_del = None if prev_out_const_del is not None: const_del = prev_out_const_del[cutoff] const_ins = None if prev_out_const_ins is not None: const_ins = prev_out_const_ins[cutoff] if prev_out_attn is None: hypo_attn, alignment = None, None else: hypo_attn = prev_out_attn[cutoff] alignment = utils.extract_hard_alignment( hypo_attn, src_tokens, tokens, self.pad, self.eos) return { 'steps': step, 'tokens': tokens, 'positional_scores': scores, 'score': scores.mean(), 'hypo_attn': hypo_attn, 'alignment': alignment, 'const_del': const_del, 'const_ins': const_ins, }
def generate(self, models, sample, **kwargs): """Score a batch of translations.""" net_input = sample['net_input'] def batch_for_softmax(dec_out, target): # assumes decoder_out[0] is the only thing needed (may not be correct for future models!) first, rest = dec_out[0], dec_out[1:] bsz, tsz, dim = first.shape if bsz * tsz < self.softmax_batch: yield dec_out, target, True else: flat = first.contiguous().view(1, -1, dim) flat_tgt = target.contiguous().view(flat.shape[:-1]) s = 0 while s < flat.size(1): e = s + self.softmax_batch yield (flat[:, s:e], ) + rest, flat_tgt[:, s:e], False s = e def gather_target_probs(probs, target): probs = probs.gather( dim=2, index=target.unsqueeze(-1), ) return probs orig_target = sample['target'] # compute scores for each model in the ensemble avg_probs = None avg_attn = None for model in models: model.eval() decoder_out = model.forward(**net_input) attn = decoder_out[1] if type(attn) is dict: attn = attn.get('attn', None) batched = batch_for_softmax(decoder_out, orig_target) probs, idx = None, 0 for bd, tgt, is_single in batched: sample['target'] = tgt curr_prob = model.get_normalized_probs( bd, log_probs=len(models) == 1, sample=sample).data if is_single: probs = gather_target_probs(curr_prob, orig_target) else: if probs is None: probs = curr_prob.new(orig_target.numel()) step = curr_prob.size(0) * curr_prob.size(1) end = step + idx tgt_probs = gather_target_probs( curr_prob.view(tgt.shape + (curr_prob.size(-1), )), tgt) probs[idx:end] = tgt_probs.view(-1) idx = end sample['target'] = orig_target probs = probs.view(sample['target'].shape) if avg_probs is None: avg_probs = probs else: avg_probs.add_(probs) if attn is not None and torch.is_tensor(attn): attn = attn.data if avg_attn is None: avg_attn = attn else: avg_attn.add_(attn) if len(models) > 1: avg_probs.div_(len(models)) avg_probs.log_() if avg_attn is not None: avg_attn.div_(len(models)) bsz = avg_probs.size(0) hypos = [] start_idxs = sample[ 'start_indices'] if 'start_indices' in sample else [0] * bsz for i in range(bsz): # remove padding from ref ref = utils.strip_pad(sample['target'][i, start_idxs[i]:], self.pad) \ if sample['target'] is not None else None tgt_len = ref.numel() avg_probs_i = avg_probs[i][start_idxs[i]:start_idxs[i] + tgt_len] score_i = avg_probs_i.sum() / tgt_len if avg_attn is not None: avg_attn_i = avg_attn[i] alignment = utils.extract_hard_alignment( avg_attn_i, sample['net_input']['src_tokens'][i], sample['target'][i], self.pad, self.eos) else: avg_attn_i = alignment = None hypos.append([{ 'tokens': ref, 'score': score_i, 'attention': avg_attn_i, 'alignment': alignment, 'positional_scores': avg_probs_i, }]) return hypos
def generate(self, models, sample, **kwargs): """Score a batch of translations.""" net_input = sample['net_input'] def batch_for_softmax(dec_out, target): # assumes decoder_out[0] is the only thing needed (may not be correct for future models!) first, rest = dec_out[0], dec_out[1:] bsz, tsz, dim = first.shape if bsz * tsz < self.softmax_batch: yield dec_out, target, True else: flat = first.contiguous().view(1, -1, dim) flat_tgt = target.contiguous().view(flat.shape[:-1]) s = 0 while s < flat.size(1): e = s + self.softmax_batch yield (flat[:, s:e], ) + rest, flat_tgt[:, s:e], False s = e def gather_target_probs(probs, target): probs = probs.gather( dim=2, index=target.unsqueeze(-1), ) return probs def combine_knn_and_vocab_probs(knn_p, vocab_p, coeff): combine_probs = torch.stack([vocab_p, knn_p], dim=0) coeffs = torch.ones_like(combine_probs) coeffs[0] = np.log(1 - coeff) coeffs[1] = np.log(coeff) curr_prob = torch.logsumexp(combine_probs + coeffs, dim=0) return curr_prob orig_target = sample['target'] # compute scores for each model in the ensemble avg_probs = None avg_attn = None extra = None for i_model, model in enumerate(models): assert extra is None model.eval() decoder_out = model(**net_input) attn = decoder_out[1] if type(attn) is dict: attn = attn.get('attn', None) batched = batch_for_softmax(decoder_out, orig_target) probs, idx = None, 0 for i, (bd, tgt, is_single) in enumerate(batched): sample['target'] = tgt curr_prob = model.get_normalized_probs( bd, log_probs=len(models) == 1, sample=sample).data if is_single: probs = gather_target_probs(curr_prob, orig_target) else: if probs is None: probs = curr_prob.new(orig_target.numel()) step = curr_prob.size(0) * curr_prob.size(1) end = step + idx tgt_probs = gather_target_probs( curr_prob.view(tgt.shape + (curr_prob.size(-1), )), tgt) probs[idx:end] = tgt_probs.view(-1) idx = end sample['target'] = orig_target probs = probs.view(sample['target'].shape) extra = {} extra['probs'] = probs[orig_target != self.pad].clone() extra['src_tokens'] = sample['net_input']['src_tokens'][ orig_target != self.pad].clone() extra['target'] = orig_target[orig_target != self.pad] extra['keys'] = decoder_out[1][self.args.knn_keytype].permute( 1, 0, 2)[orig_target != self.pad] #d0, d1 = orig_target.shape #extra['src_id'] = sample['id'].view(d0, 1).expand(d0, d1)[orig_target != self.pad].clone() if 'knn_dstore' in kwargs: dstore = kwargs['knn_dstore'] # TxBxC queries = bd[1][self.args.knn_keytype] if len(models) != 1: raise ValueError('Only knn *log* probs are supported.') yhat_knn_prob, _extra = dstore.get_knn_log_prob( queries, orig_target.permute(1, 0), pad_idx=self.pad) for k, v in _extra.items(): extra[k] = v yhat_knn_prob = yhat_knn_prob.permute(1, 0, 2).squeeze(-1) if self.args.fp16: yhat_knn_prob = yhat_knn_prob.half() probs = probs.half() probs = combine_knn_and_vocab_probs(yhat_knn_prob, probs, self.args.lmbda) if avg_probs is None: avg_probs = probs else: avg_probs.add_(probs) if attn is not None and torch.is_tensor(attn): attn = attn.data if avg_attn is None: avg_attn = attn else: avg_attn.add_(attn) if len(models) > 1: avg_probs.div_(len(models)) avg_probs.log_() if avg_attn is not None: avg_attn.div_(len(models)) # save_extra = self.collate(save_extra) # TODO: This needs to be written to file. bsz = avg_probs.size(0) hypos = [] start_idxs = sample[ 'start_indices'] if 'start_indices' in sample else [0] * bsz for i in range(bsz): # remove padding from ref ref = utils.strip_pad(sample['target'][i, start_idxs[i]:], self.pad) \ if sample['target'] is not None else None tgt_len = ref.numel() avg_probs_i = avg_probs[i][start_idxs[i]:start_idxs[i] + tgt_len] score_i = avg_probs_i.sum() / tgt_len if avg_attn is not None: avg_attn_i = avg_attn[i] if self.compute_alignment: alignment = utils.extract_hard_alignment( avg_attn_i, sample['net_input']['src_tokens'][i], sample['target'][i], self.pad, self.eos, ) else: alignment = None else: avg_attn_i = alignment = None hypos.append([{ 'tokens': ref, 'score': score_i, 'attention': avg_attn_i, 'alignment': alignment, 'positional_scores': avg_probs_i, 'dstore_keys': decoder_out[1][self.args.knn_keytype][start_idxs[i]:, i, :] if self.args.save_knnlm_dstore else None, }]) return hypos, extra
def generate(self, models, sample, **kwargs): """Score a batch of translations.""" net_input = sample['net_input'] def batch_for_softmax(dec_out, target): # assumes decoder_out[0] is the only thing needed (may not be correct for future models!) first, rest = dec_out[0], dec_out[1:] bsz, tsz, dim = first.shape if bsz * tsz < self.softmax_batch: yield dec_out, target, True else: flat = first.contiguous().view(1, -1, dim) flat_tgt = target.contiguous().view(flat.shape[:-1]) s = 0 while s < flat.size(1): e = s + self.softmax_batch yield (flat[:, s:e], ) + rest, flat_tgt[:, s:e], False s = e def gather_target_probs(probs, target): probs = probs.gather( dim=2, index=target.unsqueeze(-1), ) return probs def combine_knn_and_vocab_probs(knn_p, vocab_p, coeff): combine_probs = torch.stack([vocab_p, knn_p], dim=0) coeffs = torch.ones_like(combine_probs) assert coeff != 1.0 # have to mix when using parametric + non-parametric coeffs[0] = np.log(1 - coeff) coeffs[1] = np.log(coeff) curr_prob = torch.logsumexp(combine_probs + coeffs, dim=0) return curr_prob orig_target = sample['target'] # compute scores for each model in the ensemble avg_probs = None avg_attn = None for model in models: model.eval() decoder_out = model(**net_input) attn = decoder_out[1] if type(attn) is dict: attn = attn.get('attn', None) batched = batch_for_softmax(decoder_out, orig_target) probs, idx = None, 0 for i, (bd, tgt, is_single) in enumerate(batched): sample['target'] = tgt curr_prob = model.get_normalized_probs( bd, log_probs=len(models) == 1, sample=sample).data if is_single: probs = gather_target_probs(curr_prob, orig_target) else: if probs is None: probs = curr_prob.new(orig_target.numel()) step = curr_prob.size(0) * curr_prob.size(1) end = step + idx tgt_probs = gather_target_probs( curr_prob.view(tgt.shape + (curr_prob.size(-1), )), tgt) probs[idx:end] = tgt_probs.view(-1) idx = end sample['target'] = orig_target probs = probs.view(sample['target'].shape) if 'knn_dstore' in kwargs: dstore = kwargs['knn_dstore'] # TxBxC queries = bd[1][self.args.knn_keytype] if len(models) != 1: raise ValueError('Only knn *log* probs are supported.') yhat_knn_prob, dists_full, knns_full = dstore.get_knn_log_prob( queries, orig_target.permute(1, 0), pad_idx=self.pad) yhat_knn_prob = yhat_knn_prob.permute(1, 0, 2).squeeze(-1) dists_full = dists_full.permute(1, 0, 2).squeeze(-1) knns_full = knns_full.permute(1, 0, 2).squeeze(-1) if self.args.fp16: yhat_knn_prob = yhat_knn_prob.half() probs = probs.half() orig_probs = probs probs = combine_knn_and_vocab_probs(yhat_knn_prob, probs, self.args.lmbda) if avg_probs is None: avg_probs = probs else: avg_probs.add_(probs) if attn is not None and torch.is_tensor(attn): attn = attn.data if avg_attn is None: avg_attn = attn else: avg_attn.add_(attn) if len(models) > 1: avg_probs.div_(len(models)) avg_probs.log_() if avg_attn is not None: avg_attn.div_(len(models)) bsz = avg_probs.size(0) hypos = [] start_idxs = sample[ 'start_indices'] if 'start_indices' in sample else [0] * bsz for i in range(bsz): # remove padding from ref ref = utils.strip_pad(sample['target'][i, start_idxs[i]:], self.pad) \ if sample['target'] is not None else None src = utils.strip_pad(sample['net_input']['src_tokens'][i, :], self.pad) tgt_len = ref.numel() avg_probs_i = avg_probs[i][start_idxs[i]:start_idxs[i] + tgt_len] orig_probs_i = orig_probs[i][start_idxs[i]:start_idxs[i] + tgt_len] yhat_knn_prob_i = yhat_knn_prob[i][start_idxs[i]:start_idxs[i] + tgt_len] if 'knn_dstore' in kwargs: dists_full_i = dists_full[i][start_idxs[i]:start_idxs[i] + tgt_len] knns_full_i = knns_full[i][start_idxs[i]:start_idxs[i] + tgt_len] else: dists_full_i = None knns_full_i = None score_i = avg_probs_i.sum() / tgt_len if avg_attn is not None: avg_attn_i = avg_attn[i] if self.compute_alignment: alignment = utils.extract_hard_alignment( avg_attn_i, sample['net_input']['src_tokens'][i], sample['target'][i], self.pad, self.eos, ) else: alignment = None else: avg_attn_i = alignment = None if not self.args.save_knnlm_dstore: dstore_keys = None elif self.args.task == 'translation': # TODO, it seems like you need to trim some padding for MT dstore_keys = decoder_out[1][self.args.knn_keytype][ start_idxs[i]:, i, :][0:tgt_len] elif self.args.task == 'language_modeling': dstore_keys = decoder_out[1][self.args.knn_keytype][ start_idxs[i]:, i, :] hypos.append([{ 'source_tokens': src, 'tokens': ref, # This is the target sequence 'score': score_i, 'attention': avg_attn_i, 'alignment': alignment, 'positional_scores': avg_probs_i, 'original_scores': orig_probs_i, 'yhat_scores': yhat_knn_prob_i, 'dists_full': dists_full_i, 'knns_full': knns_full_i, 'dstore_keys': dstore_keys, }]) return hypos