def train_step(warper, weights, optimizer, mov, fix): """ Train step function for backprop using gradient tape :param warper: warping function returned from layer.Warping :param weights: trainable ddf [1, f_dim1, f_dim2, f_dim3, 3] :param optimizer: tf.optimizers :param mov: moving image [1, m_dim1, m_dim2, m_dim3] :param fix: fixed image [1, f_dim1, f_dim2, f_dim3] :return: loss: overall loss to optimise loss_image: image dissimilarity loss_deform: deformation regularisation """ with tf.GradientTape() as tape: pred = warper(inputs=[weights, mov]) loss_image = image_loss.dissimilarity_fn(y_true=fix, y_pred=pred, name=image_loss_name) loss_deform = deform_loss.local_displacement_energy( weights, deform_loss_name) loss = loss_image + weight_deform_loss * loss_deform gradients = tape.gradient(loss, [weights]) optimizer.apply_gradients(zip(gradients, [weights])) return loss, loss_image, loss_deform
def test_local_displacement_energy(): """test the computation of local displacement energy for ddf""" # bending energy tensor = tf.ones([4, 50, 50, 50, 3]) get = deform.local_displacement_energy(tensor, "bending") expect = tf.zeros([4]) assert is_equal_tf(get, expect) # l1 norm on gradient tensor = tf.ones([4, 50, 50, 50, 3]) get = deform.local_displacement_energy(tensor, "gradient-l1") expect = tf.zeros([4]) assert is_equal_tf(get, expect) # l2 norm on gradient tensor = tf.ones([4, 50, 50, 50, 3]) get = deform.local_displacement_energy(tensor, "gradient-l2") expect = tf.zeros([4]) assert is_equal_tf(get, expect) # not supported energy type tensor = tf.ones([4, 50, 50, 50, 3]) with pytest.raises(ValueError): deform.local_displacement_energy(tensor, "a wrong string")
def add_ddf_loss(model: tf.keras.Model, ddf: tf.Tensor, loss_config: dict) -> tf.keras.Model: """ add regularization loss of ddf into model :param model: tf.keras.Model :param ddf: tensor of shape (batch, m_dim1, m_dim2, m_dim3, 3) :param loss_config: config for loss """ loss_reg = tf.reduce_mean( deform_loss.local_displacement_energy(ddf, **loss_config["regularization"])) weighted_loss_reg = loss_reg * loss_config["regularization"]["weight"] model.add_loss(weighted_loss_reg) model.add_metric(loss_reg, name="loss/regularization", aggregation="mean") model.add_metric(weighted_loss_reg, name="loss/weighted_regularization", aggregation="mean") return model