-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
147 lines (132 loc) · 5.39 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import IPython
import time
import math
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch.utils.data import DataLoader
from data_helper import SRLDataSet
from model import UnifiedFramework
trainset = SRLDataSet('SRL_data/data/cpbtrain.txt', 'SRL_data/data/word_dict.json', 'SRL_data/data/pos_dict.json',
'SRL_data/data/label_dict.json', 'SRL_data/data/depend_dict.json',
'SRL_data/data/cpbtrain_tree.txt', is_test=False)
trainloader = DataLoader(dataset=trainset, batch_size=16, shuffle=True)
devset = SRLDataSet('SRL_data/data/cpbdev.txt', 'SRL_data/data/word_dict.json', 'SRL_data/data/pos_dict.json',
'SRL_data/data/label_dict.json', 'SRL_data/data/depend_dict.json',
'SRL_data/data/cpbdev_tree.txt', is_test=False)
devloader = DataLoader(dataset=devset, batch_size=16)
config = {
'vocab_size': max(trainloader.dataset.word2idx.values()) + 1,
'word_embedding_dim': 100,
'pos_embedding_dim': 25,
'depend_embedding_dim': 25,
'pos_set_size': max(trainloader.dataset.pos2idx.values()) + 1,
'depend_set_size': max(trainloader.dataset.depend2idx.values()) + 1,
'gcr_hidden_size': 100,
'gcr_num_layers': 2,
'gpr_hidden_size': 100,
'gpr_num_layers': 2,
'rpr_hidden_size': 100,
'rpr_num_layers': 2,
'feature_size': 200,
'drop_out': 0.3,
'categories': max(trainloader.dataset.label2idx.values()) + 1
}
uf = UnifiedFramework(config)
DEVICE_NO = 1
if DEVICE_NO != -1:
uf = uf.cuda(DEVICE_NO)
optimizer = torch.optim.Adagrad(uf.parameters(), lr=0.01)
criteria = nn.CrossEntropyLoss(ignore_index=0)
log_interval = 50
epochs = 20
def train(dataloader):
uf.train()
total_loss = 0
total_items = 0
start_time = time.time()
for i_batch, batch in enumerate(dataloader):
output_seq = Variable(batch['output_seq'])
del (batch['output_seq'])
for k in batch:
batch[k] = Variable(batch[k])
if DEVICE_NO != -1:
output_seq = output_seq.cuda(DEVICE_NO)
for k in batch:
batch[k] = batch[k].cuda(DEVICE_NO)
uf.zero_grad()
pred = uf.forward(**batch)
pred = pred.view(-1, pred.size(-1))
output_seq = output_seq.view(-1)
loss = criteria(pred, output_seq)
loss.backward()
num_items = len([x for x in output_seq if int(x) != criteria.ignore_index])
total_loss += num_items * loss.data
total_items += num_items
optimizer.step()
if i_batch % log_interval == 0 and i_batch > 0:
cur_loss = total_loss[0] / total_items
elapsed = time.time() - start_time
print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:04.2f} | ms/batch {:5.2f} | '
'loss {:5.2f} | ppl {:8.2f}'.format(
epoch, i_batch, len(dataloader.dataset) // dataloader.batch_size, optimizer.param_groups[0]['lr'],
elapsed * 1000 / log_interval, cur_loss, math.exp(cur_loss)))
total_loss = 0
total_items = 0
start_time = time.time()
def evaluate(dataloader):
total_loss = 0
total_items = 0
for batch in dataloader:
output_seq = Variable(batch['output_seq'])
del (batch['output_seq'])
for k in batch:
batch[k] = Variable(batch[k])
if DEVICE_NO != -1:
output_seq = output_seq.cuda(DEVICE_NO)
for k in batch:
batch[k] = batch[k].cuda(DEVICE_NO)
pred = uf.forward(**batch)
pred = pred.view(-1, pred.size(-1))
output_seq = output_seq.view(-1)
num_items = len([x for x in output_seq if int(x) != criteria.ignore_index])
total_loss += num_items * criteria(pred, output_seq).data
total_items += num_items
return total_loss[0] / total_items
best_val_loss = 1000
try:
print(uf)
uf.init_weights(pre_trained_filename='xinhuashe/weights.pkl.npy')
for epoch in range(1, epochs + 1):
# scheduler.step()
epoch_start_time = time.time()
train(trainloader)
val_loss = evaluate(devloader)
print('-' * 89)
print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
val_loss, math.exp(val_loss)))
print('-' * 89)
# Save the model if the validation loss is the best we've seen so far.
if not best_val_loss or val_loss < best_val_loss:
print('new best val loss, saving model')
with open('model.pkl', 'wb') as f:
torch.save(uf, f)
best_val_loss = val_loss
else:
# Anneal the learning rate if no improvement has been seen in the validation dataset.
pass
except KeyboardInterrupt:
print('-' * 89)
print('Exiting from training early')
val_loss = evaluate(devloader)
# Save the model if the validation loss is the best we've seen so far.
if not best_val_loss or val_loss < best_val_loss:
print('new best val loss, saving model')
with open('model.pkl', 'wb') as f:
torch.save(uf, f)
best_val_loss = val_loss
else:
# Anneal the learning rate if no improvement has been seen in the validation dataset.
pass