def __call__(self, sample): _target = sample[self.mask_elem] h, w = _target.shape[:2] heatmap = np.zeros((h, w)) pos = np.zeros((h, w)) for elem in self.tr_elems: _points = sample[elem] if _points is not None: if self.approx: heatmap = np.maximum(heatmap, helpers.gaussian_transform(_target, _points, sigma=self.sigma)) # faster! else: heatmap = np.maximum(heatmap, helpers.make_gt(_target, _points, sigma=self.sigma, one_mask_per_point=False)) # Return binary positions if self.return_pos: _points = _points.astype(int) pos[_points[:,1], _points[:,0]] = 1 sample[self.tr_name] = heatmap if self.return_pos: sample[self.tr_name+'_pos'] = pos return sample
def __call__(self, sample): if sample[self.elem].ndim == 3: raise ValueError('ExtremePoints not implemented for multiple object per image.') _target = sample[self.elem] if np.max(_target) == 0: sample['extreme_points'] = np.zeros(_target.shape, dtype=_target.dtype) # TODO: handle one_mask_per_point case else: _points = helpers.extreme_points(_target, self.pert) sample['extreme_points'] = helpers.make_gt(_target, _points, sigma=self.sigma, one_mask_per_point=False) return sample
def get_mask(image, extreme_points_ori, pad=50, thres=0.8): modelName = 'dextr_pascal-sbd' gpu_id = 0 device = torch.device("cuda:"+str(gpu_id) if torch.cuda.is_available() else "cpu") # Create the network and load the weights net = resnet.resnet101(1, nInputChannels=4, classifier='psp') print("Initializing weights from: {}".format(os.path.join(Path.models_dir(), modelName + '.pth'))) state_dict_checkpoint = torch.load(os.path.join(Path.models_dir(), modelName + '.pth'), map_location=lambda storage, loc: storage) # Remove the prefix .module from the model when it is trained using DataParallel if 'module.' in list(state_dict_checkpoint.keys())[0]: new_state_dict = OrderedDict() for k, v in state_dict_checkpoint.items(): name = k[7:] # remove `module.` from multi-gpu training new_state_dict[name] = v else: new_state_dict = state_dict_checkpoint net.load_state_dict(new_state_dict) net.eval() net.to(device) with torch.no_grad(): results = [] # Crop image to the bounding box from the extreme points and resize bbox = helpers.get_bbox(image, points=extreme_points_ori, pad=pad, zero_pad=True) crop_image = helpers.crop_from_bbox(image, bbox, zero_pad=True) resize_image = helpers.fixed_resize(crop_image, (512, 512)).astype(np.float32) # Generate extreme point heat map normalized to image values extreme_points = extreme_points_ori - [np.min(extreme_points_ori[:, 0]), np.min(extreme_points_ori[:, 1])] + [pad, pad] extreme_points = (512 * extreme_points * [1 / crop_image.shape[1], 1 / crop_image.shape[0]]).astype(np.int) extreme_heatmap = helpers.make_gt(resize_image, extreme_points, sigma=10) extreme_heatmap = helpers.cstm_normalize(extreme_heatmap, 255) # Concatenate inputs and convert to tensor input_dextr = np.concatenate((resize_image, extreme_heatmap[:, :, np.newaxis]), axis=2) inputs = torch.from_numpy(input_dextr.transpose((2, 0, 1))[np.newaxis, ...]) # Run a forward pass inputs = inputs.to(device) outputs = net.forward(inputs) outputs = interpolate(outputs, size=(512, 512), mode='bilinear', align_corners=True) outputs = outputs.to(torch.device('cpu')) pred = np.transpose(outputs.data.numpy()[0, ...], (1, 2, 0)) pred = 1 / (1 + np.exp(-pred)) pred = np.squeeze(pred) result = helpers.crop2fullmask(pred, bbox, im_size=image.shape[:2], zero_pad=True, relax=pad) > thres results.append(result) return results, bbox
def __call__(self, sample): if sample[self.elem].ndim == 3: raise ValueError('IOGPoints not implemented for multiple object per image.') _target = sample[self.elem] targetshape=_target.shape if np.max(_target) == 0: sample['IOG_points'] = np.zeros([targetshape[0],targetshape[1],2], dtype=_target.dtype) # TODO: handle one_mask_per_point case else: _points = helpers.iog_points(_target, self.pad_pixel) sample['IOG_points'] = helpers.make_gt(_target, _points, sigma=self.sigma, one_mask_per_point=False) return sample
def __call__(self, sample): if sample[self.elem].ndim == 3: raise ValueError( 'ExtremePoints not implemented for multiple object per image.') _target = sample[self.elem] if np.max(_target) == 0: sample['extreme_points'] = np.zeros( _target.shape, dtype=_target.dtype) # TODO: handle one_mask_per_point case else: import matplotlib.pyplot as plt # import ipdb # ipdb.set_trace() if self.type == 'mask': _points = helpers.get_mask_sample_points(_target, 50) elif self.type == 'normal': _points = helpers.extreme_points(_target, self.pert) elif self.type == 'bbox': _points = helpers.get_bbox_sample_points(_target, self.num_pts) elif self.type == 'polygon': _polygons = helpers.mask_to_poly(_target, visualize=False) _non_pert_points = helpers.get_polygon_points( _polygons, self.num_pts, _target.shape) _pert_points = [ point + (np.random.randint(-self.pert, self.pert), np.random.randint(-self.pert, self.pert)) for point in _non_pert_points ] _points = np.array(_pert_points) elif self.type == 'mask_noise': _points = helpers.get_mask_noise_sample_masks(_target, self.num_pts, ratio=0.2) if self.vis: # if sample['meta']['category'] == 2: plt.imshow(_target) plt.scatter(_points[:, 0], _points[:, 1]) plt.show() # plt.imshow(_target) # plt.scatter(_pert_points[:, 0], _pert_points[:, 1], c='r') sample['extreme_points'] = helpers.make_gt( _target, _points, sigma=self.sigma, one_mask_per_point=False) if self.vis: # if sample['meta']['category'] == 2: plt.imshow(sample['extreme_points']) plt.show() return sample
def __call__(self, sample): if sample[self.elem].ndim == 3: raise ValueError( 'distance_map not implemented for multiple object per image.') _target = sample[self.elem] targetshape = _target.shape if np.max(_target) == 0: sample['distance_map'] = np.zeros( [targetshape[0], targetshape[1], 2], dtype=_target.dtype) else: _points = helpers.GetDistanceMap(_target, self.pad_pixel) sample['distance_map'] = helpers.make_gt(_target, _points, sigma=self.sigma, one_mask_per_point=False) return sample
def get_inputs(image, bbox, expt, pad): crop_image = helpers.crop_from_bbox(image, bbox, zero_pad=True) resize_image = helpers.fixed_resize(crop_image, (512, 512)).astype(np.float32) # Generate extreme point heat map normalized to image values extreme_points = expt - [np.min(expt[:, 0]), np.min(expt[:, 1])] + [pad, pad] extreme_points = ( 512 * extreme_points * [1 / crop_image.shape[1], 1 / crop_image.shape[0]]).astype(np.int) extreme_heatmap = helpers.make_gt(resize_image, extreme_points, sigma=10) extreme_heatmap = helpers.cstm_normalize(extreme_heatmap, 255) # Concatenate inputs and convert to tensor input_dextr = np.concatenate( (resize_image, extreme_heatmap[:, :, np.newaxis]), axis=2) inputs = torch.tensor(input_dextr.transpose((2, 0, 1))[np.newaxis, ...]) return inputs
pad=pad, zero_pad=True) crop_image = helpers.crop_from_bbox(image, bbox, zero_pad=True) resize_image = helpers.fixed_resize(crop_image, (512, 512)).astype(np.float32) # Generate extreme point heat map normalized to image values extreme_points = extreme_points_ori - [ np.min(extreme_points_ori[:, 0]), np.min(extreme_points_ori[:, 1]) ] + [pad, pad] extreme_points = ( 512 * extreme_points * [1 / crop_image.shape[1], 1 / crop_image.shape[0]]).astype(np.int) extreme_heatmap = helpers.make_gt(resize_image, extreme_points, sigma=10) extreme_heatmap = helpers.cstm_normalize(extreme_heatmap, 255) # Concatenate inputs and convert to tensor input_dextr = np.concatenate( (resize_image, extreme_heatmap[:, :, np.newaxis]), axis=2) inputs = torch.from_numpy( input_dextr.transpose((2, 0, 1))[np.newaxis, ...]) # Run a forward pass inputs = inputs.to(device) outputs = net.forward(inputs) outputs = upsample(outputs, size=(512, 512), mode='bilinear',
def main(): parser = argparse.ArgumentParser() parser.add_argument('-i', '--image', type=str, default='ims/dog-cat.jpg', help='path to image') parser.add_argument('--model-name', type=str, default='dextr_pascal-sbd') parser.add_argument('-o', '--output', type=str, default='results', help='path where results will be saved') parser.add_argument('--pad', type=int, default=50, help='padding size') parser.add_argument('--thres', type=float, default=.9) parser.add_argument('--gpu-id', type=int, default=0) parser.add_argument('--anchors', type=int, default=5, help='amount of points to set') parser.add_argument( '--anchor-points', type=str, default=None, help='path to folder of anchor points (tracking points)') parser.add_argument( '--use-frame-info', type=bool, default=True, help='wheter to use the frame number from the csv file or not') parser.add_argument('--corrections', action='store_true', help='toggle popup message wheater to correct or not') parser.add_argument( '--cut', action='store_true', help='if used, will save the cutted image instead of the mask as png') opt = parser.parse_args() modelName = opt.model_name pad = opt.pad thres = opt.thres gpu_id = opt.gpu_id device = torch.device("cuda:" + str(gpu_id) if torch.cuda.is_available() else "cpu") # Create the network and load the weights net = resnet.resnet101(1, nInputChannels=4, classifier='psp') print("Initializing weights from: {}".format( os.path.join(Path.models_dir(), modelName + '.pth'))) state_dict_checkpoint = torch.load( os.path.join(Path.models_dir(), modelName + '.pth'), map_location=lambda storage, loc: storage) # Remove the prefix .module from the model when it is trained using DataParallel if 'module.' in list(state_dict_checkpoint.keys())[0]: new_state_dict = OrderedDict() for k, v in state_dict_checkpoint.items(): name = k[7:] # remove `module.` from multi-gpu training new_state_dict[name] = v else: new_state_dict = state_dict_checkpoint net.load_state_dict(new_state_dict) net.eval() net.to(device) # Read image and click the points if os.path.isfile(opt.image): images = [opt.image] else: images = sorted(glob.glob(opt.image + '/*.*')) if opt.anchor_points: tracks = sorted(glob.glob(opt.anchor_points + '/*.csv')) frames, X, Y = [], [], [] for i in range(len(tracks)): f, x, y = np.loadtxt(tracks[i], delimiter=',', unpack=True) frames.append(f.tolist()) X.append(x.tolist()) Y.append(y.tolist()) anchorPoints = [] uframes = np.unique(np.hstack([np.array(a) for a in frames])).tolist() # print(uframes) for i in range(len(uframes)): extreme_points = [] for j in range(len(frames)): try: ind = frames[j].index(uframes[i]) extreme_points.append([X[j][ind], Y[j][ind]]) except ValueError: continue anchorPoints.append(np.array(extreme_points)) for i, img in enumerate(images): if opt.use_frame_info and opt.anchor_points is not None: file_number = int(re.sub(r'\D', '', img)) if not file_number in uframes: print(img, 'skipped') continue if opt.anchor_points is None: plt.figure() while True: image = np.array(Image.open(img)) mask_path = os.path.join(opt.output, os.path.split(img)[1]) if opt.anchor_points is None: plt.ion() plt.axis('off') plt.imshow(image) plt.title( 'Click the four extreme points of the objects\nHit enter/middle mouse button when done (do not close the window)' ) results = [] with torch.no_grad(): # while 1: if opt.anchor_points: if opt.use_frame_info: try: index = uframes.index(file_number) except ValueError: print( 'Could not find data for frame %i. Use frame %i instead.' % (file_number, i)) index = i else: index = i extreme_points_ori = anchorPoints[index].astype(np.int) else: extreme_points_ori = np.array( plt.ginput(opt.anchors, timeout=0)).astype(np.int) # print(extreme_points_ori,extreme_points_ori.shape) # Crop image to the bounding box from the extreme points and resize bbox = helpers.get_bbox(image, points=extreme_points_ori, pad=pad, zero_pad=False) crop_image = helpers.crop_from_bbox(image, bbox, zero_pad=True) resize_image = helpers.fixed_resize( crop_image, (512, 512)).astype(np.float32) # Generate extreme point heat map normalized to image values extreme_points = extreme_points_ori - [ np.min(extreme_points_ori[:, 0]), np.min(extreme_points_ori[:, 1]) ] + [pad, pad] extreme_points = ( 512 * extreme_points * [1 / crop_image.shape[1], 1 / crop_image.shape[0]]).astype( np.int) extreme_heatmap = helpers.make_gt(resize_image, extreme_points, sigma=10) extreme_heatmap = helpers.cstm_normalize(extreme_heatmap, 255) # Concatenate inputs and convert to tensor input_dextr = np.concatenate( (resize_image, extreme_heatmap[:, :, np.newaxis]), axis=2) inputs = torch.from_numpy( input_dextr.transpose((2, 0, 1))[np.newaxis, ...]) # Run a forward pass inputs = inputs.to(device) outputs = net.forward(inputs) outputs = interpolate(outputs, size=(512, 512), mode='bilinear', align_corners=True) outputs = outputs.to(torch.device('cpu')) pred = np.transpose(outputs.data.numpy()[0, ...], (1, 2, 0)) pred = 1 / (1 + np.exp(-pred)) pred = np.squeeze(pred) result = helpers.crop2fullmask(pred, bbox, im_size=image.shape[:2], zero_pad=True, relax=pad) > thres results.append(result) # Plot the results plt.imshow(helpers.overlay_masks(image / 255, results)) plt.plot(extreme_points_ori[:, 0], extreme_points_ori[:, 1], 'gx') if not opt.cut: helpers.save_mask(results, mask_path) else: Image.fromarray( np.concatenate( (image, 255 * result[..., None].astype(np.int)), 2).astype(np.uint8)).save(mask_path, 'png') '''if len(extreme_points_ori) < 4: if len(results) > 0: helpers.save_mask(results, 'demo.png') print('Saving mask annotation in demo.png and exiting...') else: print('Exiting...') sys.exit()''' if opt.anchor_points is None: plt.close() if opt.corrections: if easygui.ynbox(image=mask_path): break else: break print(img, 'done')
def dextr_helper(img_url=IMG_URL, extreme_pts=EXTREME_PTS): """ @params img_url - string containing url to the image extreme_pts - list of (x, y) extreme coordinate tuples @returns tuple - (bbox, mask, pred) bbox (x_min, y_min, x_max, y_max) is the bounding box generated from he extreme points mask is a boolean numpy array indicating presence of instance pred is the classification result """ response = requests.get(img_url) image = np.array(Image.open(io.BytesIO(response.content))) with torch.no_grad(): extreme_points_ori = np.array(extreme_pts).astype(np.int) # Crop image to the bounding box from the extreme points and resize bbox = helpers.get_bbox(image, points=extreme_points_ori, pad=pad, zero_pad=True) crop_image = helpers.crop_from_bbox(image, bbox, zero_pad=True) resize_image = helpers.fixed_resize(crop_image, (512, 512)).astype(np.float32) class_prediction = get_prediction_numpy(crop_image) # print("Class Prediction is : {}".format(class_prediction)) # this is the bounding box to return (with 0 padding) actual_bbox = helpers.get_bbox(image, points=extreme_points_ori, pad=0, zero_pad=True) # Generate extreme point heat map normalized to image values extreme_points = extreme_points_ori - [ np.min(extreme_points_ori[:, 0]), np.min(extreme_points_ori[:, 1]) ] + [pad, pad] extreme_points = ( 512 * extreme_points * [1 / crop_image.shape[1], 1 / crop_image.shape[0]]).astype(np.int) extreme_heatmap = helpers.make_gt(resize_image, extreme_points, sigma=10) extreme_heatmap = helpers.cstm_normalize(extreme_heatmap, 255) # Concatenate inputs and convert to tensor input_dextr = np.concatenate( (resize_image, extreme_heatmap[:, :, np.newaxis]), axis=2) inputs = torch.from_numpy( input_dextr.transpose((2, 0, 1))[np.newaxis, ...]) # Run a forward pass inputs = inputs.to(device) outputs = net.forward(inputs) outputs = upsample(outputs, size=(512, 512), mode='bilinear', align_corners=True) outputs = outputs.to(torch.device('cpu')) pred = np.transpose(outputs.data.numpy()[0, ...], (1, 2, 0)) pred = 1 / (1 + np.exp(-pred)) pred = np.squeeze(pred) # Here result is of the shape of image, where True implies that part should be in the segment result = helpers.crop2fullmask( pred, bbox, im_size=image.shape[:2], zero_pad=True, relax=pad) > thres return (actual_bbox, result, class_prediction)
def demo(net, image_path='ims/soccer.jpg'): pad = 50 thres = 0.8 # Read image and click the points image = np.array(Image.open(image_path)) plt.ion() plt.axis('off') plt.imshow(image) plt.title( 'Click the four extreme points of the objects\nHit enter when done (do not close the window)' ) results = [] while True: extreme_points_ori = np.array(plt.ginput(4, timeout=0)).astype(np.int) begin = time() if extreme_points_ori.shape[0] < 4: if len(results) > 0: helpers.save_mask(results, 'demo.png') print('Saving mask annotation in demo.png and exiting...') else: print('Exiting...') sys.exit() # Crop image to the bounding box from the extreme points and resize bbox = helpers.get_bbox(image, points=extreme_points_ori, pad=pad, zero_pad=True) crop_image = helpers.crop_from_bbox(image, bbox, zero_pad=True) resize_image = helpers.fixed_resize(crop_image, (512, 512)).astype(np.float32) # Generate extreme point heat map normalized to image values extreme_points = extreme_points_ori - [ np.min(extreme_points_ori[:, 0]), np.min(extreme_points_ori[:, 1]) ] + [pad, pad] extreme_points = ( 512 * extreme_points * [1 / crop_image.shape[1], 1 / crop_image.shape[0]]).astype(np.int) extreme_heatmap = helpers.make_gt(resize_image, extreme_points, sigma=10) extreme_heatmap = helpers.cstm_normalize(extreme_heatmap, 255) # Concatenate inputs and convert to tensor input_dextr = np.concatenate( (resize_image, extreme_heatmap[:, :, np.newaxis]), axis=2) inputs = torch.from_numpy( input_dextr.transpose((2, 0, 1))[np.newaxis, ...]) # Run a forward pass outputs = net.forward(inputs) outputs = upsample(outputs, size=(512, 512), mode='bilinear', align_corners=True) outputs = torch.sigmoid(outputs) outputs = outputs.to(torch.device('cpu')) pred = np.transpose(outputs.data.numpy()[0, ...], (1, 2, 0)) #pred = 1 / (1 + np.exp(-pred)) pred = np.squeeze(pred) result = helpers.crop2fullmask( pred, bbox, im_size=image.shape[:2], zero_pad=True, relax=pad) > thres results.append(result) # Plot the results plt.imshow(helpers.overlay_masks(image / 255, results)) plt.plot(extreme_points_ori[:, 0], extreme_points_ori[:, 1], 'gx') print('Time to plot: ', time() - begin, ' seconds.')