Exemplo n.º 1
0
def deform_tensor(tensor,
                  affine_trans=None,
                  apply_elastic_trans=True,
                  interp_method='linear',
                  nonlin_std=2.,
                  nonlin_shape_factor=.0625):
    """This function spatially deforms a tensor with a combination of affine and elastic transformations.
    :param tensor: input tensor to deform. Expected to have shape [batchsize, shape_dim1, ..., shape_dimn, channel].
    :param affine_trans: (optional) tensor of shape [batchsize, n_dims+1, n_dims+1] corresponding to an affine 
    transformation. Default is None, no affine transformation is applied.
    :param apply_elastic_trans: (optional) whether to deform the input tensor with a diffeomorphic elastic 
    transformation. If True the following steps occur:
    1) a small-size SVF is sampled from a centred normal distribution of random standard deviation.
    2) it is resized with trilinear interpolation to half the shape of the input tensor
    3) it is integrated to obtain a diffeomorphic transformation
    4) finally, it is resized (again with trilinear interpolation) to full image size
    Default is None, where no elastic transformation is applied.
    :param interp_method: (optional) interpolation method when deforming the input tensor. Can be 'linear', or 'nearest'
    :param nonlin_std: (optional) maximum value of the standard deviation of the normal distribution from which we
    sample the small-size SVF.
    :param nonlin_shape_factor: (optional) ration between the shape of the input tensor and the shape of the small field
    for elastic deformation.
    :return: tensor of the same shape as volume
    """

    assert (affine_trans is not None) | apply_elastic_trans, 'affine_trans or elastic_trans should be provided'

    # reformat tensor and get its shape
    tensor = KL.Lambda(lambda x: tf.cast(x, dtype='float32'))(tensor)
    tensor._keras_shape = tuple(tensor.get_shape().as_list())
    volume_shape = tensor.get_shape().as_list()[1: -1]
    n_dims = len(volume_shape)
    trans_inputs = [tensor]

    # add affine deformation to inputs list
    if affine_trans is not None:
        trans_inputs.append(affine_trans)

    # prepare non-linear deformation field and add it to inputs list
    if apply_elastic_trans:

        # sample small field from normal distribution of specified std dev
        small_shape = utils.get_resample_shape(volume_shape, nonlin_shape_factor, n_dims)
        tensor_shape = KL.Lambda(lambda x: tf.shape(x))(tensor)
        split_shape = KL.Lambda(lambda x: tf.split(x, [1, n_dims + 1]))(tensor_shape)
        nonlin_shape = KL.Lambda(lambda x: tf.concat([x, tf.convert_to_tensor(small_shape)], axis=0))(split_shape[0])
        nonlin_std_prior = KL.Lambda(lambda x: tf.random.uniform((1, 1), maxval=nonlin_std))([])
        elastic_trans = KL.Lambda(lambda x: tf.random.normal(x[0], stddev=x[1]))([nonlin_shape, nonlin_std_prior])
        elastic_trans._keras_shape = tuple(elastic_trans.get_shape().as_list())

        # reshape this field to image size and integrate it
        resize_shape = [max(int(volume_shape[i]/2), small_shape[i]) for i in range(n_dims)]
        nonlin_field = nrn_layers.Resize(size=resize_shape, interp_method='linear')(elastic_trans)
        nonlin_field = nrn_layers.VecInt()(nonlin_field)
        nonlin_field = nrn_layers.Resize(size=volume_shape, interp_method='linear')(nonlin_field)
        trans_inputs.append(nonlin_field)

    # apply deformations
    return nrn_layers.SpatialTransformer(interp_method=interp_method)(trans_inputs)
Exemplo n.º 2
0
def resample_tensor(tensor,
                    resample_shape,
                    interp_method='linear',
                    subsample_res=None,
                    volume_res=None):
    """This function resamples a volume to resample_shape. It does not apply any pre-filtering.
    A prior downsampling step can be added if subsample_res is specified. In this case, volume_res should also be
    specified, in order to calculate the downsampling ratio.
    :param tensor: tensor
    :param resample_shape: shape to resample the input tensor to. This can be a list or numpy array of size (n_dims,),
    where n_dims excludes the batchsize and channels dimensions.
    :param interp_method: interpolation method for resampling, 'linear' or 'nearest'
    :param subsample_res: if not None, this triggers a downsampling of the volume, prior to the resampling step.
    list or numpy array of size (n_dims,).
    :param volume_res: if subsample_res is not None, this should be provided to compute downsampling ratio.
     list or numpy array of size (n_dims,).
    :return: resampled volume
    """

    # reformat resolutions to lists
    subsample_res = utils.reformat_to_list(subsample_res)
    volume_res = utils.reformat_to_list(volume_res)
    if subsample_res is not None:
        assert volume_res is not None, 'volume_res must be given when providing a subsampling resolution.'
        assert len(subsample_res) == len(volume_res), 'subsample_res and volume_res must have the same length, ' \
                                                      'had {0}, and {1}'.format(len(subsample_res), len(volume_res))
    n_dims = len(resample_shape)

    # downsample image
    downsample_shape = None
    tensor_shape = tensor.get_shape().as_list()[1:-1]
    if subsample_res is not None:
        if subsample_res != volume_res:

            # get shape at which we downsample
            assert volume_res is not None, 'if subsanple_res is specified, so should atlas_res be.'
            downsample_shape = [
                int(tensor_shape[i] * volume_res[i] / subsample_res[i])
                for i in range(n_dims)
            ]

            # downsample volume
            tensor._keras_shape = tuple(tensor.get_shape().as_list())
            tensor = nrn_layers.Resize(size=downsample_shape,
                                       interp_method='nearest')(tensor)

    # resample image at target resolution
    if resample_shape != downsample_shape:
        tensor._keras_shape = tuple(tensor.get_shape().as_list())
        tensor = nrn_layers.Resize(size=resample_shape,
                                   interp_method=interp_method)(tensor)

    return tensor
Exemplo n.º 3
0
def resample_tensor(tensor,
                    resample_shape,
                    interp_method='linear',
                    subsample_res=None,
                    volume_res=None,
                    subsample_interp_method='nearest',
                    n_dims=3):
    """This function resamples a volume to resample_shape. It does not apply any pre-filtering.
    A prior downsampling step can be added if subsample_res is specified. In this case, volume_res should also be
    specified, in order to calculate the downsampling ratio.
    :param tensor: tensor
    :param resample_shape: list or numpy array of size (n_dims,)
    :param interp_method: interpolation method for resampling, 'linear' or 'nearest'
    :param subsample_res: if not None, this triggers a downsampling of the volume, prior to the resampling step.
    list or numpy array of size (n_dims,).
    :param volume_res: if subsample_res is not None, this should be provided to compute downsampling ratio.
     list or numpy array of size (n_dims,).
    :param subsample_interp_method: interpolation method for downsampling, 'linear' or 'nearest'
    :param n_dims: number of dimensions of the initial image (excluding batch and channel dimensions)
    :return: resampled volume
    """

    # downsample image
    downsample_shape = None
    tensor_shape = tensor.get_shape().as_list()[1:-1]
    if subsample_res is not None:
        if subsample_res.tolist() != volume_res.tolist():

            # get shape at which we downsample
            assert volume_res is not None, 'if subsanple_res is specified, so should atlas_res be.'
            downsample_factor = [
                volume_res[i] / subsample_res[i] for i in range(n_dims)
            ]
            downsample_shape = [
                int(tensor_shape[i] * downsample_factor[i])
                for i in range(n_dims)
            ]

            # downsample volume
            tensor._keras_shape = tuple(tensor.get_shape().as_list())
            tensor = nrn_layers.Resize(
                size=downsample_shape,
                interp_method=subsample_interp_method)(tensor)

    # resample image at target resolution
    if resample_shape != downsample_shape:
        tensor._keras_shape = tuple(tensor.get_shape().as_list())
        tensor = nrn_layers.Resize(size=resample_shape,
                                   interp_method=interp_method)(tensor)

    return tensor
def bias_field_augmentation(tensor, bias_field_std=.3, bias_shape_factor=.025):
    """This function applies a bias field to the input tensor. The following steps occur:
    1) a small-size SVF is sampled from a centred normal distribution of random standard deviation,
    2) it is resized with trilinear interpolation to image size
    3) it is rescaled to postive values by taking the voxel-wise exponential
    4) it is multiplied to the input tensor.
    :param tensor: input tensor. Expected to have shape [batchsize, shape_dim1, ..., shape_dimn, channel].
    :param bias_field_std: (optional) maximum value of the standard deviation of the normal distribution from which we
    sample the small-size SVF.
    :param bias_shape_factor: (optional) ration between the shape of the input tensor and the shape of the sampled SVF.
    :return: a biased tensor
    """

    # reformat tensor and get its shape
    tensor._keras_shape = tuple(tensor.get_shape().as_list())
    volume_shape = tensor.get_shape().as_list()[1: -1]
    n_dims = len(volume_shape)

    # sample small field from normal distribution of specified std dev
    small_shape = utils.get_resample_shape(volume_shape, bias_shape_factor, 1)
    tensor_shape = KL.Lambda(lambda x: tf.shape(x))(tensor)
    split_shape = KL.Lambda(lambda x: tf.split(x, [1, n_dims + 1]))(tensor_shape)
    bias_shape = KL.Lambda(lambda x: tf.concat([x, tf.convert_to_tensor(small_shape)], axis=0))(split_shape[0])
    bias_std = KL.Lambda(lambda x: tf.random.uniform((1, 1), maxval=bias_field_std))([])
    bias_field = KL.Lambda(lambda x: tf.random.normal(x[0], stddev=x[1]))([bias_shape, bias_std])
    bias_field._keras_shape = tuple(bias_field.get_shape().as_list())

    # resize bias field and take exponential
    bias_field = nrn_layers.Resize(size=volume_shape, interp_method='linear')(bias_field)
    bias_field._keras_shape = tuple(bias_field.get_shape().as_list())
    bias_field = KL.Lambda(lambda x: K.exp(x))(bias_field)

    return KL.multiply([bias_field, tensor])
Exemplo n.º 5
0
def postprocessing_model(unet, posteriors_patch_shape, resample,
                         sigma_smoothing, n_dims):

    # get output from unet
    input_tensor = unet.inputs
    last_tensor = unet.outputs
    if isinstance(last_tensor, list):
        last_tensor = last_tensor[0]

    # resample to original resolution
    if resample is not None:
        last_tensor = nrn_layers.Resize(size=posteriors_patch_shape[:-1],
                                        name='post_resample')(last_tensor)

    # smooth posteriors
    if sigma_smoothing != 0:

        # separate image channels from labels channels
        n_labels = last_tensor.get_shape().as_list()[-1]
        split = KL.Lambda(lambda x: tf.split(x, [1] * n_labels, axis=-1),
                          name='resample_split')(last_tensor)

        # create gaussian blurring kernel
        sigma_smoothing = utils.reformat_to_list(sigma_smoothing,
                                                 length=n_dims)
        kernels_list = l2i_et.get_gaussian_1d_kernels(sigma_smoothing)

        # blur each image channel separately
        last_tensor = l2i_et.blur_tensor(split[0], kernels_list, n_dims)
        for i in range(1, n_labels):
            temp_blurred = l2i_et.blur_tensor(split[i], kernels_list, n_dims)
            last_tensor = KL.concatenate([last_tensor, temp_blurred],
                                         axis=-1,
                                         name='cat_blurring_%s' % i)

    # build model
    model_postprocessing = Model(inputs=input_tensor, outputs=last_tensor)

    return model_postprocessing
Exemplo n.º 6
0
def preprocessing_model(resample, model_input_shape, header, n_channels,
                        n_dims, n_levels):

    im_resolution = header['pixdim'][1:n_dims + 1]
    if not isinstance(resample, (list, tuple)):
        resample = [resample]
    if len(resample) == 1:
        resample = resample * n_dims
    else:
        assert len(resample) == n_dims, \
            'new_resolution must be of length 1 or n_dims ({}): got {}'.format(n_dims, len(resample))
    resample_factor = [
        im_resolution[i] / float(resample[i]) for i in range(n_dims)
    ]
    pre_resample_shape = [
        utils.find_closest_number_divisible_by_m(
            resample_factor[i] * model_input_shape[i],
            2**n_levels,
            smaller_ans=False) for i in range(n_dims)
    ]
    resample_factor_corrected = [
        pre_resample_shape[i] / model_input_shape[i] for i in range(n_dims)
    ]

    # add layers to model
    im_input = KL.Input(shape=model_input_shape, name='pre_resample_input')
    resampled = nrn_layers.Resize(zoom_factor=resample_factor_corrected,
                                  name='pre_resample')(im_input)

    # build model
    model_preprocessing = Model(inputs=im_input, outputs=resampled)

    # add channel dimensions to shapes
    pre_resample_shape = list(pre_resample_shape) + [n_channels]

    return model_preprocessing, pre_resample_shape
Exemplo n.º 7
0
def labels_to_image_model(labels_shape,
                          n_channels,
                          generation_labels,
                          output_labels,
                          n_neutral_labels,
                          atlas_res,
                          target_res,
                          output_shape=None,
                          output_div_by_n=None,
                          flipping=True,
                          aff=None,
                          scaling_bounds=0.15,
                          rotation_bounds=15,
                          shearing_bounds=0.012,
                          translation_bounds=False,
                          nonlin_std=4.,
                          nonlin_shape_factor=.0625,
                          randomise_res=False,
                          buil_distance_maps=False,
                          data_res=None,
                          thickness=None,
                          downsample=False,
                          blur_range=1.15,
                          bias_field_std=.5,
                          bias_shape_factor=.025):
    """
    This function builds a keras/tensorflow model to generate images from provided label maps.
    The images are generated by sampling a Gaussian Mixture Model (of given parameters), conditionned on the label map.
    The model will take as inputs:
        -a label map
        -a vector containing the means of the Gaussian Mixture Model for each label,
        -a vector containing the standard deviations of the Gaussian Mixture Model for each label,
        -if apply_affine_deformation is True: a batch*(n_dims+1)*(n_dims+1) affine matrix
        -if apply_non_linear_deformation is True: a small non linear field of size batch*(dim_1*...*dim_n)*n_dims that
        will be resampled to labels size and integrated, to obtain a diffeomorphic elastic deformation.
        -if apply_bias_field is True: a small bias field of size batch*(dim_1*...*dim_n)*1 that will be resampled to
        labels size and multiplied to the image, to add a "bias-field" noise.
    The model returns:
        -the generated image normalised between 0 and 1.
        -the corresponding label map, with only the labels present in output_labels (the other are reset to zero).
    # IMPORTANT !!!
    # Each time we provide a parameter with separate values for each axis (e.g. with a numpy array or a sequence),
    # these values refer to the RAS axes.
    :param labels_shape: shape of the input label maps. Can be a sequence or a 1d numpy array.
    :param n_channels: number of channels to be synthetised.
    :param generation_labels: (optional) list of all possible label values in the input label maps.
    Default is None, where the label values are directly gotten from the provided label maps.
    If not None, can be a sequence or a 1d numpy array. It should be organised as follows: background label first, then
    non-sided labels (e.g. CSF, brainstem, etc.), then all the structures of the same hemisphere (can be left or right),
    and finally all the corresponding contralateral structures (in the same order).
    :param output_labels: list of all the label values to keep in the output label maps, in no particular order.
    Should be a subset of the values contained in generation_labels.
    Label values that are in generation_labels but not in output_labels are reset to zero.
    Can be a sequence or a 1d numpy array. By default output_labels is equal to generation_labels.
    :param n_neutral_labels: number of non-sided generation labels.
    :param atlas_res: resolution of the input label maps.
    Can be a number (isotropic resolution), a sequence, or a 1d numpy array.
    :param target_res: target resolution of the generated images and corresponding label maps.
    Can be a number (isotropic resolution), a sequence, or a 1d numpy array.
    :param output_shape: (optional) desired shape of the output image, obtained by randomly cropping the generated image
    Can be an integer (same size in all dimensions), a sequence, a 1d numpy array, or the path to a 1d numpy array.
    Default is None, where no cropping is performed.
    :param output_div_by_n: (optional) forces the output shape to be divisible by this value. It overwrites output_shape
    if necessary. Can be an integer (same size in all dimensions), a sequence, or a 1d numpy array.
    :param flipping: (optional) whether to introduce right/left random flipping
    :param aff: (optional) example of an (n_dims+1)x(n_dims+1) affine matrix of one of the input label map.
    Used to find brain's right/left axis. Should be given if flipping is True.
    :param scaling_bounds: (optional) range of the random saling to apply at each mini-batch. The scaling factor for
    each dimension is sampled from a uniform distribution of predefined bounds. Can either be:
    1) a number, in which case the scaling factor is independently sampled from the uniform distribution of bounds
    [1-scaling_bounds, 1+scaling_bounds] for each dimension.
    2) a sequence, in which case the scaling factor is sampled from the uniform distribution of bounds
    (1-scaling_bounds[i], 1+scaling_bounds[i]) for the i-th dimension.
    3) a numpy array of shape (2, n_dims), in which case the scaling factor is sampled from the uniform distribution
     of bounds (scaling_bounds[0, i], scaling_bounds[1, i]) for the i-th dimension.
    4) False, in which case scaling is completely turned off.
    Default is scaling_bounds = 0.15 (case 1)
    :param rotation_bounds: (optional) same as scaling bounds but for the rotation angle, except that for cases 1
    and 2, the bounds are centred on 0 rather than 1, i.e. [0+rotation_bounds[i], 0-rotation_bounds[i]].
    Default is rotation_bounds = 15.
    :param shearing_bounds: (optional) same as scaling bounds. Default is shearing_bounds = 0.012.
    :param translation_bounds: (optional) same as scaling bounds. Default is translation_bounds = False, but we
    encourage using it when cropping is deactivated (i.e. when output_shape=None in BrainGenerator).
    :param nonlin_std: (optional) Maximum value for the standard deviation of the normal distribution from which we
    sample the first tensor for synthesising the deformation field. Set to 0 if you wish to completely turn the elastic
    deformation off.
    :param nonlin_shape_factor: (optional) if nonlin_std is strictly positive, factor between the shapes of the input
    label maps and the shape of the input non-linear tensor.
    :param randomise_res: (optional) whether to mimic images that would have been 1) acquired at low resolution, and
    2) resampled to high esolution. The low resolution is uniformly resampled at each minibatch from [1mm, 9mm].
    In that process, the images generated by sampling the GMM are 1) blurred at the sampled LR, 2) downsampled at LR,
    and 3) resampled at target_resolution.
    :param data_res: (optional) specific acquisition resolution to mimic, as opposed to random resolution sampled when
    randomis_res is True. This triggers a blurring to mimic the specified acquisition resolution, but the downsampling
    is optional (see param downsample). Default for data_res is None, where images are slighlty blurred.
    If the generated images are uni-modal, data_res can be a number (isotropic acquisition resolution), a sequence, a 1d
    numpy array, or the path to a 1d numy array. In the multi-modal case, it should be given as a numpy array (or a
    path) of size (n_mod, n_dims), where each row is the acquisition resolution of the corresponding channel.
    :param thickness: (optional) if data_res is provided, we can further specify the slice thickness of the low
    resolution images to mimic. Must be provided in the same format as data_res. Default thickness = data_res.
    :param downsample: (optional) whether to actually downsample the volume images to data_res after blurring.
    Default is False, except when thickness is provided, and thickness < data_res.
    :param blur_range: (optional) Randomise the standard deviation of the blurring kernels, (whether data_res is given
    or not). At each mini_batch, the standard deviation of the blurring kernels are multiplied by a coefficient sampled
    from a uniform distribution with bounds [1/blur_range, blur_range]. If None, no randomisation. Default is 1.15.
    :param bias_field_std: (optional) If strictly positive, this triggers the corruption of synthesised images with a
    bias field. It is obtained by sampling a first small tensor from a normal distribution, resizing it to full size,
    and rescaling it to positive values by taking the voxel-wise exponential. bias_field_std designates the std dev of
    the normal distribution from which we sample the first tensor. Set to 0 to deactivate biad field corruption.
    :param bias_shape_factor: (optional) If bias_field_std is strictly positive, this designates the ratio between the
    size of the input label maps and the size of the first sampled tensor for synthesising the bias field.
    """

    # reformat resolutions
    labels_shape = utils.reformat_to_list(labels_shape)
    n_dims, _ = utils.get_dims(labels_shape)
    atlas_res = utils.reformat_to_n_channels_array(atlas_res, n_dims,
                                                   n_channels)
    data_res = atlas_res if (
        data_res is None) else utils.reformat_to_n_channels_array(
            data_res, n_dims, n_channels)
    thickness = data_res if (
        thickness is None) else utils.reformat_to_n_channels_array(
            thickness, n_dims, n_channels)
    downsample = utils.reformat_to_list(
        downsample,
        n_channels) if downsample else (np.min(thickness - data_res, 1) < 0)
    atlas_res = atlas_res[0]
    target_res = atlas_res if (
        target_res is None) else utils.reformat_to_n_channels_array(
            target_res, n_dims)[0]

    # get shapes
    crop_shape, output_shape = get_shapes(labels_shape, output_shape,
                                          atlas_res, target_res,
                                          output_div_by_n)

    # create new_label_list and corresponding LUT to make sure that labels go from 0 to N-1
    new_generation_labels, lut = utils.rearrange_label_list(generation_labels)

    # define model inputs
    labels_input = KL.Input(shape=labels_shape + [1], name='labels_input')
    means_input = KL.Input(shape=list(new_generation_labels.shape) +
                           [n_channels],
                           name='means_input')
    stds_input = KL.Input(shape=list(new_generation_labels.shape) +
                          [n_channels],
                          name='std_devs_input')

    # convert labels to new_label_list
    labels = l2i_et.convert_labels(labels_input, lut)

    # deform labels
    if (scaling_bounds is not False) | (rotation_bounds is not False) | (shearing_bounds is not False) | \
       (translation_bounds is not False) | (nonlin_std > 0):
        labels._keras_shape = tuple(labels.get_shape().as_list())
        labels = layers.RandomSpatialDeformation(
            scaling_bounds=scaling_bounds,
            rotation_bounds=rotation_bounds,
            shearing_bounds=shearing_bounds,
            translation_bounds=translation_bounds,
            nonlin_std=nonlin_std,
            nonlin_shape_factor=nonlin_shape_factor,
            inter_method='nearest')(labels)

    # cropping
    if crop_shape != labels_shape:
        labels._keras_shape = tuple(labels.get_shape().as_list())
        labels = layers.RandomCrop(crop_shape)(labels)

    # flipping
    if flipping:
        assert aff is not None, 'aff should not be None if flipping is True'
        labels._keras_shape = tuple(labels.get_shape().as_list())
        labels = layers.RandomFlip(
            get_ras_axes(aff, n_dims)[0], True, new_generation_labels,
            n_neutral_labels)(labels)

    # build synthetic image
    labels._keras_shape = tuple(labels.get_shape().as_list())
    image = layers.SampleConditionalGMM()([labels, means_input, stds_input])

    # apply bias field
    if bias_field_std > 0:
        image._keras_shape = tuple(image.get_shape().as_list())
        image = layers.BiasFieldCorruption(bias_field_std, bias_shape_factor,
                                           False)(image)

    # intensity augmentation
    image._keras_shape = tuple(image.get_shape().as_list())
    image = layers.IntensityAugmentation(clip=300,
                                         normalise=True,
                                         gamma_std=.4,
                                         separate_channels=True)(image)

    # loop over channels
    channels = list()
    split = KL.Lambda(lambda x: tf.split(x, [1] * n_channels, axis=-1))(
        image) if (n_channels > 1) else [image]
    for i, channel in enumerate(split):

        channel._keras_shape = tuple(channel.get_shape().as_list())

        if randomise_res:
            max_res = np.array([9.] * 3)
            resolution, blur_res = layers.SampleResolution(
                atlas_res, max_res, .05, return_thickness=True)(means_input)
            sigma = l2i_et.blurring_sigma_for_downsampling(atlas_res,
                                                           resolution,
                                                           thickness=blur_res)
            channel = layers.DynamicGaussianBlur(
                0.75 * max_res / np.array(atlas_res),
                blur_range)([channel, sigma])
            if buil_distance_maps:
                channel, dist = layers.MimicAcquisition(
                    atlas_res, atlas_res, output_shape,
                    True)([channel, resolution])
                channels.extend([channel, dist])
            else:
                channel = layers.MimicAcquisition(atlas_res, atlas_res,
                                                  output_shape,
                                                  False)([channel, resolution])
                channels.append(channel)

        else:
            sigma = l2i_et.blurring_sigma_for_downsampling(
                atlas_res, data_res[i], thickness=thickness[i])
            channel = layers.GaussianBlur(sigma, blur_range)(channel)
            if downsample[i]:
                resolution = KL.Lambda(lambda x: tf.convert_to_tensor(
                    data_res[i], dtype='float32'))([])
                channel = layers.MimicAcquisition(atlas_res, data_res[i],
                                                  output_shape)(
                                                      [channel, resolution])
            elif output_shape != crop_shape:
                channel = nrn_layers.Resize(size=output_shape)(channel)
            channels.append(channel)

    # concatenate all channels back
    image = KL.Lambda(lambda x: tf.concat(x, -1))(
        channels) if len(channels) > 1 else channels[0]

    # resample labels at target resolution
    if crop_shape != output_shape:
        labels = l2i_et.resample_tensor(labels,
                                        output_shape,
                                        interp_method='nearest')

    # convert labels back to original values and reset unwanted labels to zero
    labels = l2i_et.convert_labels(labels, generation_labels)
    labels._keras_shape = tuple(labels.get_shape().as_list())
    reset_values = [v for v in generation_labels if v not in output_labels]
    labels = layers.ResetValuesToZero(reset_values, name='labels_out')(labels)

    # build model (dummy layer enables to keep the labels when plugging this model to other models)
    image = KL.Lambda(lambda x: x[0], name='image_out')([image, labels])
    brain_model = Model(inputs=[labels_input, means_input, stds_input],
                        outputs=[image, labels])

    return brain_model
Exemplo n.º 8
0
def build_model(model_file, input_shape, resample, im_res, n_levels, n_lab, conv_size, nb_conv_per_level,
                unet_feat_count, feat_multiplier, activation, sigma_smoothing):

    assert os.path.isfile(model_file), "The provided model path does not exist."

    # initialisation
    net = None
    n_dims, n_channels = utils.get_dims(input_shape, max_channels=10)
    resample = utils.reformat_to_list(resample, length=n_dims)

    # build preprocessing model
    if resample is not None:
        im_input = KL.Input(shape=input_shape, name='pre_resample_input')
        resample_factor = [im_res[i] / float(resample[i]) for i in range(n_dims)]
        resample_shape = [utils.find_closest_number_divisible_by_m(resample_factor[i] * input_shape[i],
                          2 ** n_levels, smaller_ans=False) for i in range(n_dims)]
        resampled = nrn_layers.Resize(size=resample_shape, name='pre_resample')(im_input)
        net = Model(inputs=im_input, outputs=resampled)
        input_shape = resample_shape + [n_channels]

    # build UNet
    net = nrn_models.unet(nb_features=unet_feat_count,
                          input_shape=input_shape,
                          nb_levels=n_levels,
                          conv_size=conv_size,
                          nb_labels=n_lab,
                          name='unet',
                          prefix=None,
                          feat_mult=feat_multiplier,
                          pool_size=2,
                          use_logp=True,
                          padding='same',
                          dilation_rate_mult=1,
                          activation=activation,
                          use_residuals=False,
                          final_pred_activation='softmax',
                          nb_conv_per_level=nb_conv_per_level,
                          add_prior_layer=False,
                          add_prior_layer_reg=0,
                          layer_nb_feats=None,
                          conv_dropout=0,
                          batch_norm=-1,
                          input_model=net)
    net.load_weights(model_file, by_name=True)

    # build postprocessing model
    if (resample is not None) | (sigma_smoothing != 0):

        # get UNet output
        input_tensor = net.inputs
        last_tensor = net.output

        # resample to initial resolution
        if resample is not None:
            last_tensor = nrn_layers.Resize(size=input_shape[:-1], name='post_resample')(last_tensor)

        # smooth posteriors
        if sigma_smoothing != 0:
            last_tensor._keras_shape = tuple(last_tensor.get_shape().as_list())
            last_tensor = layers.GaussianBlur(sigma=sigma_smoothing)(last_tensor)

        # build model
        net = Model(inputs=input_tensor, outputs=last_tensor)

    return net
Exemplo n.º 9
0
def labels_to_image_model(im_shape,
                          n_channels,
                          crop_shape,
                          label_list,
                          n_neutral_labels,
                          vox2ras,
                          nonlin_shape_factor=0.0625,
                          crop_channel2=None,
                          output_div_by_n=None,
                          flipping=True):

    # get shapes
    n_dims, _ = utils.get_dims(im_shape)
    crop_shape = get_shapes(crop_shape, im_shape, output_div_by_n)
    deformation_field_size = utils.get_resample_shape(im_shape, nonlin_shape_factor, len(im_shape))

    # create new_label_list and corresponding LUT to make sure that labels go from 0 to N-1
    new_label_list, lut = utils.rearrange_label_list(label_list)

    # define mandatory inputs
    image_input = KL.Input(shape=im_shape+[n_channels], name='image_input')
    labels_input = KL.Input(shape=im_shape + [1], name='labels_input')
    aff_in = KL.Input(shape=(n_dims + 1, n_dims + 1), name='aff_input')
    nonlin_field_in = KL.Input(shape=deformation_field_size, name='nonlin_input')
    list_inputs = [image_input, labels_input, aff_in, nonlin_field_in]

    # convert labels to new_label_list
    labels = KL.Lambda(lambda x: tf.gather(tf.convert_to_tensor(lut, dtype='int32'),
                                           tf.cast(x, dtype='int32')))(labels_input)

    # deform labels
    image_input._keras_shape = tuple(image_input.get_shape().as_list())
    labels._keras_shape = tuple(labels.get_shape().as_list())
    labels = KL.Lambda(lambda x: tf.cast(x, dtype='float'))(labels)
    resize_shape = [max(int(im_shape[i] / 2), deformation_field_size[i]) for i in range(len(im_shape))]
    nonlin_field = nrn_layers.Resize(size=resize_shape, interp_method='linear')(nonlin_field_in)
    nonlin_field = nrn_layers.VecInt()(nonlin_field)
    nonlin_field = nrn_layers.Resize(size=im_shape, interp_method='linear')(nonlin_field)
    image = nrn_layers.SpatialTransformer(interp_method='linear')([image_input, aff_in, nonlin_field])
    labels = nrn_layers.SpatialTransformer(interp_method='nearest')([labels, aff_in, nonlin_field])
    labels = KL.Lambda(lambda x: tf.cast(x, dtype='int32'))(labels)

    # cropping
    if crop_shape is not None:
        image, crop_idx = l2i_sa.random_cropping(image, crop_shape, n_dims)
        labels = KL.Lambda(lambda x: tf.slice(x[0], begin=tf.cast(x[1], dtype='int32'),
                           size=tf.convert_to_tensor([-1] + crop_shape + [-1], dtype='int32')))([labels, crop_idx])
    else:
        crop_shape = im_shape

    # flipping
    if flipping:
        labels, flip = l2i_sa.label_map_random_flipping(labels, label_list, n_neutral_labels, vox2ras, n_dims)
        ras_axes, _ = edit_volumes.get_ras_axes_and_signs(vox2ras, n_dims)
        flip_axis = [ras_axes[0] + 1]
        image = KL.Lambda(lambda y: K.switch(y[0],
                                             KL.Lambda(lambda x: K.reverse(x, axes=flip_axis))(y[1]),
                                             y[1]))([flip, image])

    # convert labels back to original values
    labels = KL.Lambda(lambda x: tf.gather(tf.convert_to_tensor(label_list, dtype='int32'),
                                           tf.cast(x, dtype='int32')), name='labels_out')(labels)

    # intensity augmentation
    image = KL.Lambda(lambda x: K.clip(x, 0, 300), name='clipping')(image)

    # loop over channels
    if n_channels > 1:
        split = KL.Lambda(lambda x: tf.split(x, [1] * n_channels, axis=-1))(image)
    else:
        split = [image]
    processed_channels = list()
    for i, channel in enumerate(split):

        # normalise and shift intensities
        image = l2i_ia.min_max_normalisation(image)
        image = KL.Lambda(lambda x: K.random_uniform((1,), .85, 1.1) * x + K.random_uniform((1,), -.3, .3))(image)
        image = KL.Lambda(lambda x: K.clip(x, 0, 1))(image)
        image = l2i_ia.gamma_augmentation(image)

        # randomly crop sides of second channel
        if (crop_channel2 is not None) & (channel == 1):
            image = l2i_sa.restrict_tensor(image, crop_channel2, n_dims)

    # concatenate all channels back, and clip output (include labels to keep it when plugging to other models)
    if n_channels > 1:
        image = KL.concatenate(processed_channels)
    else:
        image = processed_channels[0]
    image = KL.Lambda(lambda x: K.clip(x[0], 0, 1), name='image_out')([image, labels])

    # build model
    brain_model = Model(inputs=list_inputs, outputs=[image, labels])
    # shape of returned images
    output_shape = image.get_shape().as_list()[1:]

    return brain_model, output_shape
Exemplo n.º 10
0
def resample_tensor(tensor,
                    resample_shape,
                    interp_method='linear',
                    subsample_res=None,
                    volume_res=None,
                    build_reliability_map=False):
    """This function resamples a volume to resample_shape. It does not apply any pre-filtering.
    A prior downsampling step can be added if subsample_res is specified. In this case, volume_res should also be
    specified, in order to calculate the downsampling ratio. A reliability map can also be returned to indicate which
    slices were interpolated during resampling from the downsampled to final tensor.
    :param tensor: tensor
    :param resample_shape: list or numpy array of size (n_dims,)
    :param interp_method: (optional) interpolation method for resampling, 'linear' (default) or 'nearest'
    :param subsample_res: (optional) if not None, this triggers a downsampling of the volume, prior to the resampling
    step. List or numpy array of size (n_dims,). Default si None.
    :param volume_res: (optional) if subsample_res is not None, this should be provided to compute downsampling ratio.
    list or numpy array of size (n_dims,). Default is None.
    :param build_reliability_map: whether to return reliability map along with the resampled tensor. This map indicates
    which slices of the resampled tensor are interpolated (0=interpolated, 1=real slice, in between=degree of realness).
    :return: resampled volume, with reliability map if necessary.
    """

    # reformat resolutions to lists
    subsample_res = utils.reformat_to_list(subsample_res)
    volume_res = utils.reformat_to_list(volume_res)
    n_dims = len(resample_shape)

    # downsample image
    tensor_shape = tensor.get_shape().as_list()[1:-1]
    downsample_shape = tensor_shape  # will be modified if we actually downsample

    if subsample_res is not None:
        assert volume_res is not None, 'volume_res must be given when providing a subsampling resolution.'
        assert len(subsample_res) == len(volume_res), 'subsample_res and volume_res must have the same length, ' \
                                                      'had {0}, and {1}'.format(len(subsample_res), len(volume_res))
        if subsample_res != volume_res:

            # get shape at which we downsample
            downsample_shape = [int(tensor_shape[i] * volume_res[i] / subsample_res[i]) for i in range(n_dims)]

            # downsample volume
            tensor._keras_shape = tuple(tensor.get_shape().as_list())
            tensor = nrn_layers.Resize(size=downsample_shape, interp_method='nearest')(tensor)

    # resample image at target resolution
    if resample_shape != downsample_shape:  # if we didn't dowmsample downsample_shape = tensor_shape
        tensor._keras_shape = tuple(tensor.get_shape().as_list())
        tensor = nrn_layers.Resize(size=resample_shape, interp_method=interp_method)(tensor)

    # compute reliability maps if necessary and return results
    if build_reliability_map:

        # compute maps only if we downsampled
        if downsample_shape != tensor_shape:

            # compute upsampling factors
            upsampling_factors = np.array(resample_shape) / np.array(downsample_shape)

            # build reliability map
            reliability_map = 1
            for i in range(n_dims):
                loc_float = np.arange(0, resample_shape[i], upsampling_factors[i])
                loc_floor = np.int32(np.floor(loc_float))
                loc_ceil = np.int32(np.clip(loc_floor + 1, 0, resample_shape[i] - 1))
                tmp_reliability_map = np.zeros(resample_shape[i])
                tmp_reliability_map[loc_floor] = 1 - (loc_float - loc_floor)
                tmp_reliability_map[loc_ceil] = tmp_reliability_map[loc_ceil] + (loc_float - loc_floor)
                shape = [1, 1, 1]
                shape[i] = resample_shape[i]
                reliability_map = reliability_map * np.reshape(tmp_reliability_map, shape)
            shape = KL.Lambda(lambda x: tf.shape(x))(tensor)
            mask = KL.Lambda(lambda x: tf.reshape(tf.convert_to_tensor(reliability_map, dtype='float32'),
                                                  shape=x))(shape)

        # otherwise just return an all-one tensor
        else:
            mask = KL.Lambda(lambda x: tf.ones_like(x))(tensor)

        return tensor, mask

    else:
        return tensor
Exemplo n.º 11
0
def build_model(model_file, input_shape, resample, im_res, n_levels, n_lab, conv_size, nb_conv_per_level,
                unet_feat_count, feat_multiplier, no_batch_norm, activation, sigma_smoothing):

    # initialisation
    net = None
    n_dims, n_channels = utils.get_dims(input_shape, max_channels=3)
    resample = utils.reformat_to_list(resample, length=n_dims)

    # build preprocessing model
    if resample is not None:
        im_input = KL.Input(shape=input_shape, name='pre_resample_input')
        resample_factor = [im_res[i] / float(resample[i]) for i in range(n_dims)]
        resample_shape = [utils.find_closest_number_divisible_by_m(resample_factor[i] * input_shape[i],
                          2 ** n_levels, smaller_ans=False) for i in range(n_dims)]
        resampled = nrn_layers.Resize(size=resample_shape, name='pre_resample')(im_input)
        net = Model(inputs=im_input, outputs=resampled)
        input_shape = resample_shape + [n_channels]

    # build UNet
    if no_batch_norm:
        batch_norm_dim = None
    else:
        batch_norm_dim = -1
    net = nrn_models.unet(nb_features=unet_feat_count,
                          input_shape=input_shape,
                          nb_levels=n_levels,
                          conv_size=conv_size,
                          nb_labels=n_lab,
                          name='unet',
                          prefix=None,
                          feat_mult=feat_multiplier,
                          pool_size=2,
                          use_logp=True,
                          padding='same',
                          dilation_rate_mult=1,
                          activation=activation,
                          use_residuals=False,
                          final_pred_activation='softmax',
                          nb_conv_per_level=nb_conv_per_level,
                          add_prior_layer=False,
                          add_prior_layer_reg=0,
                          layer_nb_feats=None,
                          conv_dropout=0,
                          batch_norm=batch_norm_dim,
                          input_model=net)
    net.load_weights(model_file, by_name=True)

    # build postprocessing model
    if (resample is not None) | (sigma_smoothing != 0):

        # get UNet output
        input_tensor = net.inputs
        last_tensor = net.output

        # resample to initial resolution
        if resample is not None:
            last_tensor = nrn_layers.Resize(size=input_shape[:-1], name='post_resample')(last_tensor)

        # smooth each posteriors map separately
        if sigma_smoothing != 0:
            kernels_list = l2i_et.get_gaussian_1d_kernels(utils.reformat_to_list(sigma_smoothing, length=n_dims))
            split = KL.Lambda(lambda x: tf.split(x, [1] * n_lab, axis=-1), name='resample_split')(last_tensor)
            last_tensor = l2i_et.blur_tensor(split[0], kernels_list, n_dims)
            for i in range(1, n_lab):
                temp_blurred = l2i_et.blur_tensor(split[i], kernels_list, n_dims)
                last_tensor = KL.concatenate([last_tensor, temp_blurred], axis=-1, name='cat_blurring_%s' % i)

        # build model
        net = Model(inputs=input_tensor, outputs=last_tensor)

    return net
Exemplo n.º 12
0
def build_augmentation_model(im_shape,
                             n_channels,
                             segmentation_labels,
                             n_neutral_labels,
                             atlas_res,
                             target_res,
                             output_shape=None,
                             output_div_by_n=None,
                             flipping=True,
                             aff=None,
                             scaling_bounds=0.15,
                             rotation_bounds=15,
                             shearing_bounds=0.012,
                             translation_bounds=False,
                             nonlin_std=3.,
                             nonlin_shape_factor=.0625,
                             data_res=None,
                             thickness=None,
                             downsample=False,
                             blur_range=1.03,
                             bias_field_std=.5,
                             bias_shape_factor=.025):

    # reformat resolutions and get shapes
    im_shape = utils.reformat_to_list(im_shape)
    n_dims, _ = utils.get_dims(im_shape)
    if data_res is not None:
        data_res = utils.reformat_to_n_channels_array(data_res, n_dims,
                                                      n_channels)
        thickness = data_res if thickness is None else utils.reformat_to_n_channels_array(
            thickness, n_dims, n_channels)
        downsample = utils.reformat_to_list(
            downsample,
            n_channels) if downsample else np.min(thickness - data_res, 1) < 0
        target_res = atlas_res if (
            target_res is None) else utils.reformat_to_n_channels_array(
                target_res, n_dims)[0]
    else:
        target_res = atlas_res

    # get shapes
    crop_shape, output_shape = get_shapes(im_shape, output_shape, atlas_res,
                                          target_res, output_div_by_n)

    # define model inputs
    image_input = KL.Input(shape=im_shape + [n_channels], name='image_input')
    labels_input = KL.Input(shape=im_shape + [1],
                            name='labels_input',
                            dtype='int32')

    # deform labels
    labels, image = layers.RandomSpatialDeformation(
        scaling_bounds=scaling_bounds,
        rotation_bounds=rotation_bounds,
        shearing_bounds=shearing_bounds,
        translation_bounds=translation_bounds,
        nonlin_std=nonlin_std,
        nonlin_shape_factor=nonlin_shape_factor,
        inter_method=['nearest', 'linear'])([labels_input, image_input])

    # cropping
    if crop_shape != im_shape:
        labels._keras_shape = tuple(labels.get_shape().as_list())
        image._keras_shape = tuple(image.get_shape().as_list())
        labels, image = layers.RandomCrop(crop_shape)([labels, image])

    # flipping
    if flipping:
        assert aff is not None, 'aff should not be None if flipping is True'
        labels._keras_shape = tuple(labels.get_shape().as_list())
        image._keras_shape = tuple(image.get_shape().as_list())
        labels, image = layers.RandomFlip(
            get_ras_axes(aff, n_dims)[0], [True, False], segmentation_labels,
            n_neutral_labels)([labels, image])

    # apply bias field
    if bias_field_std > 0:
        image._keras_shape = tuple(image.get_shape().as_list())
        image = layers.BiasFieldCorruption(bias_field_std, bias_shape_factor,
                                           False)(image)

    # intensity augmentation
    image._keras_shape = tuple(image.get_shape().as_list())
    image = layers.IntensityAugmentation(6,
                                         clip=False,
                                         normalise=True,
                                         gamma_std=.4,
                                         separate_channels=True)(image)

    # if necessary, loop over channels to 1) blur, 2) downsample to simulated LR, and 3) upsample to target
    if data_res is not None:
        channels = list()
        split = KL.Lambda(lambda x: tf.split(x, [1] * n_channels, axis=-1))(
            image) if (n_channels > 1) else [image]
        for i, channel in enumerate(split):

            # blur
            channel._keras_shape = tuple(channel.get_shape().as_list())
            sigma = l2i_et.blurring_sigma_for_downsampling(
                atlas_res, data_res[i], thickness=thickness[i])
            channel = layers.GaussianBlur(sigma, blur_range)(channel)

            # resample
            if downsample[i]:
                resolution = KL.Lambda(lambda x: tf.convert_to_tensor(
                    data_res[i], dtype='float32'))([])
                channel = layers.MimicAcquisition(atlas_res, data_res[i],
                                                  output_shape)(
                                                      [channel, resolution])
            elif output_shape != crop_shape:
                channel = nrn_layers.Resize(size=output_shape)(channel)
            channels.append(channel)

        # concatenate all channels back
        image = KL.Lambda(lambda x: tf.concat(x, -1))(
            channels) if len(channels) > 1 else channels[0]

        # resample labels at target resolution
        if crop_shape != output_shape:
            labels = l2i_et.resample_tensor(labels,
                                            output_shape,
                                            interp_method='nearest')

    # build model (dummy layer enables to keep the labels when plugging this model to other models)
    labels = KL.Lambda(lambda x: tf.cast(x, dtype='int32'),
                       name='labels_out')(labels)
    image = KL.Lambda(lambda x: x[0], name='image_out')([image, labels])
    brain_model = models.Model(inputs=[image_input, labels_input],
                               outputs=[image, labels])

    return brain_model