/
detect.py
48 lines (35 loc) · 1.42 KB
/
detect.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# -*- coding: utf-8 -*-
#
# Developed by Alex Jercan <jercan_alex27@yahoo.com>
#
# References:
#
import torch
import config
from models.decoder import Decoder
from models.encoder import Encoder
from util.common import num_channels, load_checkpoint, plot_volumes, to_volume
from util.dataset import LoadImages
def detect(path, encoder=None, decoder=None):
torch.backends.cudnn.benchmark = True
dataset = LoadImages(path, img_size=config.IMAGE_SIZE, used_layers=config.USED_LAYERS)
if not encoder or not decoder:
in_channels = num_channels(config.USED_LAYERS)
encoder = Encoder(in_channels=in_channels)
decoder = Decoder(num_classes=config.NUM_CLASSES+1)
encoder = encoder.to(config.DEVICE)
decoder = decoder.to(config.DEVICE)
_, encoder, decoder = load_checkpoint(encoder, decoder, config.CHECKPOINT_FILE, config.DEVICE)
encoder.eval()
decoder.eval()
for _, layers, path in dataset:
with torch.no_grad():
layers = torch.from_numpy(layers).to(config.DEVICE, non_blocking=True)
if layers.ndimension() == 3:
layers = layers.unsqueeze(0)
features = encoder(layers)
predictions = decoder(features)
_, out = predictions, predictions.sigmoid()
plot_volumes(to_volume(out, config.VOXEL_THRESH).cpu(), [path], config.NAMES)
if __name__ == "__main__":
detect(config.DETECT_PATH)