-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
93 lines (62 loc) · 2.22 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from tqdm import tqdm
import os, sys
import torch
import torch.nn as nn
import torch.nn.functional as f
import torch.optim as optim
import data as dt
import config as cfg
from utils import getModel
def train(model_name):
# train model
train_loader = dt.loadData(train=True)
model = getModel(model_name = model_name)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=cfg.lr, momentum=cfg.momentum)
scheduler = optim.lr_scheduler.StepLR(optimizer, cfg.lr_decaying_step, gamma = cfg.lr_decaying_value)
saved_model_path = cfg.log_path + model_name + "_cifar10.pt"
if os.path.isfile(saved_model_path):
# load saved checkpoint
checkpoint = torch.load(saved_model_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
loss = checkpoint['loss']
else:
start_epoch = 1
model.train()
print("\nStart training ", model_name, "...")
for epoch in range(start_epoch, cfg.epochs + 1):
for batch_idx, (data, target) in enumerate(tqdm(train_loader)):
data, target = data.to(cfg.device), target.to(cfg.device)
if cfg.convert_to_RGB:
batch_size, channel, width, height = data.size()
data = data.view(batch_size, channel, width, height).expand(batch_size, cfg.converted_channel, width, height)
optimizer.zero_grad()
output=model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
scheduler.step()
if cfg.save_model and (epoch % cfg.save_epoch == 0):
if not(os.path.isdir(cfg.log_path)):
os.makedirs(os.path.join(cfg.log_path))
torch.save({
'epoch':epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss
}, saved_model_path)
print('\tTrain Epoch: {} / {} \t Loss: {:.6f}\n'.format(epoch, cfg.epochs, loss.item()))
print("Done!\n\n")
if __name__ == "__main__":
if not cfg.no_cuda:
# use GPU
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
torch.cuda.set_device(int(sys.argv[1]))
torch.cuda.manual_seed(cfg.gpu_seed)
else:
# use CPU
torch.manual_seed(cfg.cpu_seed)
for model_name in cfg.model_list:
train(model_name = model_name)