-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_1.py
56 lines (50 loc) · 2.12 KB
/
main_1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# -*- coding: utf-8 -*-
from models import load_model
import my_utils
from torch.utils.data import DataLoader
import torch
args = my_utils.get_args()
data_path,noc = my_utils.get_data_path(args.dataset,'./config.txt')
model = load_model(args.model,noc).cuda()
train_images = data_path+'/train/images'
train_labels = data_path+'/train/labels'
val_images = data_path+'/validation/images'
val_labels = data_path+'/validation/labels'
test_images = data_path+'/test/images'
test_labels = data_path+'/test/labels'
# DATA LOADERS
train_loader = DataLoader(my_utils.getDataset(train_images,
train_labels,
size = (360,480)),
batch_size=args.batch_size,
num_workers=args.num_of_workers,
shuffle=True)
val_loader = DataLoader(my_utils.getDataset(val_images,
val_labels,
size = (360,480)),
batch_size=args.batch_size,
num_workers=args.num_of_workers,
shuffle=False)
test_loader = DataLoader(my_utils.getDataset(test_images,
test_labels,
size = (360,480)),
batch_size=args.batch_size,
num_workers=args.num_of_workers,
shuffle=False)
# TRAINING
epochs = args.max_epochs
save_path = args.save_path
train_flag = int(args.fresh_train)
if train_flag:
trainer = my_utils.Trainer(model,train_loader,val_loader,save_path,epochs,noc)
trained_model = trainer.train()
else:
print('Loading model')
trained_model = torch.load(save_path+'/best_model.pth')
print('model loaded')
tester_train = my_utils.Tester(trained_model,train_loader,save_path+'/eval_train',noc)
tester_train.test()
tester_val = my_utils.Tester(trained_model,val_loader,save_path+'/eval_val',noc)
tester_val.test()
tester_test = my_utils.Tester(trained_model,test_loader,save_path+'/eval_test',noc)
tester_test.test()