예제 #1
0
def surface_distances(x, y, hausdorff_percentile=1):
    """Computes the maximum boundary distance (Haussdorff distance), and the average boundary distance of two masks.
    x and y should be boolean or 0/1 numpy arrays of the same size."""

    assert x.shape == y.shape, 'both inputs should have same size, had {} and {}'.format(
        x.shape, y.shape)
    n_dims = len(x.shape)

    # crop x and y around ROI
    _, crop_x = edit_volumes.crop_volume_around_region(x)
    _, crop_y = edit_volumes.crop_volume_around_region(y)

    # set distances to maximum volume shape if they are not defined
    if (crop_x is None) | (crop_y is None):
        return max(x.shape), max(x.shape)

    crop = np.concatenate([
        np.minimum(crop_x, crop_y)[:n_dims],
        np.maximum(crop_x, crop_y)[n_dims:]
    ])
    x = edit_volumes.crop_volume_with_idx(x, crop)
    y = edit_volumes.crop_volume_with_idx(y, crop)

    # detect edge
    x_dist_int = distance_transform_edt(x * 1)
    x_edge = (x_dist_int == 1) * 1
    y_dist_int = distance_transform_edt(y * 1)
    y_edge = (y_dist_int == 1) * 1

    # calculate distance from edge
    x_dist = distance_transform_edt(np.logical_not(x_edge))
    y_dist = distance_transform_edt(np.logical_not(y_edge))

    # find distances from the 2 surfaces
    x_dists_to_y = y_dist[x_edge == 1]
    y_dists_to_x = x_dist[y_edge == 1]

    # find max distance from the 2 surfaces
    if hausdorff_percentile == 1:
        x_max_dist_to_y = np.max(x_dists_to_y)
        y_max_dist_to_x = np.max(y_dists_to_x)
        max_dist = np.maximum(x_max_dist_to_y, y_max_dist_to_x)
    else:
        dists = np.sort(np.concatenate([x_dists_to_y, y_dists_to_x]))
        idx_max = min(int(dists.shape[0] * hausdorff_percentile),
                      dists.shape[0] - 1)
        max_dist = dists[idx_max]

    # find average distance between 2 surfaces
    if x_dists_to_y.shape[0] > 0:
        x_mean_dist_to_y = np.mean(x_dists_to_y)
    else:
        x_mean_dist_to_y = max(x.shape)
    if y_dists_to_x.shape[0] > 0:
        y_mean_dist_to_x = np.mean(y_dists_to_x)
    else:
        y_mean_dist_to_x = max(x.shape)
    mean_dist = (x_mean_dist_to_y + y_mean_dist_to_x) / 2

    return max_dist, mean_dist
예제 #2
0
def dice_evaluation(gt_dir,
                    seg_dir,
                    path_label_list,
                    path_result_dice_array):
    """Computes Dice scores for all labels contained in path_segmentation_label_list. Files in gt_folder and seg_folder
    are matched by sorting order.
    :param gt_dir: folder containing ground truth files.
    :param seg_dir: folder containing evaluation files.
    :param path_label_list: path of numpy vector containing all labels to compute the Dice for.
    :param path_result_dice_array: path where the resulting Dice will be writen as numpy array.
    :return: numpy array containing all dice scores (labels in rows, subjects in columns).
    """

    # create result folder
    if not os.path.exists(os.path.dirname(path_result_dice_array)):
        os.mkdir(os.path.dirname(path_result_dice_array))

    # get list label maps to compare
    path_gt_labels = utils.list_images_in_folder(gt_dir)
    path_segs = utils.list_images_in_folder(seg_dir)
    if len(path_gt_labels) != len(path_segs):
        print('different number of files in data folders, had {} and {}'.format(len(path_gt_labels), len(path_segs)))

    # load labels list
    label_list, neutral_labels = utils.get_list_labels(label_list=path_label_list, FS_sort=True, labels_dir=gt_dir)
    label_list_sorted = np.sort(label_list)

    # initialise result matrix
    dice_coefs = np.zeros((label_list.shape[0], len(path_segs)))

    # loop over segmentations
    for idx, (path_gt, path_seg) in enumerate(zip(path_gt_labels, path_segs)):
        utils.print_loop_info(idx, len(path_segs), 10)

        # load gt labels and segmentation
        gt_labels = utils.load_volume(path_gt, dtype='int')
        seg = utils.load_volume(path_seg, dtype='int')
        # crop images
        gt_labels, cropping = edit_volumes.crop_volume_around_region(gt_labels, margin=10)
        seg = edit_volumes.crop_volume_with_idx(seg, cropping)
        # compute dice scores
        tmp_dice = fast_dice(gt_labels, seg, label_list_sorted)
        dice_coefs[:, idx] = tmp_dice[np.searchsorted(label_list_sorted, label_list)]

    # write dice results
    np.save(path_result_dice_array, dice_coefs)

    return dice_coefs
예제 #3
0
def postprocess(prediction, pad_shape, im_shape, crop, n_dims, labels,
                keep_biggest_component, aff):

    # get posteriors and segmentation
    post_patch = np.squeeze(prediction)
    seg_patch = post_patch.argmax(-1)

    # keep biggest connected component (use it with smoothing!)
    if keep_biggest_component:
        left_mask = edit_volumes.get_largest_connected_component(
            (seg_patch > 0) & (seg_patch < 6))
        right_mask = edit_volumes.get_largest_connected_component(
            seg_patch > 5)
        seg_patch *= (left_mask | right_mask)

    # paste patches back to matrix of original image size
    if crop is not None:
        seg = np.zeros(shape=pad_shape, dtype='int32')
        posteriors = np.zeros(shape=[*pad_shape, labels.shape[0]])
        posteriors[...,
                   0] = np.ones(pad_shape)  # place background around patch
        if n_dims == 2:
            seg[crop[0]:crop[2], crop[1]:crop[3]] = seg_patch
            posteriors[crop[0]:crop[2], crop[1]:crop[3], :] = post_patch
        elif n_dims == 3:
            seg[crop[0]:crop[3], crop[1]:crop[4], crop[2]:crop[5]] = seg_patch
            posteriors[crop[0]:crop[3], crop[1]:crop[4],
                       crop[2]:crop[5], :] = post_patch
    else:
        seg = seg_patch
        posteriors = post_patch
    seg = labels[seg.astype('int')].astype('int')

    if im_shape != pad_shape:
        bounds = [int((p - i) / 2) for (p, i) in zip(pad_shape, im_shape)]
        bounds += [p + i for (p, i) in zip(bounds, im_shape)]
        seg = edit_volumes.crop_volume_with_idx(seg, bounds)

    # align prediction back to first orientation
    if n_dims > 2:
        seg = edit_volumes.align_volume_to_ref(seg, np.eye(4), aff_ref=aff)
        posteriors = edit_volumes.align_volume_to_ref(posteriors,
                                                      np.eye(4),
                                                      aff_ref=aff,
                                                      n_dims=n_dims)

    return seg, posteriors
예제 #4
0
def surface_distances(x,
                      y,
                      hausdorff_percentile=None,
                      return_coordinate_max_distance=False):
    """Computes the maximum boundary distance (Haussdorff distance), and the average boundary distance of two masks.
    :param x: numpy array (boolean or 0/1)
    :param y: numpy array (boolean or 0/1)
    :param hausdorff_percentile: (optional) percentile (from 0 to 100) for which to compute the Hausdorff distance.
    Set this to 100 to compute the real Hausdorff distance (default). Can also be a list, where HD will be compute for
    the provided values.
    :param return_coordinate_max_distance: (optional) when set to true, the function will return the coordinates of the
    voxel with the highest distance (only if hausdorff_percentile=100).
    :return: max_dist, mean_dist(, coordinate_max_distance)
    max_dist: scalar with HD computed for the given percentile (or list if hausdorff_percentile was given as a list).
    mean_dist: scalar with average surface distance
    coordinate_max_distance: only returned return_coordinate_max_distance is True."""

    assert x.shape == y.shape, 'both inputs should have same size, had {} and {}'.format(
        x.shape, y.shape)
    n_dims = len(x.shape)

    hausdorff_percentile = 100 if hausdorff_percentile is None else hausdorff_percentile
    hausdorff_percentile = utils.reformat_to_list(hausdorff_percentile)

    # crop x and y around ROI
    _, crop_x = edit_volumes.crop_volume_around_region(x)
    _, crop_y = edit_volumes.crop_volume_around_region(y)

    # set distances to maximum volume shape if they are not defined
    if (crop_x is None) | (crop_y is None):
        return max(x.shape), max(x.shape)

    crop = np.concatenate([
        np.minimum(crop_x, crop_y)[:n_dims],
        np.maximum(crop_x, crop_y)[n_dims:]
    ])
    x = edit_volumes.crop_volume_with_idx(x, crop)
    y = edit_volumes.crop_volume_with_idx(y, crop)

    # detect edge
    x_dist_int = distance_transform_edt(x * 1)
    x_edge = (x_dist_int == 1) * 1
    y_dist_int = distance_transform_edt(y * 1)
    y_edge = (y_dist_int == 1) * 1

    # calculate distance from edge
    x_dist = distance_transform_edt(np.logical_not(x_edge))
    y_dist = distance_transform_edt(np.logical_not(y_edge))

    # find distances from the 2 surfaces
    x_dists_to_y = y_dist[x_edge == 1]
    y_dists_to_x = x_dist[y_edge == 1]

    max_dist = list()
    coordinate_max_distance = None
    for hd_percentile in hausdorff_percentile:

        # find max distance from the 2 surfaces
        if hd_percentile == 100:
            max_dist.append(
                np.max(np.concatenate([x_dists_to_y, y_dists_to_x])))

            if return_coordinate_max_distance:
                indices_x_surface = np.where(x_edge == 1)
                idx_max_distance_x = np.where(x_dists_to_y == max_dist)[0]
                if idx_max_distance_x.size != 0:
                    coordinate_max_distance = np.stack(
                        indices_x_surface).transpose()[idx_max_distance_x]
                else:
                    indices_y_surface = np.where(y_edge == 1)
                    idx_max_distance_y = np.where(y_dists_to_x == max_dist)[0]
                    coordinate_max_distance = np.stack(
                        indices_y_surface).transpose()[idx_max_distance_y]

        # find percentile of max distance
        else:
            max_dist.append(
                np.percentile(np.concatenate([x_dists_to_y, y_dists_to_x]),
                              hd_percentile))

    # find average distance between 2 surfaces
    if x_dists_to_y.shape[0] > 0:
        x_mean_dist_to_y = np.mean(x_dists_to_y)
    else:
        x_mean_dist_to_y = max(x.shape)
    if y_dists_to_x.shape[0] > 0:
        y_mean_dist_to_x = np.mean(y_dists_to_x)
    else:
        y_mean_dist_to_x = max(x.shape)
    mean_dist = (x_mean_dist_to_y + y_mean_dist_to_x) / 2

    # convert max dist back to scalar if HD only computed for 1 percentile
    if len(max_dist) == 1:
        max_dist = max_dist[0]

    # return coordinate of max distance if necessary
    if coordinate_max_distance is not None:
        return max_dist, mean_dist, coordinate_max_distance
    else:
        return max_dist, mean_dist
예제 #5
0
def evaluation(gt_dir,
               seg_dir,
               label_list,
               mask_dir=None,
               compute_score_whole_structure=False,
               path_dice=None,
               path_hausdorff=None,
               path_hausdorff_99=None,
               path_hausdorff_95=None,
               path_mean_distance=None,
               crop_margin_around_gt=10,
               list_incorrect_labels=None,
               list_correct_labels=None,
               use_nearest_label=False,
               recompute=True,
               verbose=True):
    """This function computes Dice scores, as well as surface distances, between two sets of labels maps in gt_dir
    (ground truth) and seg_dir (typically predictions). Labels maps in both folders are matched by sorting order.
    The resulting scores are saved at the specified locations.
    :param gt_dir: path of directory with gt label maps
    :param seg_dir: path of directory with label maps to compare to gt_dir. Matched to gt label maps by sorting order.
    :param label_list: list of label values for which to compute evaluation metrics. Can be a sequence, a 1d numpy
    array, or the path to such array.
    :param mask_dir: (optional) path of directory with masks of areas to ignore for each evaluated segmentation.
    Matched to gt label maps by sorting order. Default is None, where nothing is masked.
    :param compute_score_whole_structure: (optional) whether to also compute the selected scores for the whole segmented
    structure (i.e. scores are computed for a single structure obtained by regrouping all non-zero values). If True, the
    resulting scores are added as an extra row to the result matrices. Default is False.
    :param path_dice: path where the resulting Dice will be writen as numpy array.
    Default is None, where the array is not saved.
    :param path_hausdorff: path where the resulting Hausdorff distances will be writen as numpy array (only if
    compute_distances is True). Default is None, where the array is not saved.
    :param path_hausdorff_99: same as for path_hausdorff but for the 99th percentile of the boundary distance.
    :param path_hausdorff_95: same as for path_hausdorff but for the 95th percentile of the boundary distance.
    :param path_mean_distance: path where the resulting mean distances will be writen as numpy array (only if
    compute_distances is True). Default is None, where the array is not saved.
    :param crop_margin_around_gt: (optional) margin by which to crop around the gt volumes, in order to copute the
    scores more efficiently. If None, no cropping is performed.
    :param list_incorrect_labels: (optional) this option enables to replace some label values in the maps in seg_dir by
    other label values. Can be a list, a 1d numpy array, or the path to such an array.
    The incorrect labels can then be replaced either by specified values, or by the nearest value (see below).
    :param list_correct_labels: (optional) list of values to correct the labels specified in list_incorrect_labels.
    Correct values must have the same order as their corresponding value in list_incorrect_labels.
    :param use_nearest_label: (optional) whether to correct the incorrect lavel values with the nearest labels.
    :param recompute: (optional) whether to recompute the already existing results. Default is True.
    :param verbose: (optional) whether to print out info about the remaining number of cases.
    """

    # check whether to recompute
    compute_dice = not os.path.isfile(path_dice) if (path_dice
                                                     is not None) else True
    compute_hausdorff = not os.path.isfile(path_hausdorff) if (
        path_hausdorff is not None) else False
    compute_hausdorff_99 = not os.path.isfile(path_hausdorff_99) if (
        path_hausdorff_99 is not None) else False
    compute_hausdorff_95 = not os.path.isfile(path_hausdorff_95) if (
        path_hausdorff_95 is not None) else False
    compute_mean_dist = not os.path.isfile(path_mean_distance) if (
        path_mean_distance is not None) else False
    compute_hd = [
        compute_hausdorff, compute_hausdorff_99, compute_hausdorff_95
    ]

    if compute_dice | any(compute_hd) | compute_mean_dist | recompute:

        # get list label maps to compare
        path_gt_labels = utils.list_images_in_folder(gt_dir)
        path_segs = utils.list_images_in_folder(seg_dir)
        path_gt_labels = utils.reformat_to_list(path_gt_labels,
                                                length=len(path_segs))
        if len(path_gt_labels) != len(path_segs):
            print(
                'gt and segmentation folders must have the same amount of label maps.'
            )
        if mask_dir is not None:
            path_masks = utils.list_images_in_folder(mask_dir)
            if len(path_masks) != len(path_segs):
                print('not the same amount of masks and segmentations.')
        else:
            path_masks = [None] * len(path_segs)

        # load labels list
        label_list, _ = utils.get_list_labels(label_list=label_list,
                                              FS_sort=True,
                                              labels_dir=gt_dir)
        n_labels = len(label_list)
        max_label = np.max(label_list) + 1

        # initialise result matrices
        if compute_score_whole_structure:
            max_dists = np.zeros((n_labels + 1, len(path_segs), 3))
            mean_dists = np.zeros((n_labels + 1, len(path_segs)))
            dice_coefs = np.zeros((n_labels + 1, len(path_segs)))
        else:
            max_dists = np.zeros((n_labels, len(path_segs), 3))
            mean_dists = np.zeros((n_labels, len(path_segs)))
            dice_coefs = np.zeros((n_labels, len(path_segs)))

        # loop over segmentations
        loop_info = utils.LoopInfo(len(path_segs),
                                   10,
                                   'evaluating',
                                   print_time=True)
        for idx, (path_gt, path_seg, path_mask) in enumerate(
                zip(path_gt_labels, path_segs, path_masks)):
            if verbose:
                loop_info.update(idx)

            # load gt labels and segmentation
            gt_labels = utils.load_volume(path_gt, dtype='int')
            seg = utils.load_volume(path_seg, dtype='int')
            if path_mask is not None:
                mask = utils.load_volume(path_mask, dtype='bool')
                gt_labels[mask] = max_label
                seg[mask] = max_label

            # crop images
            if crop_margin_around_gt is not None:
                gt_labels, cropping = edit_volumes.crop_volume_around_region(
                    gt_labels, margin=crop_margin_around_gt)
                seg = edit_volumes.crop_volume_with_idx(seg, cropping)

            if list_incorrect_labels is not None:
                seg = edit_volumes.correct_label_map(seg,
                                                     list_incorrect_labels,
                                                     list_correct_labels,
                                                     use_nearest_label)

            # compute Dice scores
            dice_coefs[:n_labels, idx] = fast_dice(gt_labels, seg, label_list)

            # compute Dice scores for whole structures
            if compute_score_whole_structure:
                temp_gt = (gt_labels > 0) * 1
                temp_seg = (seg > 0) * 1
                dice_coefs[-1, idx] = dice(temp_gt, temp_seg)
            else:
                temp_gt = temp_seg = None

            # compute average and Hausdorff distances
            if any(compute_hd) | compute_mean_dist:

                # compute unique label values
                unique_gt_labels = np.unique(gt_labels)
                unique_seg_labels = np.unique(seg)

                # compute max/mean surface distances for all labels
                for index, label in enumerate(label_list):
                    if (label in unique_gt_labels) & (label
                                                      in unique_seg_labels):
                        mask_gt = np.where(gt_labels == label, True, False)
                        mask_seg = np.where(seg == label, True, False)
                        tmp_max_dists, mean_dists[index,
                                                  idx] = surface_distances(
                                                      mask_gt, mask_seg,
                                                      [100, 99, 95])
                        max_dists[index, idx, :] = np.array(tmp_max_dists)
                    else:
                        mean_dists[index, idx] = max(gt_labels.shape)
                        max_dists[index, idx, :] = np.array(
                            [max(gt_labels.shape)] * 3)

                # compute max/mean distances for whole structure
                if compute_score_whole_structure:
                    tmp_max_dists, mean_dists[-1, idx] = surface_distances(
                        temp_gt, temp_seg, [100, 99, 95])
                    max_dists[-1, idx, :] = np.array(tmp_max_dists)

        # write results
        if path_dice is not None:
            utils.mkdir(os.path.dirname(path_dice))
            np.save(path_dice, dice_coefs)
        if path_hausdorff is not None:
            utils.mkdir(os.path.dirname(path_hausdorff))
            np.save(path_hausdorff, max_dists[..., 0])
        if path_hausdorff_99 is not None:
            utils.mkdir(os.path.dirname(path_hausdorff_99))
            np.save(path_hausdorff_99, max_dists[..., 1])
        if path_hausdorff_95 is not None:
            utils.mkdir(os.path.dirname(path_hausdorff_95))
            np.save(path_hausdorff_95, max_dists[..., 2])
        if path_mean_distance is not None:
            utils.mkdir(os.path.dirname(path_mean_distance))
            np.save(path_mean_distance, max_dists[..., 2])
예제 #6
0
def prepare_hippo_training_atlases(labels_dir,
                                   result_dir,
                                   image_dir=None,
                                   image_result_dir=None,
                                   smooth=True,
                                   crop_margin=50,
                                   recompute=True):
    """This function prepares training label maps from CobraLab. It first crops each atlas around the right and left
    hippocampi, with a margin. It then equalises the shape of these atlases by croppping them to the size of the
    smallest hippocampus. Finally it realigns the obtained atlases to FS orientation axes.
    :param labels_dir: path of directory with label maps to prepare
    :param result_dir: path of directory where prepared atlases will be writen
    :param image_dir: (optional) path of directory with images corresponding to the label maps to prepare.
    This can be sued to prepare a dataset of real images for supervised training.
    :param image_result_dir: (optional) path of directory where images corresponding to prepared atlases will be writen
    :param smooth: (optional) whether to smooth the final cropped label maps
    :param crop_margin: (optional) margin to add around hippocampi when cropping
    :param recompute: (optional) whether to recompute result files even if they already exists"""

    # create results dir
    if not os.path.exists(result_dir):
        os.mkdir(result_dir)
    tmp_result_dir = os.path.join(result_dir, 'first_cropping')
    if not os.path.exists(tmp_result_dir):
        os.mkdir(tmp_result_dir)
    if image_dir is not None:
        assert image_result_dir is not None, 'image_result_dir should not be None if image_dir is specified'
        if not os.path.exists(image_result_dir):
            os.mkdir(image_result_dir)
        tmp_image_result_dir = os.path.join(image_result_dir, 'first_cropping')
        if not os.path.exists(tmp_image_result_dir):
            os.mkdir(tmp_image_result_dir)
    else:
        tmp_image_result_dir = None

    # list labels and images
    labels_paths = utils.list_images_in_folder(labels_dir)
    if image_dir is not None:
        path_images = utils.list_images_in_folder(image_dir)
    else:
        path_images = [None] * len(labels_paths)

    # crop all atlases around hippo
    print('\ncropping around hippo')
    shape_array = np.zeros((len(labels_paths)*2, 3))
    for idx, (path_label, path_image) in enumerate(zip(labels_paths, path_images)):
        utils.print_loop_info(idx, len(labels_paths), 1)

        # crop left hippo first
        path_label_first_crop_l = os.path.join(tmp_result_dir,
                                               os.path.basename(path_label).replace('.nii', '_left.nii'))
        lab, aff, h = utils.load_volume(path_label, im_only=False)
        lab_l, croppping_idx, aff_l = edit_volumes.crop_volume_around_region(lab, crop_margin,
                                                                             list(range(20101, 20109)), aff=aff)
        if (not os.path.exists(path_label_first_crop_l)) | recompute:
            utils.save_volume(lab_l, aff_l, h, path_label_first_crop_l)
        else:
            lab_l = utils.load_volume(path_label_first_crop_l)
        if path_image is not None:
            path_image_first_crop_l = os.path.join(tmp_image_result_dir,
                                                   os.path.basename(path_image).replace('.nii', '_left.nii'))
            if (not os.path.exists(path_image_first_crop_l)) | recompute:
                im, aff, h = utils.load_volume(path_image, im_only=False)
                im, aff = edit_volumes.crop_volume_with_idx(im, croppping_idx, aff)
                utils.save_volume(im, aff, h, path_image_first_crop_l)
        shape_array[2*idx, :] = np.array(lab_l.shape)

        # crop right hippo and flip them
        path_label_first_crop_r = os.path.join(tmp_result_dir,
                                               os.path.basename(path_label).replace('.nii', '_right_flipped.nii'))
        lab, aff, h = utils.load_volume(path_label, im_only=False)
        lab_r, croppping_idx, aff_r = edit_volumes.crop_volume_around_region(lab, crop_margin,
                                                                             list(range(20001, 20009)), aff=aff)
        if (not os.path.exists(path_label_first_crop_r)) | recompute:
            lab_r = edit_volumes.flip_volume(lab_r, direction='rl', aff=aff_r)
            utils.save_volume(lab_r, aff_r, h, path_label_first_crop_r)
        else:
            lab_r = utils.load_volume(path_label_first_crop_r)
        if path_image is not None:
            path_image_first_crop_r = os.path.join(tmp_image_result_dir,
                                                   os.path.basename(path_image).replace('.nii', '_right.nii'))
            if (not os.path.exists(path_image_first_crop_r)) | recompute:
                im, aff, h = utils.load_volume(path_image, im_only=False)
                im, aff = edit_volumes.crop_volume_with_idx(im, croppping_idx, aff)
                im = edit_volumes.flip_volume(im, direction='rl', aff=aff)
                utils.save_volume(im, aff, h, path_image_first_crop_r)
        shape_array[2*idx+1, :] = np.array(lab_r.shape)

    # list croppped files
    path_labels_first_cropped = utils.list_images_in_folder(tmp_result_dir)
    if tmp_image_result_dir is not None:
        path_images_first_cropped = utils.list_images_in_folder(tmp_image_result_dir)
    else:
        path_images_first_cropped = [None] * len(path_labels_first_cropped)

    # crop all label maps to same size
    print('\nequalising shapes')
    new_shape = np.min(shape_array, axis=0).astype('int32')
    for i, (path_label, path_image) in enumerate(zip(path_labels_first_cropped, path_images_first_cropped)):
        utils.print_loop_info(i, len(path_labels_first_cropped), 1)

        # get cropping indices
        path_lab_cropped = os.path.join(result_dir, os.path.basename(path_label))
        lab, aff, h = utils.load_volume(path_label, im_only=False)
        lab_shape = lab.shape
        min_cropping = np.array([np.maximum(int((lab_shape[i]-new_shape[i])/2), 0) for i in range(3)])
        max_cropping = np.array([min_cropping[i] + new_shape[i] for i in range(3)])

        # crop labels and realign on adni format
        if (not os.path.exists(path_lab_cropped)) | recompute:
            lab, aff = edit_volumes.crop_volume_with_idx(lab, np.concatenate([min_cropping, max_cropping]), aff)
            # realign on adni format
            lab = np.flip(lab, axis=2)
            aff[0:3, 0:3] = np.array([[-0.6, 0, 0], [0, 0, -0.6], [0, -0.6, 0]])
            utils.save_volume(lab, aff, h, path_lab_cropped)

        # crop image and realign on adni format
        if path_image is not None:
            path_im_cropped = os.path.join(image_result_dir, os.path.basename(path_image))
            if (not os.path.exists(path_im_cropped)) | recompute:
                im, aff, h = utils.load_volume(path_image, im_only=False)
                im, aff = edit_volumes.crop_volume_with_idx(im, np.concatenate([min_cropping, max_cropping]), aff)
                im = np.flip(im, axis=2)
                aff[0:3, 0:3] = np.array([[-0.6, 0, 0], [0, 0, -0.6], [0, -0.6, 0]])
                im = edit_volumes.mask_volume(im, lab)
                utils.save_volume(im, aff, h, path_im_cropped)

    # correct all labels to left values
    print('\ncorrecting labels')
    list_incorrect_labels = [77, 80, 251, 252, 253, 254, 255, 29, 41, 42, 43, 44, 46, 47, 49, 50, 51, 52, 54, 58, 60,
                             61, 62, 63, 7012, 20001, 20002, 20004, 20005, 20006, 20007, 20008]
    list_correct_labels = [2, 3, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5, 7, 8, 10, 11, 12, 13, 18, 26, 28, 2, 30, 31, 20108,
                           20101, 20102, 20104, 20105, 20106, 20107, 20108]
    edit_volumes.correct_labels_in_dir(result_dir, list_incorrect_labels, list_correct_labels, result_dir)

    # smooth labels
    if smooth:
        print('\nsmoothing labels')
        edit_volumes.smooth_labels_in_dir(result_dir, result_dir)
예제 #7
0
def preprocess_adni_hippo(path_t1,
                          path_t2,
                          path_aseg,
                          result_dir,
                          target_res,
                          padding_margin=85,
                          remove=False,
                          path_freesurfer='/usr/local/freesurfer/',
                          verbose=True,
                          recompute=True):
    """This function builds a T1+T2 multimodal image from the ADNI dataset.
    It first rescales intensities of each channel between 0 and 255.
    It then resamples the T2 image (which are 0.4*0.4*2.0 resolution) to target resolution.
    The obtained T2 is then padded in all directions by the padding_margin param (typically large 85).
    The T1 and aseg are then resampled like the T2 using mri_convert.
    Now that the T1, T2 and asegs are aligned and at the same resolution, we crop them around the right and left hippo.
    Finally, the T1 and T2 are concatenated into one single multimodal image.
    :param path_t1: path input T1 (typically at 1mm isotropic)
    :param path_t2: path input T2 (typically cropped around the hippo in sagittal axis, 0.4x0.4x2.0)
    :param path_aseg: path input segmentation (typically at 1mm isotropic)
    :param result_dir: path of directory where prepared images and labels will be writen.
    :param target_res: resolution at which to resample the label maps, and the images.
    Can be a number (isotropic resolution), a sequence, or a 1d numpy array.
    :param padding_margin: (optional) margin to add around hippocampi when cropping
    :param remove: (optional) whether to delete temporary files. Default is True.
    :param path_freesurfer: (optional) path of FreeSurfer home, to use mri_convert
    :param verbose: (optional) whether to print out mri_convert output when resampling images
    :param recompute: (optional) whether to recompute result files even if they already exists
    """

    # create results dir
    if not os.path.isdir(result_dir):
        os.mkdir(result_dir)

    path_test_im_right = os.path.join(result_dir, 'hippo_right.nii.gz')
    path_test_aseg_right = os.path.join(result_dir, 'hippo_right_aseg.nii.gz')
    path_test_im_left = os.path.join(result_dir, 'hippo_left.nii.gz')
    path_test_aseg_left = os.path.join(result_dir, 'hippo_left_aseg.nii.gz')
    if (not os.path.isfile(path_test_im_right)) | (not os.path.isfile(path_test_aseg_right)) | \
       (not os.path.isfile(path_test_im_left)) | (not os.path.isfile(path_test_aseg_left)) | recompute:

        # set up FreeSurfer
        os.environ['FREESURFER_HOME'] = path_freesurfer
        os.system(os.path.join(path_freesurfer, 'SetUpFreeSurfer.sh'))
        mri_convert = os.path.join(path_freesurfer, 'bin/mri_convert.bin')

        # rescale T1
        path_t1_rescaled = os.path.join(result_dir, 't1_rescaled.nii.gz')
        if (not os.path.isfile(path_t1_rescaled)) | recompute:
            im, aff, h = utils.load_volume(path_t1, im_only=False)
            im = edit_volumes.rescale_volume(im)
            utils.save_volume(im, aff, h, path_t1_rescaled)
        # rescale T2
        path_t2_rescaled = os.path.join(result_dir, 't2_rescaled.nii.gz')
        if (not os.path.isfile(path_t2_rescaled)) | recompute:
            im, aff, h = utils.load_volume(path_t2, im_only=False)
            im = edit_volumes.rescale_volume(im)
            utils.save_volume(im, aff, h, path_t2_rescaled)

        # resample T2 to target res
        path_t2_resampled = os.path.join(result_dir, 't2_rescaled_resampled.nii.gz')
        if (not os.path.isfile(path_t2_resampled)) | recompute:
            str_res = ' '.join([str(r) for r in utils.reformat_to_list(target_res, length=3)])
            cmd = mri_convert + ' ' + path_t2_rescaled + ' ' + path_t2_resampled + ' --voxsize ' + str_res
            cmd += ' -odt float'
            if not verbose:
                cmd += ' >/dev/null 2>&1'
            _ = os.system(cmd)

        # pad T2
        path_t2_padded = os.path.join(result_dir, 't2_rescaled_resampled_padded.nii.gz')
        if (not os.path.isfile(path_t2_padded)) | recompute:
            t2, aff, h = utils.load_volume(path_t2_resampled, im_only=False)
            t2_padded = np.pad(t2, padding_margin, 'constant')
            aff[:3, -1] = aff[:3, -1] - (aff[:3, :3] @ (padding_margin * np.ones((3, 1)))).T
            utils.save_volume(t2_padded, aff, h, path_t2_padded)

        # resample T1 and aseg accordingly
        path_t1_resampled = os.path.join(result_dir, 't1_rescaled_resampled.nii.gz')
        if (not os.path.isfile(path_t1_resampled)) | recompute:
            cmd = mri_convert + ' ' + path_t1_rescaled + ' ' + path_t1_resampled + ' -rl ' + path_t2_padded
            cmd += ' -odt float'
            if not verbose:
                cmd += ' >/dev/null 2>&1'
            _ = os.system(cmd)
        path_aseg_resampled = os.path.join(result_dir, 'aseg_resampled.nii.gz')
        if (not os.path.isfile(path_aseg_resampled)) | recompute:
            cmd = mri_convert + ' ' + path_aseg + ' ' + path_aseg_resampled + ' -rl ' + path_t2_padded
            cmd += ' -rt nearest -odt float'
            if not verbose:
                cmd += ' >/dev/null 2>&1'
            _ = os.system(cmd)

        # crop images and concatenate T1 and T2
        for lab, side in zip([17, 53], ['left', 'right']):
            path_test_image = os.path.join(result_dir, 'hippo_{}.nii.gz'.format(side))
            path_test_aseg = os.path.join(result_dir, 'hippo_{}_aseg.nii.gz'.format(side))
            if (not os.path.isfile(path_test_image)) | (not os.path.isfile(path_test_aseg)) | recompute:
                aseg, aff, h = utils.load_volume(path_aseg_resampled, im_only=False)
                tmp_aseg, cropping, tmp_aff = edit_volumes.crop_volume_around_region(aseg,
                                                                                     margin=30,
                                                                                     masking_labels=lab,
                                                                                     aff=aff)
                if side == 'right':
                    tmp_aseg = edit_volumes.flip_volume(tmp_aseg, direction='rl', aff=tmp_aff)
                utils.save_volume(tmp_aseg, tmp_aff, h, path_test_aseg)
                if (not os.path.isfile(path_test_image)) | recompute:
                    t1 = utils.load_volume(path_t1_resampled)
                    t1 = edit_volumes.crop_volume_with_idx(t1, crop_idx=cropping)
                    t1 = edit_volumes.mask_volume(t1, tmp_aseg, dilate=6, erode=5)
                    t2 = utils.load_volume(path_t2_padded)
                    t2 = edit_volumes.crop_volume_with_idx(t2, crop_idx=cropping)
                    t2 = edit_volumes.mask_volume(t2, tmp_aseg, dilate=6, erode=5)
                    if side == 'right':
                        t1 = edit_volumes.flip_volume(t1, direction='rl', aff=tmp_aff)
                        t2 = edit_volumes.flip_volume(t2, direction='rl', aff=tmp_aff)
                    test_image = np.stack([t1, t2], axis=-1)
                    utils.save_volume(test_image, tmp_aff, h, path_test_image)

        # remove unnecessary files
        if remove:
            list_files_to_remove = [path_t1_rescaled,
                                    path_t2_rescaled,
                                    path_t2_resampled,
                                    path_t2_padded,
                                    path_t1_resampled,
                                    path_aseg_resampled]
            for path in list_files_to_remove:
                os.remove(path)
예제 #8
0
def postprocess(prediction, pad_shape, im_shape, crop, n_dims, labels, keep_biggest_component,
                aff, aff_ref='FS', keep_biggest_of_each_group=True, n_neutral_labels=None):

    # get posteriors and segmentation
    post_patch = np.squeeze(prediction)

    # reset posteriors to zero outside the largest connected component of each topological class
    if keep_biggest_of_each_group:
        topology_classes = np.array([0, 1, 2, 3, 4, 4, 5, 5, 6, 6, 7, 8, 9, 10, 11, 12, 13, 14, 5])
        if n_neutral_labels != len(labels):
            left = topology_classes[n_neutral_labels:]
            topology_classes = np.concatenate([topology_classes, left + np.max(left) - np.min(left) + 1])
        unique_topology_classes = np.unique(topology_classes)
        post_patch_mask = post_patch > 0.2
        for topology_class in unique_topology_classes[1:]:
            tmp_topology_indices = np.where(topology_classes == topology_class)[0]
            tmp_mask = np.any(post_patch_mask[..., tmp_topology_indices], axis=-1)
            tmp_mask = edit_volumes.get_largest_connected_component(tmp_mask)
            for idx in tmp_topology_indices:
                post_patch[..., idx] *= tmp_mask

    # renormalise posteriors and get hard segmentation
    post_patch /= np.sum(post_patch, axis=-1)[..., np.newaxis]
    seg_patch = post_patch.argmax(-1)

    # keep biggest connected component (use it with smoothing!)
    if keep_biggest_component:
        mask = seg_patch > 0
        mask = edit_volumes.get_largest_connected_component(mask)
        seg_patch = seg_patch * mask

    # align prediction back to first orientation
    if n_dims > 2:
        if aff_ref == 'FS':
            aff_ref = np.array([[-1., 0., 0., 0.], [0., 0., 1., 0.], [0., -1., 0., 0.], [0., 0., 0., 1.]])
            seg_patch = edit_volumes.align_volume_to_ref(seg_patch, aff_ref, aff_ref=aff, return_aff=False)
            post_patch = edit_volumes.align_volume_to_ref(post_patch, aff_ref, aff_ref=aff, n_dims=n_dims)
        elif aff_ref == 'identity':
            aff_ref = np.eye(4)
            seg_patch = edit_volumes.align_volume_to_ref(seg_patch, aff_ref, aff_ref=aff, return_aff=False)
            post_patch = edit_volumes.align_volume_to_ref(post_patch, aff_ref, aff_ref=aff, n_dims=n_dims)

    # paste patches back to matrix of original image size
    if crop is not None:
        seg = np.zeros(shape=pad_shape, dtype='int32')
        posteriors = np.zeros(shape=[*pad_shape, labels.shape[0]])
        posteriors[..., 0] = np.ones(pad_shape)  # place background around patch
        if n_dims == 2:
            seg[crop[0]:crop[2], crop[1]:crop[3]] = seg_patch
            posteriors[crop[0]:crop[2], crop[1]:crop[3], :] = post_patch
        elif n_dims == 3:
            seg[crop[0]:crop[3], crop[1]:crop[4], crop[2]:crop[5]] = seg_patch
            posteriors[crop[0]:crop[3], crop[1]:crop[4], crop[2]:crop[5], :] = post_patch
    else:
        seg = seg_patch
        posteriors = post_patch
    seg = labels[seg.astype('int')].astype('int')

    if im_shape != pad_shape:
        bounds = [int((p-i)/2) for (p, i) in zip(pad_shape, im_shape)]
        bounds += [p + i for (p, i) in zip(bounds, im_shape)]
        seg = edit_volumes.crop_volume_with_idx(seg, bounds)

    return seg, posteriors
예제 #9
0
def dice_evaluation(gt_dir,
                    seg_dir,
                    label_list,
                    compute_distances=False,
                    compute_score_whole_structure=False,
                    path_dice=None,
                    path_hausdorff=None,
                    path_mean_distance=None,
                    crop_margin_around_gt=10,
                    recompute=True,
                    verbose=True):
    """This function computes Dice scores between two sets of labels maps in gt_dir (ground truth) and seg_dir
    (typically predictions). Labels maps in both folders are matched by sorting order.
    :param gt_dir: path of directory with gt label maps
    :param seg_dir: path of directory with label maps to compare to gt_dir. Matched to gt label maps by sorting order.
    :param label_list: list of label values for which to compute evaluation metrics. Can be a sequence, a 1d numpy
    array, or the path to such array.
    :param compute_distances: (optional) whether to compute distances (Hausdorff and mean distance) between the surfaces
    of GT and predicted labels. Default is False.
    :param compute_score_whole_structure: (optional) whether to also compute the selected scores for the whole segmented
    structure (i.e. scores are computed for a single structure obtained by regrouping all non-zero values). If True, the
    resulting scores are added as an extra row to the result matrices. Default is False.
    :param path_dice: path where the resulting Dice will be writen as numpy array.
    Default is None, where the array is not saved.
    :param path_hausdorff: path where the resulting Hausdorff distances will be writen as numpy array (only if
    compute_distances is True). Default is None, where the array is not saved.
    :param path_mean_distance: path where the resulting mean distances will be writen as numpy array (only if
    compute_distances is True). Default is None, where the array is not saved.
    :param crop_margin_around_gt: (optional) margin by which to crop around the gt volumes, in order to copute the
    scores more efficiently. If None, no cropping is performed.
    :param recompute: (optional) whether to recompute the already existing results. Default is True.
    :param verbose: (optional) whether to print out info about the remaining number of cases.
    :return: numpy array containing all Dice scores (labels in rows, subjects in columns). Also returns numpy arrays
    with the same structures for Hausdorff and mean distances if compute_distances is True.
    """

    # check whether to recompute
    compute_dice = not os.path.isfile(path_dice) if (path_dice
                                                     is not None) else True
    if compute_distances:
        compute_hausdorff = not os.path.isfile(path_hausdorff) if (
            path_hausdorff is not None) else True
        compute_mean_dist = not os.path.isfile(path_mean_distance) if (
            path_mean_distance is not None) else True
    else:
        compute_hausdorff = compute_mean_dist = False

    if compute_dice | compute_hausdorff | compute_mean_dist | recompute:

        # get list label maps to compare
        path_gt_labels = utils.list_images_in_folder(gt_dir)
        path_segs = utils.list_images_in_folder(seg_dir)
        if len(path_gt_labels) != len(path_segs):
            print(
                'gt and segmentation folders must have the same amount of label maps.'
            )

        # load labels list
        label_list, _ = utils.get_list_labels(label_list=label_list,
                                              FS_sort=True,
                                              labels_dir=gt_dir)
        n_labels = len(label_list)

        # initialise result matrices
        if compute_score_whole_structure:
            max_dists = np.zeros((n_labels + 1, len(path_segs)))
            mean_dists = np.zeros((n_labels + 1, len(path_segs)))
            dice_coefs = np.zeros((n_labels + 1, len(path_segs)))
        else:
            max_dists = np.zeros((n_labels, len(path_segs)))
            mean_dists = np.zeros((n_labels, len(path_segs)))
            dice_coefs = np.zeros((n_labels, len(path_segs)))

        # loop over segmentations
        loop_info = utils.LoopInfo(len(path_segs), 10, 'evaluating')
        for idx, (path_gt,
                  path_seg) in enumerate(zip(path_gt_labels, path_segs)):
            if verbose:
                loop_info.update(idx)

            # load gt labels and segmentation
            gt_labels = utils.load_volume(path_gt, dtype='int')
            seg = utils.load_volume(path_seg, dtype='int')

            # crop images
            if crop_margin_around_gt is not None:
                gt_labels, cropping = edit_volumes.crop_volume_around_region(
                    gt_labels, margin=crop_margin_around_gt)
                seg = edit_volumes.crop_volume_with_idx(seg, cropping)

            # compute Dice scores
            dice_coefs[:n_labels, idx] = fast_dice(gt_labels, seg, label_list)

            # compute Dice scores for whole structures
            if compute_score_whole_structure:
                temp_gt = (gt_labels > 0) * 1
                temp_seg = (seg > 0) * 1
                dice_coefs[-1, idx] = dice(temp_gt, temp_seg)
            else:
                temp_gt = temp_seg = None

            # compute average and Hausdorff distances
            if compute_distances:

                # compute unique label values
                unique_gt_labels = np.unique(gt_labels)
                unique_seg_labels = np.unique(seg)

                # compute max/mean surface distances for all labels
                for index, label in enumerate(label_list):
                    if (label in unique_gt_labels) & (label
                                                      in unique_seg_labels):
                        mask_gt = np.where(gt_labels == label, True, False)
                        mask_seg = np.where(seg == label, True, False)
                        max_dists[index,
                                  idx], mean_dists[index,
                                                   idx] = surface_distances(
                                                       mask_gt, mask_seg)
                    else:
                        max_dists[index, idx] = max(gt_labels.shape)
                        mean_dists[index, idx] = max(gt_labels.shape)

                # compute max/mean distances for whole structure
                if compute_score_whole_structure:
                    max_dists[-1, idx], mean_dists[-1,
                                                   idx] = surface_distances(
                                                       temp_gt, temp_seg)

        # write results
        if path_dice is not None:
            utils.mkdir(os.path.dirname(path_dice))
            np.save(path_dice, dice_coefs)
        if compute_distances and path_hausdorff is not None:
            utils.mkdir(os.path.dirname(path_hausdorff))
            np.save(path_hausdorff, max_dists)
        if compute_distances and path_mean_distance is not None:
            utils.mkdir(os.path.dirname(path_mean_distance))
            np.save(path_mean_distance, mean_dists)

    else:
        dice_coefs = np.load(path_dice)
        if compute_distances:
            max_dists = np.load(path_hausdorff)
            mean_dists = np.load(path_mean_distance)
        else:
            max_dists = mean_dists = None

    if compute_distances:
        return dice_coefs, max_dists, mean_dists
    else:
        return dice_coefs, None, None
예제 #10
0
def preprocess_image(im_path, n_levels, crop_shape=None, padding=None):

    # read image and corresponding info
    im, shape, aff, n_dims, n_channels, header, labels_res = utils.get_volume_info(
        im_path, return_volume=True)

    if padding:
        if n_channels == 1:
            im = np.pad(im, padding, mode='constant')
            pad_shape = im.shape
        else:
            im = np.pad(im,
                        tuple([(padding, padding)] * n_dims + [(0, 0)]),
                        mode='constant')
            pad_shape = im.shape[:-1]
    else:
        pad_shape = shape

    # check that patch_shape or im_shape are divisible by 2**n_levels
    if crop_shape is not None:
        crop_shape = utils.reformat_to_list(crop_shape,
                                            length=n_dims,
                                            dtype='int')
        if not all(
            [pad_shape[i] >= crop_shape[i] for i in range(len(pad_shape))]):
            crop_shape = [
                min(pad_shape[i], crop_shape[i]) for i in range(n_dims)
            ]
            print(
                'cropping dimensions are higher than image size, changing cropping size to {}'
                .format(crop_shape))
        if not all([size % (2**n_levels) == 0 for size in crop_shape]):
            crop_shape = [
                utils.find_closest_number_divisible_by_m(size, 2**n_levels)
                for size in crop_shape
            ]
    else:
        if not all([size % (2**n_levels) == 0 for size in pad_shape]):
            crop_shape = [
                utils.find_closest_number_divisible_by_m(size, 2**n_levels)
                for size in pad_shape
            ]

    # crop image if necessary
    if crop_shape is not None:
        crop_idx = np.round(
            (pad_shape - np.array(crop_shape)) / 2).astype('int')
        crop_idx = np.concatenate((crop_idx, crop_idx + crop_shape), axis=0)
        im = edit_volumes.crop_volume_with_idx(im, crop_idx=crop_idx)
    else:
        crop_idx = None

    # align image
    # ref_axes = np.array([0, 2, 1])
    # ref_signs = np.array([-1, 1, -1])
    # im_axes, img_signs = utils.get_ras_axis_and_signs(aff, n_dims=n_dims)
    # im = edit_volume.align_volume_to_ref(im, ref_axes, ref_signs, im_axes, img_signs)

    # normalise image
    m = np.min(im)
    M = np.max(im)
    if M == m:
        im = np.zeros(im.shape)
    else:
        im = (im - m) / (M - m)

    # add batch and channel axes
    if n_channels > 1:
        im = utils.add_axis(im)
    else:
        im = utils.add_axis(im, -2)

    return im, aff, header, n_channels, n_dims, shape, pad_shape, crop_shape, crop_idx
예제 #11
0
파일: predict.py 프로젝트: BBillot/SynthSeg
def postprocess(post_patch,
                pad_shape,
                im_shape,
                crop,
                n_dims,
                segmentation_labels,
                lr_indices,
                keep_biggest_component,
                aff,
                topology_classes=True,
                post_patch_flip=None):

    # get posteriors and segmentation
    post_patch = np.squeeze(post_patch)
    if post_patch_flip is not None:
        post_patch_flip = edit_volumes.flip_volume(np.squeeze(post_patch_flip),
                                                   direction='rl',
                                                   aff=np.eye(4))
        if lr_indices is not None:
            post_patch_flip[..., lr_indices.flatten()] = post_patch_flip[
                ..., lr_indices[::-1].flatten()]
        post_patch = 0.5 * (post_patch + post_patch_flip)

    # keep biggest connected component (use it with smoothing!)
    if keep_biggest_component:
        tmp_post_patch = post_patch[..., 1:]
        post_patch_mask = np.sum(tmp_post_patch, axis=-1) > 0.25
        post_patch_mask = edit_volumes.get_largest_connected_component(
            post_patch_mask)
        post_patch_mask = np.stack([post_patch_mask] *
                                   tmp_post_patch.shape[-1],
                                   axis=-1)
        tmp_post_patch = edit_volumes.mask_volume(tmp_post_patch,
                                                  mask=post_patch_mask)
        post_patch[..., 1:] = tmp_post_patch

    # reset posteriors to zero outside the largest connected component of each topological class
    if topology_classes is not None:
        post_patch_mask = post_patch > 0.25
        for topology_class in np.unique(topology_classes)[1:]:
            tmp_topology_indices = np.where(
                topology_classes == topology_class)[0]
            tmp_mask = np.any(post_patch_mask[..., tmp_topology_indices],
                              axis=-1)
            tmp_mask = edit_volumes.get_largest_connected_component(tmp_mask)
            for idx in tmp_topology_indices:
                post_patch[..., idx] *= tmp_mask

    # renormalise posteriors and get hard segmentation
    if (post_patch_flip is not None) | keep_biggest_component | (
            topology_classes is not None):
        post_patch /= np.sum(post_patch, axis=-1)[..., np.newaxis]
    seg_patch = post_patch.argmax(-1)

    # paste patches back to matrix of original image size
    if crop is not None:
        seg = np.zeros(shape=pad_shape, dtype='int32')
        posteriors = np.zeros(shape=[*pad_shape, segmentation_labels.shape[0]])
        posteriors[...,
                   0] = np.ones(pad_shape)  # place background around patch
        if n_dims == 2:
            seg[crop[0]:crop[2], crop[1]:crop[3]] = seg_patch
            posteriors[crop[0]:crop[2], crop[1]:crop[3], :] = post_patch
        elif n_dims == 3:
            seg[crop[0]:crop[3], crop[1]:crop[4], crop[2]:crop[5]] = seg_patch
            posteriors[crop[0]:crop[3], crop[1]:crop[4],
                       crop[2]:crop[5], :] = post_patch
    else:
        seg = seg_patch
        posteriors = post_patch
    seg = segmentation_labels[seg.astype('int')].astype('int')

    if im_shape != pad_shape:
        bounds = [int((p - i) / 2) for (p, i) in zip(pad_shape, im_shape)]
        bounds += [p + i for (p, i) in zip(bounds, im_shape)]
        seg = edit_volumes.crop_volume_with_idx(seg, bounds)
        posteriors = edit_volumes.crop_volume_with_idx(posteriors,
                                                       bounds,
                                                       n_dims=n_dims)

    # align prediction back to first orientation
    if n_dims > 2:
        seg = edit_volumes.align_volume_to_ref(seg,
                                               aff=np.eye(4),
                                               aff_ref=aff,
                                               n_dims=n_dims,
                                               return_aff=False)
        posteriors = edit_volumes.align_volume_to_ref(posteriors,
                                                      aff=np.eye(4),
                                                      aff_ref=aff,
                                                      n_dims=n_dims)

    return seg, posteriors
예제 #12
0
파일: predict.py 프로젝트: mu40/SynthSeg
def preprocess_image(im_path, n_levels, crop_shape=None, padding=None, aff_ref='FS'):

    # read image and corresponding info
    im, shape, aff, n_dims, n_channels, header, im_res = utils.get_volume_info(im_path, return_volume=True)

    if padding:
        if n_channels == 1:
            im = np.pad(im, padding, mode='constant')
            pad_shape = im.shape
        else:
            im = np.pad(im, tuple([(padding, padding)] * n_dims + [(0, 0)]), mode='constant')
            pad_shape = im.shape[:-1]
    else:
        pad_shape = shape

    # check that patch_shape or im_shape are divisible by 2**n_levels
    if crop_shape is not None:
        crop_shape = utils.reformat_to_list(crop_shape, length=n_dims, dtype='int')
        if not all([pad_shape[i] >= crop_shape[i] for i in range(len(pad_shape))]):
            crop_shape = [min(pad_shape[i], crop_shape[i]) for i in range(n_dims)]
            print('cropping dimensions are higher than image size, changing cropping size to {}'.format(crop_shape))
        if not all([size % (2**n_levels) == 0 for size in crop_shape]):
            crop_shape = [utils.find_closest_number_divisible_by_m(size, 2 ** n_levels) for size in crop_shape]
    else:
        if not all([size % (2**n_levels) == 0 for size in pad_shape]):
            crop_shape = [utils.find_closest_number_divisible_by_m(size, 2 ** n_levels) for size in pad_shape]

    # crop image if necessary
    if crop_shape is not None:
        crop_idx = np.round((pad_shape - np.array(crop_shape)) / 2).astype('int')
        crop_idx = np.concatenate((crop_idx, crop_idx + crop_shape), axis=0)
        im = edit_volumes.crop_volume_with_idx(im, crop_idx=crop_idx)
    else:
        crop_idx = None

    # align image to training axes and directions
    if n_dims > 2:
        if aff_ref == 'FS':
            aff_ref = np.array([[-1., 0., 0., 0.], [0., 0., 1., 0.], [0., -1., 0., 0.], [0., 0., 0., 1.]])
            im = edit_volumes.align_volume_to_ref(im, aff, aff_ref=aff_ref, return_aff=False)
        elif aff_ref == 'identity':
            aff_ref = np.eye(4)
            im = edit_volumes.align_volume_to_ref(im, aff, aff_ref=aff_ref, return_aff=False)
        elif aff_ref == 'MS':
            aff_ref = np.array([[-1., 0., 0., 0.], [0., -1., 0., 0.], [0., 0., 1., 0.], [0., 0., 0., 1.]])
            im = edit_volumes.align_volume_to_ref(im, aff, aff_ref=aff_ref, return_aff=False)

    # normalise image
    if n_channels == 1:
        m = np.min(im)
        M = np.max(im)
        if M == m:
            im = np.zeros(im.shape)
        else:
            im = (im - m) / (M - m)
    if n_channels > 1:
        for i in range(im.shape[-1]):
            channel = im[..., i]
            m = np.min(channel)
            M = np.max(channel)
            if M == m:
                im[..., i] = np.zeros(channel.shape)
            else:
                im[..., i] = (channel - m) / (M - m)

    # add batch and channel axes
    if n_channels > 1:
        im = utils.add_axis(im)
    else:
        im = utils.add_axis(im, -2)

    return im, aff, header, im_res, n_channels, n_dims, shape, pad_shape, crop_shape, crop_idx