def random_walk_epoch(hdf_file, beta, threshold, random_walk=True):
    if not 'postprocessed' in hdf_file:
        logging.warning('Attempted to random walk for data file which '
                        'does not have postprocessed predictions from '
                        'previous epoch')
    else:
        #reopen data file in read mode
        data_fpath = hdf_file.filename
        hdf_file.close()
        hdf_file = h5py.File(data_fpath, 'r+')

        #get images from base data file
        base_data_file = h5py.File(
            os.path.join(os.path.dirname(hdf_file.filename), 'base_data.hdf5'),
            'r')
        images = np.array(base_data_file['images_train'])
        base_data_file.close()

        #get scribble data as output of previous epoch
        seeds = hdf_file['postprocessed']
        random_walked = hdf_file['random_walked']

        #get checkpoint metadata
        processed = random_walked.attrs.get('processed')
        processed_to = random_walked.attrs.get('processed_to')
        recursion = utils.get_recursion_from_hdf5(hdf_file)
        if not processed:
            #process in batches of 20
            #doesn't really make a time difference
            logging.info("Random walking for recursion {}".format(recursion))
            batch_size = 20
            for scr_idx in range(processed_to, len(seeds), batch_size):
                if random_walk:
                    logging.info(
                        'Random walking range {} to {} of recursion {}'.format(
                            scr_idx, scr_idx + batch_size - 1, recursion))
                    random_walked[scr_idx:scr_idx + batch_size, ...] = segment(
                        images[scr_idx:scr_idx + batch_size, ...],
                        seeds[scr_idx:scr_idx + batch_size, ...],
                        beta=beta,
                        threshold=threshold)
                else:
                    random_walked[scr_idx:scr_idx + batch_size,
                                  ...] = seeds[scr_idx:scr_idx + batch_size,
                                               ...]

                random_walked.attrs.modify('processed_to',
                                           scr_idx + batch_size)

            random_walked.attrs.modify('processed', True)

        #reopen in read mode
        hdf_file.close()
        hdf_file = h5py.File(data_fpath, 'r')
    return hdf_file
Esempio n. 2
0
def main(exp_config, batch_size=3):

    # Load data
    data = h5py.File(sys_config.project_root + exp_config.scribble_data, 'r')

    slices = np.random.randint(low=0,
                               high=data['images_test'].shape[0],
                               size=batch_size)
    slices = np.sort(np.unique(slices))
    slices = [80, 275, 370]
    batch_size = len(slices)
    images = data['images_test'][slices, ...]
    masks = data['masks_test'][slices, ...]
    #masks[masks == 0] = 4

    num_recursions = most_recent_recursion(model_path)

    image_tensor_shape = [batch_size] + list(exp_config.image_size) + [1]
    images_pl = tf.placeholder(tf.float32,
                               shape=image_tensor_shape,
                               name='images')
    mask_pl, softmax_pl = model.predict(images_pl, exp_config.model_handle,
                                        exp_config.nlabels)
    #mask_fs_pl, softmax_fs_pl = model.predict(images_pl,  unet2D_bn_modified, 4)
    saver = tf.train.Saver()
    init = tf.global_variables_initializer()

    predictions = np.zeros([batch_size] + list(exp_config.image_size) +
                           [num_recursions + 1])

    feed_dict = {
        images_pl: np.expand_dims(images, -1),
    }
    path = '/scratch_net/'
    with tf.Session() as sess:
        sess.run(init)
        pred_size = 0
        for recursion in range(num_recursions + 1):
            try:
                try:
                    checkpoint_path = utils.get_latest_model_checkpoint_path(
                        model_path,
                        'recursion_{}_model_best_dice.ckpt'.format(recursion))
                except:
                    checkpoint_path = utils.get_latest_model_checkpoint_path(
                        model_path,
                        'recursion_{}_model.ckpt'.format(recursion))
                saver.restore(sess, checkpoint_path)
                mask_out, _ = sess.run([mask_pl, softmax_pl],
                                       feed_dict=feed_dict)
                for mask in range(batch_size):
                    predictions[
                        mask, ...,
                        pred_size] = image_utils.keep_largest_connected_components(
                            np.squeeze(mask_out[mask, ...]))
                print("Classified for recursion {}".format(recursion))
                pred_size += 1
            except Exception as e:
                print(e)

    num_recursions = pred_size
    fig = plt.figure()
    num_cols = num_recursions + 3
    #RW:
    path = base_path + "/poster/"
    for recursion in range(num_recursions):
        predictions[...,
                    recursion] = segment(images,
                                         np.squeeze(predictions[...,
                                                                recursion]),
                                         beta=exp_config.rw_beta,
                                         threshold=0)

    for r in range(batch_size):
        #Add the image
        # ax = fig.add_subplot(batch_size, num_cols, 1 + r*num_cols)
        # ax.axis('off')
        # ax.imshow(np.squeeze(images[r, ...]), cmap='gray')
        image_utils.print_grayscale(
            np.squeeze(images[r, ...]), path,
            '{}_{}_image'.format(exp_config.experiment_name, slices[r]))
        #Add the mask
        # ax = fig.add_subplot(batch_size, num_cols, 2 + r*num_cols)
        # ax.axis('off')
        # ax.imshow(np.squeeze(masks[r, ...]), vmin=0, vmax=4, cmap='jet')
        image_utils.print_coloured(
            np.squeeze(masks[r, ...]), path,
            '{}_{}_gt'.format(exp_config.experiment_name, slices[r]))

        #predictions[r, ...] = segment(images, np.squeeze(predictions[r, ...]), beta=exp_config.rw_beta, threshold=0)
        for recursion in range(num_recursions):
            #Add each prediction
            image_utils.print_coloured(
                np.squeeze(predictions[r, ..., recursion]), path,
                '{}_{}_pred_r{}'.format(exp_config.experiment_name, slices[r],
                                        recursion))
            #ax = fig.add_subplot(batch_size, num_cols, 3 + recursion + r*num_cols)
            #ax.axis('off')
            #ax.imshow(np.squeeze(predictions[r, ..., recursion]), vmin=0, vmax=4, cmap='jet')

    while True:
        plt.axis('off')
        plt.show()
Esempio n. 3
0
def postprocess(mask_out, images_train, scribbles_train=None):
    '''
    Postprocesses predictions of CNN to create ground truths for recursion
    :param data: Data of this recursion - e.g. if given data file for recursion n,
                 It will set up ground truths to be random walked for recursion n
    :param images_train: Numpy array of training images
    :param scribbles_train: Numpy array of weakly annotated images
    :return:
'''

    #get labels present
    labels = np.unique(scribbles_train)
    labels = labels[labels != 0]

    # use full segmentation of random walker as upper bound
    if exp_config.rw_intersection:
        rw_segmentation = random_walker.segment(images_train,
                                                scribbles_train,
                                                threshold=0,
                                                beta=exp_config.rw_beta)
        mask = mask_out[:]
        mask_out = np.zeros_like(mask_out)
        for label in labels:
            indices = (rw_segmentation == label)
            indices &= (mask == label)
            mask_out[indices] = label

    #revert to original random walked data for 'bad' prediction
    if exp_config.rw_reversion:
        mask = mask_out[:]
        mask_out = np.zeros_like(mask_out)
        for img_id in range(exp_config.batch_size):
            for label in labels:
                if np.sum(mask[img_id, ...] == label) < np.sum(
                        scribbles_train[img_id, ...] == label):
                    #If the prediction has predicted less than the original scribble, revert to
                    #the scribble
                    mask_out[img_id, scribbles_train[img_id,
                                                     ...] == label] = label
                else:
                    mask_out[img_id, mask[img_id, ...] == label] = label

    #keep only largest cluster for output
    if exp_config.keep_largest_cluster:
        for img_id in range(exp_config.batch_size):
            mask_out[img_id,
                     ...] = image_utils.keep_largest_connected_components(
                         np.squeeze(mask_out[img_id, ...]))

    if exp_config.smooth_edges:
        labels = labels[labels != np.max(labels)]
        for img_id in range(exp_config.batch_size):
            mask = mask_out[img_id, ...]
            new_mask = np.zeros_like(mask)
            for label in labels:
                struct = (mask == label).astype(np.float)
                blurred_struct = gaussian_filter(
                    struct, sigma=exp_config.edge_smoother_sigma)
                # ax = fig.add_subplot(161 + label)
                blurred_struct[
                    blurred_struct >= exp_config.edge_smoother_threshold] = 1
                blurred_struct[
                    blurred_struct < exp_config.edge_smoother_threshold] = 0
                new_mask[blurred_struct != 0] = label
            mask_out[img_id, ...] = new_mask

    return mask_out
Esempio n. 4
0
def main(ws_exp_config, slices, test):
    # Load data
    exp_dir = sys_config.project_root + 'acdc_logdir/' + ws_exp_config.experiment_name + '/'
    base_data = h5py.File(os.path.join(exp_dir, 'base_data.hdf5'), 'r')

    # Get number of recursions
    num_recursions = acdc_data.most_recent_recursion(
        sys_config.project_root + 'acdc_logdir/' +
        ws_exp_config.experiment_name)
    print(num_recursions)

    num_recursions += 1
    # Get images
    batch_size = len(slices)

    if test:
        slices = slices[slices < len(base_data['images_test'])]
        images = base_data['images_test'][slices, ...]
        gt = base_data['masks_test'][slices, ...]
        prefix = 'test'
    else:
        slices = slices[slices < len(base_data['images_train'])]
        images = base_data['images_train'][slices, ...]
        gt = base_data['masks_train'][slices, ...]
        scr = base_data['scribbles_train'][slices, ...]
        prefix = 'train'

    image_tensor_shape = [batch_size] + list(ws_exp_config.image_size) + [1]
    images_pl = tf.placeholder(tf.float32,
                               shape=image_tensor_shape,
                               name='images')
    feed_dict = {
        images_pl: np.expand_dims(images, -1),
    }

    #Get weak supervision predictions
    mask_pl, softmax_pl = model.predict(images_pl, ws_exp_config.model_handle,
                                        ws_exp_config.nlabels)
    saver = tf.train.Saver()
    init = tf.global_variables_initializer()
    predictions = np.zeros([batch_size] + list(ws_exp_config.image_size) +
                           [num_recursions])
    predictions_klc = np.zeros_like(predictions)
    predictions_rw = np.zeros_like(predictions)
    with tf.Session() as sess:
        sess.run(init)
        for recursion in range(num_recursions):
            try:
                try:
                    checkpoint_path = utils.get_latest_model_checkpoint_path(
                        ws_model_path,
                        'recursion_{}_model_best_xent.ckpt'.format(recursion))
                except:
                    try:
                        checkpoint_path = utils.get_latest_model_checkpoint_path(
                            ws_model_path,
                            'recursion_{}_model_best_dice.ckpt'.format(
                                recursion))
                    except:
                        checkpoint_path = utils.get_latest_model_checkpoint_path(
                            ws_model_path,
                            'recursion_{}_model.ckpt'.format(recursion))
                saver.restore(sess, checkpoint_path)
                mask_out, _ = sess.run([mask_pl, softmax_pl],
                                       feed_dict=feed_dict)
                predictions[..., recursion] = mask_out
                for i in range(batch_size):
                    predictions_klc[
                        i, :, :,
                        recursion] = image_utils.keep_largest_connected_components(
                            mask_out[i, ...])

                predictions_rw[..., recursion] = segment(
                    images,
                    np.squeeze(predictions_klc[..., recursion]),
                    beta=ws_exp_config.rw_beta,
                    threshold=0)

                print("Classified for recursion {}".format(recursion))
            except Exception:
                predictions[..., recursion] = -1 * np.zeros_like(
                    predictions[..., recursion])
                print("Could not find checkpoint for recursion {} - skipping".
                      format(recursion))

    for i in range(batch_size):
        pref = '{}{}'.format(prefix, slices[i])

        print_grayscale(images[i, ...],
                        filepath=OUTPUT_FOLDER,
                        filename='{}_image'.format(pref))
        print_coloured(gt[i, ...],
                       filepath=OUTPUT_FOLDER,
                       filename='{}_gt'.format(pref))
        for recursion in range(num_recursions):
            if np.max(predictions[i, :, :, recursion]) >= -0.5:
                print_coloured(predictions[i, :, :, recursion],
                               filepath=OUTPUT_FOLDER,
                               filename="{}_ws_pred_r{}".format(
                                   pref, recursion))
                print_coloured(predictions_klc[i, :, :, recursion],
                               filepath=OUTPUT_FOLDER,
                               filename="{}_ws_pred_klc_r{}".format(
                                   pref, recursion))
                print_coloured(predictions_rw[i, :, :, recursion],
                               filepath=OUTPUT_FOLDER,
                               filename="{}_ws_pred_klc_rw_r{}".format(
                                   pref, recursion))
                print("Dice coefficient for slice {} is {}".format(
                    slices[i],
                    dice(predictions_rw[i, :, :, recursion], gt[i, ...])))
        if not test:
            print_coloured(scr[i, ...],
                           filepath=OUTPUT_FOLDER,
                           filename='{}_scribble'.format(pref))