img_name = img_name[0]; label = label[0]

        img_path = voc12.data.get_img_path(img_name, args.voc12_root)
        orig_img = np.asarray(Image.open(img_path))
        orig_img_size = orig_img.shape[:2]

        def _work(i, img):
            with torch.no_grad():
                _, cam = model(img.cuda())
                cam = F.upsample(cam[:,1:,:,:], orig_img_size, mode='bilinear', align_corners=False)[0]
                cam = cam.cpu().numpy() * label.clone().view(20, 1, 1).numpy()
                if i % 2 == 1:
                    cam = np.flip(cam, axis=-1)
                return cam

        thread_pool = pyutils.BatchThreader(_work, list(enumerate(img_list)),
                                            batch_size=12, prefetch_size=0, processes=args.num_workers)

        cam_list = thread_pool.pop_results()

        sum_cam = np.sum(cam_list, axis=0)
        sum_cam[sum_cam < 0] = 0
        cam_max = np.max(sum_cam, (1,2), keepdims=True)
        cam_min = np.min(sum_cam, (1,2), keepdims=True)
        sum_cam[sum_cam < cam_min+1e-5] = 0
        norm_cam = (sum_cam-cam_min-1e-5) / (cam_max - cam_min + 1e-5)

        cam_dict = {}
        for i in range(20):
            if label[i] > 1e-5:
                cam_dict[i] = norm_cam[i]
예제 #2
0
    def get_iou(weights):
        model = getattr(importlib.import_module(args.network), 'Net')()
        model.load_state_dict(torch.load(weights))

        model.eval()
        model.cuda()

        infer_dataset = voc12.data.VOC12ClsDatasetMSFseg(
            args.infer_list,
            voc12_root=args.voc12_root,
            scales=(1, 0.5, 1.5, 2.0),
            inter_transform=torchvision.transforms.Compose(
                [np.asarray, model.normalize, imutils.HWC_to_CHW]))

        infer_data_loader = DataLoader(infer_dataset,
                                       shuffle=False,
                                       num_workers=args.num_workers,
                                       pin_memory=True)

        n_gpus = torch.cuda.device_count()
        model_replicas = torch.nn.parallel.replicate(model,
                                                     list(range(n_gpus)))
        preds, gts = [], []
        for iter, (img_name, img_list, label,
                   labelseg) in enumerate(infer_data_loader):
            img_name = img_name[0]
            label = label[0]

            img_path = voc12.data.get_img_path(img_name, args.voc12_root)
            orig_img = np.asarray(Image.open(img_path))
            orig_img_size = orig_img.shape[:2]

            # print(len(img_list))
            # for im in img_list:
            #     print (im.shape)
            def _work(i, img):
                with torch.no_grad():
                    with torch.cuda.device(i % n_gpus):
                        cam = model_replicas[i % n_gpus].forward_cam(
                            img.cuda())
                        cam = F.upsample(cam,
                                         orig_img_size,
                                         mode='bilinear',
                                         align_corners=False)[0]
                        cam = cam.cpu().numpy() * label.clone().view(
                            20, 1, 1).numpy()
                        if i % 2 == 1:
                            cam = np.flip(cam, axis=-1)
                        return cam

            thread_pool = pyutils.BatchThreader(_work,
                                                list(enumerate(img_list)),
                                                batch_size=12,
                                                prefetch_size=0,
                                                processes=args.num_workers)

            cam_list = thread_pool.pop_results()

            sum_cam = np.sum(cam_list, axis=0)
            print(sum_cam.shape)
            norm_cam = sum_cam / (np.max(sum_cam,
                                         (1, 2), keepdims=True) + 1e-5)

            bg_score = [np.ones_like(norm_cam[0]) * 0.2]
            pred = np.argmax(np.concatenate((bg_score, norm_cam)), 0)

            if iter % 50 == 0:
                print(iter, weights)

            preds += list(pred)
            gts += list(labelseg[0].cpu().numpy())
        score = scores(gts, preds, n_class=21)
        print(weights)
        print(score)