class Model(nn.Module): """Model""" def __init__(self, config): super(Model, self).__init__() self.config = config self.init_embeddings() self.init_model() def init_embeddings(self): embed_dim = self.config['embed_dim'] tie_mode = self.config['tie_mode'] fix_norm = self.config['fix_norm'] max_pos_length = self.config['max_pos_length'] learned_pos = self.config['learned_pos'] # get positonal embedding if not learned_pos: self.pos_embedding = ut.get_positional_encoding( embed_dim, max_pos_length) else: self.pos_embedding = Parameter( torch.Tensor(max_pos_length, embed_dim)) nn.init.normal_(self.pos_embedding, mean=0, std=embed_dim**-0.5) # get word embeddings src_vocab_size, trg_vocab_size = ut.get_vocab_sizes(self.config) self.src_vocab_mask, self.trg_vocab_mask = ut.get_vocab_masks( self.config, src_vocab_size, trg_vocab_size) if tie_mode == ac.ALL_TIED: src_vocab_size = trg_vocab_size = self.trg_vocab_mask.shape[0] self.out_bias = Parameter(torch.Tensor(trg_vocab_size)) nn.init.constant_(self.out_bias, 0.) self.src_embedding = nn.Embedding(src_vocab_size, embed_dim) self.trg_embedding = nn.Embedding(trg_vocab_size, embed_dim) self.out_embedding = self.trg_embedding.weight self.embed_scale = embed_dim**0.5 if tie_mode == ac.ALL_TIED: self.src_embedding.weight = self.trg_embedding.weight if not fix_norm: nn.init.normal_(self.src_embedding.weight, mean=0, std=embed_dim**-0.5) nn.init.normal_(self.trg_embedding.weight, mean=0, std=embed_dim**-0.5) else: d = 0.01 # pure magic nn.init.uniform_(self.src_embedding.weight, a=-d, b=d) nn.init.uniform_(self.trg_embedding.weight, a=-d, b=d) def init_model(self): num_enc_layers = self.config['num_enc_layers'] num_enc_heads = self.config['num_enc_heads'] num_dec_layers = self.config['num_dec_layers'] num_dec_heads = self.config['num_dec_heads'] embed_dim = self.config['embed_dim'] ff_dim = self.config['ff_dim'] dropout = self.config['dropout'] norm_in = self.config['norm_in'] # get encoder, decoder self.encoder = Encoder(num_enc_layers, num_enc_heads, embed_dim, ff_dim, dropout=dropout, norm_in=norm_in) self.decoder = Decoder(num_dec_layers, num_dec_heads, embed_dim, ff_dim, dropout=dropout, norm_in=norm_in) # leave layer norm alone init_func = nn.init.xavier_normal_ if self.config[ 'weight_init_type'] == ac.XAVIER_NORMAL else nn.init.xavier_uniform_ for m in [ self.encoder.self_atts, self.encoder.pos_ffs, self.decoder.self_atts, self.decoder.pos_ffs, self.decoder.enc_dec_atts ]: for p in m.parameters(): if p.dim() > 1: init_func(p) else: nn.init.constant_(p, 0.) def get_input(self, toks, is_src=True): embeds = self.src_embedding if is_src else self.trg_embedding word_embeds = embeds(toks) # [bsz, max_len, embed_dim] if self.config['fix_norm']: word_embeds = ut.normalize(word_embeds, scale=False) else: word_embeds = word_embeds * self.embed_scale if toks.size()[-1] > self.pos_embedding.size()[-2]: ut.get_logger().error( "Sentence length ({}) is longer than max_pos_length ({}); please increase max_pos_length" .format(toks.size()[-1], self.pos_embedding.size()[0])) pos_embeds = self.pos_embedding[:toks.size()[-1], :].unsqueeze( 0) # [1, max_len, embed_dim] return word_embeds + pos_embeds def forward(self, src_toks, trg_toks, targets): encoder_mask = (src_toks == ac.PAD_ID).unsqueeze(1).unsqueeze( 2) # [bsz, 1, 1, max_src_len] decoder_mask = torch.triu(torch.ones( (trg_toks.size()[-1], trg_toks.size()[-1])), diagonal=1).type(trg_toks.type()) == 1 decoder_mask = decoder_mask.unsqueeze(0).unsqueeze(1) encoder_inputs = self.get_input(src_toks, is_src=True) encoder_outputs = self.encoder(encoder_inputs, encoder_mask) decoder_inputs = self.get_input(trg_toks, is_src=False) decoder_outputs = self.decoder(decoder_inputs, decoder_mask, encoder_outputs, encoder_mask) logits = self.logit_fn(decoder_outputs) neglprobs = F.log_softmax(logits, -1) neglprobs = neglprobs * self.trg_vocab_mask.type( neglprobs.type()).reshape(1, -1) targets = targets.reshape(-1, 1) non_pad_mask = targets != ac.PAD_ID nll_loss = -neglprobs.gather(dim=-1, index=targets)[non_pad_mask] smooth_loss = -neglprobs.sum(dim=-1, keepdim=True)[non_pad_mask] nll_loss = nll_loss.sum() smooth_loss = smooth_loss.sum() label_smoothing = self.config['label_smoothing'] if label_smoothing > 0: loss = ( 1.0 - label_smoothing ) * nll_loss + label_smoothing * smooth_loss / self.trg_vocab_mask.type( smooth_loss.type()).sum() else: loss = nll_loss return {'loss': loss, 'nll_loss': nll_loss} def logit_fn(self, decoder_output): softmax_weight = self.out_embedding if not self.config[ 'fix_norm'] else ut.normalize(self.out_embedding, scale=True) logits = F.linear(decoder_output, softmax_weight, bias=self.out_bias) logits = logits.reshape(-1, logits.size()[-1]) logits[:, ~self.trg_vocab_mask] = -1e9 return logits def beam_decode(self, src_toks): """Translate a minibatch of sentences. Arguments: src_toks[i,j] is the jth word of sentence i. Return: See encoders.Decoder.beam_decode """ encoder_mask = (src_toks == ac.PAD_ID).unsqueeze(1).unsqueeze( 2) # [bsz, 1, 1, max_src_len] encoder_inputs = self.get_input(src_toks, is_src=True) encoder_outputs = self.encoder(encoder_inputs, encoder_mask) max_lengths = torch.sum(src_toks != ac.PAD_ID, dim=-1).type( src_toks.type()) + 50 def get_trg_inp(ids, time_step): ids = ids.type(src_toks.type()) word_embeds = self.trg_embedding(ids) if self.config['fix_norm']: word_embeds = ut.normalize(word_embeds, scale=False) else: word_embeds = word_embeds * self.embed_scale pos_embeds = self.pos_embedding[time_step, :].reshape(1, 1, -1) return word_embeds + pos_embeds def logprob(decoder_output): return F.log_softmax(self.logit_fn(decoder_output), dim=-1) if self.config['length_model'] == 'gnmt': length_model = ut.gnmt_length_model(self.config['length_alpha']) elif self.config['length_model'] == 'linear': length_model = lambda t, p: p + self.config['length_alpha'] * t elif self.config['length_model'] == 'none': length_model = lambda t, p: p else: raise ValueError("invalid length_model '{}'".format( self.config['length_model'])) return self.decoder.beam_decode(encoder_outputs, encoder_mask, get_trg_inp, logprob, length_model, ac.BOS_ID, ac.EOS_ID, max_lengths, beam_size=self.config['beam_size'])
class Transformer(nn.Module): """Transformer https://arxiv.org/pdf/1706.03762.pdf""" def __init__(self, args): super(Transformer, self).__init__() self.args = args embed_dim = args.embed_dim fix_norm = args.fix_norm joint_vocab_size = args.joint_vocab_size lang_vocab_size = args.lang_vocab_size use_bias = args.use_bias self.scale = embed_dim**0.5 if args.mask_logit: # mask logits separately per language self.logit_mask = None else: # otherwise, use the same mask for all # this only masks out BOS and PAD mask = [1.] * joint_vocab_size mask[ac.BOS_ID] = 0. mask[ac.PAD_ID] = 0. self.logit_mask = torch.tensor(mask).type(torch.uint8) self.word_embedding = Parameter( torch.Tensor(joint_vocab_size, embed_dim)) self.lang_embedding = Parameter( torch.Tensor(lang_vocab_size, embed_dim)) self.out_bias = Parameter( torch.Tensor(joint_vocab_size)) if use_bias else None self.encoder = Encoder(args) self.decoder = Decoder(args) # initialize nn.init.normal_(self.lang_embedding, mean=0, std=embed_dim**-0.5) if fix_norm: d = 0.01 nn.init.uniform_(self.word_embedding, a=-d, b=d) else: nn.init.normal_(self.word_embedding, mean=0, std=embed_dim**-0.5) if use_bias: nn.init.constant_(self.out_bias, 0.) def replace_with_unk(self, toks): # word-dropout p = self.args.word_dropout if self.training and 0 < p < 1: non_pad_mask = toks != ac.PAD_ID mask = (torch.rand(toks.size()) <= p).type(non_pad_mask.type()) mask = (mask + non_pad_mask) >= 2 toks[mask] = ac.UNK_ID def get_input(self, toks, lang_idx, word_embedding, pos_embedding): # word dropout, but replace with unk instead of zero-ing embed self.replace_with_unk(toks) word_embed = F.embedding( toks, word_embedding) * self.scale # [bsz, len, dim] lang_embed = self.lang_embedding[lang_idx].unsqueeze(0).unsqueeze( 1) # [1, 1, dim] pos_embed = pos_embedding[:toks.size(-1), :].unsqueeze( 0) # [1, len, dim] return word_embed + lang_embed + pos_embed def forward(self, src, tgt, targets, src_lang_idx, tgt_lang_idx, logit_mask): embed_dim = self.args.embed_dim max_len = max(src.size(1), tgt.size(1)) pos_embedding = ut.get_positional_encoding(embed_dim, max_len) word_embedding = F.normalize( self.word_embedding, dim=-1) if self.args.fix_norm else self.word_embedding encoder_inputs = self.get_input(src, src_lang_idx, word_embedding, pos_embedding) encoder_mask = (src == ac.PAD_ID).unsqueeze(1).unsqueeze(2) encoder_outputs = self.encoder(encoder_inputs, encoder_mask) decoder_inputs = self.get_input(tgt, tgt_lang_idx, word_embedding, pos_embedding) decoder_mask = torch.triu(torch.ones((tgt.size(-1), tgt.size(-1))), diagonal=1).type(tgt.type()) == 1 decoder_mask = decoder_mask.unsqueeze(0).unsqueeze(1) decoder_outputs = self.decoder(decoder_inputs, decoder_mask, encoder_outputs, encoder_mask) logit_mask = logit_mask if self.logit_mask is None else self.logit_mask logits = self.logit_fn(decoder_outputs, word_embedding, logit_mask) neglprobs = F.log_softmax(logits, -1) * logit_mask.type( logits.type()).reshape(1, -1) targets = targets.reshape(-1, 1) non_pad_mask = targets != ac.PAD_ID nll_loss = neglprobs.gather(dim=-1, index=targets)[non_pad_mask] smooth_loss = neglprobs.sum(dim=-1, keepdim=True)[non_pad_mask] # label smoothing: https://arxiv.org/pdf/1701.06548.pdf nll_loss = -(nll_loss.sum()) smooth_loss = -(smooth_loss.sum()) label_smoothing = self.args.label_smoothing if label_smoothing > 0: loss = ( 1.0 - label_smoothing ) * nll_loss + label_smoothing * smooth_loss / logit_mask.type( nll_loss.type()).sum() else: loss = nll_loss num_words = non_pad_mask.type(loss.type()).sum() opt_loss = loss / num_words return { 'opt_loss': opt_loss, 'loss': loss, 'nll_loss': nll_loss, 'num_words': num_words } def logit_fn(self, decoder_output, softmax_weight, logit_mask): logits = F.linear(decoder_output, softmax_weight, bias=self.out_bias) logits = logits.reshape(-1, logits.size(-1)) logits[:, ~logit_mask] = -1e9 return logits def beam_decode(self, src, src_lang_idx, tgt_lang_idx, logit_mask): embed_dim = self.args.embed_dim max_len = src.size(1) + 51 pos_embedding = ut.get_positional_encoding(embed_dim, max_len) word_embedding = F.normalize( self.word_embedding, dim=-1) if self.args.fix_norm else self.word_embedding logit_mask = logit_mask if self.logit_mask is None else self.logit_mask tgt_lang_embed = self.lang_embedding[tgt_lang_idx] encoder_inputs = self.get_input(src, src_lang_idx, word_embedding, pos_embedding) encoder_mask = (src == ac.PAD_ID).unsqueeze(1).unsqueeze(2) encoder_outputs = self.encoder(encoder_inputs, encoder_mask) def get_tgt_inp(tgt, time_step): word_embed = F.embedding(tgt.type(src.type()), word_embedding) * self.scale pos_embed = pos_embedding[time_step, :].reshape(1, 1, -1) return word_embed + tgt_lang_embed + pos_embed def logprob_fn(decoder_output): logits = self.logit_fn(decoder_output, word_embedding, logit_mask) return F.log_softmax(logits, dim=-1) # following Attention is all you need, we decode up to src_len + 50 tokens only max_lengths = torch.sum(src != ac.PAD_ID, dim=-1).type(src.type()) + 50 return self.decoder.beam_decode(encoder_outputs, encoder_mask, get_tgt_inp, logprob_fn, ac.BOS_ID, ac.EOS_ID, max_lengths, beam_size=self.args.beam_size, alpha=self.args.beam_alpha)
class Model(nn.Module): """Model""" def __init__(self, config): super(Model, self).__init__() self.config = config self.init_embeddings() self.init_model() def init_embeddings(self): embed_dim = self.config['embed_dim'] tie_mode = self.config['tie_mode'] max_pos_length = self.config['max_pos_length'] learned_pos = self.config['learned_pos'] # get positonal embedding if not learned_pos: self.pos_embedding = ut.get_positional_encoding( embed_dim, max_pos_length) else: self.pos_embedding = Parameter( torch.Tensor(max_pos_length, embed_dim)) nn.init.normal_(self.pos_embedding, mean=0, std=embed_dim**-0.5) # get word embeddings src_vocab_size, trg_vocab_size = ut.get_vocab_sizes(self.config) self.src_vocab_mask, self.trg_vocab_mask = ut.get_vocab_masks( self.config, src_vocab_size, trg_vocab_size) if tie_mode == ac.ALL_TIED: src_vocab_size = trg_vocab_size = self.trg_vocab_mask.shape[0] self.out_bias = Parameter(torch.Tensor(trg_vocab_size)) nn.init.constant_(self.out_bias, 0.) self.src_embedding = nn.Embedding(src_vocab_size, embed_dim) self.trg_embedding = nn.Embedding(trg_vocab_size, embed_dim) self.out_embedding = self.trg_embedding.weight self.embed_scale = embed_dim**0.5 if tie_mode == ac.ALL_TIED: self.src_embedding.weight = self.trg_embedding.weight nn.init.normal_(self.src_embedding.weight, mean=0, std=embed_dim**-0.5) nn.init.normal_(self.trg_embedding.weight, mean=0, std=embed_dim**-0.5) def init_model(self): num_enc_layers = self.config['num_enc_layers'] num_enc_heads = self.config['num_enc_heads'] num_dec_layers = self.config['num_dec_layers'] num_dec_heads = self.config['num_dec_heads'] embed_dim = self.config['embed_dim'] ff_dim = self.config['ff_dim'] dropout = self.config['dropout'] # get encoder, decoder self.encoder = Encoder(num_enc_layers, num_enc_heads, embed_dim, ff_dim, dropout=dropout) self.decoder = Decoder(num_dec_layers, num_dec_heads, embed_dim, ff_dim, dropout=dropout) # leave layer norm alone init_func = nn.init.xavier_normal_ if self.config[ 'init_type'] == ac.XAVIER_NORMAL else nn.init.xavier_uniform_ for m in [ self.encoder.self_atts, self.encoder.pos_ffs, self.decoder.self_atts, self.decoder.pos_ffs, self.decoder.enc_dec_atts ]: for p in m.parameters(): if p.dim() > 1: init_func(p) else: nn.init.constant_(p, 0.) def get_input(self, toks, is_src=True): embeds = self.src_embedding if is_src else self.trg_embedding word_embeds = embeds(toks) # [bsz, max_len, embed_dim] pos_embeds = self.pos_embedding[:toks.size()[-1], :].unsqueeze( 0) # [1, max_len, embed_dim] return word_embeds * self.embed_scale + pos_embeds def forward(self, src_toks, trg_toks, targets): encoder_mask = (src_toks == ac.PAD_ID).unsqueeze(1).unsqueeze( 2) # [bsz, 1, 1, max_src_len] decoder_mask = torch.triu(torch.ones( (trg_toks.size()[-1], trg_toks.size()[-1])), diagonal=1).type(trg_toks.type()) == 1 decoder_mask = decoder_mask.unsqueeze(0).unsqueeze(1) encoder_inputs = self.get_input(src_toks, is_src=True) encoder_outputs = self.encoder(encoder_inputs, encoder_mask) decoder_inputs = self.get_input(trg_toks, is_src=False) decoder_outputs = self.decoder(decoder_inputs, decoder_mask, encoder_outputs, encoder_mask) logits = self.logit_fn(decoder_outputs) neglprobs = F.log_softmax(logits, -1) neglprobs = neglprobs * self.trg_vocab_mask.type( neglprobs.type()).reshape(1, -1) targets = targets.reshape(-1, 1) non_pad_mask = targets != ac.PAD_ID nll_loss = -neglprobs.gather(dim=-1, index=targets)[non_pad_mask] smooth_loss = -neglprobs.sum(dim=-1, keepdim=True)[non_pad_mask] nll_loss = nll_loss.sum() smooth_loss = smooth_loss.sum() label_smoothing = self.config['label_smoothing'] loss = ( 1.0 - label_smoothing ) * nll_loss + label_smoothing * smooth_loss / self.trg_vocab_mask.type( smooth_loss.type()).sum() return {'loss': loss, 'nll_loss': nll_loss} def logit_fn(self, decoder_output): logits = F.linear(decoder_output, self.out_embedding, bias=self.out_bias) logits = logits.reshape(-1, logits.size()[-1]) logits[:, ~self.trg_vocab_mask] = -1e9 return logits def beam_decode(self, src_toks): encoder_mask = (src_toks == ac.PAD_ID).unsqueeze(1).unsqueeze( 2) # [bsz, 1, 1, max_src_len] encoder_inputs = self.get_input(src_toks, is_src=True) encoder_outputs = self.encoder(encoder_inputs, encoder_mask) max_lengths = torch.sum(src_toks != ac.PAD_ID, dim=-1).type( src_toks.type()) + 50 def get_trg_inp(ids, time_step): ids = ids.type(src_toks.type()) word_embeds = self.trg_embedding(ids) pos_embeds = self.pos_embedding[time_step, :].reshape(1, 1, -1) return word_embeds * self.embed_scale + pos_embeds def logprob(decoder_output): return F.log_softmax(self.logit_fn(decoder_output), dim=-1) return self.decoder.beam_decode(encoder_outputs, encoder_mask, get_trg_inp, logprob, ac.BOS_ID, ac.EOS_ID, max_lengths, beam_size=self.config['beam_size'], alpha=self.config['beam_alpha'])