示例#1
0
class Model(nn.Module):
    """Model"""
    def __init__(self, config):
        super(Model, self).__init__()
        self.config = config

        self.init_embeddings()
        self.init_model()

    def init_embeddings(self):
        embed_dim = self.config['embed_dim']
        tie_mode = self.config['tie_mode']
        fix_norm = self.config['fix_norm']
        max_pos_length = self.config['max_pos_length']
        learned_pos = self.config['learned_pos']

        # get positonal embedding
        if not learned_pos:
            self.pos_embedding = ut.get_positional_encoding(
                embed_dim, max_pos_length)
        else:
            self.pos_embedding = Parameter(
                torch.Tensor(max_pos_length, embed_dim))
            nn.init.normal_(self.pos_embedding, mean=0, std=embed_dim**-0.5)

        # get word embeddings
        src_vocab_size, trg_vocab_size = ut.get_vocab_sizes(self.config)
        self.src_vocab_mask, self.trg_vocab_mask = ut.get_vocab_masks(
            self.config, src_vocab_size, trg_vocab_size)
        if tie_mode == ac.ALL_TIED:
            src_vocab_size = trg_vocab_size = self.trg_vocab_mask.shape[0]

        self.out_bias = Parameter(torch.Tensor(trg_vocab_size))
        nn.init.constant_(self.out_bias, 0.)

        self.src_embedding = nn.Embedding(src_vocab_size, embed_dim)
        self.trg_embedding = nn.Embedding(trg_vocab_size, embed_dim)
        self.out_embedding = self.trg_embedding.weight
        self.embed_scale = embed_dim**0.5

        if tie_mode == ac.ALL_TIED:
            self.src_embedding.weight = self.trg_embedding.weight

        if not fix_norm:
            nn.init.normal_(self.src_embedding.weight,
                            mean=0,
                            std=embed_dim**-0.5)
            nn.init.normal_(self.trg_embedding.weight,
                            mean=0,
                            std=embed_dim**-0.5)
        else:
            d = 0.01  # pure magic
            nn.init.uniform_(self.src_embedding.weight, a=-d, b=d)
            nn.init.uniform_(self.trg_embedding.weight, a=-d, b=d)

    def init_model(self):
        num_enc_layers = self.config['num_enc_layers']
        num_enc_heads = self.config['num_enc_heads']
        num_dec_layers = self.config['num_dec_layers']
        num_dec_heads = self.config['num_dec_heads']

        embed_dim = self.config['embed_dim']
        ff_dim = self.config['ff_dim']
        dropout = self.config['dropout']
        norm_in = self.config['norm_in']

        # get encoder, decoder
        self.encoder = Encoder(num_enc_layers,
                               num_enc_heads,
                               embed_dim,
                               ff_dim,
                               dropout=dropout,
                               norm_in=norm_in)
        self.decoder = Decoder(num_dec_layers,
                               num_dec_heads,
                               embed_dim,
                               ff_dim,
                               dropout=dropout,
                               norm_in=norm_in)

        # leave layer norm alone
        init_func = nn.init.xavier_normal_ if self.config[
            'weight_init_type'] == ac.XAVIER_NORMAL else nn.init.xavier_uniform_
        for m in [
                self.encoder.self_atts, self.encoder.pos_ffs,
                self.decoder.self_atts, self.decoder.pos_ffs,
                self.decoder.enc_dec_atts
        ]:
            for p in m.parameters():
                if p.dim() > 1:
                    init_func(p)
                else:
                    nn.init.constant_(p, 0.)

    def get_input(self, toks, is_src=True):
        embeds = self.src_embedding if is_src else self.trg_embedding
        word_embeds = embeds(toks)  # [bsz, max_len, embed_dim]
        if self.config['fix_norm']:
            word_embeds = ut.normalize(word_embeds, scale=False)
        else:
            word_embeds = word_embeds * self.embed_scale

        if toks.size()[-1] > self.pos_embedding.size()[-2]:
            ut.get_logger().error(
                "Sentence length ({}) is longer than max_pos_length ({}); please increase max_pos_length"
                .format(toks.size()[-1],
                        self.pos_embedding.size()[0]))

        pos_embeds = self.pos_embedding[:toks.size()[-1], :].unsqueeze(
            0)  # [1, max_len, embed_dim]
        return word_embeds + pos_embeds

    def forward(self, src_toks, trg_toks, targets):
        encoder_mask = (src_toks == ac.PAD_ID).unsqueeze(1).unsqueeze(
            2)  # [bsz, 1, 1, max_src_len]
        decoder_mask = torch.triu(torch.ones(
            (trg_toks.size()[-1], trg_toks.size()[-1])),
                                  diagonal=1).type(trg_toks.type()) == 1
        decoder_mask = decoder_mask.unsqueeze(0).unsqueeze(1)

        encoder_inputs = self.get_input(src_toks, is_src=True)
        encoder_outputs = self.encoder(encoder_inputs, encoder_mask)

        decoder_inputs = self.get_input(trg_toks, is_src=False)
        decoder_outputs = self.decoder(decoder_inputs, decoder_mask,
                                       encoder_outputs, encoder_mask)

        logits = self.logit_fn(decoder_outputs)
        neglprobs = F.log_softmax(logits, -1)
        neglprobs = neglprobs * self.trg_vocab_mask.type(
            neglprobs.type()).reshape(1, -1)
        targets = targets.reshape(-1, 1)
        non_pad_mask = targets != ac.PAD_ID
        nll_loss = -neglprobs.gather(dim=-1, index=targets)[non_pad_mask]
        smooth_loss = -neglprobs.sum(dim=-1, keepdim=True)[non_pad_mask]

        nll_loss = nll_loss.sum()
        smooth_loss = smooth_loss.sum()
        label_smoothing = self.config['label_smoothing']

        if label_smoothing > 0:
            loss = (
                1.0 - label_smoothing
            ) * nll_loss + label_smoothing * smooth_loss / self.trg_vocab_mask.type(
                smooth_loss.type()).sum()
        else:
            loss = nll_loss

        return {'loss': loss, 'nll_loss': nll_loss}

    def logit_fn(self, decoder_output):
        softmax_weight = self.out_embedding if not self.config[
            'fix_norm'] else ut.normalize(self.out_embedding, scale=True)
        logits = F.linear(decoder_output, softmax_weight, bias=self.out_bias)
        logits = logits.reshape(-1, logits.size()[-1])
        logits[:, ~self.trg_vocab_mask] = -1e9
        return logits

    def beam_decode(self, src_toks):
        """Translate a minibatch of sentences. 

        Arguments: src_toks[i,j] is the jth word of sentence i.

        Return: See encoders.Decoder.beam_decode
        """
        encoder_mask = (src_toks == ac.PAD_ID).unsqueeze(1).unsqueeze(
            2)  # [bsz, 1, 1, max_src_len]
        encoder_inputs = self.get_input(src_toks, is_src=True)
        encoder_outputs = self.encoder(encoder_inputs, encoder_mask)
        max_lengths = torch.sum(src_toks != ac.PAD_ID, dim=-1).type(
            src_toks.type()) + 50

        def get_trg_inp(ids, time_step):
            ids = ids.type(src_toks.type())
            word_embeds = self.trg_embedding(ids)
            if self.config['fix_norm']:
                word_embeds = ut.normalize(word_embeds, scale=False)
            else:
                word_embeds = word_embeds * self.embed_scale

            pos_embeds = self.pos_embedding[time_step, :].reshape(1, 1, -1)
            return word_embeds + pos_embeds

        def logprob(decoder_output):
            return F.log_softmax(self.logit_fn(decoder_output), dim=-1)

        if self.config['length_model'] == 'gnmt':
            length_model = ut.gnmt_length_model(self.config['length_alpha'])
        elif self.config['length_model'] == 'linear':
            length_model = lambda t, p: p + self.config['length_alpha'] * t
        elif self.config['length_model'] == 'none':
            length_model = lambda t, p: p
        else:
            raise ValueError("invalid length_model '{}'".format(
                self.config['length_model']))

        return self.decoder.beam_decode(encoder_outputs,
                                        encoder_mask,
                                        get_trg_inp,
                                        logprob,
                                        length_model,
                                        ac.BOS_ID,
                                        ac.EOS_ID,
                                        max_lengths,
                                        beam_size=self.config['beam_size'])
class Transformer(nn.Module):
    """Transformer https://arxiv.org/pdf/1706.03762.pdf"""
    def __init__(self, args):
        super(Transformer, self).__init__()
        self.args = args

        embed_dim = args.embed_dim
        fix_norm = args.fix_norm
        joint_vocab_size = args.joint_vocab_size
        lang_vocab_size = args.lang_vocab_size
        use_bias = args.use_bias
        self.scale = embed_dim**0.5

        if args.mask_logit:
            # mask logits separately per language
            self.logit_mask = None
        else:
            # otherwise, use the same mask for all
            # this only masks out BOS and PAD
            mask = [1.] * joint_vocab_size
            mask[ac.BOS_ID] = 0.
            mask[ac.PAD_ID] = 0.
            self.logit_mask = torch.tensor(mask).type(torch.uint8)

        self.word_embedding = Parameter(
            torch.Tensor(joint_vocab_size, embed_dim))
        self.lang_embedding = Parameter(
            torch.Tensor(lang_vocab_size, embed_dim))
        self.out_bias = Parameter(
            torch.Tensor(joint_vocab_size)) if use_bias else None

        self.encoder = Encoder(args)
        self.decoder = Decoder(args)

        # initialize
        nn.init.normal_(self.lang_embedding, mean=0, std=embed_dim**-0.5)
        if fix_norm:
            d = 0.01
            nn.init.uniform_(self.word_embedding, a=-d, b=d)
        else:
            nn.init.normal_(self.word_embedding, mean=0, std=embed_dim**-0.5)

        if use_bias:
            nn.init.constant_(self.out_bias, 0.)

    def replace_with_unk(self, toks):
        # word-dropout
        p = self.args.word_dropout
        if self.training and 0 < p < 1:
            non_pad_mask = toks != ac.PAD_ID
            mask = (torch.rand(toks.size()) <= p).type(non_pad_mask.type())
            mask = (mask + non_pad_mask) >= 2
            toks[mask] = ac.UNK_ID

    def get_input(self, toks, lang_idx, word_embedding, pos_embedding):
        # word dropout, but replace with unk instead of zero-ing embed
        self.replace_with_unk(toks)
        word_embed = F.embedding(
            toks, word_embedding) * self.scale  # [bsz, len, dim]
        lang_embed = self.lang_embedding[lang_idx].unsqueeze(0).unsqueeze(
            1)  # [1, 1, dim]
        pos_embed = pos_embedding[:toks.size(-1), :].unsqueeze(
            0)  # [1, len, dim]

        return word_embed + lang_embed + pos_embed

    def forward(self, src, tgt, targets, src_lang_idx, tgt_lang_idx,
                logit_mask):
        embed_dim = self.args.embed_dim
        max_len = max(src.size(1), tgt.size(1))
        pos_embedding = ut.get_positional_encoding(embed_dim, max_len)
        word_embedding = F.normalize(
            self.word_embedding,
            dim=-1) if self.args.fix_norm else self.word_embedding

        encoder_inputs = self.get_input(src, src_lang_idx, word_embedding,
                                        pos_embedding)
        encoder_mask = (src == ac.PAD_ID).unsqueeze(1).unsqueeze(2)
        encoder_outputs = self.encoder(encoder_inputs, encoder_mask)

        decoder_inputs = self.get_input(tgt, tgt_lang_idx, word_embedding,
                                        pos_embedding)
        decoder_mask = torch.triu(torch.ones((tgt.size(-1), tgt.size(-1))),
                                  diagonal=1).type(tgt.type()) == 1
        decoder_mask = decoder_mask.unsqueeze(0).unsqueeze(1)
        decoder_outputs = self.decoder(decoder_inputs, decoder_mask,
                                       encoder_outputs, encoder_mask)

        logit_mask = logit_mask if self.logit_mask is None else self.logit_mask
        logits = self.logit_fn(decoder_outputs, word_embedding, logit_mask)
        neglprobs = F.log_softmax(logits, -1) * logit_mask.type(
            logits.type()).reshape(1, -1)
        targets = targets.reshape(-1, 1)
        non_pad_mask = targets != ac.PAD_ID

        nll_loss = neglprobs.gather(dim=-1, index=targets)[non_pad_mask]
        smooth_loss = neglprobs.sum(dim=-1, keepdim=True)[non_pad_mask]

        # label smoothing: https://arxiv.org/pdf/1701.06548.pdf
        nll_loss = -(nll_loss.sum())
        smooth_loss = -(smooth_loss.sum())
        label_smoothing = self.args.label_smoothing
        if label_smoothing > 0:
            loss = (
                1.0 - label_smoothing
            ) * nll_loss + label_smoothing * smooth_loss / logit_mask.type(
                nll_loss.type()).sum()
        else:
            loss = nll_loss

        num_words = non_pad_mask.type(loss.type()).sum()
        opt_loss = loss / num_words
        return {
            'opt_loss': opt_loss,
            'loss': loss,
            'nll_loss': nll_loss,
            'num_words': num_words
        }

    def logit_fn(self, decoder_output, softmax_weight, logit_mask):
        logits = F.linear(decoder_output, softmax_weight, bias=self.out_bias)
        logits = logits.reshape(-1, logits.size(-1))
        logits[:, ~logit_mask] = -1e9
        return logits

    def beam_decode(self, src, src_lang_idx, tgt_lang_idx, logit_mask):
        embed_dim = self.args.embed_dim
        max_len = src.size(1) + 51
        pos_embedding = ut.get_positional_encoding(embed_dim, max_len)
        word_embedding = F.normalize(
            self.word_embedding,
            dim=-1) if self.args.fix_norm else self.word_embedding
        logit_mask = logit_mask if self.logit_mask is None else self.logit_mask
        tgt_lang_embed = self.lang_embedding[tgt_lang_idx]

        encoder_inputs = self.get_input(src, src_lang_idx, word_embedding,
                                        pos_embedding)
        encoder_mask = (src == ac.PAD_ID).unsqueeze(1).unsqueeze(2)
        encoder_outputs = self.encoder(encoder_inputs, encoder_mask)

        def get_tgt_inp(tgt, time_step):
            word_embed = F.embedding(tgt.type(src.type()),
                                     word_embedding) * self.scale
            pos_embed = pos_embedding[time_step, :].reshape(1, 1, -1)
            return word_embed + tgt_lang_embed + pos_embed

        def logprob_fn(decoder_output):
            logits = self.logit_fn(decoder_output, word_embedding, logit_mask)
            return F.log_softmax(logits, dim=-1)

        # following Attention is all you need, we decode up to src_len + 50 tokens only
        max_lengths = torch.sum(src != ac.PAD_ID, dim=-1).type(src.type()) + 50
        return self.decoder.beam_decode(encoder_outputs,
                                        encoder_mask,
                                        get_tgt_inp,
                                        logprob_fn,
                                        ac.BOS_ID,
                                        ac.EOS_ID,
                                        max_lengths,
                                        beam_size=self.args.beam_size,
                                        alpha=self.args.beam_alpha)
class Model(nn.Module):
    """Model"""
    def __init__(self, config):
        super(Model, self).__init__()
        self.config = config

        self.init_embeddings()
        self.init_model()

    def init_embeddings(self):
        embed_dim = self.config['embed_dim']
        tie_mode = self.config['tie_mode']
        max_pos_length = self.config['max_pos_length']
        learned_pos = self.config['learned_pos']

        # get positonal embedding
        if not learned_pos:
            self.pos_embedding = ut.get_positional_encoding(
                embed_dim, max_pos_length)
        else:
            self.pos_embedding = Parameter(
                torch.Tensor(max_pos_length, embed_dim))
            nn.init.normal_(self.pos_embedding, mean=0, std=embed_dim**-0.5)

        # get word embeddings
        src_vocab_size, trg_vocab_size = ut.get_vocab_sizes(self.config)
        self.src_vocab_mask, self.trg_vocab_mask = ut.get_vocab_masks(
            self.config, src_vocab_size, trg_vocab_size)
        if tie_mode == ac.ALL_TIED:
            src_vocab_size = trg_vocab_size = self.trg_vocab_mask.shape[0]

        self.out_bias = Parameter(torch.Tensor(trg_vocab_size))
        nn.init.constant_(self.out_bias, 0.)

        self.src_embedding = nn.Embedding(src_vocab_size, embed_dim)
        self.trg_embedding = nn.Embedding(trg_vocab_size, embed_dim)
        self.out_embedding = self.trg_embedding.weight
        self.embed_scale = embed_dim**0.5

        if tie_mode == ac.ALL_TIED:
            self.src_embedding.weight = self.trg_embedding.weight

        nn.init.normal_(self.src_embedding.weight, mean=0, std=embed_dim**-0.5)
        nn.init.normal_(self.trg_embedding.weight, mean=0, std=embed_dim**-0.5)

    def init_model(self):
        num_enc_layers = self.config['num_enc_layers']
        num_enc_heads = self.config['num_enc_heads']
        num_dec_layers = self.config['num_dec_layers']
        num_dec_heads = self.config['num_dec_heads']

        embed_dim = self.config['embed_dim']
        ff_dim = self.config['ff_dim']
        dropout = self.config['dropout']

        # get encoder, decoder
        self.encoder = Encoder(num_enc_layers,
                               num_enc_heads,
                               embed_dim,
                               ff_dim,
                               dropout=dropout)
        self.decoder = Decoder(num_dec_layers,
                               num_dec_heads,
                               embed_dim,
                               ff_dim,
                               dropout=dropout)

        # leave layer norm alone
        init_func = nn.init.xavier_normal_ if self.config[
            'init_type'] == ac.XAVIER_NORMAL else nn.init.xavier_uniform_
        for m in [
                self.encoder.self_atts, self.encoder.pos_ffs,
                self.decoder.self_atts, self.decoder.pos_ffs,
                self.decoder.enc_dec_atts
        ]:
            for p in m.parameters():
                if p.dim() > 1:
                    init_func(p)
                else:
                    nn.init.constant_(p, 0.)

    def get_input(self, toks, is_src=True):
        embeds = self.src_embedding if is_src else self.trg_embedding
        word_embeds = embeds(toks)  # [bsz, max_len, embed_dim]
        pos_embeds = self.pos_embedding[:toks.size()[-1], :].unsqueeze(
            0)  # [1, max_len, embed_dim]
        return word_embeds * self.embed_scale + pos_embeds

    def forward(self, src_toks, trg_toks, targets):
        encoder_mask = (src_toks == ac.PAD_ID).unsqueeze(1).unsqueeze(
            2)  # [bsz, 1, 1, max_src_len]
        decoder_mask = torch.triu(torch.ones(
            (trg_toks.size()[-1], trg_toks.size()[-1])),
                                  diagonal=1).type(trg_toks.type()) == 1
        decoder_mask = decoder_mask.unsqueeze(0).unsqueeze(1)

        encoder_inputs = self.get_input(src_toks, is_src=True)
        encoder_outputs = self.encoder(encoder_inputs, encoder_mask)

        decoder_inputs = self.get_input(trg_toks, is_src=False)
        decoder_outputs = self.decoder(decoder_inputs, decoder_mask,
                                       encoder_outputs, encoder_mask)

        logits = self.logit_fn(decoder_outputs)
        neglprobs = F.log_softmax(logits, -1)
        neglprobs = neglprobs * self.trg_vocab_mask.type(
            neglprobs.type()).reshape(1, -1)
        targets = targets.reshape(-1, 1)
        non_pad_mask = targets != ac.PAD_ID
        nll_loss = -neglprobs.gather(dim=-1, index=targets)[non_pad_mask]
        smooth_loss = -neglprobs.sum(dim=-1, keepdim=True)[non_pad_mask]

        nll_loss = nll_loss.sum()
        smooth_loss = smooth_loss.sum()
        label_smoothing = self.config['label_smoothing']
        loss = (
            1.0 - label_smoothing
        ) * nll_loss + label_smoothing * smooth_loss / self.trg_vocab_mask.type(
            smooth_loss.type()).sum()

        return {'loss': loss, 'nll_loss': nll_loss}

    def logit_fn(self, decoder_output):
        logits = F.linear(decoder_output,
                          self.out_embedding,
                          bias=self.out_bias)
        logits = logits.reshape(-1, logits.size()[-1])
        logits[:, ~self.trg_vocab_mask] = -1e9
        return logits

    def beam_decode(self, src_toks):
        encoder_mask = (src_toks == ac.PAD_ID).unsqueeze(1).unsqueeze(
            2)  # [bsz, 1, 1, max_src_len]
        encoder_inputs = self.get_input(src_toks, is_src=True)
        encoder_outputs = self.encoder(encoder_inputs, encoder_mask)
        max_lengths = torch.sum(src_toks != ac.PAD_ID, dim=-1).type(
            src_toks.type()) + 50

        def get_trg_inp(ids, time_step):
            ids = ids.type(src_toks.type())
            word_embeds = self.trg_embedding(ids)
            pos_embeds = self.pos_embedding[time_step, :].reshape(1, 1, -1)
            return word_embeds * self.embed_scale + pos_embeds

        def logprob(decoder_output):
            return F.log_softmax(self.logit_fn(decoder_output), dim=-1)

        return self.decoder.beam_decode(encoder_outputs,
                                        encoder_mask,
                                        get_trg_inp,
                                        logprob,
                                        ac.BOS_ID,
                                        ac.EOS_ID,
                                        max_lengths,
                                        beam_size=self.config['beam_size'],
                                        alpha=self.config['beam_alpha'])