-
Notifications
You must be signed in to change notification settings - Fork 2
/
main.py
38 lines (31 loc) · 1013 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from helpers import make_data, score_iou
import numpy as np
import torch
from tqdm import tqdm
from network import Net
from helpers import unnormalize
from torchsummary import summary
import pdb
def eval():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
np.random.seed(seed=46)
model = Net()
path = 'model.pth.tar'
checkpoint = torch.load(path, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint)
model = model.to(device)
model.eval()
ious = []
for _ in tqdm(range(1000)):
img, label = make_data()
img = torch.from_numpy(np.asarray(img, dtype=np.float32))
img = torch.unsqueeze(img, 0)
img = torch.unsqueeze(img, 0)
img = img.to(device)
pred = model.predict(img)
ious.append(score_iou(label, pred))
ious = np.asarray(ious, dtype="float")
ious = ious[~np.isnan(ious)] # remove true negatives
print((ious > 0.7).mean())
if __name__ == "__main__":
eval()