def call(self, inputs, training=None, mask=None): """ Build GlobalNet graph based on built layers. :param inputs: image batch, shape = (batch, f_dim1, f_dim2, f_dim3, ch) :param training: None or bool. :param mask: None or tf.Tensor. :return: tf.Tensor, shape = (batch, dim1, dim2, dim3, 3) """ # down sample from level 0 to E h_in = inputs for level in range(self._extract_max_level): # level 0 to E - 1 h_in, _ = self._downsample_blocks[level](inputs=h_in, training=training) h_out = self._conv3d_block(inputs=h_in, training=training) # level E of encoding # predict affine parameters theta of shape = [batch, 4, 3] self.theta = self._dense_layer(h_out) self.theta = tf.reshape(self.theta, shape=(-1, 4, 3)) # warp the reference grid with affine parameters to output a ddf grid_warped = layer_util.warp_grid(self.reference_grid, self.theta) output = grid_warped - self.reference_grid return output
def test_warp_grid(): """ Test warp_grid by confirming that it generates appropriate solutions for a simple precomputed case. """ grid = tf.constant( np.array( [[[[0, 0, 0], [0, 0, 1], [0, 0, 2]], [[0, 1, 0], [0, 1, 1], [0, 1, 2]]]], dtype=np.float32, )) # shape = (1, 2, 3, 3) theta = tf.constant( np.array( [[ [0.86, 0.75, 0.48], [0.07, 0.98, 0.01], [0.72, 0.52, 0.97], [0.12, 0.4, 0.04], ]], dtype=np.float32, )) # shape = (1, 4, 3) expected = tf.constant( np.array( [[[ [[0.12, 0.4, 0.04], [0.84, 0.92, 1.01], [1.56, 1.44, 1.98]], [[0.19, 1.38, 0.05], [0.91, 1.9, 1.02], [1.63, 2.42, 1.99]], ]]], dtype=np.float32, )) # shape = (1, 1, 2, 3, 3) got = layer_util.warp_grid(grid=grid, theta=theta) assert check_equal(got, expected)
def test_non_identical(self): theta = tf.constant( np.array( [ [ [0.86, 0.75, 0.48], [0.07, 0.98, 0.01], [0.72, 0.52, 0.97], [0.12, 0.4, 0.04], ] ], dtype=np.float32, ) ) # shape = (1, 4, 3) expected = tf.constant( np.array( [ [ [ [[0.12, 0.4, 0.04], [0.84, 0.92, 1.01], [1.56, 1.44, 1.98]], [[0.19, 1.38, 0.05], [0.91, 1.9, 1.02], [1.63, 2.42, 1.99]], ] ] ], dtype=np.float32, ) ) # shape = (1, 1, 2, 3, 3) got = layer_util.warp_grid(grid=self.grid, theta=theta) assert is_equal_tf(got, expected)
def transform(image: tf.Tensor, grid_ref: tf.Tensor, params: tf.Tensor) -> tf.Tensor: """ Transforms the reference grid and then resample the image. :param image: shape = (batch, dim1, dim2, dim3) :param grid_ref: shape = (dim1, dim2, dim3, 3) :param params: shape = (batch, 4, 3) :return: shape = (batch, dim1, dim2, dim3) """ return resample(vol=image, loc=warp_grid(grid_ref, params))
def _transform(image, grid_ref, transforms): """ :param image: shape = [batch, dim1, dim2, dim3] :param grid_ref: shape = [dim1, dim2, dim3, 3] :param transforms: shape = [batch, 4, 3] :return: shape = [batch, dim1, dim2, dim3] """ transformed = layer_util.resample(vol=image, loc=layer_util.warp_grid(grid_ref, transforms)) return transformed
def _transform(image, grid_ref, transforms): """ Resamples an input image from the reference grid by the series of input transforms. :param image: shape = (batch, dim1, dim2, dim3) :param grid_ref: shape = [dim1, dim2, dim3, 3] :param transforms: shape = [batch, 4, 3] :return: shape = (batch, dim1, dim2, dim3) """ transformed = layer_util.resample(vol=image, loc=layer_util.warp_grid( grid_ref, transforms)) return transformed
def train_step(grid, weights, optimizer, mov, fix): """ Train step function for backprop using gradient tape :param grid: reference grid return from layer_util.get_reference_grid :param weights: trainable affine parameters [1, 4, 3] :param optimizer: tf.optimizers :param mov: moving image [1, m_dim1, m_dim2, m_dim3] :param fix: fixed image [1, f_dim1, f_dim2, f_dim3] :return loss: image dissimilarity to minimise """ with tf.GradientTape() as tape: pred = layer_util.resample(vol=mov, loc=layer_util.warp_grid(grid, weights)) loss = image_loss.dissimilarity_fn( y_true=fix, y_pred=pred, name=image_loss_name ) gradients = tape.gradient(loss, [weights]) optimizer.apply_gradients(zip(gradients, [weights])) return loss
def call(self, inputs: Union[tf.Tensor, List], **kwargs) -> Tuple[tf.Tensor, tf.Tensor]: """ :param inputs: a tensor or a list of tensor with length 1 :param kwargs: additional args :return: ddf and theta - ddf has shape (batch, dim1, dim2, dim3, 3) - theta has shape (batch, 4, 3) """ if isinstance(inputs, list): inputs = inputs[0] theta = self._dense(self._flatten(inputs)) theta = tf.reshape(theta, shape=(-1, 4, 3)) # warp the reference grid with affine parameters to output a ddf grid_warped = layer_util.warp_grid(self.reference_grid, theta) ddf = grid_warped - self.reference_grid return ddf, theta
def test_identical(self): theta = tf.constant(np.eye(4, 3).reshape((1, 4, 3)), dtype=tf.float32) expected = self.grid[None, ...] # shape = (1, 1, 2, 3, 3) got = layer_util.warp_grid(grid=self.grid, theta=theta) assert is_equal_tf(got, expected)
if not os.path.exists(DATA_PATH): raise ("Download the data using demo_data.py script") if not os.path.exists(FILE_PATH): raise ("Download the data using demo_data.py script") fid = h5py.File(FILE_PATH, "r") fixed_image = tf.cast(tf.expand_dims(fid["image"], axis=0), dtype=tf.float32) fixed_image = (fixed_image - tf.reduce_min(fixed_image)) / ( tf.reduce_max(fixed_image) - tf.reduce_min(fixed_image) ) # normalisation to [0,1] # generate a radomly-affine-transformed moving image fixed_image_size = fixed_image.shape transform_random = layer_util.random_transform_generator(batch_size=1, scale=0.2) grid_ref = layer_util.get_reference_grid(grid_size=fixed_image_size[1:4]) grid_random = layer_util.warp_grid(grid_ref, transform_random) moving_image = layer_util.resample(vol=fixed_image, loc=grid_random) # warp the labels to get ground-truth using the same random affine, for validation fixed_labels = tf.cast(tf.expand_dims(fid["label"], axis=0), dtype=tf.float32) moving_labels = tf.stack( [ layer_util.resample(vol=fixed_labels[..., idx], loc=grid_random) for idx in range(fixed_labels.shape[4]) ], axis=4, ) ## optimisation @tf.function def train_step(grid, weights, optimizer, mov, fix):