Example #1
0
def add_image_loss(
    model: tf.keras.Model,
    fixed_image: tf.Tensor,
    pred_fixed_image: tf.Tensor,
    loss_config: dict,
) -> tf.keras.Model:
    """
    Add image dissimilarity loss of ddf into model.

    :param model: tf.keras.Model
    :param fixed_image: tensor of shape (batch, f_dim1, f_dim2, f_dim3)
    :param pred_fixed_image: tensor of shape (batch, f_dim1, f_dim2, f_dim3)
    :param loss_config: config for loss
    """
    if loss_config["dissimilarity"]["image"]["weight"] > 0:
        loss_image = tf.reduce_mean(
            image_loss.dissimilarity_fn(
                y_true=fixed_image,
                y_pred=pred_fixed_image,
                **loss_config["dissimilarity"]["image"],
            ))
        weighted_loss_image = (loss_image *
                               loss_config["dissimilarity"]["image"]["weight"])
        model.add_loss(weighted_loss_image)
        model.add_metric(loss_image,
                         name="loss/image_dissimilarity",
                         aggregation="mean")
        model.add_metric(
            weighted_loss_image,
            name="loss/weighted_image_dissimilarity",
            aggregation="mean",
        )
    return model
Example #2
0
def train_step(warper, weights, optimizer, mov, fix):
    """
    Train step function for backprop using gradient tape

    :param warper: warping function returned from layer.Warping
    :param weights: trainable ddf [1, f_dim1, f_dim2, f_dim3, 3]
    :param optimizer: tf.optimizers
    :param mov: moving image [1, m_dim1, m_dim2, m_dim3]
    :param fix: fixed image [1, f_dim1, f_dim2, f_dim3]
    :return:
        loss: overall loss to optimise
        loss_image: image dissimilarity
        loss_deform: deformation regularisation
    """
    with tf.GradientTape() as tape:
        pred = warper(inputs=[weights, mov])
        loss_image = image_loss.dissimilarity_fn(y_true=fix,
                                                 y_pred=pred,
                                                 name=image_loss_name)
        loss_deform = deform_loss.local_displacement_energy(
            weights, deform_loss_name)
        loss = loss_image + weight_deform_loss * loss_deform
    gradients = tape.gradient(loss, [weights])
    optimizer.apply_gradients(zip(gradients, [weights]))
    return loss, loss_image, loss_deform
Example #3
0
def test_dissimilarity_fn():
    """
    Testing computed dissimilarity function by comparing to precomputed, the dissimilarity function can be either normalized cross correlation or sum square error function.
    """

    # lncc diff images
    tensor_true = np.array(range(12)).reshape((2, 1, 2, 3))
    tensor_pred = 0.6 * np.ones((2, 1, 2, 3))
    tensor_true = tf.convert_to_tensor(tensor_true, dtype=tf.float32)
    tensor_pred = tf.convert_to_tensor(tensor_pred, dtype=tf.float32)

    name_ncc = "lncc"
    get_ncc = image.dissimilarity_fn(tensor_true, tensor_pred, name_ncc)
    expect_ncc = [-0.68002254, -0.9608879]

    assert is_equal_tf(get_ncc, expect_ncc)

    # ssd diff images
    tensor_true1 = np.zeros((2, 1, 2, 3))
    tensor_pred1 = 0.6 * np.ones((2, 1, 2, 3))
    tensor_true1 = tf.convert_to_tensor(tensor_true1, dtype=tf.float32)
    tensor_pred1 = tf.convert_to_tensor(tensor_pred1, dtype=tf.float32)

    name_ssd = "ssd"
    get_ssd = image.dissimilarity_fn(tensor_true1, tensor_pred1, name_ssd)
    expect_ssd = [0.36, 0.36]

    assert is_equal_tf(get_ssd, expect_ssd)

    # TODO gmi diff images

    # lncc same image
    get_zero_similarity_ncc = image.dissimilarity_fn(
        tensor_pred1, tensor_pred1, name_ncc
    )
    assert is_equal_tf(get_zero_similarity_ncc, [-1, -1])

    # ssd same image
    get_zero_similarity_ssd = image.dissimilarity_fn(
        tensor_true1, tensor_true1, name_ssd
    )
    assert is_equal_tf(get_zero_similarity_ssd, [0, 0])

    # gmi same image
    t = tf.ones([4, 3, 3, 3])
    get_zero_similarity_gmi = image.dissimilarity_fn(t, t, "gmi")
    assert is_equal_tf(get_zero_similarity_gmi, [0, 0, 0, 0])

    # unknown func name
    with pytest.raises(AssertionError):
        image.dissimilarity_fn(
            tensor_true1, tensor_pred1, "some random string that isn't ssd or lncc"
        )
Example #4
0
def train_step(grid, weights, optimizer, mov, fix):
    """
    Train step function for backprop using gradient tape

    :param grid: reference grid return from layer_util.get_reference_grid
    :param weights: trainable affine parameters [1, 4, 3]
    :param optimizer: tf.optimizers
    :param mov: moving image [1, m_dim1, m_dim2, m_dim3]
    :param fix: fixed image [1, f_dim1, f_dim2, f_dim3]
    :return loss: image dissimilarity to minimise
    """
    with tf.GradientTape() as tape:
        pred = layer_util.resample(vol=mov, loc=layer_util.warp_grid(grid, weights))
        loss = image_loss.dissimilarity_fn(
            y_true=fix, y_pred=pred, name=image_loss_name
        )
    gradients = tape.gradient(loss, [weights])
    optimizer.apply_gradients(zip(gradients, [weights]))
    return loss
Example #5
0
 def test_error(self):
     # unknown func name
     with pytest.raises(ValueError) as err_info:
         image.dissimilarity_fn(self.y_true, self.y_pred, "")
     assert "Unknown loss type" in str(err_info.value)
Example #6
0
 def test_output(self, y_true, y_pred, name, expected, tol):
     got = image.dissimilarity_fn(y_true, y_pred, name)
     assert is_equal_tf(got, expected, atol=tol)