connected = True from_same_vessel = False bifurcations_allowed = False save_vertices_indxs = False # Setting other parameters nEpochs = 2000 # Number of epochs for training numHGScales = 4 # How many times to downsample inside each HourGlass useTest = 1 # See evolution of the test set when training? testBatch = 1 # Testing Batch nTestInterval = 10 # Run on test set every nTestInterval iterations gpu_id = int(os.environ['SGE_GPU']) # Select which GPU, -1 if CPU snapshot = 200 # Store a model every snapshot epochs # Network definition net = nt.Net_SHG(p['numHG'], numHGScales, p['Block'], 128, 1) if gpu_id >= 0: torch.cuda.set_device(device=gpu_id) net.cuda() # Loss function definition criterion = nn.MSELoss(size_average=True) # Use the following optimizer optimizer = optim.RMSprop(net.parameters(), lr=0.00005, alpha=0.99, momentum=0.0) # Preparation of the data loaders # Define augmentation transformations as a composition
def get_most_confident_outputs(img_id, patch_center_row, patch_center_col, confident_th, gpu_id, connected_same_vessel): patch_size = 64 center = (patch_center_col, patch_center_row) x_tmp = int(center[0] - patch_size / 2) y_tmp = int(center[1] - patch_size / 2) confident_connections = {} confident_connections['x_peak'] = [] confident_connections['y_peak'] = [] confident_connections['peak_value'] = [] root_dir = './gt_dbs/DRIVE/' img = Image.open( os.path.join(root_dir, 'test', 'images', '%02d_test.tif' % img_id)) img = np.array(img, dtype=np.float32) h, w = img.shape[:2] if x_tmp > 0 and y_tmp > 0 and x_tmp + patch_size < w and y_tmp + patch_size < h: img_crop = img[y_tmp:y_tmp + patch_size, x_tmp:x_tmp + patch_size, :] img_crop = img_crop.transpose((2, 0, 1)) img_crop = torch.from_numpy(img_crop) img_crop = img_crop.unsqueeze(0) inputs = img_crop / 255 - 0.5 # Forward pass of the mini-batch inputs = Variable(inputs) if gpu_id >= 0: inputs = inputs.cuda() p = {} p['useRandom'] = 1 # Shuffle Images p['useAug'] = 0 # Use Random rotations in [-30, 30] and scaling in [.75, 1.25] p['inputRes'] = (64, 64) # Input Resolution p['outputRes'] = (64, 64) # Output Resolution (same as input) p['g_size'] = 64 # Higher means narrower Gaussian p['trainBatch'] = 1 # Number of Images in each mini-batch p['numHG'] = 2 # Number of Stacked Hourglasses p['Block'] = 'ConvBlock' # Select: 'ConvBlock', 'BasicBlock', 'BottleNeck' p['GTmasks'] = 0 # Use GT Vessel Segmentations as input instead of Retinal Images model_dir = './results_dir_vessels/' if connected_same_vessel: modelName = tb.construct_name(p, "HourGlass-connected-same-vessel") else: modelName = tb.construct_name(p, "HourGlass-connected") numHGScales = 4 # How many times to downsample inside each HourGlass net = nt.Net_SHG(p['numHG'], numHGScales, p['Block'], 128, 1) epoch = 1800 net.load_state_dict( torch.load(os.path.join( model_dir, os.path.join(model_dir, modelName + '_epoch-' + str(epoch) + '.pth')), map_location=lambda storage, loc: storage)) if gpu_id >= 0: net = net.cuda() output = net.forward(inputs) pred = np.squeeze( np.transpose( output[len(output) - 1].cpu().data.numpy()[0, :, :, :], (1, 2, 0))) mean, median, std = sigma_clipped_stats(pred, sigma=3.0) threshold = median + (10.0 * std) sources = find_peaks(pred, threshold, box_size=3) indxs = np.argsort(sources['peak_value']) for ii in range(0, len(indxs)): idx = indxs[len(indxs) - 1 - ii] if sources['peak_value'][idx] > confident_th: confident_connections['x_peak'].append(sources['x_peak'][idx]) confident_connections['y_peak'].append(sources['y_peak'][idx]) confident_connections['peak_value'].append( sources['peak_value'][idx]) else: break confident_connections = Table([ confident_connections['x_peak'], confident_connections['y_peak'], confident_connections['peak_value'] ], names=('x_peak', 'y_peak', 'peak_value')) return confident_connections
def get_most_confident_outputs(img_filename, patch_center_row, patch_center_col, confident_th, gpu_id): patch_size = 64 center = (patch_center_col, patch_center_row) x_tmp = int(center[0] - patch_size / 2) y_tmp = int(center[1] - patch_size / 2) confident_connections = {} confident_connections['x_peak'] = [] confident_connections['y_peak'] = [] confident_connections['peak_value'] = [] root_dir = './gt_dbs/MassachusettsRoads/test/images/' img = Image.open(os.path.join(root_dir, img_filename)) img = np.array(img, dtype=np.float32) h, w = img.shape[:2] if x_tmp > 0 and y_tmp > 0 and x_tmp + patch_size < w and y_tmp + patch_size < h: img_crop = img[y_tmp:y_tmp + patch_size, x_tmp:x_tmp + patch_size, :] img_crop = img_crop.transpose((2, 0, 1)) img_crop = torch.from_numpy(img_crop) img_crop = img_crop.unsqueeze(0) inputs = img_crop / 255 - 0.5 # Forward pass of the mini-batch inputs = Variable(inputs) if gpu_id >= 0: inputs = inputs.cuda() p = {} p['useRandom'] = 1 # Shuffle Images p['useAug'] = 0 # Use Random rotations in [-30, 30] and scaling in [.75, 1.25] p['inputRes'] = (64, 64) # Input Resolution p['outputRes'] = (64, 64) # Output Resolution (same as input) p['g_size'] = 64 # Higher means narrower Gaussian p['trainBatch'] = 1 # Number of Images in each mini-batch p['numHG'] = 2 # Number of Stacked Hourglasses p['Block'] = 'ConvBlock' # Select: 'ConvBlock', 'BasicBlock', 'BottleNeck' p['GTmasks'] = 0 # Use GT Vessel Segmentations as input instead of Retinal Images model_dir = './results_dir/' modelName = tb.construct_name(p, "HourGlass") numHGScales = 4 # How many times to downsample inside each HourGlass net = nt.Net_SHG(p['numHG'], numHGScales, p['Block'], 128, 1) epoch = 130 net.load_state_dict( torch.load(os.path.join( model_dir, os.path.join(model_dir, modelName + '_epoch-' + str(epoch) + '.pth')), map_location=lambda storage, loc: storage)) if gpu_id >= 0: net = net.cuda() output = net.forward(inputs) pred = np.squeeze( np.transpose( output[len(output) - 1].cpu().data.numpy()[0, :, :, :], (1, 2, 0))) mean, median, std = sigma_clipped_stats(pred, sigma=3.0) threshold = median + (10.0 * std) sources = find_peaks(pred, threshold, box_size=3) if visualize_graph_step_by_step: fig, axes = plt.subplots(1, 2) axes[0].imshow(img.astype(np.uint8)) mask_graph_skel = skeletonize(mask_graph > 0) indxs = np.argwhere(mask_graph_skel == 1) axes[0].scatter(indxs[:, 1], indxs[:, 0], color='red', marker='+') axes[0].add_patch( patches.Rectangle((x_tmp, y_tmp), patch_size, patch_size, fill=False, color='cyan', linewidth=5)) img_crop_array = img[y_tmp:y_tmp + patch_size, x_tmp:x_tmp + patch_size, :] axes[1].imshow(img_crop_array.astype(np.uint8), interpolation='nearest') tmp_vector_x = [] tmp_vector_y = [] for ii in range(0, len(sources['peak_value'])): if sources['peak_value'][ii] > confident_th: tmp_vector_x.append(sources['x_peak'][ii]) tmp_vector_y.append(sources['y_peak'][ii]) axes[1].plot(tmp_vector_x, tmp_vector_y, ls='none', color='red', marker='+', ms=25, markeredgewidth=10) axes[1].plot(32, 32, ls='none', color='cyan', marker='+', ms=25, markeredgewidth=10) plt.show() if visualize_evolution: if iter < 20 or (iter < 200 and iter % 20 == 0) or iter % 100 == 0: plt.figure(figsize=(12, 12), dpi=60) plt.imshow(img.astype(np.uint8)) mask_graph_skeleton = skeletonize(mask_graph > 0) indxs_skel = np.argwhere(mask_graph_skeleton == 1) plt.scatter(indxs_skel[:, 1], indxs_skel[:, 0], color='red', marker='+') plt.axis('off') plt.savefig(directory + 'iter_%05d.png' % iter, bbox_inches='tight') plt.close() indxs = np.argsort(sources['peak_value']) for ii in range(0, len(indxs)): idx = indxs[len(indxs) - 1 - ii] if sources['peak_value'][idx] > confident_th: confident_connections['x_peak'].append(sources['x_peak'][idx]) confident_connections['y_peak'].append(sources['y_peak'][idx]) confident_connections['peak_value'].append( sources['peak_value'][idx]) else: break confident_connections = Table([ confident_connections['x_peak'], confident_connections['y_peak'], confident_connections['peak_value'] ], names=('x_peak', 'y_peak', 'peak_value')) return confident_connections