Beispiel #1
0
def eval_model(net, save_dir, batch_size=10):
    # Setting parameters
    relax_crop = 50  # Enlarge the bounding box by relax_crop pixels
    zero_pad_crop = True  # Insert zero padding when cropping the image

    net.eval()
    composed_transforms_ts = transforms.Compose([
        tr.CropFromMask(crop_elems=('image', 'gt'),
                        relax=relax_crop,
                        zero_pad=zero_pad_crop),
        tr.FixedResize(resolutions={
            'gt': None,
            'crop_image': (512, 512),
            'crop_gt': (512, 512)
        }),
        tr.ExtremePoints(sigma=10, pert=0, elem='crop_gt'),
        tr.ToImage(norm_elem='extreme_points'),
        tr.ConcatInputs(elems=('crop_image', 'extreme_points')),
        tr.ToTensor()
    ])
    db_test = pascal.VOCSegmentation(split='val',
                                     transform=composed_transforms_ts,
                                     retname=True)
    testloader = DataLoader(db_test,
                            batch_size=1,
                            shuffle=False,
                            num_workers=2)

    save_dir.mkdir(exist_ok=True)

    with torch.no_grad():
        test(net, testloader, save_dir)
Beispiel #2
0
def create_transforms(relax_crop, zero_crop):
    # Preparation of the data loaders
    first = [
        tr.CropFromMask(crop_elems=('image', 'gt'),
                        relax=relax_crop,
                        zero_pad=zero_crop),
        tr.FixedResize(resolutions={
            'crop_image': (512, 512),
            'crop_gt': (512, 512)
        })
    ]
    second = [
        tr.ToImage(norm_elem='extreme_points'),
        tr.ConcatInputs(elems=('crop_image', 'extreme_points')),
        tr.ToTensor()
    ]
    train_tf = transforms.Compose([
        tr.RandomHorizontalFlip(),
        tr.ScaleNRotate(rots=(-20, 20), scales=(.75, 1.25)), *first,
        tr.ExtremePoints(sigma=10, pert=5, elem='crop_gt'), *second
    ])
    test_tf = transforms.Compose(
        [*first,
         tr.ExtremePoints(sigma=10, pert=0, elem='crop_gt'), *second])
    return train_tf, test_tf
Beispiel #3
0
    def transform_val(self, sample):

        composed_transforms = transforms.Compose([
            tr.CropFromMask(crop_elems=('image', 'gt'), relax=20, zero_pad=True),
            tr.FixedResize(resolutions={'crop_image': (256, 256), 'crop_gt': (256, 256)}),
            tr.Normalize(elems='crop_image'),
            tr.ToTensor()])

        return composed_transforms(sample)
    def __init__(self,
                 sbox_net=None,
                 split='train',
                 transform=None,
                 sbox='sbox_513_8925.pth.tar',
                 miou_thres=0.85,
                 which='click'):
        sbox_net = utils.load_model(DeepLabX(pretrain=False),
                                    'run/sbox/' + sbox)

        VOCSegmentation.__init__(self,
                                 root=Path.db_root_dir('pascal'),
                                 split=split,
                                 transform=transform,
                                 download=False,
                                 preprocess=False,
                                 area_thres=500,
                                 retname=True,
                                 suppress_void_pixels=True,
                                 which_part=which,
                                 default=False)
        # sbox_net.eval()
        self.sbox_net = sbox_net
        self.device = torch.device(
            "cuda:" + str(0) if torch.cuda.is_available() else "cpu")
        self.sbox_net = self.sbox_net.to(self.device)
        self.sbox_net.eval()

        self.miou_thres = miou_thres
        self.click_list_file = os.path.join(
            self.root, self.BASE_DIR, 'ImageSets', 'Segmentation', '-'.join(
                [sbox, 'mIoU_thres',
                 str(miou_thres), 'on', which, 'sets']) + '.txt')
        np.random.seed(42)
        if split == 'train' and not self._check_filtered_file():
            pre_transform = transforms.Compose([
                tr.CropFromMask(crop_elems=('image', 'gt'),
                                relax=20,
                                zero_pad=True,
                                jitters_bound=None),
                tr.FixedResize(resolutions={
                    'crop_image': (512, 512),
                    'crop_gt': (512, 512)
                }),
                tr.Normalize(elems='crop_image'),
                tr.ToTensor(),
            ])
            tmp = self.transform
            self.transform = pre_transform
            self._filter_obj_list()
            self.transform = tmp

        np.random.seed(42)
        # change the obj list of base class, in hope it will consequently change the behaviour of 'get_item' method
        # of base class so that we dont need to reinvent the wheel
        print('target list len:{}'.format(len(self.target_list)))
Beispiel #5
0
 def transform_tr(self, sample):
     composed_transforms_tr = transforms.Compose([
         tr.RandomHorizontalFlip(),
         tr.ScaleNRotate(rots=(-20, 20), scales=(.75, 1.25)),
         tr.CropFromMask(crop_elems=('image', 'gt'), relax=20, zero_pad=True),
         tr.FixedResize(resolutions={'crop_image': (256, 256), 'crop_gt': (256, 256)}),
         tr.Normalize(elems='crop_image'),
         tr.ToTensor()
     ])
     return composed_transforms_tr(sample)
# Training the network
if resume_epoch != nEpochs:
    # Logging into Tensorboard
    log_dir = os.path.join(save_dir, 'models', datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname())
    # writer = SummaryWriter(log_dir=log_dir)

    # Use the following optimizer
    optimizer = optim.SGD(train_params, lr=p['lr'], momentum=p['momentum'], weight_decay=p['wd'])
    p['optimizer'] = str(optimizer)

    # Preparation of the data loaders
    composed_transforms_tr = transforms.Compose([
        tr.RandomHorizontalFlip(),
        tr.ScaleNRotate(rots=(-20, 20), scales=(.75, 1.25)),
        tr.CropFromMask(crop_elems=('image', 'gt'), relax=relax_crop, zero_pad=zero_pad_crop),
        tr.FixedResize(resolutions={'crop_image': (512, 512), 'crop_gt': (512, 512)}),
        tr.ExtremePoints(sigma=10, pert=5, elem='crop_gt'),
        tr.ToImage(norm_elem='extreme_points'),
        tr.ConcatInputs(elems=('crop_image', 'extreme_points')),
        tr.ToTensor()])
    composed_transforms_ts = transforms.Compose([
        tr.CropFromMask(crop_elems=('image', 'gt'), relax=relax_crop, zero_pad=zero_pad_crop),
        tr.FixedResize(resolutions={'crop_image': (512, 512), 'crop_gt': (512, 512)}),
        tr.ExtremePoints(sigma=10, pert=0, elem='crop_gt'),
        tr.ToImage(norm_elem='extreme_points'),
        tr.ConcatInputs(elems=('crop_image', 'extreme_points')),
        tr.ToTensor()])

    voc_train = pascal.VOCSegmentation(split='train', transform=composed_transforms_tr)
    voc_val = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts)
Beispiel #7
0
        return 'VOC2012(split=' + str(self.split) + ',area_thres=' + str(
            self.area_thres) + ')'


if __name__ == '__main__':
    import matplotlib.pyplot as plt
    import dataloaders.helpers as helpers
    import torch
    from torchvision import transforms

    transform = transforms.Compose([
        tr.RandomHorizontalFlip(),
        tr.ScaleNRotate(rots=(-20, 20), scales=(.75, 1.25)),
        # tr.CropFromMask(crop_elems=('image', 'gt'), relax=30, zero_pad=False, jitters_bound=(10, 30)),
        tr.CropFromMask(crop_elems=('image', 'gt'),
                        relax=30,
                        zero_pad=False,
                        jitters_bound=(30, 31)),
        # tr.CropFromMask(crop_elems=('image', 'gt'), relax=30, zero_pad=False, jitters_bound=None),
        tr.FixedResize(resolutions={
            'crop_image': (256, 256),
            'crop_gt': (256, 256)
        }),
        tr.Normalize(elems='crop_image'),
        # tr.ToImage(norm_elem=('pos_map', 'neg_map')),
    ])

    dataset = VOCSegmentation(split=['val'], transform=transform, retname=True)
    # dataset = VOCSegmentation(split=['train', 'val'], retname=True)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=1,
                                             shuffle=False,
Beispiel #8
0
                backbone='resnet101',
                output_stride=16,
                sync_bn=None,
                freeze_bn=False)

# load pretrain_dict
pretrain_dict = torch.load(os.path.join(save_dir, 'models', modelName + '_epoch-' + str(resume_epoch - 1) + '.pth'))
print("Initializing weights from: {}".format(
    os.path.join(save_dir, 'models', modelName + '_epoch-' + str(resume_epoch - 1) + '.pth')))
net.load_state_dict(pretrain_dict)
net.to(device)

# Generate result of the validation images
net.eval()
composed_transforms_ts = transforms.Compose([
    tr.CropFromMask(crop_elems=('image', 'gt','void_pixels'), relax=30, zero_pad=True),
    tr.FixedResize(resolutions={'gt': None, 'crop_image': (512, 512), 'crop_gt': (512, 512), 'crop_void_pixels': (512, 512)},flagvals={'gt':cv2.INTER_LINEAR,'crop_image':cv2.INTER_LINEAR,'crop_gt':cv2.INTER_LINEAR,'crop_void_pixels': cv2.INTER_LINEAR}),
    tr.IOGPoints(sigma=10, elem='crop_gt',pad_pixel=10),
    tr.ToImage(norm_elem='IOG_points'),
    tr.ConcatInputs(elems=('crop_image', 'IOG_points')),
    tr.ToTensor()])
db_test = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts, retname=True)
testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1)

save_dir_res = os.path.join(save_dir, 'Results')
if not os.path.exists(save_dir_res):
    os.makedirs(save_dir_res)
save_dir_res_list=[save_dir_res]
print('Testing Network')
with torch.no_grad():
    for ii, sample_batched in enumerate(testloader):       
Beispiel #9
0
    net = tosnet.tosnet_resnet50(n_inputs=cfg['num_inputs'],
                                 n_classes=cfg['num_classes'],
                                 os=16,
                                 pretrained=None)
    print('Loading from snapshot: {}'.format(args.weights))
    net.load_state_dict(
        torch.load(args.weights, map_location=lambda storage, loc: storage))
    net.to(device)
    net.eval()

    # Setup data transformations
    composed_transforms = [
        tr.IdentityTransform(tr_elems=['gt'], prefix='ori_'),
        tr.CropFromMask(crop_elems=['image', 'gt'],
                        relax=cfg['relax_crop'],
                        zero_pad=cfg['zero_pad_crop'],
                        adaptive_relax=cfg['adaptive_relax'],
                        prefix=''),
        tr.Resize(resize_elems=['image', 'gt', 'void_pixels'],
                  min_size=cfg['min_size'],
                  max_size=cfg['max_size']),
        tr.ComputeImageGradient(elem='image'),
        tr.ExtremePoints(sigma=10, pert=0, elem='gt'),
        tr.GaussianTransform(tr_elems=['extreme_points'],
                             mask_elem='gt',
                             sigma=10,
                             tr_name='points'),
        tr.FixedResizePoints(
            resolutions={'extreme_points': (cfg['lr_size'], cfg['lr_size'])},
            mask_elem='gt',
            prefix='lr_'),
Beispiel #10
0
def make_data_loader(args, **kwargs):
    crop_size = args.crop_size
    gt_size = args.gt_size
    if args.dataset == 'pascal' or args.dataset == 'click':
        composed_transforms_tr = transforms.Compose([
            tr.RandomHorizontalFlip(),
            tr.ScaleNRotate(rots=(-20, 20), scales=(.75, 1.25)),
            tr.CropFromMask(crop_elems=('image', 'gt'),
                            relax=20,
                            zero_pad=True,
                            jitters_bound=(40, 70)),
            tr.FixedResize(
                resolutions={
                    'crop_image': (crop_size, crop_size),
                    'crop_gt': (gt_size, gt_size)
                }),
            tr.Normalize(elems='crop_image'),
            tr.ToTensor()
        ])
        composed_transforms_val = transforms.Compose([
            tr.CropFromMask(crop_elems=('image', 'gt'),
                            relax=20,
                            zero_pad=True,
                            jitters_bound=(50, 51)),
            tr.FixedResize(
                resolutions={
                    'crop_image': (crop_size, crop_size),
                    'crop_gt': (gt_size, gt_size)
                }),
            tr.Normalize(elems='crop_image'),
            tr.ToTensor()
        ])
        train_set = pascal.VOCSegmentation(split='train',
                                           transform=composed_transforms_tr)
        if args.dataset == 'click':
            train_set.reset_target_list(args)
        val_set = pascal.VOCSegmentation(split='val',
                                         transform=composed_transforms_val)
        if args.use_sbd:
            sbd_train = sbd.SBDSegmentation(args, split=['train', 'val'])
            train_set = combine_dbs.CombineDBs([train_set, sbd_train],
                                               excluded=[val_set])

        train_loader = DataLoader(train_set,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  **kwargs)
        val_loader = DataLoader(val_set,
                                batch_size=args.batch_size,
                                shuffle=False,
                                **kwargs)
        test_loader = None
        NUM_CLASSES = 2
        return train_loader, val_loader, test_loader, NUM_CLASSES

    elif args.dataset == 'grabcut':
        composed_transforms_val = transforms.Compose([
            tr.CropFromMask(crop_elems=('image', 'gt'),
                            relax=20,
                            zero_pad=True,
                            jitters_bound=(50, 51)),
            tr.FixedResize(
                resolutions={
                    'crop_image': (crop_size, crop_size),
                    'crop_gt': (gt_size, gt_size)
                }),
            tr.Normalize(elems='crop_image'),
            tr.ToTensor()
        ])
        val_set = grab_berkeley_eval.GrabBerkely(
            which='grabcut', transform=composed_transforms_val)
        val_loader = DataLoader(val_set,
                                batch_size=args.batch_size,
                                shuffle=False,
                                **kwargs)
        test_loader = None
        train_loader = None
        NUM_CLASSES = 2
        return train_loader, val_loader, test_loader, NUM_CLASSES

    elif args.dataset == 'bekeley':
        composed_transforms_val = transforms.Compose([
            tr.CropFromMask(crop_elems=('image', 'gt'),
                            relax=20,
                            zero_pad=True,
                            jitters_bound=(50, 51)),
            tr.FixedResize(
                resolutions={
                    'crop_image': (crop_size, crop_size),
                    'crop_gt': (gt_size, gt_size)
                }),
            tr.Normalize(elems='crop_image'),
            tr.ToTensor()
        ])
        val_set = grab_berkeley_eval.GrabBerkely(
            which='bekeley', transform=composed_transforms_val)
        val_loader = DataLoader(val_set,
                                batch_size=args.batch_size,
                                shuffle=False,
                                **kwargs)
        test_loader = None
        train_loader = None
        NUM_CLASSES = 2
        return train_loader, val_loader, test_loader, NUM_CLASSES

    elif args.dataset == 'cityscapes':
        train_set = cityscapes.CityscapesSegmentation(args, split='train')
        val_set = cityscapes.CityscapesSegmentation(args, split='val')
        test_set = cityscapes.CityscapesSegmentation(args, split='test')
        num_class = train_set.NUM_CLASSES
        train_loader = DataLoader(train_set,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  **kwargs)
        val_loader = DataLoader(val_set,
                                batch_size=args.batch_size,
                                shuffle=False,
                                **kwargs)
        test_loader = DataLoader(test_set,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 **kwargs)

        return train_loader, val_loader, test_loader, num_class

    elif args.dataset == 'coco':
        val_set = coco_eval.COCOSegmentation(split='val', cat=args.coco_part)
        num_class = 2
        train_loader = None
        val_loader = DataLoader(val_set,
                                batch_size=args.batch_size,
                                shuffle=False,
                                **kwargs)
        test_loader = None
        return train_loader, val_loader, test_loader, num_class

    # elif args.dataset == 'click':
    #     train_set = click_dataset.ClickDataset(split='train')
    #     val_set = click_dataset.ClickDataset(split='val')
    #     num_class = 2
    #     train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs)
    #     val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs)
    #     test_loader = None
    #     return train_loader, val_loader, test_loader, num_class

    else:
        raise NotImplementedError
Beispiel #11
0
        'params': tosnet.get_10x_lr_params(net),
        'lr': args.lr * 10
    }]
    optimizer = optim.SGD(train_params,
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()
    p['optimizer'] = str(optimizer)

    # Setup data transformations
    composed_transforms = [
        tr.RandomHorizontalFlip(),
        tr.CropFromMask(crop_elems=['image', 'gt', 'thin', 'void_pixels'],
                        relax=args.relax_crop,
                        zero_pad=args.zero_pad_crop,
                        adaptive_relax=args.adaptive_relax,
                        prefix=''),
        tr.Resize(resize_elems=['image', 'gt', 'thin', 'void_pixels'],
                  min_size=args.min_size,
                  max_size=args.max_size),
        tr.ComputeImageGradient(elem='image'),
        tr.ExtremePoints(sigma=10, pert=5, elem='gt'),
        tr.GaussianTransform(tr_elems=['extreme_points'],
                             mask_elem='gt',
                             sigma=10,
                             tr_name='points'),
        tr.RandomCrop(
            num_thin=args.num_thin_samples,
            num_non_thin=args.num_non_thin_samples,
            crop_size=args.roi_size,
def process(image_name):

    # Set gpu_id to -1 to run in CPU mode, otherwise set the id of the corresponding gpu
    gpu_id = 0
    device = torch.device("cuda:"+str(gpu_id) if torch.cuda.is_available() else "cpu")
    if torch.cuda.is_available():
        print('Using GPU: {} '.format(gpu_id))

    # Setting parameters
    resume_epoch = 100  # test epoch
    nInputChannels = 5  # Number of input channels (RGB + heatmap of IOG points)

    # Network definition
    modelName = 'IOG_pascal'
    net = Network(nInputChannels=nInputChannels,
                  num_classes=1,
                  backbone='resnet101',
                  output_stride=16,
                  sync_bn=None,
                  freeze_bn=False)

    # load pretrain_dict
    pretrain_dict = torch.load('IOG_PASCAL_SBD.pth')

    net.load_state_dict(pretrain_dict)
    # net.to(device)

    # Generate result of the validation images
    net.eval()

    image = np.array(Image.open(image_name).convert('RGB'))
    im_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    roi = cv2.selectROI(im_rgb)
    image = image.astype(np.float32)

    bbox = np.zeros_like(image[..., 0])
    # bbox[0: 130, 220: 320] = 1 # for ponny
    # bbox[220: 390, 370: 570] = 1
    bbox[int(roi[1]):int(roi[1]+roi[3]), int(roi[0]):int(roi[0]+roi[2])] = 1
    void_pixels = 1 - bbox
    sample = {'image': image, 'gt': bbox, 'void_pixels': void_pixels}

    trns = transforms.Compose([
        tr.CropFromMask(crop_elems=('image', 'gt','void_pixels'), relax=30, zero_pad=True),
        tr.FixedResize(resolutions={'gt': None, 'crop_image': (512, 512), 'crop_gt': (512, 512), 'crop_void_pixels': (512, 512)},flagvals={'gt':cv2.INTER_LINEAR,'crop_image':cv2.INTER_LINEAR,'crop_gt':cv2.INTER_LINEAR,'crop_void_pixels': cv2.INTER_LINEAR}),
        tr.IOGPoints(sigma=10, elem='crop_gt',pad_pixel=10),
        tr.ToImage(norm_elem='IOG_points'),
        tr.ConcatInputs(elems=('crop_image', 'IOG_points')),
        tr.ToTensor()])

    tr_sample = trns(sample)

    inputs = tr_sample['concat'][None]
    # inputs = inputs.to(device)
    outputs = net.forward(inputs)[-1]
    # outputs = fine_out.to(torch.device('cpu'))
    pred = np.transpose(outputs.data.numpy()[0, :, :, :], (1, 2, 0))
    pred = 1 / (1 + np.exp(-pred))
    pred = np.squeeze(pred)
    gt = tens2image(tr_sample['gt'])
    bbox = get_bbox(gt, pad=30, zero_pad=True)
    result = crop2fullmask(pred, bbox, gt, zero_pad=True, relax=0,mask_relax=False)

    light = np.zeros_like(image)
    light[:, :, 2] = 255.

    alpha = 0.5

    blending = (alpha * light + (1 - alpha) * image) * result[..., None] + (1 - result[..., None]) * image

    blending[blending > 255.] = 255

    cv2.imshow('resulting segmentation', cv2.cvtColor(blending.astype(np.uint8), cv2.COLOR_RGB2BGR))
    cv2.waitKey(0)
    cv2.destroyAllWindows()
Beispiel #13
0
def extract_hard_example(args, batch_size=8, recal=False):
    transform = transforms.Compose([
        tr.CropFromMask(crop_elems=('image', 'gt'),
                        relax=20,
                        zero_pad=True,
                        jitters_bound=(40, 70)),
        tr.FixedResize(resolutions={
            'crop_image': (513, 513),
            'crop_gt': (513, 513)
        }),
        tr.Normalize(elems='crop_image'),
        tr.ToTensor(),
    ])
    dataset = VOCSegmentation(root=Path.db_root_dir('pascal'),
                              split='train',
                              transform=transform,
                              download=False,
                              preprocess=False,
                              area_thres=500,
                              retname=True,
                              suppress_void_pixels=True,
                              which_part=args.which,
                              default=False)
    click_list_file = os.path.join(
        dataset.root, dataset.BASE_DIR, 'ImageSets', 'Segmentation', '-'.join([
            args.sbox, 'mIoU_thres',
            str(args.low_thres),
            str(args.high_thres), 'on', args.which, 'sets'
        ]) + '.txt')
    if os.path.isfile(click_list_file) and not recal:
        return
    device = torch.device("cuda:" +
                          str(0) if torch.cuda.is_available() else "cpu")
    sbox_net = DeepLabX(pretrain=False)
    path = 'run/' + args.sbox
    sbox_net.load_state_dict(
        torch.load(path, map_location=device)['state_dict'])
    sbox_net = sbox_net.to(device)
    sbox_net.eval()
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    tbar = tqdm(dataloader, desc='\r')
    hard_examples = []
    n_hard = 0
    for i, sample in enumerate(tbar):
        image, gt = sample['crop_image'], sample['crop_gt']
        image, gt = image.to(device), gt.to(device)
        with torch.no_grad():
            pred, _, _ = sbox_net(image)

        pred = F.interpolate(pred,
                             size=gt.size()[-2:],
                             mode='bilinear',
                             align_corners=True)
        pred = pred.data.cpu().numpy()
        target = gt.cpu().numpy()
        pred = np.argmax(pred, axis=1)
        for j in range(pred.shape[0]):
            matrix = _generate_matrix(target[j], pred[j])
            iou = Mean_Intersection_over_Union(matrix)
            if iou <= args.high_thres and iou >= args.low_thres:
                print(
                    '{}th object, IoU:{}, needs to be refined, {}th object added'
                    .format(i * batch_size + j, iou, n_hard))
                n_hard += 1
                hard_examples.append(dataset.target_list[i * batch_size + j])

    with open(click_list_file, 'w') as outfile:
        outfile.write(json.dumps(hard_examples))

    print('total {} objects need to be refined by clicknet'.format(
        len(hard_examples)))