def call(self, inputs, training=None, mask=None):
        """
        Build GlobalNet graph based on built layers.

        :param inputs: image batch, shape = (batch, f_dim1, f_dim2, f_dim3, ch)
        :param training: None or bool.
        :param mask: None or tf.Tensor.
        :return: tf.Tensor, shape = (batch, dim1, dim2, dim3, 3)
        """
        # down sample from level 0 to E
        h_in = inputs
        for level in range(self._extract_max_level):  # level 0 to E - 1
            h_in, _ = self._downsample_blocks[level](inputs=h_in,
                                                     training=training)
        h_out = self._conv3d_block(inputs=h_in,
                                   training=training)  # level E of encoding

        # predict affine parameters theta of shape = [batch, 4, 3]
        self.theta = self._dense_layer(h_out)
        self.theta = tf.reshape(self.theta, shape=(-1, 4, 3))

        # warp the reference grid with affine parameters to output a ddf
        grid_warped = layer_util.warp_grid(self.reference_grid, self.theta)
        output = grid_warped - self.reference_grid
        return output
예제 #2
0
def test_warp_grid():
    """
    Test warp_grid by confirming that it generates
    appropriate solutions for a simple precomputed case.
    """
    grid = tf.constant(
        np.array(
            [[[[0, 0, 0], [0, 0, 1], [0, 0, 2]],
              [[0, 1, 0], [0, 1, 1], [0, 1, 2]]]],
            dtype=np.float32,
        ))  # shape = (1, 2, 3, 3)
    theta = tf.constant(
        np.array(
            [[
                [0.86, 0.75, 0.48],
                [0.07, 0.98, 0.01],
                [0.72, 0.52, 0.97],
                [0.12, 0.4, 0.04],
            ]],
            dtype=np.float32,
        ))  # shape = (1, 4, 3)
    expected = tf.constant(
        np.array(
            [[[
                [[0.12, 0.4, 0.04], [0.84, 0.92, 1.01], [1.56, 1.44, 1.98]],
                [[0.19, 1.38, 0.05], [0.91, 1.9, 1.02], [1.63, 2.42, 1.99]],
            ]]],
            dtype=np.float32,
        ))  # shape = (1, 1, 2, 3, 3)
    got = layer_util.warp_grid(grid=grid, theta=theta)
    assert check_equal(got, expected)
 def test_non_identical(self):
     theta = tf.constant(
         np.array(
             [
                 [
                     [0.86, 0.75, 0.48],
                     [0.07, 0.98, 0.01],
                     [0.72, 0.52, 0.97],
                     [0.12, 0.4, 0.04],
                 ]
             ],
             dtype=np.float32,
         )
     )  # shape = (1, 4, 3)
     expected = tf.constant(
         np.array(
             [
                 [
                     [
                         [[0.12, 0.4, 0.04], [0.84, 0.92, 1.01], [1.56, 1.44, 1.98]],
                         [[0.19, 1.38, 0.05], [0.91, 1.9, 1.02], [1.63, 2.42, 1.99]],
                     ]
                 ]
             ],
             dtype=np.float32,
         )
     )  # shape = (1, 1, 2, 3, 3)
     got = layer_util.warp_grid(grid=self.grid, theta=theta)
     assert is_equal_tf(got, expected)
예제 #4
0
    def transform(image: tf.Tensor, grid_ref: tf.Tensor,
                  params: tf.Tensor) -> tf.Tensor:
        """
        Transforms the reference grid and then resample the image.

        :param image: shape = (batch, dim1, dim2, dim3)
        :param grid_ref: shape = (dim1, dim2, dim3, 3)
        :param params: shape = (batch, 4, 3)
        :return: shape = (batch, dim1, dim2, dim3)
        """
        return resample(vol=image, loc=warp_grid(grid_ref, params))
예제 #5
0
    def _transform(image, grid_ref, transforms):
        """

        :param image: shape = [batch, dim1, dim2, dim3]
        :param grid_ref: shape = [dim1, dim2, dim3, 3]
        :param transforms: shape = [batch, 4, 3]
        :return: shape = [batch, dim1, dim2, dim3]
        """
        transformed = layer_util.resample(vol=image,
                                          loc=layer_util.warp_grid(grid_ref, transforms))
        return transformed
예제 #6
0
    def _transform(image, grid_ref, transforms):
        """
        Resamples an input image from the reference grid by the series
        of input transforms.

        :param image: shape = (batch, dim1, dim2, dim3)
        :param grid_ref: shape = [dim1, dim2, dim3, 3]
        :param transforms: shape = [batch, 4, 3]
        :return: shape = (batch, dim1, dim2, dim3)
        """
        transformed = layer_util.resample(vol=image,
                                          loc=layer_util.warp_grid(
                                              grid_ref, transforms))
        return transformed
예제 #7
0
def train_step(grid, weights, optimizer, mov, fix):
    """
    Train step function for backprop using gradient tape

    :param grid: reference grid return from layer_util.get_reference_grid
    :param weights: trainable affine parameters [1, 4, 3]
    :param optimizer: tf.optimizers
    :param mov: moving image [1, m_dim1, m_dim2, m_dim3]
    :param fix: fixed image [1, f_dim1, f_dim2, f_dim3]
    :return loss: image dissimilarity to minimise
    """
    with tf.GradientTape() as tape:
        pred = layer_util.resample(vol=mov, loc=layer_util.warp_grid(grid, weights))
        loss = image_loss.dissimilarity_fn(
            y_true=fix, y_pred=pred, name=image_loss_name
        )
    gradients = tape.gradient(loss, [weights])
    optimizer.apply_gradients(zip(gradients, [weights]))
    return loss
예제 #8
0
    def call(self, inputs: Union[tf.Tensor, List],
             **kwargs) -> Tuple[tf.Tensor, tf.Tensor]:
        """

        :param inputs: a tensor or a list of tensor with length 1
        :param kwargs: additional args
        :return: ddf and theta

            - ddf has shape (batch, dim1, dim2, dim3, 3)
            - theta has shape (batch, 4, 3)
        """
        if isinstance(inputs, list):
            inputs = inputs[0]
        theta = self._dense(self._flatten(inputs))
        theta = tf.reshape(theta, shape=(-1, 4, 3))
        # warp the reference grid with affine parameters to output a ddf
        grid_warped = layer_util.warp_grid(self.reference_grid, theta)
        ddf = grid_warped - self.reference_grid
        return ddf, theta
예제 #9
0
 def test_identical(self):
     theta = tf.constant(np.eye(4, 3).reshape((1, 4, 3)), dtype=tf.float32)
     expected = self.grid[None, ...]  # shape = (1, 1, 2, 3, 3)
     got = layer_util.warp_grid(grid=self.grid, theta=theta)
     assert is_equal_tf(got, expected)
예제 #10
0
if not os.path.exists(DATA_PATH):
    raise ("Download the data using demo_data.py script")
if not os.path.exists(FILE_PATH):
    raise ("Download the data using demo_data.py script")

fid = h5py.File(FILE_PATH, "r")
fixed_image = tf.cast(tf.expand_dims(fid["image"], axis=0), dtype=tf.float32)
fixed_image = (fixed_image - tf.reduce_min(fixed_image)) / (
    tf.reduce_max(fixed_image) - tf.reduce_min(fixed_image)
)  # normalisation to [0,1]

# generate a radomly-affine-transformed moving image
fixed_image_size = fixed_image.shape
transform_random = layer_util.random_transform_generator(batch_size=1, scale=0.2)
grid_ref = layer_util.get_reference_grid(grid_size=fixed_image_size[1:4])
grid_random = layer_util.warp_grid(grid_ref, transform_random)
moving_image = layer_util.resample(vol=fixed_image, loc=grid_random)
# warp the labels to get ground-truth using the same random affine, for validation
fixed_labels = tf.cast(tf.expand_dims(fid["label"], axis=0), dtype=tf.float32)
moving_labels = tf.stack(
    [
        layer_util.resample(vol=fixed_labels[..., idx], loc=grid_random)
        for idx in range(fixed_labels.shape[4])
    ],
    axis=4,
)


## optimisation
@tf.function
def train_step(grid, weights, optimizer, mov, fix):