Beispiel #1
0
def ddf_dvf_forward(
    backbone: tf.keras.Model,
    moving_image: tf.Tensor,
    fixed_image: tf.Tensor,
    moving_label: (tf.Tensor, None),
    moving_image_size: tuple,
    fixed_image_size: tuple,
    output_dvf: bool,
) -> [(tf.Tensor, None), tf.Tensor, tf.Tensor, (tf.Tensor, None), tf.Tensor]:
    """
    Perform the network forward pass
    :param backbone: model architecture object, e.g. model.backbone.local_net
    :param moving_image: tensor of shape (batch, m_dim1, m_dim2, m_dim3)
    :param fixed_image:  tensor of shape (batch, f_dim1, f_dim2, f_dim3)
    :param moving_label: tensor of shape (batch, m_dim1, m_dim2, m_dim3) or None
    :param moving_image_size: tuple like (m_dim1, m_dim2, m_dim3)
    :param fixed_image_size: tuple like (f_dim1, f_dim2, f_dim3)
    :param output_dvf: bool, if true, model outputs dvf, if false, model outputs ddf
    :return: tuple(dvf, ddf, pred_fixed_image, pred_fixed_label, fixed_grid), where
    - dvf is the dense velocity field of shape (batch, f_dim1, f_dim2, f_dim3, 3)
    - ddf is the dense displacement field of shape (batch, f_dim1, f_dim2, f_dim3, 3)
    - pred_fixed_image is the predicted (warped) moving image of shape (batch, f_dim1, f_dim2, f_dim3)
    - pred_fixed_label is the predicted (warped) moving label of shape (batch, f_dim1, f_dim2, f_dim3)
    - fixed_grid is the grid of shape(f_dim1, f_dim2, f_dim3, 3)
    """

    # expand dims
    # need to be squeezed later for warping
    moving_image = tf.expand_dims(moving_image,
                                  axis=4)  # (batch, m_dim1, m_dim2, m_dim3, 1)
    fixed_image = tf.expand_dims(fixed_image,
                                 axis=4)  # (batch, f_dim1, f_dim2, f_dim3, 1)

    # adjust moving image
    moving_image = layer_util.resize3d(
        image=moving_image,
        size=fixed_image_size)  # (batch, f_dim1, f_dim2, f_dim3, 1)

    # ddf, dvf
    inputs = tf.concat([moving_image, fixed_image],
                       axis=4)  # (batch, f_dim1, f_dim2, f_dim3, 2)
    backbone_out = backbone(
        inputs=inputs)  # (batch, f_dim1, f_dim2, f_dim3, 3)
    if output_dvf:
        dvf = backbone_out  # (batch, f_dim1, f_dim2, f_dim3, 3)
        ddf = layer.IntDVF(fixed_image_size=fixed_image_size)(
            dvf)  # (batch, f_dim1, f_dim2, f_dim3, 3)
    else:
        dvf = None
        ddf = backbone_out  # (batch, f_dim1, f_dim2, f_dim3, 3)

    # prediction, (batch, f_dim1, f_dim2, f_dim3)
    warping = layer.Warping(fixed_image_size=fixed_image_size)
    grid_fixed = tf.squeeze(warping.grid_ref,
                            axis=0)  # (f_dim1, f_dim2, f_dim3, 3)
    pred_fixed_image = warping(inputs=[ddf, tf.squeeze(moving_image, axis=4)])
    pred_fixed_label = (warping(
        inputs=[ddf, moving_label]) if moving_label is not None else None)
    return dvf, ddf, pred_fixed_image, pred_fixed_label, grid_fixed
Beispiel #2
0
def test_initDVF():
    """
    Test the layer.IntDVF class, its default attributes and its call() method.
    """

    batch_size = 5
    fixed_image_size = (32, 32, 16)
    ndims = len(fixed_image_size)

    model = layer.IntDVF(fixed_image_size)

    assert isinstance(model._warping, layer.Warping)
    assert model._num_steps == 7

    inputs = np.ones((batch_size, *fixed_image_size, ndims))
    output = model.call(inputs)
    assert all(x == y for x, y in zip((batch_size, ) + fixed_image_size +
                                      (ndims, ), output.shape))
Beispiel #3
0
    def build_model(self):
        """Build the model to be saved as self._model."""
        # build inputs
        self._inputs = self.build_inputs()
        moving_image = self._inputs["moving_image"]
        fixed_image = self._inputs["fixed_image"]
        control_points = self.config["backbone"].pop("control_points", False)

        # build ddf
        backbone_inputs = self.concat_images(moving_image, fixed_image)
        backbone = REGISTRY.build_backbone(
            config=self.config["backbone"],
            default_args=dict(
                image_size=self.fixed_image_size,
                out_channels=3,
                out_kernel_initializer="zeros",
                out_activation=None,
            ),
        )
        dvf = backbone(inputs=backbone_inputs)
        dvf = self._resize_interpolate(
            dvf, control_points) if control_points else dvf
        ddf = layer.IntDVF(fixed_image_size=self.fixed_image_size)(dvf)

        # build outputs
        self._warping = layer.Warping(fixed_image_size=self.fixed_image_size)
        # (f_dim1, f_dim2, f_dim3, 3)
        pred_fixed_image = self._warping(inputs=[ddf, moving_image])

        self._outputs = dict(dvf=dvf,
                             ddf=ddf,
                             pred_fixed_image=pred_fixed_image)

        if not self.labeled:
            return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)

        # (f_dim1, f_dim2, f_dim3, 3)
        moving_label = self._inputs["moving_label"]
        pred_fixed_label = self._warping(inputs=[ddf, moving_label])

        self._outputs["pred_fixed_label"] = pred_fixed_label
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
Beispiel #4
0
    def forward(_backbone, _moving_image, _moving_label, _fixed_image):
        """

        :param _backbone:
        :param _moving_image: [batch, m_dim1, m_dim2, m_dim3]
        :param _moving_label: [batch, m_dim1, m_dim2, m_dim3]
        :param _fixed_image:  [batch, f_dim1, f_dim2, f_dim3]
        :return:
        """
        # ddf
        backbone_input = tf.concat([layer.Resize3d(size=fixed_image_size)(inputs=tf.expand_dims(_moving_image, axis=4)),
                                    tf.expand_dims(_fixed_image, axis=4)],
                                   axis=4)  # [batch, f_dim1, f_dim2, f_dim3, 2]
        _dvf = _backbone(inputs=backbone_input)  # [batch, f_dim1, f_dim2, f_dim3, 3]
        _ddf = layer.IntDVF(fixed_image_size=fixed_image_size)(_dvf)

        # prediction image ang label shape = [batch, f_dim1, f_dim2, f_dim3]
        _pred_fixed_image = layer.Warping(fixed_image_size=fixed_image_size)([_ddf, _moving_image])
        _pred_fixed_label = layer.Warping(fixed_image_size=fixed_image_size)([_ddf, _moving_label])

        return _dvf, _ddf, _pred_fixed_image, _pred_fixed_label
Beispiel #5
0
    def test_forward(self):
        """
        Test output shape and config.
        """

        fixed_image_size = (8, 9, 10)
        input_shape = (2, *fixed_image_size, 3)

        int_layer = layer.IntDVF(fixed_image_size=fixed_image_size)

        inputs = tf.ones(shape=input_shape)
        outputs = int_layer(inputs)
        assert outputs.shape == input_shape

        config = int_layer.get_config()
        assert config == dict(
            fixed_image_size=fixed_image_size,
            num_steps=7,
            name="int_dvf",
            trainable=True,
            dtype="float32",
        )
Beispiel #6
0
 def test_err(self):
     with pytest.raises(AssertionError):
         layer.IntDVF(fixed_image_size=(2, 3))