def get_shapes(im_shape, crop_shape, output_div_by_n):

    # reformat resolutions to lists
    im_shape = utils.reformat_to_list(im_shape)
    n_dims = len(im_shape)

    # crop_shape specified
    if crop_shape is not None:
        crop_shape = utils.reformat_to_list(crop_shape, length=n_dims, dtype='int')

        # make sure that crop_shape is smaller or equal to label shape
        crop_shape = [min(im_shape[i], crop_shape[i]) for i in range(n_dims)]

        # make sure crop_shape is divisible by output_div_by_n
        if output_div_by_n is not None:
            tmp_shape = [utils.find_closest_number_divisible_by_m(s, output_div_by_n, smaller_ans=True)
                         for s in crop_shape]
            if crop_shape != tmp_shape:
                print('crop_shape {0} not divisible by {1}, changed to {2}'.format(crop_shape, output_div_by_n,
                                                                                   tmp_shape))
                crop_shape = tmp_shape

    # no crop_shape specified, so no cropping unless label_shape is not divisible by output_div_by_n
    else:

        # make sure output shape is divisible by output_div_by_n
        if output_div_by_n is not None:
            crop_shape = [utils.find_closest_number_divisible_by_m(s, output_div_by_n, smaller_ans=True)
                          for s in im_shape]

        # if no need to be divisible by n, simply take labels_shape
        else:
            crop_shape = im_shape

    return crop_shape
Exemple #2
0
def draw_learning_curve(path_tensorboard_files, architecture_names, fontsize=18):
    """This function draws the learning curve of several trainings on the same graph.
    :param path_tensorboard_files: list of tensorboard files corresponding to the models to plot.
    :param architecture_names: list of the names of the models
    :param fontsize: (optional) fontsize used for the graph.
    """

    # reformat inputs
    path_tensorboard_files = utils.reformat_to_list(path_tensorboard_files)
    architecture_names = utils.reformat_to_list(architecture_names)
    assert len(path_tensorboard_files) == len(architecture_names), 'names and tensorboard lists should have same length'

    # loop over architectures
    plt.figure()
    for path_tensorboard_file, name in zip(path_tensorboard_files, architecture_names):

        # extract loss at the end of all epochs
        list_losses = list()
        logging.getLogger('tensorflow').disabled = True
        for e in summary_iterator(path_tensorboard_file):
            for v in e.summary.value:
                if v.tag == 'loss' or v.tag == 'accuracy' or v.tag == 'epoch_loss':
                    list_losses.append(v.simple_value)
        plt.plot(1-np.array(list_losses), label=name, linewidth=2)

    # finalise plot
    plt.grid()
    plt.legend(fontsize=fontsize)
    plt.xlabel('Epochs', fontsize=fontsize)
    plt.ylabel('Soft Dice scores', fontsize=fontsize)
    plt.tick_params(axis='both', labelsize=fontsize)
    plt.title('Validation curves', fontsize=fontsize)
    plt.tight_layout(pad=0)
    plt.show()
Exemple #3
0
def sample_intensity_stats_from_image(image, segmentation, labels_list, classes_list=None, keep_strictly_positive=True):
    """This function takes an image and corresponding segmentation as inputs. It estimates the mean and std intensity
    for all specified label values. Labels can share the same statistics by being regrouped into K classes.
    :param image: image from which to evaluate mean intensity and std deviation.
    :param segmentation: segmentation of the input image. Must have the same size as image.
    :param labels_list: list of labels for which to evaluate mean and std intensity.
    Can be a sequence, a 1d numpy array, or the path to a 1d numpy array.
    :param classes_list: (optional) enables to regroup structures into classes of similar intensity statistics.
    Intenstites associated to regrouped labels will thus contribute to the same Gaussian during statistics estimation.
    Can be a sequence, a 1d numpy array, or the path to a 1d numpy array.
    It should have the same length as labels_list, and contain values between 0 and K-1, where K is the total number of
    classes. Default is all labels have different classes (K=len(labels_list)).
    :param keep_strictly_positive: (optional) whether to only keep strictly positive intensity values when
    computing stats. This doesn't apply to the first label in label_list (or class if class_list is provided), for
    which we keep positive and zero values, as we consider it to be the background label.
    :return: a numpy array of size (2, K), the first row being the mean intensity for each structure,
    and the second being the median absolute deviation (robust estimation of std).
    """

    # reformat labels and classes
    labels_list = np.array(utils.reformat_to_list(labels_list, load_as_numpy=True, dtype='int'))
    if classes_list is not None:
        classes_list = np.array(utils.reformat_to_list(classes_list, load_as_numpy=True, dtype='int'))
    else:
        classes_list = np.arange(labels_list.shape[0])
    assert len(classes_list) == len(labels_list), 'labels and classes lists should have the same length'

    # get unique classes
    unique_classes, unique_indices = np.unique(classes_list, return_index=True)
    n_classes = len(unique_classes)
    if not np.array_equal(unique_classes, np.arange(n_classes)):
        raise ValueError('classes_list should only contain values between 0 and K-1, '
                         'where K is the total number of classes. Here K = %d' % n_classes)

    # compute mean/std of specified classes
    means = np.zeros(n_classes)
    stds = np.zeros(n_classes)
    for idx, tmp_class in enumerate(unique_classes):

        # get list of all intensity values for the current class
        class_labels = labels_list[classes_list == tmp_class]
        intensities = np.empty(0)
        for label in class_labels:
            tmp_intensities = image[segmentation == label]
            intensities = np.concatenate([intensities, tmp_intensities])
        if tmp_class:  # i.e. if not background
            if keep_strictly_positive:
                intensities = intensities[intensities > 0]

        # compute stats for class and put them to the location of corresponding label values
        if len(intensities) != 0:
            means[idx] = np.nanmedian(intensities)
            stds[idx] = median_absolute_deviation(intensities, nan_policy='omit')

    return np.stack([means, stds])
Exemple #4
0
def draw_learning_curve(path_tensorboard_files,
                        architecture_names,
                        figsize=(11, 6),
                        fontsize=18,
                        y_lim=None,
                        remove_legend=False):
    """This function draws the learning curve of several trainings on the same graph.
    :param path_tensorboard_files: list of tensorboard files corresponding to the models to plot.
    :param architecture_names: list of the names of the models
    :param figsize: (optional) size of the figure to draw.
    :param fontsize: (optional) fontsize used for the graph.
    """

    # reformat inputs
    path_tensorboard_files = utils.reformat_to_list(path_tensorboard_files)
    architecture_names = utils.reformat_to_list(architecture_names)
    assert len(path_tensorboard_files) == len(
        architecture_names
    ), 'names and tensorboard lists should have same length'

    # loop over architectures
    plt.figure(figsize=figsize)
    for path_tensorboard_file, name in zip(path_tensorboard_files,
                                           architecture_names):

        path_tensorboard_file = utils.reformat_to_list(path_tensorboard_file)

        # extract loss at the end of all epochs
        list_losses = list()
        list_epochs = list()
        logging.getLogger('tensorflow').disabled = True
        for path in path_tensorboard_file:
            for e in summary_iterator(path):
                for v in e.summary.value:
                    if v.tag == 'loss' or v.tag == 'accuracy' or v.tag == 'epoch_loss':
                        list_losses.append(v.simple_value)
                        list_epochs.append(e.step)
        plt.plot(np.array(list_epochs),
                 1 - np.array(list_losses),
                 label=name,
                 linewidth=2)

    # finalise plot
    plt.grid()
    if not remove_legend:
        plt.legend(fontsize=fontsize)
    plt.xlabel('Epochs', fontsize=fontsize)
    plt.ylabel('Soft Dice scores', fontsize=fontsize)
    if y_lim is not None:
        plt.ylim(y_lim[0], y_lim[1] + 0.01)  # set right/left limits of plot
    plt.tick_params(axis='both', labelsize=fontsize)
    plt.title('Learning curves', fontsize=fontsize)
    plt.tight_layout(pad=1)
    plt.show()
def get_shapes(crop_shape, im_shape, div_by_n):
    n_dims, _ = utils.get_dims(im_shape)
    # crop_shape specified
    if crop_shape is not None:
        crop_shape = utils.reformat_to_list(crop_shape,
                                            length=n_dims,
                                            dtype='int')
        crop_shape = [min(im_shape[i], crop_shape[i]) for i in range(n_dims)]
        # make sure output shape is divisible by output_div_by_n
        if div_by_n is not None:
            tmp_shape = [
                utils.find_closest_number_divisible_by_m(s,
                                                         div_by_n,
                                                         smaller_ans=True)
                for s in crop_shape
            ]
            if crop_shape != tmp_shape:
                print('crop shape {0} not divisible by {1}, changed to {2}'.
                      format(crop_shape, div_by_n, tmp_shape))
                crop_shape = tmp_shape
    # no crop_shape, so no cropping unless image shape is not divisible by output_div_by_n
    else:
        if div_by_n is not None:
            tmp_shape = [
                utils.find_closest_number_divisible_by_m(s,
                                                         div_by_n,
                                                         smaller_ans=True)
                for s in im_shape
            ]
            if tmp_shape != im_shape:
                print('image shape {0} not divisible by {1}, cropped to {2}'.
                      format(im_shape, div_by_n, tmp_shape))
                crop_shape = tmp_shape
    return crop_shape
Exemple #6
0
def preprocess_image(im_path, n_levels, crop_shape=None, padding=None, aff_ref='FS', dist_map=False):

    # read image and corresponding info
    im, shape, aff, n_dims, n_channels, header, im_res = utils.get_volume_info(im_path, return_volume=True)

    if padding:
        im = edit_volumes.pad_volume(im, padding_shape=padding)
        pad_shape = im.shape[:n_dims]
    else:
        pad_shape = shape

    # check that patch_shape or im_shape are divisible by 2**n_levels
    if crop_shape is not None:
        crop_shape = utils.reformat_to_list(crop_shape, length=n_dims, dtype='int')
        if not all([pad_shape[i] >= crop_shape[i] for i in range(len(pad_shape))]):
            crop_shape = [min(pad_shape[i], crop_shape[i]) for i in range(n_dims)]
        if not all([size % (2**n_levels) == 0 for size in crop_shape]):
            crop_shape = [utils.find_closest_number_divisible_by_m(size, 2 ** n_levels) for size in crop_shape]
    else:
        if not all([size % (2**n_levels) == 0 for size in pad_shape]):
            crop_shape = [utils.find_closest_number_divisible_by_m(size, 2 ** n_levels) for size in pad_shape]

    # crop image if necessary
    if crop_shape is not None:
        im, crop_idx = edit_volumes.crop_volume(im, cropping_shape=crop_shape, return_crop_idx=True)
    else:
        crop_idx = None

    # align image to training axes and directions
    if n_dims > 2:
        if aff_ref == 'FS':
            aff_ref = np.array([[-1., 0., 0., 0.], [0., 0., 1., 0.], [0., -1., 0., 0.], [0., 0., 0., 1.]])
            im = edit_volumes.align_volume_to_ref(im, aff, aff_ref=aff_ref, return_aff=False, n_dims=n_dims)
        elif aff_ref == 'identity':
            aff_ref = np.eye(4)
            im = edit_volumes.align_volume_to_ref(im, aff, aff_ref=aff_ref, return_aff=False, n_dims=n_dims)

    # normalise image
    if n_channels == 1:
        m = np.min(im)
        M = np.max(im)
        if M == m:
            im = np.zeros(im.shape)
        else:
            im = (im - m) / (M - m)
    else:
        for i in range(im.shape[-1]):
            if (not dist_map) | (dist_map & (i % 2 == 0)):
                channel = im[..., i]
                m = np.min(channel)
                M = np.max(channel)
                if M == m:
                    im[..., i] = np.zeros(channel.shape)
                else:
                    im[..., i] = (channel - m) / (M - m)

    # add batch and channel axes
    im = utils.add_axis(im) if n_channels > 1 else utils.add_axis(im, axis=[0, -1])

    return im, aff, header, im_res, n_channels, n_dims, shape, pad_shape, crop_idx
Exemple #7
0
def postprocessing_model(unet, posteriors_patch_shape, resample,
                         sigma_smoothing, n_dims):

    # get output from unet
    input_tensor = unet.inputs
    last_tensor = unet.outputs
    if isinstance(last_tensor, list):
        last_tensor = last_tensor[0]

    # resample to original resolution
    if resample is not None:
        last_tensor = nrn_layers.Resize(size=posteriors_patch_shape[:-1],
                                        name='post_resample')(last_tensor)

    # smooth posteriors
    if sigma_smoothing != 0:

        # separate image channels from labels channels
        n_labels = last_tensor.get_shape().as_list()[-1]
        split = KL.Lambda(lambda x: tf.split(x, [1] * n_labels, axis=-1),
                          name='resample_split')(last_tensor)

        # create gaussian blurring kernel
        sigma_smoothing = utils.reformat_to_list(sigma_smoothing,
                                                 length=n_dims)
        kernels_list = l2i_et.get_gaussian_1d_kernels(sigma_smoothing)

        # blur each image channel separately
        last_tensor = l2i_et.blur_tensor(split[0], kernels_list, n_dims)
        for i in range(1, n_labels):
            temp_blurred = l2i_et.blur_tensor(split[i], kernels_list, n_dims)
            last_tensor = KL.concatenate([last_tensor, temp_blurred],
                                         axis=-1,
                                         name='cat_blurring_%s' % i)

    # build model
    model_postprocessing = Model(inputs=input_tensor, outputs=last_tensor)

    return model_postprocessing
Exemple #8
0
def predict(path_images,
            path_model,
            segmentation_label_list,
            dist_map=False,
            path_segmentations=None,
            path_posteriors=None,
            path_volumes=None,
            segmentation_names_list=None,
            padding=None,
            cropping=None,
            resample=None,
            aff_ref='FS',
            sigma_smoothing=0,
            keep_biggest_component=False,
            conv_size=3,
            n_levels=5,
            nb_conv_per_level=2,
            unet_feat_count=24,
            feat_multiplier=2,
            activation='elu',
            gt_folder=None,
            evaluation_label_list=None,
            compute_distances=False,
            recompute=True,
            verbose=True):
    """
    This function uses trained models to segment images.
    It is crucial that the inputs match the architecture parameters of the trained model.
    :param path_images: path of the images to segment. Can be the path to a directory or the path to a single image.
    :param path_model: path ot the trained model.
    :param segmentation_label_list: List of labels for which to compute Dice scores. It should contain the same values
    as the segmentation label list used for training the network.
    Can be a sequence, a 1d numpy array, or the path to a numpy 1d array.
    :param dist_map: (optional) whether the input will contain distance maps channels (between each intenisty channels)
    Default is False.
    :param path_segmentations: (optional) path where segmentations will be writen.
    Should be a dir, if path_images is a dir, and afile if path_images is a file.
    Should not be None, if path_posteriors is None.
    :param path_posteriors: (optional) path where posteriors will be writen.
    Should be a dir, if path_images is a dir, and afile if path_images is a file.
    Should not be None, if path_segmentations is None.
    :param path_volumes: (optional) path of a csv file where the soft volumes of all segmented regions will be writen.
    The rows of the csv file correspond to subjects, and the columns correspond to segmentation labels.
    The soft volume of a structure corresponds to the sum of its predicted probability map.
    :param segmentation_names_list: (optional) List of names correponding to the names of the segmentation labels.
    Only used when path_volumes is provided. Must be of the same size as segmentation_label_list. Can be given as a
    list, a numpy array of strings, or the path to such a numpy array. Default is None.
    :param padding: (optional) pad the images to the specified shape before predicting the segmentation maps.
    Can be an int, a sequence or a 1d numpy array.
    :param cropping: (optional) crop the images to the specified shape before predicting the segmentation maps.
    If padding and cropping are specified, images are padded before being cropped.
    Can be an int, a sequence or a 1d numpy array.
    :param resample: (optional) resample the images to the specified resolution before predicting the segmentation maps.
    Can be an int, a sequence or a 1d numpy array.
    :param aff_ref: (optional) type of affine matrix of the images used for training. By default this is set to the
    FreeSurfer orientation ('FS'), as it was the configuration in which SynthSeg was trained. However, the new models
    are now trained on data aligned with identity vox2ras matrix, so you need to change aff_ref to 'identity'.
    :param sigma_smoothing: (optional) If not None, the posteriors are smoothed with a gaussian kernel of the specified
    standard deviation.
    :param keep_biggest_component: (optional) whether to only keep the biggest component in the predicted segmentation.
    :param conv_size: (optional) size of unet's convolution masks. Default is 3.
    :param n_levels: (optional) number of levels for unet. Default is 5.
    :param nb_conv_per_level: (optional) number of convolution layers per level. Default is 2.
    :param unet_feat_count: (optional) number of features for the first layer of the unet. Default is 24.
    :param feat_multiplier: (optional) multiplicative factor for the number of feature for each new level. Default is 2.
    :param activation: (optional) activation function. Can be 'elu', 'relu'.
    :param gt_folder: (optional) folder containing ground truth files for evaluation.
    A numpy array containing all dice scores (labels in rows, subjects in columns) will be writen either at
    segmentations_dir (if not None), or posteriors_dir.
    :param evaluation_label_list: (optional) if gt_folder is True you can evaluate the Dice scores on a subset of the
    segmentation labels, by providing another label list here. Can be a sequence, a 1d numpy array, or the path to a
    numpy 1d array. Default is the same as segmentation_label_list.
    :param recompute: (optional) whether to recompute segmentations that were already computed. This also applies to
    Dice scores, if gt_folder is not None. Default is True.
    :param verbose: (optional) whether to print out info about the remaining number of cases.
    """

    # prepare output filepaths
    images_to_segment, path_segmentations, path_posteriors, path_volumes, compute = \
        prepare_output_files(path_images, path_segmentations, path_posteriors, path_volumes, recompute)

    # get label and classes lists
    label_list, n_neutral_labels = utils.get_list_labels(label_list=segmentation_label_list, FS_sort=True)
    if evaluation_label_list is None:
        evaluation_label_list = segmentation_label_list

    # prepare volume file if needed
    if path_volumes is not None:
        if segmentation_names_list is not None:
            csv_header = [[''] + utils.reformat_to_list(segmentation_names_list, load_as_numpy=True)]
            csv_header += [[''] + [str(lab) for lab in label_list[1:]]]
        else:
            csv_header = [['subjects'] + [str(lab) for lab in label_list[1:]]]
        with open(path_volumes, 'w') as csvFile:
            writer = csv.writer(csvFile)
            writer.writerows(csv_header)
        csvFile.close()

    # perform segmentation
    net = None
    previous_model_input_shape = None
    loop_info = utils.LoopInfo(len(images_to_segment), 10, 'predicting', True)
    for idx, (path_image, path_segmentation, path_posterior, tmp_compute) in enumerate(zip(images_to_segment,
                                                                                           path_segmentations,
                                                                                           path_posteriors,
                                                                                           compute)):
        # compute segmentation only if needed
        if tmp_compute:

            # preprocess image and get information
            image, aff, h, im_res, n_channels, n_dims, shape, pad_shape, crop_idx = \
                preprocess_image(path_image, n_levels, cropping, padding, aff_ref=aff_ref, dist_map=dist_map)
            model_input_shape = list(image.shape[1:])

            # prepare net for first image or if input's size has changed
            if (net is None) | (previous_model_input_shape != model_input_shape):

                # check for image size compatibility
                if (net is not None) & (previous_model_input_shape != model_input_shape) & verbose:
                    print('image of different shape as previous ones, redefining network')
                previous_model_input_shape = model_input_shape

                # build network
                net = build_model(path_model, model_input_shape, resample, im_res, n_levels, len(label_list), conv_size,
                                  nb_conv_per_level, unet_feat_count, feat_multiplier, activation, sigma_smoothing)

            if verbose:
                loop_info.update(idx)

            # predict posteriors
            prediction_patch = net.predict(image)

            # get posteriors and segmentation
            seg, posteriors = postprocess(prediction_patch, pad_shape, shape, crop_idx, n_dims, label_list,
                                          keep_biggest_component, aff, aff_ref=aff_ref,
                                          keep_biggest_of_each_group=keep_biggest_component,
                                          n_neutral_labels=n_neutral_labels)

            # write results to disk
            if path_segmentation is not None:
                utils.save_volume(seg.astype('int'), aff, h, path_segmentation)
            if path_posterior is not None:
                if n_channels > 1:
                    posteriors = utils.add_axis(posteriors, axis=[0, -1])
                utils.save_volume(posteriors.astype('float'), aff, h, path_posterior)

        else:
            if path_volumes is not None:
                posteriors, _, _, _, _, _, im_res = utils.get_volume_info(path_posterior, True, aff_ref=np.eye(4))
            else:
                posteriors = im_res = None

        # compute volumes
        if path_volumes is not None:
            volumes = np.sum(posteriors[..., 1:], axis=tuple(range(0, len(posteriors.shape) - 1)))
            volumes = np.around(volumes * np.prod(im_res), 3)
            row = [os.path.basename(path_image).replace('.nii.gz', '')] + [str(vol) for vol in volumes]
            with open(path_volumes, 'a') as csvFile:
                writer = csv.writer(csvFile)
                writer.writerow(row)
            csvFile.close()

    # evaluate
    if gt_folder is not None:

        # find path evaluation folder
        path_first_result = path_segmentations[0] if (path_segmentations[0] is not None) else path_posteriors[0]
        eval_folder = os.path.dirname(path_first_result)

        # compute evaluation metrics
        evaluate.dice_evaluation(gt_folder,
                                 eval_folder,
                                 evaluation_label_list,
                                 compute_distances=compute_distances,
                                 compute_score_whole_structure=False,
                                 path_dice=os.path.join(eval_folder, 'dice.npy'),
                                 path_hausdorff=os.path.join(eval_folder, 'hausdorff.npy'),
                                 path_mean_distance=os.path.join(eval_folder, 'mean_distance.npy'),
                                 recompute=recompute,
                                 verbose=verbose)
Exemple #9
0
    def __init__(self,
                 labels_dir,
                 generation_labels=None,
                 n_neutral_labels=None,
                 output_labels=None,
                 subjects_prob=None,
                 patch_dir=None,
                 batchsize=1,
                 n_channels=1,
                 target_res=None,
                 output_shape=None,
                 output_div_by_n=None,
                 prior_distributions='uniform',
                 generation_classes=None,
                 prior_means=None,
                 prior_stds=None,
                 use_specific_stats_for_channel=False,
                 mix_prior_and_random=False,
                 flipping=True,
                 scaling_bounds=.15,
                 rotation_bounds=15,
                 shearing_bounds=.012,
                 translation_bounds=False,
                 nonlin_std=3.,
                 nonlin_shape_factor=.04,
                 randomise_res=True,
                 max_res_iso=4.,
                 max_res_aniso=8.,
                 data_res=None,
                 thickness=None,
                 downsample=False,
                 blur_range=1.03,
                 bias_field_std=.5,
                 bias_shape_factor=.025,
                 return_gradients=False):
        """
        This class is wrapper around the labels_to_image_model model. It contains the GPU model that generates images
        from labels maps, and a python generator that suplies the input data for this model.
        To generate pairs of image/labels you can just call the method generate_image() on an object of this class.

        :param labels_dir: path of folder with all input label maps, or to a single label map.

        # IMPORTANT !!!
        # Each time we provide a parameter with separate values for each axis (e.g. with a numpy array or a sequence),
        # these values refer to the RAS axes.

        # label maps-related parameters
        :param generation_labels: (optional) list of all possible label values in the input label maps.
        Default is None, where the label values are directly gotten from the provided label maps.
        If not None, can be a sequence or a 1d numpy array, or the path to a 1d numpy array.
        If flipping is true (i.e. right/left flipping is enabled), generation_labels should be organised as follows:
        background label first, then non-sided labels (e.g. CSF, brainstem, etc.), then all the structures of the same
        hemisphere (can be left or right), and finally all the corresponding contralateral structures in the same order.
        :param n_neutral_labels: (optional) number of non-sided generation labels. This is important only if you use
        flipping augmentation. Default is total number of label values.
        :param output_labels: (optional) list of the same length as generation_labels to indicate which values to use in
        the label maps returned by this function, i.e. all occurences of generation_labels[i] in the input label maps
        will be converted to output_labels[i] in the returned label maps. Examples:
        Set output_labels[i] to zero if you wish to erase the value generation_labels[i] from the returned label maps.
        Set output_labels[i]=generation_labels[i] to keep the value generation_labels[i] in the returned maps.
        Can be a list or a 1d numpy array. By default output_labels is equal to generation_labels.
        :param subjects_prob: (optional) relative order of importance (doesn't have to be probabilistic), with which to
        pick the provided label maps at each minibatch. Can be a sequence, a 1D numpy array, or the path to such an
        array, and it must be as long as path_label_maps. By default, all label maps are chosen with the same importance

        # output-related parameters
        :param batchsize: (optional) numbers of images to generate per mini-batch. Default is 1.
        :param n_channels: (optional) number of channels to be synthetised. Default is 1.
        :param target_res: (optional) target resolution of the generated images and corresponding label maps.
        If None, the outputs will have the same resolution as the input label maps.
        Can be a number (isotropic resolution), a sequence, a 1d numpy array, or the path to a 1d numpy array.
        :param output_shape: (optional) shape of the output image, obtained by randomly cropping the generated image.
        Can be an integer (same size in all dimensions), a sequence, a 1d numpy array, or the path to a 1d numpy array.
        Default is None, where no cropping is performed.
        :param output_div_by_n: (optional) forces the output shape to be divisible by this value. It overwrites
        output_shape if necessary. Can be an integer (same size in all dimensions), a sequence, a 1d numpy array, or
        the path to a 1d numpy array.

        # GMM-sampling parameters
        :param generation_classes: (optional) Indices regrouping generation labels into classes of same intensity
        distribution. Regouped labels will thus share the same Gaussian when samling a new image. Can be a sequence, a
        1d numpy array, or the path to a 1d numpy array. It should have the same length as generation_labels, and
        contain values between 0 and K-1, where K is the total number of classes.
        Default is all labels have different classes (K=len(generation_labels)).
        :param prior_distributions: (optional) type of distribution from which we sample the GMM parameters.
        Can either be 'uniform', or 'normal'. Default is 'uniform'.
        :param prior_means: (optional) hyperparameters controlling the prior distributions of the GMM means. Because
        these prior distributions are uniform or normal, they require by 2 hyperparameters. Thus prior_means can be:
        1) a sequence of length 2, directly defining the two hyperparameters: [min, max] if prior_distributions is
        uniform, [mean, std] if the distribution is normal. The GMM means of are independently sampled at each
        mini_batch from the same distribution.
        2) an array of shape (2, K), where K is the number of classes (K=len(generation_labels) if generation_classes is
        not given). The mean of the Gaussian distribution associated to class k in [0, ...K-1] is sampled at each
        mini-batch from U(prior_means[0,k], prior_means[1,k]) if prior_distributions is uniform, and from
        N(prior_means[0,k], prior_means[1,k]) if prior_distributions is normal.
        3) an array of shape (2*n_mod, K), where each block of two rows is associated to hyperparameters derived
        from different modalities. In this case, if use_specific_stats_for_channel is False, we first randomly select a
        modality from the n_mod possibilities, and we sample the GMM means like in 2).
        If use_specific_stats_for_channel is True, each block of two rows correspond to a different channel
        (n_mod=n_channels), thus we select the corresponding block to each channel rather than randomly drawing it.
        4) the path to such a numpy array.
        Default is None, which corresponds to prior_means = [25, 225].
        :param prior_stds: (optional) same as prior_means but for the standard deviations of the GMM.
        Default is None, which corresponds to prior_stds = [5, 25].
        :param use_specific_stats_for_channel: (optional) whether the i-th block of two rows in the prior arrays must be
        only used to generate the i-th channel. If True, n_mod should be equal to n_channels. Default is False.
        :param mix_prior_and_random: (optional) if prior_means is not None, enables to reset the priors to their default
        values for half of thes cases, and thus generate images of random contrast.

        # spatial deformation parameters
        :param flipping: (optional) whether to introduce right/left random flipping. Default is True.
        :param scaling_bounds: (optional) range of the random saling to apply at each mini-batch. The scaling factor for
        each dimension is sampled from a uniform distribution of predefined bounds. Can either be:
        1) a number, in which case the scaling factor is independently sampled from the uniform distribution of bounds
        [1-scaling_bounds, 1+scaling_bounds] for each dimension.
        2) a sequence, in which case the scaling factor is sampled from the uniform distribution of bounds
        (1-scaling_bounds[i], 1+scaling_bounds[i]) for the i-th dimension.
        3) a numpy array of shape (2, n_dims), in which case the scaling factor is sampled from the uniform distribution
         of bounds (scaling_bounds[0, i], scaling_bounds[1, i]) for the i-th dimension.
        4) False, in which case scaling is completely turned off.
        Default is scaling_bounds = 0.15 (case 1)
        :param rotation_bounds: (optional) same as scaling bounds but for the rotation angle, except that for cases 1
        and 2, the bounds are centred on 0 rather than 1, i.e. [0+rotation_bounds[i], 0-rotation_bounds[i]].
        Default is rotation_bounds = 15.
        :param shearing_bounds: (optional) same as scaling bounds. Default is shearing_bounds = 0.012.
        :param translation_bounds: (optional) same as scaling bounds. Default is translation_bounds = False, but we
        encourage using it when cropping is deactivated (i.e. when output_shape=None in BrainGenerator).
        :param nonlin_std: (optional) Maximum value for the standard deviation of the normal distribution from which we
        sample the first tensor for synthesising the deformation field. Set to 0 if you wish to completely turn the
        elastic deformation off.
        :param nonlin_shape_factor: (optional) if nonlin_std is strictly positive, factor between the shapes of the
        input label maps and the shape of the input non-linear tensor.

        # blurring/resampling parameters
        :param randomise_res: (optional) whether to mimic images that would have been 1) acquired at low resolution, and
        2) resampled to high esolution. The low resolution is uniformly resampled at each minibatch from [1mm, 9mm].
        In that process, the images generated by sampling the GMM are:
        1) blurred at the sampled LR, 2) downsampled at LR, and 3) resampled at target_resolution.
        :param max_res_iso: (optional) If randomise_res is True, this enables to control the upper bound of the uniform
        distribution from which we sample the random resolution U(min_res, max_res_iso), where min_res is the resolution
        of the input label maps. Must be a number, and default is 4. Set to None to deactivate it, but if randomise_res
        is True, at least one of max_res_iso or max_res_aniso must be given.
        :param max_res_aniso: If randomise_res is True, we this enables to downsample the input volumes to a random LR
        in only 1 (random) direction. This is done by randomly selecting a direction i in range [0, n_dims-1], and
        sampling a value in the corresponding uniform distribution U(min_res[i], max_res_aniso[i]), where min_res is the
        resolution of the input label maps. Can be a number, a sequence, or a 1d numpy array. Set to None to deactivate
        it, but if randomise_res is True, at least one of max_res_iso or max_res_aniso must be given.
        :param data_res: (optional) specific acquisition resolution to mimic, as opposed to random resolution sampled
        when randomis_res is True. This triggers a blurring which mimics the acquisition resolution, but downsampling
        is optional (see param downsample). Default for data_res is None, where images are slighlty blurred.
        If the generated images are uni-modal, data_res can be a number (isotropic acquisition resolution), a sequence,
        a 1d numpy array, or the path to a 1d numy array. In the multi-modal case, it should be given as a umpy array (
        or a path) of size (n_mod, n_dims), where each row is the acquisition resolution of the corresponding channel.
        :param thickness: (optional) if data_res is provided, we can further specify the slice thickness of the low
        resolution images to mimic. Must be provided in the same format as data_res. Default thickness = data_res.
        :param downsample: (optional) whether to actually downsample the volume images to data_res after blurring.
        Default is False, except when thickness is provided, and thickness < data_res.
        :param blur_range: (optional) Randomise the standard deviation of the blurring kernels, (whether data_res is
        given or not). At each mini_batch, the standard deviation of the blurring kernels are multiplied by a
        coefficient sampled from a uniform distribution with bounds [1/blur_range, blur_range].
        If None, no randomisation. Default is 1.15.

        # bias field parameters
        :param bias_field_std: (optional) If strictly positive, this triggers the corruption of synthesised images with
        a bias field. It is obtained by sampling a first small tensor from a normal distribution, resizing it to full
        size, and rescaling it to positive values by taking the voxel-wise exponential. bias_field_std designates the
        std dev of the normal distribution from which we sample the first tensor. Set to 0 to deactivate bias field.
        :param bias_shape_factor: (optional) If bias_field_std is strictly positive, this designates the ratio between
        the size of the input label maps and the size of the first sampled tensor for synthesising the bias field.

        :param return_gradients: (optional) whether to return the synthetic image or the magnitude of its spatial
        gradient (computed with Sobel kernels).
        """

        # prepare data files
        self.labels_paths = utils.list_images_in_folder(labels_dir)
        self.path_patches = utils.list_images_in_folder(patch_dir) if (patch_dir is not None) else None
        if subjects_prob is not None:
            self.subjects_prob = np.array(utils.reformat_to_list(subjects_prob, load_as_numpy=True), dtype='float32')
            assert len(self.subjects_prob) == len(self.labels_paths), \
                'subjects_prob should have the same length as labels_path, ' \
                'had {} and {}'.format(len(self.subjects_prob), len(self.labels_paths))

        # generation parameters
        self.labels_shape, self.aff, self.n_dims, _, self.header, self.atlas_res = \
            utils.get_volume_info(self.labels_paths[0], aff_ref=np.eye(4))
        self.n_channels = n_channels
        if generation_labels is not None:
            self.generation_labels = utils.load_array_if_path(generation_labels)
        else:
            self.generation_labels, _ = utils.get_list_labels(labels_dir=labels_dir)
        if output_labels is not None:
            self.output_labels = utils.load_array_if_path(output_labels)
        else:
            self.output_labels = self.generation_labels
        if n_neutral_labels is not None:
            self.n_neutral_labels = n_neutral_labels
        else:
            self.n_neutral_labels = self.generation_labels.shape[0]
        self.target_res = utils.load_array_if_path(target_res)
        self.batchsize = batchsize
        # preliminary operations
        self.flipping = flipping
        self.output_shape = utils.load_array_if_path(output_shape)
        self.output_div_by_n = output_div_by_n
        # GMM parameters
        self.prior_distributions = prior_distributions
        if generation_classes is not None:
            self.generation_classes = utils.load_array_if_path(generation_classes)
            assert self.generation_classes.shape == self.generation_labels.shape, \
                'if provided, generation_classes should have the same shape as generation_labels'
            unique_classes = np.unique(self.generation_classes)
            assert np.array_equal(unique_classes, np.arange(np.max(unique_classes)+1)), \
                'generation_classes should a linear range between 0 and its maximum value.'
        else:
            self.generation_classes = np.arange(self.generation_labels.shape[0])
        self.prior_means = utils.load_array_if_path(prior_means)
        self.prior_stds = utils.load_array_if_path(prior_stds)
        self.use_specific_stats_for_channel = use_specific_stats_for_channel
        # linear transformation parameters
        self.scaling_bounds = utils.load_array_if_path(scaling_bounds)
        self.rotation_bounds = utils.load_array_if_path(rotation_bounds)
        self.shearing_bounds = utils.load_array_if_path(shearing_bounds)
        self.translation_bounds = utils.load_array_if_path(translation_bounds)
        # elastic transformation parameters
        self.nonlin_std = nonlin_std
        self.nonlin_shape_factor = nonlin_shape_factor
        # blurring parameters
        self.randomise_res = randomise_res
        self.max_res_iso = max_res_iso
        self.max_res_aniso = max_res_aniso
        self.data_res = utils.load_array_if_path(data_res)
        assert not (self.randomise_res & (self.data_res is not None)), \
            'randomise_res and data_res cannot be provided at the same time'
        self.thickness = utils.load_array_if_path(thickness)
        self.downsample = downsample
        self.blur_range = blur_range
        # bias field parameters
        self.bias_field_std = bias_field_std
        self.bias_shape_factor = bias_shape_factor
        self.return_gradients = return_gradients

        # build transformation model
        self.labels_to_image_model, self.model_output_shape = self._build_labels_to_image_model()

        # build generator for model inputs
        self.model_inputs_generator = self._build_model_inputs_generator(mix_prior_and_random)

        # build brain generator
        self.brain_generator = self._build_brain_generator()
def preprocess_adni_hippo(path_t1,
                          path_t2,
                          path_aseg,
                          result_dir,
                          target_res,
                          padding_margin=85,
                          remove=False,
                          path_freesurfer='/usr/local/freesurfer/',
                          verbose=True,
                          recompute=True):
    """This function builds a T1+T2 multimodal image from the ADNI dataset.
    It first rescales intensities of each channel between 0 and 255.
    It then resamples the T2 image (which are 0.4*0.4*2.0 resolution) to target resolution.
    The obtained T2 is then padded in all directions by the padding_margin param (typically large 85).
    The T1 and aseg are then resampled like the T2 using mri_convert.
    Now that the T1, T2 and asegs are aligned and at the same resolution, we crop them around the right and left hippo.
    Finally, the T1 and T2 are concatenated into one single multimodal image.
    :param path_t1: path input T1 (typically at 1mm isotropic)
    :param path_t2: path input T2 (typically cropped around the hippo in sagittal axis, 0.4x0.4x2.0)
    :param path_aseg: path input segmentation (typically at 1mm isotropic)
    :param result_dir: path of directory where prepared images and labels will be writen.
    :param target_res: resolution at which to resample the label maps, and the images.
    Can be a number (isotropic resolution), a sequence, or a 1d numpy array.
    :param padding_margin: (optional) margin to add around hippocampi when cropping
    :param remove: (optional) whether to delete temporary files. Default is True.
    :param path_freesurfer: (optional) path of FreeSurfer home, to use mri_convert
    :param verbose: (optional) whether to print out mri_convert output when resampling images
    :param recompute: (optional) whether to recompute result files even if they already exists
    """

    # create results dir
    if not os.path.isdir(result_dir):
        os.mkdir(result_dir)

    path_test_im_right = os.path.join(result_dir, 'hippo_right.nii.gz')
    path_test_aseg_right = os.path.join(result_dir, 'hippo_right_aseg.nii.gz')
    path_test_im_left = os.path.join(result_dir, 'hippo_left.nii.gz')
    path_test_aseg_left = os.path.join(result_dir, 'hippo_left_aseg.nii.gz')
    if (not os.path.isfile(path_test_im_right)) | (not os.path.isfile(path_test_aseg_right)) | \
       (not os.path.isfile(path_test_im_left)) | (not os.path.isfile(path_test_aseg_left)) | recompute:

        # set up FreeSurfer
        os.environ['FREESURFER_HOME'] = path_freesurfer
        os.system(os.path.join(path_freesurfer, 'SetUpFreeSurfer.sh'))
        mri_convert = os.path.join(path_freesurfer, 'bin/mri_convert.bin')

        # rescale T1
        path_t1_rescaled = os.path.join(result_dir, 't1_rescaled.nii.gz')
        if (not os.path.isfile(path_t1_rescaled)) | recompute:
            im, aff, h = utils.load_volume(path_t1, im_only=False)
            im = edit_volumes.rescale_volume(im)
            utils.save_volume(im, aff, h, path_t1_rescaled)
        # rescale T2
        path_t2_rescaled = os.path.join(result_dir, 't2_rescaled.nii.gz')
        if (not os.path.isfile(path_t2_rescaled)) | recompute:
            im, aff, h = utils.load_volume(path_t2, im_only=False)
            im = edit_volumes.rescale_volume(im)
            utils.save_volume(im, aff, h, path_t2_rescaled)

        # resample T2 to target res
        path_t2_resampled = os.path.join(result_dir, 't2_rescaled_resampled.nii.gz')
        if (not os.path.isfile(path_t2_resampled)) | recompute:
            str_res = ' '.join([str(r) for r in utils.reformat_to_list(target_res, length=3)])
            cmd = mri_convert + ' ' + path_t2_rescaled + ' ' + path_t2_resampled + ' --voxsize ' + str_res
            cmd += ' -odt float'
            if not verbose:
                cmd += ' >/dev/null 2>&1'
            _ = os.system(cmd)

        # pad T2
        path_t2_padded = os.path.join(result_dir, 't2_rescaled_resampled_padded.nii.gz')
        if (not os.path.isfile(path_t2_padded)) | recompute:
            t2, aff, h = utils.load_volume(path_t2_resampled, im_only=False)
            t2_padded = np.pad(t2, padding_margin, 'constant')
            aff[:3, -1] = aff[:3, -1] - (aff[:3, :3] @ (padding_margin * np.ones((3, 1)))).T
            utils.save_volume(t2_padded, aff, h, path_t2_padded)

        # resample T1 and aseg accordingly
        path_t1_resampled = os.path.join(result_dir, 't1_rescaled_resampled.nii.gz')
        if (not os.path.isfile(path_t1_resampled)) | recompute:
            cmd = mri_convert + ' ' + path_t1_rescaled + ' ' + path_t1_resampled + ' -rl ' + path_t2_padded
            cmd += ' -odt float'
            if not verbose:
                cmd += ' >/dev/null 2>&1'
            _ = os.system(cmd)
        path_aseg_resampled = os.path.join(result_dir, 'aseg_resampled.nii.gz')
        if (not os.path.isfile(path_aseg_resampled)) | recompute:
            cmd = mri_convert + ' ' + path_aseg + ' ' + path_aseg_resampled + ' -rl ' + path_t2_padded
            cmd += ' -rt nearest -odt float'
            if not verbose:
                cmd += ' >/dev/null 2>&1'
            _ = os.system(cmd)

        # crop images and concatenate T1 and T2
        for lab, side in zip([17, 53], ['left', 'right']):
            path_test_image = os.path.join(result_dir, 'hippo_{}.nii.gz'.format(side))
            path_test_aseg = os.path.join(result_dir, 'hippo_{}_aseg.nii.gz'.format(side))
            if (not os.path.isfile(path_test_image)) | (not os.path.isfile(path_test_aseg)) | recompute:
                aseg, aff, h = utils.load_volume(path_aseg_resampled, im_only=False)
                tmp_aseg, cropping, tmp_aff = edit_volumes.crop_volume_around_region(aseg,
                                                                                     margin=30,
                                                                                     masking_labels=lab,
                                                                                     aff=aff)
                if side == 'right':
                    tmp_aseg = edit_volumes.flip_volume(tmp_aseg, direction='rl', aff=tmp_aff)
                utils.save_volume(tmp_aseg, tmp_aff, h, path_test_aseg)
                if (not os.path.isfile(path_test_image)) | recompute:
                    t1 = utils.load_volume(path_t1_resampled)
                    t1 = edit_volumes.crop_volume_with_idx(t1, crop_idx=cropping)
                    t1 = edit_volumes.mask_volume(t1, tmp_aseg, dilate=6, erode=5)
                    t2 = utils.load_volume(path_t2_padded)
                    t2 = edit_volumes.crop_volume_with_idx(t2, crop_idx=cropping)
                    t2 = edit_volumes.mask_volume(t2, tmp_aseg, dilate=6, erode=5)
                    if side == 'right':
                        t1 = edit_volumes.flip_volume(t1, direction='rl', aff=tmp_aff)
                        t2 = edit_volumes.flip_volume(t2, direction='rl', aff=tmp_aff)
                    test_image = np.stack([t1, t2], axis=-1)
                    utils.save_volume(test_image, tmp_aff, h, path_test_image)

        # remove unnecessary files
        if remove:
            list_files_to_remove = [path_t1_rescaled,
                                    path_t2_rescaled,
                                    path_t2_resampled,
                                    path_t2_padded,
                                    path_t1_resampled,
                                    path_aseg_resampled]
            for path in list_files_to_remove:
                os.remove(path)
Exemple #11
0
def preprocess_image(im_path,
                     n_levels,
                     target_res,
                     crop=None,
                     padding=None,
                     flip=False,
                     path_resample=None):

    # read image and corresponding info
    im, _, aff, n_dims, n_channels, header, im_res = utils.get_volume_info(
        im_path, True)

    # resample image if necessary
    if target_res is not None:
        target_res = np.squeeze(
            utils.reformat_to_n_channels_array(target_res, n_dims))
        if np.any((im_res > target_res + 0.05) | (im_res < target_res - 0.05)):
            im_res = target_res
            im, aff = edit_volumes.resample_volume(im, aff, im_res)
            if path_resample is not None:
                utils.save_volume(im, aff, header, path_resample)

    # align image
    im = edit_volumes.align_volume_to_ref(im,
                                          aff,
                                          aff_ref=np.eye(4),
                                          n_dims=n_dims)
    shape = list(im.shape)

    # pad image if specified
    if padding:
        im = edit_volumes.pad_volume(im, padding_shape=padding)
        pad_shape = im.shape[:n_dims]
    else:
        pad_shape = shape

    # check that patch_shape or im_shape are divisible by 2**n_levels
    if crop is not None:
        crop = utils.reformat_to_list(crop, length=n_dims, dtype='int')
        if not all([pad_shape[i] >= crop[i] for i in range(len(pad_shape))]):
            crop = [min(pad_shape[i], crop[i]) for i in range(n_dims)]
        if not all([size % (2**n_levels) == 0 for size in crop]):
            crop = [
                utils.find_closest_number_divisible_by_m(size, 2**n_levels)
                for size in crop
            ]
    else:
        if not all([size % (2**n_levels) == 0 for size in pad_shape]):
            crop = [
                utils.find_closest_number_divisible_by_m(size, 2**n_levels)
                for size in pad_shape
            ]

    # crop image if necessary
    if crop is not None:
        im, crop_idx = edit_volumes.crop_volume(im,
                                                cropping_shape=crop,
                                                return_crop_idx=True)
    else:
        crop_idx = None

    # normalise image
    if n_channels == 1:
        im = edit_volumes.rescale_volume(im,
                                         new_min=0.,
                                         new_max=1.,
                                         min_percentile=0.5,
                                         max_percentile=99.5)
    else:
        for i in range(im.shape[-1]):
            im[..., i] = edit_volumes.rescale_volume(im[..., i],
                                                     new_min=0.,
                                                     new_max=1.,
                                                     min_percentile=0.5,
                                                     max_percentile=99.5)

    # flip image along right/left axis
    if flip & (n_dims > 2):
        im_flipped = edit_volumes.flip_volume(im,
                                              direction='rl',
                                              aff=np.eye(4))
        im_flipped = utils.add_axis(
            im_flipped) if n_channels > 1 else utils.add_axis(im_flipped,
                                                              axis=[0, -1])
    else:
        im_flipped = None

    # add batch and channel axes
    im = utils.add_axis(im) if n_channels > 1 else utils.add_axis(im,
                                                                  axis=[0, -1])

    return im, aff, header, im_res, n_channels, n_dims, shape, pad_shape, crop_idx, im_flipped
Exemple #12
0
def preprocess_image(im_path, n_levels, crop_shape=None, padding=None):

    # read image and corresponding info
    im, shape, aff, n_dims, n_channels, header, im_res = utils.get_volume_info(
        im_path, True, np.eye(4))

    if padding:
        im = edit_volumes.pad_volume(im, padding_shape=padding)
        pad_shape = im.shape[:n_dims]
    else:
        pad_shape = shape

    # check that patch_shape or im_shape are divisible by 2**n_levels
    if crop_shape is not None:
        crop_shape = utils.reformat_to_list(crop_shape,
                                            length=n_dims,
                                            dtype='int')
        if not all(
            [pad_shape[i] >= crop_shape[i] for i in range(len(pad_shape))]):
            crop_shape = [
                min(pad_shape[i], crop_shape[i]) for i in range(n_dims)
            ]
        if not all([size % (2**n_levels) == 0 for size in crop_shape]):
            crop_shape = [
                utils.find_closest_number_divisible_by_m(size, 2**n_levels)
                for size in crop_shape
            ]
    else:
        if not all([size % (2**n_levels) == 0 for size in pad_shape]):
            crop_shape = [
                utils.find_closest_number_divisible_by_m(size, 2**n_levels)
                for size in pad_shape
            ]

    # crop image if necessary
    if crop_shape is not None:
        im, crop_idx = edit_volumes.crop_volume(im,
                                                cropping_shape=crop_shape,
                                                return_crop_idx=True)
    else:
        crop_idx = None

    # normalise image
    if n_channels == 1:
        m = np.min(im)
        M = np.max(im)
        if M == m:
            im = np.zeros(im.shape)
        else:
            im = (im - m) / (M - m)
    else:
        for i in range(im.shape[-1]):
            channel = im[..., i]
            m = np.min(channel)
            M = np.max(channel)
            if M == m:
                im[..., i] = np.zeros(channel.shape)
            else:
                im[..., i] = (channel - m) / (M - m)

    # add batch and channel axes
    im = utils.add_axis(im) if n_channels > 1 else utils.add_axis(im,
                                                                  axis=[0, -1])

    return im, aff, header, im_res, n_channels, n_dims, shape, pad_shape, crop_idx
Exemple #13
0
def build_augmentation_model(im_shape,
                             n_channels,
                             segmentation_labels,
                             n_neutral_labels,
                             atlas_res,
                             target_res,
                             output_shape=None,
                             output_div_by_n=None,
                             flipping=True,
                             aff=None,
                             scaling_bounds=0.15,
                             rotation_bounds=15,
                             shearing_bounds=0.012,
                             translation_bounds=False,
                             nonlin_std=3.,
                             nonlin_shape_factor=.0625,
                             data_res=None,
                             thickness=None,
                             downsample=False,
                             blur_range=1.03,
                             bias_field_std=.5,
                             bias_shape_factor=.025):

    # reformat resolutions and get shapes
    im_shape = utils.reformat_to_list(im_shape)
    n_dims, _ = utils.get_dims(im_shape)
    if data_res is not None:
        data_res = utils.reformat_to_n_channels_array(data_res, n_dims,
                                                      n_channels)
        thickness = data_res if thickness is None else utils.reformat_to_n_channels_array(
            thickness, n_dims, n_channels)
        downsample = utils.reformat_to_list(
            downsample,
            n_channels) if downsample else np.min(thickness - data_res, 1) < 0
        target_res = atlas_res if (
            target_res is None) else utils.reformat_to_n_channels_array(
                target_res, n_dims)[0]
    else:
        target_res = atlas_res

    # get shapes
    crop_shape, output_shape = get_shapes(im_shape, output_shape, atlas_res,
                                          target_res, output_div_by_n)

    # define model inputs
    image_input = KL.Input(shape=im_shape + [n_channels], name='image_input')
    labels_input = KL.Input(shape=im_shape + [1],
                            name='labels_input',
                            dtype='int32')

    # deform labels
    labels, image = layers.RandomSpatialDeformation(
        scaling_bounds=scaling_bounds,
        rotation_bounds=rotation_bounds,
        shearing_bounds=shearing_bounds,
        translation_bounds=translation_bounds,
        nonlin_std=nonlin_std,
        nonlin_shape_factor=nonlin_shape_factor,
        inter_method=['nearest', 'linear'])([labels_input, image_input])

    # cropping
    if crop_shape != im_shape:
        labels._keras_shape = tuple(labels.get_shape().as_list())
        image._keras_shape = tuple(image.get_shape().as_list())
        labels, image = layers.RandomCrop(crop_shape)([labels, image])

    # flipping
    if flipping:
        assert aff is not None, 'aff should not be None if flipping is True'
        labels._keras_shape = tuple(labels.get_shape().as_list())
        image._keras_shape = tuple(image.get_shape().as_list())
        labels, image = layers.RandomFlip(
            get_ras_axes(aff, n_dims)[0], [True, False], segmentation_labels,
            n_neutral_labels)([labels, image])

    # apply bias field
    if bias_field_std > 0:
        image._keras_shape = tuple(image.get_shape().as_list())
        image = layers.BiasFieldCorruption(bias_field_std, bias_shape_factor,
                                           False)(image)

    # intensity augmentation
    image._keras_shape = tuple(image.get_shape().as_list())
    image = layers.IntensityAugmentation(6,
                                         clip=False,
                                         normalise=True,
                                         gamma_std=.4,
                                         separate_channels=True)(image)

    # if necessary, loop over channels to 1) blur, 2) downsample to simulated LR, and 3) upsample to target
    if data_res is not None:
        channels = list()
        split = KL.Lambda(lambda x: tf.split(x, [1] * n_channels, axis=-1))(
            image) if (n_channels > 1) else [image]
        for i, channel in enumerate(split):

            # blur
            channel._keras_shape = tuple(channel.get_shape().as_list())
            sigma = l2i_et.blurring_sigma_for_downsampling(
                atlas_res, data_res[i], thickness=thickness[i])
            channel = layers.GaussianBlur(sigma, blur_range)(channel)

            # resample
            if downsample[i]:
                resolution = KL.Lambda(lambda x: tf.convert_to_tensor(
                    data_res[i], dtype='float32'))([])
                channel = layers.MimicAcquisition(atlas_res, data_res[i],
                                                  output_shape)(
                                                      [channel, resolution])
            elif output_shape != crop_shape:
                channel = nrn_layers.Resize(size=output_shape)(channel)
            channels.append(channel)

        # concatenate all channels back
        image = KL.Lambda(lambda x: tf.concat(x, -1))(
            channels) if len(channels) > 1 else channels[0]

        # resample labels at target resolution
        if crop_shape != output_shape:
            labels = l2i_et.resample_tensor(labels,
                                            output_shape,
                                            interp_method='nearest')

    # build model (dummy layer enables to keep the labels when plugging this model to other models)
    labels = KL.Lambda(lambda x: tf.cast(x, dtype='int32'),
                       name='labels_out')(labels)
    image = KL.Lambda(lambda x: x[0], name='image_out')([image, labels])
    brain_model = models.Model(inputs=[image_input, labels_input],
                               outputs=[image, labels])

    return brain_model
Exemple #14
0
def preprocess_image(im_path, n_levels, crop_shape=None, padding=None, aff_ref='FS'):

    # read image and corresponding info
    im, shape, aff, n_dims, n_channels, header, im_res = utils.get_volume_info(im_path, return_volume=True)

    if padding:
        if n_channels == 1:
            im = np.pad(im, padding, mode='constant')
            pad_shape = im.shape
        else:
            im = np.pad(im, tuple([(padding, padding)] * n_dims + [(0, 0)]), mode='constant')
            pad_shape = im.shape[:-1]
    else:
        pad_shape = shape

    # check that patch_shape or im_shape are divisible by 2**n_levels
    if crop_shape is not None:
        crop_shape = utils.reformat_to_list(crop_shape, length=n_dims, dtype='int')
        if not all([pad_shape[i] >= crop_shape[i] for i in range(len(pad_shape))]):
            crop_shape = [min(pad_shape[i], crop_shape[i]) for i in range(n_dims)]
            print('cropping dimensions are higher than image size, changing cropping size to {}'.format(crop_shape))
        if not all([size % (2**n_levels) == 0 for size in crop_shape]):
            crop_shape = [utils.find_closest_number_divisible_by_m(size, 2 ** n_levels) for size in crop_shape]
    else:
        if not all([size % (2**n_levels) == 0 for size in pad_shape]):
            crop_shape = [utils.find_closest_number_divisible_by_m(size, 2 ** n_levels) for size in pad_shape]

    # crop image if necessary
    if crop_shape is not None:
        crop_idx = np.round((pad_shape - np.array(crop_shape)) / 2).astype('int')
        crop_idx = np.concatenate((crop_idx, crop_idx + crop_shape), axis=0)
        im = edit_volumes.crop_volume_with_idx(im, crop_idx=crop_idx)
    else:
        crop_idx = None

    # align image to training axes and directions
    if n_dims > 2:
        if aff_ref == 'FS':
            aff_ref = np.array([[-1., 0., 0., 0.], [0., 0., 1., 0.], [0., -1., 0., 0.], [0., 0., 0., 1.]])
            im = edit_volumes.align_volume_to_ref(im, aff, aff_ref=aff_ref, return_aff=False)
        elif aff_ref == 'identity':
            aff_ref = np.eye(4)
            im = edit_volumes.align_volume_to_ref(im, aff, aff_ref=aff_ref, return_aff=False)
        elif aff_ref == 'MS':
            aff_ref = np.array([[-1., 0., 0., 0.], [0., -1., 0., 0.], [0., 0., 1., 0.], [0., 0., 0., 1.]])
            im = edit_volumes.align_volume_to_ref(im, aff, aff_ref=aff_ref, return_aff=False)

    # normalise image
    if n_channels == 1:
        m = np.min(im)
        M = np.max(im)
        if M == m:
            im = np.zeros(im.shape)
        else:
            im = (im - m) / (M - m)
    if n_channels > 1:
        for i in range(im.shape[-1]):
            channel = im[..., i]
            m = np.min(channel)
            M = np.max(channel)
            if M == m:
                im[..., i] = np.zeros(channel.shape)
            else:
                im[..., i] = (channel - m) / (M - m)

    # add batch and channel axes
    if n_channels > 1:
        im = utils.add_axis(im)
    else:
        im = utils.add_axis(im, -2)

    return im, aff, header, im_res, n_channels, n_dims, shape, pad_shape, crop_shape, crop_idx
Exemple #15
0
def metrics_model(input_shape,
                  segmentation_label_list,
                  input_model=None,
                  loss_cropping=None,
                  metrics='dice',
                  weight_background=None,
                  include_background=False,
                  name=None,
                  prefix=None,
                  validation_on_real_images=False):

    # naming the model
    model_name = name
    if prefix is None:
        prefix = model_name

    # first layer: input
    name = '%s_input' % prefix
    if input_model is None:
        input_tensor = KL.Input(shape=input_shape, name=name)
        last_tensor = input_tensor
    else:
        input_tensor = input_model.inputs
        last_tensor = input_model.outputs
        if isinstance(last_tensor, list):
            last_tensor = last_tensor[0]
        last_tensor = KL.Reshape(input_shape,
                                 name='predicted_output')(last_tensor)

    # get deformed labels
    n_labels = input_shape[-1]
    if validation_on_real_images:
        labels_gt = KL.Input(shape=input_shape[:-1] + [1], name='labels_input')
        input_tensor = [input_tensor[0], labels_gt]
    else:
        labels_gt = input_model.get_layer('labels_out').output

    # convert gt labels to 0...N-1 values
    n_labels = segmentation_label_list.shape[0]
    _, lut = utils.rearrange_label_list(segmentation_label_list)
    labels_gt = l2i_et.convert_labels(labels_gt, lut)

    # convert gt labels to probabilistic values
    labels_gt = KL.Lambda(lambda x: tf.one_hot(
        tf.cast(x, dtype='int32'), depth=n_labels, axis=-1))(labels_gt)
    labels_gt = KL.Reshape(input_shape)(labels_gt)
    labels_gt = KL.Lambda(lambda x: K.clip(
        x / K.sum(x, axis=-1, keepdims=True), K.epsilon(), 1))(labels_gt)

    # crop output to evaluate loss function in centre patch
    if loss_cropping is not None:
        # format loss_cropping
        labels_shape = labels_gt.get_shape().as_list()[1:-1]
        n_dims, _ = utils.get_dims(labels_shape)
        loss_cropping = [-1] + utils.reformat_to_list(loss_cropping,
                                                      length=n_dims) + [-1]
        # perform cropping
        begin_idx = [0] + [
            int((labels_shape[i] - loss_cropping[i]) / 2)
            for i in range(n_dims)
        ] + [0]
        labels_gt = KL.Lambda(lambda x: tf.slice(
            x,
            begin=tf.convert_to_tensor(begin_idx, dtype='int32'),
            size=tf.convert_to_tensor(loss_cropping, dtype='int32')))(
                labels_gt)
        last_tensor = KL.Lambda(lambda x: tf.slice(
            x,
            begin=tf.convert_to_tensor(begin_idx, dtype='int32'),
            size=tf.convert_to_tensor(loss_cropping, dtype='int32')))(
                last_tensor)

    # metrics is computed as part of the model
    if metrics == 'dice':

        # make sure predicted values are probabilistic
        last_tensor = KL.Lambda(lambda x: K.clip(
            x / K.sum(x, axis=-1, keepdims=True), K.epsilon(), 1))(last_tensor)

        # compute dice
        top = KL.Lambda(lambda x: 2 * x[0] * x[1])([labels_gt, last_tensor])
        bottom = KL.Lambda(lambda x: K.square(x[0]) + K.square(x[1]))(
            [labels_gt, last_tensor])
        for dims_to_sum in range(len(input_shape) - 1):
            top = KL.Lambda(lambda x: K.sum(x, axis=1))(top)
            bottom = KL.Lambda(lambda x: K.sum(x, axis=1))(bottom)
        last_tensor = KL.Lambda(lambda x: x[0] / K.maximum(x[1], 0.001),
                                name='dice')([top, bottom])  # 1d vector

        # compute mean dice loss
        if include_background:
            w = np.ones([n_labels]) / n_labels
        else:
            w = np.ones([n_labels]) / (n_labels - 1)
            w[0] = 0.0
        last_tensor = KL.Lambda(lambda x: 1 - x, name='dice_loss')(last_tensor)
        last_tensor = KL.Lambda(lambda x: K.sum(
            x * tf.convert_to_tensor(w, dtype='float32'), axis=1),
                                name='mean_dice_loss')(last_tensor)
        # average mean dice loss over mini batch
        last_tensor = KL.Lambda(lambda x: K.mean(x),
                                name='average_mean_dice_loss')(last_tensor)

    elif metrics == 'wl2':
        # compute weighted l2 loss
        weights = KL.Lambda(lambda x: K.expand_dims(1 - x[
            ..., 0] + weight_background))(labels_gt)
        normaliser = KL.Lambda(lambda x: K.sum(x[0]) * K.int_shape(x[1])[-1])(
            [weights, last_tensor])
        last_tensor = KL.Lambda(
            # lambda x: K.sum(x[2] * K.square(x[1] - (x[0] * 30 - 15))) / x[3],
            lambda x: K.sum(x[2] * K.square(x[1] - (x[0] * 6 - 3))) / x[3],
            name='wl2')([labels_gt, last_tensor, weights, normaliser])

    else:
        raise Exception(
            'metrics should either be "dice or "wl2, got {}'.format(metrics))

    # create the model and return
    model = Model(inputs=input_tensor, outputs=last_tensor, name=model_name)
    return model
Exemple #16
0
def build_intensity_stats(list_image_dir,
                          list_labels_dir,
                          result_dir,
                          estimation_labels,
                          estimation_classes=None,
                          max_channel=3,
                          rescale=True):
    """This function aims at estimating the intensity distributions of K different structure types from a set of images.
    The distribution of each structure type is modelled as a Gaussian, parametrised by a mean and a standard deviation.
    Because the intensity distribution of structures can vary accross images, we additionally use Gausian priors for the
    parameters of each Gaussian distribution. Therefore, the intensity distribution of each structure type is described
    by 4 parameters: a mean/std for the mean intensity, and a mean/std for the std deviation.
    This function uses a set of images along with corresponding segmentations to estimate the 4*K parameters.
    Additionally, it can estimate the 4*K parameters for several image datasets, that we call here n_datasets.
    This function writes 2 numpy arrays of size (2*n_datasets, K), one with the evaluated means/std for the mean
    intensities, and one for the mean/std for the standard deviations.
    In these arrays, each block of two rows refer to a different dataset.
    Within each block of two rows, the first row represents the mean, and the second represents the std.
    :param list_image_dir: path of folders with images for intensity distribution estimation.
    Can be the path of single directory (n_datasets=1), or a list of folders, each being a separate dataset.
    Images can be multimodal, in which case each modality is treated as a different dataset, i.e. each modality will
    have a separate block (of size (2, K)) in the result arrays.
    :param list_labels_dir: path of folders with label maps corresponding to input images.
    If list_image_dir is a list of several folders, list_labels_dir can either be a list of folders (one for each image
    folder), or the path to a single folder, which will be used for all datasets.
    If a dataset has multi-modal images, the same label map is applied to all modalities.
    :param result_dir: path of directory where estimated priors will be writen.
    :param estimation_labels: labels to estimate intensity statistics from.
    Can be a sequence, a 1d numpy array, or the path to a 1d numpy array.
    :param estimation_classes: (optional) enables to regroup structures into classes of similar intensity statistics.
    Intenstites associated to regrouped labels will thus contribute to the same Gaussian during statistics estimation.
    Can be a sequence, a 1d numpy array, or the path to a 1d numpy array.
    It should have the same length as labels_list, and contain values between 0 and K-1, where K is the total number of
    classes. Default is all labels have different classes (K=len(estimation_labels)).
    :param max_channel: (optional) maximum number of channels to consider if the data is multispectral. Default is 3.
    :param rescale: (optional) whether to rescale images between 0 and 255 before intensity estimation
    """

    # handle results directories
    utils.mkdir(result_dir)

    # reformat image/labels dir into lists
    list_image_dir = utils.reformat_to_list(list_image_dir)
    list_labels_dir = utils.reformat_to_list(list_labels_dir, length=len(list_image_dir))

    # reformat list estimation labels and classes
    estimation_labels = np.array(utils.reformat_to_list(estimation_labels, load_as_numpy=True, dtype='int'))
    if estimation_classes is not None:
        estimation_classes = np.array(utils.reformat_to_list(estimation_classes, load_as_numpy=True, dtype='int'))
    else:
        estimation_classes = np.arange(estimation_labels.shape[0])
    assert len(estimation_classes) == len(estimation_labels), 'estimation labels and classes should be of same length'

    # get unique classes
    unique_estimation_classes, unique_indices = np.unique(estimation_classes, return_index=True)
    n_classes = len(unique_estimation_classes)
    if not np.array_equal(unique_estimation_classes, np.arange(n_classes)):
        raise ValueError('estimation_classes should only contain values between 0 and N-1, '
                         'where K is the total number of classes. Here N = %d' % n_classes)

    # loop over dataset
    list_datasets_prior_means = list()
    list_datasets_prior_stds = list()
    for image_dir, labels_dir in zip(list_image_dir, list_labels_dir):

        # get prior stats for dataset
        tmp_prior_means, tmp_prior_stds = sample_intensity_stats_from_single_dataset(image_dir,
                                                                                     labels_dir,
                                                                                     estimation_labels,
                                                                                     estimation_classes,
                                                                                     max_channel=max_channel,
                                                                                     rescale=rescale)

        # add stats arrays to list of datasets-wise statistics
        list_datasets_prior_means.append(tmp_prior_means)
        list_datasets_prior_stds.append(tmp_prior_stds)

    # stack all modalities together
    prior_means = np.concatenate(list_datasets_prior_means, axis=0)
    prior_stds = np.concatenate(list_datasets_prior_stds, axis=0)

    # save files
    np.save(os.path.join(result_dir, 'prior_means.npy'), prior_means)
    np.save(os.path.join(result_dir, 'prior_stds.npy'), prior_stds)

    return prior_means, prior_stds
Exemple #17
0
def sample_intensity_stats_from_single_dataset(image_dir, labels_dir, labels_list, classes_list=None, max_channel=3,
                                               rescale=True):
    """This function aims at estimating the intensity distributions of K different structure types from a set of images.
    The distribution of each structure type is modelled as a Gaussian, parametrised by a mean and a standard deviation.
    Because the intensity distribution of structures can vary accross images, we additionally use Gausian priors for the
    parameters of each Gaussian distribution. Therefore, the intensity distribution of each structure type is described
    by 4 parameters: a mean/std for the mean intensity, and a mean/std for the std deviation.
    This function uses a set of images along with corresponding segmentations to estimate the 4*K parameters.
    Structures can share the same statistics by being regrouped into classes of similar structure types.
    Images can be multi-modal (n_channels), in which case different statistics are estimated for each modality.
    :param image_dir: path of directory with images to estimate the intensity distribution
    :param labels_dir: path of directory with segmentation of input images.
    They are matched with images by sorting order.
    :param labels_list: list of labels for which to evaluate mean and std intensity.
    Can be a sequence, a 1d numpy array, or the path to a 1d numpy array.
    :param classes_list: (optional) enables to regroup structures into classes of similar intensity statistics.
    Intenstites associated to regrouped labels will thus contribute to the same Gaussian during statistics estimation.
    Can be a sequence, a 1d numpy array, or the path to a 1d numpy array.
    It should have the same length as labels_list, and contain values between 0 and K-1, where K is the total number of
    classes. Default is all labels have different classes (K=len(labels_list)).
    :param max_channel: (optional) maximum number of channels to consider if the data is multispectral. Default is 3.
    :param rescale: (optional) whether to rescale images between 0 and 255 before intensity estimation
    :return: 2 numpy arrays of size (2*n_channels, K), one with the evaluated means/std for the mean
    intensity, and one for the mean/std for the standard deviation.
    Each block of two rows correspond to a different modality (channel). For each block of two rows, the first row
    represents the mean, and the second represents the std.
    """

    # list files
    path_images = utils.list_images_in_folder(image_dir)
    path_labels = utils.list_images_in_folder(labels_dir)
    assert len(path_images) == len(path_labels), 'image and labels folders do not have the same number of files'

    # reformat list labels and classes
    labels_list = np.array(utils.reformat_to_list(labels_list, load_as_numpy=True, dtype='int'))
    if classes_list is not None:
        classes_list = np.array(utils.reformat_to_list(classes_list, load_as_numpy=True, dtype='int'))
    else:
        classes_list = np.arange(labels_list.shape[0])
    assert len(classes_list) == len(labels_list), 'labels and classes lists should have the same length'

    # get unique classes
    unique_classes, unique_indices = np.unique(classes_list, return_index=True)
    n_classes = len(unique_classes)
    if not np.array_equal(unique_classes, np.arange(n_classes)):
        raise ValueError('classes_list should only contain values between 0 and K-1, '
                         'where K is the total number of classes. Here K = %d' % n_classes)

    # initialise result arrays
    n_dims, n_channels = utils.get_dims(utils.load_volume(path_images[0]).shape, max_channels=max_channel)
    means = np.zeros((len(path_images), n_classes, n_channels))
    stds = np.zeros((len(path_images), n_classes, n_channels))

    # loop over images
    loop_info = utils.LoopInfo(len(path_images), 10, 'estimating', print_time=True)
    for idx, (path_im, path_la) in enumerate(zip(path_images, path_labels)):
        loop_info.update(idx)

        # load image and label map
        image = utils.load_volume(path_im)
        la = utils.load_volume(path_la)
        if n_channels == 1:
            image = utils.add_axis(image, -1)

        # loop over channels
        for channel in range(n_channels):
            im = image[..., channel]
            if rescale:
                im = edit_volumes.rescale_volume(im)
            stats = sample_intensity_stats_from_image(im, la, labels_list, classes_list=classes_list)
            means[idx, :, channel] = stats[0, :]
            stds[idx, :, channel] = stats[1, :]

    # compute prior parameters for mean/std
    mean_means = np.mean(means, axis=0)
    std_means = np.std(means, axis=0)
    mean_stds = np.mean(stds, axis=0)
    std_stds = np.std(stds, axis=0)

    # regroup prior parameters in two different arrays: one for the mean and one for the std
    prior_means = np.zeros((2 * n_channels, n_classes))
    prior_stds = np.zeros((2 * n_channels, n_classes))
    for channel in range(n_channels):
        prior_means[2 * channel, :] = mean_means[:, channel]
        prior_means[2 * channel + 1, :] = std_means[:, channel]
        prior_stds[2 * channel, :] = mean_stds[:, channel]
        prior_stds[2 * channel + 1, :] = std_stds[:, channel]

    return prior_means, prior_stds
def build_augmentation_model(im_shape,
                             n_channels,
                             label_list,
                             image_res,
                             target_res=None,
                             output_shape=None,
                             output_div_by_n=None,
                             n_neutral_labels=1,
                             flipping=True,
                             flip_rl_only=False,
                             aff=None,
                             scaling_bounds=0.15,
                             rotation_bounds=15,
                             enable_90_rotations=False,
                             shearing_bounds=0.012,
                             translation_bounds=False,
                             nonlin_std=3.,
                             nonlin_shape_factor=.0625,
                             bias_field_std=.3,
                             bias_shape_factor=0.025,
                             same_bias_for_all_channels=False,
                             apply_intensity_augmentation=True,
                             noise_std=1.,
                             augment_channels_separately=True):

    # reformat resolutions
    im_shape = utils.reformat_to_list(im_shape)
    n_dims, _ = utils.get_dims(im_shape)
    image_res = utils.reformat_to_list(image_res, length=n_dims)
    target_res = image_res if target_res is None else utils.reformat_to_list(
        target_res, length=n_dims)

    # get shapes
    cropping_shape, output_shape = get_shapes(im_shape, output_shape,
                                              image_res, target_res,
                                              output_div_by_n)
    im_shape = im_shape + [n_channels]

    # create new_label_list and corresponding LUT to make sure that labels go from 0 to N-1
    new_label_list, lut = utils.rearrange_label_list(label_list)

    # define model inputs
    image_input = KL.Input(shape=im_shape, name='image_input')
    labels_input = KL.Input(shape=im_shape[:-1] + [1],
                            name='labels_input',
                            dtype='int32')

    # convert labels to new_label_list
    labels = l2i_et.convert_labels(labels_input, lut)

    # flipping
    if flipping:
        if flip_rl_only:
            labels, image = layers.RandomFlip(
                int(edit_volumes.get_ras_axes(aff, n_dims)[0]), [True, False],
                new_label_list, n_neutral_labels)([labels, image_input])
        else:
            labels, image = layers.RandomFlip(
                None, [True, False], new_label_list,
                n_neutral_labels)([labels, image_input])
    else:
        image = image_input

    # transform labels to soft prob. and concatenate them to the image
    labels = KL.Lambda(lambda x: tf.one_hot(
        tf.cast(x[..., 0], dtype='int32'), depth=len(label_list), axis=-1))(
            labels)
    image = KL.concatenate([image, labels], axis=len(im_shape))

    # spatial deformation
    if (scaling_bounds is not False) | (rotation_bounds is not False) | (shearing_bounds is not False) | \
       (translation_bounds is not False) | (nonlin_std > 0) | enable_90_rotations:
        image._keras_shape = tuple(image.get_shape().as_list())
        image = layers.RandomSpatialDeformation(
            scaling_bounds=scaling_bounds,
            rotation_bounds=rotation_bounds,
            shearing_bounds=shearing_bounds,
            translation_bounds=translation_bounds,
            enable_90_rotations=enable_90_rotations,
            nonlin_std=nonlin_std,
            nonlin_shape_factor=nonlin_shape_factor)(image)

    # cropping
    if cropping_shape != im_shape[:-1]:
        image._keras_shape = tuple(image.get_shape().as_list())
        image = layers.RandomCrop(cropping_shape)(image)

    # resampling (image blurred separately)
    if cropping_shape != output_shape:
        sigma = l2i_et.blurring_sigma_for_downsampling(image_res, target_res)
        split = KL.Lambda(
            lambda x: tf.split(x, [n_channels, -1], axis=len(im_shape)))(image)
        image = split[0]
        image._keras_shape = tuple(image.get_shape().as_list())
        image = layers.GaussianBlur(sigma=sigma)(image)
        image = KL.concatenate([image, split[-1]])
        image = l2i_et.resample_tensor(image, output_shape)

    # split tensor between image and labels
    image, labels = KL.Lambda(
        lambda x: tf.split(x, [n_channels, -1], axis=len(im_shape)),
        name='splitting')(image)

    # apply bias field
    if bias_field_std > 0:
        image._keras_shape = tuple(image.get_shape().as_list())
        image = layers.BiasFieldCorruption(bias_field_std, bias_shape_factor,
                                           same_bias_for_all_channels)(image)

    # intensity augmentation
    if apply_intensity_augmentation:
        image._keras_shape = tuple(image.get_shape().as_list())
        image = layers.IntensityAugmentation(
            noise_std,
            gamma_std=0.5,
            separate_channels=augment_channels_separately)(image)

    # build model
    im_trans_model = Model(inputs=[image_input, labels_input],
                           outputs=[image, labels])

    return im_trans_model
Exemple #19
0
def build_model(model_file, input_shape, resample, im_res, n_levels, n_lab, conv_size, nb_conv_per_level,
                unet_feat_count, feat_multiplier, activation, sigma_smoothing):

    assert os.path.isfile(model_file), "The provided model path does not exist."

    # initialisation
    net = None
    n_dims, n_channels = utils.get_dims(input_shape, max_channels=10)
    resample = utils.reformat_to_list(resample, length=n_dims)

    # build preprocessing model
    if resample is not None:
        im_input = KL.Input(shape=input_shape, name='pre_resample_input')
        resample_factor = [im_res[i] / float(resample[i]) for i in range(n_dims)]
        resample_shape = [utils.find_closest_number_divisible_by_m(resample_factor[i] * input_shape[i],
                          2 ** n_levels, smaller_ans=False) for i in range(n_dims)]
        resampled = nrn_layers.Resize(size=resample_shape, name='pre_resample')(im_input)
        net = Model(inputs=im_input, outputs=resampled)
        input_shape = resample_shape + [n_channels]

    # build UNet
    net = nrn_models.unet(nb_features=unet_feat_count,
                          input_shape=input_shape,
                          nb_levels=n_levels,
                          conv_size=conv_size,
                          nb_labels=n_lab,
                          name='unet',
                          prefix=None,
                          feat_mult=feat_multiplier,
                          pool_size=2,
                          use_logp=True,
                          padding='same',
                          dilation_rate_mult=1,
                          activation=activation,
                          use_residuals=False,
                          final_pred_activation='softmax',
                          nb_conv_per_level=nb_conv_per_level,
                          add_prior_layer=False,
                          add_prior_layer_reg=0,
                          layer_nb_feats=None,
                          conv_dropout=0,
                          batch_norm=-1,
                          input_model=net)
    net.load_weights(model_file, by_name=True)

    # build postprocessing model
    if (resample is not None) | (sigma_smoothing != 0):

        # get UNet output
        input_tensor = net.inputs
        last_tensor = net.output

        # resample to initial resolution
        if resample is not None:
            last_tensor = nrn_layers.Resize(size=input_shape[:-1], name='post_resample')(last_tensor)

        # smooth posteriors
        if sigma_smoothing != 0:
            last_tensor._keras_shape = tuple(last_tensor.get_shape().as_list())
            last_tensor = layers.GaussianBlur(sigma=sigma_smoothing)(last_tensor)

        # build model
        net = Model(inputs=input_tensor, outputs=last_tensor)

    return net
def labels_to_image_model(labels_shape,
                          n_channels,
                          generation_labels,
                          output_labels,
                          n_neutral_labels,
                          atlas_res,
                          target_res,
                          output_shape=None,
                          output_div_by_n=None,
                          flipping=True,
                          aff=None,
                          scaling_bounds=0.15,
                          rotation_bounds=15,
                          shearing_bounds=0.012,
                          translation_bounds=False,
                          nonlin_std=4.,
                          nonlin_shape_factor=.0625,
                          randomise_res=False,
                          buil_distance_maps=False,
                          data_res=None,
                          thickness=None,
                          downsample=False,
                          blur_range=1.15,
                          bias_field_std=.5,
                          bias_shape_factor=.025):
    """
    This function builds a keras/tensorflow model to generate images from provided label maps.
    The images are generated by sampling a Gaussian Mixture Model (of given parameters), conditionned on the label map.
    The model will take as inputs:
        -a label map
        -a vector containing the means of the Gaussian Mixture Model for each label,
        -a vector containing the standard deviations of the Gaussian Mixture Model for each label,
        -if apply_affine_deformation is True: a batch*(n_dims+1)*(n_dims+1) affine matrix
        -if apply_non_linear_deformation is True: a small non linear field of size batch*(dim_1*...*dim_n)*n_dims that
        will be resampled to labels size and integrated, to obtain a diffeomorphic elastic deformation.
        -if apply_bias_field is True: a small bias field of size batch*(dim_1*...*dim_n)*1 that will be resampled to
        labels size and multiplied to the image, to add a "bias-field" noise.
    The model returns:
        -the generated image normalised between 0 and 1.
        -the corresponding label map, with only the labels present in output_labels (the other are reset to zero).
    # IMPORTANT !!!
    # Each time we provide a parameter with separate values for each axis (e.g. with a numpy array or a sequence),
    # these values refer to the RAS axes.
    :param labels_shape: shape of the input label maps. Can be a sequence or a 1d numpy array.
    :param n_channels: number of channels to be synthetised.
    :param generation_labels: (optional) list of all possible label values in the input label maps.
    Default is None, where the label values are directly gotten from the provided label maps.
    If not None, can be a sequence or a 1d numpy array. It should be organised as follows: background label first, then
    non-sided labels (e.g. CSF, brainstem, etc.), then all the structures of the same hemisphere (can be left or right),
    and finally all the corresponding contralateral structures (in the same order).
    :param output_labels: list of all the label values to keep in the output label maps, in no particular order.
    Should be a subset of the values contained in generation_labels.
    Label values that are in generation_labels but not in output_labels are reset to zero.
    Can be a sequence or a 1d numpy array. By default output_labels is equal to generation_labels.
    :param n_neutral_labels: number of non-sided generation labels.
    :param atlas_res: resolution of the input label maps.
    Can be a number (isotropic resolution), a sequence, or a 1d numpy array.
    :param target_res: target resolution of the generated images and corresponding label maps.
    Can be a number (isotropic resolution), a sequence, or a 1d numpy array.
    :param output_shape: (optional) desired shape of the output image, obtained by randomly cropping the generated image
    Can be an integer (same size in all dimensions), a sequence, a 1d numpy array, or the path to a 1d numpy array.
    Default is None, where no cropping is performed.
    :param output_div_by_n: (optional) forces the output shape to be divisible by this value. It overwrites output_shape
    if necessary. Can be an integer (same size in all dimensions), a sequence, or a 1d numpy array.
    :param flipping: (optional) whether to introduce right/left random flipping
    :param aff: (optional) example of an (n_dims+1)x(n_dims+1) affine matrix of one of the input label map.
    Used to find brain's right/left axis. Should be given if flipping is True.
    :param scaling_bounds: (optional) range of the random saling to apply at each mini-batch. The scaling factor for
    each dimension is sampled from a uniform distribution of predefined bounds. Can either be:
    1) a number, in which case the scaling factor is independently sampled from the uniform distribution of bounds
    [1-scaling_bounds, 1+scaling_bounds] for each dimension.
    2) a sequence, in which case the scaling factor is sampled from the uniform distribution of bounds
    (1-scaling_bounds[i], 1+scaling_bounds[i]) for the i-th dimension.
    3) a numpy array of shape (2, n_dims), in which case the scaling factor is sampled from the uniform distribution
     of bounds (scaling_bounds[0, i], scaling_bounds[1, i]) for the i-th dimension.
    4) False, in which case scaling is completely turned off.
    Default is scaling_bounds = 0.15 (case 1)
    :param rotation_bounds: (optional) same as scaling bounds but for the rotation angle, except that for cases 1
    and 2, the bounds are centred on 0 rather than 1, i.e. [0+rotation_bounds[i], 0-rotation_bounds[i]].
    Default is rotation_bounds = 15.
    :param shearing_bounds: (optional) same as scaling bounds. Default is shearing_bounds = 0.012.
    :param translation_bounds: (optional) same as scaling bounds. Default is translation_bounds = False, but we
    encourage using it when cropping is deactivated (i.e. when output_shape=None in BrainGenerator).
    :param nonlin_std: (optional) Maximum value for the standard deviation of the normal distribution from which we
    sample the first tensor for synthesising the deformation field. Set to 0 if you wish to completely turn the elastic
    deformation off.
    :param nonlin_shape_factor: (optional) if nonlin_std is strictly positive, factor between the shapes of the input
    label maps and the shape of the input non-linear tensor.
    :param randomise_res: (optional) whether to mimic images that would have been 1) acquired at low resolution, and
    2) resampled to high esolution. The low resolution is uniformly resampled at each minibatch from [1mm, 9mm].
    In that process, the images generated by sampling the GMM are 1) blurred at the sampled LR, 2) downsampled at LR,
    and 3) resampled at target_resolution.
    :param data_res: (optional) specific acquisition resolution to mimic, as opposed to random resolution sampled when
    randomis_res is True. This triggers a blurring to mimic the specified acquisition resolution, but the downsampling
    is optional (see param downsample). Default for data_res is None, where images are slighlty blurred.
    If the generated images are uni-modal, data_res can be a number (isotropic acquisition resolution), a sequence, a 1d
    numpy array, or the path to a 1d numy array. In the multi-modal case, it should be given as a numpy array (or a
    path) of size (n_mod, n_dims), where each row is the acquisition resolution of the corresponding channel.
    :param thickness: (optional) if data_res is provided, we can further specify the slice thickness of the low
    resolution images to mimic. Must be provided in the same format as data_res. Default thickness = data_res.
    :param downsample: (optional) whether to actually downsample the volume images to data_res after blurring.
    Default is False, except when thickness is provided, and thickness < data_res.
    :param blur_range: (optional) Randomise the standard deviation of the blurring kernels, (whether data_res is given
    or not). At each mini_batch, the standard deviation of the blurring kernels are multiplied by a coefficient sampled
    from a uniform distribution with bounds [1/blur_range, blur_range]. If None, no randomisation. Default is 1.15.
    :param bias_field_std: (optional) If strictly positive, this triggers the corruption of synthesised images with a
    bias field. It is obtained by sampling a first small tensor from a normal distribution, resizing it to full size,
    and rescaling it to positive values by taking the voxel-wise exponential. bias_field_std designates the std dev of
    the normal distribution from which we sample the first tensor. Set to 0 to deactivate biad field corruption.
    :param bias_shape_factor: (optional) If bias_field_std is strictly positive, this designates the ratio between the
    size of the input label maps and the size of the first sampled tensor for synthesising the bias field.
    """

    # reformat resolutions
    labels_shape = utils.reformat_to_list(labels_shape)
    n_dims, _ = utils.get_dims(labels_shape)
    atlas_res = utils.reformat_to_n_channels_array(atlas_res, n_dims,
                                                   n_channels)
    data_res = atlas_res if (
        data_res is None) else utils.reformat_to_n_channels_array(
            data_res, n_dims, n_channels)
    thickness = data_res if (
        thickness is None) else utils.reformat_to_n_channels_array(
            thickness, n_dims, n_channels)
    downsample = utils.reformat_to_list(
        downsample,
        n_channels) if downsample else (np.min(thickness - data_res, 1) < 0)
    atlas_res = atlas_res[0]
    target_res = atlas_res if (
        target_res is None) else utils.reformat_to_n_channels_array(
            target_res, n_dims)[0]

    # get shapes
    crop_shape, output_shape = get_shapes(labels_shape, output_shape,
                                          atlas_res, target_res,
                                          output_div_by_n)

    # create new_label_list and corresponding LUT to make sure that labels go from 0 to N-1
    new_generation_labels, lut = utils.rearrange_label_list(generation_labels)

    # define model inputs
    labels_input = KL.Input(shape=labels_shape + [1], name='labels_input')
    means_input = KL.Input(shape=list(new_generation_labels.shape) +
                           [n_channels],
                           name='means_input')
    stds_input = KL.Input(shape=list(new_generation_labels.shape) +
                          [n_channels],
                          name='std_devs_input')

    # convert labels to new_label_list
    labels = l2i_et.convert_labels(labels_input, lut)

    # deform labels
    if (scaling_bounds is not False) | (rotation_bounds is not False) | (shearing_bounds is not False) | \
       (translation_bounds is not False) | (nonlin_std > 0):
        labels._keras_shape = tuple(labels.get_shape().as_list())
        labels = layers.RandomSpatialDeformation(
            scaling_bounds=scaling_bounds,
            rotation_bounds=rotation_bounds,
            shearing_bounds=shearing_bounds,
            translation_bounds=translation_bounds,
            nonlin_std=nonlin_std,
            nonlin_shape_factor=nonlin_shape_factor,
            inter_method='nearest')(labels)

    # cropping
    if crop_shape != labels_shape:
        labels._keras_shape = tuple(labels.get_shape().as_list())
        labels = layers.RandomCrop(crop_shape)(labels)

    # flipping
    if flipping:
        assert aff is not None, 'aff should not be None if flipping is True'
        labels._keras_shape = tuple(labels.get_shape().as_list())
        labels = layers.RandomFlip(
            get_ras_axes(aff, n_dims)[0], True, new_generation_labels,
            n_neutral_labels)(labels)

    # build synthetic image
    labels._keras_shape = tuple(labels.get_shape().as_list())
    image = layers.SampleConditionalGMM()([labels, means_input, stds_input])

    # apply bias field
    if bias_field_std > 0:
        image._keras_shape = tuple(image.get_shape().as_list())
        image = layers.BiasFieldCorruption(bias_field_std, bias_shape_factor,
                                           False)(image)

    # intensity augmentation
    image._keras_shape = tuple(image.get_shape().as_list())
    image = layers.IntensityAugmentation(clip=300,
                                         normalise=True,
                                         gamma_std=.4,
                                         separate_channels=True)(image)

    # loop over channels
    channels = list()
    split = KL.Lambda(lambda x: tf.split(x, [1] * n_channels, axis=-1))(
        image) if (n_channels > 1) else [image]
    for i, channel in enumerate(split):

        channel._keras_shape = tuple(channel.get_shape().as_list())

        if randomise_res:
            max_res = np.array([9.] * 3)
            resolution, blur_res = layers.SampleResolution(
                atlas_res, max_res, .05, return_thickness=True)(means_input)
            sigma = l2i_et.blurring_sigma_for_downsampling(atlas_res,
                                                           resolution,
                                                           thickness=blur_res)
            channel = layers.DynamicGaussianBlur(
                0.75 * max_res / np.array(atlas_res),
                blur_range)([channel, sigma])
            if buil_distance_maps:
                channel, dist = layers.MimicAcquisition(
                    atlas_res, atlas_res, output_shape,
                    True)([channel, resolution])
                channels.extend([channel, dist])
            else:
                channel = layers.MimicAcquisition(atlas_res, atlas_res,
                                                  output_shape,
                                                  False)([channel, resolution])
                channels.append(channel)

        else:
            sigma = l2i_et.blurring_sigma_for_downsampling(
                atlas_res, data_res[i], thickness=thickness[i])
            channel = layers.GaussianBlur(sigma, blur_range)(channel)
            if downsample[i]:
                resolution = KL.Lambda(lambda x: tf.convert_to_tensor(
                    data_res[i], dtype='float32'))([])
                channel = layers.MimicAcquisition(atlas_res, data_res[i],
                                                  output_shape)(
                                                      [channel, resolution])
            elif output_shape != crop_shape:
                channel = nrn_layers.Resize(size=output_shape)(channel)
            channels.append(channel)

    # concatenate all channels back
    image = KL.Lambda(lambda x: tf.concat(x, -1))(
        channels) if len(channels) > 1 else channels[0]

    # resample labels at target resolution
    if crop_shape != output_shape:
        labels = l2i_et.resample_tensor(labels,
                                        output_shape,
                                        interp_method='nearest')

    # convert labels back to original values and reset unwanted labels to zero
    labels = l2i_et.convert_labels(labels, generation_labels)
    labels._keras_shape = tuple(labels.get_shape().as_list())
    reset_values = [v for v in generation_labels if v not in output_labels]
    labels = layers.ResetValuesToZero(reset_values, name='labels_out')(labels)

    # build model (dummy layer enables to keep the labels when plugging this model to other models)
    image = KL.Lambda(lambda x: x[0], name='image_out')([image, labels])
    brain_model = Model(inputs=[labels_input, means_input, stds_input],
                        outputs=[image, labels])

    return brain_model
def build_augmentation_model(im_shape,
                             n_channels,
                             segmentation_labels,
                             n_neutral_labels,
                             output_shape=None,
                             output_div_by_n=None,
                             flipping=True,
                             aff=None,
                             scaling_bounds=0.15,
                             rotation_bounds=15,
                             shearing_bounds=0.012,
                             translation_bounds=False,
                             nonlin_std=3.,
                             nonlin_shape_factor=.0625,
                             bias_field_std=.3,
                             bias_shape_factor=.025):

    # reformat resolutions and get shapes
    im_shape = utils.reformat_to_list(im_shape)
    n_dims, _ = utils.get_dims(im_shape)
    crop_shape = get_shapes(im_shape, output_shape, output_div_by_n)

    # create new_label_list and corresponding LUT to make sure that labels go from 0 to N-1
    new_seg_labels, lut = utils.rearrange_label_list(segmentation_labels)

    # define model inputs
    image_input = KL.Input(shape=im_shape+[n_channels], name='image_input')
    labels_input = KL.Input(shape=im_shape + [1], name='labels_input')

    # convert labels to new_label_list
    labels = convert_labels(labels_input, lut)

    # deform labels
    if (scaling_bounds is not False) | (rotation_bounds is not False) | (shearing_bounds is not False) | \
       (translation_bounds is not False) | (nonlin_std > 0):
        labels._keras_shape = tuple(labels.get_shape().as_list())
        labels, image = layers.RandomSpatialDeformation(scaling_bounds=scaling_bounds,
                                                        rotation_bounds=rotation_bounds,
                                                        shearing_bounds=shearing_bounds,
                                                        translation_bounds=translation_bounds,
                                                        nonlin_std=nonlin_std,
                                                        nonlin_shape_factor=nonlin_shape_factor,
                                                        inter_method=['nearest', 'linear'])([labels, image_input])
    else:
        image = image_input

    # crop labels
    if crop_shape != im_shape:
        labels._keras_shape = tuple(labels.get_shape().as_list())
        image._keras_shape = tuple(image.get_shape().as_list())
        labels, image = layers.RandomCrop(crop_shape)([labels, image])

    # flip labels
    if flipping:
        assert aff is not None, 'aff should not be None if flipping is True'
        labels._keras_shape = tuple(labels.get_shape().as_list())
        image._keras_shape = tuple(image.get_shape().as_list())
        labels, image = layers.RandomFlip(flip_axis=get_ras_axes(aff, n_dims)[0], swap_labels=[True, False],
                                          label_list=new_seg_labels, n_neutral_labels=n_neutral_labels)([labels, image])

    # apply bias field
    if bias_field_std > 0:
        image._keras_shape = tuple(image.get_shape().as_list())
        image = layers.BiasFieldCorruption(bias_field_std, bias_shape_factor, False)(image)
        image = KL.Lambda(lambda x: tf.cast(x, dtype='float32'), name='image_biased')(image)

    # intensity augmentation
    image._keras_shape = tuple(image.get_shape().as_list())
    image = layers.IntensityAugmentation(10, clip=False, normalise=True, gamma_std=.5, separate_channels=True)(image)
    image = KL.Lambda(lambda x: tf.cast(x, dtype='float32'), name='image_augmented')(image)

    # convert labels back to original values and reset unwanted labels to zero
    labels = convert_labels(labels, segmentation_labels)

    # build model (dummy layer enables to keep the labels when plugging this model to other models)
    labels = KL.Lambda(lambda x: tf.cast(x, dtype='int32'), name='labels_out')(labels)
    image = KL.Lambda(lambda x: x[0], name='image_out')([image, labels])
    brain_model = models.Model(inputs=[image_input, labels_input], outputs=[image, labels])

    return brain_model
Exemple #22
0
def evaluation(gt_dir,
               seg_dir,
               label_list,
               mask_dir=None,
               compute_score_whole_structure=False,
               path_dice=None,
               path_hausdorff=None,
               path_hausdorff_99=None,
               path_hausdorff_95=None,
               path_mean_distance=None,
               crop_margin_around_gt=10,
               list_incorrect_labels=None,
               list_correct_labels=None,
               use_nearest_label=False,
               recompute=True,
               verbose=True):
    """This function computes Dice scores, as well as surface distances, between two sets of labels maps in gt_dir
    (ground truth) and seg_dir (typically predictions). Labels maps in both folders are matched by sorting order.
    The resulting scores are saved at the specified locations.
    :param gt_dir: path of directory with gt label maps
    :param seg_dir: path of directory with label maps to compare to gt_dir. Matched to gt label maps by sorting order.
    :param label_list: list of label values for which to compute evaluation metrics. Can be a sequence, a 1d numpy
    array, or the path to such array.
    :param mask_dir: (optional) path of directory with masks of areas to ignore for each evaluated segmentation.
    Matched to gt label maps by sorting order. Default is None, where nothing is masked.
    :param compute_score_whole_structure: (optional) whether to also compute the selected scores for the whole segmented
    structure (i.e. scores are computed for a single structure obtained by regrouping all non-zero values). If True, the
    resulting scores are added as an extra row to the result matrices. Default is False.
    :param path_dice: path where the resulting Dice will be writen as numpy array.
    Default is None, where the array is not saved.
    :param path_hausdorff: path where the resulting Hausdorff distances will be writen as numpy array (only if
    compute_distances is True). Default is None, where the array is not saved.
    :param path_hausdorff_99: same as for path_hausdorff but for the 99th percentile of the boundary distance.
    :param path_hausdorff_95: same as for path_hausdorff but for the 95th percentile of the boundary distance.
    :param path_mean_distance: path where the resulting mean distances will be writen as numpy array (only if
    compute_distances is True). Default is None, where the array is not saved.
    :param crop_margin_around_gt: (optional) margin by which to crop around the gt volumes, in order to copute the
    scores more efficiently. If None, no cropping is performed.
    :param list_incorrect_labels: (optional) this option enables to replace some label values in the maps in seg_dir by
    other label values. Can be a list, a 1d numpy array, or the path to such an array.
    The incorrect labels can then be replaced either by specified values, or by the nearest value (see below).
    :param list_correct_labels: (optional) list of values to correct the labels specified in list_incorrect_labels.
    Correct values must have the same order as their corresponding value in list_incorrect_labels.
    :param use_nearest_label: (optional) whether to correct the incorrect lavel values with the nearest labels.
    :param recompute: (optional) whether to recompute the already existing results. Default is True.
    :param verbose: (optional) whether to print out info about the remaining number of cases.
    """

    # check whether to recompute
    compute_dice = not os.path.isfile(path_dice) if (path_dice
                                                     is not None) else True
    compute_hausdorff = not os.path.isfile(path_hausdorff) if (
        path_hausdorff is not None) else False
    compute_hausdorff_99 = not os.path.isfile(path_hausdorff_99) if (
        path_hausdorff_99 is not None) else False
    compute_hausdorff_95 = not os.path.isfile(path_hausdorff_95) if (
        path_hausdorff_95 is not None) else False
    compute_mean_dist = not os.path.isfile(path_mean_distance) if (
        path_mean_distance is not None) else False
    compute_hd = [
        compute_hausdorff, compute_hausdorff_99, compute_hausdorff_95
    ]

    if compute_dice | any(compute_hd) | compute_mean_dist | recompute:

        # get list label maps to compare
        path_gt_labels = utils.list_images_in_folder(gt_dir)
        path_segs = utils.list_images_in_folder(seg_dir)
        path_gt_labels = utils.reformat_to_list(path_gt_labels,
                                                length=len(path_segs))
        if len(path_gt_labels) != len(path_segs):
            print(
                'gt and segmentation folders must have the same amount of label maps.'
            )
        if mask_dir is not None:
            path_masks = utils.list_images_in_folder(mask_dir)
            if len(path_masks) != len(path_segs):
                print('not the same amount of masks and segmentations.')
        else:
            path_masks = [None] * len(path_segs)

        # load labels list
        label_list, _ = utils.get_list_labels(label_list=label_list,
                                              FS_sort=True,
                                              labels_dir=gt_dir)
        n_labels = len(label_list)
        max_label = np.max(label_list) + 1

        # initialise result matrices
        if compute_score_whole_structure:
            max_dists = np.zeros((n_labels + 1, len(path_segs), 3))
            mean_dists = np.zeros((n_labels + 1, len(path_segs)))
            dice_coefs = np.zeros((n_labels + 1, len(path_segs)))
        else:
            max_dists = np.zeros((n_labels, len(path_segs), 3))
            mean_dists = np.zeros((n_labels, len(path_segs)))
            dice_coefs = np.zeros((n_labels, len(path_segs)))

        # loop over segmentations
        loop_info = utils.LoopInfo(len(path_segs),
                                   10,
                                   'evaluating',
                                   print_time=True)
        for idx, (path_gt, path_seg, path_mask) in enumerate(
                zip(path_gt_labels, path_segs, path_masks)):
            if verbose:
                loop_info.update(idx)

            # load gt labels and segmentation
            gt_labels = utils.load_volume(path_gt, dtype='int')
            seg = utils.load_volume(path_seg, dtype='int')
            if path_mask is not None:
                mask = utils.load_volume(path_mask, dtype='bool')
                gt_labels[mask] = max_label
                seg[mask] = max_label

            # crop images
            if crop_margin_around_gt is not None:
                gt_labels, cropping = edit_volumes.crop_volume_around_region(
                    gt_labels, margin=crop_margin_around_gt)
                seg = edit_volumes.crop_volume_with_idx(seg, cropping)

            if list_incorrect_labels is not None:
                seg = edit_volumes.correct_label_map(seg,
                                                     list_incorrect_labels,
                                                     list_correct_labels,
                                                     use_nearest_label)

            # compute Dice scores
            dice_coefs[:n_labels, idx] = fast_dice(gt_labels, seg, label_list)

            # compute Dice scores for whole structures
            if compute_score_whole_structure:
                temp_gt = (gt_labels > 0) * 1
                temp_seg = (seg > 0) * 1
                dice_coefs[-1, idx] = dice(temp_gt, temp_seg)
            else:
                temp_gt = temp_seg = None

            # compute average and Hausdorff distances
            if any(compute_hd) | compute_mean_dist:

                # compute unique label values
                unique_gt_labels = np.unique(gt_labels)
                unique_seg_labels = np.unique(seg)

                # compute max/mean surface distances for all labels
                for index, label in enumerate(label_list):
                    if (label in unique_gt_labels) & (label
                                                      in unique_seg_labels):
                        mask_gt = np.where(gt_labels == label, True, False)
                        mask_seg = np.where(seg == label, True, False)
                        tmp_max_dists, mean_dists[index,
                                                  idx] = surface_distances(
                                                      mask_gt, mask_seg,
                                                      [100, 99, 95])
                        max_dists[index, idx, :] = np.array(tmp_max_dists)
                    else:
                        mean_dists[index, idx] = max(gt_labels.shape)
                        max_dists[index, idx, :] = np.array(
                            [max(gt_labels.shape)] * 3)

                # compute max/mean distances for whole structure
                if compute_score_whole_structure:
                    tmp_max_dists, mean_dists[-1, idx] = surface_distances(
                        temp_gt, temp_seg, [100, 99, 95])
                    max_dists[-1, idx, :] = np.array(tmp_max_dists)

        # write results
        if path_dice is not None:
            utils.mkdir(os.path.dirname(path_dice))
            np.save(path_dice, dice_coefs)
        if path_hausdorff is not None:
            utils.mkdir(os.path.dirname(path_hausdorff))
            np.save(path_hausdorff, max_dists[..., 0])
        if path_hausdorff_99 is not None:
            utils.mkdir(os.path.dirname(path_hausdorff_99))
            np.save(path_hausdorff_99, max_dists[..., 1])
        if path_hausdorff_95 is not None:
            utils.mkdir(os.path.dirname(path_hausdorff_95))
            np.save(path_hausdorff_95, max_dists[..., 2])
        if path_mean_distance is not None:
            utils.mkdir(os.path.dirname(path_mean_distance))
            np.save(path_mean_distance, max_dists[..., 2])
Exemple #23
0
def plot_validation_curves(list_net_validation_dirs,
                           eval_indices=None,
                           skip_first_dice_row=True,
                           size_max_circle=100,
                           figsize=(11, 6),
                           fontsize=18,
                           remove_legend=False):
    """This function plots the validation curves of several networks, based on the results of validate_training().
    It takes as input a list of validation folders (one for each network), each containing subfolders with dice scores
    for the corresponding validated epoch.
    :param list_net_validation_dirs: list of all the validation folders of the trainings to plot.
    :param eval_indices: (optional) compute the average Dice loss on a subset of labels indicated by the specified
    indices. Can be a sequence, 1d numpy array, or the path to such an array.
    :param skip_first_dice_row: if eval_indices is None, skip the first row of the dice matrices (usually background)
    :param size_max_circle: (optional) size of the marker for epochs achieveing the best validation scores.
    :param figsize: (optional) size of the figure to draw.
    :param fontsize: (optional) fontsize used for the graph."""

    if eval_indices is not None:
        eval_indices = utils.reformat_to_list(eval_indices, load_as_numpy=True)

    # loop over architectures
    plt.figure(figsize=figsize)
    for net_val_dir in list_net_validation_dirs:

        net_name = os.path.basename(os.path.dirname(net_val_dir))
        list_epochs_dir = utils.list_subfolders(net_val_dir, whole_path=False)

        # loop over epochs
        list_net_dice_scores = list()
        list_epochs = list()
        for epoch_dir in list_epochs_dir:

            # build names and create folders
            path_epoch_dice = os.path.join(net_val_dir, epoch_dir, 'dice.npy')
            if os.path.isfile(path_epoch_dice):
                if eval_indices is not None:
                    list_net_dice_scores.append(
                        np.mean(np.load(path_epoch_dice)[eval_indices, :]))
                else:
                    if skip_first_dice_row:
                        list_net_dice_scores.append(
                            np.mean(np.load(path_epoch_dice)[1:, :]))
                    else:
                        list_net_dice_scores.append(
                            np.mean(np.load(path_epoch_dice)))
                list_epochs.append(int(re.sub('[^0-9]', '', epoch_dir)))

        # plot validation scores for current architecture
        if list_net_dice_scores:  # check that archi has been validated for at least 1 epoch
            list_net_dice_scores = np.array(list_net_dice_scores)
            list_epochs = np.array(list_epochs)
            max_score = np.max(list_net_dice_scores)
            epoch_max_score = list_epochs[np.argmax(list_net_dice_scores)]
            print('\n' + net_name)
            print('epoch max score: %d' % epoch_max_score)
            print('max score: %0.3f' % max_score)
            plt.plot(list_epochs, list_net_dice_scores, label=net_name)
            plt.scatter(epoch_max_score, max_score, s=size_max_circle)

    # finalise plot
    plt.grid()
    plt.tick_params(axis='both', labelsize=fontsize)
    plt.ylabel('Dice scores', fontsize=fontsize)
    plt.xlabel('Epochs', fontsize=fontsize)
    plt.title('Validation curves', fontsize=fontsize)
    if not remove_legend:
        plt.legend(fontsize=fontsize)
    plt.tight_layout(pad=1)
    plt.show()
Exemple #24
0
def plot_validation_curves(list_validation_dirs,
                           architecture_names=None,
                           eval_indices=None,
                           skip_first_dice_row=True,
                           size_max_circle=100,
                           figsize=(11, 6),
                           y_lim=None,
                           fontsize=18,
                           list_linestyles=None,
                           list_colours=None,
                           plot_legend=False):
    """This function plots the validation curves of several networks, based on the results of validate_training().
    It takes as input a list of validation folders (one for each network), each containing subfolders with dice scores
    for the corresponding validated epoch.
    :param list_validation_dirs: list of all the validation folders of the trainings to plot.
    :param eval_indices: (optional) compute the average Dice on a subset of labels indicated by the specified indices.
    Can be a 1d numpy array, the path to such an array, or a list of 1d numpy arrays as long as list_validation_dirs.
    :param skip_first_dice_row: if eval_indices is None, skip the first row of the dice matrices (usually background)
    :param size_max_circle: (optional) size of the marker for epochs achieveing the best validation scores.
    :param figsize: (optional) size of the figure to draw.
    :param fontsize: (optional) fontsize used for the graph."""

    n_curves = len(list_validation_dirs)

    if eval_indices is not None:
        if isinstance(eval_indices, (np.ndarray, str)):
            if isinstance(eval_indices, str):
                eval_indices = np.load(eval_indices)
            eval_indices = np.squeeze(
                utils.reformat_to_n_channels_array(eval_indices,
                                                   n_dims=len(eval_indices)))
            eval_indices = [eval_indices] * len(list_validation_dirs)
        elif isinstance(eval_indices, list):
            for (i, e) in enumerate(eval_indices):
                if isinstance(e, np.ndarray):
                    eval_indices[i] = np.squeeze(
                        utils.reformat_to_n_channels_array(e, n_dims=len(e)))
                else:
                    raise TypeError(
                        'if provided as a list, eval_indices should only contain numpy arrays'
                    )
        else:
            raise TypeError(
                'eval_indices can be a numpy array, a path to a numpy array, or a list of numpy arrays.'
            )
    else:
        eval_indices = [None] * len(list_validation_dirs)

    # reformat model names
    if architecture_names is None:
        architecture_names = [
            os.path.basename(os.path.dirname(d)) for d in list_validation_dirs
        ]
    else:
        architecture_names = utils.reformat_to_list(architecture_names,
                                                    len(list_validation_dirs))

    # prepare legend labels
    if plot_legend is False:
        list_legend_labels = ['_nolegend_'] * n_curves
    elif plot_legend is True:
        list_legend_labels = architecture_names
    else:
        list_legend_labels = architecture_names
        list_legend_labels = [
            '_nolegend_' if i >= plot_legend else list_legend_labels[i]
            for i in range(n_curves)
        ]

    # prepare linestyles
    if list_linestyles is not None:
        list_linestyles = utils.reformat_to_list(list_linestyles)
    else:
        list_linestyles = [None] * n_curves

    # prepare curve colours
    if list_colours is not None:
        list_colours = utils.reformat_to_list(list_colours)
    else:
        list_colours = [None] * n_curves

    # loop over architectures
    plt.figure(figsize=figsize)
    for idx, (net_val_dir, net_name, linestyle, colour, legend_label,
              eval_idx) in enumerate(
                  zip(list_validation_dirs, architecture_names,
                      list_linestyles, list_colours, list_legend_labels,
                      eval_indices)):

        list_epochs_dir = utils.list_subfolders(net_val_dir, whole_path=False)

        # loop over epochs
        list_net_dice_scores = list()
        list_epochs = list()
        for epoch_dir in list_epochs_dir:

            # build names and create folders
            path_epoch_dice = os.path.join(net_val_dir, epoch_dir, 'dice.npy')
            if os.path.isfile(path_epoch_dice):
                if eval_idx is not None:
                    list_net_dice_scores.append(
                        np.mean(np.load(path_epoch_dice)[eval_idx, :]))
                else:
                    if skip_first_dice_row:
                        list_net_dice_scores.append(
                            np.mean(np.load(path_epoch_dice)[1:, :]))
                    else:
                        list_net_dice_scores.append(
                            np.mean(np.load(path_epoch_dice)))
                list_epochs.append(int(re.sub('[^0-9]', '', epoch_dir)))

        # plot validation scores for current architecture
        if list_net_dice_scores:  # check that archi has been validated for at least 1 epoch
            list_net_dice_scores = np.array(list_net_dice_scores)
            list_epochs = np.array(list_epochs)
            list_epochs, idx = np.unique(list_epochs, return_index=True)
            list_net_dice_scores = list_net_dice_scores[idx]
            max_score = np.max(list_net_dice_scores)
            epoch_max_score = list_epochs[np.argmax(list_net_dice_scores)]
            print('\n' + net_name)
            print('epoch max score: %d' % epoch_max_score)
            print('max score: %0.3f' % max_score)
            plt.plot(list_epochs,
                     list_net_dice_scores,
                     label=legend_label,
                     linestyle=linestyle,
                     color=colour)
            plt.scatter(epoch_max_score,
                        max_score,
                        s=size_max_circle,
                        color=colour)

    # finalise plot
    plt.grid()
    plt.tick_params(axis='both', labelsize=fontsize)
    plt.ylabel('Dice scores', fontsize=fontsize)
    plt.xlabel('Epochs', fontsize=fontsize)
    if y_lim is not None:
        plt.ylim(y_lim[0], y_lim[1] + 0.01)  # set right/left limits of plot
    plt.title('Validation curves', fontsize=fontsize)
    if plot_legend:
        plt.legend(fontsize=fontsize)
    plt.tight_layout(pad=1)
    plt.show()
Exemple #25
0
def labels_to_image_model(labels_shape,
                          n_channels,
                          generation_labels,
                          output_labels,
                          n_neutral_labels,
                          atlas_res,
                          target_res,
                          output_shape=None,
                          output_div_by_n=None,
                          padding_margin=None,
                          flipping=True,
                          aff=None,
                          apply_linear_trans=True,
                          apply_nonlin_trans=True,
                          nonlin_std=3.,
                          nonlin_shape_factor=.0625,
                          blur_background=True,
                          data_res=None,
                          thickness=None,
                          downsample=False,
                          blur_range=1.15,
                          crop_channel2=None,
                          apply_bias_field=True,
                          bias_field_std=.3,
                          bias_shape_factor=.025):
    """
    This function builds a keras/tensorflow model to generate images from provided label maps.
    The images are generated by sampling a Gaussian Mixture Model (of given parameters), conditionned on the label map.
    The model will take as inputs:
        -a label map
        -a vector containing the means of the Gaussian Mixture Model for each label,
        -a vector containing the standard deviations of the Gaussian Mixture Model for each label,
        -if apply_affine_deformation is True: a batch*(n_dims+1)*(n_dims+1) affine matrix
        -if apply_non_linear_deformation is True: a small non linear field of size batch*(dim_1*...*dim_n)*n_dims that
        will be resampled to labels size and integrated, to obtain a diffeomorphic elastic deformation.
        -if apply_bias_field is True: a small bias field of size batch*(dim_1*...*dim_n)*1 that will be resampled to
        labels size and multiplied to the image, to add a "bias-field" noise.
    The model returns:
        -the generated image normalised between 0 and 1.
        -the corresponding label map, with only the labels present in output_labels (the other are reset to zero).
    :param labels_shape: shape of the input label maps. Can be a sequence or a 1d numpy array.
    :param n_channels: number of channels to be synthetised.
    :param generation_labels: (optional) list of all possible label values in the input label maps.
    Default is None, where the label values are directly gotten from the provided label maps.
    If not None, can be a sequence or a 1d numpy array. It should be organised as follows: background label first, then
    non-sided labels (e.g. CSF, brainstem, etc.), then all the structures of the same hemisphere (can be left or right),
    and finally all the corresponding contralateral structures (in the same order).
    :param output_labels: list of all the label values to keep in the output label maps, in no particular order.
    Should be a subset of the values contained in generation_labels.
    Label values that are in generation_labels but not in output_labels are reset to zero.
    Can be a sequence or a 1d numpy array.
    :param n_neutral_labels: number of non-sided generation labels.
    :param atlas_res: resolution of the input label maps.
    Can be a number (isotropic resolution), a sequence, or a 1d numpy array.
    :param target_res: target resolution of the generated images and corresponding label maps.
    Can be a number (isotropic resolution), a sequence, or a 1d numpy array.
    :param output_shape: (optional) desired shape of the output image, obtained by randomly cropping the generated image
    Can be an integer (same size in all dimensions), a sequence, a 1d numpy array, or the path to a 1d numpy array.
    :param output_div_by_n: (optional) forces the output shape to be divisible by this value. It overwrites output_shape
    if necessary. Can be an integer (same size in all dimensions), a sequence, or a 1d numpy array.
    :param padding_margin: (optional) margin by which to pad the input labels with zeros.
    Padding is applied prior to any other operation.
    Can be an integer (same padding in all dimensions), a sequence, or a 1d numpy array. Default is no padding.
    :param flipping: (optional) whether to introduce right/left random flipping
    :param aff: (optional) example of an (n_dims+1)x(n_dims+1) affine matrix of one of the input label map.
    Used to find brain's right/left axis. Should be given if flipping is True.
    :param apply_linear_trans: (optional) whether to linearly deform the input label maps prior to generation.
    If true, the model will take an additional input of size batch*(n_dims+1)*(n_dims+1). Default is True.
    :param apply_nonlin_trans: (optional) whether to non-linearly deform the input label maps prior to generation.
    If true, the model will take an additional input of size batch*(dim_1*...*dim_n)*n_dims. Default is True.
    :param nonlin_std: (optional) If apply_nonlin_trans is True, maximum value for the standard deviation of the normal
    distribution from which we sample the first tensor for synthesising the deformation field.
    :param nonlin_shape_factor: (optional) if apply_non_linear_deformation is True, factor between the shapes of the
    input label maps and the shape of the input non-linear tensor.
    :param blur_background: (optional) If True, the background is blurred with the other labels, and can be reset to
    zero with a probability of 0.2. If False, the background is not blurred (we apply an edge blurring correction), and
    can be replaced by a low-intensity background.
    :param data_res: ((optional) acquisition resolution to mimick. If provided, the images sampled from the GMM are
    blurred to mimick data that would be: 1) acquired at the given acquisition resolution, and 2) resample at
    target_resolution.
    Default is None, where images are isotropically blurred to introduce some spatial correlation between voxels.
    If the generated images are uni-modal, data_res can be a number (isotropic acquisition resolution), a sequence, a 1d
    numpy array, or the path to a 1d numy array. In the multi-modal case, it should be given as a numpy array (or a
    path) of size (n_mod, n_dims), where each row is the acquisition resolution of the correspionding chanel.
    :param thickness: (optional) if data_res is provided, we can further specify the slice thickness of the low
    resolution images to mimick.
    If the generated images are uni-modal, data_res can be a number (isotropic acquisition resolution), a sequence, a 1d
    numpy array, or the path to a 1d numy array. In the multi-modal case, it should be given as a numpy array (or a
    path) of size (n_mod, n_dims), where each row is the acquisition resolution of the correspionding chanel.
    :param downsample: (optional) whether to actually downsample the volume image to data_res.
    Default is False, except when thickness is provided, and thickness < data_res.
    :param blur_range: (optional) Randomise the standard deviation of the blurring kernels, (whether data_res is given
    or not). At each mini_batch, the standard deviation of the blurring kernels are multiplied by a coefficient sampled
    from a uniform distribution with bounds [1/blur_range, blur_range]. If None, no randomisation. Default is 1.15.
    :param crop_channel2: (optional) stats for cropping second channel along the anterior-posterior axis.
    Should be a vector of length 4, with bounds of uniform distribution for cropping the front and back of the image
    (in percentage). None is no croppping.
    :param apply_bias_field: (optional) whether to apply a bias field to the generated image.
    If true, the model will take an additional input of size batch*(dim_1*...*dim_n)*1. Default is True.
    :param bias_field_std: (optional) If apply_nonlin_trans is True, maximum value for the standard deviation of the
    normal distribution from which we sample the first tensor for synthesising the deformation field.
    :param bias_shape_factor: (optional) if apply_bias_field is True, factor between the shapes of the
    input label maps and the shape of the input bias field tensor.
    """

    # reformat resolutions
    labels_shape = utils.reformat_to_list(labels_shape)
    n_dims, _ = utils.get_dims(labels_shape)
    atlas_res = utils.reformat_to_n_channels_array(atlas_res,
                                                   n_dims=n_dims,
                                                   n_channels=n_channels)
    if data_res is None:  # data_res assumed to be the same as the atlas
        data_res = atlas_res
    else:
        data_res = utils.reformat_to_n_channels_array(data_res,
                                                      n_dims=n_dims,
                                                      n_channels=n_channels)
    atlas_res = atlas_res[0]
    if downsample:  # same as data_res if we want to actually downsample the synthetic image
        downsample_res = data_res
    else:  # set downsample_res to None if downsampling is not necessary
        downsample_res = None
    if target_res is None:
        target_res = atlas_res
    else:
        target_res = utils.reformat_to_n_channels_array(target_res, n_dims)[0]
    thickness = utils.reformat_to_n_channels_array(thickness,
                                                   n_dims=n_dims,
                                                   n_channels=n_channels)

    # get shapes
    crop_shape, output_shape, padding_margin = get_shapes(
        labels_shape, output_shape, atlas_res, target_res, padding_margin,
        output_div_by_n)

    # create new_label_list and corresponding LUT to make sure that labels go from 0 to N-1
    n_generation_labels = generation_labels.shape[0]
    new_generation_label_list, lut = utils.rearrange_label_list(
        generation_labels)

    # define model inputs
    labels_input = KL.Input(shape=labels_shape + [1], name='labels_input')
    means_input = KL.Input(shape=list(new_generation_label_list.shape) +
                           [n_channels],
                           name='means_input')
    std_devs_input = KL.Input(shape=list(new_generation_label_list.shape) +
                              [n_channels],
                              name='std_devs_input')
    list_inputs = [labels_input, means_input, std_devs_input]
    if apply_linear_trans:
        aff_in = KL.Input(shape=(n_dims + 1, n_dims + 1), name='aff_input')
        list_inputs.append(aff_in)
    else:
        aff_in = None

    # convert labels to new_label_list
    labels = l2i_et.convert_labels(labels_input, lut)

    # pad labels
    if padding_margin is not None:
        pad = np.transpose(np.array([[0] + padding_margin + [0]] * 2))
        labels = KL.Lambda(lambda x: tf.pad(
            x, tf.cast(tf.convert_to_tensor(pad), dtype='int32')),
                           name='pad')(labels)
        labels_shape = labels.get_shape().as_list()[1:n_dims + 1]

    # deform labels
    if apply_linear_trans | apply_nonlin_trans:
        labels = l2i_sp.deform_tensor(labels, aff_in, apply_nonlin_trans,
                                      'nearest', nonlin_std,
                                      nonlin_shape_factor)
    labels = KL.Lambda(lambda x: tf.cast(x, dtype='int32'))(labels)

    # cropping
    if crop_shape != labels_shape:
        labels, _ = l2i_sp.random_cropping(labels, crop_shape, n_dims)

    if flipping:
        assert aff is not None, 'aff should not be None if flipping is True'
        labels, _ = l2i_sp.label_map_random_flipping(
            labels, new_generation_label_list, n_neutral_labels, aff, n_dims)

    # build synthetic image
    image = l2i_gmm.sample_gmm_conditioned_on_labels(labels, means_input,
                                                     std_devs_input,
                                                     n_generation_labels,
                                                     n_channels)

    # loop over channels
    if n_channels > 1:
        split = KL.Lambda(lambda x: tf.split(x, [1] * n_channels, axis=-1))(
            image)
    else:
        split = [image]
    mask = KL.Lambda(
        lambda x: tf.where(tf.greater(x, 0), tf.ones_like(x, dtype='float32'),
                           tf.zeros_like(x, dtype='float32')))(labels)
    processed_channels = list()
    for i, channel in enumerate(split):

        # reset edges of second channels to zero
        if (crop_channel2 is not None) & (
                i == 1):  # randomly crop sides of second channel
            channel, tmp_mask = l2i_sp.restrict_tensor(
                channel, axes=3, boundaries=crop_channel2)
        else:
            tmp_mask = None

        # blur channel
        if thickness is not None:
            sigma = utils.get_std_blurring_mask_for_downsampling(
                data_res[i], atlas_res, thickness=thickness[i])
        else:
            sigma = utils.get_std_blurring_mask_for_downsampling(
                data_res[i], atlas_res)
        kernels_list = l2i_et.get_gaussian_1d_kernels(
            sigma, blurring_range=blur_range)
        channel = l2i_et.blur_channel(channel, mask, kernels_list, n_dims,
                                      blur_background)
        if (crop_channel2 is not None) & (i == 1):
            channel = KL.multiply([channel, tmp_mask])

        # resample channel
        if downsample_res is not None:
            channel = l2i_et.resample_tensor(channel,
                                             output_shape,
                                             'linear',
                                             downsample_res[i],
                                             atlas_res,
                                             n_dims=n_dims)
        else:
            if thickness is not None:
                diff = [
                    thickness[i][dim_idx] - data_res[i][dim_idx]
                    for dim_idx in range(n_dims)
                ]
                if min(diff) < 0:
                    channel = l2i_et.resample_tensor(channel,
                                                     output_shape,
                                                     'linear',
                                                     data_res[i],
                                                     atlas_res,
                                                     n_dims=n_dims)
                else:
                    channel = l2i_et.resample_tensor(channel, output_shape,
                                                     'linear', None, atlas_res,
                                                     n_dims)

        # apply bias field
        if apply_bias_field:
            channel = l2i_ia.bias_field_augmentation(channel, bias_field_std,
                                                     bias_shape_factor)

        # intensity augmentation
        channel = KL.Lambda(lambda x: K.clip(x, 0, 300))(channel)
        channel = l2i_ia.min_max_normalisation(channel)
        processed_channels.append(l2i_ia.gamma_augmentation(channel, std=0.5))

    # concatenate all channels back
    if n_channels > 1:
        image = KL.concatenate(processed_channels)
    else:
        image = processed_channels[0]

    # resample labels at target resolution
    if crop_shape != output_shape:
        labels = KL.Lambda(lambda x: tf.cast(x, dtype='float32'))(labels)
        labels = l2i_et.resample_tensor(labels,
                                        output_shape,
                                        interp_method='nearest',
                                        n_dims=3)
    # convert labels back to original values and reset unwanted labels to zero
    labels = l2i_et.convert_labels(labels, generation_labels)
    labels_to_reset = [
        lab for lab in generation_labels if lab not in output_labels
    ]
    labels = l2i_et.reset_label_values_to_zero(labels, labels_to_reset)
    labels = KL.Lambda(lambda x: tf.cast(x, dtype='int32'),
                       name='labels_out')(labels)

    # build model (dummy layer enables to keep the labels when plugging this model to other models)
    image = KL.Lambda(lambda x: x[0], name='image_out')([image, labels])
    brain_model = keras.Model(inputs=list_inputs, outputs=[image, labels])

    return brain_model
Exemple #26
0
def build_model(model_file, input_shape, resample, im_res, n_levels, n_lab, conv_size, nb_conv_per_level,
                unet_feat_count, feat_multiplier, no_batch_norm, activation, sigma_smoothing):

    # initialisation
    net = None
    n_dims, n_channels = utils.get_dims(input_shape, max_channels=3)
    resample = utils.reformat_to_list(resample, length=n_dims)

    # build preprocessing model
    if resample is not None:
        im_input = KL.Input(shape=input_shape, name='pre_resample_input')
        resample_factor = [im_res[i] / float(resample[i]) for i in range(n_dims)]
        resample_shape = [utils.find_closest_number_divisible_by_m(resample_factor[i] * input_shape[i],
                          2 ** n_levels, smaller_ans=False) for i in range(n_dims)]
        resampled = nrn_layers.Resize(size=resample_shape, name='pre_resample')(im_input)
        net = Model(inputs=im_input, outputs=resampled)
        input_shape = resample_shape + [n_channels]

    # build UNet
    if no_batch_norm:
        batch_norm_dim = None
    else:
        batch_norm_dim = -1
    net = nrn_models.unet(nb_features=unet_feat_count,
                          input_shape=input_shape,
                          nb_levels=n_levels,
                          conv_size=conv_size,
                          nb_labels=n_lab,
                          name='unet',
                          prefix=None,
                          feat_mult=feat_multiplier,
                          pool_size=2,
                          use_logp=True,
                          padding='same',
                          dilation_rate_mult=1,
                          activation=activation,
                          use_residuals=False,
                          final_pred_activation='softmax',
                          nb_conv_per_level=nb_conv_per_level,
                          add_prior_layer=False,
                          add_prior_layer_reg=0,
                          layer_nb_feats=None,
                          conv_dropout=0,
                          batch_norm=batch_norm_dim,
                          input_model=net)
    net.load_weights(model_file, by_name=True)

    # build postprocessing model
    if (resample is not None) | (sigma_smoothing != 0):

        # get UNet output
        input_tensor = net.inputs
        last_tensor = net.output

        # resample to initial resolution
        if resample is not None:
            last_tensor = nrn_layers.Resize(size=input_shape[:-1], name='post_resample')(last_tensor)

        # smooth each posteriors map separately
        if sigma_smoothing != 0:
            kernels_list = l2i_et.get_gaussian_1d_kernels(utils.reformat_to_list(sigma_smoothing, length=n_dims))
            split = KL.Lambda(lambda x: tf.split(x, [1] * n_lab, axis=-1), name='resample_split')(last_tensor)
            last_tensor = l2i_et.blur_tensor(split[0], kernels_list, n_dims)
            for i in range(1, n_lab):
                temp_blurred = l2i_et.blur_tensor(split[i], kernels_list, n_dims)
                last_tensor = KL.concatenate([last_tensor, temp_blurred], axis=-1, name='cat_blurring_%s' % i)

        # build model
        net = Model(inputs=input_tensor, outputs=last_tensor)

    return net
def get_shapes(labels_shape, output_shape, atlas_res, target_res,
               output_div_by_n):

    # reformat resolutions to lists
    atlas_res = utils.reformat_to_list(atlas_res)
    n_dims = len(atlas_res)
    target_res = utils.reformat_to_list(target_res)

    # get resampling factor
    if atlas_res != target_res:
        resample_factor = [
            atlas_res[i] / float(target_res[i]) for i in range(n_dims)
        ]
    else:
        resample_factor = None

    # output shape specified, need to get cropping shape, and resample shape if necessary
    if output_shape is not None:
        output_shape = utils.reformat_to_list(output_shape,
                                              length=n_dims,
                                              dtype='int')

        # make sure that output shape is smaller or equal to label shape
        if resample_factor is not None:
            output_shape = [
                min(int(labels_shape[i] * resample_factor[i]), output_shape[i])
                for i in range(n_dims)
            ]
        else:
            output_shape = [
                min(labels_shape[i], output_shape[i]) for i in range(n_dims)
            ]

        # make sure output shape is divisible by output_div_by_n
        if output_div_by_n is not None:
            tmp_shape = [
                utils.find_closest_number_divisible_by_m(s,
                                                         output_div_by_n,
                                                         smaller_ans=True)
                for s in output_shape
            ]
            if output_shape != tmp_shape:
                print('output shape {0} not divisible by {1}, changed to {2}'.
                      format(output_shape, output_div_by_n, tmp_shape))
                output_shape = tmp_shape

        # get cropping and resample shape
        if resample_factor is not None:
            cropping_shape = [
                int(np.around(output_shape[i] / resample_factor[i], 0))
                for i in range(n_dims)
            ]
        else:
            cropping_shape = output_shape

    # no output shape specified, so no cropping unless label_shape is not divisible by output_div_by_n
    else:

        # make sure output shape is divisible by output_div_by_n
        if output_div_by_n is not None:

            # if resampling, get the potential output_shape and check if it is divisible by n
            if resample_factor is not None:
                output_shape = [
                    int(labels_shape[i] * resample_factor[i])
                    for i in range(n_dims)
                ]
                output_shape = [
                    utils.find_closest_number_divisible_by_m(s,
                                                             output_div_by_n,
                                                             smaller_ans=True)
                    for s in output_shape
                ]
                cropping_shape = [
                    int(np.around(output_shape[i] / resample_factor[i], 0))
                    for i in range(n_dims)
                ]
            # if no resampling, simply check if image_shape is divisible by n
            else:
                cropping_shape = [
                    utils.find_closest_number_divisible_by_m(s,
                                                             output_div_by_n,
                                                             smaller_ans=True)
                    for s in labels_shape
                ]
                output_shape = cropping_shape

        # if no need to be divisible by n, simply take cropping_shape as image_shape, and build output_shape
        else:
            cropping_shape = labels_shape
            if resample_factor is not None:
                output_shape = [
                    int(cropping_shape[i] * resample_factor[i])
                    for i in range(n_dims)
                ]
            else:
                output_shape = cropping_shape

    return cropping_shape, output_shape
Exemple #28
0
def get_shapes(labels_shape, output_shape, atlas_res, target_res,
               padding_margin, output_div_by_n):

    n_dims = len(atlas_res)

    # get new labels shape if padding
    if padding_margin is not None:
        padding_margin = utils.reformat_to_list(padding_margin,
                                                length=n_dims,
                                                dtype='int')
        labels_shape = [
            labels_shape[i] + 2 * padding_margin[i] for i in range(n_dims)
        ]

    # get resampling factor
    if atlas_res.tolist() != target_res.tolist():
        resample_factor = [
            atlas_res[i] / float(target_res[i]) for i in range(n_dims)
        ]
    else:
        resample_factor = None

    # output shape specified, need to get cropping shape, and resample shape if necessary
    if output_shape is not None:
        output_shape = utils.reformat_to_list(output_shape,
                                              length=n_dims,
                                              dtype='int')

        # make sure that output shape is smaller or equal to label shape
        if resample_factor is not None:
            output_shape = [
                min(int(labels_shape[i] * resample_factor[i]), output_shape[i])
                for i in range(n_dims)
            ]
        else:
            output_shape = [
                min(labels_shape[i], output_shape[i]) for i in range(n_dims)
            ]

        # make sure output shape is divisible by output_div_by_n
        if output_div_by_n is not None:
            tmp_shape = [
                utils.find_closest_number_divisible_by_m(s,
                                                         output_div_by_n,
                                                         smaller_ans=True)
                for s in output_shape
            ]
            if output_shape != tmp_shape:
                print('output shape {0} not divisible by {1}, changed to {2}'.
                      format(output_shape, output_div_by_n, tmp_shape))
                output_shape = tmp_shape

        # get cropping and resample shape
        if resample_factor is not None:
            cropping_shape = [
                int(np.around(output_shape[i] / resample_factor[i], 0))
                for i in range(n_dims)
            ]
        else:
            cropping_shape = output_shape

    # no output shape specified, so no cropping unless label_shape is not divisible by output_div_by_n
    else:
        cropping_shape = labels_shape
        if resample_factor is not None:
            output_shape = [
                int(np.around(cropping_shape[i] * resample_factor[i], 0))
                for i in range(n_dims)
            ]
        else:
            output_shape = cropping_shape
        # make sure output shape is divisible by output_div_by_n
        if output_div_by_n is not None:
            output_shape = [
                utils.find_closest_number_divisible_by_m(s,
                                                         output_div_by_n,
                                                         smaller_ans=False)
                for s in output_shape
            ]

    return cropping_shape, output_shape, padding_margin
Exemple #29
0
def surface_distances(x,
                      y,
                      hausdorff_percentile=None,
                      return_coordinate_max_distance=False):
    """Computes the maximum boundary distance (Haussdorff distance), and the average boundary distance of two masks.
    :param x: numpy array (boolean or 0/1)
    :param y: numpy array (boolean or 0/1)
    :param hausdorff_percentile: (optional) percentile (from 0 to 100) for which to compute the Hausdorff distance.
    Set this to 100 to compute the real Hausdorff distance (default). Can also be a list, where HD will be compute for
    the provided values.
    :param return_coordinate_max_distance: (optional) when set to true, the function will return the coordinates of the
    voxel with the highest distance (only if hausdorff_percentile=100).
    :return: max_dist, mean_dist(, coordinate_max_distance)
    max_dist: scalar with HD computed for the given percentile (or list if hausdorff_percentile was given as a list).
    mean_dist: scalar with average surface distance
    coordinate_max_distance: only returned return_coordinate_max_distance is True."""

    assert x.shape == y.shape, 'both inputs should have same size, had {} and {}'.format(
        x.shape, y.shape)
    n_dims = len(x.shape)

    hausdorff_percentile = 100 if hausdorff_percentile is None else hausdorff_percentile
    hausdorff_percentile = utils.reformat_to_list(hausdorff_percentile)

    # crop x and y around ROI
    _, crop_x = edit_volumes.crop_volume_around_region(x)
    _, crop_y = edit_volumes.crop_volume_around_region(y)

    # set distances to maximum volume shape if they are not defined
    if (crop_x is None) | (crop_y is None):
        return max(x.shape), max(x.shape)

    crop = np.concatenate([
        np.minimum(crop_x, crop_y)[:n_dims],
        np.maximum(crop_x, crop_y)[n_dims:]
    ])
    x = edit_volumes.crop_volume_with_idx(x, crop)
    y = edit_volumes.crop_volume_with_idx(y, crop)

    # detect edge
    x_dist_int = distance_transform_edt(x * 1)
    x_edge = (x_dist_int == 1) * 1
    y_dist_int = distance_transform_edt(y * 1)
    y_edge = (y_dist_int == 1) * 1

    # calculate distance from edge
    x_dist = distance_transform_edt(np.logical_not(x_edge))
    y_dist = distance_transform_edt(np.logical_not(y_edge))

    # find distances from the 2 surfaces
    x_dists_to_y = y_dist[x_edge == 1]
    y_dists_to_x = x_dist[y_edge == 1]

    max_dist = list()
    coordinate_max_distance = None
    for hd_percentile in hausdorff_percentile:

        # find max distance from the 2 surfaces
        if hd_percentile == 100:
            max_dist.append(
                np.max(np.concatenate([x_dists_to_y, y_dists_to_x])))

            if return_coordinate_max_distance:
                indices_x_surface = np.where(x_edge == 1)
                idx_max_distance_x = np.where(x_dists_to_y == max_dist)[0]
                if idx_max_distance_x.size != 0:
                    coordinate_max_distance = np.stack(
                        indices_x_surface).transpose()[idx_max_distance_x]
                else:
                    indices_y_surface = np.where(y_edge == 1)
                    idx_max_distance_y = np.where(y_dists_to_x == max_dist)[0]
                    coordinate_max_distance = np.stack(
                        indices_y_surface).transpose()[idx_max_distance_y]

        # find percentile of max distance
        else:
            max_dist.append(
                np.percentile(np.concatenate([x_dists_to_y, y_dists_to_x]),
                              hd_percentile))

    # find average distance between 2 surfaces
    if x_dists_to_y.shape[0] > 0:
        x_mean_dist_to_y = np.mean(x_dists_to_y)
    else:
        x_mean_dist_to_y = max(x.shape)
    if y_dists_to_x.shape[0] > 0:
        y_mean_dist_to_x = np.mean(y_dists_to_x)
    else:
        y_mean_dist_to_x = max(x.shape)
    mean_dist = (x_mean_dist_to_y + y_mean_dist_to_x) / 2

    # convert max dist back to scalar if HD only computed for 1 percentile
    if len(max_dist) == 1:
        max_dist = max_dist[0]

    # return coordinate of max distance if necessary
    if coordinate_max_distance is not None:
        return max_dist, mean_dist, coordinate_max_distance
    else:
        return max_dist, mean_dist
Exemple #30
0
def preprocess_image(im_path, n_levels, crop_shape=None, padding=None):

    # read image and corresponding info
    im, shape, aff, n_dims, n_channels, header, labels_res = utils.get_volume_info(
        im_path, return_volume=True)

    if padding:
        if n_channels == 1:
            im = np.pad(im, padding, mode='constant')
            pad_shape = im.shape
        else:
            im = np.pad(im,
                        tuple([(padding, padding)] * n_dims + [(0, 0)]),
                        mode='constant')
            pad_shape = im.shape[:-1]
    else:
        pad_shape = shape

    # check that patch_shape or im_shape are divisible by 2**n_levels
    if crop_shape is not None:
        crop_shape = utils.reformat_to_list(crop_shape,
                                            length=n_dims,
                                            dtype='int')
        if not all(
            [pad_shape[i] >= crop_shape[i] for i in range(len(pad_shape))]):
            crop_shape = [
                min(pad_shape[i], crop_shape[i]) for i in range(n_dims)
            ]
            print(
                'cropping dimensions are higher than image size, changing cropping size to {}'
                .format(crop_shape))
        if not all([size % (2**n_levels) == 0 for size in crop_shape]):
            crop_shape = [
                utils.find_closest_number_divisible_by_m(size, 2**n_levels)
                for size in crop_shape
            ]
    else:
        if not all([size % (2**n_levels) == 0 for size in pad_shape]):
            crop_shape = [
                utils.find_closest_number_divisible_by_m(size, 2**n_levels)
                for size in pad_shape
            ]

    # crop image if necessary
    if crop_shape is not None:
        crop_idx = np.round(
            (pad_shape - np.array(crop_shape)) / 2).astype('int')
        crop_idx = np.concatenate((crop_idx, crop_idx + crop_shape), axis=0)
        im = edit_volumes.crop_volume_with_idx(im, crop_idx=crop_idx)
    else:
        crop_idx = None

    # align image
    # ref_axes = np.array([0, 2, 1])
    # ref_signs = np.array([-1, 1, -1])
    # im_axes, img_signs = utils.get_ras_axis_and_signs(aff, n_dims=n_dims)
    # im = edit_volume.align_volume_to_ref(im, ref_axes, ref_signs, im_axes, img_signs)

    # normalise image
    m = np.min(im)
    M = np.max(im)
    if M == m:
        im = np.zeros(im.shape)
    else:
        im = (im - m) / (M - m)

    # add batch and channel axes
    if n_channels > 1:
        im = utils.add_axis(im)
    else:
        im = utils.add_axis(im, -2)

    return im, aff, header, n_channels, n_dims, shape, pad_shape, crop_shape, crop_idx