-
Notifications
You must be signed in to change notification settings - Fork 0
/
training.py
74 lines (52 loc) · 1.61 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from network import Net
def imshow(img):
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
PATH = './models/mnist_net.pth'
transform = transforms.Compose(
[transforms.ToTensor()]
)
trainset = torchvision.datasets.MNIST("./training-data", train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
if __name__ == "__main__":
# data_iter = iter(trainloader)
# # data_iter.next() takes next batch (size=4)
# images, labels = data_iter.next()
# print(labels)
# imshow(torchvision.utils.make_grid(images, padding=0))
# create new network, optimizer, and loss criterion
net = Net()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
loss_fn = torch.nn.MSELoss()
# training loop
for epoch in range(2):
running_loss = 0.0
i = 0
with tqdm(trainloader) as tqdm_iterator:
for data in tqdm_iterator:
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
target = net.get_target(labels)
loss = loss_fn(outputs, target)
loss.backward()
optimizer.step()
i += 1
running_loss += loss.item()
if i > 1000:
i = 0
tqdm_iterator.set_description(f"{running_loss:5f}")
running_loss = 0.0
print('Finished Training')
# torch.save(net.state_dict(), PATH)