Example #1
0
class TransformerModel(nn.Module):
    def __init__(self, ninp, ntoken, ntoken_dec, nhid=2048, dropout=0):
        super(TransformerModel, self).__init__()
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEncoding(ninp, dropout)
        self.encoder = nn.Embedding(ntoken, ninp)
        self.ninp = ninp
        self.decoder_emb = nn.Embedding(ntoken_dec, ninp)
        self.decoder_out = nn.Linear(ninp, ntoken_dec)
        self.model = Transformer(d_model=ninp, dim_feedforward=nhid)

    def forward(self, src, tgt, src_mask, tgt_mask):
        src = self.encoder(src) * math.sqrt(self.ninp)
        src = self.pos_encoder(src)
        tgt = self.decoder_emb(tgt) * math.sqrt(self.ninp)
        tgt = self.pos_encoder(tgt)
        src_mask = src_mask != 1
        tgt_mask = tgt_mask != 1
        subseq_mask = self.model.generate_square_subsequent_mask(
            tgt.size(1)).to(tgt.device)
        output = self.model(src.transpose(0, 1),
                            tgt.transpose(0, 1),
                            tgt_mask=subseq_mask,
                            src_key_padding_mask=src_mask,
                            tgt_key_padding_mask=tgt_mask,
                            memory_key_padding_mask=src_mask)
        output = self.decoder_out(output)
        return output

    def greedy_decode(self, src, src_mask, sos_token, max_length=20):
        src = self.encoder(src) * math.sqrt(self.ninp)
        src = self.pos_encoder(src)
        src_mask = src_mask != 1
        encoded = self.model.encoder(src.transpose(0, 1),
                                     src_key_padding_mask=src_mask)
        generated = encoded.new_full((encoded.size(1), 1),
                                     sos_token,
                                     dtype=torch.long)
        for i in range(max_length - 1):
            subseq_mask = self.model.generate_square_subsequent_mask(
                generated.size(1)).to(src.device)
            decoder_in = self.decoder_emb(generated) * math.sqrt(self.ninp)
            decoder_in = self.pos_encoder(decoder_in)
            logits = self.decoder_out(
                self.model.decoder(decoder_in.transpose(0, 1),
                                   encoded,
                                   tgt_mask=subseq_mask,
                                   memory_key_padding_mask=src_mask)[-1, :, :])
            new_generated = logits.argmax(dim=-1, keepdim=True)
            generated = torch.cat([generated, new_generated], dim=-1)
        return generated

    def save(self, file_dir):
        torch.save(self.state_dict(), file_dir)

    def load(self, file_dir):
        self.load_state_dict(torch.load(file_dir))
Example #2
0
class TransformerDecoderModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, hidden_dim, num_encoder_layers, num_decoder_layers, dropout=0.1):
        super().__init__()
        self.trg_mask = None
        self.pos_encoder = PositionalEncoding(embed_dim)

        self.transformer = Transformer(embed_dim, num_heads, num_encoder_layers, num_decoder_layers, hidden_dim, dropout=dropout)

        self.src_embed = nn.Embedding(vocab_size, embed_dim)
        self.trg_embed = nn.Embedding(vocab_size, embed_dim)

        self.feature_dim = embed_dim
        self.decoder = nn.Linear(embed_dim, vocab_size)

    def forward(self, src, trg):
        if self.trg_mask is None or self.trg_mask.size(0) != len(trg):
            device = trg.device
            mask = self.transformer.generate_square_subsequent_mask(len(trg)).to(device)
            self.trg_mask = mask

        src = self.src_embed(src) * math.sqrt(self.feature_dim)
        src = self.pos_encoder(src)

        trg = self.trg_embed(trg) * math.sqrt(self.feature_dim)
        trg = self.pos_encoder(trg)

        output = self.transformer(src, trg, tgt_mask=self.trg_mask)
        output = self.decoder(output)
        return output
Example #3
0
class FullTransformer(Module):

    def __init__(self, num_vocab, num_embedding=128, dim_feedforward=512, num_encoder_layer=4,
                 num_decoder_layer=4, dropout=0.3, padding_idx=1, max_seq_len=140):
        super(FullTransformer, self).__init__()

        self.padding_idx = padding_idx

        # [x : seq_len,  batch_size ]
        self.inp_embedding = Embedding(num_vocab , num_embedding, padding_idx=padding_idx)

        # [ x : seq_len, batch_size, num_embedding ]
        self.pos_embedding = PositionalEncoding(num_embedding, dropout, max_len=max_seq_len)

        self.trfm = Transformer(d_model=num_embedding, dim_feedforward=dim_feedforward,
                                num_encoder_layers=num_encoder_layer, num_decoder_layers=num_decoder_layer,
                                dropout=dropout)
        self.linear_out = torch.nn.Linear(num_embedding, num_vocab)

    def make_pad_mask(self, inp: torch.Tensor) -> torch.Tensor:
        """
        Make mask attention that caused 'True' element will not be attended (ignored).
        Padding stated in self.padding_idx will not be attended at all.

        :param inp : input that to be masked in boolean Tensor
        """
        return (inp == self.padding_idx).transpose(0, 1)

    def forward(self, src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
        """
        forward!

        :param src : source tensor
        :param tgt : target tensor
        """
        # Generate mask for decoder attention
        tgt_mask = self.trfm.generate_square_subsequent_mask(len(tgt)).to(tgt.device)

        # trg_mask shape = [target_seq_len, target_seq_len]
        src_pad_mask = self.make_pad_mask(src)
        tgt_pad_mask = self.make_pad_mask(tgt)

        # [ src : seq_len, batch_size, num_embedding ]

        out_emb_enc = self.pos_embedding(self.inp_embedding(src))

        # [ src : seq_len, batch_size, num_embedding ]
        out_emb_dec = self.pos_embedding(self.inp_embedding(tgt))

        out_trf = self.trfm(out_emb_enc, out_emb_dec, src_mask=None, tgt_mask=tgt_mask, memory_mask=None,
                            src_key_padding_mask=src_pad_mask, tgt_key_padding_mask=tgt_pad_mask,
                            memory_key_padding_mask=src_pad_mask)

        # [ out_trf : seq_len, batch_size, num_embedding]

        out_to_logit = self.linear_out(out_trf)

        # final_out : [ seq_len, batch_size, vocab_size ]
        return out_to_logit
Example #4
0
class Transformer_fr(nn.Module):
    def __init__(self, en_vocab_size, de_vocab_size, padding_idx, max_len,
                 embed_size, device):
        super(Transformer_fr, self).__init__()
        self.en_vocab = en_vocab_size
        self.de_vocab = de_vocab_size
        self.padd = padding_idx
        self.BOS = 1
        self.EOS = 2
        self.device = device

        #self.encode = Pos_encoding(embed_size, max_len, device)
        #self.en_emb = nn.Embedding(self.en_vocab, embed_size, padding_idx = 0)
        #self.de_emb = nn.Embedding(self.de_vocab, embed_size, padding_idx = 0)

        self.en_enc = Encoding(self.en_vocab, embed_size, max_len, 0.2, device)
        self.de_enc = Encoding(self.de_vocab, embed_size, max_len, 0.2, device)

        self.transformer = Transformer()
        self.fc = nn.Linear(embed_size, self.de_vocab)

        self.scale = embed_size**0.5

    def gen_src_mask(self, x):
        '''
        x = (B, S)
        src_mask = (B, 1, S_r) --> broadcast
        '''
        #(B,1,S)
        src_mask = (x == self.padd_idx).unsqueeze(1)

        return src_mask.to(self.device)

    def gen_trg_mask(self, x):
        '''
        x = (B,S)
        trg_mask = (B, S, S_r) : triangle
        '''
        batch = x.shape[0]
        seq = x.shape[1]

        #B, 1, S
        #trg_pad = (x == self.padd).unsqueeze(1)
        #1, S, S
        #S, S
        trg_mask = torch.tril(torch.ones(seq, seq))
        trg_mask[trg_mask == 0] = float("-inf")
        trg_mask[trg_mask == 1] = float(0.0)
        #trg_mask = trg_pad | trg_idx
        #print(trg_mask)

        return trg_mask.to(self.device)

    def forward(self, src, trg):

        #src = self.en_emb(src) * self.scale + self.encode(src)
        #trg = self.de_emb(trg) * self.scale+ self.encode(trg)
        trg_seq = trg.size(1)

        src = self.en_enc(src)
        trg = self.de_enc(trg)

        trg_mask = self.transformer.generate_square_subsequent_mask(
            trg_seq).to(self.device)
        #trg_mask = self.gen_trg_mask(trg)

        #print(trg_mask)
        src = src.transpose(0, 1)
        trg = trg.transpose(0, 1)

        output = self.transformer(src, trg, tgt_mask=trg_mask)
        output = output.transpose(0, 1)
        output = self.fc(output)

        #print(src.shape, trg.shape, output.shape)
        return output

    def inference(self, src):
        '''
        x  = (B, S_source)
        return (B, S_target)
        '''

        #in order to paper, max_seq = src seq + 300
        max_seq = src.size(1) + 50
        batch = src.size(0)

        lengths = np.array([max_seq] * batch)
        #outputs = []

        outputs = torch.zeros((batch, 1)).to(torch.long).to(self.device)
        outputs[:, 0] = self.BOS

        for step in range(1, max_seq):
            out = self.forward(src, outputs)

            #out = out.view(batch, max_seq, -1)
            #print(out.shape)
            out = out[:, -1, :]
            pred = torch.topk(F.log_softmax(out), 1, dim=-1)[1]

            outputs = torch.cat([outputs, pred], dim=1)

            eos_batches = pred.data.eq(self.EOS)
            if eos_batches.dim() > 0:
                eos_batches = eos_batches.cpu().view(-1).numpy()
                update_idx = ((lengths > step) & eos_batches) != 0
                lengths[update_idx] = step

        return outputs.detach(), lengths
Example #5
0
class TransformerModel(nn.Module):
    def __init__(self,
                 vocab_size,
                 hidden_size,
                 num_attention_heads,
                 num_encoder_layers,
                 num_decoder_layers,
                 intermediate_size,
                 dropout=0.1):
        super(TransformerModel, self).__init__()

        # self.token_embeddings = nn.Embedding(vocab_size, hidden_size)
        self.token_embeddings = nn.Embedding(vocab_size,
                                             hidden_size,
                                             padding_idx=1)
        self.position_embeddings = PositionalEncoding(hidden_size)
        self.hidden_size = hidden_size
        self.dropout = nn.Dropout(p=dropout)

        self.transformer = Transformer(
            d_model=hidden_size,
            nhead=num_attention_heads,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=intermediate_size,
            dropout=dropout,
        )

        self.decoder_embeddings = nn.Linear(hidden_size, vocab_size)
        self.decoder_embeddings.weight = self.token_embeddings.weight

        self.init_weights()

    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(
            mask == 1, float(0.0))
        return mask

    def init_weights(self):
        initrange = 0.1
        self.token_embeddings.weight.data.uniform_(-initrange, initrange)
        self.decoder_embeddings.bias.data.zero_()
        self.decoder_embeddings.weight.data.uniform_(-initrange, initrange)

    def forward(self,
                src=None,
                tgt=None,
                memory=None,
                src_key_padding_mask=None,
                tgt_key_padding_mask=None,
                memory_key_padding_mask=None):
        if src is not None:
            src_embeddings = self.token_embeddings(src) * math.sqrt(
                self.hidden_size) + self.position_embeddings(src)
            src_embeddings = self.dropout(src_embeddings)

            if src_key_padding_mask is not None:
                src_key_padding_mask = src_key_padding_mask.t()

            if tgt is None:  # encode
                memory = self.transformer.encoder(
                    src_embeddings, src_key_padding_mask=src_key_padding_mask)
                return memory

        if tgt is not None:
            tgt_embeddings = self.token_embeddings(tgt) * math.sqrt(
                self.hidden_size) + self.position_embeddings(tgt)
            tgt_embeddings = self.dropout(tgt_embeddings)
            tgt_mask = self.transformer.generate_square_subsequent_mask(
                tgt.size(0)).to(tgt.device)

            if tgt_key_padding_mask is not None:
                tgt_key_padding_mask = tgt_key_padding_mask.t()

            if src is None and memory is not None:  # decode
                if memory_key_padding_mask is not None:
                    memory_key_padding_mask = memory_key_padding_mask.t()

                output = self.transformer.decoder(
                    tgt_embeddings,
                    memory,
                    tgt_mask=tgt_mask,
                    tgt_key_padding_mask=tgt_key_padding_mask,
                    memory_key_padding_mask=memory_key_padding_mask)
                output = self.decoder_embeddings(output)

                return output

        assert not (src is None and tgt is None)
        output = self.transformer(src_embeddings,
                                  tgt_embeddings,
                                  tgt_mask=tgt_mask,
                                  src_key_padding_mask=src_key_padding_mask,
                                  tgt_key_padding_mask=tgt_key_padding_mask)
        output = self.decoder_embeddings(output)
        return output
Example #6
0
class TransformerModel(nn.Module):
    def __init__(self,
                 vocab_size,
                 d_model,
                 num_attention_heads,
                 num_encoder_layers,
                 num_decoder_layers,
                 intermediate_size,
                 max_len,
                 dropout=0.1):
        super(TransformerModel, self).__init__()

        self.token_embeddings = nn.Embedding(vocab_size, d_model)
        self.position_embeddings = PositionalEncoding(d_model, max_len)
        self.hidden_size = d_model
        self.dropout = nn.Dropout(p=dropout)

        self.transformer = Transformer(d_model=d_model,
                                       nhead=num_attention_heads,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=intermediate_size,
                                       dropout=dropout)

        self.decoder_embeddings = nn.Linear(d_model, vocab_size)
        self.decoder_embeddings.weight = self.token_embeddings.weight

        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        self.token_embeddings.weight.data.uniform_(-initrange, initrange)
        self.decoder_embeddings.bias.data.zero_()
        self.decoder_embeddings.weight.data.uniform_(-initrange, initrange)

    def forward(self,
                src,
                tgt,
                src_key_padding_mask=None,
                tgt_key_padding_mask=None):
        src_embeddings = self.token_embeddings(src) * math.sqrt(
            self.hidden_size) + self.position_embeddings(src)
        src_embeddings = self.dropout(src_embeddings)

        tgt_embeddings = self.token_embeddings(tgt) * math.sqrt(
            self.hidden_size) + self.position_embeddings(tgt)
        tgt_embeddings = self.dropout(tgt_embeddings)

        tgt_mask = self.transformer.generate_square_subsequent_mask(
            tgt.size(0)).to(tgt.device)
        output = self.transformer(src_embeddings,
                                  tgt_embeddings,
                                  tgt_mask=tgt_mask,
                                  src_key_padding_mask=src_key_padding_mask,
                                  tgt_key_padding_mask=tgt_key_padding_mask)

        output = self.decoder_embeddings(output)
        return output

    def encode(self, src, src_key_padding_mask=None):
        src_embeddings = self.token_embeddings(src) * math.sqrt(
            self.hidden_size) + self.position_embeddings(src)
        src_embeddings = self.dropout(src_embeddings)

        memory = self.transformer.encoder(
            src_embeddings, src_key_padding_mask=src_key_padding_mask)
        return memory

    def decode(self,
               tgt,
               memory,
               tgt_key_padding_mask=None,
               memory_key_padding_mask=None):
        tgt_embeddings = self.token_embeddings(tgt) * math.sqrt(
            self.hidden_size) + self.position_embeddings(tgt)
        tgt_embeddings = self.dropout(tgt_embeddings)
        tgt_mask = self.transformer.generate_square_subsequent_mask(
            tgt.size(0)).to(tgt.device)

        output = self.transformer.decoder(
            tgt_embeddings,
            memory,
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=memory_key_padding_mask)
        output = self.decoder_embeddings(output)
        return output