Beispiel #1
0
 def test_forward(self, moving_image_size, fixed_image_size):
     batch_size = 2
     image = tf.ones(shape=(batch_size, ) + moving_image_size)
     ddf = tf.ones(shape=(batch_size, ) + fixed_image_size + (3, ))
     outputs = layer.Warping(fixed_image_size=fixed_image_size)(
         [ddf, image])
     assert outputs.shape == (batch_size, *fixed_image_size)
Beispiel #2
0
def ddf_dvf_forward(
    backbone: tf.keras.Model,
    moving_image: tf.Tensor,
    fixed_image: tf.Tensor,
    moving_label: (tf.Tensor, None),
    moving_image_size: tuple,
    fixed_image_size: tuple,
    output_dvf: bool,
) -> [(tf.Tensor, None), tf.Tensor, tf.Tensor, (tf.Tensor, None), tf.Tensor]:
    """
    Perform the network forward pass
    :param backbone: model architecture object, e.g. model.backbone.local_net
    :param moving_image: tensor of shape (batch, m_dim1, m_dim2, m_dim3)
    :param fixed_image:  tensor of shape (batch, f_dim1, f_dim2, f_dim3)
    :param moving_label: tensor of shape (batch, m_dim1, m_dim2, m_dim3) or None
    :param moving_image_size: tuple like (m_dim1, m_dim2, m_dim3)
    :param fixed_image_size: tuple like (f_dim1, f_dim2, f_dim3)
    :param output_dvf: bool, if true, model outputs dvf, if false, model outputs ddf
    :return: tuple(dvf, ddf, pred_fixed_image, pred_fixed_label, fixed_grid), where
    - dvf is the dense velocity field of shape (batch, f_dim1, f_dim2, f_dim3, 3)
    - ddf is the dense displacement field of shape (batch, f_dim1, f_dim2, f_dim3, 3)
    - pred_fixed_image is the predicted (warped) moving image of shape (batch, f_dim1, f_dim2, f_dim3)
    - pred_fixed_label is the predicted (warped) moving label of shape (batch, f_dim1, f_dim2, f_dim3)
    - fixed_grid is the grid of shape(f_dim1, f_dim2, f_dim3, 3)
    """

    # expand dims
    # need to be squeezed later for warping
    moving_image = tf.expand_dims(moving_image,
                                  axis=4)  # (batch, m_dim1, m_dim2, m_dim3, 1)
    fixed_image = tf.expand_dims(fixed_image,
                                 axis=4)  # (batch, f_dim1, f_dim2, f_dim3, 1)

    # adjust moving image
    moving_image = layer_util.resize3d(
        image=moving_image,
        size=fixed_image_size)  # (batch, f_dim1, f_dim2, f_dim3, 1)

    # ddf, dvf
    inputs = tf.concat([moving_image, fixed_image],
                       axis=4)  # (batch, f_dim1, f_dim2, f_dim3, 2)
    backbone_out = backbone(
        inputs=inputs)  # (batch, f_dim1, f_dim2, f_dim3, 3)
    if output_dvf:
        dvf = backbone_out  # (batch, f_dim1, f_dim2, f_dim3, 3)
        ddf = layer.IntDVF(fixed_image_size=fixed_image_size)(
            dvf)  # (batch, f_dim1, f_dim2, f_dim3, 3)
    else:
        dvf = None
        ddf = backbone_out  # (batch, f_dim1, f_dim2, f_dim3, 3)

    # prediction, (batch, f_dim1, f_dim2, f_dim3)
    warping = layer.Warping(fixed_image_size=fixed_image_size)
    grid_fixed = tf.squeeze(warping.grid_ref,
                            axis=0)  # (f_dim1, f_dim2, f_dim3, 3)
    pred_fixed_image = warping(inputs=[ddf, tf.squeeze(moving_image, axis=4)])
    pred_fixed_label = (warping(
        inputs=[ddf, moving_label]) if moving_label is not None else None)
    return dvf, ddf, pred_fixed_image, pred_fixed_label, grid_fixed
Beispiel #3
0
 def test_get_config(self):
     warping = layer.Warping(fixed_image_size=(2, 3, 4))
     config = warping.get_config()
     assert config == dict(
         fixed_image_size=(2, 3, 4),
         name="warping",
         trainable=True,
         dtype="float32",
     )
Beispiel #4
0
def conditional_forward(
    backbone: tf.keras.Model,
    moving_image: tf.Tensor,
    fixed_image: tf.Tensor,
    moving_label: (tf.Tensor, None),
    moving_image_size: tuple,
    fixed_image_size: tuple,
) -> [tf.Tensor, tf.Tensor]:
    """
    Perform the network forward pass.

    :param backbone: model architecture object, e.g. model.backbone.local_net
    :param moving_image: tensor of shape (batch, m_dim1, m_dim2, m_dim3)
    :param fixed_image:  tensor of shape (batch, f_dim1, f_dim2, f_dim3)
    :param moving_label: tensor of shape (batch, m_dim1, m_dim2, m_dim3) or None
    :param moving_image_size: tuple like (m_dim1, m_dim2, m_dim3)
    :param fixed_image_size: tuple like (f_dim1, f_dim2, f_dim3)
    :return: (pred_fixed_label, fixed_grid), where

      - pred_fixed_label is the predicted (warped) moving label of shape (batch, f_dim1, f_dim2, f_dim3)
      - fixed_grid is the grid of shape(f_dim1, f_dim2, f_dim3, 3)
    """

    # expand dims
    # need to be squeezed later for warping
    moving_image = tf.expand_dims(moving_image,
                                  axis=4)  # (batch, m_dim1, m_dim2, m_dim3, 1)
    fixed_image = tf.expand_dims(fixed_image,
                                 axis=4)  # (batch, f_dim1, f_dim2, f_dim3, 1)
    moving_label = tf.expand_dims(moving_label,
                                  axis=4)  # (batch, m_dim1, m_dim2, m_dim3, 1)

    # adjust moving image
    if moving_image_size != fixed_image_size:
        moving_image = layer_util.resize3d(
            image=moving_image,
            size=fixed_image_size)  # (batch, f_dim1, f_dim2, f_dim3, 1)
        moving_label = layer_util.resize3d(
            image=moving_label,
            size=fixed_image_size)  # (batch, f_dim1, f_dim2, f_dim3, 1)

    # conditional
    inputs = tf.concat([moving_image, fixed_image, moving_label],
                       axis=4)  # (batch, f_dim1, f_dim2, f_dim3, 3)
    pred_fixed_label = backbone(
        inputs=inputs)  # (batch, f_dim1, f_dim2, f_dim3, 1)
    pred_fixed_label = tf.squeeze(pred_fixed_label,
                                  axis=4)  # (batch, f_dim1, f_dim2, f_dim3)

    warping = layer.Warping(fixed_image_size=fixed_image_size)
    grid_fixed = tf.squeeze(warping.grid_ref,
                            axis=0)  # (f_dim1, f_dim2, f_dim3, 3)

    return pred_fixed_label, grid_fixed
Beispiel #5
0
    def forward(_backbone, _moving_image, _moving_label, _fixed_image):
        """

        :param _backbone:
        :param _moving_image: [batch, m_dim1, m_dim2, m_dim3]
        :param _moving_label: [batch, m_dim1, m_dim2, m_dim3]
        :param _fixed_image:  [batch, f_dim1, f_dim2, f_dim3]
        :return:
        """
        # ddf
        backbone_input = tf.concat([layer.Resize3d(size=fixed_image_size)(inputs=tf.expand_dims(_moving_image, axis=4)),
                                    tf.expand_dims(_fixed_image, axis=4)],
                                   axis=4)  # [batch, f_dim1, f_dim2, f_dim3, 2]
        _ddf = _backbone(inputs=backbone_input)  # [batch, f_dim1, f_dim2, f_dim3, 3]

        # prediction image ang label shape = [batch, f_dim1, f_dim2, f_dim3]
        _pred_fixed_image = layer.Warping(fixed_image_size=fixed_image_size)([_ddf, _moving_image])
        _pred_fixed_label = layer.Warping(fixed_image_size=fixed_image_size)([_ddf, _moving_label])

        return _ddf, _pred_fixed_image, _pred_fixed_label
def affine_forward(
    backbone: tf.keras.Model,
    moving_image: tf.Tensor,
    fixed_image: tf.Tensor,
    moving_label: (tf.Tensor, None),
    moving_image_size: tuple,
    fixed_image_size: tuple,
):
    """
    Perform the network forward pass.

    :param backbone: model architecture object, e.g. model.backbone.local_net
    :param moving_image: tensor of shape (batch, m_dim1, m_dim2, m_dim3)
    :param fixed_image:  tensor of shape (batch, f_dim1, f_dim2, f_dim3)
    :param moving_label: tensor of shape (batch, m_dim1, m_dim2, m_dim3) or None
    :param moving_image_size: tuple like (m_dim1, m_dim2, m_dim3)
    :param fixed_image_size: tuple like (f_dim1, f_dim2, f_dim3)
    :return: tuple(affine, ddf, pred_fixed_image, pred_fixed_label, fixed_grid), where

      - affine is the affine transformation matrix predicted by the network (batch, 4, 3)
      - ddf is the dense displacement field of shape (batch, f_dim1, f_dim2, f_dim3, 3)
      - pred_fixed_image is the predicted (warped) moving image of shape (batch, f_dim1, f_dim2, f_dim3)
      - pred_fixed_label is the predicted (warped) moving label of shape (batch, f_dim1, f_dim2, f_dim3)
      - fixed_grid is the grid of shape(f_dim1, f_dim2, f_dim3, 3)
    """

    # expand dims
    # need to be squeezed later for warping
    moving_image = tf.expand_dims(moving_image,
                                  axis=4)  # (batch, m_dim1, m_dim2, m_dim3, 1)
    fixed_image = tf.expand_dims(fixed_image,
                                 axis=4)  # (batch, f_dim1, f_dim2, f_dim3, 1)

    # adjust moving image
    moving_image = layer_util.resize3d(
        image=moving_image,
        size=fixed_image_size)  # (batch, f_dim1, f_dim2, f_dim3, 1)

    # ddf, dvf
    inputs = tf.concat([moving_image, fixed_image],
                       axis=4)  # (batch, f_dim1, f_dim2, f_dim3, 2)
    ddf = backbone(inputs=inputs)  # (batch, f_dim1, f_dim2, f_dim3, 3)
    affine = backbone.theta

    # prediction, (batch, f_dim1, f_dim2, f_dim3)
    warping = layer.Warping(fixed_image_size=fixed_image_size)
    grid_fixed = tf.squeeze(warping.grid_ref,
                            axis=0)  # (f_dim1, f_dim2, f_dim3, 3)
    pred_fixed_image = warping(inputs=[ddf, tf.squeeze(moving_image, axis=4)])
    pred_fixed_label = (warping(
        inputs=[ddf, moving_label]) if moving_label is not None else None)
    return affine, ddf, pred_fixed_image, pred_fixed_label, grid_fixed
Beispiel #7
0
    def build_model(self):
        """Build the model to be saved as self._model."""
        # build inputs
        self._inputs = self.build_inputs()
        moving_image = self._inputs[
            "moving_image"]  # (batch, m_dim1, m_dim2, m_dim3)
        fixed_image = self._inputs[
            "fixed_image"]  # (batch, f_dim1, f_dim2, f_dim3)

        # build ddf
        control_points = self.config["backbone"].pop("control_points", False)
        backbone_inputs = self.concat_images(moving_image, fixed_image)
        backbone = REGISTRY.build_backbone(
            config=self.config["backbone"],
            default_args=dict(
                image_size=self.fixed_image_size,
                out_channels=3,
                out_kernel_initializer="zeros",
                out_activation=None,
            ),
        )

        if isinstance(backbone, GlobalNet):
            # (f_dim1, f_dim2, f_dim3, 3), (4, 3)
            ddf, theta = backbone(inputs=backbone_inputs)
            self._outputs = dict(ddf=ddf, theta=theta)
        else:
            # (f_dim1, f_dim2, f_dim3, 3)
            ddf = backbone(inputs=backbone_inputs)
            ddf = (self._resize_interpolate(ddf, control_points)
                   if control_points else ddf)
            self._outputs = dict(ddf=ddf)

        # build outputs
        warping = layer.Warping(fixed_image_size=self.fixed_image_size)
        # (f_dim1, f_dim2, f_dim3)
        pred_fixed_image = warping(inputs=[ddf, moving_image])
        self._outputs["pred_fixed_image"] = pred_fixed_image

        if not self.labeled:
            return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)

        # (f_dim1, f_dim2, f_dim3)
        moving_label = self._inputs["moving_label"]
        pred_fixed_label = warping(inputs=[ddf, moving_label])

        self._outputs["pred_fixed_label"] = pred_fixed_label
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
Beispiel #8
0
def test_initDVF():
    """
    Test the layer.IntDVF class, its default attributes and its call() method.
    """

    batch_size = 5
    fixed_image_size = (32, 32, 16)
    ndims = len(fixed_image_size)

    model = layer.IntDVF(fixed_image_size)

    assert isinstance(model._warping, type(layer.Warping(fixed_image_size)))
    assert model._num_steps == 7

    inputs = np.ones((batch_size, *fixed_image_size, ndims))
    output = model.call(inputs)
    assert all(x == y for x, y in zip((batch_size, ) + fixed_image_size +
                                      (ndims, ), output.shape))
Beispiel #9
0
    def build_model(self):
        """Build the model to be saved as self._model."""
        # build inputs
        self._inputs = self.build_inputs()
        moving_image = self._inputs["moving_image"]
        fixed_image = self._inputs["fixed_image"]
        control_points = self.config["backbone"].pop("control_points", False)

        # build ddf
        backbone_inputs = self.concat_images(moving_image, fixed_image)
        backbone = REGISTRY.build_backbone(
            config=self.config["backbone"],
            default_args=dict(
                image_size=self.fixed_image_size,
                out_channels=3,
                out_kernel_initializer="zeros",
                out_activation=None,
            ),
        )
        dvf = backbone(inputs=backbone_inputs)
        dvf = self._resize_interpolate(
            dvf, control_points) if control_points else dvf
        ddf = layer.IntDVF(fixed_image_size=self.fixed_image_size)(dvf)

        # build outputs
        self._warping = layer.Warping(fixed_image_size=self.fixed_image_size)
        # (f_dim1, f_dim2, f_dim3, 3)
        pred_fixed_image = self._warping(inputs=[ddf, moving_image])

        self._outputs = dict(dvf=dvf,
                             ddf=ddf,
                             pred_fixed_image=pred_fixed_image)

        if not self.labeled:
            return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)

        # (f_dim1, f_dim2, f_dim3, 3)
        moving_label = self._inputs["moving_label"]
        pred_fixed_label = self._warping(inputs=[ddf, moving_label])

        self._outputs["pred_fixed_label"] = pred_fixed_label
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
Beispiel #10
0
def test_warping():
    """
    Test the layer.Warping class, its default attributes and its call() method.
    """
    batch_size = 5
    fixed_image_size = (32, 32, 16)
    moving_image_size = (24, 24, 16)
    ndims = len(moving_image_size)

    grid_size = (1, ) + fixed_image_size + (3, )
    model = layer.Warping(fixed_image_size)

    assert all(x == y for x, y in zip(grid_size, model.grid_ref.shape))

    # Pass an input of all zeros
    inputs = [
        np.ones(
            (
                batch_size,
                fixed_image_size[0],
                fixed_image_size[1],
                fixed_image_size[2],
                ndims,
            ),
            dtype="float32",
        ),
        np.ones(
            (
                batch_size,
                moving_image_size[0],
                moving_image_size[1],
                moving_image_size[2],
            ),
            dtype="float32",
        ),
    ]
    #  Get outputs by calling
    output = model.call(inputs)
    #  Expected shape is (5, 1, 2, 3, 3)
    assert all(x == y for x, y in zip((batch_size, ) +
                                      fixed_image_size, output.shape))
Beispiel #11
0
        loss_image = image_loss.dissimilarity_fn(
            y_true=fix, y_pred=pred, name=image_loss_name
        )
        loss_deform = deform_loss.local_displacement_energy(weights, deform_loss_name)
        loss = loss_image + weight_deform_loss * loss_deform
    gradients = tape.gradient(loss, [weights])
    optimizer.apply_gradients(zip(gradients, [weights]))
    return loss, loss_image, loss_deform


# ddf as trainable weights
fixed_image_size = fixed_image.shape
initializer = tf.random_normal_initializer(mean=0, stddev=1e-3)
var_ddf = tf.Variable(initializer(fixed_image_size + [3]), name="ddf", trainable=True)

warping = layer.Warping(fixed_image_size=fixed_image_size[1:4])
optimiser = tf.optimizers.Adam(learning_rate)
for step in range(total_iter):
    loss_opt, loss_image_opt, loss_deform_opt = train_step(
        warping, var_ddf, optimiser, moving_image, fixed_image
    )
    if (step % 50) == 0:  # print info
        tf.print(
            "Step",
            step,
            "loss",
            loss_opt,
            image_loss_name,
            loss_image_opt,
            deform_loss_name,
            loss_deform_opt,