-
Notifications
You must be signed in to change notification settings - Fork 0
/
vi_cls_tta.py
119 lines (89 loc) · 3.49 KB
/
vi_cls_tta.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
import pickle
import gc
import time
import torch
from torch.utils.data import DataLoader
from global_parameter import *
from dataset import CloudTrainDataset2,CloudTestDataset
import segmentation_models_pytorch as smp
import utils
from models import load_model
import warnings
warnings.filterwarnings('ignore')
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
cls_name = 'resnext50_32x4d_1115-1540_all'
if 'wide_resnet50_2' in cls_name:
model_name = 'wide_resnet50_2'
if 'resnext50_32x4d_1115' in cls_name:
model_name = 'resnext50_32x4d'
else:
model_name = cls_name.split('_')[0]
# encoder_weights = 'imagenet'
save_dir = os.path.join(SAVE_PATH,'classify',cls_name)
print(cls_name)
NUM_WORKERS = 4
BATCH_SIZE = 32
RESIZE = (480,640)
train_csv = os.path.join(DATA_PATH,'train.csv')
img_path = os.path.join(DATA_PATH,'train_images')
kfold_path = 'kfold.pkl'
data_path = '/home/jianglb/pythonproject/cloud_segment/data/train_{}_{}'.format(*RESIZE)
# data_path = '/home/noel/pythonproject/cloud_segment/data/train_{}_{}'.format(*RESIZE)
test_data = os.path.join(DATA_PATH,'test_{}_{}'.format(*RESIZE))
cls_probs = []
cloud_class = 0
with torch.no_grad():
for fold in range(K):
print('Fold{}:'.format(fold))
cls_model = load_model(model_name,classes=4,dropout=0.,pretrained=False)
cls_model.load_state_dict(torch.load(os.path.join(save_dir, 'model_{}.pth'.format(fold))))
cls_model.cuda()
cls_model.eval()
preprocessing_fn = smp.encoders.get_preprocessing_fn('resnet34', 'imagenet')
# validate
cls_probs_fold = 0
for tt in range(4):
validate_dataset = CloudTrainDataset2(train_csv, data_path, kfold_path, fold, phase='validate', transform_type=tt, preprocessing=utils.get_preprocessing(preprocessing_fn),)
validate_dataloader = DataLoader(validate_dataset,batch_size=BATCH_SIZE,num_workers=NUM_WORKERS,shuffle=False)
cls_probs_fold_type = []
for i, (_, imgs, masks, classes) in enumerate(tqdm(validate_dataloader)):
imgs = imgs.float().cuda()
predclasses = cls_model.predict(imgs)
cls_probs_fold_type.append(predclasses.cpu().numpy())
cls_probs_fold_type = np.concatenate(cls_probs_fold_type)
cls_probs_fold += cls_probs_fold_type
cls_probs_fold /= 4
cls_probs.append(cls_probs_fold)
del cls_probs_fold
del cls_probs_fold_type
gc.collect()
# inference
for tt in range(4):
test_dataset = CloudTestDataset(test_data, transform_type=tt, preprocessing=utils.get_preprocessing(preprocessing_fn), )
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=False)
class_ = []
for img_names, imgs, _ in tqdm(test_dataloader):
imgs = imgs.float().cuda()
predclasses = cls_model.predict(imgs)
class_.append(predclasses.cpu().numpy())
class_ = np.concatenate(class_,axis=0)
cloud_class += class_
del class_
gc.collect()
cls_probs = np.concatenate(cls_probs,axis=0)
cloud_class /= (K * 4)
del validate_dataset
del validate_dataloader
del imgs
del classes
del predclasses
del cls_model
del test_dataset
del test_dataloader
gc.collect()
np.save(os.path.join(save_dir,'cls_probs.npy'),cls_probs)
np.save(os.path.join(save_dir,'cloud_class.npy'),cloud_class)