Ejemplo n.º 1
0
connected = True
from_same_vessel = False
bifurcations_allowed = False
save_vertices_indxs = False

# Setting other parameters
nEpochs = 2000  # Number of epochs for training
numHGScales = 4  # How many times to downsample inside each HourGlass
useTest = 1  # See evolution of the test set when training?
testBatch = 1  # Testing Batch
nTestInterval = 10  # Run on test set every nTestInterval iterations
gpu_id = int(os.environ['SGE_GPU'])  # Select which GPU, -1 if CPU
snapshot = 200  # Store a model every snapshot epochs

# Network definition
net = nt.Net_SHG(p['numHG'], numHGScales, p['Block'], 128, 1)
if gpu_id >= 0:
    torch.cuda.set_device(device=gpu_id)
    net.cuda()

# Loss function definition
criterion = nn.MSELoss(size_average=True)

# Use the following optimizer
optimizer = optim.RMSprop(net.parameters(),
                          lr=0.00005,
                          alpha=0.99,
                          momentum=0.0)

# Preparation of the data loaders
# Define augmentation transformations as a composition
Ejemplo n.º 2
0
def get_most_confident_outputs(img_id, patch_center_row, patch_center_col,
                               confident_th, gpu_id, connected_same_vessel):

    patch_size = 64
    center = (patch_center_col, patch_center_row)

    x_tmp = int(center[0] - patch_size / 2)
    y_tmp = int(center[1] - patch_size / 2)

    confident_connections = {}
    confident_connections['x_peak'] = []
    confident_connections['y_peak'] = []
    confident_connections['peak_value'] = []

    root_dir = './gt_dbs/DRIVE/'
    img = Image.open(
        os.path.join(root_dir, 'test', 'images', '%02d_test.tif' % img_id))
    img = np.array(img, dtype=np.float32)
    h, w = img.shape[:2]

    if x_tmp > 0 and y_tmp > 0 and x_tmp + patch_size < w and y_tmp + patch_size < h:

        img_crop = img[y_tmp:y_tmp + patch_size, x_tmp:x_tmp + patch_size, :]

        img_crop = img_crop.transpose((2, 0, 1))
        img_crop = torch.from_numpy(img_crop)
        img_crop = img_crop.unsqueeze(0)

        inputs = img_crop / 255 - 0.5

        # Forward pass of the mini-batch
        inputs = Variable(inputs)

        if gpu_id >= 0:
            inputs = inputs.cuda()

        p = {}
        p['useRandom'] = 1  # Shuffle Images
        p['useAug'] = 0  # Use Random rotations in [-30, 30] and scaling in [.75, 1.25]
        p['inputRes'] = (64, 64)  # Input Resolution
        p['outputRes'] = (64, 64)  # Output Resolution (same as input)
        p['g_size'] = 64  # Higher means narrower Gaussian
        p['trainBatch'] = 1  # Number of Images in each mini-batch
        p['numHG'] = 2  # Number of Stacked Hourglasses
        p['Block'] = 'ConvBlock'  # Select: 'ConvBlock', 'BasicBlock', 'BottleNeck'
        p['GTmasks'] = 0  # Use GT Vessel Segmentations as input instead of Retinal Images
        model_dir = './results_dir_vessels/'
        if connected_same_vessel:
            modelName = tb.construct_name(p, "HourGlass-connected-same-vessel")
        else:
            modelName = tb.construct_name(p, "HourGlass-connected")
        numHGScales = 4  # How many times to downsample inside each HourGlass
        net = nt.Net_SHG(p['numHG'], numHGScales, p['Block'], 128, 1)
        epoch = 1800
        net.load_state_dict(
            torch.load(os.path.join(
                model_dir,
                os.path.join(model_dir,
                             modelName + '_epoch-' + str(epoch) + '.pth')),
                       map_location=lambda storage, loc: storage))

        if gpu_id >= 0:
            net = net.cuda()

        output = net.forward(inputs)
        pred = np.squeeze(
            np.transpose(
                output[len(output) - 1].cpu().data.numpy()[0, :, :, :],
                (1, 2, 0)))

        mean, median, std = sigma_clipped_stats(pred, sigma=3.0)
        threshold = median + (10.0 * std)
        sources = find_peaks(pred, threshold, box_size=3)

        indxs = np.argsort(sources['peak_value'])
        for ii in range(0, len(indxs)):
            idx = indxs[len(indxs) - 1 - ii]
            if sources['peak_value'][idx] > confident_th:
                confident_connections['x_peak'].append(sources['x_peak'][idx])
                confident_connections['y_peak'].append(sources['y_peak'][idx])
                confident_connections['peak_value'].append(
                    sources['peak_value'][idx])
            else:
                break

        confident_connections = Table([
            confident_connections['x_peak'], confident_connections['y_peak'],
            confident_connections['peak_value']
        ],
                                      names=('x_peak', 'y_peak', 'peak_value'))

    return confident_connections
Ejemplo n.º 3
0
def get_most_confident_outputs(img_filename, patch_center_row,
                               patch_center_col, confident_th, gpu_id):

    patch_size = 64
    center = (patch_center_col, patch_center_row)

    x_tmp = int(center[0] - patch_size / 2)
    y_tmp = int(center[1] - patch_size / 2)

    confident_connections = {}
    confident_connections['x_peak'] = []
    confident_connections['y_peak'] = []
    confident_connections['peak_value'] = []

    root_dir = './gt_dbs/MassachusettsRoads/test/images/'
    img = Image.open(os.path.join(root_dir, img_filename))
    img = np.array(img, dtype=np.float32)
    h, w = img.shape[:2]

    if x_tmp > 0 and y_tmp > 0 and x_tmp + patch_size < w and y_tmp + patch_size < h:

        img_crop = img[y_tmp:y_tmp + patch_size, x_tmp:x_tmp + patch_size, :]

        img_crop = img_crop.transpose((2, 0, 1))
        img_crop = torch.from_numpy(img_crop)
        img_crop = img_crop.unsqueeze(0)

        inputs = img_crop / 255 - 0.5

        # Forward pass of the mini-batch
        inputs = Variable(inputs)

        if gpu_id >= 0:
            inputs = inputs.cuda()

        p = {}
        p['useRandom'] = 1  # Shuffle Images
        p['useAug'] = 0  # Use Random rotations in [-30, 30] and scaling in [.75, 1.25]
        p['inputRes'] = (64, 64)  # Input Resolution
        p['outputRes'] = (64, 64)  # Output Resolution (same as input)
        p['g_size'] = 64  # Higher means narrower Gaussian
        p['trainBatch'] = 1  # Number of Images in each mini-batch
        p['numHG'] = 2  # Number of Stacked Hourglasses
        p['Block'] = 'ConvBlock'  # Select: 'ConvBlock', 'BasicBlock', 'BottleNeck'
        p['GTmasks'] = 0  # Use GT Vessel Segmentations as input instead of Retinal Images
        model_dir = './results_dir/'
        modelName = tb.construct_name(p, "HourGlass")
        numHGScales = 4  # How many times to downsample inside each HourGlass
        net = nt.Net_SHG(p['numHG'], numHGScales, p['Block'], 128, 1)
        epoch = 130
        net.load_state_dict(
            torch.load(os.path.join(
                model_dir,
                os.path.join(model_dir,
                             modelName + '_epoch-' + str(epoch) + '.pth')),
                       map_location=lambda storage, loc: storage))

        if gpu_id >= 0:
            net = net.cuda()

        output = net.forward(inputs)
        pred = np.squeeze(
            np.transpose(
                output[len(output) - 1].cpu().data.numpy()[0, :, :, :],
                (1, 2, 0)))

        mean, median, std = sigma_clipped_stats(pred, sigma=3.0)
        threshold = median + (10.0 * std)
        sources = find_peaks(pred, threshold, box_size=3)

        if visualize_graph_step_by_step:
            fig, axes = plt.subplots(1, 2)
            axes[0].imshow(img.astype(np.uint8))
            mask_graph_skel = skeletonize(mask_graph > 0)
            indxs = np.argwhere(mask_graph_skel == 1)
            axes[0].scatter(indxs[:, 1], indxs[:, 0], color='red', marker='+')

            axes[0].add_patch(
                patches.Rectangle((x_tmp, y_tmp),
                                  patch_size,
                                  patch_size,
                                  fill=False,
                                  color='cyan',
                                  linewidth=5))
            img_crop_array = img[y_tmp:y_tmp + patch_size,
                                 x_tmp:x_tmp + patch_size, :]
            axes[1].imshow(img_crop_array.astype(np.uint8),
                           interpolation='nearest')
            tmp_vector_x = []
            tmp_vector_y = []
            for ii in range(0, len(sources['peak_value'])):
                if sources['peak_value'][ii] > confident_th:
                    tmp_vector_x.append(sources['x_peak'][ii])
                    tmp_vector_y.append(sources['y_peak'][ii])
            axes[1].plot(tmp_vector_x,
                         tmp_vector_y,
                         ls='none',
                         color='red',
                         marker='+',
                         ms=25,
                         markeredgewidth=10)
            axes[1].plot(32,
                         32,
                         ls='none',
                         color='cyan',
                         marker='+',
                         ms=25,
                         markeredgewidth=10)
            plt.show()

        if visualize_evolution:

            if iter < 20 or (iter < 200 and iter % 20 == 0) or iter % 100 == 0:

                plt.figure(figsize=(12, 12), dpi=60)
                plt.imshow(img.astype(np.uint8))
                mask_graph_skeleton = skeletonize(mask_graph > 0)
                indxs_skel = np.argwhere(mask_graph_skeleton == 1)
                plt.scatter(indxs_skel[:, 1],
                            indxs_skel[:, 0],
                            color='red',
                            marker='+')
                plt.axis('off')
                plt.savefig(directory + 'iter_%05d.png' % iter,
                            bbox_inches='tight')
                plt.close()

        indxs = np.argsort(sources['peak_value'])
        for ii in range(0, len(indxs)):
            idx = indxs[len(indxs) - 1 - ii]
            if sources['peak_value'][idx] > confident_th:
                confident_connections['x_peak'].append(sources['x_peak'][idx])
                confident_connections['y_peak'].append(sources['y_peak'][idx])
                confident_connections['peak_value'].append(
                    sources['peak_value'][idx])
            else:
                break

        confident_connections = Table([
            confident_connections['x_peak'], confident_connections['y_peak'],
            confident_connections['peak_value']
        ],
                                      names=('x_peak', 'y_peak', 'peak_value'))

    return confident_connections