Beispiel #1
0
def resize_inputs(inputs: dict, moving_image_size: tuple,
                  fixed_image_size: tuple):
    """
    Resize inputs
    :param inputs:
        if labeled:
            moving_image, shape = (None, None, None)
            fixed_image, shape = (None, None, None)
            moving_label, shape = (None, None, None)
            fixed_label, shape = (None, None, None)
            indices, shape = (num_indices, )
        else, unlabeled:
            moving_image, shape = (None, None, None)
            fixed_image, shape = (None, None, None)
            indices, shape = (num_indices, )
    :param moving_image_size: tuple, (m_dim1, m_dim2, m_dim3)
    :param fixed_image_size: tuple, (f_dim1, f_dim2, f_dim3)
    :return:
        if labeled:
            moving_image, shape = (m_dim1, m_dim2, m_dim3)
            fixed_image, shape = (f_dim1, f_dim2, f_dim3)
            moving_label, shape = (m_dim1, m_dim2, m_dim3)
            fixed_label, shape = (f_dim1, f_dim2, f_dim3)
            indices, shape = (num_indices, )
        else, unlabeled:
            moving_image, shape = (m_dim1, m_dim2, m_dim3)
            fixed_image, shape = (f_dim1, f_dim2, f_dim3)
            indices, shape = (num_indices, )
    """
    moving_image = inputs.get("moving_image")
    fixed_image = inputs.get("fixed_image")
    moving_label = inputs.get("moving_label", None)
    fixed_label = inputs.get("fixed_label", None)
    indices = inputs.get("indices")

    moving_image = layer_util.resize3d(image=moving_image,
                                       size=moving_image_size)
    fixed_image = layer_util.resize3d(image=fixed_image, size=fixed_image_size)

    if moving_label is None:  # unlabeled
        return dict(moving_image=moving_image,
                    fixed_image=fixed_image,
                    indices=indices)

    moving_label = layer_util.resize3d(image=moving_label,
                                       size=moving_image_size)
    fixed_label = layer_util.resize3d(image=fixed_label, size=fixed_image_size)

    return dict(
        moving_image=moving_image,
        fixed_image=fixed_image,
        moving_label=moving_label,
        fixed_label=fixed_label,
        indices=indices,
    )
Beispiel #2
0
def conditional_forward(
    backbone: tf.keras.Model,
    moving_image: tf.Tensor,
    fixed_image: tf.Tensor,
    moving_label: (tf.Tensor, None),
    moving_image_size: tuple,
    fixed_image_size: tuple,
) -> [tf.Tensor, tf.Tensor]:
    """
    Perform the network forward pass.

    :param backbone: model architecture object, e.g. model.backbone.local_net
    :param moving_image: tensor of shape (batch, m_dim1, m_dim2, m_dim3)
    :param fixed_image:  tensor of shape (batch, f_dim1, f_dim2, f_dim3)
    :param moving_label: tensor of shape (batch, m_dim1, m_dim2, m_dim3) or None
    :param moving_image_size: tuple like (m_dim1, m_dim2, m_dim3)
    :param fixed_image_size: tuple like (f_dim1, f_dim2, f_dim3)
    :return: (pred_fixed_label, fixed_grid), where

      - pred_fixed_label is the predicted (warped) moving label of shape (batch, f_dim1, f_dim2, f_dim3)
      - fixed_grid is the grid of shape(f_dim1, f_dim2, f_dim3, 3)
    """

    # expand dims
    # need to be squeezed later for warping
    moving_image = tf.expand_dims(moving_image,
                                  axis=4)  # (batch, m_dim1, m_dim2, m_dim3, 1)
    fixed_image = tf.expand_dims(fixed_image,
                                 axis=4)  # (batch, f_dim1, f_dim2, f_dim3, 1)
    moving_label = tf.expand_dims(moving_label,
                                  axis=4)  # (batch, m_dim1, m_dim2, m_dim3, 1)

    # adjust moving image
    if moving_image_size != fixed_image_size:
        moving_image = layer_util.resize3d(
            image=moving_image,
            size=fixed_image_size)  # (batch, f_dim1, f_dim2, f_dim3, 1)
        moving_label = layer_util.resize3d(
            image=moving_label,
            size=fixed_image_size)  # (batch, f_dim1, f_dim2, f_dim3, 1)

    # conditional
    inputs = tf.concat([moving_image, fixed_image, moving_label],
                       axis=4)  # (batch, f_dim1, f_dim2, f_dim3, 3)
    pred_fixed_label = backbone(
        inputs=inputs)  # (batch, f_dim1, f_dim2, f_dim3, 1)
    pred_fixed_label = tf.squeeze(pred_fixed_label,
                                  axis=4)  # (batch, f_dim1, f_dim2, f_dim3)

    warping = layer.Warping(fixed_image_size=fixed_image_size)
    grid_fixed = tf.squeeze(warping.grid_ref,
                            axis=0)  # (f_dim1, f_dim2, f_dim3, 3)

    return pred_fixed_label, grid_fixed
Beispiel #3
0
def resize_inputs(
    inputs: Dict[str, tf.Tensor], moving_image_size: tuple, fixed_image_size: tuple
) -> Dict[str, tf.Tensor]:
    """
    Resize inputs
    :param inputs:
        if labeled:
            moving_image, shape = (None, None, None)
            fixed_image, shape = (None, None, None)
            moving_label, shape = (None, None, None)
            fixed_label, shape = (None, None, None)
            indices, shape = (num_indices, )
        else, unlabeled:
            moving_image, shape = (None, None, None)
            fixed_image, shape = (None, None, None)
            indices, shape = (num_indices, )
    :param moving_image_size: tuple, (m_dim1, m_dim2, m_dim3)
    :param fixed_image_size: tuple, (f_dim1, f_dim2, f_dim3)
    :return:
        if labeled:
            moving_image, shape = (m_dim1, m_dim2, m_dim3)
            fixed_image, shape = (f_dim1, f_dim2, f_dim3)
            moving_label, shape = (m_dim1, m_dim2, m_dim3)
            fixed_label, shape = (f_dim1, f_dim2, f_dim3)
            indices, shape = (num_indices, )
        else, unlabeled:
            moving_image, shape = (m_dim1, m_dim2, m_dim3)
            fixed_image, shape = (f_dim1, f_dim2, f_dim3)
            indices, shape = (num_indices, )
    """
    moving_image = inputs["moving_image"]
    fixed_image = inputs["fixed_image"]
    indices = inputs["indices"]

    moving_image = resize3d(image=moving_image, size=moving_image_size)
    fixed_image = resize3d(image=fixed_image, size=fixed_image_size)

    if "moving_label" not in inputs:  # unlabeled
        return dict(moving_image=moving_image, fixed_image=fixed_image, indices=indices)
    moving_label = inputs["moving_label"]
    fixed_label = inputs["fixed_label"]

    moving_label = resize3d(image=moving_label, size=moving_image_size)
    fixed_label = resize3d(image=fixed_label, size=fixed_image_size)

    return dict(
        moving_image=moving_image,
        fixed_image=fixed_image,
        moving_label=moving_label,
        fixed_label=fixed_label,
        indices=indices,
    )
def test_resize3d():
    """
    Test resize3d by confirming the output shapes.
    """

    # Check resize3d for images with different size and without channel nor batch - Pass
    input_shape = (1, 3, 5)
    output_shape = (2, 4, 6)
    size = (2, 4, 6)
    got = layer_util.resize3d(image=tf.ones(input_shape), size=size)
    assert got.shape == output_shape

    # Check resize3d for images with different size and without channel - Pass
    input_shape = (1, 1, 3, 5)
    output_shape = (1, 2, 4, 6)
    size = (2, 4, 6)
    got = layer_util.resize3d(image=tf.ones(input_shape), size=size)
    assert got.shape == output_shape

    # Check resize3d for images with different size and with one channel - Pass
    input_shape = (1, 1, 3, 5, 1)
    output_shape = (1, 2, 4, 6, 1)
    size = (2, 4, 6)
    got = layer_util.resize3d(image=tf.ones(input_shape), size=size)
    assert got.shape == output_shape

    # Check resize3d for images with different size and with multiple channels - Pass
    input_shape = (1, 1, 3, 5, 3)
    output_shape = (1, 2, 4, 6, 3)
    size = (2, 4, 6)
    got = layer_util.resize3d(image=tf.ones(input_shape), size=size)
    assert got.shape == output_shape

    # Check resize3d for images with the same size and without channel nor batch - Pass
    input_shape = (1, 3, 5)
    output_shape = (1, 3, 5)
    size = (1, 3, 5)
    got = layer_util.resize3d(image=tf.ones(input_shape), size=size)
    assert got.shape == output_shape

    # Check resize3d for images with the same size and without channel - Pass
    input_shape = (1, 1, 3, 5)
    output_shape = (1, 1, 3, 5)
    size = (1, 3, 5)
    got = layer_util.resize3d(image=tf.ones(input_shape), size=size)
    assert got.shape == output_shape

    # Check resize3d for images with the same size and with one channel - Pass
    input_shape = (1, 1, 3, 5, 1)
    output_shape = (1, 1, 3, 5, 1)
    size = (1, 3, 5)
    got = layer_util.resize3d(image=tf.ones(input_shape), size=size)
    assert got.shape == output_shape

    # Check resize3d for images with the same size and with multiple channels - Pass
    input_shape = (1, 1, 3, 5, 3)
    output_shape = (1, 1, 3, 5, 3)
    size = (1, 3, 5)
    got = layer_util.resize3d(image=tf.ones(input_shape), size=size)
    assert got.shape == output_shape
Beispiel #5
0
def ddf_dvf_forward(
    backbone: tf.keras.Model,
    moving_image: tf.Tensor,
    fixed_image: tf.Tensor,
    moving_label: (tf.Tensor, None),
    moving_image_size: tuple,
    fixed_image_size: tuple,
    output_dvf: bool,
) -> [(tf.Tensor, None), tf.Tensor, tf.Tensor, (tf.Tensor, None), tf.Tensor]:
    """
    Perform the network forward pass
    :param backbone: model architecture object, e.g. model.backbone.local_net
    :param moving_image: tensor of shape (batch, m_dim1, m_dim2, m_dim3)
    :param fixed_image:  tensor of shape (batch, f_dim1, f_dim2, f_dim3)
    :param moving_label: tensor of shape (batch, m_dim1, m_dim2, m_dim3) or None
    :param moving_image_size: tuple like (m_dim1, m_dim2, m_dim3)
    :param fixed_image_size: tuple like (f_dim1, f_dim2, f_dim3)
    :param output_dvf: bool, if true, model outputs dvf, if false, model outputs ddf
    :return: tuple(dvf, ddf, pred_fixed_image, pred_fixed_label, fixed_grid), where
    - dvf is the dense velocity field of shape (batch, f_dim1, f_dim2, f_dim3, 3)
    - ddf is the dense displacement field of shape (batch, f_dim1, f_dim2, f_dim3, 3)
    - pred_fixed_image is the predicted (warped) moving image of shape (batch, f_dim1, f_dim2, f_dim3)
    - pred_fixed_label is the predicted (warped) moving label of shape (batch, f_dim1, f_dim2, f_dim3)
    - fixed_grid is the grid of shape(f_dim1, f_dim2, f_dim3, 3)
    """

    # expand dims
    # need to be squeezed later for warping
    moving_image = tf.expand_dims(moving_image,
                                  axis=4)  # (batch, m_dim1, m_dim2, m_dim3, 1)
    fixed_image = tf.expand_dims(fixed_image,
                                 axis=4)  # (batch, f_dim1, f_dim2, f_dim3, 1)

    # adjust moving image
    moving_image = layer_util.resize3d(
        image=moving_image,
        size=fixed_image_size)  # (batch, f_dim1, f_dim2, f_dim3, 1)

    # ddf, dvf
    inputs = tf.concat([moving_image, fixed_image],
                       axis=4)  # (batch, f_dim1, f_dim2, f_dim3, 2)
    backbone_out = backbone(
        inputs=inputs)  # (batch, f_dim1, f_dim2, f_dim3, 3)
    if output_dvf:
        dvf = backbone_out  # (batch, f_dim1, f_dim2, f_dim3, 3)
        ddf = layer.IntDVF(fixed_image_size=fixed_image_size)(
            dvf)  # (batch, f_dim1, f_dim2, f_dim3, 3)
    else:
        dvf = None
        ddf = backbone_out  # (batch, f_dim1, f_dim2, f_dim3, 3)

    # prediction, (batch, f_dim1, f_dim2, f_dim3)
    warping = layer.Warping(fixed_image_size=fixed_image_size)
    grid_fixed = tf.squeeze(warping.grid_ref,
                            axis=0)  # (f_dim1, f_dim2, f_dim3, 3)
    pred_fixed_image = warping(inputs=[ddf, tf.squeeze(moving_image, axis=4)])
    pred_fixed_label = (warping(
        inputs=[ddf, moving_label]) if moving_label is not None else None)
    return dvf, ddf, pred_fixed_image, pred_fixed_label, grid_fixed
Beispiel #6
0
 def call(self, inputs, **kwargs):
     """
     :param inputs: shape = [batch, dim1, dim2, dim3, channels]
     :param kwargs:
     :return: shape = [batch, out_dim1, out_dim2, out_dim3, channels]
     """
     output = self._conv3d(inputs=inputs)
     output = layer_util.resize3d(image=output, size=self._output_shape)
     return output
Beispiel #7
0
 def call(self, inputs, **kwargs) -> tf.Tensor:
     """
     :param inputs: shape = (batch, dim1, dim2, dim3, channels)
     :param kwargs: additional arguments.
     :return: shape = (batch, out_dim1, out_dim2, out_dim3, channels)
     """
     output = self._conv3d(inputs=inputs)
     output = layer_util.resize3d(image=output, size=self._output_shape)
     return output
def affine_forward(
    backbone: tf.keras.Model,
    moving_image: tf.Tensor,
    fixed_image: tf.Tensor,
    moving_label: (tf.Tensor, None),
    moving_image_size: tuple,
    fixed_image_size: tuple,
):
    """
    Perform the network forward pass.

    :param backbone: model architecture object, e.g. model.backbone.local_net
    :param moving_image: tensor of shape (batch, m_dim1, m_dim2, m_dim3)
    :param fixed_image:  tensor of shape (batch, f_dim1, f_dim2, f_dim3)
    :param moving_label: tensor of shape (batch, m_dim1, m_dim2, m_dim3) or None
    :param moving_image_size: tuple like (m_dim1, m_dim2, m_dim3)
    :param fixed_image_size: tuple like (f_dim1, f_dim2, f_dim3)
    :return: tuple(affine, ddf, pred_fixed_image, pred_fixed_label, fixed_grid), where

      - affine is the affine transformation matrix predicted by the network (batch, 4, 3)
      - ddf is the dense displacement field of shape (batch, f_dim1, f_dim2, f_dim3, 3)
      - pred_fixed_image is the predicted (warped) moving image of shape (batch, f_dim1, f_dim2, f_dim3)
      - pred_fixed_label is the predicted (warped) moving label of shape (batch, f_dim1, f_dim2, f_dim3)
      - fixed_grid is the grid of shape(f_dim1, f_dim2, f_dim3, 3)
    """

    # expand dims
    # need to be squeezed later for warping
    moving_image = tf.expand_dims(moving_image,
                                  axis=4)  # (batch, m_dim1, m_dim2, m_dim3, 1)
    fixed_image = tf.expand_dims(fixed_image,
                                 axis=4)  # (batch, f_dim1, f_dim2, f_dim3, 1)

    # adjust moving image
    moving_image = layer_util.resize3d(
        image=moving_image,
        size=fixed_image_size)  # (batch, f_dim1, f_dim2, f_dim3, 1)

    # ddf, dvf
    inputs = tf.concat([moving_image, fixed_image],
                       axis=4)  # (batch, f_dim1, f_dim2, f_dim3, 2)
    ddf = backbone(inputs=inputs)  # (batch, f_dim1, f_dim2, f_dim3, 3)
    affine = backbone.theta

    # prediction, (batch, f_dim1, f_dim2, f_dim3)
    warping = layer.Warping(fixed_image_size=fixed_image_size)
    grid_fixed = tf.squeeze(warping.grid_ref,
                            axis=0)  # (f_dim1, f_dim2, f_dim3, 3)
    pred_fixed_image = warping(inputs=[ddf, tf.squeeze(moving_image, axis=4)])
    pred_fixed_label = (warping(
        inputs=[ddf, moving_label]) if moving_label is not None else None)
    return affine, ddf, pred_fixed_image, pred_fixed_label, grid_fixed
Beispiel #9
0
 def call(self, inputs, **kwargs) -> tf.Tensor:
     """
     :param inputs: shape = (batch, dim1, dim2, dim3, channels)
     :param kwargs: additional arguments.
     :return: shape = (batch, out_dim1, out_dim2, out_dim3, channels//stride]
     """
     if inputs.shape[4] % self._stride != 0:
         raise ValueError("The channel dimension can not be divided by the stride")
     output = layer_util.resize3d(image=inputs, size=self._output_shape)
     # a list of (batch, out_dim1, out_dim2, out_dim3, channels//stride)
     output = tf.split(output, num_or_size_splits=self._stride, axis=4)
     # (batch, out_dim1, out_dim2, out_dim3, channels//stride)
     output = tf.reduce_sum(tf.stack(output, axis=5), axis=5)
     return output
Beispiel #10
0
    def concat_images(
        self,
        moving_image: tf.Tensor,
        fixed_image: tf.Tensor,
        moving_label: Optional[tf.Tensor] = None,
    ) -> tf.Tensor:
        """
        Adjust image shape and concatenate them together.

        :param moving_image: registration source
        :param fixed_image: registration target
        :param moving_label: optional, only used for conditional model.
        :return:
        """
        images = []

        # (batch, m_dim1, m_dim2, m_dim3, 1)
        moving_image = tf.expand_dims(moving_image, axis=4)
        moving_image = layer_util.resize3d(image=moving_image,
                                           size=self.fixed_image_size)
        images.append(moving_image)

        # (batch, m_dim1, m_dim2, m_dim3, 1)
        fixed_image = tf.expand_dims(fixed_image, axis=4)
        images.append(fixed_image)

        # (batch, m_dim1, m_dim2, m_dim3, 1)
        if moving_label is not None:
            moving_label = tf.expand_dims(moving_label, axis=4)
            moving_label = layer_util.resize3d(image=moving_label,
                                               size=self.fixed_image_size)
            images.append(moving_label)

        # (batch, f_dim1, f_dim2, f_dim3, 2 or 3)
        images = tf.concat(images, axis=4)
        return images
Beispiel #11
0
 def call(self, inputs, **kwargs) -> tf.Tensor:
     output = tf.nn.conv3d(
         inputs, self.kernel, strides=(1, 1, 1, 1, 1), padding="SAME"
     )
     return layer_util.resize3d(image=output, size=self._output_shape)
def test_resize3d():
    """
    Test resize3d by confirming the output shapes.
    """

    # Check resize3d for images with different size and without channel nor batch - Pass
    input_shape = (1, 3, 5)
    output_shape = (2, 4, 6)
    size = (2, 4, 6)
    got = layer_util.resize3d(image=tf.ones(input_shape), size=size)
    assert got.shape == output_shape

    # Check resize3d for images with different size and without channel - Pass
    input_shape = (1, 1, 3, 5)
    output_shape = (1, 2, 4, 6)
    size = (2, 4, 6)
    got = layer_util.resize3d(image=tf.ones(input_shape), size=size)
    assert got.shape == output_shape

    # Check resize3d for images with different size and with one channel - Pass
    input_shape = (1, 1, 3, 5, 1)
    output_shape = (1, 2, 4, 6, 1)
    size = (2, 4, 6)
    got = layer_util.resize3d(image=tf.ones(input_shape), size=size)
    assert got.shape == output_shape

    # Check resize3d for images with different size and with multiple channels - Pass
    input_shape = (1, 1, 3, 5, 3)
    output_shape = (1, 2, 4, 6, 3)
    size = (2, 4, 6)
    got = layer_util.resize3d(image=tf.ones(input_shape), size=size)
    assert got.shape == output_shape

    # Check resize3d for images with the same size and without channel nor batch - Pass
    input_shape = (1, 3, 5)
    output_shape = (1, 3, 5)
    size = (1, 3, 5)
    got = layer_util.resize3d(image=tf.ones(input_shape), size=size)
    assert got.shape == output_shape

    # Check resize3d for images with the same size and without channel - Pass
    input_shape = (1, 1, 3, 5)
    output_shape = (1, 1, 3, 5)
    size = (1, 3, 5)
    got = layer_util.resize3d(image=tf.ones(input_shape), size=size)
    assert got.shape == output_shape

    # Check resize3d for images with the same size and with one channel - Pass
    input_shape = (1, 1, 3, 5, 1)
    output_shape = (1, 1, 3, 5, 1)
    size = (1, 3, 5)
    got = layer_util.resize3d(image=tf.ones(input_shape), size=size)
    assert got.shape == output_shape

    # Check resize3d for images with the same size and with multiple channels - Pass
    input_shape = (1, 1, 3, 5, 3)
    output_shape = (1, 1, 3, 5, 3)
    size = (1, 3, 5)
    got = layer_util.resize3d(image=tf.ones(input_shape), size=size)
    assert got.shape == output_shape

    # Check resize3d for proper image dimensions - Fail
    input_shape = (1, 1)
    size = (1, 1, 1)
    with pytest.raises(ValueError) as err_info:
        layer_util.resize3d(image=tf.ones(input_shape), size=size)
    assert "resize3d takes input image of dimension 3 or 4 or 5" in str(err_info.value)

    # Check resize3d for proper size - Fail
    input_shape = (1, 1, 1)
    size = (1, 1)
    with pytest.raises(ValueError) as err_info:
        layer_util.resize3d(image=tf.ones(input_shape), size=size)
    assert "resize3d takes size of type tuple/list and of length 3" in str(
        err_info.value
    )