def __call__(self, sample):
     msk = sample['gt']
     bbox = helpers.get_bbox(msk)
     out = np.ones(msk.shape) * 255
     out[bbox[1]: bbox[3], bbox[0]: bbox[2]] = 0
     sample['bb_mask'] = out.astype(np.float32)
     return sample
Beispiel #2
0
def gen_init_mask(video_path: str, extreme_points: str) -> str:
    """ Generate mask of a single image by 4 extreme points.

    Args:
        video_path(str): path to original video.
        extreme_points(str): coordinate of 4 extreme points, split by '|'.
    Returns:
        overlay_save_path(str): path to the generated overlay mask.
    """

    pad = 50
    thres = 0.8

    mask_arr = []

    image = get_first_frame(video_path)
    extreme_points = get_expt_coordinate(extreme_points)
    bbox = helpers.get_bbox(image,
                            points=extreme_points,
                            pad=pad,
                            zero_pad=True)
    inputs = get_inputs(image, bbox, extreme_points, pad)
    outputs = gen_seg(inputs)
    mask = gen_mask(outputs, bbox, image.shape[:2], pad, thres)

    mask_arr.append(mask)

    overlay_mask = helpers.overlay_masks(image / 255, mask_arr) * 255

    mask = Image.fromarray(mask)
    mask.save(MASK_SAVE_PATH)
    overlay_mask = Image.fromarray(overlay_mask.astype("uint8"))
    overlay_mask.save(OVERLAY_SAVE_PATH)

    return OVERLAY_SAVE_PATH
Beispiel #3
0
 def __call__(self, sample):
     self.dz = random.randint(350, 400)
     _target = sample[self.mask_elem]
     if len(np.unique(_target)) == 1:
         sample['crop_image'] = sample['image']
         sample['crop_gt'] = sample['gt']
         return sample
     if _target.ndim == 2:
         _target = np.expand_dims(_target, axis=-1)
     for elem in self.crop_elems:
         _img = sample[elem]
         _crop = []
         ### dynamic relax crop ###
         bbox = helpers.get_bbox(_target)
         d = np.maximum(bbox[2] - bbox[0], bbox[3] - bbox[1])
         sample['temp'] = d
         zoom_factor = self.dz / d
         crop_relax = (512 - d * zoom_factor) / (2 * zoom_factor)
         self.crop_relax = np.ceil(crop_relax).astype(int)
         self.crop_relax = np.maximum(15, self.crop_relax)
         sample['crop_relax'] = self.crop_relax
         ###                    ###
         if self.mask_elem == elem:
             if _img.ndim == 2:
                 _img = np.expand_dims(_img, axis=-1)
             for k in range(0, _target.shape[-1]):
                 _tmp_img = _img[..., k]
                 _tmp_target = _target[..., k]
                 if np.max(_target[..., k]) == 0:
                     _crop.append(np.zeros(_tmp_img.shape,
                                           dtype=_img.dtype))
                 else:
                     _crop.append(
                         helpers.crop_from_mask(_tmp_img,
                                                _tmp_target,
                                                relax=self.crop_relax,
                                                zero_pad=self.zero_pad))
         else:
             for k in range(0, _target.shape[-1]):
                 if np.max(_target[..., k]) == 0:
                     _crop.append(np.zeros(_img.shape, dtype=_img.dtype))
                 else:
                     _tmp_target = _target[..., k]
                     _crop.append(
                         helpers.crop_from_mask(_img,
                                                _tmp_target,
                                                relax=self.crop_relax,
                                                zero_pad=self.zero_pad))
         if len(_crop) == 1:
             sample['crop_' + elem] = _crop[0]
         else:
             sample['crop_' + elem] = _crop
     return sample
Beispiel #4
0
def get_mask(image, extreme_points_ori, pad=50, thres=0.8):
    modelName = 'dextr_pascal-sbd'
    gpu_id = 0
    device = torch.device("cuda:"+str(gpu_id) if torch.cuda.is_available() else "cpu")

    #  Create the network and load the weights
    net = resnet.resnet101(1, nInputChannels=4, classifier='psp')
    print("Initializing weights from: {}".format(os.path.join(Path.models_dir(), modelName + '.pth')))
    state_dict_checkpoint = torch.load(os.path.join(Path.models_dir(), modelName + '.pth'),
                                       map_location=lambda storage, loc: storage)
    # Remove the prefix .module from the model when it is trained using DataParallel
    if 'module.' in list(state_dict_checkpoint.keys())[0]:
        new_state_dict = OrderedDict()
        for k, v in state_dict_checkpoint.items():
            name = k[7:]  # remove `module.` from multi-gpu training
            new_state_dict[name] = v
    else:
        new_state_dict = state_dict_checkpoint
    net.load_state_dict(new_state_dict)
    net.eval()
    net.to(device)
    
    with torch.no_grad():
        results = []
        #  Crop image to the bounding box from the extreme points and resize
        bbox = helpers.get_bbox(image, points=extreme_points_ori, pad=pad, zero_pad=True)
        crop_image = helpers.crop_from_bbox(image, bbox, zero_pad=True)
        resize_image = helpers.fixed_resize(crop_image, (512, 512)).astype(np.float32)

        #  Generate extreme point heat map normalized to image values
        extreme_points = extreme_points_ori - [np.min(extreme_points_ori[:, 0]), np.min(extreme_points_ori[:, 1])] + [pad,
                                                                                                                      pad]
        extreme_points = (512 * extreme_points * [1 / crop_image.shape[1], 1 / crop_image.shape[0]]).astype(np.int)
        extreme_heatmap = helpers.make_gt(resize_image, extreme_points, sigma=10)
        extreme_heatmap = helpers.cstm_normalize(extreme_heatmap, 255)

        #  Concatenate inputs and convert to tensor
        input_dextr = np.concatenate((resize_image, extreme_heatmap[:, :, np.newaxis]), axis=2)
        inputs = torch.from_numpy(input_dextr.transpose((2, 0, 1))[np.newaxis, ...])

        # Run a forward pass
        inputs = inputs.to(device)
        outputs = net.forward(inputs)
        outputs = interpolate(outputs, size=(512, 512), mode='bilinear', align_corners=True)
        outputs = outputs.to(torch.device('cpu'))

        pred = np.transpose(outputs.data.numpy()[0, ...], (1, 2, 0))
        pred = 1 / (1 + np.exp(-pred))
        pred = np.squeeze(pred)
        result = helpers.crop2fullmask(pred, bbox, im_size=image.shape[:2], zero_pad=True, relax=pad) > thres
        results.append(result)

        return results, bbox
 def __call__(self, sample):
     if self.is_val == False:
         self.dz = np.array(np.random.randint(self.min_,self.max_)).astype(np.float32)
     _target = sample[self.mask_elem]
     if len(np.unique(_target)) == 1:
         sample['crop_image'] = sample['image']
         sample['crop_gt'] = sample['gt']
         return sample 
     if _target.ndim == 2:
         _target = np.expand_dims(_target, axis=-1)
     for elem in self.crop_elems:
         _img = sample[elem]
         _crop = []
         ### dynamic relax crop ###
         bbox = helpers.get_bbox(_target)
         d = np.maximum(bbox[2] - bbox[0], bbox[3] - bbox[1])
         if d < 1:
             print("Very small objects detected")
             print(sample['id'])
         zoom_factor = self.dz/d
         crop_relax = (self.d-d*zoom_factor)/(2*zoom_factor)
         crop_relax = np.maximum(crop_relax, self.thresh)
         self.crop_relax = np.ceil(crop_relax).astype(int)
         sample['crop_relax'] = self.crop_relax
         ###                    ###
         if self.mask_elem == elem:
             if _img.ndim == 2:
                 _img = np.expand_dims(_img, axis=-1)
             for k in range(0, _target.shape[-1]):
                 _tmp_img = _img[..., k]
                 _tmp_target = _target[..., k]
                 if np.max(_target[..., k]) == 0:
                     _crop.append(np.zeros(_tmp_img.shape, dtype=_img.dtype))
                 else:
                     _crop.append(helpers.crop_from_mask(_tmp_img, _tmp_target, relax=self.crop_relax, zero_pad=self.zero_pad))
         else:
             for k in range(0, _target.shape[-1]):
                 if np.max(_target[..., k]) == 0:
                     _crop.append(np.zeros(_img.shape, dtype=_img.dtype))
                 else:
                     _tmp_target = _target[..., k]
                     _crop.append(helpers.crop_from_mask(_img, _tmp_target, relax=self.crop_relax, zero_pad=self.zero_pad))
         if len(_crop) == 1:
             sample['crop_' + elem] = _crop[0]
         else:
             sample['crop_' + elem] = _crop
     return sample
Beispiel #6
0
with torch.no_grad():
    while 1:
        extreme_points_ori = np.array(plt.ginput(4, timeout=0)).astype(np.int)
        begin = time()
        if extreme_points_ori.shape[0] < 4:
            if len(results) > 0:
                helpers.save_mask(results, 'demo.png')
                print('Saving mask annotation in demo.png and exiting...')
            else:
                print('Exiting...')
            sys.exit()

        #  Crop image to the bounding box from the extreme points and resize
        bbox = helpers.get_bbox(image,
                                points=extreme_points_ori,
                                pad=pad,
                                zero_pad=True)
        crop_image = helpers.crop_from_bbox(image, bbox, zero_pad=True)
        resize_image = helpers.fixed_resize(crop_image,
                                            (512, 512)).astype(np.float32)

        #  Generate extreme point heat map normalized to image values
        extreme_points = extreme_points_ori - [
            np.min(extreme_points_ori[:, 0]),
            np.min(extreme_points_ori[:, 1])
        ] + [pad, pad]
        extreme_points = (
            512 * extreme_points *
            [1 / crop_image.shape[1], 1 / crop_image.shape[0]]).astype(np.int)
        extreme_heatmap = helpers.make_gt(resize_image,
                                          extreme_points,
Beispiel #7
0
    def _segment(self, image, extreme_points_ori):
        with torch.no_grad():
            # Crop the image
            h, w = image.shape[:2]
            if self.cfg['adaptive_relax']:
                mean_shape = np.mean(image.shape[:2])
                relax = int(self.cfg['relax_crop'] * mean_shape / 428.)
            else:
                relax = self.cfg['relax_crop']
            bbox = helpers.get_bbox(image,
                                    points=extreme_points_ori,
                                    pad=relax,
                                    zero_pad=self.cfg['zero_pad_crop'])
            crop_image = helpers.crop_from_bbox(
                image, bbox, zero_pad=self.cfg['zero_pad_crop'])

            # Compute the offsets of extreme points
            bounds = (0, 0, w - 1, h - 1)
            bbox_valid = (max(bbox[0], bounds[0]), max(bbox[1], bounds[1]),
                          min(bbox[2], bounds[2]), min(bbox[3], bounds[3]))
            if self.cfg['zero_pad_crop']:
                offsets = (-bbox[0], -bbox[1])
            else:
                offsets = (-bbox_valid[0], -bbox_valid[1])
            crop_extreme_points = extreme_points_ori + offsets

            # Resize
            if (np.minimum(h, w) < self.cfg['min_size']) or (np.maximum(
                    h, w) > self.cfg['max_size']):
                sc1 = self.cfg['min_size'] / np.minimum(h, w)
                sc2 = self.cfg['max_size'] / np.maximum(h, w)
                if sc1 > 1:
                    sc = sc1
                else:
                    sc = np.maximum(sc1, sc2)
                resize_image = cv2.resize(crop_image, (0, 0),
                                          fx=sc,
                                          fy=sc,
                                          interpolation=cv2.INTER_LINEAR)
                points = crop_extreme_points * sc
            else:
                resize_image = crop_image
                points = crop_extreme_points
            h2, w2 = resize_image.shape[:2]

            # Compute image gradient using Sobel filter
            img_r = resize_image[:, :, 0]
            img_g = resize_image[:, :, 1]
            img_b = resize_image[:, :, 2]
            grad_r = helpers.imgradient(img_r)[0]
            grad_g = helpers.imgradient(img_g)[0]
            grad_b = helpers.imgradient(img_b)[0]
            image_grad = np.sqrt(grad_r**2 + grad_g**2 + grad_b**2)
            # Normalize to [0,1]
            image_grad = (image_grad - image_grad.min()) / (image_grad.max() -
                                                            image_grad.min())

            # Convert extreme points to Gaussian heatmaps
            heatmap = helpers.gaussian_transform(resize_image,
                                                 points,
                                                 sigma=10)

            # Resize to a fixed resolution to for global context extraction
            resolution = (self.cfg['lr_size'], self.cfg['lr_size'])
            lr_points = np.array([resolution[0] / w2, resolution[1] / h2
                                  ]) * points
            lr_image = helpers.fixed_resize(resize_image, resolution)

            # Convert the extreme points to Gaussian heatmaps
            lr_heatmap = helpers.gaussian_transform(lr_image,
                                                    lr_points,
                                                    sigma=10)

            # Normalize inputs
            heatmap = 255 * (heatmap - heatmap.min()) / (heatmap.max() -
                                                         heatmap.min() + 1e-10)
            lr_heatmap = 255 * (lr_heatmap - lr_heatmap.min()) / (
                lr_heatmap.max() - lr_heatmap.min() + 1e-10)
            image_grad = 255 * (image_grad - image_grad.min()) / (
                image_grad.max() - image_grad.min() + 1e-10)

            # Concatenate the inputs (1, H, W, C)
            concat_lr = np.concatenate([lr_image, lr_heatmap[:, :, None]],
                                       axis=-1)[None, :].transpose(0, 3, 1, 2)
            concat = np.concatenate([resize_image, heatmap[:, :, None]],
                                    axis=-1)[None, :].transpose(0, 3, 1, 2)
            grad = np.concatenate([resize_image, image_grad[:, :, None]],
                                  axis=-1)[None, :].transpose(0, 3, 1, 2)

            # Convert to PyTorch tensors
            concat_lr = torch.from_numpy(concat_lr).float().to(self.device)
            concat = torch.from_numpy(concat).float().to(self.device)
            grad = torch.from_numpy(grad).float().to(self.device)

            # Forward pass
            outs = self.net.forward(concat, grad, concat_lr, roi=None)[1]
            output = torch.sigmoid(outs).cpu().numpy().squeeze()

            # Project back to original image space
            result = helpers.crop2fullmask(output,
                                           bbox,
                                           im_size=image.shape[:2],
                                           zero_pad=True,
                                           relax=relax)

        return result
Beispiel #8
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-i',
                        '--image',
                        type=str,
                        default='ims/dog-cat.jpg',
                        help='path to image')
    parser.add_argument('--model-name', type=str, default='dextr_pascal-sbd')
    parser.add_argument('-o',
                        '--output',
                        type=str,
                        default='results',
                        help='path where results will be saved')
    parser.add_argument('--pad', type=int, default=50, help='padding size')
    parser.add_argument('--thres', type=float, default=.9)
    parser.add_argument('--gpu-id', type=int, default=0)
    parser.add_argument('--anchors',
                        type=int,
                        default=5,
                        help='amount of points to set')
    parser.add_argument(
        '--anchor-points',
        type=str,
        default=None,
        help='path to folder of anchor points (tracking points)')
    parser.add_argument(
        '--use-frame-info',
        type=bool,
        default=True,
        help='wheter to use the frame number from the csv file or not')
    parser.add_argument('--corrections',
                        action='store_true',
                        help='toggle popup message wheater to correct or not')
    parser.add_argument(
        '--cut',
        action='store_true',
        help='if used, will save the cutted image instead of the mask as png')

    opt = parser.parse_args()
    modelName = opt.model_name
    pad = opt.pad
    thres = opt.thres
    gpu_id = opt.gpu_id
    device = torch.device("cuda:" +
                          str(gpu_id) if torch.cuda.is_available() else "cpu")

    #  Create the network and load the weights
    net = resnet.resnet101(1, nInputChannels=4, classifier='psp')
    print("Initializing weights from: {}".format(
        os.path.join(Path.models_dir(), modelName + '.pth')))
    state_dict_checkpoint = torch.load(
        os.path.join(Path.models_dir(), modelName + '.pth'),
        map_location=lambda storage, loc: storage)
    # Remove the prefix .module from the model when it is trained using DataParallel
    if 'module.' in list(state_dict_checkpoint.keys())[0]:
        new_state_dict = OrderedDict()
        for k, v in state_dict_checkpoint.items():
            name = k[7:]  # remove `module.` from multi-gpu training
            new_state_dict[name] = v
    else:
        new_state_dict = state_dict_checkpoint
    net.load_state_dict(new_state_dict)
    net.eval()
    net.to(device)

    #  Read image and click the points
    if os.path.isfile(opt.image):
        images = [opt.image]
    else:
        images = sorted(glob.glob(opt.image + '/*.*'))
    if opt.anchor_points:
        tracks = sorted(glob.glob(opt.anchor_points + '/*.csv'))
        frames, X, Y = [], [], []
        for i in range(len(tracks)):
            f, x, y = np.loadtxt(tracks[i], delimiter=',', unpack=True)
            frames.append(f.tolist())
            X.append(x.tolist())
            Y.append(y.tolist())
        anchorPoints = []
        uframes = np.unique(np.hstack([np.array(a) for a in frames])).tolist()
        # print(uframes)
        for i in range(len(uframes)):
            extreme_points = []
            for j in range(len(frames)):
                try:
                    ind = frames[j].index(uframes[i])
                    extreme_points.append([X[j][ind], Y[j][ind]])
                except ValueError:
                    continue
            anchorPoints.append(np.array(extreme_points))

    for i, img in enumerate(images):

        if opt.use_frame_info and opt.anchor_points is not None:
            file_number = int(re.sub(r'\D', '', img))
            if not file_number in uframes:
                print(img, 'skipped')
                continue

        if opt.anchor_points is None:
            plt.figure()
        while True:
            image = np.array(Image.open(img))
            mask_path = os.path.join(opt.output, os.path.split(img)[1])
            if opt.anchor_points is None:
                plt.ion()
                plt.axis('off')
                plt.imshow(image)
                plt.title(
                    'Click the four extreme points of the objects\nHit enter/middle mouse button when done (do not close the window)'
                )

            results = []

            with torch.no_grad():
                # while 1:
                if opt.anchor_points:
                    if opt.use_frame_info:
                        try:
                            index = uframes.index(file_number)
                        except ValueError:
                            print(
                                'Could not find data for frame %i. Use frame %i instead.'
                                % (file_number, i))
                            index = i
                    else:
                        index = i
                    extreme_points_ori = anchorPoints[index].astype(np.int)
                else:
                    extreme_points_ori = np.array(
                        plt.ginput(opt.anchors, timeout=0)).astype(np.int)

                # print(extreme_points_ori,extreme_points_ori.shape)
                #  Crop image to the bounding box from the extreme points and resize
                bbox = helpers.get_bbox(image,
                                        points=extreme_points_ori,
                                        pad=pad,
                                        zero_pad=False)
                crop_image = helpers.crop_from_bbox(image, bbox, zero_pad=True)
                resize_image = helpers.fixed_resize(
                    crop_image, (512, 512)).astype(np.float32)

                #  Generate extreme point heat map normalized to image values
                extreme_points = extreme_points_ori - [
                    np.min(extreme_points_ori[:, 0]),
                    np.min(extreme_points_ori[:, 1])
                ] + [pad, pad]
                extreme_points = (
                    512 * extreme_points *
                    [1 / crop_image.shape[1], 1 / crop_image.shape[0]]).astype(
                        np.int)
                extreme_heatmap = helpers.make_gt(resize_image,
                                                  extreme_points,
                                                  sigma=10)
                extreme_heatmap = helpers.cstm_normalize(extreme_heatmap, 255)

                #  Concatenate inputs and convert to tensor
                input_dextr = np.concatenate(
                    (resize_image, extreme_heatmap[:, :, np.newaxis]), axis=2)
                inputs = torch.from_numpy(
                    input_dextr.transpose((2, 0, 1))[np.newaxis, ...])

                # Run a forward pass
                inputs = inputs.to(device)
                outputs = net.forward(inputs)
                outputs = interpolate(outputs,
                                      size=(512, 512),
                                      mode='bilinear',
                                      align_corners=True)
                outputs = outputs.to(torch.device('cpu'))

                pred = np.transpose(outputs.data.numpy()[0, ...], (1, 2, 0))
                pred = 1 / (1 + np.exp(-pred))
                pred = np.squeeze(pred)
                result = helpers.crop2fullmask(pred,
                                               bbox,
                                               im_size=image.shape[:2],
                                               zero_pad=True,
                                               relax=pad) > thres

                results.append(result)
                # Plot the results
                plt.imshow(helpers.overlay_masks(image / 255, results))
                plt.plot(extreme_points_ori[:, 0], extreme_points_ori[:, 1],
                         'gx')

                if not opt.cut:
                    helpers.save_mask(results, mask_path)
                else:
                    Image.fromarray(
                        np.concatenate(
                            (image, 255 * result[..., None].astype(np.int)),
                            2).astype(np.uint8)).save(mask_path, 'png')
                '''if len(extreme_points_ori) < 4:
                        if len(results) > 0:
                            helpers.save_mask(results, 'demo.png')
                            print('Saving mask annotation in demo.png and exiting...')
                        else:
                            print('Exiting...')
                        sys.exit()'''
            if opt.anchor_points is None:
                plt.close()
            if opt.corrections:
                if easygui.ynbox(image=mask_path):
                    break
            else:
                break
        print(img, 'done')
Beispiel #9
0
    with torch.no_grad():
        for ii, sample in enumerate(tqdm(testloader)):

            # Read (image, gt) pairs
            inputs = sample['concat'].to(device)
            inputs_lr = sample['concat_lr'].to(device)
            grads = sample['grad'].to(device)
            metas = sample['meta']

            # Forward pass
            outs = net.forward(inputs, grads, inputs_lr, roi=None)[1]
            assert outs.size()[2:] == inputs.size()[2:]
            output = torch.sigmoid(outs).cpu().numpy().squeeze()

            # Project back to original image space
            relax = sample['meta']['relax'][0].item()
            gt = sample['ori_gt'].numpy().squeeze()
            bbox = helpers.get_bbox(gt, pad=relax, zero_pad=True)
            result = helpers.crop2fullmask(output,
                                           bbox,
                                           gt,
                                           zero_pad=True,
                                           relax=relax)
            result = np.uint8(result * 255)

            # Save results
            imageio.imwrite(os.path.join(save_dir, metas['image'][0] + \
                                '-' + metas['object'][0] + '.png'), result)

    print('Done testing for dataset: {}'.format(args.test_set))
def eval_one_result(loader, folder, one_mask_per_image=False, mask_thres=0.5, use_void_pixels=True, custom_box=False):
    def mAPr(per_cat, thresholds):
        n_cat = len(per_cat)
        all_apr = np.zeros(len(thresholds))
        for ii, th in enumerate(thresholds):
            per_cat_recall = np.zeros(n_cat)
            for jj, categ in enumerate(per_cat.keys()):
                per_cat_recall[jj] = np.sum(np.array(per_cat[categ]) > th)/len(per_cat[categ])

            all_apr[ii] = per_cat_recall.mean()

        return all_apr.mean()

    # Allocate
    eval_result = dict()
    eval_result["all_jaccards"] = np.zeros(len(loader))
    eval_result["all_percent"] = np.zeros(len(loader))
    eval_result["meta"] = []
    eval_result["per_categ_jaccard"] = dict()

    # Iterate
    for i, sample in enumerate(loader):

        if i % 500 == 0:
            print('Evaluating: {} of {} objects'.format(i, len(loader)))

        # Load result
        if not one_mask_per_image:
            filename = os.path.join(folder,
                                    sample["meta"]["image"][0] + '-' + sample["meta"]["object"][0] + '.png')
        else:
            filename = os.path.join(folder,
                                    sample["meta"]["image"][0] + '.png')
        mask = np.array(Image.open(filename)).astype(np.float32) / 255.
        gt = np.squeeze(helpers.tens2image(sample["gt"]))
        if use_void_pixels:
            void_pixels = np.squeeze(helpers.tens2image(sample["void_pixels"]))
        if mask.shape != gt.shape:
            mask = cv2.resize(mask, gt.shape[::-1], interpolation=cv2.INTER_CUBIC)

        # Threshold
        mask = (mask > mask_thres)
        if use_void_pixels:
            void_pixels = (void_pixels > 0.5)

        # Evaluate
        if use_void_pixels:
            eval_result["all_jaccards"][i] = evaluation.jaccard(gt, mask, void_pixels)
        else:
            eval_result["all_jaccards"][i] = evaluation.jaccard(gt, mask)

        if custom_box:
            box = np.squeeze(helpers.tens2image(sample["box"]))
            bb = helpers.get_bbox(box)
        else:
            bb = helpers.get_bbox(gt)

        mask_crop = helpers.crop_from_bbox(mask, bb)
        if use_void_pixels:
            non_void_pixels_crop = helpers.crop_from_bbox(np.logical_not(void_pixels), bb)
        gt_crop = helpers.crop_from_bbox(gt, bb)
        if use_void_pixels:
            eval_result["all_percent"][i] = np.sum((gt_crop != mask_crop) & non_void_pixels_crop)/np.sum(non_void_pixels_crop)
        else:
            eval_result["all_percent"][i] = np.sum((gt_crop != mask_crop))/mask_crop.size
        # Store in per category
        if "category" in sample["meta"]:
            cat = sample["meta"]["category"][0]
        else:
            cat = 1
        if cat not in eval_result["per_categ_jaccard"]:
            eval_result["per_categ_jaccard"][cat] = []
        eval_result["per_categ_jaccard"][cat].append(eval_result["all_jaccards"][i])

        # Store meta
        eval_result["meta"].append(sample["meta"])

    # Compute some stats
    eval_result["mAPr0.5"] = mAPr(eval_result["per_categ_jaccard"], [0.5])
    eval_result["mAPr0.7"] = mAPr(eval_result["per_categ_jaccard"], [0.7])
    eval_result["mAPr-vol"] = mAPr(eval_result["per_categ_jaccard"], np.linspace(0.1, 0.9, 9))

    return eval_result
Beispiel #11
0
def dextr_helper(img_url=IMG_URL, extreme_pts=EXTREME_PTS):
    """
    @params 
    img_url - string containing url to the image
    extreme_pts - list of (x, y) extreme coordinate tuples 

    @returns tuple - (bbox, mask, pred)
    bbox (x_min, y_min, x_max, y_max) is the bounding box generated from he extreme points
    mask is a boolean numpy array indicating presence of instance 
    pred is the classification result 
    """

    response = requests.get(img_url)
    image = np.array(Image.open(io.BytesIO(response.content)))

    with torch.no_grad():
        extreme_points_ori = np.array(extreme_pts).astype(np.int)

        #  Crop image to the bounding box from the extreme points and resize
        bbox = helpers.get_bbox(image,
                                points=extreme_points_ori,
                                pad=pad,
                                zero_pad=True)
        crop_image = helpers.crop_from_bbox(image, bbox, zero_pad=True)
        resize_image = helpers.fixed_resize(crop_image,
                                            (512, 512)).astype(np.float32)

        class_prediction = get_prediction_numpy(crop_image)
        # print("Class Prediction is : {}".format(class_prediction))

        # this is the bounding box to return (with 0 padding)
        actual_bbox = helpers.get_bbox(image,
                                       points=extreme_points_ori,
                                       pad=0,
                                       zero_pad=True)

        #  Generate extreme point heat map normalized to image values
        extreme_points = extreme_points_ori - [
            np.min(extreme_points_ori[:, 0]),
            np.min(extreme_points_ori[:, 1])
        ] + [pad, pad]
        extreme_points = (
            512 * extreme_points *
            [1 / crop_image.shape[1], 1 / crop_image.shape[0]]).astype(np.int)
        extreme_heatmap = helpers.make_gt(resize_image,
                                          extreme_points,
                                          sigma=10)
        extreme_heatmap = helpers.cstm_normalize(extreme_heatmap, 255)

        #  Concatenate inputs and convert to tensor
        input_dextr = np.concatenate(
            (resize_image, extreme_heatmap[:, :, np.newaxis]), axis=2)
        inputs = torch.from_numpy(
            input_dextr.transpose((2, 0, 1))[np.newaxis, ...])

        # Run a forward pass
        inputs = inputs.to(device)
        outputs = net.forward(inputs)
        outputs = upsample(outputs,
                           size=(512, 512),
                           mode='bilinear',
                           align_corners=True)
        outputs = outputs.to(torch.device('cpu'))

        pred = np.transpose(outputs.data.numpy()[0, ...], (1, 2, 0))
        pred = 1 / (1 + np.exp(-pred))
        pred = np.squeeze(pred)

        # Here result is of the shape of image, where True implies that part should be in the segment
        result = helpers.crop2fullmask(
            pred, bbox, im_size=image.shape[:2], zero_pad=True,
            relax=pad) > thres

        return (actual_bbox, result, class_prediction)
Beispiel #12
0
def demo(net, image_path='ims/soccer.jpg'):
    pad = 50
    thres = 0.8
    #  Read image and click the points
    image = np.array(Image.open(image_path))
    plt.ion()
    plt.axis('off')
    plt.imshow(image)
    plt.title(
        'Click the four extreme points of the objects\nHit enter when done (do not close the window)'
    )
    results = []
    while True:
        extreme_points_ori = np.array(plt.ginput(4, timeout=0)).astype(np.int)
        begin = time()
        if extreme_points_ori.shape[0] < 4:
            if len(results) > 0:
                helpers.save_mask(results, 'demo.png')
                print('Saving mask annotation in demo.png and exiting...')
            else:
                print('Exiting...')
            sys.exit()

        #  Crop image to the bounding box from the extreme points and resize
        bbox = helpers.get_bbox(image,
                                points=extreme_points_ori,
                                pad=pad,
                                zero_pad=True)
        crop_image = helpers.crop_from_bbox(image, bbox, zero_pad=True)
        resize_image = helpers.fixed_resize(crop_image,
                                            (512, 512)).astype(np.float32)

        #  Generate extreme point heat map normalized to image values
        extreme_points = extreme_points_ori - [
            np.min(extreme_points_ori[:, 0]),
            np.min(extreme_points_ori[:, 1])
        ] + [pad, pad]
        extreme_points = (
            512 * extreme_points *
            [1 / crop_image.shape[1], 1 / crop_image.shape[0]]).astype(np.int)
        extreme_heatmap = helpers.make_gt(resize_image,
                                          extreme_points,
                                          sigma=10)
        extreme_heatmap = helpers.cstm_normalize(extreme_heatmap, 255)

        #  Concatenate inputs and convert to tensor
        input_dextr = np.concatenate(
            (resize_image, extreme_heatmap[:, :, np.newaxis]), axis=2)
        inputs = torch.from_numpy(
            input_dextr.transpose((2, 0, 1))[np.newaxis, ...])

        # Run a forward pass
        outputs = net.forward(inputs)
        outputs = upsample(outputs,
                           size=(512, 512),
                           mode='bilinear',
                           align_corners=True)
        outputs = torch.sigmoid(outputs)
        outputs = outputs.to(torch.device('cpu'))

        pred = np.transpose(outputs.data.numpy()[0, ...], (1, 2, 0))
        #pred = 1 / (1 + np.exp(-pred))
        pred = np.squeeze(pred)
        result = helpers.crop2fullmask(
            pred, bbox, im_size=image.shape[:2], zero_pad=True,
            relax=pad) > thres

        results.append(result)

        # Plot the results
        plt.imshow(helpers.overlay_masks(image / 255, results))
        plt.plot(extreme_points_ori[:, 0], extreme_points_ori[:, 1], 'gx')
        print('Time to plot: ', time() - begin, ' seconds.')