def test_forward(self, moving_image_size, fixed_image_size): batch_size = 2 image = tf.ones(shape=(batch_size, ) + moving_image_size) ddf = tf.ones(shape=(batch_size, ) + fixed_image_size + (3, )) outputs = layer.Warping(fixed_image_size=fixed_image_size)( [ddf, image]) assert outputs.shape == (batch_size, *fixed_image_size)
def ddf_dvf_forward( backbone: tf.keras.Model, moving_image: tf.Tensor, fixed_image: tf.Tensor, moving_label: (tf.Tensor, None), moving_image_size: tuple, fixed_image_size: tuple, output_dvf: bool, ) -> [(tf.Tensor, None), tf.Tensor, tf.Tensor, (tf.Tensor, None), tf.Tensor]: """ Perform the network forward pass :param backbone: model architecture object, e.g. model.backbone.local_net :param moving_image: tensor of shape (batch, m_dim1, m_dim2, m_dim3) :param fixed_image: tensor of shape (batch, f_dim1, f_dim2, f_dim3) :param moving_label: tensor of shape (batch, m_dim1, m_dim2, m_dim3) or None :param moving_image_size: tuple like (m_dim1, m_dim2, m_dim3) :param fixed_image_size: tuple like (f_dim1, f_dim2, f_dim3) :param output_dvf: bool, if true, model outputs dvf, if false, model outputs ddf :return: tuple(dvf, ddf, pred_fixed_image, pred_fixed_label, fixed_grid), where - dvf is the dense velocity field of shape (batch, f_dim1, f_dim2, f_dim3, 3) - ddf is the dense displacement field of shape (batch, f_dim1, f_dim2, f_dim3, 3) - pred_fixed_image is the predicted (warped) moving image of shape (batch, f_dim1, f_dim2, f_dim3) - pred_fixed_label is the predicted (warped) moving label of shape (batch, f_dim1, f_dim2, f_dim3) - fixed_grid is the grid of shape(f_dim1, f_dim2, f_dim3, 3) """ # expand dims # need to be squeezed later for warping moving_image = tf.expand_dims(moving_image, axis=4) # (batch, m_dim1, m_dim2, m_dim3, 1) fixed_image = tf.expand_dims(fixed_image, axis=4) # (batch, f_dim1, f_dim2, f_dim3, 1) # adjust moving image moving_image = layer_util.resize3d( image=moving_image, size=fixed_image_size) # (batch, f_dim1, f_dim2, f_dim3, 1) # ddf, dvf inputs = tf.concat([moving_image, fixed_image], axis=4) # (batch, f_dim1, f_dim2, f_dim3, 2) backbone_out = backbone( inputs=inputs) # (batch, f_dim1, f_dim2, f_dim3, 3) if output_dvf: dvf = backbone_out # (batch, f_dim1, f_dim2, f_dim3, 3) ddf = layer.IntDVF(fixed_image_size=fixed_image_size)( dvf) # (batch, f_dim1, f_dim2, f_dim3, 3) else: dvf = None ddf = backbone_out # (batch, f_dim1, f_dim2, f_dim3, 3) # prediction, (batch, f_dim1, f_dim2, f_dim3) warping = layer.Warping(fixed_image_size=fixed_image_size) grid_fixed = tf.squeeze(warping.grid_ref, axis=0) # (f_dim1, f_dim2, f_dim3, 3) pred_fixed_image = warping(inputs=[ddf, tf.squeeze(moving_image, axis=4)]) pred_fixed_label = (warping( inputs=[ddf, moving_label]) if moving_label is not None else None) return dvf, ddf, pred_fixed_image, pred_fixed_label, grid_fixed
def test_get_config(self): warping = layer.Warping(fixed_image_size=(2, 3, 4)) config = warping.get_config() assert config == dict( fixed_image_size=(2, 3, 4), name="warping", trainable=True, dtype="float32", )
def conditional_forward( backbone: tf.keras.Model, moving_image: tf.Tensor, fixed_image: tf.Tensor, moving_label: (tf.Tensor, None), moving_image_size: tuple, fixed_image_size: tuple, ) -> [tf.Tensor, tf.Tensor]: """ Perform the network forward pass. :param backbone: model architecture object, e.g. model.backbone.local_net :param moving_image: tensor of shape (batch, m_dim1, m_dim2, m_dim3) :param fixed_image: tensor of shape (batch, f_dim1, f_dim2, f_dim3) :param moving_label: tensor of shape (batch, m_dim1, m_dim2, m_dim3) or None :param moving_image_size: tuple like (m_dim1, m_dim2, m_dim3) :param fixed_image_size: tuple like (f_dim1, f_dim2, f_dim3) :return: (pred_fixed_label, fixed_grid), where - pred_fixed_label is the predicted (warped) moving label of shape (batch, f_dim1, f_dim2, f_dim3) - fixed_grid is the grid of shape(f_dim1, f_dim2, f_dim3, 3) """ # expand dims # need to be squeezed later for warping moving_image = tf.expand_dims(moving_image, axis=4) # (batch, m_dim1, m_dim2, m_dim3, 1) fixed_image = tf.expand_dims(fixed_image, axis=4) # (batch, f_dim1, f_dim2, f_dim3, 1) moving_label = tf.expand_dims(moving_label, axis=4) # (batch, m_dim1, m_dim2, m_dim3, 1) # adjust moving image if moving_image_size != fixed_image_size: moving_image = layer_util.resize3d( image=moving_image, size=fixed_image_size) # (batch, f_dim1, f_dim2, f_dim3, 1) moving_label = layer_util.resize3d( image=moving_label, size=fixed_image_size) # (batch, f_dim1, f_dim2, f_dim3, 1) # conditional inputs = tf.concat([moving_image, fixed_image, moving_label], axis=4) # (batch, f_dim1, f_dim2, f_dim3, 3) pred_fixed_label = backbone( inputs=inputs) # (batch, f_dim1, f_dim2, f_dim3, 1) pred_fixed_label = tf.squeeze(pred_fixed_label, axis=4) # (batch, f_dim1, f_dim2, f_dim3) warping = layer.Warping(fixed_image_size=fixed_image_size) grid_fixed = tf.squeeze(warping.grid_ref, axis=0) # (f_dim1, f_dim2, f_dim3, 3) return pred_fixed_label, grid_fixed
def forward(_backbone, _moving_image, _moving_label, _fixed_image): """ :param _backbone: :param _moving_image: [batch, m_dim1, m_dim2, m_dim3] :param _moving_label: [batch, m_dim1, m_dim2, m_dim3] :param _fixed_image: [batch, f_dim1, f_dim2, f_dim3] :return: """ # ddf backbone_input = tf.concat([layer.Resize3d(size=fixed_image_size)(inputs=tf.expand_dims(_moving_image, axis=4)), tf.expand_dims(_fixed_image, axis=4)], axis=4) # [batch, f_dim1, f_dim2, f_dim3, 2] _ddf = _backbone(inputs=backbone_input) # [batch, f_dim1, f_dim2, f_dim3, 3] # prediction image ang label shape = [batch, f_dim1, f_dim2, f_dim3] _pred_fixed_image = layer.Warping(fixed_image_size=fixed_image_size)([_ddf, _moving_image]) _pred_fixed_label = layer.Warping(fixed_image_size=fixed_image_size)([_ddf, _moving_label]) return _ddf, _pred_fixed_image, _pred_fixed_label
def affine_forward( backbone: tf.keras.Model, moving_image: tf.Tensor, fixed_image: tf.Tensor, moving_label: (tf.Tensor, None), moving_image_size: tuple, fixed_image_size: tuple, ): """ Perform the network forward pass. :param backbone: model architecture object, e.g. model.backbone.local_net :param moving_image: tensor of shape (batch, m_dim1, m_dim2, m_dim3) :param fixed_image: tensor of shape (batch, f_dim1, f_dim2, f_dim3) :param moving_label: tensor of shape (batch, m_dim1, m_dim2, m_dim3) or None :param moving_image_size: tuple like (m_dim1, m_dim2, m_dim3) :param fixed_image_size: tuple like (f_dim1, f_dim2, f_dim3) :return: tuple(affine, ddf, pred_fixed_image, pred_fixed_label, fixed_grid), where - affine is the affine transformation matrix predicted by the network (batch, 4, 3) - ddf is the dense displacement field of shape (batch, f_dim1, f_dim2, f_dim3, 3) - pred_fixed_image is the predicted (warped) moving image of shape (batch, f_dim1, f_dim2, f_dim3) - pred_fixed_label is the predicted (warped) moving label of shape (batch, f_dim1, f_dim2, f_dim3) - fixed_grid is the grid of shape(f_dim1, f_dim2, f_dim3, 3) """ # expand dims # need to be squeezed later for warping moving_image = tf.expand_dims(moving_image, axis=4) # (batch, m_dim1, m_dim2, m_dim3, 1) fixed_image = tf.expand_dims(fixed_image, axis=4) # (batch, f_dim1, f_dim2, f_dim3, 1) # adjust moving image moving_image = layer_util.resize3d( image=moving_image, size=fixed_image_size) # (batch, f_dim1, f_dim2, f_dim3, 1) # ddf, dvf inputs = tf.concat([moving_image, fixed_image], axis=4) # (batch, f_dim1, f_dim2, f_dim3, 2) ddf = backbone(inputs=inputs) # (batch, f_dim1, f_dim2, f_dim3, 3) affine = backbone.theta # prediction, (batch, f_dim1, f_dim2, f_dim3) warping = layer.Warping(fixed_image_size=fixed_image_size) grid_fixed = tf.squeeze(warping.grid_ref, axis=0) # (f_dim1, f_dim2, f_dim3, 3) pred_fixed_image = warping(inputs=[ddf, tf.squeeze(moving_image, axis=4)]) pred_fixed_label = (warping( inputs=[ddf, moving_label]) if moving_label is not None else None) return affine, ddf, pred_fixed_image, pred_fixed_label, grid_fixed
def build_model(self): """Build the model to be saved as self._model.""" # build inputs self._inputs = self.build_inputs() moving_image = self._inputs[ "moving_image"] # (batch, m_dim1, m_dim2, m_dim3) fixed_image = self._inputs[ "fixed_image"] # (batch, f_dim1, f_dim2, f_dim3) # build ddf control_points = self.config["backbone"].pop("control_points", False) backbone_inputs = self.concat_images(moving_image, fixed_image) backbone = REGISTRY.build_backbone( config=self.config["backbone"], default_args=dict( image_size=self.fixed_image_size, out_channels=3, out_kernel_initializer="zeros", out_activation=None, ), ) if isinstance(backbone, GlobalNet): # (f_dim1, f_dim2, f_dim3, 3), (4, 3) ddf, theta = backbone(inputs=backbone_inputs) self._outputs = dict(ddf=ddf, theta=theta) else: # (f_dim1, f_dim2, f_dim3, 3) ddf = backbone(inputs=backbone_inputs) ddf = (self._resize_interpolate(ddf, control_points) if control_points else ddf) self._outputs = dict(ddf=ddf) # build outputs warping = layer.Warping(fixed_image_size=self.fixed_image_size) # (f_dim1, f_dim2, f_dim3) pred_fixed_image = warping(inputs=[ddf, moving_image]) self._outputs["pred_fixed_image"] = pred_fixed_image if not self.labeled: return tf.keras.Model(inputs=self._inputs, outputs=self._outputs) # (f_dim1, f_dim2, f_dim3) moving_label = self._inputs["moving_label"] pred_fixed_label = warping(inputs=[ddf, moving_label]) self._outputs["pred_fixed_label"] = pred_fixed_label return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
def test_initDVF(): """ Test the layer.IntDVF class, its default attributes and its call() method. """ batch_size = 5 fixed_image_size = (32, 32, 16) ndims = len(fixed_image_size) model = layer.IntDVF(fixed_image_size) assert isinstance(model._warping, type(layer.Warping(fixed_image_size))) assert model._num_steps == 7 inputs = np.ones((batch_size, *fixed_image_size, ndims)) output = model.call(inputs) assert all(x == y for x, y in zip((batch_size, ) + fixed_image_size + (ndims, ), output.shape))
def build_model(self): """Build the model to be saved as self._model.""" # build inputs self._inputs = self.build_inputs() moving_image = self._inputs["moving_image"] fixed_image = self._inputs["fixed_image"] control_points = self.config["backbone"].pop("control_points", False) # build ddf backbone_inputs = self.concat_images(moving_image, fixed_image) backbone = REGISTRY.build_backbone( config=self.config["backbone"], default_args=dict( image_size=self.fixed_image_size, out_channels=3, out_kernel_initializer="zeros", out_activation=None, ), ) dvf = backbone(inputs=backbone_inputs) dvf = self._resize_interpolate( dvf, control_points) if control_points else dvf ddf = layer.IntDVF(fixed_image_size=self.fixed_image_size)(dvf) # build outputs self._warping = layer.Warping(fixed_image_size=self.fixed_image_size) # (f_dim1, f_dim2, f_dim3, 3) pred_fixed_image = self._warping(inputs=[ddf, moving_image]) self._outputs = dict(dvf=dvf, ddf=ddf, pred_fixed_image=pred_fixed_image) if not self.labeled: return tf.keras.Model(inputs=self._inputs, outputs=self._outputs) # (f_dim1, f_dim2, f_dim3, 3) moving_label = self._inputs["moving_label"] pred_fixed_label = self._warping(inputs=[ddf, moving_label]) self._outputs["pred_fixed_label"] = pred_fixed_label return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
def test_warping(): """ Test the layer.Warping class, its default attributes and its call() method. """ batch_size = 5 fixed_image_size = (32, 32, 16) moving_image_size = (24, 24, 16) ndims = len(moving_image_size) grid_size = (1, ) + fixed_image_size + (3, ) model = layer.Warping(fixed_image_size) assert all(x == y for x, y in zip(grid_size, model.grid_ref.shape)) # Pass an input of all zeros inputs = [ np.ones( ( batch_size, fixed_image_size[0], fixed_image_size[1], fixed_image_size[2], ndims, ), dtype="float32", ), np.ones( ( batch_size, moving_image_size[0], moving_image_size[1], moving_image_size[2], ), dtype="float32", ), ] # Get outputs by calling output = model.call(inputs) # Expected shape is (5, 1, 2, 3, 3) assert all(x == y for x, y in zip((batch_size, ) + fixed_image_size, output.shape))
loss_image = image_loss.dissimilarity_fn( y_true=fix, y_pred=pred, name=image_loss_name ) loss_deform = deform_loss.local_displacement_energy(weights, deform_loss_name) loss = loss_image + weight_deform_loss * loss_deform gradients = tape.gradient(loss, [weights]) optimizer.apply_gradients(zip(gradients, [weights])) return loss, loss_image, loss_deform # ddf as trainable weights fixed_image_size = fixed_image.shape initializer = tf.random_normal_initializer(mean=0, stddev=1e-3) var_ddf = tf.Variable(initializer(fixed_image_size + [3]), name="ddf", trainable=True) warping = layer.Warping(fixed_image_size=fixed_image_size[1:4]) optimiser = tf.optimizers.Adam(learning_rate) for step in range(total_iter): loss_opt, loss_image_opt, loss_deform_opt = train_step( warping, var_ddf, optimiser, moving_image, fixed_image ) if (step % 50) == 0: # print info tf.print( "Step", step, "loss", loss_opt, image_loss_name, loss_image_opt, deform_loss_name, loss_deform_opt,