-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
105 lines (79 loc) · 2.53 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# -*- coding: UTF-8 -*-
import torch
import torch.nn.functional as F
import torchtext
import time
import opts
from utils import load_data, clip_gradient, evaluate, validate
import models
from visualize import Visualizer
from tqdm import tqdm
# get option
opt = opts.parse_opt()
# opt.use_cuda = torch.cuda.is_available()
opt.use_cuda = True
# select model
# opt.model = 'lstm'
#opt.model = 'cnn'
opt.model = 'lstm'
opt.env = opt.model
# visdom
vis = Visualizer(opt.env)
# vis log output
vis.log('user config:')
for k, v in opt.__dict__.items():
if not k.startswith('__'):
vis.log('{} {}'.format(k, getattr(opt, k)))
# load data
# use torchtext to load
train_iter, test_iter = load_data(opt)
model = models.init(opt)
print(type(model))
# cuda
if opt.use_cuda:
model.cuda()
# start trainning
model.train()
# set optimizer
optim = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=opt.learning_rate)
optim.zero_grad()
# use cross_entropy loss for classification
criterion = F.cross_entropy
# save best model use accuracy evaluation metrics
best_accuaracy = 0
for i in range(opt.max_epoch):
for train_epoch, batch in enumerate(train_iter):
start_epoch = time.time()
# print(batch.label.size())
# print(opt.batch_size)
# # if batch.label.size()[0] != opt.batch_size:
# # continue
# for torchtext
text = batch.text[0]
pred = model(text)
loss = criterion(pred, batch.label)
loss.backward()
# trainint trick : clip_gradient
# https://blog.csdn.net/u010814042/article/details/76154391
# solve Gradient explosion problem
clip_gradient(optimizer=optim, grad_clip=opt.grad_clip)
# step optimizer
optim.step()
# plot for loss and accuracy
if train_epoch % 50 == 0:
if opt.use_cuda:
loss_data = loss.cpu().data[0]
else:
loss_data = loss.data[0]
print("{} EPOCH {} batch: train loss {}".format(i, train_epoch, loss_data))
# vis loss
vis.plot('loss', loss_data)
# evaluate on test for this epoch
accuracy = evaluate(model, test_iter, opt)
vis.log("{} EPOCH, accuaracy : {}".format(i, accuracy))
vis.plot('accuracy', accuracy)
# handel best model, update best model , best_lstm.pth
if accuracy > best_accuaracy:
best_accuaracy = accuracy
torch.save(model.state_dict(), './best_{}.pth'.format(opt.model))
print('best accuracy: {}'.format(best_accuaracy))