def build_model(model_name, num_classes): if model_name == 'SQNet': return SQNet(classes=num_classes) elif model_name == 'LinkNet': return LinkNet(classes=num_classes) elif model_name == 'SegNet': return SegNet(classes=num_classes) elif model_name == 'UNet': return UNet(classes=num_classes) elif model_name == 'ENet': return ENet(classes=num_classes) elif model_name == 'ERFNet': return ERFNet(classes=num_classes) elif model_name == 'CGNet': return CGNet(classes=num_classes) elif model_name == 'EDANet': return EDANet(classes=num_classes) elif model_name == 'ESNet': return ESNet(classes=num_classes) elif model_name == 'ESPNet': return ESPNet(classes=num_classes) elif model_name == 'LEDNet': return LEDNet(classes=num_classes) elif model_name == 'ESPNet_v2': return EESPNet_Seg(classes=num_classes) elif model_name == 'ContextNet': return ContextNet(classes=num_classes) elif model_name == 'FastSCNN': return FastSCNN(classes=num_classes) elif model_name == 'DABNet': return DABNet(classes=num_classes) elif model_name == 'FSSNet': return FSSNet(classes=num_classes) elif model_name == 'FPENet': return FPENet(classes=num_classes)
def processImage(infile, args): n_classes = 12 model = LinkNet(n_classes) model.load_state_dict(torch.load(args.model_path)) if torch.cuda.is_available(): model = model.cuda(0) model.eval() gif = cv2.VideoCapture(infile) cv2.namedWindow('camvid') while (gif.isOpened()): ret, frame = gif.read() frame = cv2.resize(frame, (768, 576)) images = get_tensor(frame) if torch.cuda.is_available(): images = Variable(images.cuda(0)) else: images = Variable(images) outputs = model(images) pred = outputs.data.max(1)[1].cpu().numpy().reshape(576, 768) pred = decode_segmap(pred) vis = np.zeros((576, 1536, 3), np.uint8) vis[:576, :768, :3] = frame vis[:576, 768:1536, :3] = pred cv2.imshow('camvid', vis) cv2.waitKey(10)
def build_model(model_name, num_classes): # for deeplabv3 model_map = { 'deeplabv3_resnet50': network.deeplabv3_resnet50, 'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50, 'deeplabv3_resnet101': network.deeplabv3_resnet101, 'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101, 'deeplabv3_mobilenet': network.deeplabv3_mobilenet, 'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet } if model_name == 'SQNet': return SQNet(classes=num_classes) elif model_name == 'LinkNet': return LinkNet(classes=num_classes) elif model_name == 'SegNet': return SegNet(classes=num_classes) elif model_name == 'UNet': return UNet(classes=num_classes) elif model_name == 'ENet': return ENet(classes=num_classes) elif model_name == 'ERFNet': return ERFNet(classes=num_classes) elif model_name == 'CGNet': return CGNet(classes=num_classes) elif model_name == 'EDANet': return EDANet(classes=num_classes) elif model_name == 'ESNet': return ESNet(classes=num_classes) elif model_name == 'ESPNet': return ESPNet(classes=num_classes) elif model_name == 'LEDNet': return LEDNet(classes=num_classes) elif model_name == 'ESPNet_v2': return EESPNet_Seg(classes=num_classes) elif model_name == 'ContextNet': return ContextNet(classes=num_classes) elif model_name == 'FastSCNN': return FastSCNN(classes=num_classes) elif model_name == 'DABNet': return DABNet(classes=num_classes) elif model_name == 'FSSNet': return FSSNet(classes=num_classes) elif model_name == 'FPENet': return FPENet(classes=num_classes) elif model_name == 'FCN': return FCN32VGG(classes=num_classes) elif model_name in model_map.keys(): return model_map[model_name](num_classes, output_stride=8)
def validate(args): # Setup Dataloader data_loader = get_loader(args.dataset) data_path = get_data_path(args.dataset) loader = data_loader(data_path, split=args.split, is_transform=True) n_classes = loader.n_classes valloader = data.DataLoader(loader, batch_size=args.batch_size) # Setup Model model = LinkNet(n_classes) model.load_state_dict(torch.load(args.model_path)) model.eval() if torch.cuda.is_available(): model.cuda(0) gts, preds = [], [] for i, (images, labels) in enumerate(valloader): if torch.cuda.is_available(): images = Variable(images.cuda(0)) labels = Variable(labels.cuda(0)) else: images = Variable(images) labels = Variable(labels) t1 = time.time() outputs = model(images) t2 = time.time() print(t2 - t1) pred = outputs.data.max(1)[1].cpu().numpy() gt = labels.data.cpu().numpy() for gt_, pred_ in zip(gt, pred): gts.append(gt_) preds.append(pred_) score, class_iou = scores(gts, preds, n_class=n_classes) for k, v in score.items(): print k, v for i in range(n_classes): print i, class_iou[i]