Esempio n. 1
0
def prepare_hippo_testing_images(main_image_dir,
                                 main_result_dir,
                                 target_res,
                                 padding_margin=85,
                                 delete_intermediate_files=True,
                                 path_freesurfer='/usr/local/freesurfer/',
                                 verbose=True,
                                 recompute=True):
    """This function creates multi-modal images of the right and left hippocampi at the target resolution.
    In that purpose it loops over subjects (assumed to be sorted between healthy and AD subfolders) and calls
    preprocess_adni_hippo on each of them.
    :param main_image_dir: path of main directory with images to prepare for testing. Should be organised as follows:
    main_image_dir/state_dir(AD or healthy)/subject_dir/images(t1.mgz, t2.mgz, and aseg.mgz)
    :param main_result_dir: path of main directory where prepared images and labels will be writen.
    Will be organised as follows: main_result_dir/state_dir(AD or healthy)/subject_dir/images(hippo_left.nii.gz,
    hippo_right.nii.gz, hippo_left_aseg.nii.gz, hippo_right_aseg.nii.gz)
    :param target_res: resolution at which to resample the label maps, and the images.
    Can be a number (isotropic resolution), a sequence, or a 1d numpy array.
    :param padding_margin: (optional) margin to add around hippocampi when cropping
    :param delete_intermediate_files: (optional) whether to delete temporary files. Default is True.
    :param path_freesurfer: (optional) path of FreeSurfer home, to use mri_convert
    :param verbose: (optional) whether to print out mri_convert output when resampling images.
    :param recompute: (optional) whether to recompute result files even if they already exists"""

    # create results dir
    utils.mkdir(main_result_dir)

    # loop over states (i.e. AD and healthy)
    list_states = utils.list_subfolders(main_image_dir, whole_path=False)
    for state in list_states:

        # create state directory in result folder
        state_dir = os.path.join(main_image_dir, state)
        result_state_dir = os.path.join(main_result_dir, state)
        utils.mkdir(result_state_dir)

        # loop over subjects
        list_subjects = utils.list_subfolders(state_dir, whole_path=False)
        for subject in list_subjects:

            # create subject directoty in state subfolder
            subject_dir = os.path.join(state_dir, subject)
            result_subject_dir = os.path.join(result_state_dir, subject)
            utils.mkdir(result_subject_dir)

            # get file paths
            t1_path = os.path.join(subject_dir, 't1.mgz')
            t2_path = os.path.join(subject_dir, 't2.mgz')
            aseg_path = os.path.join(subject_dir, 'aseg.mgz')

            preprocess_adni_hippo(t1_path,
                                  t2_path,
                                  aseg_path,
                                  result_subject_dir,
                                  target_res,
                                  padding_margin,
                                  remove=delete_intermediate_files,
                                  path_freesurfer=path_freesurfer,
                                  verbose=verbose,
                                  recompute=recompute)
Esempio n. 2
0
def run_validation_on_aseg_gt(list_supervised_model_dir,
                              list_aseg_gt_dir,
                              path_label_list,
                              recompute=False):
    list_main_samseg_validation_dir = [
        os.path.join(p, 'validation_samseg') for p in list_supervised_model_dir
    ]

    # loop over architectures
    for (main_samseg_validation_dir,
         gt_dir) in zip(list_main_samseg_validation_dir, list_aseg_gt_dir):

        # list model subdirs
        main_aseg_validation_dir = os.path.join(
            os.path.dirname(main_samseg_validation_dir), 'validation')
        utils.mkdir(main_aseg_validation_dir)
        list_samseg_validation_subdir = utils.list_subfolders(
            main_samseg_validation_dir)

        # lover over models
        for samseg_validation_subdir in list_samseg_validation_subdir:

            # create equivalent aseg subdir
            aseg_validation_subdir = os.path.join(
                main_aseg_validation_dir,
                os.path.basename(samseg_validation_subdir))
            utils.mkdir(aseg_validation_subdir)
            path_aseg_dice = os.path.join(aseg_validation_subdir, 'dice.npy')

            # compute dice with aseg gt
            if (not os.path.isfile(path_aseg_dice)) | recompute:
                dice_evaluation(gt_dir, samseg_validation_subdir,
                                path_label_list, path_aseg_dice)
Esempio n. 3
0
def validation_on_dilated_lesions(normal_validation_dir,
                                  dilated_validation_dir,
                                  gt_dir,
                                  evaluation_labels,
                                  recompute=True):

    utils.mkdir(dilated_validation_dir)

    list_validation_subdir = utils.list_subfolders(normal_validation_dir)
    loop_info = utils.LoopInfo(len(list_validation_subdir), 5, 'validating',
                               True)
    for val_idx, validation_subdir in enumerate(list_validation_subdir):
        loop_info.update(val_idx)
        # dilate lesion
        dilated_validation_subdir = os.path.join(
            dilated_validation_dir, os.path.basename(validation_subdir))
        dilate_lesions(validation_subdir,
                       dilated_validation_subdir,
                       recompute=recompute)

        # compute new dice scores
        path_dice = os.path.join(dilated_validation_subdir, 'dice.npy')
        dice_evaluation(gt_dir,
                        dilated_validation_subdir,
                        evaluation_labels,
                        path_dice=path_dice,
                        recompute=recompute)
Esempio n. 4
0
def plot_validation_curves(list_net_validation_dirs,
                           fontsize=18,
                           size_max_circle=100,
                           skip_first_dice_row=True):
    """This function plots the validation curves of several networks, based on the results of validate_training().
    It takes as input a list of validation folders (one for each network), each containing subfolders with dice scores
    for the corresponding validated epoch.
    :param list_net_validation_dirs: list of all the validation folders of the trainings to plot.
    :param fontsize: (optional) fontsize used for the graph.
    :param size_max_circle: (optional) size of the marker for epochs achieveing the best validation scores.
    :param skip_first_dice_row: """

    # loop over architectures
    plt.figure()
    for net_val_dir in list_net_validation_dirs:

        net_name = os.path.basename(os.path.dirname(net_val_dir))
        list_epochs_dir = utils.list_subfolders(net_val_dir, whole_path=False)

        # loop over epochs
        list_net_dice_scores = list()
        list_epochs = list()
        for epoch_dir in list_epochs_dir:

            # build names and create folders
            epoch_dir = os.path.join(net_val_dir, epoch_dir)
            path_epoch_dice = os.path.join(epoch_dir, 'dice.npy')
            if os.path.isfile(path_epoch_dice):
                if skip_first_dice_row:
                    list_net_dice_scores.append(
                        np.mean(np.load(path_epoch_dice)[1:, :]))
                else:
                    list_net_dice_scores.append(
                        np.mean(np.load(path_epoch_dice)))
                list_epochs.append(int(re.sub('[^0-9]', '', epoch_dir)))

        # plot validation scores for current architecture
        if list_net_dice_scores:  # check that archi has been validated for at least 1 epoch
            list_net_dice_scores = np.array(list_net_dice_scores)
            list_epochs = np.array(list_epochs)
            max_score = np.max(list_net_dice_scores)
            epoch_max_score = list_epochs[np.argmax(list_net_dice_scores)]
            print('\n' + net_name)
            print('epoch max score: %d' % epoch_max_score)
            print('max score: %0.2f' % max_score)
            plt.plot(list_epochs, list_net_dice_scores, label=net_name)
            plt.scatter(epoch_max_score, max_score, s=size_max_circle)

    # finalise plot
    plt.tick_params(axis='both', labelsize=fontsize)
    plt.ylabel('Dice scores', fontsize=fontsize)
    plt.xlabel('Epochs', fontsize=fontsize)
    plt.title('Validation curves', fontsize=fontsize)
    plt.legend()
    plt.show()
Esempio n. 5
0
def validation_on_dilated_lesions(normal_validation_dir,
                                  dilated_validation_dir,
                                  gt_dir,
                                  evaluation_labels,
                                  recompute=True):

    utils.mkdir(dilated_validation_dir)

    list_validation_subdir = utils.list_subfolders(normal_validation_dir)
    for val_idx, validation_subdir in enumerate(list_validation_subdir):
        utils.print_loop_info(val_idx, len(list_validation_subdir), 5)

        # dilate lesion
        dilated_validation_subdir = os.path.join(
            dilated_validation_dir, os.path.basename(validation_subdir))
        dilate_lesions(validation_subdir,
                       dilated_validation_subdir,
                       recompute=recompute)

        # compute new dice scores
        path_dice = os.path.join(dilated_validation_subdir, 'dice.npy')
        if (not os.path.isfile(path_dice)) | recompute:
            dice_evaluation(gt_dir, dilated_validation_subdir,
                            evaluation_labels, path_dice)
Esempio n. 6
0
def inter_rater_reproducibility_cross_val_exp(manual_seg_dir,
                                              ref_image_dir=None,
                                              recompute=True):

    # list subjects
    list_subjects = utils.list_subfolders(manual_seg_dir)

    # create result directories
    if ref_image_dir is not None:
        realigned_seg_dir = os.path.join(os.path.dirname(manual_seg_dir),
                                         'registered_to_t1')
        list_ref_subjects = utils.list_images_in_folder(ref_image_dir)
    else:
        realigned_seg_dir = os.path.join(os.path.dirname(manual_seg_dir),
                                         'realigned')
        list_ref_subjects = [None] * len(list_subjects)
    utils.mkdir(realigned_seg_dir)
    path_dice = os.path.join(realigned_seg_dir, 'dice.npy')

    # loop over subjects
    dice = list()
    if (not os.path.isfile(path_dice)) | recompute:
        for subject_dir, ref_subject in zip(list_subjects, list_ref_subjects):

            # align all images to first image
            if ref_subject is not None:
                ref_image = ref_subject
            else:
                ref_image = utils.list_images_in_folder(subject_dir)[0]
            result_dir = os.path.join(realigned_seg_dir,
                                      os.path.basename(subject_dir))
            edit_volumes.mri_convert_images_in_dir(subject_dir,
                                                   result_dir,
                                                   interpolation='nearest',
                                                   reference_dir=ref_image,
                                                   same_reference=True,
                                                   recompute=recompute)

            # load all volumes and compute distance maps
            list_segs = [
                utils.load_volume(path)
                for path in utils.list_images_in_folder(result_dir)
            ]
            list_distance_maps = [
                edit_volumes.compute_distance_map(labels, crop_margin=20)
                for labels in list_segs
            ]
            distance_maps = np.stack(list_distance_maps, axis=-1)
            n_raters = len(list_segs)

            # compare each segmentation to the consensus of all others
            tmp_dice = list()
            for i, seg in enumerate(list_segs):
                tmp_distance_maps = distance_maps[...,
                                                  np.arange(n_raters) != i]
                tmp_distance_maps = (np.mean(tmp_distance_maps, axis=-1) >
                                     0) * 1
                seg = (seg > 0) * 1
                tmp_dice.append(2 * np.sum(tmp_distance_maps * seg) /
                                (np.sum(tmp_distance_maps) + np.sum(seg)))
            dice.append(tmp_dice)

        np.save(path_dice, np.array(dice))
Esempio n. 7
0
def plot_validation_curves(list_net_validation_dirs,
                           eval_indices=None,
                           skip_first_dice_row=True,
                           size_max_circle=100,
                           figsize=(11, 6),
                           fontsize=18,
                           remove_legend=False):
    """This function plots the validation curves of several networks, based on the results of validate_training().
    It takes as input a list of validation folders (one for each network), each containing subfolders with dice scores
    for the corresponding validated epoch.
    :param list_net_validation_dirs: list of all the validation folders of the trainings to plot.
    :param eval_indices: (optional) compute the average Dice loss on a subset of labels indicated by the specified
    indices. Can be a sequence, 1d numpy array, or the path to such an array.
    :param skip_first_dice_row: if eval_indices is None, skip the first row of the dice matrices (usually background)
    :param size_max_circle: (optional) size of the marker for epochs achieveing the best validation scores.
    :param figsize: (optional) size of the figure to draw.
    :param fontsize: (optional) fontsize used for the graph."""

    if eval_indices is not None:
        eval_indices = utils.reformat_to_list(eval_indices, load_as_numpy=True)

    # loop over architectures
    plt.figure(figsize=figsize)
    for net_val_dir in list_net_validation_dirs:

        net_name = os.path.basename(os.path.dirname(net_val_dir))
        list_epochs_dir = utils.list_subfolders(net_val_dir, whole_path=False)

        # loop over epochs
        list_net_dice_scores = list()
        list_epochs = list()
        for epoch_dir in list_epochs_dir:

            # build names and create folders
            path_epoch_dice = os.path.join(net_val_dir, epoch_dir, 'dice.npy')
            if os.path.isfile(path_epoch_dice):
                if eval_indices is not None:
                    list_net_dice_scores.append(
                        np.mean(np.load(path_epoch_dice)[eval_indices, :]))
                else:
                    if skip_first_dice_row:
                        list_net_dice_scores.append(
                            np.mean(np.load(path_epoch_dice)[1:, :]))
                    else:
                        list_net_dice_scores.append(
                            np.mean(np.load(path_epoch_dice)))
                list_epochs.append(int(re.sub('[^0-9]', '', epoch_dir)))

        # plot validation scores for current architecture
        if list_net_dice_scores:  # check that archi has been validated for at least 1 epoch
            list_net_dice_scores = np.array(list_net_dice_scores)
            list_epochs = np.array(list_epochs)
            max_score = np.max(list_net_dice_scores)
            epoch_max_score = list_epochs[np.argmax(list_net_dice_scores)]
            print('\n' + net_name)
            print('epoch max score: %d' % epoch_max_score)
            print('max score: %0.3f' % max_score)
            plt.plot(list_epochs, list_net_dice_scores, label=net_name)
            plt.scatter(epoch_max_score, max_score, s=size_max_circle)

    # finalise plot
    plt.grid()
    plt.tick_params(axis='both', labelsize=fontsize)
    plt.ylabel('Dice scores', fontsize=fontsize)
    plt.xlabel('Epochs', fontsize=fontsize)
    plt.title('Validation curves', fontsize=fontsize)
    if not remove_legend:
        plt.legend(fontsize=fontsize)
    plt.tight_layout(pad=1)
    plt.show()
Esempio n. 8
0
def plot_validation_curves(list_validation_dirs,
                           architecture_names=None,
                           eval_indices=None,
                           skip_first_dice_row=True,
                           size_max_circle=100,
                           figsize=(11, 6),
                           y_lim=None,
                           fontsize=18,
                           list_linestyles=None,
                           list_colours=None,
                           plot_legend=False):
    """This function plots the validation curves of several networks, based on the results of validate_training().
    It takes as input a list of validation folders (one for each network), each containing subfolders with dice scores
    for the corresponding validated epoch.
    :param list_validation_dirs: list of all the validation folders of the trainings to plot.
    :param eval_indices: (optional) compute the average Dice on a subset of labels indicated by the specified indices.
    Can be a 1d numpy array, the path to such an array, or a list of 1d numpy arrays as long as list_validation_dirs.
    :param skip_first_dice_row: if eval_indices is None, skip the first row of the dice matrices (usually background)
    :param size_max_circle: (optional) size of the marker for epochs achieveing the best validation scores.
    :param figsize: (optional) size of the figure to draw.
    :param fontsize: (optional) fontsize used for the graph."""

    n_curves = len(list_validation_dirs)

    if eval_indices is not None:
        if isinstance(eval_indices, (np.ndarray, str)):
            if isinstance(eval_indices, str):
                eval_indices = np.load(eval_indices)
            eval_indices = np.squeeze(
                utils.reformat_to_n_channels_array(eval_indices,
                                                   n_dims=len(eval_indices)))
            eval_indices = [eval_indices] * len(list_validation_dirs)
        elif isinstance(eval_indices, list):
            for (i, e) in enumerate(eval_indices):
                if isinstance(e, np.ndarray):
                    eval_indices[i] = np.squeeze(
                        utils.reformat_to_n_channels_array(e, n_dims=len(e)))
                else:
                    raise TypeError(
                        'if provided as a list, eval_indices should only contain numpy arrays'
                    )
        else:
            raise TypeError(
                'eval_indices can be a numpy array, a path to a numpy array, or a list of numpy arrays.'
            )
    else:
        eval_indices = [None] * len(list_validation_dirs)

    # reformat model names
    if architecture_names is None:
        architecture_names = [
            os.path.basename(os.path.dirname(d)) for d in list_validation_dirs
        ]
    else:
        architecture_names = utils.reformat_to_list(architecture_names,
                                                    len(list_validation_dirs))

    # prepare legend labels
    if plot_legend is False:
        list_legend_labels = ['_nolegend_'] * n_curves
    elif plot_legend is True:
        list_legend_labels = architecture_names
    else:
        list_legend_labels = architecture_names
        list_legend_labels = [
            '_nolegend_' if i >= plot_legend else list_legend_labels[i]
            for i in range(n_curves)
        ]

    # prepare linestyles
    if list_linestyles is not None:
        list_linestyles = utils.reformat_to_list(list_linestyles)
    else:
        list_linestyles = [None] * n_curves

    # prepare curve colours
    if list_colours is not None:
        list_colours = utils.reformat_to_list(list_colours)
    else:
        list_colours = [None] * n_curves

    # loop over architectures
    plt.figure(figsize=figsize)
    for idx, (net_val_dir, net_name, linestyle, colour, legend_label,
              eval_idx) in enumerate(
                  zip(list_validation_dirs, architecture_names,
                      list_linestyles, list_colours, list_legend_labels,
                      eval_indices)):

        list_epochs_dir = utils.list_subfolders(net_val_dir, whole_path=False)

        # loop over epochs
        list_net_dice_scores = list()
        list_epochs = list()
        for epoch_dir in list_epochs_dir:

            # build names and create folders
            path_epoch_dice = os.path.join(net_val_dir, epoch_dir, 'dice.npy')
            if os.path.isfile(path_epoch_dice):
                if eval_idx is not None:
                    list_net_dice_scores.append(
                        np.mean(np.load(path_epoch_dice)[eval_idx, :]))
                else:
                    if skip_first_dice_row:
                        list_net_dice_scores.append(
                            np.mean(np.load(path_epoch_dice)[1:, :]))
                    else:
                        list_net_dice_scores.append(
                            np.mean(np.load(path_epoch_dice)))
                list_epochs.append(int(re.sub('[^0-9]', '', epoch_dir)))

        # plot validation scores for current architecture
        if list_net_dice_scores:  # check that archi has been validated for at least 1 epoch
            list_net_dice_scores = np.array(list_net_dice_scores)
            list_epochs = np.array(list_epochs)
            list_epochs, idx = np.unique(list_epochs, return_index=True)
            list_net_dice_scores = list_net_dice_scores[idx]
            max_score = np.max(list_net_dice_scores)
            epoch_max_score = list_epochs[np.argmax(list_net_dice_scores)]
            print('\n' + net_name)
            print('epoch max score: %d' % epoch_max_score)
            print('max score: %0.3f' % max_score)
            plt.plot(list_epochs,
                     list_net_dice_scores,
                     label=legend_label,
                     linestyle=linestyle,
                     color=colour)
            plt.scatter(epoch_max_score,
                        max_score,
                        s=size_max_circle,
                        color=colour)

    # finalise plot
    plt.grid()
    plt.tick_params(axis='both', labelsize=fontsize)
    plt.ylabel('Dice scores', fontsize=fontsize)
    plt.xlabel('Epochs', fontsize=fontsize)
    if y_lim is not None:
        plt.ylim(y_lim[0], y_lim[1] + 0.01)  # set right/left limits of plot
    plt.title('Validation curves', fontsize=fontsize)
    if plot_legend:
        plt.legend(fontsize=fontsize)
    plt.tight_layout(pad=1)
    plt.show()