Esempio n. 1
0
def preprocess_image(im_path, n_levels, crop_shape=None, padding=None, aff_ref='FS', dist_map=False):

    # read image and corresponding info
    im, shape, aff, n_dims, n_channels, header, im_res = utils.get_volume_info(im_path, return_volume=True)

    if padding:
        im = edit_volumes.pad_volume(im, padding_shape=padding)
        pad_shape = im.shape[:n_dims]
    else:
        pad_shape = shape

    # check that patch_shape or im_shape are divisible by 2**n_levels
    if crop_shape is not None:
        crop_shape = utils.reformat_to_list(crop_shape, length=n_dims, dtype='int')
        if not all([pad_shape[i] >= crop_shape[i] for i in range(len(pad_shape))]):
            crop_shape = [min(pad_shape[i], crop_shape[i]) for i in range(n_dims)]
        if not all([size % (2**n_levels) == 0 for size in crop_shape]):
            crop_shape = [utils.find_closest_number_divisible_by_m(size, 2 ** n_levels) for size in crop_shape]
    else:
        if not all([size % (2**n_levels) == 0 for size in pad_shape]):
            crop_shape = [utils.find_closest_number_divisible_by_m(size, 2 ** n_levels) for size in pad_shape]

    # crop image if necessary
    if crop_shape is not None:
        im, crop_idx = edit_volumes.crop_volume(im, cropping_shape=crop_shape, return_crop_idx=True)
    else:
        crop_idx = None

    # align image to training axes and directions
    if n_dims > 2:
        if aff_ref == 'FS':
            aff_ref = np.array([[-1., 0., 0., 0.], [0., 0., 1., 0.], [0., -1., 0., 0.], [0., 0., 0., 1.]])
            im = edit_volumes.align_volume_to_ref(im, aff, aff_ref=aff_ref, return_aff=False, n_dims=n_dims)
        elif aff_ref == 'identity':
            aff_ref = np.eye(4)
            im = edit_volumes.align_volume_to_ref(im, aff, aff_ref=aff_ref, return_aff=False, n_dims=n_dims)

    # normalise image
    if n_channels == 1:
        m = np.min(im)
        M = np.max(im)
        if M == m:
            im = np.zeros(im.shape)
        else:
            im = (im - m) / (M - m)
    else:
        for i in range(im.shape[-1]):
            if (not dist_map) | (dist_map & (i % 2 == 0)):
                channel = im[..., i]
                m = np.min(channel)
                M = np.max(channel)
                if M == m:
                    im[..., i] = np.zeros(channel.shape)
                else:
                    im[..., i] = (channel - m) / (M - m)

    # add batch and channel axes
    im = utils.add_axis(im) if n_channels > 1 else utils.add_axis(im, axis=[0, -1])

    return im, aff, header, im_res, n_channels, n_dims, shape, pad_shape, crop_idx
Esempio n. 2
0
def build_model_inputs(path_images,
                       path_label_maps,
                       batchsize=1):

    # get label info
    _, _, n_dims, n_channels, _, _ = utils.get_volume_info(path_images[0])

    # Generate!
    while True:

        # randomly pick as many images as batchsize
        indices = npr.randint(len(path_label_maps), size=batchsize)

        # initialise input lists
        list_images = list()
        list_label_maps = list()

        for idx in indices:

            # add image
            image = utils.load_volume(path_images[idx], aff_ref=np.eye(4))
            if n_channels > 1:
                list_images.append(utils.add_axis(image, axis=0))
            else:
                list_images.append(utils.add_axis(image, axis=[0, -1]))

            # add labels
            labels = utils.load_volume(path_label_maps[idx], dtype='int', aff_ref=np.eye(4))
            list_label_maps.append(utils.add_axis(labels, axis=[0, -1]))

        # build list of inputs of augmentation model
        list_inputs = [list_images, list_label_maps]
        if batchsize > 1:  # concatenate individual input types if batchsize > 1
            list_inputs = [np.concatenate(item, 0) for item in list_inputs]
        else:
            list_inputs = [item[0] for item in list_inputs]

        yield list_inputs
Esempio n. 3
0
    def __init__(self,
                 labels_dir,
                 generation_labels=None,
                 output_labels=None,
                 n_neutral_labels=None,
                 batchsize=1,
                 n_channels=1,
                 target_res=None,
                 output_shape=None,
                 output_div_by_n=None,
                 prior_distributions='uniform',
                 generation_classes=None,
                 prior_means=None,
                 prior_stds=None,
                 use_specific_stats_for_channel=False,
                 mix_prior_and_random=False,
                 flipping=True,
                 scaling_bounds=0.15,
                 rotation_bounds=15,
                 shearing_bounds=.012,
                 translation_bounds=False,
                 nonlin_std=4.,
                 nonlin_shape_factor=0.0625,
                 randomise_res=False,
                 buil_distance_maps=False,
                 data_res=None,
                 thickness=None,
                 downsample=False,
                 blur_range=1.15,
                 bias_field_std=0.5,
                 bias_shape_factor=0.025):
        """
        This class is wrapper around the labels_to_image_model model. It contains the GPU model that generates images
        from labels maps, and a python generator that suplies the input data for this model.
        To generate pairs of image/labels you can just call the method generate_image() on an object of this class.

        :param labels_dir: path of folder with all input label maps, or to a single label map.

        # IMPORTANT !!!
        # Each time we provide a parameter with separate values for each axis (e.g. with a numpy array or a sequence),
        # these values refer to the RAS axes.

        # label maps-related parameters
        :param generation_labels: (optional) list of all possible label values in the input label maps.
        Default is None, where the label values are directly gotten from the provided label maps.
        If not None, can be a sequence or a 1d numpy array, or the path to a 1d numpy array.
        If flipping is true (i.e. right/left flipping is enabled), generation_labels should be organised as follows:
        background label first, then non-sided labels (e.g. CSF, brainstem, etc.), then all the structures of the same
        hemisphere (can be left or right), and finally all the corresponding contralateral structures in the same order.
        :param output_labels: (optional) list of all the label values to keep in the output label maps (in no particular
        order). Should be a subset of the values contained in generation_labels.
        Label values that are in generation_labels but not in output_labels are reset to zero.
        Can be a sequence, a 1d numpy array, or the path to a 1d numpy array.
        By default output labels are equal to generation labels.
        :param n_neutral_labels: (optional) number of non-sided generation labels.
        Default is total number of label values.

        # output-related parameters
        :param batchsize: (optional) numbers of images to generate per mini-batch. Default is 1.
        :param n_channels: (optional) number of channels to be synthetised. Default is 1.
        :param target_res: (optional) target resolution of the generated images and corresponding label maps.
        If None, the outputs will have the same resolution as the input label maps.
        Can be a number (isotropic resolution), a sequence, a 1d numpy array, or the path to a 1d numpy array.
        :param output_shape: (optional) shape of the output image, obtained by randomly cropping the generated image.
        Can be an integer (same size in all dimensions), a sequence, a 1d numpy array, or the path to a 1d numpy array.
        Default is None, where no cropping is performed.
        :param output_div_by_n: (optional) forces the output shape to be divisible by this value. It overwrites
        output_shape if necessary. Can be an integer (same size in all dimensions), a sequence, a 1d numpy array, or
        the path to a 1d numpy array.

        # GMM-sampling parameters
        :param generation_classes: (optional) Indices regrouping generation labels into classes of same intensity
        distribution. Regouped labels will thus share the same Gaussian when samling a new image. Can be a sequence, a
        1d numpy array, or the path to a 1d numpy array. It should have the same length as generation_labels, and
        contain values between 0 and K-1, where K is the total number of classes.
        Default is all labels have different classes (K=len(generation_labels)).
        :param prior_distributions: (optional) type of distribution from which we sample the GMM parameters.
        Can either be 'uniform', or 'normal'. Default is 'uniform'.
        :param prior_means: (optional) hyperparameters controlling the prior distributions of the GMM means. Because
        these prior distributions are uniform or normal, they require by 2 hyperparameters. Thus prior_means can be:
        1) a sequence of length 2, directly defining the two hyperparameters: [min, max] if prior_distributions is
        uniform, [mean, std] if the distribution is normal. The GMM means of are independently sampled at each
        mini_batch from the same distribution.
        2) an array of shape (2, K), where K is the number of classes (K=len(generation_labels) if generation_classes is
        not given). The mean of the Gaussian distribution associated to class k in [0, ...K-1] is sampled at each
        mini-batch from U(prior_means[0,k], prior_means[1,k]) if prior_distributions is uniform, and from
        N(prior_means[0,k], prior_means[1,k]) if prior_distributions is normal.
        3) an array of shape (2*n_mod, K), where each block of two rows is associated to hyperparameters derived
        from different modalities. In this case, if use_specific_stats_for_channel is False, we first randomly select a
        modality from the n_mod possibilities, and we sample the GMM means like in 2).
        If use_specific_stats_for_channel is True, each block of two rows correspond to a different channel
        (n_mod=n_channels), thus we select the corresponding block to each channel rather than randomly drawing it.
        4) the path to such a numpy array.
        Default is None, which corresponds to prior_means = [25, 225].
        :param prior_stds: (optional) same as prior_means but for the standard deviations of the GMM.
        Default is None, which corresponds to prior_stds = [5, 25].
        :param use_specific_stats_for_channel: (optional) whether the i-th block of two rows in the prior arrays must be
        only used to generate the i-th channel. If True, n_mod should be equal to n_channels. Default is False.
        :param mix_prior_and_random: (optional) if prior_means is not None, enables to reset the priors to their default
        values for half of thes cases, and thus generate images of random contrast.

        # spatial deformation parameters
        :param flipping: (optional) whether to introduce right/left random flipping. Default is True.
        :param scaling_bounds: (optional) range of the random saling to apply at each mini-batch. The scaling factor for
        each dimension is sampled from a uniform distribution of predefined bounds. Can either be:
        1) a number, in which case the scaling factor is independently sampled from the uniform distribution of bounds
        [1-scaling_bounds, 1+scaling_bounds] for each dimension.
        2) a sequence, in which case the scaling factor is sampled from the uniform distribution of bounds
        (1-scaling_bounds[i], 1+scaling_bounds[i]) for the i-th dimension.
        3) a numpy array of shape (2, n_dims), in which case the scaling factor is sampled from the uniform distribution
         of bounds (scaling_bounds[0, i], scaling_bounds[1, i]) for the i-th dimension.
        4) False, in which case scaling is completely turned off.
        Default is scaling_bounds = 0.15 (case 1)
        :param rotation_bounds: (optional) same as scaling bounds but for the rotation angle, except that for cases 1
        and 2, the bounds are centred on 0 rather than 1, i.e. [0+rotation_bounds[i], 0-rotation_bounds[i]].
        Default is rotation_bounds = 15.
        :param shearing_bounds: (optional) same as scaling bounds. Default is shearing_bounds = 0.012.
        :param translation_bounds: (optional) same as scaling bounds. Default is translation_bounds = False, but we
        encourage using it when cropping is deactivated (i.e. when output_shape=None in BrainGenerator).
        :param nonlin_std: (optional) Maximum value for the standard deviation of the normal distribution from which we
        sample the first tensor for synthesising the deformation field. Set to 0 if you wish to completely turn the
        elastic deformation off.
        :param nonlin_shape_factor: (optional) if nonlin_std is strictly positive, factor between the shapes of the
        input label maps and the shape of the input non-linear tensor.

        # blurring/resampling parameters
        :param randomise_res: (optional) whether to mimic images that would have been 1) acquired at low resolution, and
        2) resampled to high esolution. The low resolution is uniformly resampled at each minibatch from [1mm, 9mm].
        In that process, the images generated by sampling the GMM are:
        1) blurred at the sampled LR, 2) downsampled at LR, and 3) resampled at target_resolution.
        :param data_res: (optional) specific acquisition resolution to mimic, as opposed to random resolution sampled
        when randomis_res is True. This triggers a blurring which mimics the acquisition resolution, but downsampling
        is optional (see param downsample). Default for data_res is None, where images are slighlty blurred.
        If the generated images are uni-modal, data_res can be a number (isotropic acquisition resolution), a sequence,
        a 1d numpy array, or the path to a 1d numy array. In the multi-modal case, it should be given as a umpy array (
        or a path) of size (n_mod, n_dims), where each row is the acquisition resolution of the corresponding channel.
        :param thickness: (optional) if data_res is provided, we can further specify the slice thickness of the low
        resolution images to mimic. Must be provided in the same format as data_res. Default thickness = data_res.
        :param downsample: (optional) whether to actually downsample the volume images to data_res after blurring.
        Default is False, except when thickness is provided, and thickness < data_res.
        :param blur_range: (optional) Randomise the standard deviation of the blurring kernels, (whether data_res is
        given or not). At each mini_batch, the standard deviation of the blurring kernels are multiplied by a
        coefficient sampled from a uniform distribution with bounds [1/blur_range, blur_range].
        If None, no randomisation. Default is 1.15.

        # bias field parameters
        :param bias_field_std: (optional) If strictly positive, this triggers the corruption of synthesised images with
        a bias field. It is obtained by sampling a first small tensor from a normal distribution, resizing it to full
        size, and rescaling it to positive values by taking the voxel-wise exponential. bias_field_std designates the
        std dev of the normal distribution from which we sample the first tensor. Set to 0 to deactivate bias field.
        :param bias_shape_factor: (optional) If bias_field_std is strictly positive, this designates the ratio between
        the size of the input label maps and the size of the first sampled tensor for synthesising the bias field.
        """

        # prepare data files
        self.labels_paths = utils.list_images_in_folder(labels_dir)

        # generation parameters
        self.labels_shape, self.aff, self.n_dims, _, self.header, self.atlas_res = \
            utils.get_volume_info(self.labels_paths[0], aff_ref=np.eye(4))
        self.n_channels = n_channels
        if generation_labels is not None:
            self.generation_labels = utils.load_array_if_path(generation_labels)
        else:
            self.generation_labels, _ = utils.get_list_labels(labels_dir=labels_dir)
        if output_labels is not None:
            self.output_labels = utils.load_array_if_path(output_labels)
        else:
            self.output_labels = self.generation_labels
        if n_neutral_labels is not None:
            self.n_neutral_labels = n_neutral_labels
        else:
            self.n_neutral_labels = self.generation_labels.shape[0]
        self.target_res = utils.load_array_if_path(target_res)
        self.batchsize = batchsize
        # preliminary operations
        self.flipping = flipping
        self.output_shape = utils.load_array_if_path(output_shape)
        self.output_div_by_n = output_div_by_n
        # GMM parameters
        self.prior_distributions = prior_distributions
        if generation_classes is not None:
            self.generation_classes = utils.load_array_if_path(generation_classes)
            assert self.generation_classes.shape == self.generation_labels.shape, \
                'if provided, generation labels should have the same shape as generation_labels'
            unique_classes = np.unique(self.generation_classes)
            assert np.array_equal(unique_classes, np.arange(np.max(unique_classes)+1)), \
                'generation_classes should a linear range between 0 and its maximum value.'
        else:
            self.generation_classes = np.arange(self.generation_labels.shape[0])
        self.prior_means = utils.load_array_if_path(prior_means)
        self.prior_stds = utils.load_array_if_path(prior_stds)
        self.use_specific_stats_for_channel = use_specific_stats_for_channel
        # linear transformation parameters
        self.scaling_bounds = utils.load_array_if_path(scaling_bounds)
        self.rotation_bounds = utils.load_array_if_path(rotation_bounds)
        self.shearing_bounds = utils.load_array_if_path(shearing_bounds)
        self.translation_bounds = utils.load_array_if_path(translation_bounds)
        # elastic transformation parameters
        self.nonlin_std = nonlin_std
        self.nonlin_shape_factor = nonlin_shape_factor
        # blurring parameters
        self.randomise_res = randomise_res
        self.buil_distance_maps = buil_distance_maps
        self.data_res = utils.load_array_if_path(data_res)
        assert not (self.randomise_res & (self.data_res is not None)), \
            'randomise_res and data_res cannot be provided at the same time'
        self.thickness = utils.load_array_if_path(thickness)
        self.downsample = downsample
        self.blur_range = blur_range
        # bias field parameters
        self.bias_field_std = bias_field_std
        self.bias_shape_factor = bias_shape_factor

        # build transformation model
        self.labels_to_image_model, self.model_output_shape = self._build_labels_to_image_model()

        # build generator for model inputs
        self.model_inputs_generator = self._build_model_inputs_generator(mix_prior_and_random)

        # build brain generator
        self.brain_generator = self._build_brain_generator()
Esempio n. 4
0
def supervised_training(image_dir,
                        labels_dir,
                        model_dir,
                        path_segmentation_labels=None,
                        batchsize=1,
                        output_shape=None,
                        flipping=True,
                        scaling_bounds=.15,
                        rotation_bounds=15,
                        shearing_bounds=.012,
                        translation_bounds=False,
                        nonlin_std=3.,
                        nonlin_shape_factor=.04,
                        bias_field_std=.5,
                        bias_shape_factor=.025,
                        n_levels=5,
                        nb_conv_per_level=2,
                        conv_size=3,
                        unet_feat_count=24,
                        feat_multiplier=2,
                        dropout=0,
                        activation='elu',
                        lr=1e-4,
                        lr_decay=0,
                        wl2_epochs=5,
                        dice_epochs=100,
                        steps_per_epoch=1000,
                        checkpoint=None,
                        reinitialise_momentum=False,
                        freeze_layers=False):

    # check epochs
    assert (wl2_epochs > 0) | (dice_epochs > 0), \
        'either wl2_epochs or dice_epochs must be positive, had {0} and {1}'.format(wl2_epochs, dice_epochs)

    # prepare data files
    path_images = utils.list_images_in_folder(image_dir)
    path_labels = utils.list_images_in_folder(labels_dir)
    assert len(path_images) == len(path_labels), "There should be as many images as label maps."

    # get label lists
    label_list, n_neutral_labels = utils.get_list_labels(label_list=path_segmentation_labels, labels_dir=labels_dir,
                                                         FS_sort=True)
    n_labels = np.size(label_list)

    # create augmentation model and input generator
    im_shape, _, _, n_channels, _, _ = utils.get_volume_info(path_images[0], aff_ref=np.eye(4))
    augmentation_model = build_augmentation_model(im_shape,
                                                  n_channels,
                                                  label_list,
                                                  n_neutral_labels,
                                                  output_shape=output_shape,
                                                  output_div_by_n=2 ** n_levels,
                                                  flipping=flipping,
                                                  aff=np.eye(4),
                                                  scaling_bounds=scaling_bounds,
                                                  rotation_bounds=rotation_bounds,
                                                  shearing_bounds=shearing_bounds,
                                                  translation_bounds=translation_bounds,
                                                  nonlin_std=nonlin_std,
                                                  nonlin_shape_factor=nonlin_shape_factor,
                                                  bias_field_std=bias_field_std,
                                                  bias_shape_factor=bias_shape_factor)
    unet_input_shape = augmentation_model.output[0].get_shape().as_list()[1:]

    # prepare the segmentation model
    unet_model = nrn_models.unet(nb_features=unet_feat_count,
                                 input_shape=unet_input_shape,
                                 nb_levels=n_levels,
                                 conv_size=conv_size,
                                 nb_labels=n_labels,
                                 feat_mult=feat_multiplier,
                                 nb_conv_per_level=nb_conv_per_level,
                                 conv_dropout=dropout,
                                 batch_norm=-1,
                                 activation=activation,
                                 input_model=augmentation_model)

    # input generator
    input_generator = utils.build_training_generator(build_model_inputs(path_images, path_labels, batchsize), batchsize)

    # pre-training with weighted L2, input is fit to the softmax rather than the probabilities
    if wl2_epochs > 0:
        wl2_model = models.Model(unet_model.inputs, [unet_model.get_layer('unet_likelihood').output])
        wl2_model = metrics.metrics_model(wl2_model, label_list, 'wl2')
        train_model(wl2_model, input_generator, lr, lr_decay, wl2_epochs, steps_per_epoch, model_dir, 'wl2', checkpoint)
        checkpoint = os.path.join(model_dir, 'wl2_%03d.h5' % wl2_epochs)

    # freeze all layers but last if necessary (use -2 because the very last layer only applies softmax activation)
    if freeze_layers:
        for layer in unet_model.layers[:-2]:
            layer.trainable = False

    # fine-tuning with dice metric
    dice_model = metrics.metrics_model(unet_model, label_list, 'dice')
    train_model(dice_model, input_generator, lr, lr_decay, dice_epochs, steps_per_epoch, model_dir, 'dice', checkpoint,
                reinitialise_momentum=reinitialise_momentum)
Esempio n. 5
0
def build_model_inputs(path_label_maps,
                       n_labels,
                       batchsize=1,
                       n_channels=1,
                       generation_classes=None,
                       prior_distributions='uniform',
                       prior_means=None,
                       prior_stds=None,
                       use_specific_stats_for_channel=False,
                       mix_prior_and_random=False,
                       apply_linear_trans=True,
                       scaling_bounds=None,
                       rotation_bounds=None,
                       shearing_bounds=None,
                       background_paths=None):
    """
    This function builds a generator to be fed to the lab2im model. It enables to generate all the required inputs,
    according to the operations performed in the model.
    :param path_label_maps: list of the paths of the input label maps.
    :param n_labels: number of labels in the input label maps.
    :param batchsize: (optional) numbers of images to generate per mini-batch. Default is 1.
    :param n_channels: (optional) number of channels to be synthetised. Default is 1.
    :param generation_classes: (optional) Indices regrouping generation labels into classes of same intensity
    distribution. Regouped labels will thus share the same Gaussian when samling a new image. Can be a sequence or a
    1d numpy array. It should have the same length as generation_labels, and contain values between 0 and K-1, where K
    is the total number of classes. Default is all labels have different classes.
    :param prior_distributions: (optional) type of distribution from which we sample the GMM parameters.
    Can either be 'uniform', or 'normal'. Default is 'uniform'.
    :param prior_means: (optional) hyperparameters controlling the prior distributions of the GMM means. Because
    these prior distributions are uniform or normal, they require by 2 hyperparameters. Thus prior_means can be:
    1) a sequence of length 2, directly defining the two hyperparameters: [min, max] if prior_distributions is
    uniform, [mean, std] if the distribution is normal. The GMM means of are independently sampled at each
    mini_batch from the same distribution.
    2) an array of shape (2, K), where K is the number of classes (K=len(generation_labels) if generation_classes is
    not given). The mean of the Gaussian distribution associated to class k in [0, ...K-1] is sampled at each mini-batch
    from U(prior_means[0,k], prior_means[1,k]) if prior_distributions is uniform, or from
    N(prior_means[0,k], prior_means[1,k]) if prior_distributions is normal.
    3) an array of shape (2*n_mod, K), where each block of two rows is associated to hyperparameters derived
    from different modalities. In this case, if use_specific_stats_for_channel is False, we first randomly select a
    modality from the n_mod possibilities, and we sample the GMM means like in 2).
    If use_specific_stats_for_channel is True, each block of two rows correspond to a different channel
    (n_mod=n_channels), thus we select the corresponding block to each channel rather than randomly drawing it.
    4) the path to such a numpy array.
    Default is None, which corresponds to prior_means = [25, 225].
    :param prior_stds: (optional) same as prior_means but for the standard deviations of the GMM.
    Default is None, which corresponds to prior_stds = [5, 25].
    :param use_specific_stats_for_channel: (optional) whether the i-th block of two rows in the prior arrays must be
    only used to generate the i-th channel. If True, n_mod should be equal to n_channels. Default is False.
    :param mix_prior_and_random: (optional) if prior_means is not None, enables to reset the priors to their default
    values for half of thes cases, and thus generate images of random contrast.
    :param apply_linear_trans: (optional) whether to apply affine deformation. Default is True.
    :param scaling_bounds: (optional) if apply_linear_trans is True, the scaling factor for each dimension is
    sampled from a uniform distribution of predefined bounds. Can either be:
    1) a number, in which case the scaling factor is independently sampled from the uniform distribution of bounds
    (1-scaling_bounds, 1+scaling_bounds) for each dimension.
    2) a sequence, in which case the scaling factor is sampled from the uniform distribution of bounds
    (1-scaling_bounds[i], 1+scaling_bounds[i]) for the i-th dimension.
    3) a numpy array of shape (2, n_dims), in which case the scaling factor is sampled from the uniform distribution
     of bounds (scaling_bounds[0, i], scaling_bounds[1, i]) for the i-th dimension.
    If None (default), scaling_range = 0.15
    :param rotation_bounds: (optional) same as scaling bounds but for the rotation angle, except that for cases 1
    and 2, the bounds are centred on 0 rather than 1, i.e. (0+rotation_bounds[i], 0-rotation_bounds[i]).
    If None (default), rotation_bounds = 15.
    :param shearing_bounds: (optional) same as scaling bounds. If None (default), shearing_bounds = 0.01.
    :param background_paths: (optional) list of paths of label maps to replace the soft brain tissues (label 258) with.
    """

    # get label info
    _, _, n_dims, _, _, _ = utils.get_volume_info(path_label_maps[0])

    # allocate unique class to each label if generation classes is not given
    if generation_classes is None:
        generation_classes = np.arange(n_labels)

    # Generate!
    while True:

        # randomly pick as many images as batchsize
        indices = npr.randint(len(path_label_maps), size=batchsize)

        # initialise input lists
        list_label_maps = []
        list_means = []
        list_stds = []
        list_affine_transforms = []

        for idx in indices:

            # add labels to inputs
            y = utils.load_volume(path_label_maps[idx], dtype='int', aff_ref=np.eye(4))
            if background_paths is not None:
                idx_258 = np.where(y == 258)
                if np.any(idx_258):
                    background = utils.load_volume(background_paths[npr.randint(len(background_paths))],
                                                   dtype='int', aff_ref=np.eye(4))
                    background_shape = background.shape
                    if np.all(np.array(background_shape) == background_shape[0]):  # flip if same dimensions
                        background = np.flip(background, tuple([i for i in range(3) if np.random.normal() > 0]))
                    assert background.shape == y.shape, 'background patches should have same shape than training ' \
                                                        'labels. Had {0} and {1}'.format(background.shape, y.shape)
                    y[idx_258] = background[idx_258]
            list_label_maps.append(utils.add_axis(y, axis=-2))

            # add means and standard deviations to inputs
            means = np.empty((n_labels, 0))
            stds = np.empty((n_labels, 0))
            for channel in range(n_channels):

                # retrieve channel specific stats if necessary
                if isinstance(prior_means, np.ndarray):
                    if (prior_means.shape[0] > 2) & use_specific_stats_for_channel:
                        if prior_means.shape[0] / 2 != n_channels:
                            raise ValueError("the number of blocks in prior_means does not match n_channels. This "
                                             "message is printed because use_specific_stats_for_channel is True.")
                        tmp_prior_means = prior_means[2 * channel:2 * channel + 2, :]
                    else:
                        tmp_prior_means = prior_means
                else:
                    tmp_prior_means = prior_means
                if (prior_means is not None) & mix_prior_and_random & (npr.uniform() > 0.5):
                    tmp_prior_means = None
                if isinstance(prior_stds, np.ndarray):
                    if (prior_stds.shape[0] > 2) & use_specific_stats_for_channel:
                        if prior_stds.shape[0] / 2 != n_channels:
                            raise ValueError("the number of blocks in prior_stds does not match n_channels. This "
                                             "message is printed because use_specific_stats_for_channel is True.")
                        tmp_prior_stds = prior_stds[2 * channel:2 * channel + 2, :]
                    else:
                        tmp_prior_stds = prior_stds
                else:
                    tmp_prior_stds = prior_stds
                if (prior_stds is not None) & mix_prior_and_random & (npr.uniform() > 0.5):
                    tmp_prior_stds = None

                # draw means and std devs from priors
                tmp_classes_means = utils.draw_value_from_distribution(tmp_prior_means, n_labels, prior_distributions,
                                                                       125., 100., positive_only=True)
                tmp_classes_stds = utils.draw_value_from_distribution(tmp_prior_stds, n_labels, prior_distributions,
                                                                      15., 10., positive_only=True)
                tmp_means = utils.add_axis(tmp_classes_means[generation_classes], -1)
                tmp_stds = utils.add_axis(tmp_classes_stds[generation_classes], -1)
                means = np.concatenate([means, tmp_means], axis=1)
                stds = np.concatenate([stds, tmp_stds], axis=1)
            list_means.append(utils.add_axis(means))
            list_stds.append(utils.add_axis(stds))

            # add linear transform to inputs
            if apply_linear_trans:
                # get affine transformation: rotate, scale, shear (translation done during random cropping)
                scaling = utils.draw_value_from_distribution(scaling_bounds, size=n_dims, centre=1, default_range=.15)
                if n_dims == 2:
                    rotation = utils.draw_value_from_distribution(rotation_bounds, default_range=15.0)
                else:
                    rotation = utils.draw_value_from_distribution(rotation_bounds, size=n_dims, default_range=15.0)
                shearing = utils.draw_value_from_distribution(shearing_bounds, size=n_dims**2-n_dims, default_range=.01)
                affine_transform = utils.create_affine_transformation_matrix(n_dims, scaling, rotation, shearing)
                list_affine_transforms.append(utils.add_axis(affine_transform))

        # build list of inputs of augmentation model
        list_inputs = [list_label_maps, list_means, list_stds]
        if apply_linear_trans:
            list_inputs.append(list_affine_transforms)

        # concatenate individual input types if batchsize > 1
        if batchsize > 1:
            list_inputs = [np.concatenate(item, 0) for item in list_inputs]
        else:
            list_inputs = [item[0] for item in list_inputs]

        yield list_inputs
Esempio n. 6
0
def predict(path_images,
            path_model,
            segmentation_label_list,
            dist_map=False,
            path_segmentations=None,
            path_posteriors=None,
            path_volumes=None,
            segmentation_names_list=None,
            padding=None,
            cropping=None,
            resample=None,
            aff_ref='FS',
            sigma_smoothing=0,
            keep_biggest_component=False,
            conv_size=3,
            n_levels=5,
            nb_conv_per_level=2,
            unet_feat_count=24,
            feat_multiplier=2,
            activation='elu',
            gt_folder=None,
            evaluation_label_list=None,
            compute_distances=False,
            recompute=True,
            verbose=True):
    """
    This function uses trained models to segment images.
    It is crucial that the inputs match the architecture parameters of the trained model.
    :param path_images: path of the images to segment. Can be the path to a directory or the path to a single image.
    :param path_model: path ot the trained model.
    :param segmentation_label_list: List of labels for which to compute Dice scores. It should contain the same values
    as the segmentation label list used for training the network.
    Can be a sequence, a 1d numpy array, or the path to a numpy 1d array.
    :param dist_map: (optional) whether the input will contain distance maps channels (between each intenisty channels)
    Default is False.
    :param path_segmentations: (optional) path where segmentations will be writen.
    Should be a dir, if path_images is a dir, and afile if path_images is a file.
    Should not be None, if path_posteriors is None.
    :param path_posteriors: (optional) path where posteriors will be writen.
    Should be a dir, if path_images is a dir, and afile if path_images is a file.
    Should not be None, if path_segmentations is None.
    :param path_volumes: (optional) path of a csv file where the soft volumes of all segmented regions will be writen.
    The rows of the csv file correspond to subjects, and the columns correspond to segmentation labels.
    The soft volume of a structure corresponds to the sum of its predicted probability map.
    :param segmentation_names_list: (optional) List of names correponding to the names of the segmentation labels.
    Only used when path_volumes is provided. Must be of the same size as segmentation_label_list. Can be given as a
    list, a numpy array of strings, or the path to such a numpy array. Default is None.
    :param padding: (optional) pad the images to the specified shape before predicting the segmentation maps.
    Can be an int, a sequence or a 1d numpy array.
    :param cropping: (optional) crop the images to the specified shape before predicting the segmentation maps.
    If padding and cropping are specified, images are padded before being cropped.
    Can be an int, a sequence or a 1d numpy array.
    :param resample: (optional) resample the images to the specified resolution before predicting the segmentation maps.
    Can be an int, a sequence or a 1d numpy array.
    :param aff_ref: (optional) type of affine matrix of the images used for training. By default this is set to the
    FreeSurfer orientation ('FS'), as it was the configuration in which SynthSeg was trained. However, the new models
    are now trained on data aligned with identity vox2ras matrix, so you need to change aff_ref to 'identity'.
    :param sigma_smoothing: (optional) If not None, the posteriors are smoothed with a gaussian kernel of the specified
    standard deviation.
    :param keep_biggest_component: (optional) whether to only keep the biggest component in the predicted segmentation.
    :param conv_size: (optional) size of unet's convolution masks. Default is 3.
    :param n_levels: (optional) number of levels for unet. Default is 5.
    :param nb_conv_per_level: (optional) number of convolution layers per level. Default is 2.
    :param unet_feat_count: (optional) number of features for the first layer of the unet. Default is 24.
    :param feat_multiplier: (optional) multiplicative factor for the number of feature for each new level. Default is 2.
    :param activation: (optional) activation function. Can be 'elu', 'relu'.
    :param gt_folder: (optional) folder containing ground truth files for evaluation.
    A numpy array containing all dice scores (labels in rows, subjects in columns) will be writen either at
    segmentations_dir (if not None), or posteriors_dir.
    :param evaluation_label_list: (optional) if gt_folder is True you can evaluate the Dice scores on a subset of the
    segmentation labels, by providing another label list here. Can be a sequence, a 1d numpy array, or the path to a
    numpy 1d array. Default is the same as segmentation_label_list.
    :param recompute: (optional) whether to recompute segmentations that were already computed. This also applies to
    Dice scores, if gt_folder is not None. Default is True.
    :param verbose: (optional) whether to print out info about the remaining number of cases.
    """

    # prepare output filepaths
    images_to_segment, path_segmentations, path_posteriors, path_volumes, compute = \
        prepare_output_files(path_images, path_segmentations, path_posteriors, path_volumes, recompute)

    # get label and classes lists
    label_list, n_neutral_labels = utils.get_list_labels(label_list=segmentation_label_list, FS_sort=True)
    if evaluation_label_list is None:
        evaluation_label_list = segmentation_label_list

    # prepare volume file if needed
    if path_volumes is not None:
        if segmentation_names_list is not None:
            csv_header = [[''] + utils.reformat_to_list(segmentation_names_list, load_as_numpy=True)]
            csv_header += [[''] + [str(lab) for lab in label_list[1:]]]
        else:
            csv_header = [['subjects'] + [str(lab) for lab in label_list[1:]]]
        with open(path_volumes, 'w') as csvFile:
            writer = csv.writer(csvFile)
            writer.writerows(csv_header)
        csvFile.close()

    # perform segmentation
    net = None
    previous_model_input_shape = None
    loop_info = utils.LoopInfo(len(images_to_segment), 10, 'predicting', True)
    for idx, (path_image, path_segmentation, path_posterior, tmp_compute) in enumerate(zip(images_to_segment,
                                                                                           path_segmentations,
                                                                                           path_posteriors,
                                                                                           compute)):
        # compute segmentation only if needed
        if tmp_compute:

            # preprocess image and get information
            image, aff, h, im_res, n_channels, n_dims, shape, pad_shape, crop_idx = \
                preprocess_image(path_image, n_levels, cropping, padding, aff_ref=aff_ref, dist_map=dist_map)
            model_input_shape = list(image.shape[1:])

            # prepare net for first image or if input's size has changed
            if (net is None) | (previous_model_input_shape != model_input_shape):

                # check for image size compatibility
                if (net is not None) & (previous_model_input_shape != model_input_shape) & verbose:
                    print('image of different shape as previous ones, redefining network')
                previous_model_input_shape = model_input_shape

                # build network
                net = build_model(path_model, model_input_shape, resample, im_res, n_levels, len(label_list), conv_size,
                                  nb_conv_per_level, unet_feat_count, feat_multiplier, activation, sigma_smoothing)

            if verbose:
                loop_info.update(idx)

            # predict posteriors
            prediction_patch = net.predict(image)

            # get posteriors and segmentation
            seg, posteriors = postprocess(prediction_patch, pad_shape, shape, crop_idx, n_dims, label_list,
                                          keep_biggest_component, aff, aff_ref=aff_ref,
                                          keep_biggest_of_each_group=keep_biggest_component,
                                          n_neutral_labels=n_neutral_labels)

            # write results to disk
            if path_segmentation is not None:
                utils.save_volume(seg.astype('int'), aff, h, path_segmentation)
            if path_posterior is not None:
                if n_channels > 1:
                    posteriors = utils.add_axis(posteriors, axis=[0, -1])
                utils.save_volume(posteriors.astype('float'), aff, h, path_posterior)

        else:
            if path_volumes is not None:
                posteriors, _, _, _, _, _, im_res = utils.get_volume_info(path_posterior, True, aff_ref=np.eye(4))
            else:
                posteriors = im_res = None

        # compute volumes
        if path_volumes is not None:
            volumes = np.sum(posteriors[..., 1:], axis=tuple(range(0, len(posteriors.shape) - 1)))
            volumes = np.around(volumes * np.prod(im_res), 3)
            row = [os.path.basename(path_image).replace('.nii.gz', '')] + [str(vol) for vol in volumes]
            with open(path_volumes, 'a') as csvFile:
                writer = csv.writer(csvFile)
                writer.writerow(row)
            csvFile.close()

    # evaluate
    if gt_folder is not None:

        # find path evaluation folder
        path_first_result = path_segmentations[0] if (path_segmentations[0] is not None) else path_posteriors[0]
        eval_folder = os.path.dirname(path_first_result)

        # compute evaluation metrics
        evaluate.dice_evaluation(gt_folder,
                                 eval_folder,
                                 evaluation_label_list,
                                 compute_distances=compute_distances,
                                 compute_score_whole_structure=False,
                                 path_dice=os.path.join(eval_folder, 'dice.npy'),
                                 path_hausdorff=os.path.join(eval_folder, 'hausdorff.npy'),
                                 path_mean_distance=os.path.join(eval_folder, 'mean_distance.npy'),
                                 recompute=recompute,
                                 verbose=verbose)
Esempio n. 7
0
def preprocess_image(im_path, n_levels, crop_shape=None, padding=None):

    # read image and corresponding info
    im, shape, aff, n_dims, n_channels, header, im_res = utils.get_volume_info(
        im_path, True, np.eye(4))

    if padding:
        im = edit_volumes.pad_volume(im, padding_shape=padding)
        pad_shape = im.shape[:n_dims]
    else:
        pad_shape = shape

    # check that patch_shape or im_shape are divisible by 2**n_levels
    if crop_shape is not None:
        crop_shape = utils.reformat_to_list(crop_shape,
                                            length=n_dims,
                                            dtype='int')
        if not all(
            [pad_shape[i] >= crop_shape[i] for i in range(len(pad_shape))]):
            crop_shape = [
                min(pad_shape[i], crop_shape[i]) for i in range(n_dims)
            ]
        if not all([size % (2**n_levels) == 0 for size in crop_shape]):
            crop_shape = [
                utils.find_closest_number_divisible_by_m(size, 2**n_levels)
                for size in crop_shape
            ]
    else:
        if not all([size % (2**n_levels) == 0 for size in pad_shape]):
            crop_shape = [
                utils.find_closest_number_divisible_by_m(size, 2**n_levels)
                for size in pad_shape
            ]

    # crop image if necessary
    if crop_shape is not None:
        im, crop_idx = edit_volumes.crop_volume(im,
                                                cropping_shape=crop_shape,
                                                return_crop_idx=True)
    else:
        crop_idx = None

    # normalise image
    if n_channels == 1:
        m = np.min(im)
        M = np.max(im)
        if M == m:
            im = np.zeros(im.shape)
        else:
            im = (im - m) / (M - m)
    else:
        for i in range(im.shape[-1]):
            channel = im[..., i]
            m = np.min(channel)
            M = np.max(channel)
            if M == m:
                im[..., i] = np.zeros(channel.shape)
            else:
                im[..., i] = (channel - m) / (M - m)

    # add batch and channel axes
    im = utils.add_axis(im) if n_channels > 1 else utils.add_axis(im,
                                                                  axis=[0, -1])

    return im, aff, header, im_res, n_channels, n_dims, shape, pad_shape, crop_idx
Esempio n. 8
0
def preprocess_image(im_path, n_levels, crop_shape=None, padding=None):

    # read image and corresponding info
    im, shape, aff, n_dims, n_channels, header, labels_res = utils.get_volume_info(
        im_path, return_volume=True)

    if padding:
        if n_channels == 1:
            im = np.pad(im, padding, mode='constant')
            pad_shape = im.shape
        else:
            im = np.pad(im,
                        tuple([(padding, padding)] * n_dims + [(0, 0)]),
                        mode='constant')
            pad_shape = im.shape[:-1]
    else:
        pad_shape = shape

    # check that patch_shape or im_shape are divisible by 2**n_levels
    if crop_shape is not None:
        crop_shape = utils.reformat_to_list(crop_shape,
                                            length=n_dims,
                                            dtype='int')
        if not all(
            [pad_shape[i] >= crop_shape[i] for i in range(len(pad_shape))]):
            crop_shape = [
                min(pad_shape[i], crop_shape[i]) for i in range(n_dims)
            ]
            print(
                'cropping dimensions are higher than image size, changing cropping size to {}'
                .format(crop_shape))
        if not all([size % (2**n_levels) == 0 for size in crop_shape]):
            crop_shape = [
                utils.find_closest_number_divisible_by_m(size, 2**n_levels)
                for size in crop_shape
            ]
    else:
        if not all([size % (2**n_levels) == 0 for size in pad_shape]):
            crop_shape = [
                utils.find_closest_number_divisible_by_m(size, 2**n_levels)
                for size in pad_shape
            ]

    # crop image if necessary
    if crop_shape is not None:
        crop_idx = np.round(
            (pad_shape - np.array(crop_shape)) / 2).astype('int')
        crop_idx = np.concatenate((crop_idx, crop_idx + crop_shape), axis=0)
        im = edit_volumes.crop_volume_with_idx(im, crop_idx=crop_idx)
    else:
        crop_idx = None

    # align image
    # ref_axes = np.array([0, 2, 1])
    # ref_signs = np.array([-1, 1, -1])
    # im_axes, img_signs = utils.get_ras_axis_and_signs(aff, n_dims=n_dims)
    # im = edit_volume.align_volume_to_ref(im, ref_axes, ref_signs, im_axes, img_signs)

    # normalise image
    m = np.min(im)
    M = np.max(im)
    if M == m:
        im = np.zeros(im.shape)
    else:
        im = (im - m) / (M - m)

    # add batch and channel axes
    if n_channels > 1:
        im = utils.add_axis(im)
    else:
        im = utils.add_axis(im, -2)

    return im, aff, header, n_channels, n_dims, shape, pad_shape, crop_shape, crop_idx
Esempio n. 9
0
def training(image_dir,
             labels_dir,
             model_dir,
             path_label_list=None,
             save_label_list=None,
             n_neutral_labels=1,
             batchsize=1,
             target_res=None,
             output_shape=None,
             flipping=True,
             flip_rl_only=False,
             scaling_bounds=0.15,
             rotation_bounds=15,
             enable_90_rotations=False,
             shearing_bounds=.012,
             translation_bounds=False,
             nonlin_std=3.,
             nonlin_shape_factor=.04,
             bias_field_std=.3,
             bias_shape_factor=.025,
             same_bias_for_all_channels=False,
             augment_intensitites=True,
             noise_std=1.,
             augment_channels_separately=True,
             n_levels=5,
             nb_conv_per_level=2,
             conv_size=3,
             unet_feat_count=24,
             feat_multiplier=1,
             dropout=0,
             activation='elu',
             lr=1e-4,
             lr_decay=0,
             wl2_epochs=5,
             dice_epochs=200,
             steps_per_epoch=1000,
             checkpoint=None):
    """
    This function trains a neural network with aggressively augmented images. The model is implemented on the GPU and
    contains three sub-model: one for augmentation, one neural network (UNet), and one for computing the loss function.
    The network is pre-trained with a weighted sum of square error, in order to bring the weights in a favorable
    optimisation landscape. The training then continues with a soft dice loss function.

    :param image_dir: path of folder with all input images, or to a single image (if only one training example)
    :param labels_dir: path of folder with all input label maps, or to a single label map (if only one training example)
    labels maps and images are likend by sorting order.
    :param model_dir: path of a directory where the models will be saved during training.

    #---------------------------------------------- Generation parameters ----------------------------------------------
    # output-related parameters
    :param path_label_list: (optional) path to a numpy array containing all the label values to be segmented.
    By default, this is computed by taking all the label values in the training label maps.
    :param save_label_list: (optional) path where to write the computed list of segmentation labels.
    :param n_neutral_labels: (optional) number of non-sided labels in label_list. This is used for determining which
    label values to swap when right/left flipping the training examples. Default is 1 (to account for the background).
    :param batchsize: (optional) number of images per mini-batch. Default is 1.
    :param target_res: (optional) target resolution at which to produce the segmentation label maps. The training data
    will be resampled to this resolution before being run through the network. If None, no resampling is performed.
    Can be a number (isotropic resolution), or the path to a 1d numpy array.
    :param output_shape: (optional) desired shape of the output image, obtained by randomly cropping the generated image
    Can be an integer (same size in all dimensions), a sequence, a 1d numpy array, or the path to a 1d numpy array.

    # Augmentation parameters
    :param flipping: (optional) whether to introduce random flipping. Default is True.
    :param flip_rl_only: (optional) if flipping is True, whether to flip only in the right/left axis. Default is False.
    :param scaling_bounds: (optional) if apply_linear_trans is True, the scaling factor for each dimension is
    sampled from a uniform distribution of predefined bounds. scaling_bounds can either be:
    1) a number, in which case the scaling factor is independently sampled from the uniform distribution of bounds
    (1-scaling_bounds, 1+scaling_bounds) for each dimension.
    2) the path to a numpy array of shape (2, n_dims), in which case the scaling factor is sampled from the uniform
    distribution of bounds (scaling_bounds[0, i], scaling_bounds[1, i]) for the i-th dimension.
    3) False, in which case scaling is completely turned off.
    Default is scaling_bounds = 0.15
    :param rotation_bounds: (optional) same as scaling bounds but for the rotation angle, except that for case 1 the
    bounds are centred on 0 rather than 1, i.e. (0+rotation_bounds[i], 0-rotation_bounds[i]).
    Default is rotation_bounds = 15.
    :param enable_90_rotations: (optional) wheter to rotate the input by a random angle chosen in {0, 90, 180, 270}.
    This is done regardless of the value of rotation_bounds. If true, a different value is sampled for each dimension.
    :param shearing_bounds: (optional) same as scaling bounds. Default is shearing_bounds = 0.012.
    :param translation_bounds: (optional) same as scaling bounds. Default is translation_bounds = False, but we
    encourage using it when cropping is deactivated (i.e. when output_shape=None in BrainGenerator).
    :param nonlin_std: (optional) maximum value for the standard deviation of the normal distribution from which we
    sample the first tensor for synthesising the deformation field. Set to 0 to completely turn it off.
    :param nonlin_shape_factor: (optional) ratio between the size of the input label maps and the size of the sampled
    tensor for synthesising the elastic deformation field.
    :param bias_field_std: (optional) If strictly positive, this triggers the corruption of images with a bias field.
    It is obtained by sampling a first small tensor from a normal distribution, resizing it to full size, and rescaling
    it to positive values by taking the voxel-wise exponential. bias_field_std designates the std dev of the normal
    distribution from which we sample the first tensor. Set to 0 to deactivate biad field.
    :param bias_shape_factor: (optional) If bias_field_std is not False, this designates the ratio between the size
    of the input label maps and the size of the first sampled tensor for synthesising the bias field.
    :param same_bias_for_all_channels: (optional) If bias_field_std is not False, whether to apply the same bias field
    to all channels or not.
    :param augment_intensitites: (optional) whether to augment the intensities of the images with gamma augmentation.
    :param noise_std: (optional) if augment_intensities is true, maximum value for the standard deviation of the normal
    distribution from which we sample a Gaussian white noise. Set to False to deactivate white noise augmentation.
    Default value is 1.
    :param augment_channels_separately: (optional) whether to augment the intensities of each channel indenpendently.
    Only applied if augment_intensity is True, and the training images have several channels. Default is True.

    # ------------------------------------------ UNet architecture parameters ------------------------------------------
    :param n_levels: (optional) number of level for the Unet. Default is 5.
    :param nb_conv_per_level: (optional) number of convolutional layers per level. Default is 2.
    :param conv_size: (optional) size of the convolution kernels. Default is 2.
    :param unet_feat_count: (optional) number of feature for the first layr of the Unet. Default is 24.
    :param feat_multiplier: (optional) multiply the number of feature by this nummber at each new level. Default is 1.
    :param dropout: (optional) probability of dropout for the Unet. Deafult is 0, where no dropout is applied.
    :param activation: (optional) activation function. Can be 'elu', 'relu'.

    # ----------------------------------------------- Training parameters ----------------------------------------------
    :param lr: (optional) learning rate for the training. Default is 1e-4
    :param lr_decay: (optional) learing rate decay. Default is 0, where no decay is applied.
    :param wl2_epochs: (optional) number of epohs for which the network (except the soft-max layer) is trained with L2
    norm loss function. Default is 5.
    :param dice_epochs: (optional) number of epochs with the soft Dice loss function. default is 100.
    :param steps_per_epoch: (optional) number of steps per epoch. Default is 1000. Since no online validation is
    possible, this is equivalent to the frequency at which the models are saved.
    :param checkpoint: (optional) path of an already saved model to load before starting the training.
    """

    # check epochs
    assert (wl2_epochs > 0) | (dice_epochs > 0), \
        'either wl2_epochs or dice_epochs must be positive, had {0} and {1}'.format(wl2_epochs, dice_epochs)

    # prepare data files
    path_images = utils.list_images_in_folder(image_dir)
    path_label_maps = utils.list_images_in_folder(labels_dir)
    assert len(path_images) == len(path_label_maps), 'not the same number of training images and label maps.'

    # read info from image and get label list
    im_shape, _, _, n_channels, _, image_res = utils.get_volume_info(path_images[0], aff_ref=np.eye(4))
    label_list, _ = utils.get_list_labels(path_label_list, labels_dir=labels_dir, save_label_list=save_label_list)
    n_labels = np.size(label_list)

    # prepare model folder
    utils.mkdir(model_dir)

    # transformation model
    augmentation_model = build_augmentation_model(im_shape=im_shape,
                                                  n_channels=n_channels,
                                                  label_list=label_list,
                                                  n_neutral_labels=n_neutral_labels,
                                                  image_res=image_res,
                                                  target_res=target_res,
                                                  output_shape=output_shape,
                                                  output_div_by_n=2 ** n_levels,
                                                  flipping=flipping,
                                                  flip_rl_only=flip_rl_only,
                                                  aff=np.eye(4),
                                                  scaling_bounds=scaling_bounds,
                                                  rotation_bounds=rotation_bounds,
                                                  enable_90_rotations=enable_90_rotations,
                                                  shearing_bounds=shearing_bounds,
                                                  translation_bounds=translation_bounds,
                                                  nonlin_std=nonlin_std,
                                                  nonlin_shape_factor=nonlin_shape_factor,
                                                  bias_field_std=bias_field_std,
                                                  bias_shape_factor=bias_shape_factor,
                                                  same_bias_for_all_channels=same_bias_for_all_channels,
                                                  apply_intensity_augmentation=augment_intensitites,
                                                  noise_std=noise_std,
                                                  augment_channels_separately=augment_channels_separately)
    unet_input_shape = augmentation_model.output[0].get_shape().as_list()[1:]

    # prepare the segmentation model
    unet_model = nrn_models.unet(nb_features=unet_feat_count,
                                 input_shape=unet_input_shape,
                                 nb_levels=n_levels,
                                 conv_size=conv_size,
                                 nb_labels=n_labels,
                                 feat_mult=feat_multiplier,
                                 nb_conv_per_level=nb_conv_per_level,
                                 conv_dropout=dropout,
                                 batch_norm=-1,
                                 activation=activation,
                                 input_model=augmentation_model)

    # input generator
    train_example_gen = image_seg_generator(path_images=path_images,
                                            path_labels=path_label_maps,
                                            batchsize=batchsize,
                                            n_channels=n_channels)
    input_generator = utils.build_training_generator(train_example_gen, batchsize)

    # pre-training with weighted L2, input is fit to the softmax rather than the probabilities
    if wl2_epochs > 0:
        wl2_model = models.Model(unet_model.inputs, [unet_model.get_layer('unet_likelihood').output])
        wl2_model = metrics.metrics_model(input_model=wl2_model, metrics='wl2')
        train_model(wl2_model, input_generator, lr, lr_decay, wl2_epochs, steps_per_epoch, model_dir, 'wl2', checkpoint)
        checkpoint = os.path.join(model_dir, 'wl2_%03d.h5' % wl2_epochs)

    # fine-tuning with dice metric
    dice_model = metrics.metrics_model(input_model=unet_model, metrics='dice')
    train_model(dice_model, input_generator, lr, lr_decay, dice_epochs, steps_per_epoch, model_dir, 'dice', checkpoint)
Esempio n. 10
0
def training(image_dir,
             labels_dir,
             cropping=None,
             flipping=True,
             scaling_range=0.07,
             rotation_range=10,
             shearing_range=0.01,
             nonlin_std_dev=3,
             nonlin_shape_fact=0.04,
             crop_channel_2=None,
             conv_size=3,
             n_levels=5,
             nb_conv_per_level=2,
             feat_multiplier=2,
             dropout=0,
             unet_feat_count=24,
             no_batch_norm=False,
             lr=1e-4,
             lr_decay=0,
             batch_size=1,
             wl2_epochs=50,
             dice_epochs=500,
             steps_per_epoch=100,
             background_weight=1e-4,
             include_background=False,
             load_model_file=None,
             initial_epoch_wl2=0,
             initial_epoch_dice=0,
             path_label_list=None,
             model_dir=None):

    # check epochs
    assert (wl2_epochs > 0) | (dice_epochs > 0), \
        'either wl2_epochs or dice_epochs must be positive, had {0} and {1}'.format(wl2_epochs, dice_epochs)

    # prepare data files
    image_paths = utils.list_images_in_folder(image_dir)
    labels_paths = utils.list_images_in_folder(labels_dir)
    assert len(image_paths) == len(labels_paths), "There should be as many images as label maps."

    # get label and classes lists
    rotation_range = utils.load_array_if_path(rotation_range)
    scaling_range = utils.load_array_if_path(scaling_range)
    crop_channel_2 = utils.load_array_if_path(crop_channel_2)
    label_list, n_neutral_labels = utils.get_list_labels(label_list=path_label_list, FS_sort=True)
    n_labels = np.size(label_list)

    # prepare model folder
    if not os.path.isdir(model_dir):
        os.mkdir(model_dir)

    # prepare log folder
    log_dir = os.path.join(model_dir, 'logs')
    if not os.path.isdir(log_dir):
        os.mkdir(log_dir)

    # create augmentation model and input generator
    im_shape, aff, _, n_channels, _, _ = utils.get_volume_info(image_paths[0])
    augmentation_model, unet_input_shape = labels_to_image_model(im_shape=im_shape,
                                                                 n_channels=n_channels,
                                                                 crop_shape=cropping,
                                                                 label_list=label_list,
                                                                 n_neutral_labels=n_neutral_labels,
                                                                 vox2ras=aff,
                                                                 nonlin_shape_factor=nonlin_shape_fact,
                                                                 crop_channel2=crop_channel_2,
                                                                 output_div_by_n=2 ** n_levels,
                                                                 flipping=flipping)

    model_input_generator = build_model_input_generator(images_paths=image_paths,
                                                        labels_paths=labels_paths,
                                                        n_channels=n_channels,
                                                        im_shape=im_shape,
                                                        scaling_range=scaling_range,
                                                        rotation_range=rotation_range,
                                                        shearing_range=shearing_range,
                                                        nonlin_shape_fact=nonlin_shape_fact,
                                                        nonlin_std_dev=nonlin_std_dev,
                                                        batch_size=batch_size)
    training_generator = utils.build_training_generator(model_input_generator, batch_size)

    # prepare the segmentation model
    if no_batch_norm:
        batch_norm_dim = None
    else:
        batch_norm_dim = -1
    unet_model = nrn_models.unet(nb_features=unet_feat_count,
                                 input_shape=unet_input_shape,
                                 nb_levels=n_levels,
                                 conv_size=conv_size,
                                 nb_labels=n_labels,
                                 feat_mult=feat_multiplier,
                                 dilation_rate_mult=1,
                                 nb_conv_per_level=nb_conv_per_level,
                                 conv_dropout=dropout,
                                 batch_norm=batch_norm_dim,
                                 input_model=augmentation_model)

    # pre-training with weighted L2, input is fit to the softmax rather than the probabilities
    if wl2_epochs > 0:
        wl2_model = Model(unet_model.inputs, [unet_model.get_layer('unet_likelihood').output])
        wl2_model = metrics_model.metrics_model(input_shape=unet_input_shape[:-1] + [n_labels],
                                                segmentation_label_list=label_list,
                                                input_model=wl2_model,
                                                metrics='weighted_l2',
                                                weight_background=background_weight,
                                                name='metrics_model')
        if load_model_file is not None:
            print('loading', load_model_file)
            wl2_model.load_weights(load_model_file)
        train_model(wl2_model, training_generator, lr, lr_decay, wl2_epochs, steps_per_epoch, model_dir, log_dir,
                    'wl2', initial_epoch_wl2)

    # fine-tuning with dice metric
    if dice_epochs > 0:
        dice_model = metrics_model.metrics_model(input_shape=unet_input_shape[:-1] + [n_labels],
                                                 segmentation_label_list=label_list,
                                                 input_model=unet_model,
                                                 include_background=include_background,
                                                 name='metrics_model')
        if wl2_epochs > 0:
            last_wl2_model_name = os.path.join(model_dir, 'wl2_%03d.h5' % wl2_epochs)
            dice_model.load_weights(last_wl2_model_name, by_name=True)
        elif load_model_file is not None:
            print('loading', load_model_file)
            dice_model.load_weights(load_model_file)
        train_model(dice_model, training_generator, lr, lr_decay, dice_epochs, steps_per_epoch, model_dir, log_dir,
                    'dice', initial_epoch_dice)
Esempio n. 11
0
def preprocess_image(im_path,
                     n_levels,
                     target_res,
                     crop=None,
                     padding=None,
                     flip=False,
                     path_resample=None):

    # read image and corresponding info
    im, _, aff, n_dims, n_channels, header, im_res = utils.get_volume_info(
        im_path, True)

    # resample image if necessary
    if target_res is not None:
        target_res = np.squeeze(
            utils.reformat_to_n_channels_array(target_res, n_dims))
        if np.any((im_res > target_res + 0.05) | (im_res < target_res - 0.05)):
            im_res = target_res
            im, aff = edit_volumes.resample_volume(im, aff, im_res)
            if path_resample is not None:
                utils.save_volume(im, aff, header, path_resample)

    # align image
    im = edit_volumes.align_volume_to_ref(im,
                                          aff,
                                          aff_ref=np.eye(4),
                                          n_dims=n_dims)
    shape = list(im.shape)

    # pad image if specified
    if padding:
        im = edit_volumes.pad_volume(im, padding_shape=padding)
        pad_shape = im.shape[:n_dims]
    else:
        pad_shape = shape

    # check that patch_shape or im_shape are divisible by 2**n_levels
    if crop is not None:
        crop = utils.reformat_to_list(crop, length=n_dims, dtype='int')
        if not all([pad_shape[i] >= crop[i] for i in range(len(pad_shape))]):
            crop = [min(pad_shape[i], crop[i]) for i in range(n_dims)]
        if not all([size % (2**n_levels) == 0 for size in crop]):
            crop = [
                utils.find_closest_number_divisible_by_m(size, 2**n_levels)
                for size in crop
            ]
    else:
        if not all([size % (2**n_levels) == 0 for size in pad_shape]):
            crop = [
                utils.find_closest_number_divisible_by_m(size, 2**n_levels)
                for size in pad_shape
            ]

    # crop image if necessary
    if crop is not None:
        im, crop_idx = edit_volumes.crop_volume(im,
                                                cropping_shape=crop,
                                                return_crop_idx=True)
    else:
        crop_idx = None

    # normalise image
    if n_channels == 1:
        im = edit_volumes.rescale_volume(im,
                                         new_min=0.,
                                         new_max=1.,
                                         min_percentile=0.5,
                                         max_percentile=99.5)
    else:
        for i in range(im.shape[-1]):
            im[..., i] = edit_volumes.rescale_volume(im[..., i],
                                                     new_min=0.,
                                                     new_max=1.,
                                                     min_percentile=0.5,
                                                     max_percentile=99.5)

    # flip image along right/left axis
    if flip & (n_dims > 2):
        im_flipped = edit_volumes.flip_volume(im,
                                              direction='rl',
                                              aff=np.eye(4))
        im_flipped = utils.add_axis(
            im_flipped) if n_channels > 1 else utils.add_axis(im_flipped,
                                                              axis=[0, -1])
    else:
        im_flipped = None

    # add batch and channel axes
    im = utils.add_axis(im) if n_channels > 1 else utils.add_axis(im,
                                                                  axis=[0, -1])

    return im, aff, header, im_res, n_channels, n_dims, shape, pad_shape, crop_idx, im_flipped
Esempio n. 12
0
def predict(path_images,
            path_segmentations,
            path_model,
            segmentation_labels,
            n_neutral_labels=None,
            path_posteriors=None,
            path_resampled=None,
            path_volumes=None,
            segmentation_label_names=None,
            padding=None,
            cropping=None,
            target_res=1.,
            gradients=False,
            flip=True,
            topology_classes=None,
            sigma_smoothing=0.5,
            keep_biggest_component=True,
            conv_size=3,
            n_levels=5,
            nb_conv_per_level=2,
            unet_feat_count=24,
            feat_multiplier=2,
            activation='elu',
            gt_folder=None,
            evaluation_labels=None,
            mask_folder=None,
            list_incorrect_labels=None,
            list_correct_labels=None,
            compute_distances=False,
            recompute=True,
            verbose=True):
    """
    This function uses trained models to segment images.
    It is crucial that the inputs match the architecture parameters of the trained model.
    :param path_images: path of the images to segment. Can be the path to a directory or the path to a single image.
    :param path_segmentations: path where segmentations will be writen.
    Should be a dir, if path_images is a dir, and a file if path_images is a file.
    :param path_model: path ot the trained model.
    :param segmentation_labels: List of labels for which to compute Dice scores. It should be the same list as the
    segmentation_labels used in training.
    :param n_neutral_labels: (optional) if the label maps contain some right/left specific labels and if test-time
    flipping is applied (see parameter 'flip'), please provide the number of non-sided labels (including background).
    It should be the same value as for training. Default is None.
    :param path_posteriors: (optional) path where posteriors will be writen.
    Should be a dir, if path_images is a dir, and a file if path_images is a file.
    :param path_resampled: (optional) path where images resampled to 1mm isotropic will be writen.
    We emphasise that images are resampled as soon as the resolution in one of the axes is not in the range [0.9; 1.1].
    Should be a dir, if path_images is a dir, and a file if path_images is a file. Default is None, where resampled
    images are not saved.
    :param path_volumes: (optional) path of a csv file where the soft volumes of all segmented regions will be writen.
    The rows of the csv file correspond to subjects, and the columns correspond to segmentation labels.
    The soft volume of a structure corresponds to the sum of its predicted probability map.
    :param segmentation_label_names: (optional) List of names correponding to the names of the segmentation labels.
    Only used when path_volumes is provided. Must be of the same size as segmentation_labels. Can be given as a
    list, a numpy array of strings, or the path to such a numpy array. Default is None.
    :param padding: (optional) pad the images to the specified shape before predicting the segmentation maps.
    Can be an int, a sequence or a 1d numpy array.
    :param cropping: (optional) crop the images to the specified shape before predicting the segmentation maps.
    If padding and cropping are specified, images are padded before being cropped.
    Can be an int, a sequence or a 1d numpy array.
    :param target_res: (optional) target resolution at which the network operates (and thus resolution of the output
    segmentations). This must match the resolution of the training data ! target_res is used to automatically resampled
    the images with resolutions outside [target_res-0.05, target_res+0.05].
    Can be a sequence, a 1d numpy array. Set to None to disable the automatic resampling. Default is 1mm.
    :param flip: (optional) whether to perform test-time augmentation, where the input image is segmented along with
    a right/left flipped version on it. If set to True (default), be careful because this requires more memory.
    :param topology_classes: List of classes corresponding to all segmentation labels, in order to group them into
    classes, for each of which we will operate a smooth version of biggest connected component.
    Can be a sequence, a 1d numpy array, or the path to a numpy 1d array in the same order as segmentation_labels.
    Default is None, where no topological analysis is performed.
    :param sigma_smoothing: (optional) If not None, the posteriors are smoothed with a gaussian kernel of the specified
    standard deviation.
    :param keep_biggest_component: (optional) whether to only keep the biggest component in the predicted segmentation.
    This is applied independently of topology_classes, and it is applied to the whole segmentation
    :param conv_size: (optional) size of unet's convolution masks. Default is 3.
    :param n_levels: (optional) number of levels for unet. Default is 5.
    :param nb_conv_per_level: (optional) number of convolution layers per level. Default is 2.
    :param unet_feat_count: (optional) number of features for the first layer of the unet. Default is 24.
    :param feat_multiplier: (optional) multiplicative factor for the number of feature for each new level. Default is 2.
    :param activation: (optional) activation function. Can be 'elu', 'relu'.
    :param gt_folder: (optional) path of the ground truth label maps corresponding to the input images. Should be a dir,
    if path_images is a dir, or a file if path_images is a file.
    Providing a gt_folder will trigger a Dice evaluation, where scores will be writen along with the path_segmentations.
    Specifically, the scores are contained in a numpy array, where labels are in rows, and subjects in columns.
    :param evaluation_labels: (optional) if gt_folder is True you can evaluate the Dice scores on a subset of the
    segmentation labels, by providing another label list here. Can be a sequence, a 1d numpy array, or the path to a
    numpy 1d array. Default is np.unique(segmentation_labels).
    :param mask_folder: (optional) path of masks that will be used to mask out some parts of the obtained segmentations
    during the evaluation. Default is None, where nothing is masked.
    :param list_incorrect_labels: (optional) this option enables to replace some label values in the obtained
    segmentations by other label values. Can be a list, a 1d numpy array, or the path to such an array.
    :param list_correct_labels: (optional) list of values to correct the labels specified in list_incorrect_labels.
    Correct values must have the same order as their corresponding value in list_incorrect_labels.
    :param compute_distances: (optional) whether to add Hausdorff and mean surface distance evaluations to the default
    Dice evaluation. Default is True.
    :param recompute: (optional) whether to recompute segmentations that were already computed. This also applies to
    Dice scores, if gt_folder is not None. Default is True.
    :param verbose: (optional) whether to print out info about the remaining number of cases.
    """

    # prepare input/output filepaths
    path_images, path_segmentations, path_posteriors, path_resampled, path_volumes, compute = \
        prepare_output_files(path_images, path_segmentations, path_posteriors, path_resampled, path_volumes, recompute)

    # get label list
    segmentation_labels, _ = utils.get_list_labels(
        label_list=segmentation_labels)
    n_labels = len(segmentation_labels)

    # get unique label values, and build correspondance table between contralateral structures if necessary
    if (n_neutral_labels is not None) & flip:
        n_sided_labels = int((n_labels - n_neutral_labels) / 2)
        lr_corresp = np.stack([
            segmentation_labels[n_neutral_labels:n_neutral_labels +
                                n_sided_labels],
            segmentation_labels[n_neutral_labels + n_sided_labels:]
        ])
        segmentation_labels, indices = np.unique(segmentation_labels,
                                                 return_index=True)
        lr_corresp_unique, lr_corresp_indices = np.unique(lr_corresp[0, :],
                                                          return_index=True)
        lr_corresp_unique = np.stack(
            [lr_corresp_unique, lr_corresp[1, lr_corresp_indices]])
        lr_corresp_unique = lr_corresp_unique[:, 1:] if not np.all(
            lr_corresp_unique[:, 0]) else lr_corresp_unique
        lr_indices = np.zeros_like(lr_corresp_unique)
        for i in range(lr_corresp_unique.shape[0]):
            for j, lab in enumerate(lr_corresp_unique[i]):
                lr_indices[i, j] = np.where(segmentation_labels == lab)[0]
    else:
        segmentation_labels, indices = np.unique(segmentation_labels,
                                                 return_index=True)
        lr_indices = None

    # prepare topology classes
    if topology_classes is not None:
        topology_classes = utils.load_array_if_path(
            topology_classes, load_as_numpy=True)[indices]

    # prepare volume file if needed
    if path_volumes is not None:
        if segmentation_label_names is not None:
            segmentation_label_names = utils.load_array_if_path(
                segmentation_label_names)[indices]
            csv_header = [[''] + segmentation_label_names[1:].tolist()]
            csv_header += [[''] +
                           [str(lab) for lab in segmentation_labels[1:]]]
        else:
            csv_header = [['subjects'] +
                          [str(lab) for lab in segmentation_labels[1:]]]
        with open(path_volumes, 'w') as csvFile:
            writer = csv.writer(csvFile)
            writer.writerows(csv_header)
        csvFile.close()

    # build network
    _, _, n_dims, n_channels, _, _ = utils.get_volume_info(path_images[0])
    model_input_shape = [None] * n_dims + [n_channels]
    net = build_model(path_model, model_input_shape, n_levels,
                      len(segmentation_labels), conv_size, nb_conv_per_level,
                      unet_feat_count, feat_multiplier, activation,
                      sigma_smoothing, gradients)

    # perform segmentation
    loop_info = utils.LoopInfo(len(path_images), 10, 'predicting', True)
    for idx, (path_image, path_segmentation, path_posterior, path_resample, tmp_compute) in \
            enumerate(zip(path_images, path_segmentations, path_posteriors, path_resampled, compute)):

        # compute segmentation only if needed
        if tmp_compute:
            if verbose:
                loop_info.update(idx)

            # preprocessing
            image, aff, h, im_res, _, _, shape, pad_shape, crop_idx, im_flipped = \
                preprocess_image(path_image, n_levels, target_res, cropping, padding, flip, path_resample)

            # prediction
            prediction_patch = net.predict(image)
            prediction_patch_flip = net.predict(im_flipped) if flip else None

            # postprocessing
            seg, posteriors = postprocess(
                prediction_patch,
                pad_shape,
                shape,
                crop_idx,
                n_dims,
                segmentation_labels,
                lr_indices,
                keep_biggest_component,
                aff,
                topology_classes=topology_classes,
                post_patch_flip=prediction_patch_flip)

            # write results to disk
            if path_segmentation is not None:
                utils.save_volume(seg,
                                  aff,
                                  h,
                                  path_segmentation,
                                  dtype='int32')
            if path_posterior is not None:
                if n_channels > 1:
                    posteriors = utils.add_axis(posteriors, axis=[0, -1])
                utils.save_volume(posteriors,
                                  aff,
                                  h,
                                  path_posterior,
                                  dtype='float32')

        else:
            if path_volumes is not None:
                posteriors, _, _, _, _, _, im_res = utils.get_volume_info(
                    path_posterior, True, aff_ref=np.eye(4))
            else:
                posteriors = im_res = None

        # compute volumes
        if path_volumes is not None:
            volumes = np.sum(posteriors[..., 1:],
                             axis=tuple(range(0,
                                              len(posteriors.shape) - 1)))
            volumes = np.around(volumes * np.prod(im_res), 3)
            row = [os.path.basename(path_image).replace('.nii.gz', '')
                   ] + [str(vol) for vol in volumes]
            with open(path_volumes, 'a') as csvFile:
                writer = csv.writer(csvFile)
                writer.writerow(row)
            csvFile.close()

    # evaluate
    if gt_folder is not None:

        # find path where segmentations are saved evaluation folder, and get labels on which to evaluate
        eval_folder = os.path.dirname(path_segmentations[0])
        if evaluation_labels is None:
            evaluation_labels = segmentation_labels

        # set path of result arrays for surface distance if necessary
        if compute_distances:
            path_hausdorff = os.path.join(eval_folder, 'hausdorff.npy')
            path_hausdorff_99 = os.path.join(eval_folder, 'hausdorff_99.npy')
            path_hausdorff_95 = os.path.join(eval_folder, 'hausdorff_95.npy')
            path_mean_distance = os.path.join(eval_folder, 'mean_distance.npy')
        else:
            path_hausdorff = path_hausdorff_99 = path_hausdorff_95 = path_mean_distance = None

        # compute evaluation metrics
        evaluate.evaluation(gt_folder,
                            eval_folder,
                            evaluation_labels,
                            mask_dir=mask_folder,
                            path_dice=os.path.join(eval_folder, 'dice.npy'),
                            path_hausdorff=path_hausdorff,
                            path_hausdorff_99=path_hausdorff_99,
                            path_hausdorff_95=path_hausdorff_95,
                            path_mean_distance=path_mean_distance,
                            list_incorrect_labels=list_incorrect_labels,
                            list_correct_labels=list_correct_labels,
                            recompute=recompute,
                            verbose=verbose)
Esempio n. 13
0
def build_model_inputs(path_label_maps,
                       n_labels,
                       batchsize=1,
                       n_channels=1,
                       generation_classes=None,
                       prior_distributions='uniform',
                       prior_means=None,
                       prior_stds=None,
                       use_specific_stats_for_channel=False,
                       mix_prior_and_random=False):
    """
    This function builds a generator to be fed to the lab2im model. It enables to generate all the required inputs,
    according to the operations performed in the model.
    :param path_label_maps: list of the paths of the input label maps.
    :param n_labels: number of labels in the input label maps.
    :param batchsize: (optional) numbers of images to generate per mini-batch. Default is 1.
    :param n_channels: (optional) number of channels to be synthetised. Default is 1.
    :param generation_classes: (optional) Indices regrouping generation labels into classes of same intensity
    distribution. Regouped labels will thus share the same Gaussian when samling a new image. Can be a sequence or a
    1d numpy array. It should have the same length as generation_labels, and contain values between 0 and K-1, where K
    is the total number of classes. Default is all labels have different classes.
    :param prior_distributions: (optional) type of distribution from which we sample the GMM parameters.
    Can either be 'uniform', or 'normal'. Default is 'uniform'.
    :param prior_means: (optional) hyperparameters controlling the prior distributions of the GMM means. Because
    these prior distributions are uniform or normal, they require by 2 hyperparameters. Thus prior_means can be:
    1) a sequence of length 2, directly defining the two hyperparameters: [min, max] if prior_distributions is
    uniform, [mean, std] if the distribution is normal. The GMM means of are independently sampled at each
    mini_batch from the same distribution.
    2) an array of shape (2, K), where K is the number of classes (K=len(generation_labels) if generation_classes is
    not given). The mean of the Gaussian distribution associated to class k in [0, ...K-1] is sampled at each mini-batch
    from U(prior_means[0,k], prior_means[1,k]) if prior_distributions is uniform, or from
    N(prior_means[0,k], prior_means[1,k]) if prior_distributions is normal.
    3) an array of shape (2*n_mod, K), where each block of two rows is associated to hyperparameters derived
    from different modalities. In this case, if use_specific_stats_for_channel is False, we first randomly select a
    modality from the n_mod possibilities, and we sample the GMM means like in 2).
    If use_specific_stats_for_channel is True, each block of two rows correspond to a different channel
    (n_mod=n_channels), thus we select the corresponding block to each channel rather than randomly drawing it.
    4) the path to such a numpy array.
    Default is None, which corresponds to prior_means = [25, 225].
    :param prior_stds: (optional) same as prior_means but for the standard deviations of the GMM.
    Default is None, which corresponds to prior_stds = [5, 25].
    :param use_specific_stats_for_channel: (optional) whether the i-th block of two rows in the prior arrays must be
    only used to generate the i-th channel. If True, n_mod should be equal to n_channels. Default is False.
    :param mix_prior_and_random: (optional) if prior_means is not None, enables to reset the priors to their default
    values for half of thes cases, and thus generate images of random contrast.
    """

    # get label info
    _, _, n_dims, _, _, _ = utils.get_volume_info(path_label_maps[0])

    # allocate unique class to each label if generation classes is not given
    if generation_classes is None:
        generation_classes = np.arange(n_labels)

    # Generate!
    while True:

        # randomly pick as many images as batchsize
        indices = npr.randint(len(path_label_maps), size=batchsize)

        # initialise input lists
        list_label_maps = []
        list_means = []
        list_stds = []

        for idx in indices:

            # add labels to inputs
            lab = utils.load_volume(path_label_maps[idx], dtype='int', aff_ref=np.eye(4))
            list_label_maps.append(utils.add_axis(lab, axis=[0, -1]))

            # add means and standard deviations to inputs
            means = np.empty((1, n_labels, 0))
            stds = np.empty((1, n_labels, 0))
            for channel in range(n_channels):

                # retrieve channel specific stats if necessary
                if isinstance(prior_means, np.ndarray):
                    if (prior_means.shape[0] > 2) & use_specific_stats_for_channel:
                        if prior_means.shape[0] / 2 != n_channels:
                            raise ValueError("the number of blocks in prior_means does not match n_channels. This "
                                             "message is printed because use_specific_stats_for_channel is True.")
                        tmp_prior_means = prior_means[2 * channel:2 * channel + 2, :]
                    else:
                        tmp_prior_means = prior_means
                else:
                    tmp_prior_means = prior_means
                if (prior_means is not None) & mix_prior_and_random & (npr.uniform() > 0.5):
                    tmp_prior_means = None
                if isinstance(prior_stds, np.ndarray):
                    if (prior_stds.shape[0] > 2) & use_specific_stats_for_channel:
                        if prior_stds.shape[0] / 2 != n_channels:
                            raise ValueError("the number of blocks in prior_stds does not match n_channels. This "
                                             "message is printed because use_specific_stats_for_channel is True.")
                        tmp_prior_stds = prior_stds[2 * channel:2 * channel + 2, :]
                    else:
                        tmp_prior_stds = prior_stds
                else:
                    tmp_prior_stds = prior_stds
                if (prior_stds is not None) & mix_prior_and_random & (npr.uniform() > 0.5):
                    tmp_prior_stds = None

                # draw means and std devs from priors
                tmp_classes_means = utils.draw_value_from_distribution(tmp_prior_means, n_labels, prior_distributions,
                                                                       125., 100., positive_only=True)
                tmp_classes_stds = utils.draw_value_from_distribution(tmp_prior_stds, n_labels, prior_distributions,
                                                                      15., 10., positive_only=True)
                if npr.uniform() > 0.95:  # reset the background to 0 in 10% of cases
                    tmp_classes_means[0] = 0
                    tmp_classes_stds[0] = 0
                tmp_means = utils.add_axis(tmp_classes_means[generation_classes], axis=[0, -1])
                tmp_stds = utils.add_axis(tmp_classes_stds[generation_classes], axis=[0, -1])
                means = np.concatenate([means, tmp_means], axis=-1)
                stds = np.concatenate([stds, tmp_stds], axis=-1)
            list_means.append(means)
            list_stds.append(stds)

        # build list of inputs for generation model
        list_inputs = [list_label_maps, list_means, list_stds]
        if batchsize > 1:  # concatenate each input type if batchsize > 1
            list_inputs = [np.concatenate(item, 0) for item in list_inputs]
        else:
            list_inputs = [item[0] for item in list_inputs]

        yield list_inputs
Esempio n. 14
0
    def __init__(self,
                 labels_dir,
                 generation_labels=None,
                 output_labels=None,
                 n_neutral_labels=None,
                 padding_margin=None,
                 batch_size=1,
                 n_channels=1,
                 target_res=None,
                 output_shape=None,
                 output_div_by_n=None,
                 prior_distributions='uniform',
                 generation_classes=None,
                 prior_means=None,
                 prior_stds=None,
                 use_specific_stats_for_channel=False,
                 flipping=True,
                 apply_linear_trans=True,
                 scaling_bounds=None,
                 rotation_bounds=None,
                 shearing_bounds=None,
                 apply_nonlin_trans=True,
                 nonlin_std=3.,
                 nonlin_shape_factor=0.0625,
                 blur_background=True,
                 data_res=None,
                 thickness=None,
                 downsample=False,
                 blur_range=1.15,
                 crop_channel_2=None,
                 apply_bias_field=True,
                 bias_field_std=0.3,
                 bias_shape_factor=0.025):
        """
        This class is wrapper around the labels_to_image_model model. It contains the GPU model that generates images
        from labels maps, and a python generator that suplies the input data for this model.
        To generate pairs of image/labels you can just call the method generate_image() on an object of this class.

        :param labels_dir: path of folder with all input label maps, or to a single label map.

        # IMPORTANT !!!
        # Each time we provide a parameter with separate values for each axis (e.g. with a numpy array or a sequence),
        # these values refer to the axes of the raw label map (i.e. once it has been loaded in python).
        # Depending on the label map orientation, the axes of its raw array may or may not correspond to the RAS axes.

        # label maps-related parameters
        :param generation_labels: (optional) list of all possible label values in the input label maps.
        Default is None, where the label values are directly gotten from the provided label maps.
        If not None, can be a sequence or a 1d numpy array, or the path to a 1d numpy array.
        If flipping is true (i.e. right/left flipping is enabled), generation_labels should be organised as follows:
        background label first, then non-sided labels (e.g. CSF, brainstem, etc.), then all the structures of the same
        hemisphere (can be left or right), and finally all the corresponding contralateral structures in the same order.
        :param output_labels: (optional) list of all the label values to keep in the output label maps (in no particular
        order). Should be a subset of the values contained in generation_labels.
        Label values that are in generation_labels but not in output_labels are reset to zero.
        Can be a sequence, a 1d numpy array, or the path to a 1d numpy array.
        :param n_neutral_labels: (optional) number of non-sided generation labels.
        Default is total number of label values.
        :param padding_margin: (optional) margin by which to pad the input labels with zeros.
        Padding is applied prior to any other operation.
        Can be an integer (same padding in all dimensions), a sequence, a 1d numpy array, or the path to a 1d numpy
        array. Default is no padding.

        # output-related parameters
        :param batch_size: (optional) numbers of images to generate per mini-batch. Default is 1.
        :param n_channels: (optional) number of channels to be synthetised. Default is 1.
        :param target_res: (optional) target resolution of the generated images and corresponding label maps.
        If None, the outputs will have the same resolution as the input label maps.
        Can be a number (isotropic resolution), a sequence, a 1d numpy array, or the path to a 1d numpy array.
        :param output_shape: (optional) shape of the output image, obtained by randomly cropping the generated image.
        Can be an integer (same size in all dimensions), a sequence, a 1d numpy array, or the path to a 1d numpy array.
        :param output_div_by_n: (optional) forces the output shape to be divisible by this value. It overwrites
        output_shape if necessary. Can be an integer (same size in all dimensions), a sequence, a 1d numpy array, or
        the path to a 1d numpy array.

        # GMM-sampling parameters
        :param generation_classes: (optional) Indices regrouping generation labels into classes of same intensity
        distribution. Regouped labels will thus share the same Gaussian when samling a new image. Can be a sequence, a
        1d numpy array, or the path to a 1d numpy array. It should have the same length as generation_labels, and
        contain values between 0 and K-1, where K is the total number of classes.
        Default is all labels have different classes (K=len(generation_labels)).
        :param prior_distributions: (optional) type of distribution from which we sample the GMM parameters.
        Can either be 'uniform', or 'normal'. Default is 'uniform'.
        :param prior_means: (optional) hyperparameters controlling the prior distributions of the GMM means. Because
        these prior distributions are uniform or normal, they require by 2 hyperparameters. Thus prior_means can be:
        1) a sequence of length 2, directly defining the two hyperparameters: [min, max] if prior_distributions is
        uniform, [mean, std] if the distribution is normal. The GMM means of are independently sampled at each
        mini_batch from the same distribution.
        2) an array of shape (2, K), where K is the number of classes (K=len(generation_labels) if generation_classes is
        not given). The mean of the Gaussian distribution associated to class k in [0, ...K-1] is sampled at each
        mini-batch from U(prior_means[0,k], prior_means[1,k]) if prior_distributions is uniform, and from
        N(prior_means[0,k], prior_means[1,k]) if prior_distributions is normal.
        3) an array of shape (2*n_mod, K), where each block of two rows is associated to hyperparameters derived
        from different modalities. In this case, if use_specific_stats_for_channel is False, we first randomly select a
        modality from the n_mod possibilities, and we sample the GMM means like in 2).
        If use_specific_stats_for_channel is True, each block of two rows correspond to a different channel
        (n_mod=n_channels), thus we select the corresponding block to each channel rather than randomly drawing it.
        4) the path to such a numpy array.
        Default is None, which corresponds to prior_means = [25, 225].
        :param prior_stds: (optional) same as prior_means but for the standard deviations of the GMM.
        Default is None, which corresponds to prior_stds = [5, 25].
        :param use_specific_stats_for_channel: (optional) whether the i-th block of two rows in the prior arrays must be
        only used to generate the i-th channel. If True, n_mod should be equal to n_channels. Default is False.

        # spatial deformation parameters
        :param flipping: (optional) whether to introduce right/left random flipping. Default is True.
        :param apply_linear_trans: (optional) whether to apply affine deformation. Default is True.
        :param scaling_bounds: (optional) if apply_linear_trans is True, the scaling factor for each dimension is
        sampled from a uniform distribution of predefined bounds. Can either be:
        1) a number, in which case the scaling factor is independently sampled from the uniform distribution of bounds
        (1-scaling_bounds, 1+scaling_bounds) for each dimension.
        2) a sequence, in which case the scaling factor is sampled from the uniform distribution of bounds
        (1-scaling_bounds[i], 1+scaling_bounds[i]) for the i-th dimension.
        3) a numpy array of shape (2, n_dims), in which case the scaling factor is sampled from the uniform distribution
         of bounds (scaling_bounds[0, i], scaling_bounds[1, i]) for the i-th dimension.
        4) the path to such a numpy array.
        If None (default), scaling_range = 0.15
        :param rotation_bounds: (optional) same as scaling bounds but for the rotation angle, except that for cases 1
        and 2, the bounds are centred on 0 rather than 1, i.e. (0+rotation_bounds[i], 0-rotation_bounds[i]).
        If None (default), rotation_bounds = 15.
        :param shearing_bounds: (optional) same as scaling bounds. If None (default), shearing_bounds = 0.01.
        :param apply_nonlin_trans: (optional) whether to apply non linear elastic deformation.
        If true, a diffeomorphic deformation field is obtained by first sampling a small tensor from the normal
        distribution, resizing it to image size, and integrationg it. Default is True.
        :param nonlin_std: (optional) If apply_nonlin_trans is True, maximum value for the standard deviation of the
        normal distribution from which we sample the first tensor for synthesising the deformation field.
        :param nonlin_shape_factor: (optional) If apply_nonlin_trans is True, ratio between the size of the input label
        maps and the size of the sampled tensor for synthesising the deformation field.

        # blurring/resampling parameters
        :param blur_background: (optional) whether to produce an unrealistic background or not.
        If True, the background is generated/blurred with the other labels, according to the values of prior_means and
        prior_stds. Also, it is reset to zero-background with a probability of 0.2.
        If False, the background is reset to zero, or can be replaced by a low-intensity background with a probability
        of 0.5. Additionally we correct for edge blurring effects.
        Default is True.
        :param data_res: (optional) acquisition resolution to mimick. If provided, the images sampled from the GMM are
        blurred to mimick data that would be: 1) acquired at the given acquisition resolution, and 2) resample at
        target_resolution.
        Default is None, where images are isotropically blurred to introduce some spatial correlation between voxels.
        If the generated images are uni-modal, data_res can be a number (isotropic acquisition resolution), a sequence,
        a 1d numpy array, or the path to a 1d numy array. In the multi-modal case, it should be given as a numpy array
        (or a path) of size (n_mod, n_dims), where each row is the acquisition resolution of the correspionding chanel.
        :param thickness: (optional) if data_res is provided, we can further specify the slice thickness of the low
        resolution images to mimick.
        If the generated images are uni-modal, data_res can be a number (isotropic acquisition resolution), a sequence,
        a 1d numpy array, or the path to a 1d numy array. In the multi-modal case, it should be given as a numpy array
        (or a path) of size (n_mod, n_dims), where each row is the acquisition resolution of the correspionding chanel.
        :param downsample: (optional) whether to actually downsample the volume image to data_res.
        Default is False, except when thickness is provided, and thickness < data_res.
        :param blur_range: (optional) Randomise the standard deviation of the blurring kernels, (whether data_res is
        given or not). At each mini_batch, the standard deviation of the blurring kernels are multiplied by a
        coefficient sampled from a uniform distribution with bounds [1/blur_range, blur_range].
        If None, no randomisation. Default is 1.15.
        :param crop_channel_2: (optional) stats for cropping second channel along the anterior-posterior axis.
        Should be a vector of length 4, with bounds of uniform distribution for cropping the front and back of the image
        (in percentage). None is no croppping.

        # bias field parameters
        :param apply_bias_field: (optional) whether to apply a bias field to the final image. Default is True.
        If True, the bias field is obtained by sampling a first tensor from normal distribution, resizing it to image
        size, and rescaling the values to positive number by taking the voxel-wise exponential. Default is True.
        :param bias_field_std: (optional) If apply_nonlin_trans is True, maximum value for the standard deviation of the
        normal distribution from which we sample the first tensor for synthesising the bias field.
        :param bias_shape_factor: (optional) If apply_bias_field is True, ratio between the size of the input
        label maps and the size of the sampled tensor for synthesising the bias field.
        """

        # prepare data files
        if ('.nii.gz' in labels_dir) | ('.nii' in labels_dir) | (
                '.mgz' in labels_dir) | ('.npz' in labels_dir):
            self.labels_paths = [labels_dir]
        else:
            self.labels_paths = utils.list_images_in_folder(labels_dir)
        assert len(self.labels_paths) > 0, "Could not find any training data"

        # generation parameters
        self.labels_shape, self.aff, self.n_dims, _, self.header, self.atlas_res = \
            utils.get_volume_info(self.labels_paths[0])
        self.n_channels = n_channels
        if generation_labels is not None:
            self.generation_labels = utils.load_array_if_path(
                generation_labels)
        else:
            self.generation_labels = utils.get_list_labels(
                labels_dir=labels_dir)
        if output_labels is not None:
            self.output_labels = utils.load_array_if_path(output_labels)
        else:
            self.output_labels = self.generation_labels
        if n_neutral_labels is not None:
            self.n_neutral_labels = n_neutral_labels
        else:
            self.n_neutral_labels = self.generation_labels.shape[0]
        self.target_res = utils.load_array_if_path(target_res)
        # preliminary operations
        self.padding_margin = utils.load_array_if_path(padding_margin)
        self.flipping = flipping
        self.output_shape = utils.load_array_if_path(output_shape)
        self.output_div_by_n = output_div_by_n
        # GMM parameters
        self.prior_distributions = prior_distributions
        if generation_classes is not None:
            self.generation_classes = utils.load_array_if_path(
                generation_classes)
            assert self.generation_classes.shape == self.generation_labels.shape, \
                'if provided, generation labels should have the same shape as generation_labels'
            unique_classes = np.unique(self.generation_classes)
            assert np.array_equal(unique_classes, np.arange(np.max(unique_classes)+1)), \
                'generation_classes should a linear range between 0 and its maximum value.'
        else:
            self.generation_classes = np.arange(
                self.generation_labels.shape[0])
        self.prior_means = utils.load_array_if_path(prior_means)
        self.prior_stds = utils.load_array_if_path(prior_stds)
        self.use_specific_stats_for_channel = use_specific_stats_for_channel
        # linear transformation parameters
        self.apply_linear_trans = apply_linear_trans
        self.scaling_bounds = utils.load_array_if_path(scaling_bounds)
        self.rotation_bounds = utils.load_array_if_path(rotation_bounds)
        self.shearing_bounds = utils.load_array_if_path(shearing_bounds)
        # elastic transformation parameters
        self.apply_nonlin_trans = apply_nonlin_trans
        self.nonlin_std = nonlin_std
        self.nonlin_shape_factor = nonlin_shape_factor
        # blurring parameters
        self.blur_background = blur_background
        self.data_res = utils.load_array_if_path(data_res)
        self.thickness = utils.load_array_if_path(thickness)
        self.downsample = downsample
        self.blur_range = blur_range
        self.crop_second_channel = utils.load_array_if_path(crop_channel_2)
        # bias field parameters
        self.apply_bias_field = apply_bias_field
        self.bias_field_std = bias_field_std
        self.bias_shape_factor = bias_shape_factor

        # build transformation model
        self.labels_to_image_model, self.model_output_shape = self._build_labels_to_image_model(
        )

        # build generator for model inputs
        self.model_inputs_generator = self._build_model_inputs_generator(
            batch_size)

        # build brain generator
        self.brain_generator = self._build_brain_generator()
Esempio n. 15
0
def preprocess_image(im_path, n_levels, crop_shape=None, padding=None, aff_ref='FS'):

    # read image and corresponding info
    im, shape, aff, n_dims, n_channels, header, im_res = utils.get_volume_info(im_path, return_volume=True)

    if padding:
        if n_channels == 1:
            im = np.pad(im, padding, mode='constant')
            pad_shape = im.shape
        else:
            im = np.pad(im, tuple([(padding, padding)] * n_dims + [(0, 0)]), mode='constant')
            pad_shape = im.shape[:-1]
    else:
        pad_shape = shape

    # check that patch_shape or im_shape are divisible by 2**n_levels
    if crop_shape is not None:
        crop_shape = utils.reformat_to_list(crop_shape, length=n_dims, dtype='int')
        if not all([pad_shape[i] >= crop_shape[i] for i in range(len(pad_shape))]):
            crop_shape = [min(pad_shape[i], crop_shape[i]) for i in range(n_dims)]
            print('cropping dimensions are higher than image size, changing cropping size to {}'.format(crop_shape))
        if not all([size % (2**n_levels) == 0 for size in crop_shape]):
            crop_shape = [utils.find_closest_number_divisible_by_m(size, 2 ** n_levels) for size in crop_shape]
    else:
        if not all([size % (2**n_levels) == 0 for size in pad_shape]):
            crop_shape = [utils.find_closest_number_divisible_by_m(size, 2 ** n_levels) for size in pad_shape]

    # crop image if necessary
    if crop_shape is not None:
        crop_idx = np.round((pad_shape - np.array(crop_shape)) / 2).astype('int')
        crop_idx = np.concatenate((crop_idx, crop_idx + crop_shape), axis=0)
        im = edit_volumes.crop_volume_with_idx(im, crop_idx=crop_idx)
    else:
        crop_idx = None

    # align image to training axes and directions
    if n_dims > 2:
        if aff_ref == 'FS':
            aff_ref = np.array([[-1., 0., 0., 0.], [0., 0., 1., 0.], [0., -1., 0., 0.], [0., 0., 0., 1.]])
            im = edit_volumes.align_volume_to_ref(im, aff, aff_ref=aff_ref, return_aff=False)
        elif aff_ref == 'identity':
            aff_ref = np.eye(4)
            im = edit_volumes.align_volume_to_ref(im, aff, aff_ref=aff_ref, return_aff=False)
        elif aff_ref == 'MS':
            aff_ref = np.array([[-1., 0., 0., 0.], [0., -1., 0., 0.], [0., 0., 1., 0.], [0., 0., 0., 1.]])
            im = edit_volumes.align_volume_to_ref(im, aff, aff_ref=aff_ref, return_aff=False)

    # normalise image
    if n_channels == 1:
        m = np.min(im)
        M = np.max(im)
        if M == m:
            im = np.zeros(im.shape)
        else:
            im = (im - m) / (M - m)
    if n_channels > 1:
        for i in range(im.shape[-1]):
            channel = im[..., i]
            m = np.min(channel)
            M = np.max(channel)
            if M == m:
                im[..., i] = np.zeros(channel.shape)
            else:
                im[..., i] = (channel - m) / (M - m)

    # add batch and channel axes
    if n_channels > 1:
        im = utils.add_axis(im)
    else:
        im = utils.add_axis(im, -2)

    return im, aff, header, im_res, n_channels, n_dims, shape, pad_shape, crop_shape, crop_idx