def __call__(self, sample):
        _target = sample[self.mask_elem]
        h, w = _target.shape[:2]

        heatmap = np.zeros((h, w))
        pos = np.zeros((h, w))
        for elem in self.tr_elems:
            _points = sample[elem]
            if _points is not None:
                if self.approx:
                    heatmap = np.maximum(heatmap, helpers.gaussian_transform(_target, 
                        _points, sigma=self.sigma)) # faster!
                else:
                    heatmap = np.maximum(heatmap, helpers.make_gt(_target, 
                        _points, sigma=self.sigma, one_mask_per_point=False))
                    
                # Return binary positions
                if self.return_pos:
                    _points = _points.astype(int)
                    pos[_points[:,1], _points[:,0]] = 1
        sample[self.tr_name] = heatmap
        if self.return_pos:
            sample[self.tr_name+'_pos'] = pos

        return sample
Exemplo n.º 2
0
    def __call__(self, sample):
        if sample[self.elem].ndim == 3:
            raise ValueError('ExtremePoints not implemented for multiple object per image.')
        _target = sample[self.elem]
        if np.max(_target) == 0:
            sample['extreme_points'] = np.zeros(_target.shape, dtype=_target.dtype) #  TODO: handle one_mask_per_point case
        else:
            _points = helpers.extreme_points(_target, self.pert)
            sample['extreme_points'] = helpers.make_gt(_target, _points, sigma=self.sigma, one_mask_per_point=False)

        return sample
Exemplo n.º 3
0
def get_mask(image, extreme_points_ori, pad=50, thres=0.8):
    modelName = 'dextr_pascal-sbd'
    gpu_id = 0
    device = torch.device("cuda:"+str(gpu_id) if torch.cuda.is_available() else "cpu")

    #  Create the network and load the weights
    net = resnet.resnet101(1, nInputChannels=4, classifier='psp')
    print("Initializing weights from: {}".format(os.path.join(Path.models_dir(), modelName + '.pth')))
    state_dict_checkpoint = torch.load(os.path.join(Path.models_dir(), modelName + '.pth'),
                                       map_location=lambda storage, loc: storage)
    # Remove the prefix .module from the model when it is trained using DataParallel
    if 'module.' in list(state_dict_checkpoint.keys())[0]:
        new_state_dict = OrderedDict()
        for k, v in state_dict_checkpoint.items():
            name = k[7:]  # remove `module.` from multi-gpu training
            new_state_dict[name] = v
    else:
        new_state_dict = state_dict_checkpoint
    net.load_state_dict(new_state_dict)
    net.eval()
    net.to(device)
    
    with torch.no_grad():
        results = []
        #  Crop image to the bounding box from the extreme points and resize
        bbox = helpers.get_bbox(image, points=extreme_points_ori, pad=pad, zero_pad=True)
        crop_image = helpers.crop_from_bbox(image, bbox, zero_pad=True)
        resize_image = helpers.fixed_resize(crop_image, (512, 512)).astype(np.float32)

        #  Generate extreme point heat map normalized to image values
        extreme_points = extreme_points_ori - [np.min(extreme_points_ori[:, 0]), np.min(extreme_points_ori[:, 1])] + [pad,
                                                                                                                      pad]
        extreme_points = (512 * extreme_points * [1 / crop_image.shape[1], 1 / crop_image.shape[0]]).astype(np.int)
        extreme_heatmap = helpers.make_gt(resize_image, extreme_points, sigma=10)
        extreme_heatmap = helpers.cstm_normalize(extreme_heatmap, 255)

        #  Concatenate inputs and convert to tensor
        input_dextr = np.concatenate((resize_image, extreme_heatmap[:, :, np.newaxis]), axis=2)
        inputs = torch.from_numpy(input_dextr.transpose((2, 0, 1))[np.newaxis, ...])

        # Run a forward pass
        inputs = inputs.to(device)
        outputs = net.forward(inputs)
        outputs = interpolate(outputs, size=(512, 512), mode='bilinear', align_corners=True)
        outputs = outputs.to(torch.device('cpu'))

        pred = np.transpose(outputs.data.numpy()[0, ...], (1, 2, 0))
        pred = 1 / (1 + np.exp(-pred))
        pred = np.squeeze(pred)
        result = helpers.crop2fullmask(pred, bbox, im_size=image.shape[:2], zero_pad=True, relax=pad) > thres
        results.append(result)

        return results, bbox
Exemplo n.º 4
0
    def __call__(self, sample):

        if sample[self.elem].ndim == 3:
            raise ValueError('IOGPoints not implemented for multiple object per image.')
        _target = sample[self.elem]

        targetshape=_target.shape
        if np.max(_target) == 0:
            sample['IOG_points'] = np.zeros([targetshape[0],targetshape[1],2], dtype=_target.dtype) #  TODO: handle one_mask_per_point case
        else:
            _points = helpers.iog_points(_target, self.pad_pixel)
            sample['IOG_points'] = helpers.make_gt(_target, _points, sigma=self.sigma, one_mask_per_point=False)

        return sample
Exemplo n.º 5
0
    def __call__(self, sample):
        if sample[self.elem].ndim == 3:
            raise ValueError(
                'ExtremePoints not implemented for multiple object per image.')
        _target = sample[self.elem]
        if np.max(_target) == 0:
            sample['extreme_points'] = np.zeros(
                _target.shape,
                dtype=_target.dtype)  #  TODO: handle one_mask_per_point case
        else:
            import matplotlib.pyplot as plt
            # import ipdb
            # ipdb.set_trace()
            if self.type == 'mask':
                _points = helpers.get_mask_sample_points(_target, 50)
            elif self.type == 'normal':
                _points = helpers.extreme_points(_target, self.pert)
            elif self.type == 'bbox':
                _points = helpers.get_bbox_sample_points(_target, self.num_pts)
            elif self.type == 'polygon':
                _polygons = helpers.mask_to_poly(_target, visualize=False)
                _non_pert_points = helpers.get_polygon_points(
                    _polygons, self.num_pts, _target.shape)
                _pert_points = [
                    point + (np.random.randint(-self.pert, self.pert),
                             np.random.randint(-self.pert, self.pert))
                    for point in _non_pert_points
                ]
                _points = np.array(_pert_points)
            elif self.type == 'mask_noise':
                _points = helpers.get_mask_noise_sample_masks(_target,
                                                              self.num_pts,
                                                              ratio=0.2)
            if self.vis:
                # if sample['meta']['category'] == 2:
                plt.imshow(_target)
                plt.scatter(_points[:, 0], _points[:, 1])
                plt.show()
                # plt.imshow(_target)
                # plt.scatter(_pert_points[:, 0], _pert_points[:, 1], c='r')

            sample['extreme_points'] = helpers.make_gt(
                _target, _points, sigma=self.sigma, one_mask_per_point=False)

            if self.vis:
                # if sample['meta']['category'] == 2:
                plt.imshow(sample['extreme_points'])
                plt.show()
        return sample
    def __call__(self, sample):

        if sample[self.elem].ndim == 3:
            raise ValueError(
                'distance_map not implemented for multiple object per image.')
        _target = sample[self.elem]

        targetshape = _target.shape
        if np.max(_target) == 0:
            sample['distance_map'] = np.zeros(
                [targetshape[0], targetshape[1], 2], dtype=_target.dtype)
        else:
            _points = helpers.GetDistanceMap(_target, self.pad_pixel)
            sample['distance_map'] = helpers.make_gt(_target,
                                                     _points,
                                                     sigma=self.sigma,
                                                     one_mask_per_point=False)
        return sample
Exemplo n.º 7
0
def get_inputs(image, bbox, expt, pad):
    crop_image = helpers.crop_from_bbox(image, bbox, zero_pad=True)
    resize_image = helpers.fixed_resize(crop_image,
                                        (512, 512)).astype(np.float32)

    #  Generate extreme point heat map normalized to image values
    extreme_points = expt - [np.min(expt[:, 0]),
                             np.min(expt[:, 1])] + [pad, pad]
    extreme_points = (
        512 * extreme_points *
        [1 / crop_image.shape[1], 1 / crop_image.shape[0]]).astype(np.int)
    extreme_heatmap = helpers.make_gt(resize_image, extreme_points, sigma=10)
    extreme_heatmap = helpers.cstm_normalize(extreme_heatmap, 255)

    #  Concatenate inputs and convert to tensor
    input_dextr = np.concatenate(
        (resize_image, extreme_heatmap[:, :, np.newaxis]), axis=2)
    inputs = torch.tensor(input_dextr.transpose((2, 0, 1))[np.newaxis, ...])
    return inputs
Exemplo n.º 8
0
                                pad=pad,
                                zero_pad=True)
        crop_image = helpers.crop_from_bbox(image, bbox, zero_pad=True)
        resize_image = helpers.fixed_resize(crop_image,
                                            (512, 512)).astype(np.float32)

        #  Generate extreme point heat map normalized to image values
        extreme_points = extreme_points_ori - [
            np.min(extreme_points_ori[:, 0]),
            np.min(extreme_points_ori[:, 1])
        ] + [pad, pad]
        extreme_points = (
            512 * extreme_points *
            [1 / crop_image.shape[1], 1 / crop_image.shape[0]]).astype(np.int)
        extreme_heatmap = helpers.make_gt(resize_image,
                                          extreme_points,
                                          sigma=10)
        extreme_heatmap = helpers.cstm_normalize(extreme_heatmap, 255)

        #  Concatenate inputs and convert to tensor
        input_dextr = np.concatenate(
            (resize_image, extreme_heatmap[:, :, np.newaxis]), axis=2)
        inputs = torch.from_numpy(
            input_dextr.transpose((2, 0, 1))[np.newaxis, ...])

        # Run a forward pass
        inputs = inputs.to(device)
        outputs = net.forward(inputs)
        outputs = upsample(outputs,
                           size=(512, 512),
                           mode='bilinear',
Exemplo n.º 9
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-i',
                        '--image',
                        type=str,
                        default='ims/dog-cat.jpg',
                        help='path to image')
    parser.add_argument('--model-name', type=str, default='dextr_pascal-sbd')
    parser.add_argument('-o',
                        '--output',
                        type=str,
                        default='results',
                        help='path where results will be saved')
    parser.add_argument('--pad', type=int, default=50, help='padding size')
    parser.add_argument('--thres', type=float, default=.9)
    parser.add_argument('--gpu-id', type=int, default=0)
    parser.add_argument('--anchors',
                        type=int,
                        default=5,
                        help='amount of points to set')
    parser.add_argument(
        '--anchor-points',
        type=str,
        default=None,
        help='path to folder of anchor points (tracking points)')
    parser.add_argument(
        '--use-frame-info',
        type=bool,
        default=True,
        help='wheter to use the frame number from the csv file or not')
    parser.add_argument('--corrections',
                        action='store_true',
                        help='toggle popup message wheater to correct or not')
    parser.add_argument(
        '--cut',
        action='store_true',
        help='if used, will save the cutted image instead of the mask as png')

    opt = parser.parse_args()
    modelName = opt.model_name
    pad = opt.pad
    thres = opt.thres
    gpu_id = opt.gpu_id
    device = torch.device("cuda:" +
                          str(gpu_id) if torch.cuda.is_available() else "cpu")

    #  Create the network and load the weights
    net = resnet.resnet101(1, nInputChannels=4, classifier='psp')
    print("Initializing weights from: {}".format(
        os.path.join(Path.models_dir(), modelName + '.pth')))
    state_dict_checkpoint = torch.load(
        os.path.join(Path.models_dir(), modelName + '.pth'),
        map_location=lambda storage, loc: storage)
    # Remove the prefix .module from the model when it is trained using DataParallel
    if 'module.' in list(state_dict_checkpoint.keys())[0]:
        new_state_dict = OrderedDict()
        for k, v in state_dict_checkpoint.items():
            name = k[7:]  # remove `module.` from multi-gpu training
            new_state_dict[name] = v
    else:
        new_state_dict = state_dict_checkpoint
    net.load_state_dict(new_state_dict)
    net.eval()
    net.to(device)

    #  Read image and click the points
    if os.path.isfile(opt.image):
        images = [opt.image]
    else:
        images = sorted(glob.glob(opt.image + '/*.*'))
    if opt.anchor_points:
        tracks = sorted(glob.glob(opt.anchor_points + '/*.csv'))
        frames, X, Y = [], [], []
        for i in range(len(tracks)):
            f, x, y = np.loadtxt(tracks[i], delimiter=',', unpack=True)
            frames.append(f.tolist())
            X.append(x.tolist())
            Y.append(y.tolist())
        anchorPoints = []
        uframes = np.unique(np.hstack([np.array(a) for a in frames])).tolist()
        # print(uframes)
        for i in range(len(uframes)):
            extreme_points = []
            for j in range(len(frames)):
                try:
                    ind = frames[j].index(uframes[i])
                    extreme_points.append([X[j][ind], Y[j][ind]])
                except ValueError:
                    continue
            anchorPoints.append(np.array(extreme_points))

    for i, img in enumerate(images):

        if opt.use_frame_info and opt.anchor_points is not None:
            file_number = int(re.sub(r'\D', '', img))
            if not file_number in uframes:
                print(img, 'skipped')
                continue

        if opt.anchor_points is None:
            plt.figure()
        while True:
            image = np.array(Image.open(img))
            mask_path = os.path.join(opt.output, os.path.split(img)[1])
            if opt.anchor_points is None:
                plt.ion()
                plt.axis('off')
                plt.imshow(image)
                plt.title(
                    'Click the four extreme points of the objects\nHit enter/middle mouse button when done (do not close the window)'
                )

            results = []

            with torch.no_grad():
                # while 1:
                if opt.anchor_points:
                    if opt.use_frame_info:
                        try:
                            index = uframes.index(file_number)
                        except ValueError:
                            print(
                                'Could not find data for frame %i. Use frame %i instead.'
                                % (file_number, i))
                            index = i
                    else:
                        index = i
                    extreme_points_ori = anchorPoints[index].astype(np.int)
                else:
                    extreme_points_ori = np.array(
                        plt.ginput(opt.anchors, timeout=0)).astype(np.int)

                # print(extreme_points_ori,extreme_points_ori.shape)
                #  Crop image to the bounding box from the extreme points and resize
                bbox = helpers.get_bbox(image,
                                        points=extreme_points_ori,
                                        pad=pad,
                                        zero_pad=False)
                crop_image = helpers.crop_from_bbox(image, bbox, zero_pad=True)
                resize_image = helpers.fixed_resize(
                    crop_image, (512, 512)).astype(np.float32)

                #  Generate extreme point heat map normalized to image values
                extreme_points = extreme_points_ori - [
                    np.min(extreme_points_ori[:, 0]),
                    np.min(extreme_points_ori[:, 1])
                ] + [pad, pad]
                extreme_points = (
                    512 * extreme_points *
                    [1 / crop_image.shape[1], 1 / crop_image.shape[0]]).astype(
                        np.int)
                extreme_heatmap = helpers.make_gt(resize_image,
                                                  extreme_points,
                                                  sigma=10)
                extreme_heatmap = helpers.cstm_normalize(extreme_heatmap, 255)

                #  Concatenate inputs and convert to tensor
                input_dextr = np.concatenate(
                    (resize_image, extreme_heatmap[:, :, np.newaxis]), axis=2)
                inputs = torch.from_numpy(
                    input_dextr.transpose((2, 0, 1))[np.newaxis, ...])

                # Run a forward pass
                inputs = inputs.to(device)
                outputs = net.forward(inputs)
                outputs = interpolate(outputs,
                                      size=(512, 512),
                                      mode='bilinear',
                                      align_corners=True)
                outputs = outputs.to(torch.device('cpu'))

                pred = np.transpose(outputs.data.numpy()[0, ...], (1, 2, 0))
                pred = 1 / (1 + np.exp(-pred))
                pred = np.squeeze(pred)
                result = helpers.crop2fullmask(pred,
                                               bbox,
                                               im_size=image.shape[:2],
                                               zero_pad=True,
                                               relax=pad) > thres

                results.append(result)
                # Plot the results
                plt.imshow(helpers.overlay_masks(image / 255, results))
                plt.plot(extreme_points_ori[:, 0], extreme_points_ori[:, 1],
                         'gx')

                if not opt.cut:
                    helpers.save_mask(results, mask_path)
                else:
                    Image.fromarray(
                        np.concatenate(
                            (image, 255 * result[..., None].astype(np.int)),
                            2).astype(np.uint8)).save(mask_path, 'png')
                '''if len(extreme_points_ori) < 4:
                        if len(results) > 0:
                            helpers.save_mask(results, 'demo.png')
                            print('Saving mask annotation in demo.png and exiting...')
                        else:
                            print('Exiting...')
                        sys.exit()'''
            if opt.anchor_points is None:
                plt.close()
            if opt.corrections:
                if easygui.ynbox(image=mask_path):
                    break
            else:
                break
        print(img, 'done')
Exemplo n.º 10
0
def dextr_helper(img_url=IMG_URL, extreme_pts=EXTREME_PTS):
    """
    @params 
    img_url - string containing url to the image
    extreme_pts - list of (x, y) extreme coordinate tuples 

    @returns tuple - (bbox, mask, pred)
    bbox (x_min, y_min, x_max, y_max) is the bounding box generated from he extreme points
    mask is a boolean numpy array indicating presence of instance 
    pred is the classification result 
    """

    response = requests.get(img_url)
    image = np.array(Image.open(io.BytesIO(response.content)))

    with torch.no_grad():
        extreme_points_ori = np.array(extreme_pts).astype(np.int)

        #  Crop image to the bounding box from the extreme points and resize
        bbox = helpers.get_bbox(image,
                                points=extreme_points_ori,
                                pad=pad,
                                zero_pad=True)
        crop_image = helpers.crop_from_bbox(image, bbox, zero_pad=True)
        resize_image = helpers.fixed_resize(crop_image,
                                            (512, 512)).astype(np.float32)

        class_prediction = get_prediction_numpy(crop_image)
        # print("Class Prediction is : {}".format(class_prediction))

        # this is the bounding box to return (with 0 padding)
        actual_bbox = helpers.get_bbox(image,
                                       points=extreme_points_ori,
                                       pad=0,
                                       zero_pad=True)

        #  Generate extreme point heat map normalized to image values
        extreme_points = extreme_points_ori - [
            np.min(extreme_points_ori[:, 0]),
            np.min(extreme_points_ori[:, 1])
        ] + [pad, pad]
        extreme_points = (
            512 * extreme_points *
            [1 / crop_image.shape[1], 1 / crop_image.shape[0]]).astype(np.int)
        extreme_heatmap = helpers.make_gt(resize_image,
                                          extreme_points,
                                          sigma=10)
        extreme_heatmap = helpers.cstm_normalize(extreme_heatmap, 255)

        #  Concatenate inputs and convert to tensor
        input_dextr = np.concatenate(
            (resize_image, extreme_heatmap[:, :, np.newaxis]), axis=2)
        inputs = torch.from_numpy(
            input_dextr.transpose((2, 0, 1))[np.newaxis, ...])

        # Run a forward pass
        inputs = inputs.to(device)
        outputs = net.forward(inputs)
        outputs = upsample(outputs,
                           size=(512, 512),
                           mode='bilinear',
                           align_corners=True)
        outputs = outputs.to(torch.device('cpu'))

        pred = np.transpose(outputs.data.numpy()[0, ...], (1, 2, 0))
        pred = 1 / (1 + np.exp(-pred))
        pred = np.squeeze(pred)

        # Here result is of the shape of image, where True implies that part should be in the segment
        result = helpers.crop2fullmask(
            pred, bbox, im_size=image.shape[:2], zero_pad=True,
            relax=pad) > thres

        return (actual_bbox, result, class_prediction)
Exemplo n.º 11
0
def demo(net, image_path='ims/soccer.jpg'):
    pad = 50
    thres = 0.8
    #  Read image and click the points
    image = np.array(Image.open(image_path))
    plt.ion()
    plt.axis('off')
    plt.imshow(image)
    plt.title(
        'Click the four extreme points of the objects\nHit enter when done (do not close the window)'
    )
    results = []
    while True:
        extreme_points_ori = np.array(plt.ginput(4, timeout=0)).astype(np.int)
        begin = time()
        if extreme_points_ori.shape[0] < 4:
            if len(results) > 0:
                helpers.save_mask(results, 'demo.png')
                print('Saving mask annotation in demo.png and exiting...')
            else:
                print('Exiting...')
            sys.exit()

        #  Crop image to the bounding box from the extreme points and resize
        bbox = helpers.get_bbox(image,
                                points=extreme_points_ori,
                                pad=pad,
                                zero_pad=True)
        crop_image = helpers.crop_from_bbox(image, bbox, zero_pad=True)
        resize_image = helpers.fixed_resize(crop_image,
                                            (512, 512)).astype(np.float32)

        #  Generate extreme point heat map normalized to image values
        extreme_points = extreme_points_ori - [
            np.min(extreme_points_ori[:, 0]),
            np.min(extreme_points_ori[:, 1])
        ] + [pad, pad]
        extreme_points = (
            512 * extreme_points *
            [1 / crop_image.shape[1], 1 / crop_image.shape[0]]).astype(np.int)
        extreme_heatmap = helpers.make_gt(resize_image,
                                          extreme_points,
                                          sigma=10)
        extreme_heatmap = helpers.cstm_normalize(extreme_heatmap, 255)

        #  Concatenate inputs and convert to tensor
        input_dextr = np.concatenate(
            (resize_image, extreme_heatmap[:, :, np.newaxis]), axis=2)
        inputs = torch.from_numpy(
            input_dextr.transpose((2, 0, 1))[np.newaxis, ...])

        # Run a forward pass
        outputs = net.forward(inputs)
        outputs = upsample(outputs,
                           size=(512, 512),
                           mode='bilinear',
                           align_corners=True)
        outputs = torch.sigmoid(outputs)
        outputs = outputs.to(torch.device('cpu'))

        pred = np.transpose(outputs.data.numpy()[0, ...], (1, 2, 0))
        #pred = 1 / (1 + np.exp(-pred))
        pred = np.squeeze(pred)
        result = helpers.crop2fullmask(
            pred, bbox, im_size=image.shape[:2], zero_pad=True,
            relax=pad) > thres

        results.append(result)

        # Plot the results
        plt.imshow(helpers.overlay_masks(image / 255, results))
        plt.plot(extreme_points_ori[:, 0], extreme_points_ori[:, 1], 'gx')
        print('Time to plot: ', time() - begin, ' seconds.')