/
run.py
115 lines (95 loc) · 3.51 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import json
import os
import time
import torch
from torch import optim
from torch.utils.data import DataLoader
from config import get_opt
from dataset import MedicalExtractionDataset
from model import MedicalExtractionModel
from utils import get_cuda, logging, print_params
def train(opt):
train_ds = MedicalExtractionDataset(opt.train_data)
dev_ds = MedicalExtractionDataset(opt.dev_data)
dev_dl = DataLoader(dev_ds,
batch_size=opt.dev_batch_size,
shuffle=False,
num_workers=1
)
model = MedicalExtractionModel(opt)
print(model.parameters)
print_params(model)
start_epoch = 1
learning_rate = opt.lr
total_epochs = opt.epochs
log_step = opt.log_step
pretrain_model = opt.pretrain_model
model_name = opt.model_name # 要保存的模型名字
# load pretrained model
if pretrain_model != '':
chkpt = torch.load(pretrain_model, map_location=torch.device('cpu'))
model.load_state_dict(chkpt['checkpoints'])
logging('load model from {}'.format(pretrain_model))
start_epoch = chkpt['epoch'] + 1
learning_rate = chkpt['learning_rate']
logging('resume from epoch {} with learning_rate {}'.format(start_epoch, learning_rate))
else:
logging('training from scratch with learning_rate {}'.format(learning_rate))
model = get_cuda(model)
# TODO 如果用Bert可以改成AdamW
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# TODO loss function
# criterion =
checkpoint_dir = opt.checkpoint_dir
if not os.path.exists(checkpoint_dir):
os.mkdir(checkpoint_dir)
# training process
# 1.
global_step = 0
total_loss = 0
for epoch in range(1, total_epochs + 1):
start_time = time.time()
train_dl = DataLoader(train_ds,
batch_size=opt.batch_size,
shuffle=True,
num_workers=8
)
model.train()
for batch in train_dl:
optimizer.zero_grad()
# TODO 喂数据
# TODO loss计算
loss = None
loss.backward()
optimizer.step()
global_step += 1
total_loss += loss.item()
if global_step % log_step == 0:
cur_loss = total_loss / log_step
elapsed = time.time() - start_time
logging(
'| epoch {:2d} | step {:4d} | ms/b {:5.2f} | train loss {:5.3f} '.format(
epoch, global_step, elapsed * 1000 / log_step, cur_loss * 1000))
total_loss = 0
start_time = time.time()
if epoch % opt.test_epoch == 0:
model.eval()
with torch.no_grad():
for batch in dev_dl:
# TODO 在验证集上测试
pass
# save model
# TODO 可以改成只save在dev上最佳的模型
if epoch % opt.save_model_freq == 0:
path = os.path.join(checkpoint_dir, model_name + '_{}.pt'.format(epoch))
torch.save({
'epoch': epoch,
'learning_rate': learning_rate,
'checkpoint': model.state_dict()
}, path)
if __name__ == '__main__':
print('processId:', os.getpid())
print('prarent processId:', os.getppid())
opt = get_opt()
print(json.dumps(opt.__dict__, indent=4))
train(opt)