def main(): print('Loading model from %s' % (args.ckpt)) net = RetinaNet() load_checkpoint('%s' % (args.ckpt), net) net.eval() use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") print('Device used:', device) net = net.to(device) encoder = DataEncoder() if args.vmode: if not os.path.exists("%s" % (args.pred_root)): os.makedirs("%s" % (args.pred_root)) visualize(encoder, net, device, args.image_root, val_image_list, args.pred_root) if args.mmode: if not os.path.exists("./mPA"): os.makedirs("./mPA") if os.path.exists("./mPA/detection-results"): print("Remove Pred file") os.system("rm -rf ./mPA/detection-results") os.makedirs("./mPA/detection-results") if not os.path.exists("./mPA/detection-results"): os.makedirs("./mPA/detection-results") mPA_pred(encoder, net, device, args.image_root, args.anno_root) return 0
def test(): import torchvision transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.485,0.456,0.406), (0.229,0.224,0.225)) ]) dataType = 'train2017' root_path = "../COCO_dataset/images/%s"%(dataType) list_root_path = "./data/%s.txt"%(dataType) dataset = ListDataset(root=root_path,list_file=list_root_path, train=True, transform=transform, input_size=360) dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=1, collate_fn=dataset.collate_fn) for images, loc_targets, cls_targets in dataloader: print(images.shape) print(loc_targets.shape) print(cls_targets.shape) grid = torchvision.utils.make_grid(images, 1) torchvision.utils.save_image(grid, 'a.jpg') print('Loading image..') net = RetinaNet() net.eval() img = Image.open('a.jpg') w = h = 360 img = img.resize((w,h)) print('Predicting..') x = transform(img) x = x.unsqueeze(0) x = Variable(x) use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") print('Device used:', device) if torch.cuda.is_available(): net = net.to(device) x = x.to(device) loc_targets = loc_targets.to(device) with torch.no_grad(): loc_preds, cls_preds = net(x) print(loc_preds.shape) print(cls_preds.shape) print('Decoding..') encoder = DataEncoder() boxes, labels, _ = encoder.decode(loc_targets.data.cpu()[0], cls_preds.data.cpu().squeeze(), (w,h)) print("Label : ",labels) draw = ImageDraw.Draw(img) # use a truetype font font = ImageFont.truetype("./font/DELIA_regular.ttf", 20) for i,(box,label) in enumerate(zip(boxes,labels)): draw.rectangle(list(box), outline=color_map(int(label)),width = 5) draw.rectangle(list([box[0],box[1]-17,box[0]+10*len(my_cate[int(label)])+5,box[1]]), outline=color_map(int(label)),width = 3,fill=color_map(int(label))) draw.text((box[0]+3, box[1]-16), my_cate[int(label)],font = font,fill = (0, 0, 0, 100),width = 5) plt.imshow(img) plt.savefig("./test.jpg") break