def get_shapes(crop_shape, im_shape, div_by_n):
    n_dims, _ = utils.get_dims(im_shape)
    # crop_shape specified
    if crop_shape is not None:
        crop_shape = utils.reformat_to_list(crop_shape,
                                            length=n_dims,
                                            dtype='int')
        crop_shape = [min(im_shape[i], crop_shape[i]) for i in range(n_dims)]
        # make sure output shape is divisible by output_div_by_n
        if div_by_n is not None:
            tmp_shape = [
                utils.find_closest_number_divisible_by_m(s,
                                                         div_by_n,
                                                         smaller_ans=True)
                for s in crop_shape
            ]
            if crop_shape != tmp_shape:
                print('crop shape {0} not divisible by {1}, changed to {2}'.
                      format(crop_shape, div_by_n, tmp_shape))
                crop_shape = tmp_shape
    # no crop_shape, so no cropping unless image shape is not divisible by output_div_by_n
    else:
        if div_by_n is not None:
            tmp_shape = [
                utils.find_closest_number_divisible_by_m(s,
                                                         div_by_n,
                                                         smaller_ans=True)
                for s in im_shape
            ]
            if tmp_shape != im_shape:
                print('image shape {0} not divisible by {1}, cropped to {2}'.
                      format(im_shape, div_by_n, tmp_shape))
                crop_shape = tmp_shape
    return crop_shape
def labels_to_image_model(labels_shape,
                          n_channels,
                          generation_labels,
                          output_labels,
                          n_neutral_labels,
                          atlas_res,
                          target_res,
                          output_shape=None,
                          output_div_by_n=None,
                          flipping=True,
                          aff=None,
                          scaling_bounds=0.15,
                          rotation_bounds=15,
                          shearing_bounds=0.012,
                          translation_bounds=False,
                          nonlin_std=4.,
                          nonlin_shape_factor=.0625,
                          randomise_res=False,
                          buil_distance_maps=False,
                          data_res=None,
                          thickness=None,
                          downsample=False,
                          blur_range=1.15,
                          bias_field_std=.5,
                          bias_shape_factor=.025):
    """
    This function builds a keras/tensorflow model to generate images from provided label maps.
    The images are generated by sampling a Gaussian Mixture Model (of given parameters), conditionned on the label map.
    The model will take as inputs:
        -a label map
        -a vector containing the means of the Gaussian Mixture Model for each label,
        -a vector containing the standard deviations of the Gaussian Mixture Model for each label,
        -if apply_affine_deformation is True: a batch*(n_dims+1)*(n_dims+1) affine matrix
        -if apply_non_linear_deformation is True: a small non linear field of size batch*(dim_1*...*dim_n)*n_dims that
        will be resampled to labels size and integrated, to obtain a diffeomorphic elastic deformation.
        -if apply_bias_field is True: a small bias field of size batch*(dim_1*...*dim_n)*1 that will be resampled to
        labels size and multiplied to the image, to add a "bias-field" noise.
    The model returns:
        -the generated image normalised between 0 and 1.
        -the corresponding label map, with only the labels present in output_labels (the other are reset to zero).
    # IMPORTANT !!!
    # Each time we provide a parameter with separate values for each axis (e.g. with a numpy array or a sequence),
    # these values refer to the RAS axes.
    :param labels_shape: shape of the input label maps. Can be a sequence or a 1d numpy array.
    :param n_channels: number of channels to be synthetised.
    :param generation_labels: (optional) list of all possible label values in the input label maps.
    Default is None, where the label values are directly gotten from the provided label maps.
    If not None, can be a sequence or a 1d numpy array. It should be organised as follows: background label first, then
    non-sided labels (e.g. CSF, brainstem, etc.), then all the structures of the same hemisphere (can be left or right),
    and finally all the corresponding contralateral structures (in the same order).
    :param output_labels: list of all the label values to keep in the output label maps, in no particular order.
    Should be a subset of the values contained in generation_labels.
    Label values that are in generation_labels but not in output_labels are reset to zero.
    Can be a sequence or a 1d numpy array. By default output_labels is equal to generation_labels.
    :param n_neutral_labels: number of non-sided generation labels.
    :param atlas_res: resolution of the input label maps.
    Can be a number (isotropic resolution), a sequence, or a 1d numpy array.
    :param target_res: target resolution of the generated images and corresponding label maps.
    Can be a number (isotropic resolution), a sequence, or a 1d numpy array.
    :param output_shape: (optional) desired shape of the output image, obtained by randomly cropping the generated image
    Can be an integer (same size in all dimensions), a sequence, a 1d numpy array, or the path to a 1d numpy array.
    Default is None, where no cropping is performed.
    :param output_div_by_n: (optional) forces the output shape to be divisible by this value. It overwrites output_shape
    if necessary. Can be an integer (same size in all dimensions), a sequence, or a 1d numpy array.
    :param flipping: (optional) whether to introduce right/left random flipping
    :param aff: (optional) example of an (n_dims+1)x(n_dims+1) affine matrix of one of the input label map.
    Used to find brain's right/left axis. Should be given if flipping is True.
    :param scaling_bounds: (optional) range of the random saling to apply at each mini-batch. The scaling factor for
    each dimension is sampled from a uniform distribution of predefined bounds. Can either be:
    1) a number, in which case the scaling factor is independently sampled from the uniform distribution of bounds
    [1-scaling_bounds, 1+scaling_bounds] for each dimension.
    2) a sequence, in which case the scaling factor is sampled from the uniform distribution of bounds
    (1-scaling_bounds[i], 1+scaling_bounds[i]) for the i-th dimension.
    3) a numpy array of shape (2, n_dims), in which case the scaling factor is sampled from the uniform distribution
     of bounds (scaling_bounds[0, i], scaling_bounds[1, i]) for the i-th dimension.
    4) False, in which case scaling is completely turned off.
    Default is scaling_bounds = 0.15 (case 1)
    :param rotation_bounds: (optional) same as scaling bounds but for the rotation angle, except that for cases 1
    and 2, the bounds are centred on 0 rather than 1, i.e. [0+rotation_bounds[i], 0-rotation_bounds[i]].
    Default is rotation_bounds = 15.
    :param shearing_bounds: (optional) same as scaling bounds. Default is shearing_bounds = 0.012.
    :param translation_bounds: (optional) same as scaling bounds. Default is translation_bounds = False, but we
    encourage using it when cropping is deactivated (i.e. when output_shape=None in BrainGenerator).
    :param nonlin_std: (optional) Maximum value for the standard deviation of the normal distribution from which we
    sample the first tensor for synthesising the deformation field. Set to 0 if you wish to completely turn the elastic
    deformation off.
    :param nonlin_shape_factor: (optional) if nonlin_std is strictly positive, factor between the shapes of the input
    label maps and the shape of the input non-linear tensor.
    :param randomise_res: (optional) whether to mimic images that would have been 1) acquired at low resolution, and
    2) resampled to high esolution. The low resolution is uniformly resampled at each minibatch from [1mm, 9mm].
    In that process, the images generated by sampling the GMM are 1) blurred at the sampled LR, 2) downsampled at LR,
    and 3) resampled at target_resolution.
    :param data_res: (optional) specific acquisition resolution to mimic, as opposed to random resolution sampled when
    randomis_res is True. This triggers a blurring to mimic the specified acquisition resolution, but the downsampling
    is optional (see param downsample). Default for data_res is None, where images are slighlty blurred.
    If the generated images are uni-modal, data_res can be a number (isotropic acquisition resolution), a sequence, a 1d
    numpy array, or the path to a 1d numy array. In the multi-modal case, it should be given as a numpy array (or a
    path) of size (n_mod, n_dims), where each row is the acquisition resolution of the corresponding channel.
    :param thickness: (optional) if data_res is provided, we can further specify the slice thickness of the low
    resolution images to mimic. Must be provided in the same format as data_res. Default thickness = data_res.
    :param downsample: (optional) whether to actually downsample the volume images to data_res after blurring.
    Default is False, except when thickness is provided, and thickness < data_res.
    :param blur_range: (optional) Randomise the standard deviation of the blurring kernels, (whether data_res is given
    or not). At each mini_batch, the standard deviation of the blurring kernels are multiplied by a coefficient sampled
    from a uniform distribution with bounds [1/blur_range, blur_range]. If None, no randomisation. Default is 1.15.
    :param bias_field_std: (optional) If strictly positive, this triggers the corruption of synthesised images with a
    bias field. It is obtained by sampling a first small tensor from a normal distribution, resizing it to full size,
    and rescaling it to positive values by taking the voxel-wise exponential. bias_field_std designates the std dev of
    the normal distribution from which we sample the first tensor. Set to 0 to deactivate biad field corruption.
    :param bias_shape_factor: (optional) If bias_field_std is strictly positive, this designates the ratio between the
    size of the input label maps and the size of the first sampled tensor for synthesising the bias field.
    """

    # reformat resolutions
    labels_shape = utils.reformat_to_list(labels_shape)
    n_dims, _ = utils.get_dims(labels_shape)
    atlas_res = utils.reformat_to_n_channels_array(atlas_res, n_dims,
                                                   n_channels)
    data_res = atlas_res if (
        data_res is None) else utils.reformat_to_n_channels_array(
            data_res, n_dims, n_channels)
    thickness = data_res if (
        thickness is None) else utils.reformat_to_n_channels_array(
            thickness, n_dims, n_channels)
    downsample = utils.reformat_to_list(
        downsample,
        n_channels) if downsample else (np.min(thickness - data_res, 1) < 0)
    atlas_res = atlas_res[0]
    target_res = atlas_res if (
        target_res is None) else utils.reformat_to_n_channels_array(
            target_res, n_dims)[0]

    # get shapes
    crop_shape, output_shape = get_shapes(labels_shape, output_shape,
                                          atlas_res, target_res,
                                          output_div_by_n)

    # create new_label_list and corresponding LUT to make sure that labels go from 0 to N-1
    new_generation_labels, lut = utils.rearrange_label_list(generation_labels)

    # define model inputs
    labels_input = KL.Input(shape=labels_shape + [1], name='labels_input')
    means_input = KL.Input(shape=list(new_generation_labels.shape) +
                           [n_channels],
                           name='means_input')
    stds_input = KL.Input(shape=list(new_generation_labels.shape) +
                          [n_channels],
                          name='std_devs_input')

    # convert labels to new_label_list
    labels = l2i_et.convert_labels(labels_input, lut)

    # deform labels
    if (scaling_bounds is not False) | (rotation_bounds is not False) | (shearing_bounds is not False) | \
       (translation_bounds is not False) | (nonlin_std > 0):
        labels._keras_shape = tuple(labels.get_shape().as_list())
        labels = layers.RandomSpatialDeformation(
            scaling_bounds=scaling_bounds,
            rotation_bounds=rotation_bounds,
            shearing_bounds=shearing_bounds,
            translation_bounds=translation_bounds,
            nonlin_std=nonlin_std,
            nonlin_shape_factor=nonlin_shape_factor,
            inter_method='nearest')(labels)

    # cropping
    if crop_shape != labels_shape:
        labels._keras_shape = tuple(labels.get_shape().as_list())
        labels = layers.RandomCrop(crop_shape)(labels)

    # flipping
    if flipping:
        assert aff is not None, 'aff should not be None if flipping is True'
        labels._keras_shape = tuple(labels.get_shape().as_list())
        labels = layers.RandomFlip(
            get_ras_axes(aff, n_dims)[0], True, new_generation_labels,
            n_neutral_labels)(labels)

    # build synthetic image
    labels._keras_shape = tuple(labels.get_shape().as_list())
    image = layers.SampleConditionalGMM()([labels, means_input, stds_input])

    # apply bias field
    if bias_field_std > 0:
        image._keras_shape = tuple(image.get_shape().as_list())
        image = layers.BiasFieldCorruption(bias_field_std, bias_shape_factor,
                                           False)(image)

    # intensity augmentation
    image._keras_shape = tuple(image.get_shape().as_list())
    image = layers.IntensityAugmentation(clip=300,
                                         normalise=True,
                                         gamma_std=.4,
                                         separate_channels=True)(image)

    # loop over channels
    channels = list()
    split = KL.Lambda(lambda x: tf.split(x, [1] * n_channels, axis=-1))(
        image) if (n_channels > 1) else [image]
    for i, channel in enumerate(split):

        channel._keras_shape = tuple(channel.get_shape().as_list())

        if randomise_res:
            max_res = np.array([9.] * 3)
            resolution, blur_res = layers.SampleResolution(
                atlas_res, max_res, .05, return_thickness=True)(means_input)
            sigma = l2i_et.blurring_sigma_for_downsampling(atlas_res,
                                                           resolution,
                                                           thickness=blur_res)
            channel = layers.DynamicGaussianBlur(
                0.75 * max_res / np.array(atlas_res),
                blur_range)([channel, sigma])
            if buil_distance_maps:
                channel, dist = layers.MimicAcquisition(
                    atlas_res, atlas_res, output_shape,
                    True)([channel, resolution])
                channels.extend([channel, dist])
            else:
                channel = layers.MimicAcquisition(atlas_res, atlas_res,
                                                  output_shape,
                                                  False)([channel, resolution])
                channels.append(channel)

        else:
            sigma = l2i_et.blurring_sigma_for_downsampling(
                atlas_res, data_res[i], thickness=thickness[i])
            channel = layers.GaussianBlur(sigma, blur_range)(channel)
            if downsample[i]:
                resolution = KL.Lambda(lambda x: tf.convert_to_tensor(
                    data_res[i], dtype='float32'))([])
                channel = layers.MimicAcquisition(atlas_res, data_res[i],
                                                  output_shape)(
                                                      [channel, resolution])
            elif output_shape != crop_shape:
                channel = nrn_layers.Resize(size=output_shape)(channel)
            channels.append(channel)

    # concatenate all channels back
    image = KL.Lambda(lambda x: tf.concat(x, -1))(
        channels) if len(channels) > 1 else channels[0]

    # resample labels at target resolution
    if crop_shape != output_shape:
        labels = l2i_et.resample_tensor(labels,
                                        output_shape,
                                        interp_method='nearest')

    # convert labels back to original values and reset unwanted labels to zero
    labels = l2i_et.convert_labels(labels, generation_labels)
    labels._keras_shape = tuple(labels.get_shape().as_list())
    reset_values = [v for v in generation_labels if v not in output_labels]
    labels = layers.ResetValuesToZero(reset_values, name='labels_out')(labels)

    # build model (dummy layer enables to keep the labels when plugging this model to other models)
    image = KL.Lambda(lambda x: x[0], name='image_out')([image, labels])
    brain_model = Model(inputs=[labels_input, means_input, stds_input],
                        outputs=[image, labels])

    return brain_model
Esempio n. 3
0
    n_labels = segmentation_label_list.shape[0]
    _, lut = utils.rearrange_label_list(segmentation_label_list)
    labels_gt = KL.Lambda(lambda x: tf.gather(tf.convert_to_tensor(lut, dtype='int32'),
                                              tf.cast(x, dtype='int32')), name='metric_convert_labels')(labels_gt)

    # convert gt labels to probabilistic values
    labels_gt = KL.Lambda(lambda x: tf.one_hot(tf.cast(x, dtype='int32'), depth=n_labels, axis=-1))(labels_gt)
    labels_gt = KL.Reshape(input_shape)(labels_gt)
    labels_gt = KL.Lambda(lambda x: K.clip(x / K.sum(x, axis=-1, keepdims=True), K.epsilon(), 1),
                          name='prob_target')(labels_gt)

    # crop output to evaluate loss function in centre patch
    if loss_cropping is not None:
        # format loss_cropping
        labels_shape = labels_gt.get_shape().as_list()[1:-1]
        n_dims, _ = utils.get_dims(labels_shape)
        if isinstance(loss_cropping, (int, float)):
            loss_cropping = [loss_cropping] * n_dims
        if isinstance(loss_cropping, (list, tuple)):
            if len(loss_cropping) == 1:
                loss_cropping = loss_cropping * n_dims
            elif len(loss_cropping) != n_dims:
                raise TypeError('loss_cropping should be float, list of size 1 or {0}, or None. '
                                'Had {1}'.format(n_dims, loss_cropping))
        # perform cropping
        begin_idx = [int((labels_shape[i] - loss_cropping[i]) / 2) for i in range(n_dims)]
        labels_gt = KL.Lambda(
            lambda x: tf.slice(x, begin=tf.convert_to_tensor([0] + begin_idx + [0], dtype='int32'),
                               size=tf.convert_to_tensor([-1] + loss_cropping + [-1], dtype='int32')),
            name='cropping_gt')(labels_gt)
        last_tensor = KL.Lambda(
Esempio n. 4
0
def build_augmentation_model(im_shape,
                             n_channels,
                             segmentation_labels,
                             n_neutral_labels,
                             output_shape=None,
                             output_div_by_n=None,
                             flipping=True,
                             aff=None,
                             scaling_bounds=0.15,
                             rotation_bounds=15,
                             shearing_bounds=0.012,
                             translation_bounds=False,
                             nonlin_std=3.,
                             nonlin_shape_factor=.0625,
                             bias_field_std=.3,
                             bias_shape_factor=.025):

    # reformat resolutions and get shapes
    im_shape = utils.reformat_to_list(im_shape)
    n_dims, _ = utils.get_dims(im_shape)
    crop_shape = get_shapes(im_shape, output_shape, output_div_by_n)

    # create new_label_list and corresponding LUT to make sure that labels go from 0 to N-1
    new_seg_labels, lut = utils.rearrange_label_list(segmentation_labels)

    # define model inputs
    image_input = KL.Input(shape=im_shape+[n_channels], name='image_input')
    labels_input = KL.Input(shape=im_shape + [1], name='labels_input')

    # convert labels to new_label_list
    labels = convert_labels(labels_input, lut)

    # deform labels
    if (scaling_bounds is not False) | (rotation_bounds is not False) | (shearing_bounds is not False) | \
       (translation_bounds is not False) | (nonlin_std > 0):
        labels._keras_shape = tuple(labels.get_shape().as_list())
        labels, image = layers.RandomSpatialDeformation(scaling_bounds=scaling_bounds,
                                                        rotation_bounds=rotation_bounds,
                                                        shearing_bounds=shearing_bounds,
                                                        translation_bounds=translation_bounds,
                                                        nonlin_std=nonlin_std,
                                                        nonlin_shape_factor=nonlin_shape_factor,
                                                        inter_method=['nearest', 'linear'])([labels, image_input])
    else:
        image = image_input

    # crop labels
    if crop_shape != im_shape:
        labels._keras_shape = tuple(labels.get_shape().as_list())
        image._keras_shape = tuple(image.get_shape().as_list())
        labels, image = layers.RandomCrop(crop_shape)([labels, image])

    # flip labels
    if flipping:
        assert aff is not None, 'aff should not be None if flipping is True'
        labels._keras_shape = tuple(labels.get_shape().as_list())
        image._keras_shape = tuple(image.get_shape().as_list())
        labels, image = layers.RandomFlip(flip_axis=get_ras_axes(aff, n_dims)[0], swap_labels=[True, False],
                                          label_list=new_seg_labels, n_neutral_labels=n_neutral_labels)([labels, image])

    # apply bias field
    if bias_field_std > 0:
        image._keras_shape = tuple(image.get_shape().as_list())
        image = layers.BiasFieldCorruption(bias_field_std, bias_shape_factor, False)(image)
        image = KL.Lambda(lambda x: tf.cast(x, dtype='float32'), name='image_biased')(image)

    # intensity augmentation
    image._keras_shape = tuple(image.get_shape().as_list())
    image = layers.IntensityAugmentation(10, clip=False, normalise=True, gamma_std=.5, separate_channels=True)(image)
    image = KL.Lambda(lambda x: tf.cast(x, dtype='float32'), name='image_augmented')(image)

    # convert labels back to original values and reset unwanted labels to zero
    labels = convert_labels(labels, segmentation_labels)

    # build model (dummy layer enables to keep the labels when plugging this model to other models)
    labels = KL.Lambda(lambda x: tf.cast(x, dtype='int32'), name='labels_out')(labels)
    image = KL.Lambda(lambda x: x[0], name='image_out')([image, labels])
    brain_model = models.Model(inputs=[image_input, labels_input], outputs=[image, labels])

    return brain_model
Esempio n. 5
0
def build_model(model_file, input_shape, resample, im_res, n_levels, n_lab, conv_size, nb_conv_per_level,
                unet_feat_count, feat_multiplier, activation, sigma_smoothing):

    assert os.path.isfile(model_file), "The provided model path does not exist."

    # initialisation
    net = None
    n_dims, n_channels = utils.get_dims(input_shape, max_channels=10)
    resample = utils.reformat_to_list(resample, length=n_dims)

    # build preprocessing model
    if resample is not None:
        im_input = KL.Input(shape=input_shape, name='pre_resample_input')
        resample_factor = [im_res[i] / float(resample[i]) for i in range(n_dims)]
        resample_shape = [utils.find_closest_number_divisible_by_m(resample_factor[i] * input_shape[i],
                          2 ** n_levels, smaller_ans=False) for i in range(n_dims)]
        resampled = nrn_layers.Resize(size=resample_shape, name='pre_resample')(im_input)
        net = Model(inputs=im_input, outputs=resampled)
        input_shape = resample_shape + [n_channels]

    # build UNet
    net = nrn_models.unet(nb_features=unet_feat_count,
                          input_shape=input_shape,
                          nb_levels=n_levels,
                          conv_size=conv_size,
                          nb_labels=n_lab,
                          name='unet',
                          prefix=None,
                          feat_mult=feat_multiplier,
                          pool_size=2,
                          use_logp=True,
                          padding='same',
                          dilation_rate_mult=1,
                          activation=activation,
                          use_residuals=False,
                          final_pred_activation='softmax',
                          nb_conv_per_level=nb_conv_per_level,
                          add_prior_layer=False,
                          add_prior_layer_reg=0,
                          layer_nb_feats=None,
                          conv_dropout=0,
                          batch_norm=-1,
                          input_model=net)
    net.load_weights(model_file, by_name=True)

    # build postprocessing model
    if (resample is not None) | (sigma_smoothing != 0):

        # get UNet output
        input_tensor = net.inputs
        last_tensor = net.output

        # resample to initial resolution
        if resample is not None:
            last_tensor = nrn_layers.Resize(size=input_shape[:-1], name='post_resample')(last_tensor)

        # smooth posteriors
        if sigma_smoothing != 0:
            last_tensor._keras_shape = tuple(last_tensor.get_shape().as_list())
            last_tensor = layers.GaussianBlur(sigma=sigma_smoothing)(last_tensor)

        # build model
        net = Model(inputs=input_tensor, outputs=last_tensor)

    return net
Esempio n. 6
0
def labels_to_image_model(labels_shape,
                          n_channels,
                          generation_labels,
                          output_labels,
                          n_neutral_labels,
                          atlas_res,
                          target_res,
                          output_shape=None,
                          output_div_by_n=None,
                          padding_margin=None,
                          flipping=True,
                          aff=None,
                          apply_linear_trans=True,
                          apply_nonlin_trans=True,
                          nonlin_std=3.,
                          nonlin_shape_factor=.0625,
                          blur_background=True,
                          data_res=None,
                          thickness=None,
                          downsample=False,
                          blur_range=1.15,
                          crop_channel2=None,
                          apply_bias_field=True,
                          bias_field_std=.3,
                          bias_shape_factor=.025):
    """
    This function builds a keras/tensorflow model to generate images from provided label maps.
    The images are generated by sampling a Gaussian Mixture Model (of given parameters), conditionned on the label map.
    The model will take as inputs:
        -a label map
        -a vector containing the means of the Gaussian Mixture Model for each label,
        -a vector containing the standard deviations of the Gaussian Mixture Model for each label,
        -if apply_affine_deformation is True: a batch*(n_dims+1)*(n_dims+1) affine matrix
        -if apply_non_linear_deformation is True: a small non linear field of size batch*(dim_1*...*dim_n)*n_dims that
        will be resampled to labels size and integrated, to obtain a diffeomorphic elastic deformation.
        -if apply_bias_field is True: a small bias field of size batch*(dim_1*...*dim_n)*1 that will be resampled to
        labels size and multiplied to the image, to add a "bias-field" noise.
    The model returns:
        -the generated image normalised between 0 and 1.
        -the corresponding label map, with only the labels present in output_labels (the other are reset to zero).
    :param labels_shape: shape of the input label maps. Can be a sequence or a 1d numpy array.
    :param n_channels: number of channels to be synthetised.
    :param generation_labels: (optional) list of all possible label values in the input label maps.
    Default is None, where the label values are directly gotten from the provided label maps.
    If not None, can be a sequence or a 1d numpy array. It should be organised as follows: background label first, then
    non-sided labels (e.g. CSF, brainstem, etc.), then all the structures of the same hemisphere (can be left or right),
    and finally all the corresponding contralateral structures (in the same order).
    :param output_labels: list of all the label values to keep in the output label maps, in no particular order.
    Should be a subset of the values contained in generation_labels.
    Label values that are in generation_labels but not in output_labels are reset to zero.
    Can be a sequence or a 1d numpy array.
    :param n_neutral_labels: number of non-sided generation labels.
    :param atlas_res: resolution of the input label maps.
    Can be a number (isotropic resolution), a sequence, or a 1d numpy array.
    :param target_res: target resolution of the generated images and corresponding label maps.
    Can be a number (isotropic resolution), a sequence, or a 1d numpy array.
    :param output_shape: (optional) desired shape of the output image, obtained by randomly cropping the generated image
    Can be an integer (same size in all dimensions), a sequence, a 1d numpy array, or the path to a 1d numpy array.
    :param output_div_by_n: (optional) forces the output shape to be divisible by this value. It overwrites output_shape
    if necessary. Can be an integer (same size in all dimensions), a sequence, or a 1d numpy array.
    :param padding_margin: (optional) margin by which to pad the input labels with zeros.
    Padding is applied prior to any other operation.
    Can be an integer (same padding in all dimensions), a sequence, or a 1d numpy array. Default is no padding.
    :param flipping: (optional) whether to introduce right/left random flipping
    :param aff: (optional) example of an (n_dims+1)x(n_dims+1) affine matrix of one of the input label map.
    Used to find brain's right/left axis. Should be given if flipping is True.
    :param apply_linear_trans: (optional) whether to linearly deform the input label maps prior to generation.
    If true, the model will take an additional input of size batch*(n_dims+1)*(n_dims+1). Default is True.
    :param apply_nonlin_trans: (optional) whether to non-linearly deform the input label maps prior to generation.
    If true, the model will take an additional input of size batch*(dim_1*...*dim_n)*n_dims. Default is True.
    :param nonlin_std: (optional) If apply_nonlin_trans is True, maximum value for the standard deviation of the normal
    distribution from which we sample the first tensor for synthesising the deformation field.
    :param nonlin_shape_factor: (optional) if apply_non_linear_deformation is True, factor between the shapes of the
    input label maps and the shape of the input non-linear tensor.
    :param blur_background: (optional) If True, the background is blurred with the other labels, and can be reset to
    zero with a probability of 0.2. If False, the background is not blurred (we apply an edge blurring correction), and
    can be replaced by a low-intensity background.
    :param data_res: ((optional) acquisition resolution to mimick. If provided, the images sampled from the GMM are
    blurred to mimick data that would be: 1) acquired at the given acquisition resolution, and 2) resample at
    target_resolution.
    Default is None, where images are isotropically blurred to introduce some spatial correlation between voxels.
    If the generated images are uni-modal, data_res can be a number (isotropic acquisition resolution), a sequence, a 1d
    numpy array, or the path to a 1d numy array. In the multi-modal case, it should be given as a numpy array (or a
    path) of size (n_mod, n_dims), where each row is the acquisition resolution of the correspionding chanel.
    :param thickness: (optional) if data_res is provided, we can further specify the slice thickness of the low
    resolution images to mimick.
    If the generated images are uni-modal, data_res can be a number (isotropic acquisition resolution), a sequence, a 1d
    numpy array, or the path to a 1d numy array. In the multi-modal case, it should be given as a numpy array (or a
    path) of size (n_mod, n_dims), where each row is the acquisition resolution of the correspionding chanel.
    :param downsample: (optional) whether to actually downsample the volume image to data_res.
    Default is False, except when thickness is provided, and thickness < data_res.
    :param blur_range: (optional) Randomise the standard deviation of the blurring kernels, (whether data_res is given
    or not). At each mini_batch, the standard deviation of the blurring kernels are multiplied by a coefficient sampled
    from a uniform distribution with bounds [1/blur_range, blur_range]. If None, no randomisation. Default is 1.15.
    :param crop_channel2: (optional) stats for cropping second channel along the anterior-posterior axis.
    Should be a vector of length 4, with bounds of uniform distribution for cropping the front and back of the image
    (in percentage). None is no croppping.
    :param apply_bias_field: (optional) whether to apply a bias field to the generated image.
    If true, the model will take an additional input of size batch*(dim_1*...*dim_n)*1. Default is True.
    :param bias_field_std: (optional) If apply_nonlin_trans is True, maximum value for the standard deviation of the
    normal distribution from which we sample the first tensor for synthesising the deformation field.
    :param bias_shape_factor: (optional) if apply_bias_field is True, factor between the shapes of the
    input label maps and the shape of the input bias field tensor.
    """

    # reformat resolutions
    labels_shape = utils.reformat_to_list(labels_shape)
    n_dims, _ = utils.get_dims(labels_shape)
    atlas_res = utils.reformat_to_n_channels_array(atlas_res,
                                                   n_dims=n_dims,
                                                   n_channels=n_channels)
    if data_res is None:  # data_res assumed to be the same as the atlas
        data_res = atlas_res
    else:
        data_res = utils.reformat_to_n_channels_array(data_res,
                                                      n_dims=n_dims,
                                                      n_channels=n_channels)
    atlas_res = atlas_res[0]
    if downsample:  # same as data_res if we want to actually downsample the synthetic image
        downsample_res = data_res
    else:  # set downsample_res to None if downsampling is not necessary
        downsample_res = None
    if target_res is None:
        target_res = atlas_res
    else:
        target_res = utils.reformat_to_n_channels_array(target_res, n_dims)[0]
    thickness = utils.reformat_to_n_channels_array(thickness,
                                                   n_dims=n_dims,
                                                   n_channels=n_channels)

    # get shapes
    crop_shape, output_shape, padding_margin = get_shapes(
        labels_shape, output_shape, atlas_res, target_res, padding_margin,
        output_div_by_n)

    # create new_label_list and corresponding LUT to make sure that labels go from 0 to N-1
    n_generation_labels = generation_labels.shape[0]
    new_generation_label_list, lut = utils.rearrange_label_list(
        generation_labels)

    # define model inputs
    labels_input = KL.Input(shape=labels_shape + [1], name='labels_input')
    means_input = KL.Input(shape=list(new_generation_label_list.shape) +
                           [n_channels],
                           name='means_input')
    std_devs_input = KL.Input(shape=list(new_generation_label_list.shape) +
                              [n_channels],
                              name='std_devs_input')
    list_inputs = [labels_input, means_input, std_devs_input]
    if apply_linear_trans:
        aff_in = KL.Input(shape=(n_dims + 1, n_dims + 1), name='aff_input')
        list_inputs.append(aff_in)
    else:
        aff_in = None

    # convert labels to new_label_list
    labels = l2i_et.convert_labels(labels_input, lut)

    # pad labels
    if padding_margin is not None:
        pad = np.transpose(np.array([[0] + padding_margin + [0]] * 2))
        labels = KL.Lambda(lambda x: tf.pad(
            x, tf.cast(tf.convert_to_tensor(pad), dtype='int32')),
                           name='pad')(labels)
        labels_shape = labels.get_shape().as_list()[1:n_dims + 1]

    # deform labels
    if apply_linear_trans | apply_nonlin_trans:
        labels = l2i_sp.deform_tensor(labels, aff_in, apply_nonlin_trans,
                                      'nearest', nonlin_std,
                                      nonlin_shape_factor)
    labels = KL.Lambda(lambda x: tf.cast(x, dtype='int32'))(labels)

    # cropping
    if crop_shape != labels_shape:
        labels, _ = l2i_sp.random_cropping(labels, crop_shape, n_dims)

    if flipping:
        assert aff is not None, 'aff should not be None if flipping is True'
        labels, _ = l2i_sp.label_map_random_flipping(
            labels, new_generation_label_list, n_neutral_labels, aff, n_dims)

    # build synthetic image
    image = l2i_gmm.sample_gmm_conditioned_on_labels(labels, means_input,
                                                     std_devs_input,
                                                     n_generation_labels,
                                                     n_channels)

    # loop over channels
    if n_channels > 1:
        split = KL.Lambda(lambda x: tf.split(x, [1] * n_channels, axis=-1))(
            image)
    else:
        split = [image]
    mask = KL.Lambda(
        lambda x: tf.where(tf.greater(x, 0), tf.ones_like(x, dtype='float32'),
                           tf.zeros_like(x, dtype='float32')))(labels)
    processed_channels = list()
    for i, channel in enumerate(split):

        # reset edges of second channels to zero
        if (crop_channel2 is not None) & (
                i == 1):  # randomly crop sides of second channel
            channel, tmp_mask = l2i_sp.restrict_tensor(
                channel, axes=3, boundaries=crop_channel2)
        else:
            tmp_mask = None

        # blur channel
        if thickness is not None:
            sigma = utils.get_std_blurring_mask_for_downsampling(
                data_res[i], atlas_res, thickness=thickness[i])
        else:
            sigma = utils.get_std_blurring_mask_for_downsampling(
                data_res[i], atlas_res)
        kernels_list = l2i_et.get_gaussian_1d_kernels(
            sigma, blurring_range=blur_range)
        channel = l2i_et.blur_channel(channel, mask, kernels_list, n_dims,
                                      blur_background)
        if (crop_channel2 is not None) & (i == 1):
            channel = KL.multiply([channel, tmp_mask])

        # resample channel
        if downsample_res is not None:
            channel = l2i_et.resample_tensor(channel,
                                             output_shape,
                                             'linear',
                                             downsample_res[i],
                                             atlas_res,
                                             n_dims=n_dims)
        else:
            if thickness is not None:
                diff = [
                    thickness[i][dim_idx] - data_res[i][dim_idx]
                    for dim_idx in range(n_dims)
                ]
                if min(diff) < 0:
                    channel = l2i_et.resample_tensor(channel,
                                                     output_shape,
                                                     'linear',
                                                     data_res[i],
                                                     atlas_res,
                                                     n_dims=n_dims)
                else:
                    channel = l2i_et.resample_tensor(channel, output_shape,
                                                     'linear', None, atlas_res,
                                                     n_dims)

        # apply bias field
        if apply_bias_field:
            channel = l2i_ia.bias_field_augmentation(channel, bias_field_std,
                                                     bias_shape_factor)

        # intensity augmentation
        channel = KL.Lambda(lambda x: K.clip(x, 0, 300))(channel)
        channel = l2i_ia.min_max_normalisation(channel)
        processed_channels.append(l2i_ia.gamma_augmentation(channel, std=0.5))

    # concatenate all channels back
    if n_channels > 1:
        image = KL.concatenate(processed_channels)
    else:
        image = processed_channels[0]

    # resample labels at target resolution
    if crop_shape != output_shape:
        labels = KL.Lambda(lambda x: tf.cast(x, dtype='float32'))(labels)
        labels = l2i_et.resample_tensor(labels,
                                        output_shape,
                                        interp_method='nearest',
                                        n_dims=3)
    # convert labels back to original values and reset unwanted labels to zero
    labels = l2i_et.convert_labels(labels, generation_labels)
    labels_to_reset = [
        lab for lab in generation_labels if lab not in output_labels
    ]
    labels = l2i_et.reset_label_values_to_zero(labels, labels_to_reset)
    labels = KL.Lambda(lambda x: tf.cast(x, dtype='int32'),
                       name='labels_out')(labels)

    # build model (dummy layer enables to keep the labels when plugging this model to other models)
    image = KL.Lambda(lambda x: x[0], name='image_out')([image, labels])
    brain_model = keras.Model(inputs=list_inputs, outputs=[image, labels])

    return brain_model
def build_model_input_generator(images_paths,
                                labels_paths,
                                n_channels,
                                im_shape,
                                scaling_range=None,
                                rotation_range=None,
                                shearing_range=None,
                                nonlin_shape_fact=0.0625,
                                nonlin_std_dev=3,
                                batch_size=1):

    # Generate!
    while True:

        # randomly pick as many images as batch_size
        indices = npr.randint(len(images_paths), size=batch_size)

        # initialise input tensors
        images_all = []
        labels_all = []
        aff_all = []
        nonlinear_field_all = []

        for idx in indices:

            # add image
            image = utils.load_volume(images_paths[idx])
            if n_channels > 1:
                images_all.append(utils.add_axis(image, axis=0))
            else:
                images_all.append(utils.add_axis(image, axis=-2))

            # add labels
            labels = utils.load_volume(labels_paths[idx], dtype='int')
            labels_all.append(utils.add_axis(labels, axis=-2))

            # get affine transformation: rotate, scale, shear (translation done during random cropping)
            n_dims, _ = utils.get_dims(im_shape)
            scaling = utils.draw_value_from_distribution(scaling_range, size=n_dims, centre=1, default_range=.15)
            if n_dims == 2:
                rotation_angle = utils.draw_value_from_distribution(rotation_range, default_range=15.0)
            else:
                rotation_angle = utils.draw_value_from_distribution(rotation_range, size=n_dims, default_range=15.0)
            shearing = utils.draw_value_from_distribution(shearing_range, size=n_dims ** 2 - n_dims, default_range=.01)
            aff = utils.create_affine_transformation_matrix(n_dims, scaling, rotation_angle, shearing)
            aff_all.append(utils.add_axis(aff))

            # add non linear field
            deform_shape = utils.get_resample_shape(im_shape, nonlin_shape_fact, len(im_shape))
            nonlinear_field = npr.normal(loc=0, scale=nonlin_std_dev * npr.rand(), size=deform_shape)
            nonlinear_field_all.append(utils.add_axis(nonlinear_field))

        # build list of inputs of the augmentation model
        inputs_vals = [images_all, labels_all, aff_all, nonlinear_field_all]

        # put images and labels (concatenated if batch_size>1) into a tuple of 2 elements: (cat_images, cat_labels)
        if batch_size > 1:
            inputs_vals = [np.concatenate(item, 0) for item in inputs_vals]
        else:
            inputs_vals = [item[0] for item in inputs_vals]

        yield inputs_vals
def labels_to_image_model(im_shape,
                          n_channels,
                          crop_shape,
                          label_list,
                          n_neutral_labels,
                          vox2ras,
                          nonlin_shape_factor=0.0625,
                          crop_channel2=None,
                          output_div_by_n=None,
                          flipping=True):

    # get shapes
    n_dims, _ = utils.get_dims(im_shape)
    crop_shape = get_shapes(crop_shape, im_shape, output_div_by_n)
    deformation_field_size = utils.get_resample_shape(im_shape, nonlin_shape_factor, len(im_shape))

    # create new_label_list and corresponding LUT to make sure that labels go from 0 to N-1
    new_label_list, lut = utils.rearrange_label_list(label_list)

    # define mandatory inputs
    image_input = KL.Input(shape=im_shape+[n_channels], name='image_input')
    labels_input = KL.Input(shape=im_shape + [1], name='labels_input')
    aff_in = KL.Input(shape=(n_dims + 1, n_dims + 1), name='aff_input')
    nonlin_field_in = KL.Input(shape=deformation_field_size, name='nonlin_input')
    list_inputs = [image_input, labels_input, aff_in, nonlin_field_in]

    # convert labels to new_label_list
    labels = KL.Lambda(lambda x: tf.gather(tf.convert_to_tensor(lut, dtype='int32'),
                                           tf.cast(x, dtype='int32')))(labels_input)

    # deform labels
    image_input._keras_shape = tuple(image_input.get_shape().as_list())
    labels._keras_shape = tuple(labels.get_shape().as_list())
    labels = KL.Lambda(lambda x: tf.cast(x, dtype='float'))(labels)
    resize_shape = [max(int(im_shape[i] / 2), deformation_field_size[i]) for i in range(len(im_shape))]
    nonlin_field = nrn_layers.Resize(size=resize_shape, interp_method='linear')(nonlin_field_in)
    nonlin_field = nrn_layers.VecInt()(nonlin_field)
    nonlin_field = nrn_layers.Resize(size=im_shape, interp_method='linear')(nonlin_field)
    image = nrn_layers.SpatialTransformer(interp_method='linear')([image_input, aff_in, nonlin_field])
    labels = nrn_layers.SpatialTransformer(interp_method='nearest')([labels, aff_in, nonlin_field])
    labels = KL.Lambda(lambda x: tf.cast(x, dtype='int32'))(labels)

    # cropping
    if crop_shape is not None:
        image, crop_idx = l2i_sa.random_cropping(image, crop_shape, n_dims)
        labels = KL.Lambda(lambda x: tf.slice(x[0], begin=tf.cast(x[1], dtype='int32'),
                           size=tf.convert_to_tensor([-1] + crop_shape + [-1], dtype='int32')))([labels, crop_idx])
    else:
        crop_shape = im_shape

    # flipping
    if flipping:
        labels, flip = l2i_sa.label_map_random_flipping(labels, label_list, n_neutral_labels, vox2ras, n_dims)
        ras_axes, _ = edit_volumes.get_ras_axes_and_signs(vox2ras, n_dims)
        flip_axis = [ras_axes[0] + 1]
        image = KL.Lambda(lambda y: K.switch(y[0],
                                             KL.Lambda(lambda x: K.reverse(x, axes=flip_axis))(y[1]),
                                             y[1]))([flip, image])

    # convert labels back to original values
    labels = KL.Lambda(lambda x: tf.gather(tf.convert_to_tensor(label_list, dtype='int32'),
                                           tf.cast(x, dtype='int32')), name='labels_out')(labels)

    # intensity augmentation
    image = KL.Lambda(lambda x: K.clip(x, 0, 300), name='clipping')(image)

    # loop over channels
    if n_channels > 1:
        split = KL.Lambda(lambda x: tf.split(x, [1] * n_channels, axis=-1))(image)
    else:
        split = [image]
    processed_channels = list()
    for i, channel in enumerate(split):

        # normalise and shift intensities
        image = l2i_ia.min_max_normalisation(image)
        image = KL.Lambda(lambda x: K.random_uniform((1,), .85, 1.1) * x + K.random_uniform((1,), -.3, .3))(image)
        image = KL.Lambda(lambda x: K.clip(x, 0, 1))(image)
        image = l2i_ia.gamma_augmentation(image)

        # randomly crop sides of second channel
        if (crop_channel2 is not None) & (channel == 1):
            image = l2i_sa.restrict_tensor(image, crop_channel2, n_dims)

    # concatenate all channels back, and clip output (include labels to keep it when plugging to other models)
    if n_channels > 1:
        image = KL.concatenate(processed_channels)
    else:
        image = processed_channels[0]
    image = KL.Lambda(lambda x: K.clip(x[0], 0, 1), name='image_out')([image, labels])

    # build model
    brain_model = Model(inputs=list_inputs, outputs=[image, labels])
    # shape of returned images
    output_shape = image.get_shape().as_list()[1:]

    return brain_model, output_shape
Esempio n. 9
0
def metrics_model(input_shape,
                  segmentation_label_list,
                  input_model=None,
                  loss_cropping=None,
                  metrics='dice',
                  weight_background=None,
                  include_background=False,
                  name=None,
                  prefix=None,
                  validation_on_real_images=False):

    # naming the model
    model_name = name
    if prefix is None:
        prefix = model_name

    # first layer: input
    name = '%s_input' % prefix
    if input_model is None:
        input_tensor = KL.Input(shape=input_shape, name=name)
        last_tensor = input_tensor
    else:
        input_tensor = input_model.inputs
        last_tensor = input_model.outputs
        if isinstance(last_tensor, list):
            last_tensor = last_tensor[0]
        last_tensor = KL.Reshape(input_shape,
                                 name='predicted_output')(last_tensor)

    # get deformed labels
    n_labels = input_shape[-1]
    if validation_on_real_images:
        labels_gt = KL.Input(shape=input_shape[:-1] + [1], name='labels_input')
        input_tensor = [input_tensor[0], labels_gt]
    else:
        labels_gt = input_model.get_layer('labels_out').output

    # convert gt labels to 0...N-1 values
    n_labels = segmentation_label_list.shape[0]
    _, lut = utils.rearrange_label_list(segmentation_label_list)
    labels_gt = l2i_et.convert_labels(labels_gt, lut)

    # convert gt labels to probabilistic values
    labels_gt = KL.Lambda(lambda x: tf.one_hot(
        tf.cast(x, dtype='int32'), depth=n_labels, axis=-1))(labels_gt)
    labels_gt = KL.Reshape(input_shape)(labels_gt)
    labels_gt = KL.Lambda(lambda x: K.clip(
        x / K.sum(x, axis=-1, keepdims=True), K.epsilon(), 1))(labels_gt)

    # crop output to evaluate loss function in centre patch
    if loss_cropping is not None:
        # format loss_cropping
        labels_shape = labels_gt.get_shape().as_list()[1:-1]
        n_dims, _ = utils.get_dims(labels_shape)
        loss_cropping = [-1] + utils.reformat_to_list(loss_cropping,
                                                      length=n_dims) + [-1]
        # perform cropping
        begin_idx = [0] + [
            int((labels_shape[i] - loss_cropping[i]) / 2)
            for i in range(n_dims)
        ] + [0]
        labels_gt = KL.Lambda(lambda x: tf.slice(
            x,
            begin=tf.convert_to_tensor(begin_idx, dtype='int32'),
            size=tf.convert_to_tensor(loss_cropping, dtype='int32')))(
                labels_gt)
        last_tensor = KL.Lambda(lambda x: tf.slice(
            x,
            begin=tf.convert_to_tensor(begin_idx, dtype='int32'),
            size=tf.convert_to_tensor(loss_cropping, dtype='int32')))(
                last_tensor)

    # metrics is computed as part of the model
    if metrics == 'dice':

        # make sure predicted values are probabilistic
        last_tensor = KL.Lambda(lambda x: K.clip(
            x / K.sum(x, axis=-1, keepdims=True), K.epsilon(), 1))(last_tensor)

        # compute dice
        top = KL.Lambda(lambda x: 2 * x[0] * x[1])([labels_gt, last_tensor])
        bottom = KL.Lambda(lambda x: K.square(x[0]) + K.square(x[1]))(
            [labels_gt, last_tensor])
        for dims_to_sum in range(len(input_shape) - 1):
            top = KL.Lambda(lambda x: K.sum(x, axis=1))(top)
            bottom = KL.Lambda(lambda x: K.sum(x, axis=1))(bottom)
        last_tensor = KL.Lambda(lambda x: x[0] / K.maximum(x[1], 0.001),
                                name='dice')([top, bottom])  # 1d vector

        # compute mean dice loss
        if include_background:
            w = np.ones([n_labels]) / n_labels
        else:
            w = np.ones([n_labels]) / (n_labels - 1)
            w[0] = 0.0
        last_tensor = KL.Lambda(lambda x: 1 - x, name='dice_loss')(last_tensor)
        last_tensor = KL.Lambda(lambda x: K.sum(
            x * tf.convert_to_tensor(w, dtype='float32'), axis=1),
                                name='mean_dice_loss')(last_tensor)
        # average mean dice loss over mini batch
        last_tensor = KL.Lambda(lambda x: K.mean(x),
                                name='average_mean_dice_loss')(last_tensor)

    elif metrics == 'wl2':
        # compute weighted l2 loss
        weights = KL.Lambda(lambda x: K.expand_dims(1 - x[
            ..., 0] + weight_background))(labels_gt)
        normaliser = KL.Lambda(lambda x: K.sum(x[0]) * K.int_shape(x[1])[-1])(
            [weights, last_tensor])
        last_tensor = KL.Lambda(
            # lambda x: K.sum(x[2] * K.square(x[1] - (x[0] * 30 - 15))) / x[3],
            lambda x: K.sum(x[2] * K.square(x[1] - (x[0] * 6 - 3))) / x[3],
            name='wl2')([labels_gt, last_tensor, weights, normaliser])

    else:
        raise Exception(
            'metrics should either be "dice or "wl2, got {}'.format(metrics))

    # create the model and return
    model = Model(inputs=input_tensor, outputs=last_tensor, name=model_name)
    return model
Esempio n. 10
0
def sample_intensity_stats_from_single_dataset(image_dir, labels_dir, labels_list, classes_list=None, max_channel=3,
                                               rescale=True):
    """This function aims at estimating the intensity distributions of K different structure types from a set of images.
    The distribution of each structure type is modelled as a Gaussian, parametrised by a mean and a standard deviation.
    Because the intensity distribution of structures can vary accross images, we additionally use Gausian priors for the
    parameters of each Gaussian distribution. Therefore, the intensity distribution of each structure type is described
    by 4 parameters: a mean/std for the mean intensity, and a mean/std for the std deviation.
    This function uses a set of images along with corresponding segmentations to estimate the 4*K parameters.
    Structures can share the same statistics by being regrouped into classes of similar structure types.
    Images can be multi-modal (n_channels), in which case different statistics are estimated for each modality.
    :param image_dir: path of directory with images to estimate the intensity distribution
    :param labels_dir: path of directory with segmentation of input images.
    They are matched with images by sorting order.
    :param labels_list: list of labels for which to evaluate mean and std intensity.
    Can be a sequence, a 1d numpy array, or the path to a 1d numpy array.
    :param classes_list: (optional) enables to regroup structures into classes of similar intensity statistics.
    Intenstites associated to regrouped labels will thus contribute to the same Gaussian during statistics estimation.
    Can be a sequence, a 1d numpy array, or the path to a 1d numpy array.
    It should have the same length as labels_list, and contain values between 0 and K-1, where K is the total number of
    classes. Default is all labels have different classes (K=len(labels_list)).
    :param max_channel: (optional) maximum number of channels to consider if the data is multispectral. Default is 3.
    :param rescale: (optional) whether to rescale images between 0 and 255 before intensity estimation
    :return: 2 numpy arrays of size (2*n_channels, K), one with the evaluated means/std for the mean
    intensity, and one for the mean/std for the standard deviation.
    Each block of two rows correspond to a different modality (channel). For each block of two rows, the first row
    represents the mean, and the second represents the std.
    """

    # list files
    path_images = utils.list_images_in_folder(image_dir)
    path_labels = utils.list_images_in_folder(labels_dir)
    assert len(path_images) == len(path_labels), 'image and labels folders do not have the same number of files'

    # reformat list labels and classes
    labels_list = np.array(utils.reformat_to_list(labels_list, load_as_numpy=True, dtype='int'))
    if classes_list is not None:
        classes_list = np.array(utils.reformat_to_list(classes_list, load_as_numpy=True, dtype='int'))
    else:
        classes_list = np.arange(labels_list.shape[0])
    assert len(classes_list) == len(labels_list), 'labels and classes lists should have the same length'

    # get unique classes
    unique_classes, unique_indices = np.unique(classes_list, return_index=True)
    n_classes = len(unique_classes)
    if not np.array_equal(unique_classes, np.arange(n_classes)):
        raise ValueError('classes_list should only contain values between 0 and K-1, '
                         'where K is the total number of classes. Here K = %d' % n_classes)

    # initialise result arrays
    n_dims, n_channels = utils.get_dims(utils.load_volume(path_images[0]).shape, max_channels=max_channel)
    means = np.zeros((len(path_images), n_classes, n_channels))
    stds = np.zeros((len(path_images), n_classes, n_channels))

    # loop over images
    loop_info = utils.LoopInfo(len(path_images), 10, 'estimating', print_time=True)
    for idx, (path_im, path_la) in enumerate(zip(path_images, path_labels)):
        loop_info.update(idx)

        # load image and label map
        image = utils.load_volume(path_im)
        la = utils.load_volume(path_la)
        if n_channels == 1:
            image = utils.add_axis(image, -1)

        # loop over channels
        for channel in range(n_channels):
            im = image[..., channel]
            if rescale:
                im = edit_volumes.rescale_volume(im)
            stats = sample_intensity_stats_from_image(im, la, labels_list, classes_list=classes_list)
            means[idx, :, channel] = stats[0, :]
            stds[idx, :, channel] = stats[1, :]

    # compute prior parameters for mean/std
    mean_means = np.mean(means, axis=0)
    std_means = np.std(means, axis=0)
    mean_stds = np.mean(stds, axis=0)
    std_stds = np.std(stds, axis=0)

    # regroup prior parameters in two different arrays: one for the mean and one for the std
    prior_means = np.zeros((2 * n_channels, n_classes))
    prior_stds = np.zeros((2 * n_channels, n_classes))
    for channel in range(n_channels):
        prior_means[2 * channel, :] = mean_means[:, channel]
        prior_means[2 * channel + 1, :] = std_means[:, channel]
        prior_stds[2 * channel, :] = mean_stds[:, channel]
        prior_stds[2 * channel + 1, :] = std_stds[:, channel]

    return prior_means, prior_stds
def build_augmentation_model(im_shape,
                             n_channels,
                             label_list,
                             image_res,
                             target_res=None,
                             output_shape=None,
                             output_div_by_n=None,
                             n_neutral_labels=1,
                             flipping=True,
                             flip_rl_only=False,
                             aff=None,
                             scaling_bounds=0.15,
                             rotation_bounds=15,
                             enable_90_rotations=False,
                             shearing_bounds=0.012,
                             translation_bounds=False,
                             nonlin_std=3.,
                             nonlin_shape_factor=.0625,
                             bias_field_std=.3,
                             bias_shape_factor=0.025,
                             same_bias_for_all_channels=False,
                             apply_intensity_augmentation=True,
                             noise_std=1.,
                             augment_channels_separately=True):

    # reformat resolutions
    im_shape = utils.reformat_to_list(im_shape)
    n_dims, _ = utils.get_dims(im_shape)
    image_res = utils.reformat_to_list(image_res, length=n_dims)
    target_res = image_res if target_res is None else utils.reformat_to_list(
        target_res, length=n_dims)

    # get shapes
    cropping_shape, output_shape = get_shapes(im_shape, output_shape,
                                              image_res, target_res,
                                              output_div_by_n)
    im_shape = im_shape + [n_channels]

    # create new_label_list and corresponding LUT to make sure that labels go from 0 to N-1
    new_label_list, lut = utils.rearrange_label_list(label_list)

    # define model inputs
    image_input = KL.Input(shape=im_shape, name='image_input')
    labels_input = KL.Input(shape=im_shape[:-1] + [1],
                            name='labels_input',
                            dtype='int32')

    # convert labels to new_label_list
    labels = l2i_et.convert_labels(labels_input, lut)

    # flipping
    if flipping:
        if flip_rl_only:
            labels, image = layers.RandomFlip(
                int(edit_volumes.get_ras_axes(aff, n_dims)[0]), [True, False],
                new_label_list, n_neutral_labels)([labels, image_input])
        else:
            labels, image = layers.RandomFlip(
                None, [True, False], new_label_list,
                n_neutral_labels)([labels, image_input])
    else:
        image = image_input

    # transform labels to soft prob. and concatenate them to the image
    labels = KL.Lambda(lambda x: tf.one_hot(
        tf.cast(x[..., 0], dtype='int32'), depth=len(label_list), axis=-1))(
            labels)
    image = KL.concatenate([image, labels], axis=len(im_shape))

    # spatial deformation
    if (scaling_bounds is not False) | (rotation_bounds is not False) | (shearing_bounds is not False) | \
       (translation_bounds is not False) | (nonlin_std > 0) | enable_90_rotations:
        image._keras_shape = tuple(image.get_shape().as_list())
        image = layers.RandomSpatialDeformation(
            scaling_bounds=scaling_bounds,
            rotation_bounds=rotation_bounds,
            shearing_bounds=shearing_bounds,
            translation_bounds=translation_bounds,
            enable_90_rotations=enable_90_rotations,
            nonlin_std=nonlin_std,
            nonlin_shape_factor=nonlin_shape_factor)(image)

    # cropping
    if cropping_shape != im_shape[:-1]:
        image._keras_shape = tuple(image.get_shape().as_list())
        image = layers.RandomCrop(cropping_shape)(image)

    # resampling (image blurred separately)
    if cropping_shape != output_shape:
        sigma = l2i_et.blurring_sigma_for_downsampling(image_res, target_res)
        split = KL.Lambda(
            lambda x: tf.split(x, [n_channels, -1], axis=len(im_shape)))(image)
        image = split[0]
        image._keras_shape = tuple(image.get_shape().as_list())
        image = layers.GaussianBlur(sigma=sigma)(image)
        image = KL.concatenate([image, split[-1]])
        image = l2i_et.resample_tensor(image, output_shape)

    # split tensor between image and labels
    image, labels = KL.Lambda(
        lambda x: tf.split(x, [n_channels, -1], axis=len(im_shape)),
        name='splitting')(image)

    # apply bias field
    if bias_field_std > 0:
        image._keras_shape = tuple(image.get_shape().as_list())
        image = layers.BiasFieldCorruption(bias_field_std, bias_shape_factor,
                                           same_bias_for_all_channels)(image)

    # intensity augmentation
    if apply_intensity_augmentation:
        image._keras_shape = tuple(image.get_shape().as_list())
        image = layers.IntensityAugmentation(
            noise_std,
            gamma_std=0.5,
            separate_channels=augment_channels_separately)(image)

    # build model
    im_trans_model = Model(inputs=[image_input, labels_input],
                           outputs=[image, labels])

    return im_trans_model
Esempio n. 12
0
def build_model(model_file, input_shape, resample, im_res, n_levels, n_lab, conv_size, nb_conv_per_level,
                unet_feat_count, feat_multiplier, no_batch_norm, activation, sigma_smoothing):

    # initialisation
    net = None
    n_dims, n_channels = utils.get_dims(input_shape, max_channels=3)
    resample = utils.reformat_to_list(resample, length=n_dims)

    # build preprocessing model
    if resample is not None:
        im_input = KL.Input(shape=input_shape, name='pre_resample_input')
        resample_factor = [im_res[i] / float(resample[i]) for i in range(n_dims)]
        resample_shape = [utils.find_closest_number_divisible_by_m(resample_factor[i] * input_shape[i],
                          2 ** n_levels, smaller_ans=False) for i in range(n_dims)]
        resampled = nrn_layers.Resize(size=resample_shape, name='pre_resample')(im_input)
        net = Model(inputs=im_input, outputs=resampled)
        input_shape = resample_shape + [n_channels]

    # build UNet
    if no_batch_norm:
        batch_norm_dim = None
    else:
        batch_norm_dim = -1
    net = nrn_models.unet(nb_features=unet_feat_count,
                          input_shape=input_shape,
                          nb_levels=n_levels,
                          conv_size=conv_size,
                          nb_labels=n_lab,
                          name='unet',
                          prefix=None,
                          feat_mult=feat_multiplier,
                          pool_size=2,
                          use_logp=True,
                          padding='same',
                          dilation_rate_mult=1,
                          activation=activation,
                          use_residuals=False,
                          final_pred_activation='softmax',
                          nb_conv_per_level=nb_conv_per_level,
                          add_prior_layer=False,
                          add_prior_layer_reg=0,
                          layer_nb_feats=None,
                          conv_dropout=0,
                          batch_norm=batch_norm_dim,
                          input_model=net)
    net.load_weights(model_file, by_name=True)

    # build postprocessing model
    if (resample is not None) | (sigma_smoothing != 0):

        # get UNet output
        input_tensor = net.inputs
        last_tensor = net.output

        # resample to initial resolution
        if resample is not None:
            last_tensor = nrn_layers.Resize(size=input_shape[:-1], name='post_resample')(last_tensor)

        # smooth each posteriors map separately
        if sigma_smoothing != 0:
            kernels_list = l2i_et.get_gaussian_1d_kernels(utils.reformat_to_list(sigma_smoothing, length=n_dims))
            split = KL.Lambda(lambda x: tf.split(x, [1] * n_lab, axis=-1), name='resample_split')(last_tensor)
            last_tensor = l2i_et.blur_tensor(split[0], kernels_list, n_dims)
            for i in range(1, n_lab):
                temp_blurred = l2i_et.blur_tensor(split[i], kernels_list, n_dims)
                last_tensor = KL.concatenate([last_tensor, temp_blurred], axis=-1, name='cat_blurring_%s' % i)

        # build model
        net = Model(inputs=input_tensor, outputs=last_tensor)

    return net
Esempio n. 13
0
def build_augmentation_model(im_shape,
                             n_channels,
                             segmentation_labels,
                             n_neutral_labels,
                             atlas_res,
                             target_res,
                             output_shape=None,
                             output_div_by_n=None,
                             flipping=True,
                             aff=None,
                             scaling_bounds=0.15,
                             rotation_bounds=15,
                             shearing_bounds=0.012,
                             translation_bounds=False,
                             nonlin_std=3.,
                             nonlin_shape_factor=.0625,
                             data_res=None,
                             thickness=None,
                             downsample=False,
                             blur_range=1.03,
                             bias_field_std=.5,
                             bias_shape_factor=.025):

    # reformat resolutions and get shapes
    im_shape = utils.reformat_to_list(im_shape)
    n_dims, _ = utils.get_dims(im_shape)
    if data_res is not None:
        data_res = utils.reformat_to_n_channels_array(data_res, n_dims,
                                                      n_channels)
        thickness = data_res if thickness is None else utils.reformat_to_n_channels_array(
            thickness, n_dims, n_channels)
        downsample = utils.reformat_to_list(
            downsample,
            n_channels) if downsample else np.min(thickness - data_res, 1) < 0
        target_res = atlas_res if (
            target_res is None) else utils.reformat_to_n_channels_array(
                target_res, n_dims)[0]
    else:
        target_res = atlas_res

    # get shapes
    crop_shape, output_shape = get_shapes(im_shape, output_shape, atlas_res,
                                          target_res, output_div_by_n)

    # define model inputs
    image_input = KL.Input(shape=im_shape + [n_channels], name='image_input')
    labels_input = KL.Input(shape=im_shape + [1],
                            name='labels_input',
                            dtype='int32')

    # deform labels
    labels, image = layers.RandomSpatialDeformation(
        scaling_bounds=scaling_bounds,
        rotation_bounds=rotation_bounds,
        shearing_bounds=shearing_bounds,
        translation_bounds=translation_bounds,
        nonlin_std=nonlin_std,
        nonlin_shape_factor=nonlin_shape_factor,
        inter_method=['nearest', 'linear'])([labels_input, image_input])

    # cropping
    if crop_shape != im_shape:
        labels._keras_shape = tuple(labels.get_shape().as_list())
        image._keras_shape = tuple(image.get_shape().as_list())
        labels, image = layers.RandomCrop(crop_shape)([labels, image])

    # flipping
    if flipping:
        assert aff is not None, 'aff should not be None if flipping is True'
        labels._keras_shape = tuple(labels.get_shape().as_list())
        image._keras_shape = tuple(image.get_shape().as_list())
        labels, image = layers.RandomFlip(
            get_ras_axes(aff, n_dims)[0], [True, False], segmentation_labels,
            n_neutral_labels)([labels, image])

    # apply bias field
    if bias_field_std > 0:
        image._keras_shape = tuple(image.get_shape().as_list())
        image = layers.BiasFieldCorruption(bias_field_std, bias_shape_factor,
                                           False)(image)

    # intensity augmentation
    image._keras_shape = tuple(image.get_shape().as_list())
    image = layers.IntensityAugmentation(6,
                                         clip=False,
                                         normalise=True,
                                         gamma_std=.4,
                                         separate_channels=True)(image)

    # if necessary, loop over channels to 1) blur, 2) downsample to simulated LR, and 3) upsample to target
    if data_res is not None:
        channels = list()
        split = KL.Lambda(lambda x: tf.split(x, [1] * n_channels, axis=-1))(
            image) if (n_channels > 1) else [image]
        for i, channel in enumerate(split):

            # blur
            channel._keras_shape = tuple(channel.get_shape().as_list())
            sigma = l2i_et.blurring_sigma_for_downsampling(
                atlas_res, data_res[i], thickness=thickness[i])
            channel = layers.GaussianBlur(sigma, blur_range)(channel)

            # resample
            if downsample[i]:
                resolution = KL.Lambda(lambda x: tf.convert_to_tensor(
                    data_res[i], dtype='float32'))([])
                channel = layers.MimicAcquisition(atlas_res, data_res[i],
                                                  output_shape)(
                                                      [channel, resolution])
            elif output_shape != crop_shape:
                channel = nrn_layers.Resize(size=output_shape)(channel)
            channels.append(channel)

        # concatenate all channels back
        image = KL.Lambda(lambda x: tf.concat(x, -1))(
            channels) if len(channels) > 1 else channels[0]

        # resample labels at target resolution
        if crop_shape != output_shape:
            labels = l2i_et.resample_tensor(labels,
                                            output_shape,
                                            interp_method='nearest')

    # build model (dummy layer enables to keep the labels when plugging this model to other models)
    labels = KL.Lambda(lambda x: tf.cast(x, dtype='int32'),
                       name='labels_out')(labels)
    image = KL.Lambda(lambda x: x[0], name='image_out')([image, labels])
    brain_model = models.Model(inputs=[image_input, labels_input],
                               outputs=[image, labels])

    return brain_model