Example #1
0
def train_step(warper, weights, optimizer, mov, fix):
    """
    Train step function for backprop using gradient tape

    :param warper: warping function returned from layer.Warping
    :param weights: trainable ddf [1, f_dim1, f_dim2, f_dim3, 3]
    :param optimizer: tf.optimizers
    :param mov: moving image [1, m_dim1, m_dim2, m_dim3]
    :param fix: fixed image [1, f_dim1, f_dim2, f_dim3]
    :return:
        loss: overall loss to optimise
        loss_image: image dissimilarity
        loss_deform: deformation regularisation
    """
    with tf.GradientTape() as tape:
        pred = warper(inputs=[weights, mov])
        loss_image = image_loss.dissimilarity_fn(y_true=fix,
                                                 y_pred=pred,
                                                 name=image_loss_name)
        loss_deform = deform_loss.local_displacement_energy(
            weights, deform_loss_name)
        loss = loss_image + weight_deform_loss * loss_deform
    gradients = tape.gradient(loss, [weights])
    optimizer.apply_gradients(zip(gradients, [weights]))
    return loss, loss_image, loss_deform
Example #2
0
def test_local_displacement_energy():
    """test the computation of local displacement energy for ddf"""
    # bending energy
    tensor = tf.ones([4, 50, 50, 50, 3])
    get = deform.local_displacement_energy(tensor, "bending")
    expect = tf.zeros([4])
    assert is_equal_tf(get, expect)

    # l1 norm on gradient
    tensor = tf.ones([4, 50, 50, 50, 3])
    get = deform.local_displacement_energy(tensor, "gradient-l1")
    expect = tf.zeros([4])
    assert is_equal_tf(get, expect)

    # l2 norm on gradient
    tensor = tf.ones([4, 50, 50, 50, 3])
    get = deform.local_displacement_energy(tensor, "gradient-l2")
    expect = tf.zeros([4])
    assert is_equal_tf(get, expect)

    # not supported energy type
    tensor = tf.ones([4, 50, 50, 50, 3])
    with pytest.raises(ValueError):
        deform.local_displacement_energy(tensor, "a wrong string")
def add_ddf_loss(model: tf.keras.Model, ddf: tf.Tensor,
                 loss_config: dict) -> tf.keras.Model:
    """
    add regularization loss of ddf into model
    :param model: tf.keras.Model
    :param ddf: tensor of shape (batch, m_dim1, m_dim2, m_dim3, 3)
    :param loss_config: config for loss
    """
    loss_reg = tf.reduce_mean(
        deform_loss.local_displacement_energy(ddf,
                                              **loss_config["regularization"]))
    weighted_loss_reg = loss_reg * loss_config["regularization"]["weight"]
    model.add_loss(weighted_loss_reg)
    model.add_metric(loss_reg, name="loss/regularization", aggregation="mean")
    model.add_metric(weighted_loss_reg,
                     name="loss/weighted_regularization",
                     aggregation="mean")
    return model