def add_image_loss( model: tf.keras.Model, fixed_image: tf.Tensor, pred_fixed_image: tf.Tensor, loss_config: dict, ) -> tf.keras.Model: """ Add image dissimilarity loss of ddf into model. :param model: tf.keras.Model :param fixed_image: tensor of shape (batch, f_dim1, f_dim2, f_dim3) :param pred_fixed_image: tensor of shape (batch, f_dim1, f_dim2, f_dim3) :param loss_config: config for loss """ if loss_config["dissimilarity"]["image"]["weight"] > 0: loss_image = tf.reduce_mean( image_loss.dissimilarity_fn( y_true=fixed_image, y_pred=pred_fixed_image, **loss_config["dissimilarity"]["image"], )) weighted_loss_image = (loss_image * loss_config["dissimilarity"]["image"]["weight"]) model.add_loss(weighted_loss_image) model.add_metric(loss_image, name="loss/image_dissimilarity", aggregation="mean") model.add_metric( weighted_loss_image, name="loss/weighted_image_dissimilarity", aggregation="mean", ) return model
def train_step(warper, weights, optimizer, mov, fix): """ Train step function for backprop using gradient tape :param warper: warping function returned from layer.Warping :param weights: trainable ddf [1, f_dim1, f_dim2, f_dim3, 3] :param optimizer: tf.optimizers :param mov: moving image [1, m_dim1, m_dim2, m_dim3] :param fix: fixed image [1, f_dim1, f_dim2, f_dim3] :return: loss: overall loss to optimise loss_image: image dissimilarity loss_deform: deformation regularisation """ with tf.GradientTape() as tape: pred = warper(inputs=[weights, mov]) loss_image = image_loss.dissimilarity_fn(y_true=fix, y_pred=pred, name=image_loss_name) loss_deform = deform_loss.local_displacement_energy( weights, deform_loss_name) loss = loss_image + weight_deform_loss * loss_deform gradients = tape.gradient(loss, [weights]) optimizer.apply_gradients(zip(gradients, [weights])) return loss, loss_image, loss_deform
def test_dissimilarity_fn(): """ Testing computed dissimilarity function by comparing to precomputed, the dissimilarity function can be either normalized cross correlation or sum square error function. """ # lncc diff images tensor_true = np.array(range(12)).reshape((2, 1, 2, 3)) tensor_pred = 0.6 * np.ones((2, 1, 2, 3)) tensor_true = tf.convert_to_tensor(tensor_true, dtype=tf.float32) tensor_pred = tf.convert_to_tensor(tensor_pred, dtype=tf.float32) name_ncc = "lncc" get_ncc = image.dissimilarity_fn(tensor_true, tensor_pred, name_ncc) expect_ncc = [-0.68002254, -0.9608879] assert is_equal_tf(get_ncc, expect_ncc) # ssd diff images tensor_true1 = np.zeros((2, 1, 2, 3)) tensor_pred1 = 0.6 * np.ones((2, 1, 2, 3)) tensor_true1 = tf.convert_to_tensor(tensor_true1, dtype=tf.float32) tensor_pred1 = tf.convert_to_tensor(tensor_pred1, dtype=tf.float32) name_ssd = "ssd" get_ssd = image.dissimilarity_fn(tensor_true1, tensor_pred1, name_ssd) expect_ssd = [0.36, 0.36] assert is_equal_tf(get_ssd, expect_ssd) # TODO gmi diff images # lncc same image get_zero_similarity_ncc = image.dissimilarity_fn( tensor_pred1, tensor_pred1, name_ncc ) assert is_equal_tf(get_zero_similarity_ncc, [-1, -1]) # ssd same image get_zero_similarity_ssd = image.dissimilarity_fn( tensor_true1, tensor_true1, name_ssd ) assert is_equal_tf(get_zero_similarity_ssd, [0, 0]) # gmi same image t = tf.ones([4, 3, 3, 3]) get_zero_similarity_gmi = image.dissimilarity_fn(t, t, "gmi") assert is_equal_tf(get_zero_similarity_gmi, [0, 0, 0, 0]) # unknown func name with pytest.raises(AssertionError): image.dissimilarity_fn( tensor_true1, tensor_pred1, "some random string that isn't ssd or lncc" )
def train_step(grid, weights, optimizer, mov, fix): """ Train step function for backprop using gradient tape :param grid: reference grid return from layer_util.get_reference_grid :param weights: trainable affine parameters [1, 4, 3] :param optimizer: tf.optimizers :param mov: moving image [1, m_dim1, m_dim2, m_dim3] :param fix: fixed image [1, f_dim1, f_dim2, f_dim3] :return loss: image dissimilarity to minimise """ with tf.GradientTape() as tape: pred = layer_util.resample(vol=mov, loc=layer_util.warp_grid(grid, weights)) loss = image_loss.dissimilarity_fn( y_true=fix, y_pred=pred, name=image_loss_name ) gradients = tape.gradient(loss, [weights]) optimizer.apply_gradients(zip(gradients, [weights])) return loss
def test_error(self): # unknown func name with pytest.raises(ValueError) as err_info: image.dissimilarity_fn(self.y_true, self.y_pred, "") assert "Unknown loss type" in str(err_info.value)
def test_output(self, y_true, y_pred, name, expected, tol): got = image.dissimilarity_fn(y_true, y_pred, name) assert is_equal_tf(got, expected, atol=tol)