def forward_word_ins(self, encoder_outs, output_tokens, output_scores,
                         attn, can_ins_word):
        word_ins_score_avg = []
        word_ins_attn_avg = []
        for model, encoder_out in zip(self.models, encoder_outs):
            word_ins_out, word_ins_attn = model.decoder.forward_word_ins(
                _skip(output_tokens, can_ins_word),
                _skip(encoder_out, can_ins_word),
            )
            word_ins_score = F.log_softmax(word_ins_out, 2)
            word_ins_score_avg.append(word_ins_score)
            word_ins_attn_avg.append(word_ins_attn)
        word_ins_score_avg = torch.logsumexp(
            torch.stack(word_ins_score_avg, dim=0), dim=0) - math.log(
                len(self.models))
        if word_ins_attn_avg[0] is not None:
            word_ins_attn_avg = torch.stack(word_ins_attn_avg, dim=0) / len(
                self.models)
        else:
            word_ins_attn_avg = None
        word_ins_score_max, word_ins_pred = word_ins_score_avg.max(-1)

        _tokens, _scores = _apply_ins_words(
            output_tokens[can_ins_word],
            output_scores[can_ins_word],
            word_ins_pred,
            word_ins_score_max,
            self.unk,
        )

        output_tokens = _fill(output_tokens, can_ins_word, _tokens, self.pad)
        output_scores = _fill(output_scores, can_ins_word, _scores, 0)
        attn = _fill(attn, can_ins_word, word_ins_attn, 0.)
        return output_tokens, output_scores, attn
 def forward_mask_ins(self, encoder_outs, output_tokens, output_scores,
                      can_ins_mask, eos_penalty, max_lens):
     mask_ins_score_avg = []
     for model, encoder_out in zip(self.models, encoder_outs):
         mask_ins_out, _ = model.decoder.forward_mask_ins(
             _skip(output_tokens, can_ins_mask),
             _skip(encoder_out, can_ins_mask),
         )
         mask_ins_score = F.log_softmax(mask_ins_out, 2)
         if eos_penalty > 0.0:
             mask_ins_score[:, :, 0] -= eos_penalty
         mask_ins_score_avg.append(mask_ins_score)
     mask_ins_score_avg = torch.logsumexp(
         torch.stack(mask_ins_score_avg, dim=0), dim=0) - math.log(
             len(self.models))
     mask_ins_pred = mask_ins_score_avg.max(-1)[1]
     mask_ins_pred = torch.min(
         mask_ins_pred, max_lens[can_ins_mask,
                                 None].expand_as(mask_ins_pred))
     _tokens, _scores = _apply_ins_masks(
         output_tokens[can_ins_mask],
         output_scores[can_ins_mask],
         mask_ins_pred,
         self.pad,
         self.unk,
         self.eos,
     )
     output_tokens = _fill(output_tokens, can_ins_mask, _tokens, self.pad)
     output_scores = _fill(output_scores, can_ins_mask, _scores, 0)
     return output_tokens, output_scores
Esempio n. 3
0
    def generate(self, models, sample, prefix_tokens=None):

        # TODO: model ensemble
        assert len(models) == 1, 'only support single model'
        model = models[0]
        if not self.retain_dropout:
            model.eval()

        # TODO: better encoder inputs?
        src_tokens = sample['net_input']['src_tokens']
        src_lengths = sample['net_input']['src_lengths']
        bsz, src_len = src_tokens.size()
        sent_idxs = torch.arange(bsz, device=src_tokens.device)

        # encoding
        encoder_out = model.forward_encoder([src_tokens, src_lengths])

        # initialize buffers (very model specific, with length prediction or not)
        prev_decoder_out = model.initialize_output_tokens(
            encoder_out, src_tokens)
        prev_out_tokens = prev_decoder_out['output_tokens'].clone()

        finalized = [[] for _ in range(bsz)]

        def is_a_loop(x, y, s, a):
            b, l_x, l_y = x.size(0), x.size(1), y.size(1)
            if l_x > l_y:
                y = torch.cat([y, x.new_zeros(b, l_x - l_y).fill_(self.pad)],
                              1)
                s = torch.cat([s, s.new_zeros(b, l_x - l_y)], 1)
                if a is not None:
                    a = torch.cat([a, a.new_zeros(b, l_x - l_y, a.size(2))], 1)
            elif l_x < l_y:
                x = torch.cat([x, y.new_zeros(b, l_y - l_x).fill_(self.pad)],
                              1)
            return (x == y).all(1), y, s, a

        def finalized_hypos(step, prev_out_token, prev_out_score,
                            prev_out_attn):
            cutoff = prev_out_token.ne(self.pad)
            tokens = prev_out_token[cutoff]
            scores = prev_out_score[cutoff]
            if prev_out_attn is None:
                hypo_attn, alignment = None, None
            else:
                hypo_attn = prev_out_attn[cutoff]
                alignment = hypo_attn.max(dim=1)[1]
            return {
                'steps': step,
                'tokens': tokens,
                'positional_scores': scores,
                'score': scores.mean(),
                'hypo_attn': hypo_attn,
                'alignment': alignment,
            }

        for step in range(self.max_iter + 1):

            decoder_options = {
                'eos_penalty': self.eos_penalty,
                'max_ratio': self.max_ratio,
                'decoding_format': self.decoding_format
            }
            prev_decoder_out['step'] = step
            prev_decoder_out['max_step'] = self.max_iter + 1

            decoder_out = model.forward_decoder(prev_decoder_out, encoder_out,
                                                **decoder_options)

            if self.adaptive:
                # terminate if there is a loop
                terminated, out_tokens, out_scores, out_attn = is_a_loop(
                    prev_out_tokens, decoder_out['output_tokens'],
                    decoder_out['output_scores'], decoder_out['attn'])
                decoder_out['output_tokens'] = out_tokens
                decoder_out['output_scores'] = out_scores
                decoder_out['attn'] = out_attn

            else:
                terminated = decoder_out['output_tokens'].new_zeros(
                    decoder_out['output_tokens'].size(0)).bool()

            if step == self.max_iter:  # reach last iteration, terminate
                terminated.fill_(1)

            # collect finalized sentences
            finalized_idxs = sent_idxs[terminated]
            finalized_tokens = decoder_out['output_tokens'][terminated]
            finalized_scores = decoder_out['output_scores'][terminated]
            finalized_attn = None if decoder_out[
                'attn'] is None else decoder_out['attn'][terminated]

            for i in range(finalized_idxs.size(0)):
                finalized[finalized_idxs[i]] = [
                    finalized_hypos(
                        step, finalized_tokens[i], finalized_scores[i],
                        None if finalized_attn is None else finalized_attn[i])
                ]
            # check if all terminated
            if terminated.sum() == terminated.size(0):
                break

            # for next step
            prev_decoder_out = _skip(decoder_out, ~terminated)
            encoder_out = _skip(encoder_out, ~terminated)
            sent_idxs = _skip(sent_idxs, ~terminated)

            prev_out_tokens = prev_decoder_out['output_tokens'].clone()

        return finalized
Esempio n. 4
0
    def generate(self, models, sample, prefix_tokens=None):

        if len(models) == 1:
            # Keep this for other NAT models for which we have yet to implement ensemble wrappers. Later delete this.
            model = models[0]
        elif isinstance(models[0], LevenshteinTransformerModel):
            model = EnsembleLevT(models)
        else:
            raise NotImplementedError
        if not self.retain_dropout:
            model.eval()

        # TODO: better encoder inputs?
        src_tokens = sample["net_input"]["src_tokens"]
        src_lengths = sample["net_input"]["src_lengths"]
        bsz, src_len = src_tokens.size()
        sent_idxs = torch.arange(bsz)

        # encoding
        encoder_out = model.forward_encoder([src_tokens, src_lengths])

        # initialize buffers (very model specific, with length prediction or not)
        prev_decoder_out = model.initialize_output_tokens(
            encoder_out, src_tokens)
        prev_output_tokens = prev_decoder_out[0].clone()

        finalized = [[] for _ in range(bsz)]

        def is_a_loop(x, y, s, a):
            b, l_x, l_y = x.size(0), x.size(1), y.size(1)
            if l_x > l_y:
                y = torch.cat([y, x.new_zeros(b, l_x - l_y).fill_(self.pad)],
                              1)
                s = torch.cat([s, s.new_zeros(b, l_x - l_y)], 1)
                if a is not None:
                    a = torch.cat([a, a.new_zeros(b, l_x - l_y, a.size(2))], 1)
            elif l_x < l_y:
                x = torch.cat([x, y.new_zeros(b, l_y - l_x).fill_(self.pad)],
                              1)
            return (x == y).all(1), y, s, a

        def finalized_hypos(step, prev_out_token, prev_out_score,
                            prev_out_attn):
            cutoff = prev_out_token.ne(self.pad)
            tokens = prev_out_token[cutoff]
            scores = prev_out_score[cutoff]
            if prev_out_attn is None:
                hypo_attn, alignment = None, None
            else:
                hypo_attn = prev_out_attn[cutoff]
                alignment = hypo_attn.max(dim=1)[1]
            return {
                "steps": step,
                "tokens": tokens,
                "positional_scores": scores,
                "score": scores.mean(),
                "hypo_attn": hypo_attn,
                "alignment": alignment,
            }

        for step in range(self.max_iter + 1):

            decoder_options = {
                "eos_penalty": self.eos_penalty,
                "max_ratio": self.max_ratio,
                "decoding_format": self.decoding_format,
            }
            prev_decoder_out[3] = step
            prev_decoder_out[4] = self.max_iter + 1

            decoder_out = model.forward_decoder(prev_decoder_out, encoder_out,
                                                **decoder_options)

            if self.adaptive:
                # terminate if there is a loop
                terminated, out_tokens, out_scores, out_attn = is_a_loop(
                    prev_output_tokens, decoder_out[0], decoder_out[1],
                    decoder_out[2])
                decoder_out[0] = out_tokens
                decoder_out[1] = out_scores
                decoder_out[2] = out_attn

            else:
                terminated = decoder_out[0].new_zeros(
                    decoder_out[0].size(0)).bool()

            if step == self.max_iter:  # reach last iteration, terminate
                terminated.fill_(1)

            # collect finalized sentences
            finalized_idxs = sent_idxs[terminated]
            finalized_tokens = decoder_out[0][terminated]
            finalized_scores = decoder_out[1][terminated]
            finalized_attn = (None if decoder_out[2] is None else
                              decoder_out[2][terminated])

            for i in range(finalized_idxs.size(0)):
                finalized[finalized_idxs[i]] = [
                    finalized_hypos(
                        step,
                        finalized_tokens[i],
                        finalized_scores[i],
                        None if finalized_attn is None else finalized_attn[i],
                    )
                ]
            # check if all terminated
            if terminated.sum() == terminated.size(0):
                break

            # for next step
            prev_decoder_out = _skip(decoder_out, ~terminated)
            encoder_out = script_skip_tensor_list(encoder_out, ~terminated)
            sent_idxs = _skip(sent_idxs, ~terminated)

            prev_output_tokens = prev_decoder_out[0].clone()

        return finalized
Esempio n. 5
0
    def forward_decoder(self,
                        decoder_out,
                        encoder_out,
                        eos_penalty=0.0,
                        max_ratio=None,
                        **kwargs):

        output_tokens = decoder_out["output_tokens"]
        output_scores = decoder_out["output_scores"]
        attn = decoder_out["attn"]

        if max_ratio is None:
            max_lens = output_tokens.new(output_tokens.size(0)).fill_(255)
        else:
            max_lens = ((~encoder_out["encoder_padding_mask"]).sum(1) *
                        max_ratio).clamp(min=10)

        # delete words
        # do not delete tokens if it is <s> </s>
        can_del_word = output_tokens.ne(self.pad).sum(1) > 2
        if can_del_word.sum() != 0:  # we cannot delete, skip
            word_del_out, word_del_attn = self.decoder.forward_word_del(
                _skip(output_tokens, can_del_word),
                _skip(encoder_out, can_del_word))
            word_del_score = F.log_softmax(word_del_out, 2)
            word_del_pred = word_del_score.max(-1)[1].bool()

            _tokens, _scores, _attn = _apply_del_words(
                output_tokens[can_del_word],
                output_scores[can_del_word],
                word_del_attn,
                word_del_pred,
                self.pad,
                self.bos,
                self.eos,
            )
            output_tokens = _fill(output_tokens, can_del_word, _tokens,
                                  self.pad)
            output_scores = _fill(output_scores, can_del_word, _scores, 0)
            attn = _fill(attn, can_del_word, _attn, 0.)

        # insert placeholders
        can_ins_mask = output_tokens.ne(self.pad).sum(1) < max_lens
        if can_ins_mask.sum() != 0:
            mask_ins_out, _ = self.decoder.forward_mask_ins(
                _skip(output_tokens, can_ins_mask),
                _skip(encoder_out, can_ins_mask))
            mask_ins_score = F.log_softmax(mask_ins_out, 2)
            if eos_penalty > 0.0:
                mask_ins_score[:, :, 0] -= eos_penalty
            mask_ins_pred = mask_ins_score.max(-1)[1]
            mask_ins_pred = torch.min(
                mask_ins_pred, max_lens[:, None].expand_as(mask_ins_pred))

            _tokens, _scores = _apply_ins_masks(
                output_tokens[can_ins_mask],
                output_scores[can_ins_mask],
                mask_ins_pred,
                self.pad,
                self.unk,
                self.eos,
            )
            output_tokens = _fill(output_tokens, can_ins_mask, _tokens,
                                  self.pad)
            output_scores = _fill(output_scores, can_ins_mask, _scores, 0)

        # insert words
        can_ins_word = output_tokens.eq(self.unk).sum(1) > 0
        if can_ins_word.sum() != 0:
            word_ins_out, word_ins_attn = self.decoder.forward_word_ins(
                _skip(output_tokens, can_ins_word),
                _skip(encoder_out, can_ins_word))
            word_ins_score = F.log_softmax(word_ins_out, 2)
            word_ins_pred = word_ins_score.max(-1)[1]

            _tokens, _scores = _apply_ins_words(
                output_tokens[can_ins_word],
                output_scores[can_ins_word],
                word_ins_pred,
                word_ins_score,
                self.unk,
            )

            output_tokens = _fill(output_tokens, can_ins_word, _tokens,
                                  self.pad)
            output_scores = _fill(output_scores, can_ins_word, _scores, 0)
            attn = _fill(attn, can_ins_word, word_ins_attn, 0.)

        # delete some unnecessary paddings
        cut_off = output_tokens.ne(self.pad).sum(1).max()
        output_tokens = output_tokens[:, :cut_off]
        output_scores = output_scores[:, :cut_off]
        attn = None if attn is None else attn[:, :cut_off, :]
        return {
            "output_tokens": output_tokens,
            "output_scores": output_scores,
            "attn": attn,
        }
    def generate(self, models, sample, prefix_tokens=None):

        if len(models) == 1:
            # Keep this for other NAT models for which we have yet to implement ensemble wrappers. Later delete this.
            model = models[0]
        elif isinstance(models[0], LevenshteinTransformerModel):
            model = EnsembleLevT(models)
        else:
            raise NotImplementedError

        if not self.retain_dropout:
            model.eval()

        # TODO: better encoder inputs?
        src_tokens = sample['net_input']['src_tokens']
        src_lengths = sample['net_input']['src_lengths']
        tgt_init_tokens = None
        tgt_init_lengths = None
        if 'tgt_init_tokens' in sample['net_input']:
            tgt_init_tokens = sample['net_input']['tgt_init_tokens']
            tgt_init_lengths = sample['net_input']['tgt_init_lengths']
        bsz, src_len = src_tokens.size()
        sent_idxs = torch.arange(bsz, device=src_tokens.device)

        # encoding
        encoder_out = model.forward_encoder([src_tokens, src_lengths])

        # initialize buffers (very model specific, with length prediction or not)
        prev_decoder_out = model.initialize_output_tokens(
            encoder_out, src_tokens, tgt_init_tokens, tgt_init_lengths)
        prev_out_tokens = prev_decoder_out['output_tokens'].clone()

        finalized = [[] for _ in range(bsz)]

        def is_a_loop(x, y, s, a, c, d):
            b, l_x, l_y = x.size(0), x.size(1), y.size(1)
            if l_x > l_y:
                y = torch.cat([y, x.new_zeros(b, l_x - l_y).fill_(self.pad)],
                              1)
                s = torch.cat([s, s.new_zeros(b, l_x - l_y)], 1)
                if a is not None:
                    a = torch.cat([a, a.new_zeros(b, l_x - l_y, a.size(2))], 1)
                if c is not None:
                    c = torch.cat([c, c.new_zeros(b, l_x - l_y)], 1)
                if d is not None:
                    d = torch.cat([d, d.new_zeros(b, l_x - l_y)], 1)
            elif l_x < l_y:
                x = torch.cat([x, y.new_zeros(b, l_y - l_x).fill_(self.pad)],
                              1)
            return (x == y).all(1), y, s, a, c, d

        def finalized_hypos(step, prev_out_token, prev_out_score,
                            prev_out_attn, prev_out_const_del,
                            prev_out_const_ins, src_tokens):
            cutoff = prev_out_token.ne(self.pad)
            tokens = prev_out_token[cutoff]
            scores = prev_out_score[cutoff]
            const_del = None
            if prev_out_const_del is not None:
                const_del = prev_out_const_del[cutoff]
            const_ins = None
            if prev_out_const_ins is not None:
                const_ins = prev_out_const_ins[cutoff]
            if prev_out_attn is None:
                hypo_attn, alignment = None, None
            else:
                hypo_attn = prev_out_attn[cutoff]
                alignment = utils.extract_hard_alignment(
                    hypo_attn, src_tokens, tokens, self.pad, self.eos)
            return {
                'steps': step,
                'tokens': tokens,
                'positional_scores': scores,
                'score': scores.mean(),
                'hypo_attn': hypo_attn,
                'alignment': alignment,
                'const_del': const_del,
                'const_ins': const_ins,
            }

        for step in range(self.max_iter + 1):

            decoder_options = {
                'eos_penalty': self.eos_penalty,
                'max_ratio': self.max_ratio,
                'decoding_format': self.decoding_format,
                'preserve_constraint': self.preserve_constraint,
                'allow_insertion_constraint': self.allow_insertion_constraint,
            }
            prev_decoder_out['step'] = step
            prev_decoder_out['max_step'] = self.max_iter + 1

            decoder_out = model.forward_decoder(prev_decoder_out, encoder_out,
                                                **decoder_options)

            if self.adaptive:
                # terminate if there is a loop
                terminated, out_tokens, out_scores, out_attn, out_const_del, out_const_ins = is_a_loop(
                    prev_out_tokens, decoder_out['output_tokens'],
                    decoder_out['output_scores'], decoder_out['attn'],
                    decoder_out['const_del'], decoder_out['const_ins'])
                decoder_out['output_tokens'] = out_tokens
                decoder_out['output_scores'] = out_scores
                decoder_out['attn'] = out_attn
                decoder_out['const_del'] = out_const_del
                decoder_out['const_ins'] = out_const_ins

            else:
                terminated = decoder_out['output_tokens'].new_zeros(
                    decoder_out['output_tokens'].size(0)).bool()

            if step == self.max_iter:  # reach last iteration, terminate
                terminated.fill_(1)

            # collect finalized sentences
            finalized_idxs = sent_idxs[terminated]
            finalized_tokens = decoder_out['output_tokens'][terminated]
            finalized_scores = decoder_out['output_scores'][terminated]
            finalized_attn = None if decoder_out[
                'attn'] is None else decoder_out['attn'][terminated]
            finalized_const_del = None if decoder_out[
                'const_del'] is None else decoder_out['const_del'][terminated]
            finalized_const_ins = None if decoder_out[
                'const_ins'] is None else decoder_out['const_ins'][terminated]

            for i in range(finalized_idxs.size(0)):
                finalized[finalized_idxs[i]] = [
                    finalized_hypos(
                        step, finalized_tokens[i], finalized_scores[i],
                        None if finalized_attn is None else finalized_attn[i],
                        None if finalized_const_del is None else
                        finalized_const_del[i],
                        None if finalized_const_ins is None else
                        finalized_const_ins[i], src_tokens[finalized_idxs[i]])
                ]
            # check if all terminated
            if terminated.sum() == terminated.size(0):
                break

            # for next step
            prev_decoder_out = _skip(decoder_out, ~terminated)
            encoder_out = _skip(encoder_out, ~terminated)
            sent_idxs = _skip(sent_idxs, ~terminated)

            prev_out_tokens = prev_decoder_out['output_tokens'].clone()

        return finalized