Exemple #1
0
    def regenerate_length_beam(self, decoder_out, beam_size):
        output_tokens = decoder_out.output_tokens
        length_tgt = output_tokens.ne(self.pad).sum(1)
        length_tgt = length_tgt[:, None] + utils.new_arange(length_tgt, 1, beam_size) - beam_size // 2
        length_tgt = length_tgt.view(-1).clamp_(min=2)
        max_length = length_tgt.max()
        idx_length = utils.new_arange(length_tgt, max_length)

        initial_output_tokens = output_tokens.new_zeros(
            length_tgt.size(0), max_length
        ).fill_(self.pad)
        initial_output_tokens.masked_fill_(
            idx_length[None, :] < length_tgt[:, None], self.unk
        )
        initial_output_tokens[:, 0] = self.bos
        initial_output_tokens.scatter_(1, length_tgt[:, None] - 1, self.eos)

        initial_output_scores = initial_output_tokens.new_zeros(
            *initial_output_tokens.size()
        ).type_as(decoder_out.output_scores)

        return decoder_out._replace(
            output_tokens=initial_output_tokens,
            output_scores=initial_output_scores
        )
Exemple #2
0
def _skeptical_unmasking(output_scores, output_masks, p):
    """ Non-consecutive Decoding """
    if p >= 0.5:
        sorted_index = output_scores.sort(-1)[1]
        boundary_len = (
            (output_masks.sum(1, keepdim=True).type_as(output_scores) - 2) *
            p).long()
        skeptical_mask = new_arange(output_masks) < boundary_len
        return skeptical_mask.scatter(1, sorted_index, skeptical_mask)

    mask_ratio = p

    # large number gauranteed to be sorted to last
    _large_num = output_scores.shape[1] * 10
    # we don't want to mask pad or eos
    target_masks = output_masks
    # create random sampled scores, which is later used for selecting mask locations
    target_score = output_scores
    # get length for each example in the batch
    target_length = target_masks.sum(1).float()

    # get num of masks & unmasks for each example in the batch
    mask_nums = (target_length * mask_ratio).long() + 1
    unmask_nums = target_length - mask_nums + 1
    # create a binary mask where 1 means unmasks
    unmask_cutoff = new_arange(target_score) < unmask_nums[:, None]

    # make indices larger than B be sorted last -> we will only sample from 0~B
    target_score.masked_fill_(~unmask_cutoff, _large_num)

    # sorting for the top locations
    _, target_rank = target_score.sort(1)

    # create a binary mask where 1 means masks
    mask_cutoff = new_arange(target_score) < mask_nums[:, None]

    # sort the desired indices while discarding the rest
    mask_pre, _ = target_rank.masked_fill(~mask_cutoff, _large_num).sort(1)
    # add the cumsum to indices to transform to sequence indices
    mask_mid = mask_pre + new_arange(mask_pre)
    # replace the discarded part with duplicated first column -> ensures correctness.
    duped = mask_mid[:, :1].expand(*mask_mid.shape)
    mask_fin = mask_mid * mask_cutoff + duped * (~mask_cutoff)

    # import pdb
    # pdb.set_trace()

    # scatter 1 to locations indicated by mask_fin, then fill
    skeptical_mask = mask_cutoff.new_zeros(mask_cutoff.size())
    return skeptical_mask.scatter(1, mask_fin, 1)
Exemple #3
0
    def initialize_output_tokens(self, encoder_out, src_tokens):
        # length prediction
        length_tgt = self.decoder.forward_length_prediction(
            self.decoder.forward_length(normalize=True,
                                        encoder_out=encoder_out),
            encoder_out=encoder_out,
        )

        max_length = length_tgt.clamp_(min=2).max()
        idx_length = utils.new_arange(src_tokens, max_length)

        initial_output_tokens = src_tokens.new_zeros(
            src_tokens.size(0), max_length).fill_(self.pad)
        initial_output_tokens.masked_fill_(
            idx_length[None, :] < length_tgt[:, None], self.unk)
        initial_output_tokens[:, 0] = self.bos
        initial_output_tokens.scatter_(1, length_tgt[:, None] - 1, self.eos)

        initial_output_scores = initial_output_tokens.new_zeros(
            *initial_output_tokens.size()).type_as(encoder_out.encoder_out)

        return DecoderOut(
            output_tokens=initial_output_tokens,
            output_scores=initial_output_scores,
            attn=None,
            step=0,
            max_step=0,
            history=None,
        )
    def rerank(self, reranker, finalized, encoder_input, beam_size):

        def rebuild_batch(finalized):
            finalized_tokens = [f[0]['tokens'] for f in finalized]
            finalized_maxlen = max(f.size(0) for f in finalized_tokens)
            final_output_tokens = finalized_tokens[0].new_zeros(len(finalized_tokens), finalized_maxlen).fill_(self.pad)
            for i, f in enumerate(finalized_tokens):
                final_output_tokens[i, :f.size(0)] = f
            return final_output_tokens

        final_output_tokens = rebuild_batch(finalized)
        final_output_tokens[:, 0] = self.eos  # autoregressive model assumes starting with EOS

        reranker_encoder_out = reranker.encoder(*encoder_input)
        length_beam_order = utils.new_arange(
            final_output_tokens, beam_size, reranker_encoder_out.encoder_out.size(1)).t().reshape(-1)
        reranker_encoder_out = reranker.encoder.reorder_encoder_out(reranker_encoder_out, length_beam_order)
        reranking_scores = reranker.get_normalized_probs(
            reranker.decoder(final_output_tokens[:, :-1], reranker_encoder_out), True, None)
        reranking_scores = reranking_scores.gather(2, final_output_tokens[:, 1:, None])
        reranking_masks = final_output_tokens[:, 1:].ne(self.pad)
        reranking_scores = reranking_scores[:, :, 0].masked_fill_(~reranking_masks, 0).sum(1)
        reranking_scores = reranking_scores / reranking_masks.sum(1).type_as(reranking_scores)

        for i in range(len(finalized)):
            finalized[i][0]['score'] = reranking_scores[i]

        return finalized
def _apply_del_words(
    in_tokens, in_scores, in_attn, word_del_pred, padding_idx, bos_idx, eos_idx
):
    # apply deletion to a tensor
    in_masks = in_tokens.ne(padding_idx)
    bos_eos_masks = in_tokens.eq(bos_idx) | in_tokens.eq(eos_idx)

    max_len = in_tokens.size(1)
    word_del_pred.masked_fill_(~in_masks, 1)
    word_del_pred.masked_fill_(bos_eos_masks, 0)

    reordering = (
        new_arange(in_tokens)
        .masked_fill_(word_del_pred, max_len)
        .sort(1)[1]
    )

    out_tokens = in_tokens.masked_fill(word_del_pred, padding_idx).gather(1, reordering)

    out_scores = None
    if in_scores is not None:
        out_scores = in_scores.masked_fill(word_del_pred, 0).gather(1, reordering)

    out_attn = None
    if in_attn is not None:
        _mask = word_del_pred[:, :, None].expand_as(in_attn)
        _reordering = reordering[:, :, None].expand_as(in_attn)
        out_attn = in_attn.masked_fill(_mask, 0.).gather(1, _reordering)

    return out_tokens, out_scores, out_attn
def _apply_ins_masks(
    in_tokens, in_scores, mask_ins_pred, padding_idx, unk_idx, eos_idx
):

    in_masks = in_tokens.ne(padding_idx)
    in_lengths = in_masks.sum(1)

    # HACK: hacky way to shift all the paddings to eos first.
    in_tokens.masked_fill_(~in_masks, eos_idx)
    mask_ins_pred.masked_fill_(~in_masks[:, 1:], 0)

    out_lengths = in_lengths + mask_ins_pred.sum(1)
    out_max_len = out_lengths.max()
    out_masks = (
        new_arange(out_lengths, out_max_len)[None, :]
        < out_lengths[:, None]
    )

    reordering = (mask_ins_pred + in_masks[:, 1:].long()).cumsum(1)
    out_tokens = (
        in_tokens.new_zeros(in_tokens.size(0), out_max_len)
        .fill_(padding_idx)
        .masked_fill_(out_masks, unk_idx)
    )
    out_tokens[:, 0] = in_tokens[:, 0]
    out_tokens.scatter_(1, reordering, in_tokens[:, 1:])

    out_scores = None
    if in_scores is not None:
        in_scores.masked_fill_(~in_masks, 0)
        out_scores = in_scores.new_zeros(*out_tokens.size())
        out_scores[:, 0] = in_scores[:, 0]
        out_scores.scatter_(1, reordering, in_scores[:, 1:])

    return out_tokens, out_scores
def _skeptical_unmasking(output_scores, output_masks, p):
    sorted_index = output_scores.sort(-1)[1]
    boundary_len = (
        (output_masks.sum(1, keepdim=True).type_as(output_scores) - 2) * p
    ).long()
    skeptical_mask = new_arange(output_masks) < boundary_len
    return skeptical_mask.scatter(1, sorted_index, skeptical_mask)
        def _random_mask(target_tokens):
            pad = self.tgt_dict.pad()
            bos = self.tgt_dict.bos()
            eos = self.tgt_dict.eos()
            unk = self.tgt_dict.unk()
            en_tag = self.tgt_dict.index("<en>")
            ch_tag = self.tgt_dict.index("<ch>")

            target_masks = target_tokens.ne(pad) & \
                           target_tokens.ne(bos) & \
                           target_tokens.ne(eos) & \
                           target_tokens.ne(en_tag) & \
                           target_tokens.ne(ch_tag)
            target_score = target_tokens.clone().float().uniform_()
            target_score.masked_fill_(~target_masks, 2.0)
            target_length = target_masks.sum(1).float()
            target_length = target_length * target_length.clone().uniform_()
            target_length = target_length + 1  # make sure to mask at least one token.

            _, target_rank = target_score.sort(1)
            target_cutoff = new_arange(
                target_rank) < target_length[:, None].long()
            prev_target_tokens = target_tokens.masked_fill(
                target_cutoff.scatter(1, target_rank, target_cutoff), unk)
            return prev_target_tokens
Exemple #9
0
    def initialize_output_tokens(self, encoder_out, src_tokens):
        # length prediction
        enc_feats = encoder_out.encoder_out  # T x B x C
        src_masks = encoder_out.encoder_padding_mask  # B x T or None
        if src_masks is None:
            src_lengs = enc_feats.new_ones(enc_feats.size(1)).fill_(
                enc_feats.size(0))
        else:
            src_lengs = (~src_masks).transpose(0, 1).type_as(enc_feats).sum(0)
        src_lengs = src_lengs.long()

        length_tgt = src_lengs * 2

        max_length = length_tgt.clamp_(min=2).max()
        idx_length = utils.new_arange(src_tokens, max_length)

        initial_output_tokens = src_tokens.new_zeros(
            src_tokens.size(0), max_length).fill_(self.pad)
        initial_output_tokens.masked_fill_(
            idx_length[None, :] < length_tgt[:, None], self.unk)
        initial_output_tokens[:, 0] = self.bos
        initial_output_tokens.scatter_(1, length_tgt[:, None] - 1, self.eos)

        initial_output_scores = initial_output_tokens.new_zeros(
            *initial_output_tokens.size()).type_as(encoder_out.encoder_out)

        return DecoderOut(
            output_tokens=initial_output_tokens,
            output_scores=initial_output_scores,
            attn=None,
            step=0,
            max_step=0,
            history=None,
        )
Exemple #10
0
def _uniform_assignment(src_lens, trg_lens):
    max_trg_len = trg_lens.max()
    steps = (src_lens.float() - 1) / (trg_lens.float() - 1)  # step-size
    # max_trg_len
    index_t = utils.new_arange(trg_lens, max_trg_len).float()
    index_t = steps[:, None] * index_t[None, :]  # batch_size X max_trg_len
    index_t = torch.round(index_t).long().detach()
    return index_t
Exemple #11
0
def _apply_ins_words(in_tokens, in_scores, word_ins_pred, word_ins_scores, padding_idx):

    padding_masks = in_tokens[:, 1:].eq(padding_idx)
    word_ins_scores.masked_fill_(padding_masks, 0.0)
    word_ins_pred.masked_fill_(padding_masks, padding_idx)

    in_coords = new_arange(in_tokens).type_as(in_scores)

    # shift all padding predictions to infinite
    out_coords = (in_coords[:, 1:] - 0.5).masked_fill(
        word_ins_pred.eq(padding_idx), float("inf")
    )
    out_coords = torch.cat([in_coords, out_coords], 1).sort(-1)[1]
    out_tokens = torch.cat([in_tokens, word_ins_pred], 1).gather(1, out_coords)
    out_scores = torch.cat([in_scores, word_ins_scores], 1).gather(1, out_coords)
    return out_tokens, out_scores
        def _random_mask(target_tokens):
            target_masks = target_tokens.ne(pad) & \
                target_tokens.ne(bos) & \
                target_tokens.ne(eos)
            target_score = target_tokens.clone().float().uniform_()
            target_score.masked_fill_(~target_masks, 2.0)
            target_length = target_masks.sum(1).float()
            target_length = target_length * target_length.clone().uniform_()
            # make sure to mask at least one token.
            target_length = target_length + 1

            _, target_rank = target_score.sort(1)
            target_cutoff = new_arange(
                target_rank) < target_length[:, None].long()
            prev_target_tokens = target_tokens.masked_fill(
                target_cutoff.scatter(1, target_rank, target_cutoff), unk)
            return prev_target_tokens
    def initialize_output_tokens(self, encoder_out, src_tokens):
        # length prediction
        _, length_tgt = self.decoder.forward_length_prediction(encoder_out)
        max_length = length_tgt.max()
        idx_length = utils.new_arange(src_tokens, max_length)

        initial_output_tokens = src_tokens.new_zeros(
            src_tokens.size(0), max_length).fill_(self.pad)
        initial_output_tokens.masked_fill_(
            idx_length[None, :] < length_tgt[:, None], self.unk)
        initial_output_tokens[:, 0] = self.bos
        initial_output_tokens.scatter_(1, length_tgt[:, None] - 1, self.eos)

        initial_output_scores = initial_output_tokens.new_zeros(
            *initial_output_tokens.size()).type_as(encoder_out["encoder_out"])

        return {
            "output_tokens": initial_output_tokens,
            "output_scores": initial_output_scores,
            "attn": None
        }
    def initialize_output_tokens(self, encoder_out, src_tokens, tgt_lang):
        # length prediction
        length_tgt = self.decoder.forward_length_prediction(
            self.decoder.forward_length(normalize=True,
                                        encoder_out=encoder_out),
            tgt_lengths=None)
        print("predict tgt_lengths: ", length_tgt)
        max_length = length_tgt.clamp_(min=2).max()
        idx_length = utils.new_arange(src_tokens, max_length)

        positions = torch.arange(1, max_length + 1)[None, :].repeat(
            src_tokens.size(0), 1).to(src_tokens.device)
        positions.masked_fill_(idx_length[None, :] + 1 > length_tgt[:, None],
                               0)

        initial_output_tokens = src_tokens.new_zeros(
            src_tokens.size(0), max_length).long().fill_(self.pad)
        initial_output_tokens.masked_fill_(
            idx_length[None, :] < length_tgt[:, None], self.unk)
        initial_output_tokens[:, 0] = self.bos
        if tgt_lang == "en":
            initial_output_tokens[:, 1] = self.en_tag
            langs = src_tokens.new_zeros(src_tokens.size(0), max_length).long()
        elif tgt_lang == "ch":
            initial_output_tokens[:, 1] = self.ch_tag
            langs = src_tokens.new_ones(src_tokens.size(0), max_length).long()
        else:
            assert tgt_lang == ("en", "ch")
            pass
        initial_output_tokens.scatter_(1, length_tgt[:, None] - 1, self.eos)
        initial_output_scores = initial_output_tokens.new_zeros(
            *initial_output_tokens.size()).type_as(encoder_out.encoder_out)

        return langs, positions, DecoderOut(
            output_tokens=initial_output_tokens,
            output_scores=initial_output_scores,
            attn=None,
            step=0,
            max_step=0,
            history=None)
    def generate(self, models, sample, prefix_tokens=None):

        # TODO: iterative refinement generator does not support ensemble for now.
        if not self.retain_dropout:
            for model in models:
                model.eval()

        model, reranker = models[0], None
        if self.reranking:
            assert len(models) > 1, "Assuming the last checkpoint is the reranker"
            assert self.beam_size > 1, "Reranking requires multiple translation for each example"

            reranker = models[-1]
            models = models[:-1]

        if len(models) > 1 and hasattr(model, 'enable_ensemble'):
            assert model.allow_ensemble, "{} does not support ensembling".format(model.__class__.__name__)
            model.enable_ensemble(models)

        # TODO: better encoder inputs?
        src_tokens = sample["net_input"]["src_tokens"]
        src_lengths = sample["net_input"]["src_lengths"]
        bsz, src_len = src_tokens.size()

        # initialize
        encoder_out = model.forward_encoder([src_tokens, src_lengths])
        prev_decoder_out = model.initialize_output_tokens(encoder_out, src_tokens)

        if self.beam_size > 1:
            assert model.allow_length_beam, \
                "{} does not support decoding with length beam.".format(model.__class__.__name__)

            # regenerate data based on length-beam
            length_beam_order = utils.new_arange(src_tokens, self.beam_size, bsz).t().reshape(-1)
            encoder_out = model.encoder.reorder_encoder_out(encoder_out, length_beam_order)
            prev_decoder_out = model.regenerate_length_beam(prev_decoder_out, self.beam_size)
            bsz = bsz * self.beam_size

        sent_idxs = torch.arange(bsz)
        prev_output_tokens = prev_decoder_out.output_tokens.clone()

        if self.retain_history:
            prev_decoder_out = prev_decoder_out._replace(history=[prev_output_tokens])

        finalized = [[] for _ in range(bsz)]

        def is_a_loop(x, y, s, a):
            b, l_x, l_y = x.size(0), x.size(1), y.size(1)
            if l_x > l_y:
                y = torch.cat([y, x.new_zeros(b, l_x - l_y).fill_(self.pad)], 1)
                s = torch.cat([s, s.new_zeros(b, l_x - l_y)], 1)
                if a is not None:
                    a = torch.cat([a, a.new_zeros(b, l_x - l_y, a.size(2))], 1)
            elif l_x < l_y:
                x = torch.cat([x, y.new_zeros(b, l_y - l_x).fill_(self.pad)], 1)
            return (x == y).all(1), y, s, a

        def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn):
            cutoff = prev_out_token.ne(self.pad) & prev_out_token.ne(self.bos) & prev_out_token.ne(self.eos)
            tokens = prev_out_token[cutoff]
            if prev_out_score is None:
                scores, score = None, None
            else:
                scores = prev_out_score[cutoff]
                score = scores.mean()

            if prev_out_attn is None:
                hypo_attn, alignment = None, None
            else:
                hypo_attn = prev_out_attn[cutoff]
                alignment = hypo_attn.max(dim=1)[1]
            return {
                "steps": step,
                "tokens": tokens,
                "positional_scores": scores,
                "score": score,
                "hypo_attn": hypo_attn,
                "alignment": alignment,
            }

        for step in range(self.max_iter + 1):

            decoder_options = {
                "eos_penalty": self.eos_penalty,
                "max_ratio": self.max_ratio,
                "decoding_format": self.decoding_format,
            }
            prev_decoder_out = prev_decoder_out._replace(
                step=step,
                max_step=self.max_iter + 1,
            )

            decoder_out = model.forward_decoder(
                prev_decoder_out, encoder_out, **decoder_options
            )

            if self.adaptive:
                # terminate if there is a loop
                terminated, out_tokens, out_scores, out_attn = is_a_loop(
                    prev_output_tokens, decoder_out.output_tokens, decoder_out.output_scores, decoder_out.attn
                )

                if self.decoding_format == 'easy-first':
                    # Parallel Easy-First that exits early.
                    # If the best sequence terminated in beam, terminated all.
                    new_bsz, seq_len = out_tokens.size()
                    nb_sents = int(new_bsz/self.beam_size)
                    out_tokens_beam = out_tokens.view(nb_sents, self.beam_size, seq_len)
                    out_scores_beam = out_scores.view(nb_sents, self.beam_size, seq_len)
                    score_mask = out_tokens_beam.ne(self.pad) & out_tokens_beam.ne(
                        self.bos) & out_tokens_beam.ne(self.eos)
                    out_scores_avg_beam = (score_mask*out_scores_beam).sum(2) /(score_mask.sum(2))
                    _, max_idxes = out_scores_avg_beam.max(1)
                    terminated_beam = terminated.view(nb_sents, self.beam_size)
                    terminated = terminated_beam[torch.arange(
                        nb_sents, device=terminated_beam.device), max_idxes]
                    # [nb_sents]
                    if terminated.any():
                        new_out_scores_beam = score_mask*(torch.ones_like(out_scores_beam)*(-1000))
                        new_out_scores_beam[terminated, max_idxes[terminated]] = out_scores_beam[terminated, max_idxes[terminated]]
                        out_scores_beam = new_out_scores_beam
                    out_scores = out_scores_beam.view(new_bsz, seq_len)
                    terminated = terminated.unsqueeze(1).repeat([1, self.beam_size]).view(new_bsz)

                decoder_out = decoder_out._replace(
                    output_tokens=out_tokens,
                    output_scores=out_scores,
                    attn=out_attn,
                )

            else:
                terminated = decoder_out.output_tokens.new_zeros(decoder_out.output_tokens.size(0)).bool()

            if step == self.max_iter:  # reach last iteration, terminate
                terminated.fill_(1)


            # collect finalized sentences
            finalized_idxs = sent_idxs[terminated]
            finalized_tokens = decoder_out.output_tokens[terminated]
            finalized_scores = decoder_out.output_scores[terminated]
            finalized_attn = (
                None if (decoder_out.attn is None or decoder_out.attn.size(0) == 0) else decoder_out.attn[terminated]
            )

            if self.retain_history:
                finalized_history_tokens = [h[terminated] for h in decoder_out.history]

            for i in range(finalized_idxs.size(0)):
                finalized[finalized_idxs[i]] = [
                    finalized_hypos(
                        step,
                        finalized_tokens[i],
                        finalized_scores[i],
                        None if finalized_attn is None else finalized_attn[i],
                    )
                ]

                if self.retain_history:
                    finalized[finalized_idxs[i]][0]['history'] = []
                    for j in range(len(finalized_history_tokens)):
                        finalized[finalized_idxs[i]][0]['history'].append(
                            finalized_hypos(
                                step,
                                finalized_history_tokens[j][i],
                                None, None
                            )
                        )

            # check if all terminated
            if terminated.sum() == terminated.size(0):
                break

            # for next step
            not_terminated = ~terminated
            prev_decoder_out = decoder_out._replace(
                output_tokens=decoder_out.output_tokens[not_terminated],
                output_scores=decoder_out.output_scores[not_terminated],
                attn=decoder_out.attn[not_terminated]
                if (decoder_out.attn is not None and decoder_out.attn.size(0) > 0)
                else None,
                history=[h[not_terminated] for h in decoder_out.history]
                if decoder_out.history is not None
                else None,
            )
            encoder_out = model.encoder.reorder_encoder_out(encoder_out, not_terminated.nonzero().squeeze())
            sent_idxs = sent_idxs[not_terminated]
            prev_output_tokens = prev_decoder_out.output_tokens.clone()

        if self.beam_size > 1:
            if reranker is not None:
                finalized = self.rerank(
                    reranker, finalized, [src_tokens, src_lengths], self.beam_size
                )

            # aggregate information from length beam
            finalized = [
                finalized[np.argmax(
                    [finalized[self.beam_size * i + j][0]['score'] for j in range(self.beam_size)]
                    ) + self.beam_size * i] for i in range(len(finalized) // self.beam_size)
                ]

        return finalized