示例#1
0
    def generate(self, models, sample, **kwargs):
        self.model.reset_incremental_state()
        finalized = super()._generate(sample, **kwargs)

        src_tokens = sample["net_input"]["src_tokens"]
        bsz = src_tokens.shape[0]
        beam_size = self.beam_size
        src_tokens, src_lengths, prev_output_tokens, tgt_tokens = self._prepare_batch_for_alignment(
            sample, finalized)
        if any(
                getattr(m, "full_context_alignment", False)
                for m in self.model.models):
            attn = self.model.forward_align(src_tokens, src_lengths,
                                            prev_output_tokens)
        else:
            attn = [
                finalized[i //
                          beam_size][i % beam_size]["attention"].transpose(
                              1, 0) for i in range(bsz * beam_size)
            ]

        # Process the attn matrix to extract hard alignments.
        for i in range(bsz * beam_size):
            alignment = utils.extract_hard_alignment(attn[i], src_tokens[i],
                                                     tgt_tokens[i], self.pad,
                                                     self.eos)
            finalized[i // beam_size][i % beam_size]["alignment"] = alignment
        return finalized
示例#2
0
    def generate(self, models, sample, **kwargs):
        model = EnsembleModelWithAlignment(models)
        finalized = super()._generate(model, sample, **kwargs)

        src_tokens = sample['net_input']['src_tokens']
        bsz = src_tokens.shape[0]
        beam_size = self.beam_size
        src_tokens, src_lengths, prev_output_tokens, tgt_tokens = \
            self._prepare_batch_for_alignment(sample, finalized)
        if any(
                getattr(m, 'full_context_alignment', False)
                for m in model.models):
            attn = model.forward_align(src_tokens, src_lengths,
                                       prev_output_tokens)
        else:
            attn = [
                finalized[i //
                          beam_size][i % beam_size]['attention'].transpose(
                              1, 0) for i in range(bsz * beam_size)
            ]

        # Process the attn matrix to extract hard alignments.
        for i in range(bsz * beam_size):
            alignment = utils.extract_hard_alignment(attn[i], src_tokens[i],
                                                     tgt_tokens[i], self.pad,
                                                     self.eos)
            finalized[i // beam_size][i % beam_size]['alignment'] = alignment
        return finalized
 def finalized_hypos(step, prev_out_token, prev_out_score,
                     prev_out_attn, prev_out_const_del,
                     prev_out_const_ins, src_tokens):
     cutoff = prev_out_token.ne(self.pad)
     tokens = prev_out_token[cutoff]
     scores = prev_out_score[cutoff]
     const_del = None
     if prev_out_const_del is not None:
         const_del = prev_out_const_del[cutoff]
     const_ins = None
     if prev_out_const_ins is not None:
         const_ins = prev_out_const_ins[cutoff]
     if prev_out_attn is None:
         hypo_attn, alignment = None, None
     else:
         hypo_attn = prev_out_attn[cutoff]
         alignment = utils.extract_hard_alignment(
             hypo_attn, src_tokens, tokens, self.pad, self.eos)
     return {
         'steps': step,
         'tokens': tokens,
         'positional_scores': scores,
         'score': scores.mean(),
         'hypo_attn': hypo_attn,
         'alignment': alignment,
         'const_del': const_del,
         'const_ins': const_ins,
     }
    def generate(self, models, sample, **kwargs):
        """Score a batch of translations."""
        net_input = sample['net_input']

        def batch_for_softmax(dec_out, target):
            # assumes decoder_out[0] is the only thing needed (may not be correct for future models!)
            first, rest = dec_out[0], dec_out[1:]
            bsz, tsz, dim = first.shape
            if bsz * tsz < self.softmax_batch:
                yield dec_out, target, True
            else:
                flat = first.contiguous().view(1, -1, dim)
                flat_tgt = target.contiguous().view(flat.shape[:-1])
                s = 0
                while s < flat.size(1):
                    e = s + self.softmax_batch
                    yield (flat[:, s:e], ) + rest, flat_tgt[:, s:e], False
                    s = e

        def gather_target_probs(probs, target):
            probs = probs.gather(
                dim=2,
                index=target.unsqueeze(-1),
            )
            return probs

        orig_target = sample['target']

        # compute scores for each model in the ensemble
        avg_probs = None
        avg_attn = None
        for model in models:
            model.eval()
            decoder_out = model.forward(**net_input)
            attn = decoder_out[1]
            if type(attn) is dict:
                attn = attn.get('attn', None)

            batched = batch_for_softmax(decoder_out, orig_target)
            probs, idx = None, 0
            for bd, tgt, is_single in batched:
                sample['target'] = tgt
                curr_prob = model.get_normalized_probs(
                    bd, log_probs=len(models) == 1, sample=sample).data
                if is_single:
                    probs = gather_target_probs(curr_prob, orig_target)
                else:
                    if probs is None:
                        probs = curr_prob.new(orig_target.numel())
                    step = curr_prob.size(0) * curr_prob.size(1)
                    end = step + idx
                    tgt_probs = gather_target_probs(
                        curr_prob.view(tgt.shape + (curr_prob.size(-1), )),
                        tgt)
                    probs[idx:end] = tgt_probs.view(-1)
                    idx = end
                sample['target'] = orig_target

            probs = probs.view(sample['target'].shape)

            if avg_probs is None:
                avg_probs = probs
            else:
                avg_probs.add_(probs)
            if attn is not None and torch.is_tensor(attn):
                attn = attn.data
                if avg_attn is None:
                    avg_attn = attn
                else:
                    avg_attn.add_(attn)
        if len(models) > 1:
            avg_probs.div_(len(models))
            avg_probs.log_()
            if avg_attn is not None:
                avg_attn.div_(len(models))

        bsz = avg_probs.size(0)
        hypos = []
        start_idxs = sample[
            'start_indices'] if 'start_indices' in sample else [0] * bsz
        for i in range(bsz):
            # remove padding from ref
            ref = utils.strip_pad(sample['target'][i, start_idxs[i]:], self.pad) \
                if sample['target'] is not None else None
            tgt_len = ref.numel()
            avg_probs_i = avg_probs[i][start_idxs[i]:start_idxs[i] + tgt_len]
            score_i = avg_probs_i.sum() / tgt_len
            if avg_attn is not None:
                avg_attn_i = avg_attn[i]
                alignment = utils.extract_hard_alignment(
                    avg_attn_i, sample['net_input']['src_tokens'][i],
                    sample['target'][i], self.pad, self.eos)
            else:
                avg_attn_i = alignment = None
            hypos.append([{
                'tokens': ref,
                'score': score_i,
                'attention': avg_attn_i,
                'alignment': alignment,
                'positional_scores': avg_probs_i,
            }])
        return hypos
示例#5
0
    def generate(self, models, sample, **kwargs):
        """Score a batch of translations."""
        net_input = sample['net_input']

        def batch_for_softmax(dec_out, target):
            # assumes decoder_out[0] is the only thing needed (may not be correct for future models!)
            first, rest = dec_out[0], dec_out[1:]
            bsz, tsz, dim = first.shape
            if bsz * tsz < self.softmax_batch:
                yield dec_out, target, True
            else:
                flat = first.contiguous().view(1, -1, dim)
                flat_tgt = target.contiguous().view(flat.shape[:-1])
                s = 0
                while s < flat.size(1):
                    e = s + self.softmax_batch
                    yield (flat[:, s:e], ) + rest, flat_tgt[:, s:e], False
                    s = e

        def gather_target_probs(probs, target):
            probs = probs.gather(
                dim=2,
                index=target.unsqueeze(-1),
            )
            return probs

        def combine_knn_and_vocab_probs(knn_p, vocab_p, coeff):
            combine_probs = torch.stack([vocab_p, knn_p], dim=0)
            coeffs = torch.ones_like(combine_probs)
            coeffs[0] = np.log(1 - coeff)
            coeffs[1] = np.log(coeff)
            curr_prob = torch.logsumexp(combine_probs + coeffs, dim=0)

            return curr_prob

        orig_target = sample['target']

        # compute scores for each model in the ensemble
        avg_probs = None
        avg_attn = None
        extra = None
        for i_model, model in enumerate(models):
            assert extra is None
            model.eval()
            decoder_out = model(**net_input)
            attn = decoder_out[1]
            if type(attn) is dict:
                attn = attn.get('attn', None)

            batched = batch_for_softmax(decoder_out, orig_target)
            probs, idx = None, 0
            for i, (bd, tgt, is_single) in enumerate(batched):
                sample['target'] = tgt
                curr_prob = model.get_normalized_probs(
                    bd, log_probs=len(models) == 1, sample=sample).data

                if is_single:
                    probs = gather_target_probs(curr_prob, orig_target)
                else:
                    if probs is None:
                        probs = curr_prob.new(orig_target.numel())
                    step = curr_prob.size(0) * curr_prob.size(1)
                    end = step + idx
                    tgt_probs = gather_target_probs(
                        curr_prob.view(tgt.shape + (curr_prob.size(-1), )),
                        tgt)
                    probs[idx:end] = tgt_probs.view(-1)
                    idx = end
                sample['target'] = orig_target

            probs = probs.view(sample['target'].shape)

            extra = {}
            extra['probs'] = probs[orig_target != self.pad].clone()
            extra['src_tokens'] = sample['net_input']['src_tokens'][
                orig_target != self.pad].clone()
            extra['target'] = orig_target[orig_target != self.pad]
            extra['keys'] = decoder_out[1][self.args.knn_keytype].permute(
                1, 0, 2)[orig_target != self.pad]
            #d0, d1 = orig_target.shape
            #extra['src_id'] = sample['id'].view(d0, 1).expand(d0, d1)[orig_target != self.pad].clone()

            if 'knn_dstore' in kwargs:
                dstore = kwargs['knn_dstore']
                # TxBxC
                queries = bd[1][self.args.knn_keytype]
                if len(models) != 1:
                    raise ValueError('Only knn *log* probs are supported.')

                yhat_knn_prob, _extra = dstore.get_knn_log_prob(
                    queries, orig_target.permute(1, 0), pad_idx=self.pad)
                for k, v in _extra.items():
                    extra[k] = v
                yhat_knn_prob = yhat_knn_prob.permute(1, 0, 2).squeeze(-1)
                if self.args.fp16:
                    yhat_knn_prob = yhat_knn_prob.half()
                    probs = probs.half()

                probs = combine_knn_and_vocab_probs(yhat_knn_prob, probs,
                                                    self.args.lmbda)

            if avg_probs is None:
                avg_probs = probs
            else:
                avg_probs.add_(probs)
            if attn is not None and torch.is_tensor(attn):
                attn = attn.data
                if avg_attn is None:
                    avg_attn = attn
                else:
                    avg_attn.add_(attn)
        if len(models) > 1:
            avg_probs.div_(len(models))
            avg_probs.log_()
            if avg_attn is not None:
                avg_attn.div_(len(models))

        # save_extra = self.collate(save_extra)
        # TODO: This needs to be written to file.

        bsz = avg_probs.size(0)
        hypos = []
        start_idxs = sample[
            'start_indices'] if 'start_indices' in sample else [0] * bsz
        for i in range(bsz):
            # remove padding from ref
            ref = utils.strip_pad(sample['target'][i, start_idxs[i]:], self.pad) \
                if sample['target'] is not None else None
            tgt_len = ref.numel()
            avg_probs_i = avg_probs[i][start_idxs[i]:start_idxs[i] + tgt_len]
            score_i = avg_probs_i.sum() / tgt_len
            if avg_attn is not None:
                avg_attn_i = avg_attn[i]
                if self.compute_alignment:
                    alignment = utils.extract_hard_alignment(
                        avg_attn_i,
                        sample['net_input']['src_tokens'][i],
                        sample['target'][i],
                        self.pad,
                        self.eos,
                    )
                else:
                    alignment = None
            else:
                avg_attn_i = alignment = None
            hypos.append([{
                'tokens':
                ref,
                'score':
                score_i,
                'attention':
                avg_attn_i,
                'alignment':
                alignment,
                'positional_scores':
                avg_probs_i,
                'dstore_keys':
                decoder_out[1][self.args.knn_keytype][start_idxs[i]:, i, :]
                if self.args.save_knnlm_dstore else None,
            }])
        return hypos, extra
示例#6
0
    def generate(self, models, sample, **kwargs):
        """Score a batch of translations."""
        net_input = sample['net_input']

        def batch_for_softmax(dec_out, target):
            # assumes decoder_out[0] is the only thing needed (may not be correct for future models!)
            first, rest = dec_out[0], dec_out[1:]
            bsz, tsz, dim = first.shape
            if bsz * tsz < self.softmax_batch:
                yield dec_out, target, True
            else:
                flat = first.contiguous().view(1, -1, dim)
                flat_tgt = target.contiguous().view(flat.shape[:-1])
                s = 0
                while s < flat.size(1):
                    e = s + self.softmax_batch
                    yield (flat[:, s:e], ) + rest, flat_tgt[:, s:e], False
                    s = e

        def gather_target_probs(probs, target):
            probs = probs.gather(
                dim=2,
                index=target.unsqueeze(-1),
            )
            return probs

        def combine_knn_and_vocab_probs(knn_p, vocab_p, coeff):
            combine_probs = torch.stack([vocab_p, knn_p], dim=0)
            coeffs = torch.ones_like(combine_probs)
            assert coeff != 1.0  # have to mix when using parametric + non-parametric
            coeffs[0] = np.log(1 - coeff)
            coeffs[1] = np.log(coeff)
            curr_prob = torch.logsumexp(combine_probs + coeffs, dim=0)

            return curr_prob

        orig_target = sample['target']

        # compute scores for each model in the ensemble
        avg_probs = None
        avg_attn = None
        for model in models:
            model.eval()
            decoder_out = model(**net_input)
            attn = decoder_out[1]
            if type(attn) is dict:
                attn = attn.get('attn', None)

            batched = batch_for_softmax(decoder_out, orig_target)
            probs, idx = None, 0
            for i, (bd, tgt, is_single) in enumerate(batched):
                sample['target'] = tgt
                curr_prob = model.get_normalized_probs(
                    bd, log_probs=len(models) == 1, sample=sample).data

                if is_single:
                    probs = gather_target_probs(curr_prob, orig_target)
                else:
                    if probs is None:
                        probs = curr_prob.new(orig_target.numel())
                    step = curr_prob.size(0) * curr_prob.size(1)
                    end = step + idx
                    tgt_probs = gather_target_probs(
                        curr_prob.view(tgt.shape + (curr_prob.size(-1), )),
                        tgt)
                    probs[idx:end] = tgt_probs.view(-1)
                    idx = end
                sample['target'] = orig_target

            probs = probs.view(sample['target'].shape)

            if 'knn_dstore' in kwargs:
                dstore = kwargs['knn_dstore']
                # TxBxC
                queries = bd[1][self.args.knn_keytype]
                if len(models) != 1:
                    raise ValueError('Only knn *log* probs are supported.')

                yhat_knn_prob, dists_full, knns_full = dstore.get_knn_log_prob(
                    queries, orig_target.permute(1, 0), pad_idx=self.pad)
                yhat_knn_prob = yhat_knn_prob.permute(1, 0, 2).squeeze(-1)
                dists_full = dists_full.permute(1, 0, 2).squeeze(-1)
                knns_full = knns_full.permute(1, 0, 2).squeeze(-1)

                if self.args.fp16:
                    yhat_knn_prob = yhat_knn_prob.half()
                    probs = probs.half()
                orig_probs = probs
                probs = combine_knn_and_vocab_probs(yhat_knn_prob, probs,
                                                    self.args.lmbda)

            if avg_probs is None:
                avg_probs = probs
            else:
                avg_probs.add_(probs)
            if attn is not None and torch.is_tensor(attn):
                attn = attn.data
                if avg_attn is None:
                    avg_attn = attn
                else:
                    avg_attn.add_(attn)
        if len(models) > 1:
            avg_probs.div_(len(models))
            avg_probs.log_()
            if avg_attn is not None:
                avg_attn.div_(len(models))

        bsz = avg_probs.size(0)
        hypos = []
        start_idxs = sample[
            'start_indices'] if 'start_indices' in sample else [0] * bsz

        for i in range(bsz):
            # remove padding from ref
            ref = utils.strip_pad(sample['target'][i, start_idxs[i]:], self.pad) \
                if sample['target'] is not None else None
            src = utils.strip_pad(sample['net_input']['src_tokens'][i, :],
                                  self.pad)
            tgt_len = ref.numel()
            avg_probs_i = avg_probs[i][start_idxs[i]:start_idxs[i] + tgt_len]
            orig_probs_i = orig_probs[i][start_idxs[i]:start_idxs[i] + tgt_len]
            yhat_knn_prob_i = yhat_knn_prob[i][start_idxs[i]:start_idxs[i] +
                                               tgt_len]
            if 'knn_dstore' in kwargs:
                dists_full_i = dists_full[i][start_idxs[i]:start_idxs[i] +
                                             tgt_len]
                knns_full_i = knns_full[i][start_idxs[i]:start_idxs[i] +
                                           tgt_len]
            else:
                dists_full_i = None
                knns_full_i = None
            score_i = avg_probs_i.sum() / tgt_len
            if avg_attn is not None:
                avg_attn_i = avg_attn[i]
                if self.compute_alignment:
                    alignment = utils.extract_hard_alignment(
                        avg_attn_i,
                        sample['net_input']['src_tokens'][i],
                        sample['target'][i],
                        self.pad,
                        self.eos,
                    )
                else:
                    alignment = None
            else:
                avg_attn_i = alignment = None
            if not self.args.save_knnlm_dstore:
                dstore_keys = None
            elif self.args.task == 'translation':  # TODO, it seems like you need to trim some padding for MT
                dstore_keys = decoder_out[1][self.args.knn_keytype][
                    start_idxs[i]:, i, :][0:tgt_len]
            elif self.args.task == 'language_modeling':
                dstore_keys = decoder_out[1][self.args.knn_keytype][
                    start_idxs[i]:, i, :]
            hypos.append([{
                'source_tokens': src,
                'tokens': ref,  # This is the target sequence
                'score': score_i,
                'attention': avg_attn_i,
                'alignment': alignment,
                'positional_scores': avg_probs_i,
                'original_scores': orig_probs_i,
                'yhat_scores': yhat_knn_prob_i,
                'dists_full': dists_full_i,
                'knns_full': knns_full_i,
                'dstore_keys': dstore_keys,
            }])
        return hypos