def __call__(self, sample):

        # Fixed range of scales
        if self.resolutions is None:
            return sample

        elems = list(sample.keys())

        for elem in elems:

            if 'meta' in elem or 'bbox' in elem or (
                    'extreme_points_coord' in elem
                    and elem not in self.resolutions):
                continue
            if 'extreme_points_coord' in elem and elem in self.resolutions:
                bbox = sample['bbox']
                crop_size = np.array(
                    [bbox[3] - bbox[1] + 1, bbox[4] - bbox[2] + 1])
                res = np.array(self.resolutions[elem]).astype(np.float32)
                sample[elem] = np.round(sample[elem] * res / crop_size).astype(
                    np.int)
                continue
            if elem in self.resolutions:
                if self.resolutions[elem] is None:
                    continue
                if isinstance(sample[elem], list):
                    if sample[elem][0].ndim == 3:
                        output_size = np.append(self.resolutions[elem],
                                                [3, len(sample[elem])])
                    else:
                        output_size = np.append(self.resolutions[elem],
                                                len(sample[elem]))
                    tmp = sample[elem]
                    sample[elem] = np.zeros(output_size, dtype=np.float32)
                    for ii, crop in enumerate(tmp):
                        if self.flagvals is None:
                            sample[elem][..., ii] = helpers.fixed_resize(
                                crop, self.resolutions[elem])
                        else:
                            sample[elem][..., ii] = helpers.fixed_resize(
                                crop,
                                self.resolutions[elem],
                                flagval=self.flagvals[elem])
                else:
                    if self.flagvals is None:
                        sample[elem] = helpers.fixed_resize(
                            sample[elem], self.resolutions[elem])
                    else:
                        sample[elem] = helpers.fixed_resize(
                            sample[elem],
                            self.resolutions[elem],
                            flagval=self.flagvals[elem])
            else:
                del sample[elem]

        return sample
예제 #2
0
def get_mask(image, extreme_points_ori, pad=50, thres=0.8):
    modelName = 'dextr_pascal-sbd'
    gpu_id = 0
    device = torch.device("cuda:"+str(gpu_id) if torch.cuda.is_available() else "cpu")

    #  Create the network and load the weights
    net = resnet.resnet101(1, nInputChannels=4, classifier='psp')
    print("Initializing weights from: {}".format(os.path.join(Path.models_dir(), modelName + '.pth')))
    state_dict_checkpoint = torch.load(os.path.join(Path.models_dir(), modelName + '.pth'),
                                       map_location=lambda storage, loc: storage)
    # Remove the prefix .module from the model when it is trained using DataParallel
    if 'module.' in list(state_dict_checkpoint.keys())[0]:
        new_state_dict = OrderedDict()
        for k, v in state_dict_checkpoint.items():
            name = k[7:]  # remove `module.` from multi-gpu training
            new_state_dict[name] = v
    else:
        new_state_dict = state_dict_checkpoint
    net.load_state_dict(new_state_dict)
    net.eval()
    net.to(device)
    
    with torch.no_grad():
        results = []
        #  Crop image to the bounding box from the extreme points and resize
        bbox = helpers.get_bbox(image, points=extreme_points_ori, pad=pad, zero_pad=True)
        crop_image = helpers.crop_from_bbox(image, bbox, zero_pad=True)
        resize_image = helpers.fixed_resize(crop_image, (512, 512)).astype(np.float32)

        #  Generate extreme point heat map normalized to image values
        extreme_points = extreme_points_ori - [np.min(extreme_points_ori[:, 0]), np.min(extreme_points_ori[:, 1])] + [pad,
                                                                                                                      pad]
        extreme_points = (512 * extreme_points * [1 / crop_image.shape[1], 1 / crop_image.shape[0]]).astype(np.int)
        extreme_heatmap = helpers.make_gt(resize_image, extreme_points, sigma=10)
        extreme_heatmap = helpers.cstm_normalize(extreme_heatmap, 255)

        #  Concatenate inputs and convert to tensor
        input_dextr = np.concatenate((resize_image, extreme_heatmap[:, :, np.newaxis]), axis=2)
        inputs = torch.from_numpy(input_dextr.transpose((2, 0, 1))[np.newaxis, ...])

        # Run a forward pass
        inputs = inputs.to(device)
        outputs = net.forward(inputs)
        outputs = interpolate(outputs, size=(512, 512), mode='bilinear', align_corners=True)
        outputs = outputs.to(torch.device('cpu'))

        pred = np.transpose(outputs.data.numpy()[0, ...], (1, 2, 0))
        pred = 1 / (1 + np.exp(-pred))
        pred = np.squeeze(pred)
        result = helpers.crop2fullmask(pred, bbox, im_size=image.shape[:2], zero_pad=True, relax=pad) > thres
        results.append(result)

        return results, bbox
 def __call__(self, sample):
     # Fixed range of scales
     if self.resolutions is None:
         return sample
     elems = list(sample.keys())
     for elem in elems:
         if elem == 'meta':
             continue
         if elem in self.resolutions:
             if self.resolutions[elem] is None:
                 continue
             if self.flagvals is None:
                 sample[elem] = helpers.fixed_resize(
                     sample[elem], self.resolutions[elem])
             else:
                 sample[elem] = helpers.fixed_resize(
                     sample[elem],
                     self.resolutions[elem],
                     flagval=self.flagvals[elem])
         else:
             del sample[elem]
     return sample
예제 #4
0
    def __call__(self, sample):

        # Fixed range of scales
        if self.resolutions is None:
            return sample
        elems = list(sample.keys())
        for elem in elems:
            if elem in self.resolutions:
                if self.resolutions[elem] is None:
                    continue
                if isinstance(sample[elem], list):
                    if sample[elem][0].ndim == 3:
                        output_size = np.append(self.resolutions[elem],
                                                [3, len(sample[elem])])
                    else:
                        output_size = np.append(self.resolutions[elem],
                                                len(sample[elem]))
                    tmp = sample[elem]
                    sample[elem] = np.zeros(output_size, dtype=np.float32)
                    for ii, crop in enumerate(tmp):
                        if self.flagvals is None:
                            sample[elem][..., ii] = helpers.fixed_resize(
                                crop, self.resolutions[elem])
                        else:
                            sample[elem][..., ii] = helpers.fixed_resize(
                                crop,
                                self.resolutions[elem],
                                flagval=self.flagvals[elem])
                else:
                    if self.flagvals is None:
                        sample[elem] = helpers.fixed_resize(
                            sample[elem], self.resolutions[elem])
                    else:
                        sample[elem] = helpers.fixed_resize(
                            sample[elem],
                            self.resolutions[elem],
                            flagval=self.flagvals[elem])
        return sample
예제 #5
0
def get_inputs(image, bbox, expt, pad):
    crop_image = helpers.crop_from_bbox(image, bbox, zero_pad=True)
    resize_image = helpers.fixed_resize(crop_image,
                                        (512, 512)).astype(np.float32)

    #  Generate extreme point heat map normalized to image values
    extreme_points = expt - [np.min(expt[:, 0]),
                             np.min(expt[:, 1])] + [pad, pad]
    extreme_points = (
        512 * extreme_points *
        [1 / crop_image.shape[1], 1 / crop_image.shape[0]]).astype(np.int)
    extreme_heatmap = helpers.make_gt(resize_image, extreme_points, sigma=10)
    extreme_heatmap = helpers.cstm_normalize(extreme_heatmap, 255)

    #  Concatenate inputs and convert to tensor
    input_dextr = np.concatenate(
        (resize_image, extreme_heatmap[:, :, np.newaxis]), axis=2)
    inputs = torch.tensor(input_dextr.transpose((2, 0, 1))[np.newaxis, ...])
    return inputs
예제 #6
0
        begin = time()
        if extreme_points_ori.shape[0] < 4:
            if len(results) > 0:
                helpers.save_mask(results, 'demo.png')
                print('Saving mask annotation in demo.png and exiting...')
            else:
                print('Exiting...')
            sys.exit()

        #  Crop image to the bounding box from the extreme points and resize
        bbox = helpers.get_bbox(image,
                                points=extreme_points_ori,
                                pad=pad,
                                zero_pad=True)
        crop_image = helpers.crop_from_bbox(image, bbox, zero_pad=True)
        resize_image = helpers.fixed_resize(crop_image,
                                            (512, 512)).astype(np.float32)

        #  Generate extreme point heat map normalized to image values
        extreme_points = extreme_points_ori - [
            np.min(extreme_points_ori[:, 0]),
            np.min(extreme_points_ori[:, 1])
        ] + [pad, pad]
        extreme_points = (
            512 * extreme_points *
            [1 / crop_image.shape[1], 1 / crop_image.shape[0]]).astype(np.int)
        extreme_heatmap = helpers.make_gt(resize_image,
                                          extreme_points,
                                          sigma=10)
        extreme_heatmap = helpers.cstm_normalize(extreme_heatmap, 255)

        #  Concatenate inputs and convert to tensor
예제 #7
0
    def _segment(self, image, extreme_points_ori):
        with torch.no_grad():
            # Crop the image
            h, w = image.shape[:2]
            if self.cfg['adaptive_relax']:
                mean_shape = np.mean(image.shape[:2])
                relax = int(self.cfg['relax_crop'] * mean_shape / 428.)
            else:
                relax = self.cfg['relax_crop']
            bbox = helpers.get_bbox(image,
                                    points=extreme_points_ori,
                                    pad=relax,
                                    zero_pad=self.cfg['zero_pad_crop'])
            crop_image = helpers.crop_from_bbox(
                image, bbox, zero_pad=self.cfg['zero_pad_crop'])

            # Compute the offsets of extreme points
            bounds = (0, 0, w - 1, h - 1)
            bbox_valid = (max(bbox[0], bounds[0]), max(bbox[1], bounds[1]),
                          min(bbox[2], bounds[2]), min(bbox[3], bounds[3]))
            if self.cfg['zero_pad_crop']:
                offsets = (-bbox[0], -bbox[1])
            else:
                offsets = (-bbox_valid[0], -bbox_valid[1])
            crop_extreme_points = extreme_points_ori + offsets

            # Resize
            if (np.minimum(h, w) < self.cfg['min_size']) or (np.maximum(
                    h, w) > self.cfg['max_size']):
                sc1 = self.cfg['min_size'] / np.minimum(h, w)
                sc2 = self.cfg['max_size'] / np.maximum(h, w)
                if sc1 > 1:
                    sc = sc1
                else:
                    sc = np.maximum(sc1, sc2)
                resize_image = cv2.resize(crop_image, (0, 0),
                                          fx=sc,
                                          fy=sc,
                                          interpolation=cv2.INTER_LINEAR)
                points = crop_extreme_points * sc
            else:
                resize_image = crop_image
                points = crop_extreme_points
            h2, w2 = resize_image.shape[:2]

            # Compute image gradient using Sobel filter
            img_r = resize_image[:, :, 0]
            img_g = resize_image[:, :, 1]
            img_b = resize_image[:, :, 2]
            grad_r = helpers.imgradient(img_r)[0]
            grad_g = helpers.imgradient(img_g)[0]
            grad_b = helpers.imgradient(img_b)[0]
            image_grad = np.sqrt(grad_r**2 + grad_g**2 + grad_b**2)
            # Normalize to [0,1]
            image_grad = (image_grad - image_grad.min()) / (image_grad.max() -
                                                            image_grad.min())

            # Convert extreme points to Gaussian heatmaps
            heatmap = helpers.gaussian_transform(resize_image,
                                                 points,
                                                 sigma=10)

            # Resize to a fixed resolution to for global context extraction
            resolution = (self.cfg['lr_size'], self.cfg['lr_size'])
            lr_points = np.array([resolution[0] / w2, resolution[1] / h2
                                  ]) * points
            lr_image = helpers.fixed_resize(resize_image, resolution)

            # Convert the extreme points to Gaussian heatmaps
            lr_heatmap = helpers.gaussian_transform(lr_image,
                                                    lr_points,
                                                    sigma=10)

            # Normalize inputs
            heatmap = 255 * (heatmap - heatmap.min()) / (heatmap.max() -
                                                         heatmap.min() + 1e-10)
            lr_heatmap = 255 * (lr_heatmap - lr_heatmap.min()) / (
                lr_heatmap.max() - lr_heatmap.min() + 1e-10)
            image_grad = 255 * (image_grad - image_grad.min()) / (
                image_grad.max() - image_grad.min() + 1e-10)

            # Concatenate the inputs (1, H, W, C)
            concat_lr = np.concatenate([lr_image, lr_heatmap[:, :, None]],
                                       axis=-1)[None, :].transpose(0, 3, 1, 2)
            concat = np.concatenate([resize_image, heatmap[:, :, None]],
                                    axis=-1)[None, :].transpose(0, 3, 1, 2)
            grad = np.concatenate([resize_image, image_grad[:, :, None]],
                                  axis=-1)[None, :].transpose(0, 3, 1, 2)

            # Convert to PyTorch tensors
            concat_lr = torch.from_numpy(concat_lr).float().to(self.device)
            concat = torch.from_numpy(concat).float().to(self.device)
            grad = torch.from_numpy(grad).float().to(self.device)

            # Forward pass
            outs = self.net.forward(concat, grad, concat_lr, roi=None)[1]
            output = torch.sigmoid(outs).cpu().numpy().squeeze()

            # Project back to original image space
            result = helpers.crop2fullmask(output,
                                           bbox,
                                           im_size=image.shape[:2],
                                           zero_pad=True,
                                           relax=relax)

        return result
예제 #8
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-i',
                        '--image',
                        type=str,
                        default='ims/dog-cat.jpg',
                        help='path to image')
    parser.add_argument('--model-name', type=str, default='dextr_pascal-sbd')
    parser.add_argument('-o',
                        '--output',
                        type=str,
                        default='results',
                        help='path where results will be saved')
    parser.add_argument('--pad', type=int, default=50, help='padding size')
    parser.add_argument('--thres', type=float, default=.9)
    parser.add_argument('--gpu-id', type=int, default=0)
    parser.add_argument('--anchors',
                        type=int,
                        default=5,
                        help='amount of points to set')
    parser.add_argument(
        '--anchor-points',
        type=str,
        default=None,
        help='path to folder of anchor points (tracking points)')
    parser.add_argument(
        '--use-frame-info',
        type=bool,
        default=True,
        help='wheter to use the frame number from the csv file or not')
    parser.add_argument('--corrections',
                        action='store_true',
                        help='toggle popup message wheater to correct or not')
    parser.add_argument(
        '--cut',
        action='store_true',
        help='if used, will save the cutted image instead of the mask as png')

    opt = parser.parse_args()
    modelName = opt.model_name
    pad = opt.pad
    thres = opt.thres
    gpu_id = opt.gpu_id
    device = torch.device("cuda:" +
                          str(gpu_id) if torch.cuda.is_available() else "cpu")

    #  Create the network and load the weights
    net = resnet.resnet101(1, nInputChannels=4, classifier='psp')
    print("Initializing weights from: {}".format(
        os.path.join(Path.models_dir(), modelName + '.pth')))
    state_dict_checkpoint = torch.load(
        os.path.join(Path.models_dir(), modelName + '.pth'),
        map_location=lambda storage, loc: storage)
    # Remove the prefix .module from the model when it is trained using DataParallel
    if 'module.' in list(state_dict_checkpoint.keys())[0]:
        new_state_dict = OrderedDict()
        for k, v in state_dict_checkpoint.items():
            name = k[7:]  # remove `module.` from multi-gpu training
            new_state_dict[name] = v
    else:
        new_state_dict = state_dict_checkpoint
    net.load_state_dict(new_state_dict)
    net.eval()
    net.to(device)

    #  Read image and click the points
    if os.path.isfile(opt.image):
        images = [opt.image]
    else:
        images = sorted(glob.glob(opt.image + '/*.*'))
    if opt.anchor_points:
        tracks = sorted(glob.glob(opt.anchor_points + '/*.csv'))
        frames, X, Y = [], [], []
        for i in range(len(tracks)):
            f, x, y = np.loadtxt(tracks[i], delimiter=',', unpack=True)
            frames.append(f.tolist())
            X.append(x.tolist())
            Y.append(y.tolist())
        anchorPoints = []
        uframes = np.unique(np.hstack([np.array(a) for a in frames])).tolist()
        # print(uframes)
        for i in range(len(uframes)):
            extreme_points = []
            for j in range(len(frames)):
                try:
                    ind = frames[j].index(uframes[i])
                    extreme_points.append([X[j][ind], Y[j][ind]])
                except ValueError:
                    continue
            anchorPoints.append(np.array(extreme_points))

    for i, img in enumerate(images):

        if opt.use_frame_info and opt.anchor_points is not None:
            file_number = int(re.sub(r'\D', '', img))
            if not file_number in uframes:
                print(img, 'skipped')
                continue

        if opt.anchor_points is None:
            plt.figure()
        while True:
            image = np.array(Image.open(img))
            mask_path = os.path.join(opt.output, os.path.split(img)[1])
            if opt.anchor_points is None:
                plt.ion()
                plt.axis('off')
                plt.imshow(image)
                plt.title(
                    'Click the four extreme points of the objects\nHit enter/middle mouse button when done (do not close the window)'
                )

            results = []

            with torch.no_grad():
                # while 1:
                if opt.anchor_points:
                    if opt.use_frame_info:
                        try:
                            index = uframes.index(file_number)
                        except ValueError:
                            print(
                                'Could not find data for frame %i. Use frame %i instead.'
                                % (file_number, i))
                            index = i
                    else:
                        index = i
                    extreme_points_ori = anchorPoints[index].astype(np.int)
                else:
                    extreme_points_ori = np.array(
                        plt.ginput(opt.anchors, timeout=0)).astype(np.int)

                # print(extreme_points_ori,extreme_points_ori.shape)
                #  Crop image to the bounding box from the extreme points and resize
                bbox = helpers.get_bbox(image,
                                        points=extreme_points_ori,
                                        pad=pad,
                                        zero_pad=False)
                crop_image = helpers.crop_from_bbox(image, bbox, zero_pad=True)
                resize_image = helpers.fixed_resize(
                    crop_image, (512, 512)).astype(np.float32)

                #  Generate extreme point heat map normalized to image values
                extreme_points = extreme_points_ori - [
                    np.min(extreme_points_ori[:, 0]),
                    np.min(extreme_points_ori[:, 1])
                ] + [pad, pad]
                extreme_points = (
                    512 * extreme_points *
                    [1 / crop_image.shape[1], 1 / crop_image.shape[0]]).astype(
                        np.int)
                extreme_heatmap = helpers.make_gt(resize_image,
                                                  extreme_points,
                                                  sigma=10)
                extreme_heatmap = helpers.cstm_normalize(extreme_heatmap, 255)

                #  Concatenate inputs and convert to tensor
                input_dextr = np.concatenate(
                    (resize_image, extreme_heatmap[:, :, np.newaxis]), axis=2)
                inputs = torch.from_numpy(
                    input_dextr.transpose((2, 0, 1))[np.newaxis, ...])

                # Run a forward pass
                inputs = inputs.to(device)
                outputs = net.forward(inputs)
                outputs = interpolate(outputs,
                                      size=(512, 512),
                                      mode='bilinear',
                                      align_corners=True)
                outputs = outputs.to(torch.device('cpu'))

                pred = np.transpose(outputs.data.numpy()[0, ...], (1, 2, 0))
                pred = 1 / (1 + np.exp(-pred))
                pred = np.squeeze(pred)
                result = helpers.crop2fullmask(pred,
                                               bbox,
                                               im_size=image.shape[:2],
                                               zero_pad=True,
                                               relax=pad) > thres

                results.append(result)
                # Plot the results
                plt.imshow(helpers.overlay_masks(image / 255, results))
                plt.plot(extreme_points_ori[:, 0], extreme_points_ori[:, 1],
                         'gx')

                if not opt.cut:
                    helpers.save_mask(results, mask_path)
                else:
                    Image.fromarray(
                        np.concatenate(
                            (image, 255 * result[..., None].astype(np.int)),
                            2).astype(np.uint8)).save(mask_path, 'png')
                '''if len(extreme_points_ori) < 4:
                        if len(results) > 0:
                            helpers.save_mask(results, 'demo.png')
                            print('Saving mask annotation in demo.png and exiting...')
                        else:
                            print('Exiting...')
                        sys.exit()'''
            if opt.anchor_points is None:
                plt.close()
            if opt.corrections:
                if easygui.ynbox(image=mask_path):
                    break
            else:
                break
        print(img, 'done')
예제 #9
0
파일: dextr.py 프로젝트: 4rshdeep/dextr-api
def dextr_helper(img_url=IMG_URL, extreme_pts=EXTREME_PTS):
    """
    @params 
    img_url - string containing url to the image
    extreme_pts - list of (x, y) extreme coordinate tuples 

    @returns tuple - (bbox, mask, pred)
    bbox (x_min, y_min, x_max, y_max) is the bounding box generated from he extreme points
    mask is a boolean numpy array indicating presence of instance 
    pred is the classification result 
    """

    response = requests.get(img_url)
    image = np.array(Image.open(io.BytesIO(response.content)))

    with torch.no_grad():
        extreme_points_ori = np.array(extreme_pts).astype(np.int)

        #  Crop image to the bounding box from the extreme points and resize
        bbox = helpers.get_bbox(image,
                                points=extreme_points_ori,
                                pad=pad,
                                zero_pad=True)
        crop_image = helpers.crop_from_bbox(image, bbox, zero_pad=True)
        resize_image = helpers.fixed_resize(crop_image,
                                            (512, 512)).astype(np.float32)

        class_prediction = get_prediction_numpy(crop_image)
        # print("Class Prediction is : {}".format(class_prediction))

        # this is the bounding box to return (with 0 padding)
        actual_bbox = helpers.get_bbox(image,
                                       points=extreme_points_ori,
                                       pad=0,
                                       zero_pad=True)

        #  Generate extreme point heat map normalized to image values
        extreme_points = extreme_points_ori - [
            np.min(extreme_points_ori[:, 0]),
            np.min(extreme_points_ori[:, 1])
        ] + [pad, pad]
        extreme_points = (
            512 * extreme_points *
            [1 / crop_image.shape[1], 1 / crop_image.shape[0]]).astype(np.int)
        extreme_heatmap = helpers.make_gt(resize_image,
                                          extreme_points,
                                          sigma=10)
        extreme_heatmap = helpers.cstm_normalize(extreme_heatmap, 255)

        #  Concatenate inputs and convert to tensor
        input_dextr = np.concatenate(
            (resize_image, extreme_heatmap[:, :, np.newaxis]), axis=2)
        inputs = torch.from_numpy(
            input_dextr.transpose((2, 0, 1))[np.newaxis, ...])

        # Run a forward pass
        inputs = inputs.to(device)
        outputs = net.forward(inputs)
        outputs = upsample(outputs,
                           size=(512, 512),
                           mode='bilinear',
                           align_corners=True)
        outputs = outputs.to(torch.device('cpu'))

        pred = np.transpose(outputs.data.numpy()[0, ...], (1, 2, 0))
        pred = 1 / (1 + np.exp(-pred))
        pred = np.squeeze(pred)

        # Here result is of the shape of image, where True implies that part should be in the segment
        result = helpers.crop2fullmask(
            pred, bbox, im_size=image.shape[:2], zero_pad=True,
            relax=pad) > thres

        return (actual_bbox, result, class_prediction)
예제 #10
0
def demo(net, image_path='ims/soccer.jpg'):
    pad = 50
    thres = 0.8
    #  Read image and click the points
    image = np.array(Image.open(image_path))
    plt.ion()
    plt.axis('off')
    plt.imshow(image)
    plt.title(
        'Click the four extreme points of the objects\nHit enter when done (do not close the window)'
    )
    results = []
    while True:
        extreme_points_ori = np.array(plt.ginput(4, timeout=0)).astype(np.int)
        begin = time()
        if extreme_points_ori.shape[0] < 4:
            if len(results) > 0:
                helpers.save_mask(results, 'demo.png')
                print('Saving mask annotation in demo.png and exiting...')
            else:
                print('Exiting...')
            sys.exit()

        #  Crop image to the bounding box from the extreme points and resize
        bbox = helpers.get_bbox(image,
                                points=extreme_points_ori,
                                pad=pad,
                                zero_pad=True)
        crop_image = helpers.crop_from_bbox(image, bbox, zero_pad=True)
        resize_image = helpers.fixed_resize(crop_image,
                                            (512, 512)).astype(np.float32)

        #  Generate extreme point heat map normalized to image values
        extreme_points = extreme_points_ori - [
            np.min(extreme_points_ori[:, 0]),
            np.min(extreme_points_ori[:, 1])
        ] + [pad, pad]
        extreme_points = (
            512 * extreme_points *
            [1 / crop_image.shape[1], 1 / crop_image.shape[0]]).astype(np.int)
        extreme_heatmap = helpers.make_gt(resize_image,
                                          extreme_points,
                                          sigma=10)
        extreme_heatmap = helpers.cstm_normalize(extreme_heatmap, 255)

        #  Concatenate inputs and convert to tensor
        input_dextr = np.concatenate(
            (resize_image, extreme_heatmap[:, :, np.newaxis]), axis=2)
        inputs = torch.from_numpy(
            input_dextr.transpose((2, 0, 1))[np.newaxis, ...])

        # Run a forward pass
        outputs = net.forward(inputs)
        outputs = upsample(outputs,
                           size=(512, 512),
                           mode='bilinear',
                           align_corners=True)
        outputs = torch.sigmoid(outputs)
        outputs = outputs.to(torch.device('cpu'))

        pred = np.transpose(outputs.data.numpy()[0, ...], (1, 2, 0))
        #pred = 1 / (1 + np.exp(-pred))
        pred = np.squeeze(pred)
        result = helpers.crop2fullmask(
            pred, bbox, im_size=image.shape[:2], zero_pad=True,
            relax=pad) > thres

        results.append(result)

        # Plot the results
        plt.imshow(helpers.overlay_masks(image / 255, results))
        plt.plot(extreme_points_ori[:, 0], extreme_points_ori[:, 1], 'gx')
        print('Time to plot: ', time() - begin, ' seconds.')