示例#1
0
    def __init__(self,
                 num_vocab,
                 num_embedding=128,
                 dim_feedforward=512,
                 num_encoder_layer=4,
                 num_decoder_layer=4,
                 dropout=0.3,
                 padding_idx=1,
                 max_seq_len=140,
                 nhead=8):
        super(FullTransformer, self).__init__()

        self.padding_idx = padding_idx

        # [x : seq_len,  batch_size ]
        self.inp_embedding = Embedding(num_vocab,
                                       num_embedding,
                                       padding_idx=padding_idx)

        # [ x : seq_len, batch_size, num_embedding ]
        self.pos_embedding = PositionalEncoding(num_embedding,
                                                dropout,
                                                max_len=max_seq_len)

        self.trfm = Transformer(d_model=num_embedding,
                                dim_feedforward=dim_feedforward,
                                num_encoder_layers=num_encoder_layer,
                                num_decoder_layers=num_decoder_layer,
                                dropout=dropout,
                                nhead=nhead)
        self.linear_out = torch.nn.Linear(num_embedding, num_vocab)
示例#2
0
    def __init__(self,
                 vocab_size,
                 hidden_size,
                 num_attention_heads,
                 num_encoder_layers,
                 num_decoder_layers,
                 intermediate_size,
                 dropout=0.1):
        super(TransformerModel, self).__init__()

        # self.token_embeddings = nn.Embedding(vocab_size, hidden_size)
        self.token_embeddings = nn.Embedding(vocab_size,
                                             hidden_size,
                                             padding_idx=1)
        self.position_embeddings = PositionalEncoding(hidden_size)
        self.hidden_size = hidden_size
        self.dropout = nn.Dropout(p=dropout)

        self.transformer = Transformer(
            d_model=hidden_size,
            nhead=num_attention_heads,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=intermediate_size,
            dropout=dropout,
        )

        self.decoder_embeddings = nn.Linear(hidden_size, vocab_size)
        self.decoder_embeddings.weight = self.token_embeddings.weight

        self.init_weights()
示例#3
0
 def __init__(self, ninp, ntoken, ntoken_dec, nhid=2048, dropout=0):
     super(TransformerModel, self).__init__()
     self.model_type = 'Transformer'
     self.pos_encoder = PositionalEncoding(ninp, dropout)
     self.encoder = nn.Embedding(ntoken, ninp)
     self.ninp = ninp
     self.decoder_emb = nn.Embedding(ntoken_dec, ninp)
     self.decoder_out = nn.Linear(ninp, ntoken_dec)
     self.model = Transformer(d_model=ninp, dim_feedforward=nhid)
示例#4
0
class TransformerModel(nn.Module):
    def __init__(self, ninp, ntoken, ntoken_dec, nhid=2048, dropout=0):
        super(TransformerModel, self).__init__()
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEncoding(ninp, dropout)
        self.encoder = nn.Embedding(ntoken, ninp)
        self.ninp = ninp
        self.decoder_emb = nn.Embedding(ntoken_dec, ninp)
        self.decoder_out = nn.Linear(ninp, ntoken_dec)
        self.model = Transformer(d_model=ninp, dim_feedforward=nhid)

    def forward(self, src, tgt, src_mask, tgt_mask):
        src = self.encoder(src) * math.sqrt(self.ninp)
        src = self.pos_encoder(src)
        tgt = self.decoder_emb(tgt) * math.sqrt(self.ninp)
        tgt = self.pos_encoder(tgt)
        src_mask = src_mask != 1
        tgt_mask = tgt_mask != 1
        subseq_mask = self.model.generate_square_subsequent_mask(
            tgt.size(1)).to(tgt.device)
        output = self.model(src.transpose(0, 1),
                            tgt.transpose(0, 1),
                            tgt_mask=subseq_mask,
                            src_key_padding_mask=src_mask,
                            tgt_key_padding_mask=tgt_mask,
                            memory_key_padding_mask=src_mask)
        output = self.decoder_out(output)
        return output

    def greedy_decode(self, src, src_mask, sos_token, max_length=20):
        src = self.encoder(src) * math.sqrt(self.ninp)
        src = self.pos_encoder(src)
        src_mask = src_mask != 1
        encoded = self.model.encoder(src.transpose(0, 1),
                                     src_key_padding_mask=src_mask)
        generated = encoded.new_full((encoded.size(1), 1),
                                     sos_token,
                                     dtype=torch.long)
        for i in range(max_length - 1):
            subseq_mask = self.model.generate_square_subsequent_mask(
                generated.size(1)).to(src.device)
            decoder_in = self.decoder_emb(generated) * math.sqrt(self.ninp)
            decoder_in = self.pos_encoder(decoder_in)
            logits = self.decoder_out(
                self.model.decoder(decoder_in.transpose(0, 1),
                                   encoded,
                                   tgt_mask=subseq_mask,
                                   memory_key_padding_mask=src_mask)[-1, :, :])
            new_generated = logits.argmax(dim=-1, keepdim=True)
            generated = torch.cat([generated, new_generated], dim=-1)
        return generated

    def save(self, file_dir):
        torch.save(self.state_dict(), file_dir)

    def load(self, file_dir):
        self.load_state_dict(torch.load(file_dir))
示例#5
0
    def __init__(self, vocab_size, embed_dim, num_heads, hidden_dim, num_encoder_layers, num_decoder_layers, dropout=0.1):
        super().__init__()
        self.trg_mask = None
        self.pos_encoder = PositionalEncoding(embed_dim)

        self.transformer = Transformer(embed_dim, num_heads, num_encoder_layers, num_decoder_layers, hidden_dim, dropout=dropout)

        self.src_embed = nn.Embedding(vocab_size, embed_dim)
        self.trg_embed = nn.Embedding(vocab_size, embed_dim)

        self.feature_dim = embed_dim
        self.decoder = nn.Linear(embed_dim, vocab_size)
    def __init__(self, nb_tokens: int, emb_size: int, nb_layers=2, nb_heads=4, hid_size=512, dropout=0.25, max_len=30):
        super(SimpleTransformerModel, self).__init__()
        from torch.nn import Transformer
        self.emb_size = emb_size
        self.max_len = max_len

        self.pos_encoder = PositionalEncoding(emb_size, dropout=dropout, max_len=max_len)
        self.embedder = nn.Embedding(nb_tokens, emb_size)

        self.transformer = Transformer(d_model=emb_size, nhead=nb_heads, num_encoder_layers=nb_layers,
                                       num_decoder_layers=nb_layers, dim_feedforward=hid_size, dropout=dropout)

        self.out_lin = nn.Linear(in_features=emb_size, out_features=nb_tokens)

        self.tgt_mask = None
    def __init__(self,
                 n_tokens,
                 n_joints,
                 joints_dim,
                 nhead,
                 nhid,
                 nout,
                 n_enc_layers,
                 n_dec_layers,
                 dropout=0.5):
        super(TextPoseTransformer, self).__init__()
        from torch.nn import Transformer
        self.model_type = 'Transformer'
        self.src_mask = None
        self.token_pos_encoder = PositionalEncoding(nhid, dropout, max_len=40)
        self.pose_pos_encoder = PositionalEncoding(nhid, dropout, max_len=100)

        self.transformer = Transformer(nhid, nhead, n_enc_layers, n_dec_layers,
                                       nhid)

        self.token_embedding = nn.Embedding(n_tokens, nhid)
        self.hidden2pose_projection = nn.Linear(nhid, nout)
        self.pose2hidden_projection = nn.Linear(n_joints * joints_dim, nhid)

        self.init_weights()
    def __init__(self, config: TransformerEncoderConfig):
        super().__init__()
        self.save_hyperparameters()
        TEXT = torchtext.data.Field(tokenize=get_tokenizer(config.data),
                                    init_token='<sos>',
                                    eos_token='<eos>',
                                    lower=True)
        train_txt, val_txt, test_txt = torchtext.datasets.WikiText2.splits(
            TEXT)
        TEXT.build_vocab(train_txt)
        self.TEXT = TEXT
        self.train_txt = train_txt
        self.val_txt = val_txt
        self.test_txt = test_txt
        self.ntoken = len(TEXT.vocab.stoi)
        self.model_type = 'Transformer'
        self.src_mask = None
        self.pos_encoder = PositionalEncoding(config.ninp, config.dropout)
        encoder_layers = TransformerEncoderLayer(config.ninp, config.nhead,
                                                 config.nhid, config.dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers,
                                                      config.nlayers)
        self.encoder = nn.Embedding(self.ntoken, config.ninp)
        self.ninp = config.ninp
        self.transformer = Transformer(d_model=config.ninp,
                                       nhead=config.nhead,
                                       num_encoder_layers=config.nlayers,
                                       num_decoder_layers=1,
                                       dim_feedforward=config.nhid,
                                       dropout=config.dropout)
        self.out = nn.Linear(config.ninp, self.ntoken)

        self.init_weights()
示例#9
0
class TransformerDecoderModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, hidden_dim, num_encoder_layers, num_decoder_layers, dropout=0.1):
        super().__init__()
        self.trg_mask = None
        self.pos_encoder = PositionalEncoding(embed_dim)

        self.transformer = Transformer(embed_dim, num_heads, num_encoder_layers, num_decoder_layers, hidden_dim, dropout=dropout)

        self.src_embed = nn.Embedding(vocab_size, embed_dim)
        self.trg_embed = nn.Embedding(vocab_size, embed_dim)

        self.feature_dim = embed_dim
        self.decoder = nn.Linear(embed_dim, vocab_size)

    def forward(self, src, trg):
        if self.trg_mask is None or self.trg_mask.size(0) != len(trg):
            device = trg.device
            mask = self.transformer.generate_square_subsequent_mask(len(trg)).to(device)
            self.trg_mask = mask

        src = self.src_embed(src) * math.sqrt(self.feature_dim)
        src = self.pos_encoder(src)

        trg = self.trg_embed(trg) * math.sqrt(self.feature_dim)
        trg = self.pos_encoder(trg)

        output = self.transformer(src, trg, tgt_mask=self.trg_mask)
        output = self.decoder(output)
        return output
示例#10
0
    def __init__(
        self,
        vocab: Vocabulary,
        source_embedder: TextFieldEmbedder,
        transformer: Dict,
        max_decoding_steps: int,
        target_namespace: str,
        target_embedder: TextFieldEmbedder = None,
        use_bleu: bool = True,
    ) -> None:
        super().__init__(vocab)
        self._target_namespace = target_namespace

        self._start_index = self.vocab.get_token_index(START_SYMBOL,
                                                       self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL,
                                                     self._target_namespace)
        self._pad_index = self.vocab.get_token_index(self.vocab._padding_token,
                                                     self._target_namespace)

        if use_bleu:
            self._bleu = BLEU(exclude_indices={
                self._pad_index, self._end_index, self._start_index
            })
        else:
            self._bleu = None
        self._seq_acc = SequenceAccuracy()

        self._max_decoding_steps = max_decoding_steps

        self._source_embedder = source_embedder

        self._ndim = transformer["d_model"]
        self.pos_encoder = PositionalEncoding(self._ndim,
                                              transformer["dropout"])

        num_classes = self.vocab.get_vocab_size(self._target_namespace)

        self._transformer = Transformer(**transformer)
        self._transformer.apply(inplace_relu)

        if target_embedder is None:
            self._target_embedder = self._source_embedder
        else:
            self._target_embedder = target_embedder

        self._output_projection_layer = Linear(self._ndim, num_classes)
示例#11
0
class FullTransformer(Module):

    def __init__(self, num_vocab, num_embedding=128, dim_feedforward=512, num_encoder_layer=4,
                 num_decoder_layer=4, dropout=0.3, padding_idx=1, max_seq_len=140):
        super(FullTransformer, self).__init__()

        self.padding_idx = padding_idx

        # [x : seq_len,  batch_size ]
        self.inp_embedding = Embedding(num_vocab , num_embedding, padding_idx=padding_idx)

        # [ x : seq_len, batch_size, num_embedding ]
        self.pos_embedding = PositionalEncoding(num_embedding, dropout, max_len=max_seq_len)

        self.trfm = Transformer(d_model=num_embedding, dim_feedforward=dim_feedforward,
                                num_encoder_layers=num_encoder_layer, num_decoder_layers=num_decoder_layer,
                                dropout=dropout)
        self.linear_out = torch.nn.Linear(num_embedding, num_vocab)

    def make_pad_mask(self, inp: torch.Tensor) -> torch.Tensor:
        """
        Make mask attention that caused 'True' element will not be attended (ignored).
        Padding stated in self.padding_idx will not be attended at all.

        :param inp : input that to be masked in boolean Tensor
        """
        return (inp == self.padding_idx).transpose(0, 1)

    def forward(self, src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
        """
        forward!

        :param src : source tensor
        :param tgt : target tensor
        """
        # Generate mask for decoder attention
        tgt_mask = self.trfm.generate_square_subsequent_mask(len(tgt)).to(tgt.device)

        # trg_mask shape = [target_seq_len, target_seq_len]
        src_pad_mask = self.make_pad_mask(src)
        tgt_pad_mask = self.make_pad_mask(tgt)

        # [ src : seq_len, batch_size, num_embedding ]

        out_emb_enc = self.pos_embedding(self.inp_embedding(src))

        # [ src : seq_len, batch_size, num_embedding ]
        out_emb_dec = self.pos_embedding(self.inp_embedding(tgt))

        out_trf = self.trfm(out_emb_enc, out_emb_dec, src_mask=None, tgt_mask=tgt_mask, memory_mask=None,
                            src_key_padding_mask=src_pad_mask, tgt_key_padding_mask=tgt_pad_mask,
                            memory_key_padding_mask=src_pad_mask)

        # [ out_trf : seq_len, batch_size, num_embedding]

        out_to_logit = self.linear_out(out_trf)

        # final_out : [ seq_len, batch_size, vocab_size ]
        return out_to_logit
示例#12
0
class Seq2SeqTransformer(nn.Module):
    def __init__(self,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 nhead: int,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 dim_feedforward: int = 512,
                 dropout: float = 0.1):
        super(Seq2SeqTransformer, self).__init__()
        self.transformer = Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(
            emb_size, dropout=dropout)

    def forward(self,
                src: Tensor,
                trg: Tensor,
                src_mask: Tensor,
                tgt_mask: Tensor,
                src_padding_mask: Tensor,
                tgt_padding_mask: Tensor,
                memory_key_padding_mask: Tensor):
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
                                src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer.encoder(self.positional_encoding(
                            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer.decoder(self.positional_encoding(
                          self.tgt_tok_emb(tgt)), memory,
                          tgt_mask)
示例#13
0
 def __init__(self,
              d_model=256,
              nhead=4,
              num_encoder_layers=1,
              num_decoder_layers=1,
              dim_feedforward=1028):
     super().__init__()
     self.transformer = Transformer(d_model, nhead, num_encoder_layers,
                                    num_decoder_layers, dim_feedforward)
示例#14
0
 def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, 
             dim_feedforward=2048, dropout=0.1, activation='relu', custom_encoder=None, custom_decoder=None,
             vocab_size=5000):
     super(PTransformer,self).__init__()
     self.transformer=Transformer(d_model=d_model,nhead=nhead,num_encoder_layers=num_encoder_layers,num_decoder_layers=num_decoder_layers,
                                 dim_feedforward=dim_feedforward,dropout=dropout,activation=activation,custom_encoder=custom_encoder,
                                 custom_decoder=custom_decoder)
     self.embedding=nn.Embedding(vocab_size,d_model)
     self.positional_encoding=PositionalEncoding(d_model,dropout=dropout)
     self.linear=nn.Linear(d_model,vocab_size)
示例#15
0
 def __init__(self,
              num_encoder_layers: int,
              num_decoder_layers: int,
              emb_size: int,
              nhead: int,
              src_vocab_size: int,
              tgt_vocab_size: int,
              dim_feedforward: int = 512,
              dropout: float = 0.1):
     super(Seq2SeqTransformer, self).__init__()
     self.transformer = Transformer(d_model=emb_size,
                                    nhead=nhead,
                                    num_encoder_layers=num_encoder_layers,
                                    num_decoder_layers=num_decoder_layers,
                                    dim_feedforward=dim_feedforward,
                                    dropout=dropout)
     self.generator = nn.Linear(emb_size, tgt_vocab_size)
     self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
     self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
     self.positional_encoding = PositionalEncoding(emb_size,
                                                   dropout=dropout)
示例#16
0
    def __init__(self, en_vocab_size, de_vocab_size, padding_idx, max_len,
                 embed_size, device):
        super(Transformer_fr, self).__init__()
        self.en_vocab = en_vocab_size
        self.de_vocab = de_vocab_size
        self.padd = padding_idx
        self.BOS = 1
        self.EOS = 2
        self.device = device

        #self.encode = Pos_encoding(embed_size, max_len, device)
        #self.en_emb = nn.Embedding(self.en_vocab, embed_size, padding_idx = 0)
        #self.de_emb = nn.Embedding(self.de_vocab, embed_size, padding_idx = 0)

        self.en_enc = Encoding(self.en_vocab, embed_size, max_len, 0.2, device)
        self.de_enc = Encoding(self.de_vocab, embed_size, max_len, 0.2, device)

        self.transformer = Transformer()
        self.fc = nn.Linear(embed_size, self.de_vocab)

        self.scale = embed_size**0.5
示例#17
0
	def __init__(self, num_tokens = 30, d_model = 30, nhead = 3, num_encoder_layers = 2, num_decoder_layers = 2, dim_feedforward = 512, dropout = 0.1, embed = True, max_len = 2000):
		super().__init__()
		self.num_tokens = num_tokens
		self.d_model = d_model
		self.nhead = nhead
		self.num_encoder_layers = num_encoder_layers
		self.num_decoder_layers = num_decoder_layers
		self.dim_feedforward = dim_feedforward
		self.dropout = dropout
		self.embed = embed
		self.max_len = max_len

		self.embedding = nn.Embedding(self.num_tokens, self.d_model)
		self.pos_encoder = PositionalEncoding(self.d_model, self.dropout, self.max_len)
		self.transformer = Transformer(d_model = self.d_model, nhead = self.nhead, num_encoder_layers = self.num_encoder_layers, num_decoder_layers = self.num_decoder_layers, dim_feedforward = self.dim_feedforward, dropout = self.dropout)
示例#18
0
    def __init__(self,
                 ntokens_src,
                 ntokens_tgt,
                 ninp,
                 nhead,
                 dim_feedforward,
                 nlayers,
                 pad_token,
                 dropout=0.5):
        super(TransformerModel, self).__init__()
        from torch.nn import Transformer
        self.model_type = 'Transformer'
        self.ninp = ninp
        self.pad_token = pad_token
        self.masks = {
            'src': None,
            'tgt': None,
            'memory': None,
        }
        # Token Encoders
        self.src_encoder = nn.Embedding(ntokens_src, ninp)
        self.tgt_encoder = nn.Embedding(ntokens_tgt, ninp)
        # Positional Encoding
        self.pos_encoder = PositionalEncoding(ninp, dropout)
        # Transformer
        self.transformer = Transformer(
            d_model=ninp,
            nhead=nhead,
            num_encoder_layers=nlayers,
            num_decoder_layers=nlayers,
            dropout=dropout,
            dim_feedforward=dim_feedforward,
        )
        self.out = nn.Linear(ninp, ntokens_tgt)

        self.init_weights()
 def __init__(self,
              in_size,
              hidden_size,
              out_size,
              n_layers,
              nhead=4,
              dropout=0.1):
     super(TrfmSmiles, self).__init__()
     self.in_size = in_size
     self.hidden_size = hidden_size
     self.embed = Embedding(in_size, hidden_size)
     self.pe = PositionalEncoding(hidden_size, dropout)
     self.trfm = Transformer(d_model=hidden_size,
                             nhead=nhead,
                             num_encoder_layers=n_layers,
                             num_decoder_layers=n_layers,
                             dim_feedforward=hidden_size)
     self.out = Linear(hidden_size, out_size)
示例#20
0
    def __init__(self,
                 d_model=256,
                 nhead=4,
                 num_encoder_layers=1,
                 num_decoder_layers=1,
                 dim_feedforward=1028):
        """

        Note dim_feedforwards satisfies 512*4 = 2048.
        If d_model=256 then 256*4 = 1024
        It might be useful to keep that.


        :param d_model:
        :param nhead:
        :param num_encoder_layers:
        :param num_decoder_layers:
        :param dim_feedforward:
        """
        super().__init__()
        self.transformer = Transformer(d_model, nhead, num_encoder_layers,
                                       num_decoder_layers, dim_feedforward)
    def __init__(self,
                 d_model=512,
                 vocabulary_size=30,
                 max_seq_len=75,
                 decoder_max_seq_len=35,
                 nhead=8,
                 num_encoder_layers=6,
                 num_decoder_layers=6,
                 dim_feedforward=2048,
                 dropout=0.1,
                 activation="relu",
                 stn_on=True,
                 bs=16,
                 ts=75):
        super(Transformer_model, self).__init__()
        self.stn_on = stn_on

        self.cfe = Convolutional_Feature_Extractor(d_model)  #(n,t,e)
        self.positionEncoding = PositionalEncoding(max_seq_len, d_model)
        self.decodeEmbedding = Embedding(vocabulary_size, d_model)
        self.decoder_max_seq_len = decoder_max_seq_len
        self.transformer = Transformer(d_model=d_model,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout,
                                       activation=activation)
        self.linear = Linear(d_model, vocabulary_size)
        self.ln = LayerNorm(d_model)
        self.stn = STNHead(in_planes=3, num_ctrlpoints=20, activation=None)
        self.tps = TPSSpatialTransformer(output_image_size=(50, 100),
                                         num_control_points=20,
                                         margins=tuple([0.05, 0.05]))
        self.bs = bs
        self.ts = ts
示例#22
0
    def __init__(self, cfg):
        super(Model, self).__init__()
        self.cfg = cfg
        self.stages = {'Trans': cfg['model']['transform'], 'Feat': cfg['model']['extraction'],
                       'Seq': cfg['model']['sequence'], 'Pred': cfg['model']['prediction']}
        
        

        """ Transformation """
        if cfg['model']['transform'] == 'TPS':
            self.Transformation = TPS_SpatialTransformerNetwork(
                F=cfg['transform']['num_fiducial'],
                I_size=(cfg['dataset']['imgH'],cfg['dataset']['imgW']),
                I_r_size=(cfg['dataset']['imgH'], cfg['dataset']['imgW']),
                I_channel_num=cfg['model']['input_channel'])
            print ("Transformation moduls : {}".format(cfg['model']['transform']))
            
        else:
            print('No Transformation module specified')
            
            
            
        """ FeatureExtraction """
        if cfg['model']['extraction'] == 'VGG':
            self.FeatureExtraction = VGG_FeatureExtractor(cfg['model']['input_channel'], cfg['model']['output_channel'])
            
        elif cfg['model']['extraction'] == 'RCNN':
            self.FeatureExtraction = RCNN_FeatureExtractor(cfg['model']['input_channel'], cfg['model']['output_channel'])
            
        elif cfg['model']['extraction'] == 'ResNet':
            self.FeatureExtraction = ResNet_FeatureExtractor(cfg['model']['input_channel'], cfg['model']['output_channel'])
            
        else:
            raise Exception('No FeatureExtraction module specified')
            
        self.FeatureExtraction_output = cfg['model']['output_channel']  # int(imgH/16-1) * 512
        self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1))  # Transform final (imgH/16-1) -> 1
        print ('Feature extractor : {}'.format(cfg['model']['extraction']))

        
        """ Sequence modeling"""
        if cfg['model']['sequence'] == 'BiLSTM':
            self.SequenceModeling = nn.Sequential(
                BidirectionalLSTM(self.FeatureExtraction_output, cfg['model']['hidden_size'], cfg['model']['hidden_size']),
                BidirectionalLSTM(cfg['model']['hidden_size'], cfg['model']['hidden_size'], cfg['model']['hidden_size']))
            self.SequenceModeling_output = cfg['model']['hidden_size']
        
        # SequenceModeling : Transformer
        elif cfg['model']['sequence'] == 'Transformer':
            self.SequenceModeling = Transformer(
                d_model=self.FeatureExtraction_output, 
                nhead=2, 
                num_encoder_layers=2,
                num_decoder_layers=2,
                dim_feedforward=cfg['model']['hidden_size'],
                dropout=0.1,
                activation='relu')
            print('SequenceModeling: Transformer initialized.')
            self.SequenceModeling_output = self.FeatureExtraction_output # 입력의 차원과 같은 차원으로 출력 됨
        
        else:
            print('No SequenceModeling module specified')
            self.SequenceModeling_output = self.FeatureExtraction_output
        print('Sequence modeling : {}'.format(cfg['model']['sequence']))

        
        
        """ Prediction """
        if cfg['model']['prediction'] == 'CTC':
            self.Prediction = nn.Linear(self.SequenceModeling_output, cfg['training']['num_class'])
            
        elif cfg['model']['prediction'] == 'Attn':
            self.Prediction = Attention(self.SequenceModeling_output, cfg['model']['hidden_size'], cfg['training']['num_class'])
            
        elif cfg['model']['prediction'] == 'Transformer':
            self.Prediction = nn.Linear(self.SequenceModeling_output, cfg['training']['num_class'])
            
        else:
            raise Exception('Prediction should be in [CTC | Attn | Transformer]')
        
        print ("Prediction : {}".format(cfg['model']['prediction']))
示例#23
0
class TransformerModel(nn.Module):
    def __init__(self,
                 vocab_size,
                 d_model,
                 num_attention_heads,
                 num_encoder_layers,
                 num_decoder_layers,
                 intermediate_size,
                 max_len,
                 dropout=0.1):
        super(TransformerModel, self).__init__()

        self.token_embeddings = nn.Embedding(vocab_size, d_model)
        self.position_embeddings = PositionalEncoding(d_model, max_len)
        self.hidden_size = d_model
        self.dropout = nn.Dropout(p=dropout)

        self.transformer = Transformer(d_model=d_model,
                                       nhead=num_attention_heads,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=intermediate_size,
                                       dropout=dropout)

        self.decoder_embeddings = nn.Linear(d_model, vocab_size)
        self.decoder_embeddings.weight = self.token_embeddings.weight

        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        self.token_embeddings.weight.data.uniform_(-initrange, initrange)
        self.decoder_embeddings.bias.data.zero_()
        self.decoder_embeddings.weight.data.uniform_(-initrange, initrange)

    def forward(self,
                src,
                tgt,
                src_key_padding_mask=None,
                tgt_key_padding_mask=None):
        src_embeddings = self.token_embeddings(src) * math.sqrt(
            self.hidden_size) + self.position_embeddings(src)
        src_embeddings = self.dropout(src_embeddings)

        tgt_embeddings = self.token_embeddings(tgt) * math.sqrt(
            self.hidden_size) + self.position_embeddings(tgt)
        tgt_embeddings = self.dropout(tgt_embeddings)

        tgt_mask = self.transformer.generate_square_subsequent_mask(
            tgt.size(0)).to(tgt.device)
        output = self.transformer(src_embeddings,
                                  tgt_embeddings,
                                  tgt_mask=tgt_mask,
                                  src_key_padding_mask=src_key_padding_mask,
                                  tgt_key_padding_mask=tgt_key_padding_mask)

        output = self.decoder_embeddings(output)
        return output

    def encode(self, src, src_key_padding_mask=None):
        src_embeddings = self.token_embeddings(src) * math.sqrt(
            self.hidden_size) + self.position_embeddings(src)
        src_embeddings = self.dropout(src_embeddings)

        memory = self.transformer.encoder(
            src_embeddings, src_key_padding_mask=src_key_padding_mask)
        return memory

    def decode(self,
               tgt,
               memory,
               tgt_key_padding_mask=None,
               memory_key_padding_mask=None):
        tgt_embeddings = self.token_embeddings(tgt) * math.sqrt(
            self.hidden_size) + self.position_embeddings(tgt)
        tgt_embeddings = self.dropout(tgt_embeddings)
        tgt_mask = self.transformer.generate_square_subsequent_mask(
            tgt.size(0)).to(tgt.device)

        output = self.transformer.decoder(
            tgt_embeddings,
            memory,
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=memory_key_padding_mask)
        output = self.decoder_embeddings(output)
        return output
示例#24
0
class MyTransformer(Model):
    def __init__(
        self,
        vocab: Vocabulary,
        source_embedder: TextFieldEmbedder,
        transformer: Dict,
        max_decoding_steps: int,
        target_namespace: str,
        target_embedder: TextFieldEmbedder = None,
        use_bleu: bool = True,
    ) -> None:
        super().__init__(vocab)
        self._target_namespace = target_namespace

        self._start_index = self.vocab.get_token_index(START_SYMBOL,
                                                       self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL,
                                                     self._target_namespace)
        self._pad_index = self.vocab.get_token_index(self.vocab._padding_token,
                                                     self._target_namespace)

        if use_bleu:
            self._bleu = BLEU(exclude_indices={
                self._pad_index, self._end_index, self._start_index
            })
        else:
            self._bleu = None
        self._seq_acc = SequenceAccuracy()

        self._max_decoding_steps = max_decoding_steps

        self._source_embedder = source_embedder

        self._ndim = transformer["d_model"]
        self.pos_encoder = PositionalEncoding(self._ndim,
                                              transformer["dropout"])

        num_classes = self.vocab.get_vocab_size(self._target_namespace)

        self._transformer = Transformer(**transformer)
        self._transformer.apply(inplace_relu)

        if target_embedder is None:
            self._target_embedder = self._source_embedder
        else:
            self._target_embedder = target_embedder

        self._output_projection_layer = Linear(self._ndim, num_classes)

    def _get_mask(self, meta_data):
        mask = torch.zeros(1, len(meta_data),
                           self.vocab.get_vocab_size(
                               self._target_namespace)).float()
        for bidx, md in enumerate(meta_data):
            for k, v in self.vocab._token_to_index[
                    self._target_namespace].items():
                if 'position' in k and k not in md['avail_pos']:
                    mask[:, bidx, v] = float('-inf')
        return mask

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == False,
                                        float('-inf')).masked_fill(
                                            mask == True, float(0.0))
        return mask

    @overrides
    def forward(
        self,
        source_tokens: Dict[str, torch.LongTensor],
        target_tokens: Dict[str, torch.LongTensor] = None,
        meta_data: Any = None,
    ) -> Dict[str, torch.Tensor]:
        src, src_key_padding_mask = self._encode(self._source_embedder,
                                                 source_tokens)
        memory = self._transformer.encoder(
            src, src_key_padding_mask=src_key_padding_mask)

        if meta_data is not None:
            target_vocab_mask = self._get_mask(meta_data)
            target_vocab_mask = target_vocab_mask.to(memory.device)
        else:
            target_vocab_mask = None
        output_dict = {}
        targets = None
        if target_tokens:
            targets = target_tokens["tokens"][:, 1:]
            target_mask = (util.get_text_field_mask({"tokens": targets}) == 1)
            assert targets.size(1) <= self._max_decoding_steps
        if self.training and target_tokens:
            tgt, tgt_key_padding_mask = self._encode(
                self._target_embedder,
                {"tokens": target_tokens["tokens"][:, :-1]})
            tgt_mask = self.generate_square_subsequent_mask(tgt.size(0)).to(
                memory.device)
            output = self._transformer.decoder(
                tgt,
                memory,
                tgt_mask=tgt_mask,
                tgt_key_padding_mask=tgt_key_padding_mask,
                memory_key_padding_mask=src_key_padding_mask)
            logits = self._output_projection_layer(output)
            if target_vocab_mask is not None:
                logits += target_vocab_mask
            class_probabilities = F.softmax(logits.detach(), dim=-1)
            _, predictions = torch.max(class_probabilities, -1)
            logits = logits.transpose(0, 1)
            loss = self._get_loss(logits, targets, target_mask)
            output_dict["loss"] = loss
        else:
            assert self.training is False
            output_dict["loss"] = torch.tensor(0.0).to(memory.device)
            if targets is not None:
                max_target_len = targets.size(1)
            else:
                max_target_len = None
            predictions, class_probabilities = self._decoder_step_by_step(
                memory,
                src_key_padding_mask,
                target_vocab_mask,
                max_target_len=max_target_len)
        predictions = predictions.transpose(0, 1)
        output_dict["predictions"] = predictions
        output_dict["class_probabilities"] = class_probabilities.transpose(
            0, 1)

        if target_tokens:
            with torch.no_grad():
                best_predictions = output_dict["predictions"]
                if self._bleu:
                    self._bleu(best_predictions, targets)
                batch_size = targets.size(0)
                max_sz = max(best_predictions.size(1), targets.size(1),
                             target_mask.size(1))
                best_predictions_ = torch.zeros(batch_size,
                                                max_sz).to(memory.device)
                best_predictions_[:, :best_predictions.
                                  size(1)] = best_predictions
                targets_ = torch.zeros(batch_size, max_sz).to(memory.device)
                targets_[:, :targets.size(1)] = targets.cpu()
                target_mask_ = torch.zeros(batch_size,
                                           max_sz).to(memory.device)
                target_mask_[:, :target_mask.size(1)] = target_mask
                self._seq_acc(best_predictions_.unsqueeze(1), targets_,
                              target_mask_)
        return output_dict

    @overrides
    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        predicted_indices = output_dict["predictions"]
        if not isinstance(predicted_indices, numpy.ndarray):
            # shape: (batch_size, num_decoding_steps)
            predicted_indices = predicted_indices.detach().cpu().numpy()
            # class_probabilities = output_dict["class_probabilities"].detach().cpu()
            # sample_predicted_indices = []
            # for cp in class_probabilities:
            #     sample = torch.multinomial(cp, num_samples=1)
            #     sample_predicted_indices.append(sample)
            # # shape: (batch_size, num_decoding_steps, num_samples)
            # sample_predicted_indices = torch.stack(sample_predicted_indices)

        all_predicted_tokens = []
        for indices in predicted_indices:
            # Beam search gives us the top k results for each source sentence in the batch
            # but we just want the single best.
            if len(indices.shape) > 1:
                indices = indices[0]
            indices = list(indices)
            # Collect indices till the first end_symbol
            if self._end_index in indices:
                indices = indices[:indices.index(self._end_index)]
            predicted_tokens = [
                self.vocab.get_token_from_index(
                    x, namespace=self._target_namespace) for x in indices
            ]
            all_predicted_tokens.append(predicted_tokens)
        output_dict["predicted_tokens"] = all_predicted_tokens
        return output_dict

    def _encode(
            self, embedder: TextFieldEmbedder,
            tokens: Dict[str,
                         torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
        src = embedder(tokens) * math.sqrt(self._ndim)
        src = src.transpose(0, 1)
        src = self.pos_encoder(src)
        mask = util.get_text_field_mask(tokens)
        mask = (mask == 0)
        return src, mask

    def _decoder_step_by_step(
            self,
            memory: torch.Tensor,
            memory_key_padding_mask: torch.Tensor,
            target_vocab_mask: torch.Tensor = None,
            max_target_len: int = None) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size = memory.size(1)
        if getattr(self, "target_limit_decode_steps",
                   False) and max_target_len is not None:
            num_decoding_steps = min(self._max_decoding_steps, max_target_len)
            print('decoding steps: ', num_decoding_steps)
        else:
            num_decoding_steps = self._max_decoding_steps

        last_predictions = memory.new_full(
            (batch_size, ), fill_value=self._start_index).long()

        step_predictions: List[torch.Tensor] = []
        all_predicts = memory.new_full((batch_size, num_decoding_steps),
                                       fill_value=0).long()
        for timestep in range(num_decoding_steps):
            all_predicts[:, timestep] = last_predictions
            tgt, tgt_key_padding_mask = self._encode(
                self._target_embedder,
                {"tokens": all_predicts[:, :timestep + 1]})
            tgt_mask = self.generate_square_subsequent_mask(timestep + 1).to(
                memory.device)
            output = self._transformer.decoder(
                tgt,
                memory,
                tgt_mask=tgt_mask,
                tgt_key_padding_mask=tgt_key_padding_mask,
                memory_key_padding_mask=memory_key_padding_mask)
            output_projections = self._output_projection_layer(output)
            if target_vocab_mask is not None:
                output_projections += target_vocab_mask

            class_probabilities = F.softmax(output_projections, dim=-1)
            _, predicted_classes = torch.max(class_probabilities, -1)

            # shape (predicted_classes): (batch_size,)
            last_predictions = predicted_classes[timestep, :]
            step_predictions.append(last_predictions)
            if ((last_predictions == self._end_index) +
                (last_predictions == self._pad_index)).all():
                break

        # shape: (num_decoding_steps, batch_size)
        predictions = torch.stack(step_predictions)
        return predictions, class_probabilities

    @staticmethod
    def _get_loss(logits: torch.FloatTensor, targets: torch.LongTensor,
                  target_mask: torch.FloatTensor) -> torch.Tensor:
        logits = logits.contiguous()
        # shape: (batch_size, num_decoding_steps)
        relevant_targets = targets.contiguous()

        # shape: (batch_size, num_decoding_steps)
        relevant_mask = target_mask.contiguous()

        return util.sequence_cross_entropy_with_logits(logits,
                                                       relevant_targets,
                                                       relevant_mask)

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics: Dict[str, float] = {}
        if self._bleu:
            all_metrics.update(self._bleu.get_metric(reset=reset))
        all_metrics['seq_acc'] = self._seq_acc.get_metric(reset=reset)
        return all_metrics

    def load_state_dict(self, state_dict, strict=True):
        new_state_dict = {}
        for k, v in state_dict.items():
            if k.startswith('module.'):
                new_state_dict[k[len('module.'):]] = v
            else:
                new_state_dict[k] = v

        super(MyTransformer, self).load_state_dict(new_state_dict, strict)
示例#25
0
class Transformer_fr(nn.Module):
    def __init__(self, en_vocab_size, de_vocab_size, padding_idx, max_len,
                 embed_size, device):
        super(Transformer_fr, self).__init__()
        self.en_vocab = en_vocab_size
        self.de_vocab = de_vocab_size
        self.padd = padding_idx
        self.BOS = 1
        self.EOS = 2
        self.device = device

        #self.encode = Pos_encoding(embed_size, max_len, device)
        #self.en_emb = nn.Embedding(self.en_vocab, embed_size, padding_idx = 0)
        #self.de_emb = nn.Embedding(self.de_vocab, embed_size, padding_idx = 0)

        self.en_enc = Encoding(self.en_vocab, embed_size, max_len, 0.2, device)
        self.de_enc = Encoding(self.de_vocab, embed_size, max_len, 0.2, device)

        self.transformer = Transformer()
        self.fc = nn.Linear(embed_size, self.de_vocab)

        self.scale = embed_size**0.5

    def gen_src_mask(self, x):
        '''
        x = (B, S)
        src_mask = (B, 1, S_r) --> broadcast
        '''
        #(B,1,S)
        src_mask = (x == self.padd_idx).unsqueeze(1)

        return src_mask.to(self.device)

    def gen_trg_mask(self, x):
        '''
        x = (B,S)
        trg_mask = (B, S, S_r) : triangle
        '''
        batch = x.shape[0]
        seq = x.shape[1]

        #B, 1, S
        #trg_pad = (x == self.padd).unsqueeze(1)
        #1, S, S
        #S, S
        trg_mask = torch.tril(torch.ones(seq, seq))
        trg_mask[trg_mask == 0] = float("-inf")
        trg_mask[trg_mask == 1] = float(0.0)
        #trg_mask = trg_pad | trg_idx
        #print(trg_mask)

        return trg_mask.to(self.device)

    def forward(self, src, trg):

        #src = self.en_emb(src) * self.scale + self.encode(src)
        #trg = self.de_emb(trg) * self.scale+ self.encode(trg)
        trg_seq = trg.size(1)

        src = self.en_enc(src)
        trg = self.de_enc(trg)

        trg_mask = self.transformer.generate_square_subsequent_mask(
            trg_seq).to(self.device)
        #trg_mask = self.gen_trg_mask(trg)

        #print(trg_mask)
        src = src.transpose(0, 1)
        trg = trg.transpose(0, 1)

        output = self.transformer(src, trg, tgt_mask=trg_mask)
        output = output.transpose(0, 1)
        output = self.fc(output)

        #print(src.shape, trg.shape, output.shape)
        return output

    def inference(self, src):
        '''
        x  = (B, S_source)
        return (B, S_target)
        '''

        #in order to paper, max_seq = src seq + 300
        max_seq = src.size(1) + 50
        batch = src.size(0)

        lengths = np.array([max_seq] * batch)
        #outputs = []

        outputs = torch.zeros((batch, 1)).to(torch.long).to(self.device)
        outputs[:, 0] = self.BOS

        for step in range(1, max_seq):
            out = self.forward(src, outputs)

            #out = out.view(batch, max_seq, -1)
            #print(out.shape)
            out = out[:, -1, :]
            pred = torch.topk(F.log_softmax(out), 1, dim=-1)[1]

            outputs = torch.cat([outputs, pred], dim=1)

            eos_batches = pred.data.eq(self.EOS)
            if eos_batches.dim() > 0:
                eos_batches = eos_batches.cpu().view(-1).numpy()
                update_idx = ((lengths > step) & eos_batches) != 0
                lengths[update_idx] = step

        return outputs.detach(), lengths
示例#26
0
class TransformerModel(nn.Module):
    def __init__(self,
                 vocab_size,
                 hidden_size,
                 num_attention_heads,
                 num_encoder_layers,
                 num_decoder_layers,
                 intermediate_size,
                 dropout=0.1):
        super(TransformerModel, self).__init__()

        # self.token_embeddings = nn.Embedding(vocab_size, hidden_size)
        self.token_embeddings = nn.Embedding(vocab_size,
                                             hidden_size,
                                             padding_idx=1)
        self.position_embeddings = PositionalEncoding(hidden_size)
        self.hidden_size = hidden_size
        self.dropout = nn.Dropout(p=dropout)

        self.transformer = Transformer(
            d_model=hidden_size,
            nhead=num_attention_heads,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=intermediate_size,
            dropout=dropout,
        )

        self.decoder_embeddings = nn.Linear(hidden_size, vocab_size)
        self.decoder_embeddings.weight = self.token_embeddings.weight

        self.init_weights()

    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(
            mask == 1, float(0.0))
        return mask

    def init_weights(self):
        initrange = 0.1
        self.token_embeddings.weight.data.uniform_(-initrange, initrange)
        self.decoder_embeddings.bias.data.zero_()
        self.decoder_embeddings.weight.data.uniform_(-initrange, initrange)

    def forward(self,
                src=None,
                tgt=None,
                memory=None,
                src_key_padding_mask=None,
                tgt_key_padding_mask=None,
                memory_key_padding_mask=None):
        if src is not None:
            src_embeddings = self.token_embeddings(src) * math.sqrt(
                self.hidden_size) + self.position_embeddings(src)
            src_embeddings = self.dropout(src_embeddings)

            if src_key_padding_mask is not None:
                src_key_padding_mask = src_key_padding_mask.t()

            if tgt is None:  # encode
                memory = self.transformer.encoder(
                    src_embeddings, src_key_padding_mask=src_key_padding_mask)
                return memory

        if tgt is not None:
            tgt_embeddings = self.token_embeddings(tgt) * math.sqrt(
                self.hidden_size) + self.position_embeddings(tgt)
            tgt_embeddings = self.dropout(tgt_embeddings)
            tgt_mask = self.transformer.generate_square_subsequent_mask(
                tgt.size(0)).to(tgt.device)

            if tgt_key_padding_mask is not None:
                tgt_key_padding_mask = tgt_key_padding_mask.t()

            if src is None and memory is not None:  # decode
                if memory_key_padding_mask is not None:
                    memory_key_padding_mask = memory_key_padding_mask.t()

                output = self.transformer.decoder(
                    tgt_embeddings,
                    memory,
                    tgt_mask=tgt_mask,
                    tgt_key_padding_mask=tgt_key_padding_mask,
                    memory_key_padding_mask=memory_key_padding_mask)
                output = self.decoder_embeddings(output)

                return output

        assert not (src is None and tgt is None)
        output = self.transformer(src_embeddings,
                                  tgt_embeddings,
                                  tgt_mask=tgt_mask,
                                  src_key_padding_mask=src_key_padding_mask,
                                  tgt_key_padding_mask=tgt_key_padding_mask)
        output = self.decoder_embeddings(output)
        return output
示例#27
0
class TransformerModel(nn.Module):
    def __init__(self,
                 ntokens_src,
                 ntokens_tgt,
                 ninp,
                 nhead,
                 dim_feedforward,
                 nlayers,
                 pad_token,
                 dropout=0.5):
        super(TransformerModel, self).__init__()
        from torch.nn import Transformer
        self.model_type = 'Transformer'
        self.ninp = ninp
        self.pad_token = pad_token
        self.masks = {
            'src': None,
            'tgt': None,
            'memory': None,
        }
        # Token Encoders
        self.src_encoder = nn.Embedding(ntokens_src, ninp)
        self.tgt_encoder = nn.Embedding(ntokens_tgt, ninp)
        # Positional Encoding
        self.pos_encoder = PositionalEncoding(ninp, dropout)
        # Transformer
        self.transformer = Transformer(
            d_model=ninp,
            nhead=nhead,
            num_encoder_layers=nlayers,
            num_decoder_layers=nlayers,
            dropout=dropout,
            dim_feedforward=dim_feedforward,
        )
        self.out = nn.Linear(ninp, ntokens_tgt)

        self.init_weights()

    def _generate_square_subsequent_mask(self, sx, sy=None):
        """Generate matrix for seqential reveal of tokens."""
        sy = sy or sx
        mask = (torch.triu(torch.ones((sx, sy))) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(
            mask == 1, float(0.0))
        return mask

    def init_weights(self):
        self.transformer._reset_parameters()

    def preprocess(self, x, x_type):
        # Create masks
        padding_mask = (x == self.pad_token).bool().t()
        if self.masks[x_type] is None or self.masks[x_type].size(0) != len(x):
            self.masks[x_type] = self._generate_square_subsequent_mask(
                len(x), len(x)).to(x.device)

        x_enc = self.src_encoder(x) if x_type == 'src' else self.tgt_encoder(x)
        x_enc *= math.sqrt(self.ninp)
        x_enc = self.pos_encoder(x_enc)

        return x_enc, self.masks[x_type], padding_mask

    def forward(self, src, tgt):

        # TODO: Do we need memory mask?
        if (self.masks['memory'] is None
                or self.masks['src'].size(0) != len(src)
                or self.masks['tgt'].size(0) != len(tgt)):
            self.masks['memory'] = self._generate_square_subsequent_mask(
                len(src), len(tgt)).to(src.device)

        src_enc, _, src_key_padding_mask = self.preprocess(src, 'src')
        tgt_enc, _, tgt_key_padding_mask = self.preprocess(tgt, 'tgt')
        memory_key_padding_mask = src_key_padding_mask.clone().detach()

        output = self.transformer(
            src_enc,
            tgt_enc,
            src_mask=self.masks['src'],
            tgt_mask=self.masks['tgt'],
            memory_mask=self.masks['memory'],
            src_key_padding_mask=src_key_padding_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=memory_key_padding_mask,
        )
        output = self.out(output)
        return output
示例#28
0
# X_train = X_train.reshape(S, N ,E)
# N,E,T = Y_train.shape
# Y_train = Y_train.reshape(T, N ,E)
# N,S,E = X_test.shape
# X_test = X_test.reshape(S, N ,E)
# N,E,T = Y_test.shape
# Y_test = Y_test.reshape(T, N ,E)
# print(X_train.shape)
# print((X_train[0,0,0]))

X_train = process_dataX(X_train)
Y_train = process_dataY(Y_train)
X_test = process_dataX(X_test)
Y_test = process_dataY(Y_test)

model = Transformer(d_model=6, nhead=6)

training_avg_losses, evaluating_avg_losses = train_loop(X_train,
                                                        X_test,
                                                        Y_train,
                                                        Y_test,
                                                        model,
                                                        loop_n=50)
print("training_avg_losses: ", training_avg_losses)
print("evaluating_avg_losses: ", evaluating_avg_losses)
#plot
epochs = [i + 1 for i in range(len(training_avg_losses))]
plt.plot(epochs, training_avg_losses, label="Training AVG Loss")
plt.legend(loc='upper right')
plt.title(f'Transformer Training Losses')
plt.savefig(f'Train_Transformer_Losses')
class SimpleTransformerModel(nn.Module):
    def __init__(self, nb_tokens: int, emb_size: int, nb_layers=2, nb_heads=4, hid_size=512, dropout=0.25, max_len=30):
        super(SimpleTransformerModel, self).__init__()
        from torch.nn import Transformer
        self.emb_size = emb_size
        self.max_len = max_len

        self.pos_encoder = PositionalEncoding(emb_size, dropout=dropout, max_len=max_len)
        self.embedder = nn.Embedding(nb_tokens, emb_size)

        self.transformer = Transformer(d_model=emb_size, nhead=nb_heads, num_encoder_layers=nb_layers,
                                       num_decoder_layers=nb_layers, dim_feedforward=hid_size, dropout=dropout)

        self.out_lin = nn.Linear(in_features=emb_size, out_features=nb_tokens)

        self.tgt_mask = None

    def _generate_square_subsequent_mask(self, sz):
        mask = torch.triu(torch.ones(sz, sz), diagonal=1).to(device)
        mask = mask.masked_fill(mask == 1, float('-inf'))
        return mask

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def enc_forward(self, src):
        # Embed source
        src = self.embedder(src) * math.sqrt(self.emb_size)
        # Add positional encoding + reshape into format (seq element, batch element, embedding)
        src = self.pos_encoder(src.view(src.shape[0], 1, src.shape[1]))

        # Push through encoder
        output = self.transformer.encoder(src)

        return output

    def dec_forward(self, memory, tgt):
        # Generate target mask, if necessary
        if self.tgt_mask is None or self.tgt_mask.size(0) != len(tgt):
            mask = self._generate_square_subsequent_mask(len(tgt)).to(device)
            self.tgt_mask = mask

        # Embed target
        tgt = self.embedder(tgt) * math.sqrt(self.emb_size)
        # Add positional encoding + reshape into format (seq element, batch element, embedding)
        tgt = self.pos_encoder(tgt.view(tgt.shape[0], 1, tgt.shape[1]))

        # Push through decoder + linear output layer
        output = self.out_lin(self.transformer.decoder(memory=memory, tgt=tgt, tgt_mask=self.tgt_mask))
        # If using the model to evaluate, also take softmax of final layer to obtain probabilities
        if not self.training:
            output = torch.nn.functional.softmax(output, 2)

        return output

    def forward(self, src, tgt):
        memory = self.enc_forward(src)
        output = self.dec_forward(memory, tgt)

        return output

    def greedy_decode(self, src, max_len=None, start_symbol=0, stop_symbol=None):
        """
        Greedy decode input "src": generate output character one at a time, until "stop_symbol" is generated or
        the output exceeds max_len, whichever comes first.

        :param src: input src, 1D tensor
        :param max_len: int
        :param start_symbol: int
        :param stop_symbol: int
        :return: decoded output sequence
        """
        b_training = self.training
        if b_training:
            self.eval()

        if max_len is None:
            max_len = self.max_len
        elif max_len > self.max_len:
            raise ValueError(f"Parameter 'max_len' can not exceed model's own max_len,"
                             f" which is set at {self.max_len}.")
        # Get memory = output from encoder layer
        memory = model.enc_forward(src)
        # Initiate output buffer
        idxs = [start_symbol]
        # Keep track of last predicted symbol
        next_char = start_symbol
        # Keep generating output until "stop_symbol" is generated, or max_len is reached
        while next_char != stop_symbol:
            if len(idxs) == max_len:
                break
            # Convert output buffer to tensor
            ys = torch.LongTensor(idxs).to(device)
            # Push through decoder
            out = self.dec_forward(memory=memory, tgt=ys)
            # Get position of max probability of newly predicted character
            _, next_char = torch.max(out[-1, :, :], dim=1)
            next_char = next_char.item()

            # Append generated character to output buffer
            idxs.append(next_char)

        if b_training:
            self.train()

        return idxs