def deform_tensor(tensor, affine_trans=None, apply_elastic_trans=True, interp_method='linear', nonlin_std=2., nonlin_shape_factor=.0625): """This function spatially deforms a tensor with a combination of affine and elastic transformations. :param tensor: input tensor to deform. Expected to have shape [batchsize, shape_dim1, ..., shape_dimn, channel]. :param affine_trans: (optional) tensor of shape [batchsize, n_dims+1, n_dims+1] corresponding to an affine transformation. Default is None, no affine transformation is applied. :param apply_elastic_trans: (optional) whether to deform the input tensor with a diffeomorphic elastic transformation. If True the following steps occur: 1) a small-size SVF is sampled from a centred normal distribution of random standard deviation. 2) it is resized with trilinear interpolation to half the shape of the input tensor 3) it is integrated to obtain a diffeomorphic transformation 4) finally, it is resized (again with trilinear interpolation) to full image size Default is None, where no elastic transformation is applied. :param interp_method: (optional) interpolation method when deforming the input tensor. Can be 'linear', or 'nearest' :param nonlin_std: (optional) maximum value of the standard deviation of the normal distribution from which we sample the small-size SVF. :param nonlin_shape_factor: (optional) ration between the shape of the input tensor and the shape of the small field for elastic deformation. :return: tensor of the same shape as volume """ assert (affine_trans is not None) | apply_elastic_trans, 'affine_trans or elastic_trans should be provided' # reformat tensor and get its shape tensor = KL.Lambda(lambda x: tf.cast(x, dtype='float32'))(tensor) tensor._keras_shape = tuple(tensor.get_shape().as_list()) volume_shape = tensor.get_shape().as_list()[1: -1] n_dims = len(volume_shape) trans_inputs = [tensor] # add affine deformation to inputs list if affine_trans is not None: trans_inputs.append(affine_trans) # prepare non-linear deformation field and add it to inputs list if apply_elastic_trans: # sample small field from normal distribution of specified std dev small_shape = utils.get_resample_shape(volume_shape, nonlin_shape_factor, n_dims) tensor_shape = KL.Lambda(lambda x: tf.shape(x))(tensor) split_shape = KL.Lambda(lambda x: tf.split(x, [1, n_dims + 1]))(tensor_shape) nonlin_shape = KL.Lambda(lambda x: tf.concat([x, tf.convert_to_tensor(small_shape)], axis=0))(split_shape[0]) nonlin_std_prior = KL.Lambda(lambda x: tf.random.uniform((1, 1), maxval=nonlin_std))([]) elastic_trans = KL.Lambda(lambda x: tf.random.normal(x[0], stddev=x[1]))([nonlin_shape, nonlin_std_prior]) elastic_trans._keras_shape = tuple(elastic_trans.get_shape().as_list()) # reshape this field to image size and integrate it resize_shape = [max(int(volume_shape[i]/2), small_shape[i]) for i in range(n_dims)] nonlin_field = nrn_layers.Resize(size=resize_shape, interp_method='linear')(elastic_trans) nonlin_field = nrn_layers.VecInt()(nonlin_field) nonlin_field = nrn_layers.Resize(size=volume_shape, interp_method='linear')(nonlin_field) trans_inputs.append(nonlin_field) # apply deformations return nrn_layers.SpatialTransformer(interp_method=interp_method)(trans_inputs)
def labels_to_image_model(im_shape, n_channels, crop_shape, label_list, n_neutral_labels, vox2ras, nonlin_shape_factor=0.0625, crop_channel2=None, output_div_by_n=None, flipping=True): # get shapes n_dims, _ = utils.get_dims(im_shape) crop_shape = get_shapes(crop_shape, im_shape, output_div_by_n) deformation_field_size = utils.get_resample_shape(im_shape, nonlin_shape_factor, len(im_shape)) # create new_label_list and corresponding LUT to make sure that labels go from 0 to N-1 new_label_list, lut = utils.rearrange_label_list(label_list) # define mandatory inputs image_input = KL.Input(shape=im_shape+[n_channels], name='image_input') labels_input = KL.Input(shape=im_shape + [1], name='labels_input') aff_in = KL.Input(shape=(n_dims + 1, n_dims + 1), name='aff_input') nonlin_field_in = KL.Input(shape=deformation_field_size, name='nonlin_input') list_inputs = [image_input, labels_input, aff_in, nonlin_field_in] # convert labels to new_label_list labels = KL.Lambda(lambda x: tf.gather(tf.convert_to_tensor(lut, dtype='int32'), tf.cast(x, dtype='int32')))(labels_input) # deform labels image_input._keras_shape = tuple(image_input.get_shape().as_list()) labels._keras_shape = tuple(labels.get_shape().as_list()) labels = KL.Lambda(lambda x: tf.cast(x, dtype='float'))(labels) resize_shape = [max(int(im_shape[i] / 2), deformation_field_size[i]) for i in range(len(im_shape))] nonlin_field = nrn_layers.Resize(size=resize_shape, interp_method='linear')(nonlin_field_in) nonlin_field = nrn_layers.VecInt()(nonlin_field) nonlin_field = nrn_layers.Resize(size=im_shape, interp_method='linear')(nonlin_field) image = nrn_layers.SpatialTransformer(interp_method='linear')([image_input, aff_in, nonlin_field]) labels = nrn_layers.SpatialTransformer(interp_method='nearest')([labels, aff_in, nonlin_field]) labels = KL.Lambda(lambda x: tf.cast(x, dtype='int32'))(labels) # cropping if crop_shape is not None: image, crop_idx = l2i_sa.random_cropping(image, crop_shape, n_dims) labels = KL.Lambda(lambda x: tf.slice(x[0], begin=tf.cast(x[1], dtype='int32'), size=tf.convert_to_tensor([-1] + crop_shape + [-1], dtype='int32')))([labels, crop_idx]) else: crop_shape = im_shape # flipping if flipping: labels, flip = l2i_sa.label_map_random_flipping(labels, label_list, n_neutral_labels, vox2ras, n_dims) ras_axes, _ = edit_volumes.get_ras_axes_and_signs(vox2ras, n_dims) flip_axis = [ras_axes[0] + 1] image = KL.Lambda(lambda y: K.switch(y[0], KL.Lambda(lambda x: K.reverse(x, axes=flip_axis))(y[1]), y[1]))([flip, image]) # convert labels back to original values labels = KL.Lambda(lambda x: tf.gather(tf.convert_to_tensor(label_list, dtype='int32'), tf.cast(x, dtype='int32')), name='labels_out')(labels) # intensity augmentation image = KL.Lambda(lambda x: K.clip(x, 0, 300), name='clipping')(image) # loop over channels if n_channels > 1: split = KL.Lambda(lambda x: tf.split(x, [1] * n_channels, axis=-1))(image) else: split = [image] processed_channels = list() for i, channel in enumerate(split): # normalise and shift intensities image = l2i_ia.min_max_normalisation(image) image = KL.Lambda(lambda x: K.random_uniform((1,), .85, 1.1) * x + K.random_uniform((1,), -.3, .3))(image) image = KL.Lambda(lambda x: K.clip(x, 0, 1))(image) image = l2i_ia.gamma_augmentation(image) # randomly crop sides of second channel if (crop_channel2 is not None) & (channel == 1): image = l2i_sa.restrict_tensor(image, crop_channel2, n_dims) # concatenate all channels back, and clip output (include labels to keep it when plugging to other models) if n_channels > 1: image = KL.concatenate(processed_channels) else: image = processed_channels[0] image = KL.Lambda(lambda x: K.clip(x[0], 0, 1), name='image_out')([image, labels]) # build model brain_model = Model(inputs=list_inputs, outputs=[image, labels]) # shape of returned images output_shape = image.get_shape().as_list()[1:] return brain_model, output_shape