-
Notifications
You must be signed in to change notification settings - Fork 0
/
supervised.py
104 lines (82 loc) · 3.02 KB
/
supervised.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#pylint:disable=E1101
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.distributions import Categorical
import numpy as np
import pandas as pd
import numpy as numpy
import matplotlib.pyplot as plt
import argparse
from models import MyModel
from math_dataset import MyDataset
def main():
_i, _j, _k = 2,3,3
dataset = MyDataset(_i,_j,_k)
dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0")
#batch, input, hidden, output
N, D_in, H, D_out = 10, _i+_j+_k, 16, _i*_j*_k
msg_len = 10
x, y = dataset.get_frame()
x = torch.tensor(x, dtype=dtype, device=device)
#x = torch.cat((x,x,x,x,x),0)
y = torch.tensor(y, dtype=torch.long, device=device).squeeze()
#y = torch.cat((y,y,y,y,y),0)
print(x.size(), y.size())
#x = torch.zeros(N, D_in, device=device, dtype=dtype)
#y = torch.zeros(N, device=device, dtype=dtype)
model = MyModel(D_in, H, D_out)
#model = torch.nn.Linear(D_in, D_out)
loss_fn = torch.nn.CrossEntropyLoss(reduce=None)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
for t in range(10001):
if True: #reinforce
y_pred = model(x)
probs = F.softmax(y_pred, dim=1)
m = Categorical(probs)
action = m.sample()
reward = torch.eq(action, y).to(torch.float)
reward = (reward - reward.mean())
loss = -m.log_prob(action) * reward
model.zero_grad()
loss.sum().backward()
#loss.backward(loss)
optimizer.step()
elif True:
y_pred = model(x)
else: # supervised
y_pred = model(x)
loss = loss_fn(y_pred, y)
model.zero_grad()
loss.backward()
optimizer.step()
if t % 100 == 0:
with torch.no_grad():
y_pred = model(x)
eq = torch.eq(torch.argmax(y_pred, dim=1), y)
print("t: {}, acc: {}/{} = {}".format(t, torch.sum(eq).item(), eq.numel(), torch.sum(eq).item() / eq.numel()))
torch.save({'epoch': t,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss
}, "checkpoints.tar")
if __name__ == "__main__":
main()
# model3 = MyModel(D_in, H, D_out)
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
# checkpoint = torch.load("checkpoints.tar")
# model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# epoch = checkpoint['epoch']
# loss = checkpoint['loss']
# print(model.state_dict())
# print(optimizer.state_dict())
# PATH = "model.pt"
# torch.save(model.state_dict(), PATH)
# model2 = MyModel(D_in, H, D_out)
# model.load_state_dict(torch.load(PATH))
# model.eval() # for dropout and BN