class TransformerModel(nn.Module): def __init__(self, ninp, ntoken, ntoken_dec, nhid=2048, dropout=0): super(TransformerModel, self).__init__() self.model_type = 'Transformer' self.pos_encoder = PositionalEncoding(ninp, dropout) self.encoder = nn.Embedding(ntoken, ninp) self.ninp = ninp self.decoder_emb = nn.Embedding(ntoken_dec, ninp) self.decoder_out = nn.Linear(ninp, ntoken_dec) self.model = Transformer(d_model=ninp, dim_feedforward=nhid) def forward(self, src, tgt, src_mask, tgt_mask): src = self.encoder(src) * math.sqrt(self.ninp) src = self.pos_encoder(src) tgt = self.decoder_emb(tgt) * math.sqrt(self.ninp) tgt = self.pos_encoder(tgt) src_mask = src_mask != 1 tgt_mask = tgt_mask != 1 subseq_mask = self.model.generate_square_subsequent_mask( tgt.size(1)).to(tgt.device) output = self.model(src.transpose(0, 1), tgt.transpose(0, 1), tgt_mask=subseq_mask, src_key_padding_mask=src_mask, tgt_key_padding_mask=tgt_mask, memory_key_padding_mask=src_mask) output = self.decoder_out(output) return output def greedy_decode(self, src, src_mask, sos_token, max_length=20): src = self.encoder(src) * math.sqrt(self.ninp) src = self.pos_encoder(src) src_mask = src_mask != 1 encoded = self.model.encoder(src.transpose(0, 1), src_key_padding_mask=src_mask) generated = encoded.new_full((encoded.size(1), 1), sos_token, dtype=torch.long) for i in range(max_length - 1): subseq_mask = self.model.generate_square_subsequent_mask( generated.size(1)).to(src.device) decoder_in = self.decoder_emb(generated) * math.sqrt(self.ninp) decoder_in = self.pos_encoder(decoder_in) logits = self.decoder_out( self.model.decoder(decoder_in.transpose(0, 1), encoded, tgt_mask=subseq_mask, memory_key_padding_mask=src_mask)[-1, :, :]) new_generated = logits.argmax(dim=-1, keepdim=True) generated = torch.cat([generated, new_generated], dim=-1) return generated def save(self, file_dir): torch.save(self.state_dict(), file_dir) def load(self, file_dir): self.load_state_dict(torch.load(file_dir))
class TransformerDecoderModel(nn.Module): def __init__(self, vocab_size, embed_dim, num_heads, hidden_dim, num_encoder_layers, num_decoder_layers, dropout=0.1): super().__init__() self.trg_mask = None self.pos_encoder = PositionalEncoding(embed_dim) self.transformer = Transformer(embed_dim, num_heads, num_encoder_layers, num_decoder_layers, hidden_dim, dropout=dropout) self.src_embed = nn.Embedding(vocab_size, embed_dim) self.trg_embed = nn.Embedding(vocab_size, embed_dim) self.feature_dim = embed_dim self.decoder = nn.Linear(embed_dim, vocab_size) def forward(self, src, trg): if self.trg_mask is None or self.trg_mask.size(0) != len(trg): device = trg.device mask = self.transformer.generate_square_subsequent_mask(len(trg)).to(device) self.trg_mask = mask src = self.src_embed(src) * math.sqrt(self.feature_dim) src = self.pos_encoder(src) trg = self.trg_embed(trg) * math.sqrt(self.feature_dim) trg = self.pos_encoder(trg) output = self.transformer(src, trg, tgt_mask=self.trg_mask) output = self.decoder(output) return output
class FullTransformer(Module): def __init__(self, num_vocab, num_embedding=128, dim_feedforward=512, num_encoder_layer=4, num_decoder_layer=4, dropout=0.3, padding_idx=1, max_seq_len=140): super(FullTransformer, self).__init__() self.padding_idx = padding_idx # [x : seq_len, batch_size ] self.inp_embedding = Embedding(num_vocab , num_embedding, padding_idx=padding_idx) # [ x : seq_len, batch_size, num_embedding ] self.pos_embedding = PositionalEncoding(num_embedding, dropout, max_len=max_seq_len) self.trfm = Transformer(d_model=num_embedding, dim_feedforward=dim_feedforward, num_encoder_layers=num_encoder_layer, num_decoder_layers=num_decoder_layer, dropout=dropout) self.linear_out = torch.nn.Linear(num_embedding, num_vocab) def make_pad_mask(self, inp: torch.Tensor) -> torch.Tensor: """ Make mask attention that caused 'True' element will not be attended (ignored). Padding stated in self.padding_idx will not be attended at all. :param inp : input that to be masked in boolean Tensor """ return (inp == self.padding_idx).transpose(0, 1) def forward(self, src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor: """ forward! :param src : source tensor :param tgt : target tensor """ # Generate mask for decoder attention tgt_mask = self.trfm.generate_square_subsequent_mask(len(tgt)).to(tgt.device) # trg_mask shape = [target_seq_len, target_seq_len] src_pad_mask = self.make_pad_mask(src) tgt_pad_mask = self.make_pad_mask(tgt) # [ src : seq_len, batch_size, num_embedding ] out_emb_enc = self.pos_embedding(self.inp_embedding(src)) # [ src : seq_len, batch_size, num_embedding ] out_emb_dec = self.pos_embedding(self.inp_embedding(tgt)) out_trf = self.trfm(out_emb_enc, out_emb_dec, src_mask=None, tgt_mask=tgt_mask, memory_mask=None, src_key_padding_mask=src_pad_mask, tgt_key_padding_mask=tgt_pad_mask, memory_key_padding_mask=src_pad_mask) # [ out_trf : seq_len, batch_size, num_embedding] out_to_logit = self.linear_out(out_trf) # final_out : [ seq_len, batch_size, vocab_size ] return out_to_logit
class Transformer_fr(nn.Module): def __init__(self, en_vocab_size, de_vocab_size, padding_idx, max_len, embed_size, device): super(Transformer_fr, self).__init__() self.en_vocab = en_vocab_size self.de_vocab = de_vocab_size self.padd = padding_idx self.BOS = 1 self.EOS = 2 self.device = device #self.encode = Pos_encoding(embed_size, max_len, device) #self.en_emb = nn.Embedding(self.en_vocab, embed_size, padding_idx = 0) #self.de_emb = nn.Embedding(self.de_vocab, embed_size, padding_idx = 0) self.en_enc = Encoding(self.en_vocab, embed_size, max_len, 0.2, device) self.de_enc = Encoding(self.de_vocab, embed_size, max_len, 0.2, device) self.transformer = Transformer() self.fc = nn.Linear(embed_size, self.de_vocab) self.scale = embed_size**0.5 def gen_src_mask(self, x): ''' x = (B, S) src_mask = (B, 1, S_r) --> broadcast ''' #(B,1,S) src_mask = (x == self.padd_idx).unsqueeze(1) return src_mask.to(self.device) def gen_trg_mask(self, x): ''' x = (B,S) trg_mask = (B, S, S_r) : triangle ''' batch = x.shape[0] seq = x.shape[1] #B, 1, S #trg_pad = (x == self.padd).unsqueeze(1) #1, S, S #S, S trg_mask = torch.tril(torch.ones(seq, seq)) trg_mask[trg_mask == 0] = float("-inf") trg_mask[trg_mask == 1] = float(0.0) #trg_mask = trg_pad | trg_idx #print(trg_mask) return trg_mask.to(self.device) def forward(self, src, trg): #src = self.en_emb(src) * self.scale + self.encode(src) #trg = self.de_emb(trg) * self.scale+ self.encode(trg) trg_seq = trg.size(1) src = self.en_enc(src) trg = self.de_enc(trg) trg_mask = self.transformer.generate_square_subsequent_mask( trg_seq).to(self.device) #trg_mask = self.gen_trg_mask(trg) #print(trg_mask) src = src.transpose(0, 1) trg = trg.transpose(0, 1) output = self.transformer(src, trg, tgt_mask=trg_mask) output = output.transpose(0, 1) output = self.fc(output) #print(src.shape, trg.shape, output.shape) return output def inference(self, src): ''' x = (B, S_source) return (B, S_target) ''' #in order to paper, max_seq = src seq + 300 max_seq = src.size(1) + 50 batch = src.size(0) lengths = np.array([max_seq] * batch) #outputs = [] outputs = torch.zeros((batch, 1)).to(torch.long).to(self.device) outputs[:, 0] = self.BOS for step in range(1, max_seq): out = self.forward(src, outputs) #out = out.view(batch, max_seq, -1) #print(out.shape) out = out[:, -1, :] pred = torch.topk(F.log_softmax(out), 1, dim=-1)[1] outputs = torch.cat([outputs, pred], dim=1) eos_batches = pred.data.eq(self.EOS) if eos_batches.dim() > 0: eos_batches = eos_batches.cpu().view(-1).numpy() update_idx = ((lengths > step) & eos_batches) != 0 lengths[update_idx] = step return outputs.detach(), lengths
class TransformerModel(nn.Module): def __init__(self, vocab_size, hidden_size, num_attention_heads, num_encoder_layers, num_decoder_layers, intermediate_size, dropout=0.1): super(TransformerModel, self).__init__() # self.token_embeddings = nn.Embedding(vocab_size, hidden_size) self.token_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=1) self.position_embeddings = PositionalEncoding(hidden_size) self.hidden_size = hidden_size self.dropout = nn.Dropout(p=dropout) self.transformer = Transformer( d_model=hidden_size, nhead=num_attention_heads, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, dim_feedforward=intermediate_size, dropout=dropout, ) self.decoder_embeddings = nn.Linear(hidden_size, vocab_size) self.decoder_embeddings.weight = self.token_embeddings.weight self.init_weights() def _generate_square_subsequent_mask(self, sz): mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill( mask == 1, float(0.0)) return mask def init_weights(self): initrange = 0.1 self.token_embeddings.weight.data.uniform_(-initrange, initrange) self.decoder_embeddings.bias.data.zero_() self.decoder_embeddings.weight.data.uniform_(-initrange, initrange) def forward(self, src=None, tgt=None, memory=None, src_key_padding_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None): if src is not None: src_embeddings = self.token_embeddings(src) * math.sqrt( self.hidden_size) + self.position_embeddings(src) src_embeddings = self.dropout(src_embeddings) if src_key_padding_mask is not None: src_key_padding_mask = src_key_padding_mask.t() if tgt is None: # encode memory = self.transformer.encoder( src_embeddings, src_key_padding_mask=src_key_padding_mask) return memory if tgt is not None: tgt_embeddings = self.token_embeddings(tgt) * math.sqrt( self.hidden_size) + self.position_embeddings(tgt) tgt_embeddings = self.dropout(tgt_embeddings) tgt_mask = self.transformer.generate_square_subsequent_mask( tgt.size(0)).to(tgt.device) if tgt_key_padding_mask is not None: tgt_key_padding_mask = tgt_key_padding_mask.t() if src is None and memory is not None: # decode if memory_key_padding_mask is not None: memory_key_padding_mask = memory_key_padding_mask.t() output = self.transformer.decoder( tgt_embeddings, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask) output = self.decoder_embeddings(output) return output assert not (src is None and tgt is None) output = self.transformer(src_embeddings, tgt_embeddings, tgt_mask=tgt_mask, src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=tgt_key_padding_mask) output = self.decoder_embeddings(output) return output
class TransformerModel(nn.Module): def __init__(self, vocab_size, d_model, num_attention_heads, num_encoder_layers, num_decoder_layers, intermediate_size, max_len, dropout=0.1): super(TransformerModel, self).__init__() self.token_embeddings = nn.Embedding(vocab_size, d_model) self.position_embeddings = PositionalEncoding(d_model, max_len) self.hidden_size = d_model self.dropout = nn.Dropout(p=dropout) self.transformer = Transformer(d_model=d_model, nhead=num_attention_heads, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, dim_feedforward=intermediate_size, dropout=dropout) self.decoder_embeddings = nn.Linear(d_model, vocab_size) self.decoder_embeddings.weight = self.token_embeddings.weight self.init_weights() def init_weights(self): initrange = 0.1 self.token_embeddings.weight.data.uniform_(-initrange, initrange) self.decoder_embeddings.bias.data.zero_() self.decoder_embeddings.weight.data.uniform_(-initrange, initrange) def forward(self, src, tgt, src_key_padding_mask=None, tgt_key_padding_mask=None): src_embeddings = self.token_embeddings(src) * math.sqrt( self.hidden_size) + self.position_embeddings(src) src_embeddings = self.dropout(src_embeddings) tgt_embeddings = self.token_embeddings(tgt) * math.sqrt( self.hidden_size) + self.position_embeddings(tgt) tgt_embeddings = self.dropout(tgt_embeddings) tgt_mask = self.transformer.generate_square_subsequent_mask( tgt.size(0)).to(tgt.device) output = self.transformer(src_embeddings, tgt_embeddings, tgt_mask=tgt_mask, src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=tgt_key_padding_mask) output = self.decoder_embeddings(output) return output def encode(self, src, src_key_padding_mask=None): src_embeddings = self.token_embeddings(src) * math.sqrt( self.hidden_size) + self.position_embeddings(src) src_embeddings = self.dropout(src_embeddings) memory = self.transformer.encoder( src_embeddings, src_key_padding_mask=src_key_padding_mask) return memory def decode(self, tgt, memory, tgt_key_padding_mask=None, memory_key_padding_mask=None): tgt_embeddings = self.token_embeddings(tgt) * math.sqrt( self.hidden_size) + self.position_embeddings(tgt) tgt_embeddings = self.dropout(tgt_embeddings) tgt_mask = self.transformer.generate_square_subsequent_mask( tgt.size(0)).to(tgt.device) output = self.transformer.decoder( tgt_embeddings, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask) output = self.decoder_embeddings(output) return output