示例#1
0
encoder_stacks = Encoder(d_model=32,
                         d_inner=64,
                         n_layers=2,
                         n_head=4,
                         d_k=16,
                         d_v=16,
                         dropout=0.1)

criterion = torch.nn.MSELoss().to(device)
optimizer = torch.optim.SGD(encoder_stacks.parameters(), lr=1)

src = torch.rand(1, 2, 32, requires_grad=True)
tgt = torch.rand(1, 2, 32)

print(src)

encoder_stacks.train()

for i in range(100):
    out, attn = encoder_stacks.forward(src, src_mask=None)
    loss = criterion(out, tgt)
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(encoder_stacks.parameters(), 0.5)
    optimizer.step()
    print(loss.item())

print("out:", out)
print("tgt:", tgt)
print("attn:", attn)
示例#2
0
class TransformerEncoder(torch.nn.Module):
    def __init__(self, in_channels=13, len_max_seq=100,
            d_word_vec=512, d_model=512, d_inner=2048,
            n_layers=6, n_head=8, d_k=64, d_v=64,
            dropout=0.2, nclasses=6):

        super(TransformerEncoder, self).__init__()

        self.d_model = 512

        self.inlayernorm = nn.LayerNorm(in_channels)
        self.convlayernorm = nn.LayerNorm(d_model)
        self.outlayernorm = nn.LayerNorm(d_model)

        self.inconv = torch.nn.Conv1d(in_channels, d_model, 1)

        self.encoder = Encoder(
            n_src_vocab=None, len_max_seq=len_max_seq,
            d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner,
            n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v,
            dropout=dropout)

        self.outlinear = nn.Linear(d_model, nclasses, bias=False)

        self.tempmaxpool = nn.MaxPool1d(len_max_seq)

        self.logsoftmax = nn.LogSoftmax(dim=-1)

    def _logits(self, x):
        # b,d,t - > b,t,d
        x = x.transpose(1,2)

        x = self.inlayernorm(x)

        # b,
        x = self.inconv(x.transpose(1,2)).transpose(1,2)

        x = self.convlayernorm(x)

        batchsize, seq, d = x.shape
        src_pos = torch.arange(1, seq + 1, dtype=torch.long).expand(batchsize, seq)

        if torch.cuda.is_available():
            src_pos = src_pos.cuda()

        enc_output, enc_slf_attn_list = self.encoder.forward(src_seq=x, src_pos=src_pos, return_attns=True)

        enc_output = self.outlayernorm(enc_output)

        enc_output = self.tempmaxpool(enc_output.transpose(1, 2)).squeeze(-1)

        logits = self.outlinear(enc_output)

        return logits

    def forward(self, x):

        logits = self._logits(x)

        logprobabilities = self.logsoftmax(logits)

        return logprobabilities

    def save(self, path="model.pth", **kwargs):
        print("\nsaving model to "+path)
        model_state = self.state_dict()
        os.makedirs(os.path.dirname(path), exist_ok=True)
        torch.save(dict(model_state=model_state,**kwargs),path)

    def load(self, path):
        print("loading model from "+path)
        snapshot = torch.load(path, map_location="cpu")
        model_state = snapshot.pop('model_state', snapshot)
        self.load_state_dict(model_state)
        return snapshot