def resize_inputs(inputs: dict, moving_image_size: tuple, fixed_image_size: tuple): """ Resize inputs :param inputs: if labeled: moving_image, shape = (None, None, None) fixed_image, shape = (None, None, None) moving_label, shape = (None, None, None) fixed_label, shape = (None, None, None) indices, shape = (num_indices, ) else, unlabeled: moving_image, shape = (None, None, None) fixed_image, shape = (None, None, None) indices, shape = (num_indices, ) :param moving_image_size: tuple, (m_dim1, m_dim2, m_dim3) :param fixed_image_size: tuple, (f_dim1, f_dim2, f_dim3) :return: if labeled: moving_image, shape = (m_dim1, m_dim2, m_dim3) fixed_image, shape = (f_dim1, f_dim2, f_dim3) moving_label, shape = (m_dim1, m_dim2, m_dim3) fixed_label, shape = (f_dim1, f_dim2, f_dim3) indices, shape = (num_indices, ) else, unlabeled: moving_image, shape = (m_dim1, m_dim2, m_dim3) fixed_image, shape = (f_dim1, f_dim2, f_dim3) indices, shape = (num_indices, ) """ moving_image = inputs.get("moving_image") fixed_image = inputs.get("fixed_image") moving_label = inputs.get("moving_label", None) fixed_label = inputs.get("fixed_label", None) indices = inputs.get("indices") moving_image = layer_util.resize3d(image=moving_image, size=moving_image_size) fixed_image = layer_util.resize3d(image=fixed_image, size=fixed_image_size) if moving_label is None: # unlabeled return dict(moving_image=moving_image, fixed_image=fixed_image, indices=indices) moving_label = layer_util.resize3d(image=moving_label, size=moving_image_size) fixed_label = layer_util.resize3d(image=fixed_label, size=fixed_image_size) return dict( moving_image=moving_image, fixed_image=fixed_image, moving_label=moving_label, fixed_label=fixed_label, indices=indices, )
def conditional_forward( backbone: tf.keras.Model, moving_image: tf.Tensor, fixed_image: tf.Tensor, moving_label: (tf.Tensor, None), moving_image_size: tuple, fixed_image_size: tuple, ) -> [tf.Tensor, tf.Tensor]: """ Perform the network forward pass. :param backbone: model architecture object, e.g. model.backbone.local_net :param moving_image: tensor of shape (batch, m_dim1, m_dim2, m_dim3) :param fixed_image: tensor of shape (batch, f_dim1, f_dim2, f_dim3) :param moving_label: tensor of shape (batch, m_dim1, m_dim2, m_dim3) or None :param moving_image_size: tuple like (m_dim1, m_dim2, m_dim3) :param fixed_image_size: tuple like (f_dim1, f_dim2, f_dim3) :return: (pred_fixed_label, fixed_grid), where - pred_fixed_label is the predicted (warped) moving label of shape (batch, f_dim1, f_dim2, f_dim3) - fixed_grid is the grid of shape(f_dim1, f_dim2, f_dim3, 3) """ # expand dims # need to be squeezed later for warping moving_image = tf.expand_dims(moving_image, axis=4) # (batch, m_dim1, m_dim2, m_dim3, 1) fixed_image = tf.expand_dims(fixed_image, axis=4) # (batch, f_dim1, f_dim2, f_dim3, 1) moving_label = tf.expand_dims(moving_label, axis=4) # (batch, m_dim1, m_dim2, m_dim3, 1) # adjust moving image if moving_image_size != fixed_image_size: moving_image = layer_util.resize3d( image=moving_image, size=fixed_image_size) # (batch, f_dim1, f_dim2, f_dim3, 1) moving_label = layer_util.resize3d( image=moving_label, size=fixed_image_size) # (batch, f_dim1, f_dim2, f_dim3, 1) # conditional inputs = tf.concat([moving_image, fixed_image, moving_label], axis=4) # (batch, f_dim1, f_dim2, f_dim3, 3) pred_fixed_label = backbone( inputs=inputs) # (batch, f_dim1, f_dim2, f_dim3, 1) pred_fixed_label = tf.squeeze(pred_fixed_label, axis=4) # (batch, f_dim1, f_dim2, f_dim3) warping = layer.Warping(fixed_image_size=fixed_image_size) grid_fixed = tf.squeeze(warping.grid_ref, axis=0) # (f_dim1, f_dim2, f_dim3, 3) return pred_fixed_label, grid_fixed
def resize_inputs( inputs: Dict[str, tf.Tensor], moving_image_size: tuple, fixed_image_size: tuple ) -> Dict[str, tf.Tensor]: """ Resize inputs :param inputs: if labeled: moving_image, shape = (None, None, None) fixed_image, shape = (None, None, None) moving_label, shape = (None, None, None) fixed_label, shape = (None, None, None) indices, shape = (num_indices, ) else, unlabeled: moving_image, shape = (None, None, None) fixed_image, shape = (None, None, None) indices, shape = (num_indices, ) :param moving_image_size: tuple, (m_dim1, m_dim2, m_dim3) :param fixed_image_size: tuple, (f_dim1, f_dim2, f_dim3) :return: if labeled: moving_image, shape = (m_dim1, m_dim2, m_dim3) fixed_image, shape = (f_dim1, f_dim2, f_dim3) moving_label, shape = (m_dim1, m_dim2, m_dim3) fixed_label, shape = (f_dim1, f_dim2, f_dim3) indices, shape = (num_indices, ) else, unlabeled: moving_image, shape = (m_dim1, m_dim2, m_dim3) fixed_image, shape = (f_dim1, f_dim2, f_dim3) indices, shape = (num_indices, ) """ moving_image = inputs["moving_image"] fixed_image = inputs["fixed_image"] indices = inputs["indices"] moving_image = resize3d(image=moving_image, size=moving_image_size) fixed_image = resize3d(image=fixed_image, size=fixed_image_size) if "moving_label" not in inputs: # unlabeled return dict(moving_image=moving_image, fixed_image=fixed_image, indices=indices) moving_label = inputs["moving_label"] fixed_label = inputs["fixed_label"] moving_label = resize3d(image=moving_label, size=moving_image_size) fixed_label = resize3d(image=fixed_label, size=fixed_image_size) return dict( moving_image=moving_image, fixed_image=fixed_image, moving_label=moving_label, fixed_label=fixed_label, indices=indices, )
def test_resize3d(): """ Test resize3d by confirming the output shapes. """ # Check resize3d for images with different size and without channel nor batch - Pass input_shape = (1, 3, 5) output_shape = (2, 4, 6) size = (2, 4, 6) got = layer_util.resize3d(image=tf.ones(input_shape), size=size) assert got.shape == output_shape # Check resize3d for images with different size and without channel - Pass input_shape = (1, 1, 3, 5) output_shape = (1, 2, 4, 6) size = (2, 4, 6) got = layer_util.resize3d(image=tf.ones(input_shape), size=size) assert got.shape == output_shape # Check resize3d for images with different size and with one channel - Pass input_shape = (1, 1, 3, 5, 1) output_shape = (1, 2, 4, 6, 1) size = (2, 4, 6) got = layer_util.resize3d(image=tf.ones(input_shape), size=size) assert got.shape == output_shape # Check resize3d for images with different size and with multiple channels - Pass input_shape = (1, 1, 3, 5, 3) output_shape = (1, 2, 4, 6, 3) size = (2, 4, 6) got = layer_util.resize3d(image=tf.ones(input_shape), size=size) assert got.shape == output_shape # Check resize3d for images with the same size and without channel nor batch - Pass input_shape = (1, 3, 5) output_shape = (1, 3, 5) size = (1, 3, 5) got = layer_util.resize3d(image=tf.ones(input_shape), size=size) assert got.shape == output_shape # Check resize3d for images with the same size and without channel - Pass input_shape = (1, 1, 3, 5) output_shape = (1, 1, 3, 5) size = (1, 3, 5) got = layer_util.resize3d(image=tf.ones(input_shape), size=size) assert got.shape == output_shape # Check resize3d for images with the same size and with one channel - Pass input_shape = (1, 1, 3, 5, 1) output_shape = (1, 1, 3, 5, 1) size = (1, 3, 5) got = layer_util.resize3d(image=tf.ones(input_shape), size=size) assert got.shape == output_shape # Check resize3d for images with the same size and with multiple channels - Pass input_shape = (1, 1, 3, 5, 3) output_shape = (1, 1, 3, 5, 3) size = (1, 3, 5) got = layer_util.resize3d(image=tf.ones(input_shape), size=size) assert got.shape == output_shape
def ddf_dvf_forward( backbone: tf.keras.Model, moving_image: tf.Tensor, fixed_image: tf.Tensor, moving_label: (tf.Tensor, None), moving_image_size: tuple, fixed_image_size: tuple, output_dvf: bool, ) -> [(tf.Tensor, None), tf.Tensor, tf.Tensor, (tf.Tensor, None), tf.Tensor]: """ Perform the network forward pass :param backbone: model architecture object, e.g. model.backbone.local_net :param moving_image: tensor of shape (batch, m_dim1, m_dim2, m_dim3) :param fixed_image: tensor of shape (batch, f_dim1, f_dim2, f_dim3) :param moving_label: tensor of shape (batch, m_dim1, m_dim2, m_dim3) or None :param moving_image_size: tuple like (m_dim1, m_dim2, m_dim3) :param fixed_image_size: tuple like (f_dim1, f_dim2, f_dim3) :param output_dvf: bool, if true, model outputs dvf, if false, model outputs ddf :return: tuple(dvf, ddf, pred_fixed_image, pred_fixed_label, fixed_grid), where - dvf is the dense velocity field of shape (batch, f_dim1, f_dim2, f_dim3, 3) - ddf is the dense displacement field of shape (batch, f_dim1, f_dim2, f_dim3, 3) - pred_fixed_image is the predicted (warped) moving image of shape (batch, f_dim1, f_dim2, f_dim3) - pred_fixed_label is the predicted (warped) moving label of shape (batch, f_dim1, f_dim2, f_dim3) - fixed_grid is the grid of shape(f_dim1, f_dim2, f_dim3, 3) """ # expand dims # need to be squeezed later for warping moving_image = tf.expand_dims(moving_image, axis=4) # (batch, m_dim1, m_dim2, m_dim3, 1) fixed_image = tf.expand_dims(fixed_image, axis=4) # (batch, f_dim1, f_dim2, f_dim3, 1) # adjust moving image moving_image = layer_util.resize3d( image=moving_image, size=fixed_image_size) # (batch, f_dim1, f_dim2, f_dim3, 1) # ddf, dvf inputs = tf.concat([moving_image, fixed_image], axis=4) # (batch, f_dim1, f_dim2, f_dim3, 2) backbone_out = backbone( inputs=inputs) # (batch, f_dim1, f_dim2, f_dim3, 3) if output_dvf: dvf = backbone_out # (batch, f_dim1, f_dim2, f_dim3, 3) ddf = layer.IntDVF(fixed_image_size=fixed_image_size)( dvf) # (batch, f_dim1, f_dim2, f_dim3, 3) else: dvf = None ddf = backbone_out # (batch, f_dim1, f_dim2, f_dim3, 3) # prediction, (batch, f_dim1, f_dim2, f_dim3) warping = layer.Warping(fixed_image_size=fixed_image_size) grid_fixed = tf.squeeze(warping.grid_ref, axis=0) # (f_dim1, f_dim2, f_dim3, 3) pred_fixed_image = warping(inputs=[ddf, tf.squeeze(moving_image, axis=4)]) pred_fixed_label = (warping( inputs=[ddf, moving_label]) if moving_label is not None else None) return dvf, ddf, pred_fixed_image, pred_fixed_label, grid_fixed
def call(self, inputs, **kwargs): """ :param inputs: shape = [batch, dim1, dim2, dim3, channels] :param kwargs: :return: shape = [batch, out_dim1, out_dim2, out_dim3, channels] """ output = self._conv3d(inputs=inputs) output = layer_util.resize3d(image=output, size=self._output_shape) return output
def call(self, inputs, **kwargs) -> tf.Tensor: """ :param inputs: shape = (batch, dim1, dim2, dim3, channels) :param kwargs: additional arguments. :return: shape = (batch, out_dim1, out_dim2, out_dim3, channels) """ output = self._conv3d(inputs=inputs) output = layer_util.resize3d(image=output, size=self._output_shape) return output
def affine_forward( backbone: tf.keras.Model, moving_image: tf.Tensor, fixed_image: tf.Tensor, moving_label: (tf.Tensor, None), moving_image_size: tuple, fixed_image_size: tuple, ): """ Perform the network forward pass. :param backbone: model architecture object, e.g. model.backbone.local_net :param moving_image: tensor of shape (batch, m_dim1, m_dim2, m_dim3) :param fixed_image: tensor of shape (batch, f_dim1, f_dim2, f_dim3) :param moving_label: tensor of shape (batch, m_dim1, m_dim2, m_dim3) or None :param moving_image_size: tuple like (m_dim1, m_dim2, m_dim3) :param fixed_image_size: tuple like (f_dim1, f_dim2, f_dim3) :return: tuple(affine, ddf, pred_fixed_image, pred_fixed_label, fixed_grid), where - affine is the affine transformation matrix predicted by the network (batch, 4, 3) - ddf is the dense displacement field of shape (batch, f_dim1, f_dim2, f_dim3, 3) - pred_fixed_image is the predicted (warped) moving image of shape (batch, f_dim1, f_dim2, f_dim3) - pred_fixed_label is the predicted (warped) moving label of shape (batch, f_dim1, f_dim2, f_dim3) - fixed_grid is the grid of shape(f_dim1, f_dim2, f_dim3, 3) """ # expand dims # need to be squeezed later for warping moving_image = tf.expand_dims(moving_image, axis=4) # (batch, m_dim1, m_dim2, m_dim3, 1) fixed_image = tf.expand_dims(fixed_image, axis=4) # (batch, f_dim1, f_dim2, f_dim3, 1) # adjust moving image moving_image = layer_util.resize3d( image=moving_image, size=fixed_image_size) # (batch, f_dim1, f_dim2, f_dim3, 1) # ddf, dvf inputs = tf.concat([moving_image, fixed_image], axis=4) # (batch, f_dim1, f_dim2, f_dim3, 2) ddf = backbone(inputs=inputs) # (batch, f_dim1, f_dim2, f_dim3, 3) affine = backbone.theta # prediction, (batch, f_dim1, f_dim2, f_dim3) warping = layer.Warping(fixed_image_size=fixed_image_size) grid_fixed = tf.squeeze(warping.grid_ref, axis=0) # (f_dim1, f_dim2, f_dim3, 3) pred_fixed_image = warping(inputs=[ddf, tf.squeeze(moving_image, axis=4)]) pred_fixed_label = (warping( inputs=[ddf, moving_label]) if moving_label is not None else None) return affine, ddf, pred_fixed_image, pred_fixed_label, grid_fixed
def call(self, inputs, **kwargs) -> tf.Tensor: """ :param inputs: shape = (batch, dim1, dim2, dim3, channels) :param kwargs: additional arguments. :return: shape = (batch, out_dim1, out_dim2, out_dim3, channels//stride] """ if inputs.shape[4] % self._stride != 0: raise ValueError("The channel dimension can not be divided by the stride") output = layer_util.resize3d(image=inputs, size=self._output_shape) # a list of (batch, out_dim1, out_dim2, out_dim3, channels//stride) output = tf.split(output, num_or_size_splits=self._stride, axis=4) # (batch, out_dim1, out_dim2, out_dim3, channels//stride) output = tf.reduce_sum(tf.stack(output, axis=5), axis=5) return output
def concat_images( self, moving_image: tf.Tensor, fixed_image: tf.Tensor, moving_label: Optional[tf.Tensor] = None, ) -> tf.Tensor: """ Adjust image shape and concatenate them together. :param moving_image: registration source :param fixed_image: registration target :param moving_label: optional, only used for conditional model. :return: """ images = [] # (batch, m_dim1, m_dim2, m_dim3, 1) moving_image = tf.expand_dims(moving_image, axis=4) moving_image = layer_util.resize3d(image=moving_image, size=self.fixed_image_size) images.append(moving_image) # (batch, m_dim1, m_dim2, m_dim3, 1) fixed_image = tf.expand_dims(fixed_image, axis=4) images.append(fixed_image) # (batch, m_dim1, m_dim2, m_dim3, 1) if moving_label is not None: moving_label = tf.expand_dims(moving_label, axis=4) moving_label = layer_util.resize3d(image=moving_label, size=self.fixed_image_size) images.append(moving_label) # (batch, f_dim1, f_dim2, f_dim3, 2 or 3) images = tf.concat(images, axis=4) return images
def call(self, inputs, **kwargs) -> tf.Tensor: output = tf.nn.conv3d( inputs, self.kernel, strides=(1, 1, 1, 1, 1), padding="SAME" ) return layer_util.resize3d(image=output, size=self._output_shape)
def test_resize3d(): """ Test resize3d by confirming the output shapes. """ # Check resize3d for images with different size and without channel nor batch - Pass input_shape = (1, 3, 5) output_shape = (2, 4, 6) size = (2, 4, 6) got = layer_util.resize3d(image=tf.ones(input_shape), size=size) assert got.shape == output_shape # Check resize3d for images with different size and without channel - Pass input_shape = (1, 1, 3, 5) output_shape = (1, 2, 4, 6) size = (2, 4, 6) got = layer_util.resize3d(image=tf.ones(input_shape), size=size) assert got.shape == output_shape # Check resize3d for images with different size and with one channel - Pass input_shape = (1, 1, 3, 5, 1) output_shape = (1, 2, 4, 6, 1) size = (2, 4, 6) got = layer_util.resize3d(image=tf.ones(input_shape), size=size) assert got.shape == output_shape # Check resize3d for images with different size and with multiple channels - Pass input_shape = (1, 1, 3, 5, 3) output_shape = (1, 2, 4, 6, 3) size = (2, 4, 6) got = layer_util.resize3d(image=tf.ones(input_shape), size=size) assert got.shape == output_shape # Check resize3d for images with the same size and without channel nor batch - Pass input_shape = (1, 3, 5) output_shape = (1, 3, 5) size = (1, 3, 5) got = layer_util.resize3d(image=tf.ones(input_shape), size=size) assert got.shape == output_shape # Check resize3d for images with the same size and without channel - Pass input_shape = (1, 1, 3, 5) output_shape = (1, 1, 3, 5) size = (1, 3, 5) got = layer_util.resize3d(image=tf.ones(input_shape), size=size) assert got.shape == output_shape # Check resize3d for images with the same size and with one channel - Pass input_shape = (1, 1, 3, 5, 1) output_shape = (1, 1, 3, 5, 1) size = (1, 3, 5) got = layer_util.resize3d(image=tf.ones(input_shape), size=size) assert got.shape == output_shape # Check resize3d for images with the same size and with multiple channels - Pass input_shape = (1, 1, 3, 5, 3) output_shape = (1, 1, 3, 5, 3) size = (1, 3, 5) got = layer_util.resize3d(image=tf.ones(input_shape), size=size) assert got.shape == output_shape # Check resize3d for proper image dimensions - Fail input_shape = (1, 1) size = (1, 1, 1) with pytest.raises(ValueError) as err_info: layer_util.resize3d(image=tf.ones(input_shape), size=size) assert "resize3d takes input image of dimension 3 or 4 or 5" in str(err_info.value) # Check resize3d for proper size - Fail input_shape = (1, 1, 1) size = (1, 1) with pytest.raises(ValueError) as err_info: layer_util.resize3d(image=tf.ones(input_shape), size=size) assert "resize3d takes size of type tuple/list and of length 3" in str( err_info.value )