/
main_TC.py
102 lines (81 loc) · 3.43 KB
/
main_TC.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import time
import torch
import torch.nn as nn
import torch.optim as optim
import os
from ordmatch import OrdMatch
from corpus import Corpus
from config import args
from evaluate import evaluation, accuracy
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
torch.set_num_threads(3)
torch.manual_seed(args.seed)
if torch.cuda.is_available():
if not args.cuda:
print("WARNING: You have a CUDA device, so you should probably run with --cuda")
else:
torch.cuda.manual_seed(args.seed)
corpus = Corpus(args)
print("Corpus built.")
model = OrdMatch(corpus, args)
model.train()
criterion = nn.NLLLoss()
parameters = filter(lambda p: p.requires_grad, model.parameters())
optimizer = optim.Adamax(parameters, lr=args.lr)
if args.cuda:
model.cuda()
criterion.cuda()
start_time = time.time()
total_loss = 0
total_loss1 = 0
total_loss2 = 0
interval = args.interval
save_interval = len(corpus.data_all['train']) // args.batch_size
best_dev_score = -99999
iterations = args.epochs*len(corpus.data_all['train']) // args.batch_size
print('max iterations: '+str(iterations))
for iter in range(iterations):
optimizer.zero_grad()
data = corpus.get_batch(args.batch_size, 'train',div=True)
output1, output2, _ = model(data)
labels = data[2].cuda() if args.cuda else data[2]
# predict only with matching score
_, pred = output1.max(1)
score = accuracy(labels, pred)
loss1 = criterion(output1, labels)
loss2 = criterion(output2, labels)
#train via interploated loss of matching and ordering score
loss = (1 - args.lamda) * loss1 + args.lamda * loss2
loss.backward()
optimizer.step()
optimizer.zero_grad()
total_loss += float(loss.data)
total_loss1 += float(loss1.data)
total_loss2 += float(loss2.data)
if iter % interval == 0:
cur_loss = total_loss / interval if iter!=0 else total_loss
cur_loss1 = total_loss1 / interval if iter != 0 else total_loss1
cur_loss2 = total_loss2 / interval if iter != 0 else total_loss2
elapsed = time.time() - start_time
print('| iterations {:3d} | start_id {:3d} | ms/batch {:5.2f} | loss {:5.3f} loss1 {:5.3f} loss2 {:5.3f} '.format(
iter, corpus.start_id['train'], elapsed * 1000 / interval, cur_loss, cur_loss1, cur_loss2))
total_loss = 0
total_loss1 = 0
total_loss2 = 0
start_time = time.time()
if iter % save_interval == 0 :
save_path = os.path.join(args.save_path, args.task)
if not os.path.exists(save_path):
os.mkdir(save_path)
torch.save([model, optimizer, criterion], os.path.join(save_path, f'save_4_{args.lamda}.pt'))
score = evaluation(model, corpus, args.task, args.batch_size, dataset='val',div=True,reg=True)
print('DEV accuracy: ' + str(score))
with open(os.path.join(save_path,f'record_4_{args.lamda}.txt'), 'a', encoding='utf-8') as fpw:
if iter == 0: fpw.write(str(args) + '\n')
fpw.write(str(iter) + ':\tDEV accuracy:\t' + str(score) + '\n')
if score > best_dev_score:
best_dev_score = score
torch.save([model, optimizer, criterion], os.path.join(save_path, f'save_best_4_{args.lamda}.pt'))
# if (iter+1) % (len(corpus.data_all['train']) // args.batch_size) == 0:
# for param_group in optimizer.param_groups:
# param_group['lr'] *= 0.95