def ddf_dvf_forward( backbone: tf.keras.Model, moving_image: tf.Tensor, fixed_image: tf.Tensor, moving_label: (tf.Tensor, None), moving_image_size: tuple, fixed_image_size: tuple, output_dvf: bool, ) -> [(tf.Tensor, None), tf.Tensor, tf.Tensor, (tf.Tensor, None), tf.Tensor]: """ Perform the network forward pass :param backbone: model architecture object, e.g. model.backbone.local_net :param moving_image: tensor of shape (batch, m_dim1, m_dim2, m_dim3) :param fixed_image: tensor of shape (batch, f_dim1, f_dim2, f_dim3) :param moving_label: tensor of shape (batch, m_dim1, m_dim2, m_dim3) or None :param moving_image_size: tuple like (m_dim1, m_dim2, m_dim3) :param fixed_image_size: tuple like (f_dim1, f_dim2, f_dim3) :param output_dvf: bool, if true, model outputs dvf, if false, model outputs ddf :return: tuple(dvf, ddf, pred_fixed_image, pred_fixed_label, fixed_grid), where - dvf is the dense velocity field of shape (batch, f_dim1, f_dim2, f_dim3, 3) - ddf is the dense displacement field of shape (batch, f_dim1, f_dim2, f_dim3, 3) - pred_fixed_image is the predicted (warped) moving image of shape (batch, f_dim1, f_dim2, f_dim3) - pred_fixed_label is the predicted (warped) moving label of shape (batch, f_dim1, f_dim2, f_dim3) - fixed_grid is the grid of shape(f_dim1, f_dim2, f_dim3, 3) """ # expand dims # need to be squeezed later for warping moving_image = tf.expand_dims(moving_image, axis=4) # (batch, m_dim1, m_dim2, m_dim3, 1) fixed_image = tf.expand_dims(fixed_image, axis=4) # (batch, f_dim1, f_dim2, f_dim3, 1) # adjust moving image moving_image = layer_util.resize3d( image=moving_image, size=fixed_image_size) # (batch, f_dim1, f_dim2, f_dim3, 1) # ddf, dvf inputs = tf.concat([moving_image, fixed_image], axis=4) # (batch, f_dim1, f_dim2, f_dim3, 2) backbone_out = backbone( inputs=inputs) # (batch, f_dim1, f_dim2, f_dim3, 3) if output_dvf: dvf = backbone_out # (batch, f_dim1, f_dim2, f_dim3, 3) ddf = layer.IntDVF(fixed_image_size=fixed_image_size)( dvf) # (batch, f_dim1, f_dim2, f_dim3, 3) else: dvf = None ddf = backbone_out # (batch, f_dim1, f_dim2, f_dim3, 3) # prediction, (batch, f_dim1, f_dim2, f_dim3) warping = layer.Warping(fixed_image_size=fixed_image_size) grid_fixed = tf.squeeze(warping.grid_ref, axis=0) # (f_dim1, f_dim2, f_dim3, 3) pred_fixed_image = warping(inputs=[ddf, tf.squeeze(moving_image, axis=4)]) pred_fixed_label = (warping( inputs=[ddf, moving_label]) if moving_label is not None else None) return dvf, ddf, pred_fixed_image, pred_fixed_label, grid_fixed
def test_initDVF(): """ Test the layer.IntDVF class, its default attributes and its call() method. """ batch_size = 5 fixed_image_size = (32, 32, 16) ndims = len(fixed_image_size) model = layer.IntDVF(fixed_image_size) assert isinstance(model._warping, layer.Warping) assert model._num_steps == 7 inputs = np.ones((batch_size, *fixed_image_size, ndims)) output = model.call(inputs) assert all(x == y for x, y in zip((batch_size, ) + fixed_image_size + (ndims, ), output.shape))
def build_model(self): """Build the model to be saved as self._model.""" # build inputs self._inputs = self.build_inputs() moving_image = self._inputs["moving_image"] fixed_image = self._inputs["fixed_image"] control_points = self.config["backbone"].pop("control_points", False) # build ddf backbone_inputs = self.concat_images(moving_image, fixed_image) backbone = REGISTRY.build_backbone( config=self.config["backbone"], default_args=dict( image_size=self.fixed_image_size, out_channels=3, out_kernel_initializer="zeros", out_activation=None, ), ) dvf = backbone(inputs=backbone_inputs) dvf = self._resize_interpolate( dvf, control_points) if control_points else dvf ddf = layer.IntDVF(fixed_image_size=self.fixed_image_size)(dvf) # build outputs self._warping = layer.Warping(fixed_image_size=self.fixed_image_size) # (f_dim1, f_dim2, f_dim3, 3) pred_fixed_image = self._warping(inputs=[ddf, moving_image]) self._outputs = dict(dvf=dvf, ddf=ddf, pred_fixed_image=pred_fixed_image) if not self.labeled: return tf.keras.Model(inputs=self._inputs, outputs=self._outputs) # (f_dim1, f_dim2, f_dim3, 3) moving_label = self._inputs["moving_label"] pred_fixed_label = self._warping(inputs=[ddf, moving_label]) self._outputs["pred_fixed_label"] = pred_fixed_label return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
def forward(_backbone, _moving_image, _moving_label, _fixed_image): """ :param _backbone: :param _moving_image: [batch, m_dim1, m_dim2, m_dim3] :param _moving_label: [batch, m_dim1, m_dim2, m_dim3] :param _fixed_image: [batch, f_dim1, f_dim2, f_dim3] :return: """ # ddf backbone_input = tf.concat([layer.Resize3d(size=fixed_image_size)(inputs=tf.expand_dims(_moving_image, axis=4)), tf.expand_dims(_fixed_image, axis=4)], axis=4) # [batch, f_dim1, f_dim2, f_dim3, 2] _dvf = _backbone(inputs=backbone_input) # [batch, f_dim1, f_dim2, f_dim3, 3] _ddf = layer.IntDVF(fixed_image_size=fixed_image_size)(_dvf) # prediction image ang label shape = [batch, f_dim1, f_dim2, f_dim3] _pred_fixed_image = layer.Warping(fixed_image_size=fixed_image_size)([_ddf, _moving_image]) _pred_fixed_label = layer.Warping(fixed_image_size=fixed_image_size)([_ddf, _moving_label]) return _dvf, _ddf, _pred_fixed_image, _pred_fixed_label
def test_forward(self): """ Test output shape and config. """ fixed_image_size = (8, 9, 10) input_shape = (2, *fixed_image_size, 3) int_layer = layer.IntDVF(fixed_image_size=fixed_image_size) inputs = tf.ones(shape=input_shape) outputs = int_layer(inputs) assert outputs.shape == input_shape config = int_layer.get_config() assert config == dict( fixed_image_size=fixed_image_size, num_steps=7, name="int_dvf", trainable=True, dtype="float32", )
def test_err(self): with pytest.raises(AssertionError): layer.IntDVF(fixed_image_size=(2, 3))