Example #1
0
def deform_tensor(tensor,
                  affine_trans=None,
                  apply_elastic_trans=True,
                  interp_method='linear',
                  nonlin_std=2.,
                  nonlin_shape_factor=.0625):
    """This function spatially deforms a tensor with a combination of affine and elastic transformations.
    :param tensor: input tensor to deform. Expected to have shape [batchsize, shape_dim1, ..., shape_dimn, channel].
    :param affine_trans: (optional) tensor of shape [batchsize, n_dims+1, n_dims+1] corresponding to an affine 
    transformation. Default is None, no affine transformation is applied.
    :param apply_elastic_trans: (optional) whether to deform the input tensor with a diffeomorphic elastic 
    transformation. If True the following steps occur:
    1) a small-size SVF is sampled from a centred normal distribution of random standard deviation.
    2) it is resized with trilinear interpolation to half the shape of the input tensor
    3) it is integrated to obtain a diffeomorphic transformation
    4) finally, it is resized (again with trilinear interpolation) to full image size
    Default is None, where no elastic transformation is applied.
    :param interp_method: (optional) interpolation method when deforming the input tensor. Can be 'linear', or 'nearest'
    :param nonlin_std: (optional) maximum value of the standard deviation of the normal distribution from which we
    sample the small-size SVF.
    :param nonlin_shape_factor: (optional) ration between the shape of the input tensor and the shape of the small field
    for elastic deformation.
    :return: tensor of the same shape as volume
    """

    assert (affine_trans is not None) | apply_elastic_trans, 'affine_trans or elastic_trans should be provided'

    # reformat tensor and get its shape
    tensor = KL.Lambda(lambda x: tf.cast(x, dtype='float32'))(tensor)
    tensor._keras_shape = tuple(tensor.get_shape().as_list())
    volume_shape = tensor.get_shape().as_list()[1: -1]
    n_dims = len(volume_shape)
    trans_inputs = [tensor]

    # add affine deformation to inputs list
    if affine_trans is not None:
        trans_inputs.append(affine_trans)

    # prepare non-linear deformation field and add it to inputs list
    if apply_elastic_trans:

        # sample small field from normal distribution of specified std dev
        small_shape = utils.get_resample_shape(volume_shape, nonlin_shape_factor, n_dims)
        tensor_shape = KL.Lambda(lambda x: tf.shape(x))(tensor)
        split_shape = KL.Lambda(lambda x: tf.split(x, [1, n_dims + 1]))(tensor_shape)
        nonlin_shape = KL.Lambda(lambda x: tf.concat([x, tf.convert_to_tensor(small_shape)], axis=0))(split_shape[0])
        nonlin_std_prior = KL.Lambda(lambda x: tf.random.uniform((1, 1), maxval=nonlin_std))([])
        elastic_trans = KL.Lambda(lambda x: tf.random.normal(x[0], stddev=x[1]))([nonlin_shape, nonlin_std_prior])
        elastic_trans._keras_shape = tuple(elastic_trans.get_shape().as_list())

        # reshape this field to image size and integrate it
        resize_shape = [max(int(volume_shape[i]/2), small_shape[i]) for i in range(n_dims)]
        nonlin_field = nrn_layers.Resize(size=resize_shape, interp_method='linear')(elastic_trans)
        nonlin_field = nrn_layers.VecInt()(nonlin_field)
        nonlin_field = nrn_layers.Resize(size=volume_shape, interp_method='linear')(nonlin_field)
        trans_inputs.append(nonlin_field)

    # apply deformations
    return nrn_layers.SpatialTransformer(interp_method=interp_method)(trans_inputs)
def labels_to_image_model(im_shape,
                          n_channels,
                          crop_shape,
                          label_list,
                          n_neutral_labels,
                          vox2ras,
                          nonlin_shape_factor=0.0625,
                          crop_channel2=None,
                          output_div_by_n=None,
                          flipping=True):

    # get shapes
    n_dims, _ = utils.get_dims(im_shape)
    crop_shape = get_shapes(crop_shape, im_shape, output_div_by_n)
    deformation_field_size = utils.get_resample_shape(im_shape, nonlin_shape_factor, len(im_shape))

    # create new_label_list and corresponding LUT to make sure that labels go from 0 to N-1
    new_label_list, lut = utils.rearrange_label_list(label_list)

    # define mandatory inputs
    image_input = KL.Input(shape=im_shape+[n_channels], name='image_input')
    labels_input = KL.Input(shape=im_shape + [1], name='labels_input')
    aff_in = KL.Input(shape=(n_dims + 1, n_dims + 1), name='aff_input')
    nonlin_field_in = KL.Input(shape=deformation_field_size, name='nonlin_input')
    list_inputs = [image_input, labels_input, aff_in, nonlin_field_in]

    # convert labels to new_label_list
    labels = KL.Lambda(lambda x: tf.gather(tf.convert_to_tensor(lut, dtype='int32'),
                                           tf.cast(x, dtype='int32')))(labels_input)

    # deform labels
    image_input._keras_shape = tuple(image_input.get_shape().as_list())
    labels._keras_shape = tuple(labels.get_shape().as_list())
    labels = KL.Lambda(lambda x: tf.cast(x, dtype='float'))(labels)
    resize_shape = [max(int(im_shape[i] / 2), deformation_field_size[i]) for i in range(len(im_shape))]
    nonlin_field = nrn_layers.Resize(size=resize_shape, interp_method='linear')(nonlin_field_in)
    nonlin_field = nrn_layers.VecInt()(nonlin_field)
    nonlin_field = nrn_layers.Resize(size=im_shape, interp_method='linear')(nonlin_field)
    image = nrn_layers.SpatialTransformer(interp_method='linear')([image_input, aff_in, nonlin_field])
    labels = nrn_layers.SpatialTransformer(interp_method='nearest')([labels, aff_in, nonlin_field])
    labels = KL.Lambda(lambda x: tf.cast(x, dtype='int32'))(labels)

    # cropping
    if crop_shape is not None:
        image, crop_idx = l2i_sa.random_cropping(image, crop_shape, n_dims)
        labels = KL.Lambda(lambda x: tf.slice(x[0], begin=tf.cast(x[1], dtype='int32'),
                           size=tf.convert_to_tensor([-1] + crop_shape + [-1], dtype='int32')))([labels, crop_idx])
    else:
        crop_shape = im_shape

    # flipping
    if flipping:
        labels, flip = l2i_sa.label_map_random_flipping(labels, label_list, n_neutral_labels, vox2ras, n_dims)
        ras_axes, _ = edit_volumes.get_ras_axes_and_signs(vox2ras, n_dims)
        flip_axis = [ras_axes[0] + 1]
        image = KL.Lambda(lambda y: K.switch(y[0],
                                             KL.Lambda(lambda x: K.reverse(x, axes=flip_axis))(y[1]),
                                             y[1]))([flip, image])

    # convert labels back to original values
    labels = KL.Lambda(lambda x: tf.gather(tf.convert_to_tensor(label_list, dtype='int32'),
                                           tf.cast(x, dtype='int32')), name='labels_out')(labels)

    # intensity augmentation
    image = KL.Lambda(lambda x: K.clip(x, 0, 300), name='clipping')(image)

    # loop over channels
    if n_channels > 1:
        split = KL.Lambda(lambda x: tf.split(x, [1] * n_channels, axis=-1))(image)
    else:
        split = [image]
    processed_channels = list()
    for i, channel in enumerate(split):

        # normalise and shift intensities
        image = l2i_ia.min_max_normalisation(image)
        image = KL.Lambda(lambda x: K.random_uniform((1,), .85, 1.1) * x + K.random_uniform((1,), -.3, .3))(image)
        image = KL.Lambda(lambda x: K.clip(x, 0, 1))(image)
        image = l2i_ia.gamma_augmentation(image)

        # randomly crop sides of second channel
        if (crop_channel2 is not None) & (channel == 1):
            image = l2i_sa.restrict_tensor(image, crop_channel2, n_dims)

    # concatenate all channels back, and clip output (include labels to keep it when plugging to other models)
    if n_channels > 1:
        image = KL.concatenate(processed_channels)
    else:
        image = processed_channels[0]
    image = KL.Lambda(lambda x: K.clip(x[0], 0, 1), name='image_out')([image, labels])

    # build model
    brain_model = Model(inputs=list_inputs, outputs=[image, labels])
    # shape of returned images
    output_shape = image.get_shape().as_list()[1:]

    return brain_model, output_shape