Пример #1
0
    def initialize_output_tokens(self, encoder_out, prev_output_tokens):
        # length prediction
        src_lengths = (~encoder_out.encoder_padding_mask).sum(1, keepdim=True)
        length_tgt = (src_lengths * 1.2 + 10).long()
        max_length_tgt = length_tgt.max()
        bsz = prev_output_tokens.size(0)

        initial_output_tokens = prev_output_tokens.new_ones(
            [bsz, max_length_tgt]) * self.pad
        arange_mat = torch.arange(
            max_length_tgt,
            device=prev_output_tokens.device).squeeze(0).repeat([bsz, 1])
        initial_output_tokens = initial_output_tokens.masked_fill(
            arange_mat < length_tgt - 1, self.unk)
        initial_output_tokens = initial_output_tokens.masked_fill(
            arange_mat.eq(length_tgt - 1), self.eos)

        initial_output_scores = initial_output_tokens.new_zeros(
            *initial_output_tokens.size()).type_as(encoder_out.encoder_out)

        return DecoderOut(output_tokens=initial_output_tokens,
                          output_scores=initial_output_scores,
                          attn=None,
                          step=0,
                          max_step=10000,
                          history=None)
Пример #2
0
    def initialize_output_tokens(self, encoder_out, src_tokens):
        # length prediction
        length_tgt = self.decoder.forward_length_prediction(
            self.decoder.forward_length(normalize=True,
                                        encoder_out=encoder_out),
            encoder_out=encoder_out,
        )

        max_length = length_tgt.clamp_(min=2).max()
        idx_length = utils.new_arange(src_tokens, max_length)

        initial_output_tokens = src_tokens.new_zeros(
            src_tokens.size(0), max_length).fill_(self.pad)
        initial_output_tokens.masked_fill_(
            idx_length[None, :] < length_tgt[:, None], self.unk)
        initial_output_tokens[:, 0] = self.bos
        initial_output_tokens.scatter_(1, length_tgt[:, None] - 1, self.eos)

        initial_output_scores = initial_output_tokens.new_zeros(
            *initial_output_tokens.size()).type_as(encoder_out.encoder_out)

        return DecoderOut(
            output_tokens=initial_output_tokens,
            output_scores=initial_output_scores,
            attn=None,
            step=0,
            max_step=0,
            history=None,
        )
Пример #3
0
    def initialize_output_tokens(self, encoder_out, src_tokens):
        # length prediction
        enc_feats = encoder_out.encoder_out  # T x B x C
        src_masks = encoder_out.encoder_padding_mask  # B x T or None
        if src_masks is None:
            src_lengs = enc_feats.new_ones(enc_feats.size(1)).fill_(
                enc_feats.size(0))
        else:
            src_lengs = (~src_masks).transpose(0, 1).type_as(enc_feats).sum(0)
        src_lengs = src_lengs.long()

        length_tgt = src_lengs * 2

        max_length = length_tgt.clamp_(min=2).max()
        idx_length = utils.new_arange(src_tokens, max_length)

        initial_output_tokens = src_tokens.new_zeros(
            src_tokens.size(0), max_length).fill_(self.pad)
        initial_output_tokens.masked_fill_(
            idx_length[None, :] < length_tgt[:, None], self.unk)
        initial_output_tokens[:, 0] = self.bos
        initial_output_tokens.scatter_(1, length_tgt[:, None] - 1, self.eos)

        initial_output_scores = initial_output_tokens.new_zeros(
            *initial_output_tokens.size()).type_as(encoder_out.encoder_out)

        return DecoderOut(
            output_tokens=initial_output_tokens,
            output_scores=initial_output_scores,
            attn=None,
            step=0,
            max_step=0,
            history=None,
        )
Пример #4
0
    def initialize_output_tokens(self, encoder_out, src_tokens):
        initial_output_tokens = src_tokens.new_zeros(src_tokens.size(0), 2)
        initial_output_tokens[:, 0] = self.bos
        initial_output_tokens[:, 1] = self.eos

        initial_output_scores = initial_output_tokens.new_zeros(
            *initial_output_tokens.size()).type_as(encoder_out.encoder_out)
        return DecoderOut(output_tokens=initial_output_tokens,
                          output_scores=initial_output_scores,
                          attn=None,
                          step=0,
                          max_step=0,
                          history=None)
    def initialize_output_tokens(self, encoder_out, src_tokens, tgt_lang):
        # length prediction
        length_tgt = self.decoder.forward_length_prediction(
            self.decoder.forward_length(normalize=True,
                                        encoder_out=encoder_out),
            tgt_lengths=None)
        print("predict tgt_lengths: ", length_tgt)
        max_length = length_tgt.clamp_(min=2).max()
        idx_length = utils.new_arange(src_tokens, max_length)

        positions = torch.arange(1, max_length + 1)[None, :].repeat(
            src_tokens.size(0), 1).to(src_tokens.device)
        positions.masked_fill_(idx_length[None, :] + 1 > length_tgt[:, None],
                               0)

        initial_output_tokens = src_tokens.new_zeros(
            src_tokens.size(0), max_length).long().fill_(self.pad)
        initial_output_tokens.masked_fill_(
            idx_length[None, :] < length_tgt[:, None], self.unk)
        initial_output_tokens[:, 0] = self.bos
        if tgt_lang == "en":
            initial_output_tokens[:, 1] = self.en_tag
            langs = src_tokens.new_zeros(src_tokens.size(0), max_length).long()
        elif tgt_lang == "ch":
            initial_output_tokens[:, 1] = self.ch_tag
            langs = src_tokens.new_ones(src_tokens.size(0), max_length).long()
        else:
            assert tgt_lang == ("en", "ch")
            pass
        initial_output_tokens.scatter_(1, length_tgt[:, None] - 1, self.eos)
        initial_output_scores = initial_output_tokens.new_zeros(
            *initial_output_tokens.size()).type_as(encoder_out.encoder_out)

        return langs, positions, DecoderOut(
            output_tokens=initial_output_tokens,
            output_scores=initial_output_scores,
            attn=None,
            step=0,
            max_step=0,
            history=None)
Пример #6
0
    def initialize_output_tokens(self, encoder_out, src_tokens, initial_tokens,
                                 initial_marks):
        initial_tokens = initial_tokens.tolist(
        ) if initial_tokens is not None else None
        initial_marks = initial_marks.tolist(
        ) if initial_marks is not None else None
        max_num_constraints = max([len(seq) for seq in initial_tokens
                                   ]) if initial_tokens else 0
        initial_output_marks = src_tokens.new_zeros(src_tokens.size(0),
                                                    max_num_constraints + 2)
        initial_output_tokens = src_tokens.new_zeros(src_tokens.size(0),
                                                     max_num_constraints + 2)
        initial_output_tokens[:, 0] = self.bos
        initial_output_tokens[:, 1] = self.eos

        if initial_tokens:
            for i, seq in enumerate(initial_tokens):
                for j, tok in enumerate(seq):
                    initial_output_tokens[i, j + 1] = tok
                initial_output_tokens[i, len(seq) + 1] = self.eos
                for j in range(len(seq) + 2, max_num_constraints + 2):
                    initial_output_tokens[i, j] = self.pad

        if initial_marks:
            for i, seq in enumerate(initial_marks):
                for j, mark in enumerate(seq):
                    initial_output_marks[i, j + 1] = mark

        initial_output_scores = initial_output_tokens.new_zeros(
            *initial_output_tokens.size()).type_as(encoder_out.encoder_out)

        return DecoderOut(output_tokens=initial_output_tokens,
                          output_marks=initial_output_marks,
                          output_scores=initial_output_scores,
                          attn=None,
                          step=0,
                          max_step=0,
                          num_ops=(0, 0),
                          history=None)
Пример #7
0
    def initialize_output_tokens(self, encoder_out, src_tokens):
        initial_output_tokens = torch.cat(
            [
                src_tokens.new_zeros(src_tokens.size(0), 1).fill_(self.bos),
                src_tokens.new_zeros(src_tokens.size(0), 1).fill_(self.eos),
            ],
            1,
        )

        initial_output_scores = torch.zeros_like(initial_output_tokens).to(
            encoder_out.encoder_out)

        initial_attn = torch.empty([0])
        if getattr(self.decoder.layers[-1], "need_attn", True):
            initial_attn = torch.zeros(
                [src_tokens.size(0), 2,
                 src_tokens.size(1)]).to(initial_output_tokens)

        return DecoderOut(output_tokens=initial_output_tokens,
                          output_scores=initial_output_scores,
                          attn=initial_attn,
                          step=0,
                          max_step=0,
                          history=None)