Esempio n. 1
0
def add_image_loss(
    model: tf.keras.Model,
    fixed_image: tf.Tensor,
    pred_fixed_image: tf.Tensor,
    loss_config: dict,
) -> tf.keras.Model:
    """
    Add image dissimilarity loss of ddf into model.

    :param model: tf.keras.Model
    :param fixed_image: tensor of shape (batch, f_dim1, f_dim2, f_dim3)
    :param pred_fixed_image: tensor of shape (batch, f_dim1, f_dim2, f_dim3)
    :param loss_config: config for loss
    """
    if loss_config["dissimilarity"]["image"]["weight"] > 0:
        loss_image = tf.reduce_mean(
            image_loss.dissimilarity_fn(
                y_true=fixed_image,
                y_pred=pred_fixed_image,
                **loss_config["dissimilarity"]["image"],
            ))
        weighted_loss_image = (loss_image *
                               loss_config["dissimilarity"]["image"]["weight"])
        model.add_loss(weighted_loss_image)
        model.add_metric(loss_image,
                         name="loss/image_dissimilarity",
                         aggregation="mean")
        model.add_metric(
            weighted_loss_image,
            name="loss/weighted_image_dissimilarity",
            aggregation="mean",
        )
    return model
Esempio n. 2
0
def add_ddf_loss(model: tf.keras.Model, ddf: tf.Tensor,
                 loss_config: dict) -> tf.keras.Model:
    """
    add regularization loss of ddf into model
    :param model: tf.keras.Model
    :param ddf: tensor of shape (batch, m_dim1, m_dim2, m_dim3, 3)
    :param loss_config: config for loss
    """
    loss_reg = tf.reduce_mean(
        deform_loss.local_displacement_energy(ddf,
                                              **loss_config["regularization"]))
    weighted_loss_reg = loss_reg * loss_config["regularization"]["weight"]
    model.add_loss(weighted_loss_reg)
    model.add_metric(loss_reg, name="loss/regularization", aggregation="mean")
    model.add_metric(weighted_loss_reg,
                     name="loss/weighted_regularization",
                     aggregation="mean")
    return model
Esempio n. 3
0
def add_label_loss(
    model: tf.keras.Model,
    grid_fixed: tf.Tensor,
    fixed_label: (tf.Tensor, None),
    pred_fixed_label: (tf.Tensor, None),
    loss_config: dict,
) -> tf.keras.Model:
    """
    Add label dissimilarity loss of ddf into model.

    :param model: tf.keras.Model
    :param grid_fixed: tensor of shape (f_dim1, f_dim2, f_dim3, 3)
    :param fixed_label: tensor of shape (batch, f_dim1, f_dim2, f_dim3)
    :param pred_fixed_label: tensor of shape (batch, f_dim1, f_dim2, f_dim3)
    :param loss_config: config for loss
    """
    if fixed_label is not None:
        loss_label = tf.reduce_mean(
            label_loss.get_dissimilarity_fn(
                config=loss_config["dissimilarity"]["label"])(
                    y_true=fixed_label, y_pred=pred_fixed_label))
        weighted_loss_label = (loss_label *
                               loss_config["dissimilarity"]["label"]["weight"])
        model.add_loss(weighted_loss_label)
        model.add_metric(loss_label,
                         name="loss/label_dissimilarity",
                         aggregation="mean")
        model.add_metric(
            weighted_loss_label,
            name="loss/weighted_label_dissimilarity",
            aggregation="mean",
        )

        # metrics
        dice_binary = label_loss.dice_score(y_true=fixed_label,
                                            y_pred=pred_fixed_label,
                                            binary=True)
        dice_float = label_loss.dice_score(y_true=fixed_label,
                                           y_pred=pred_fixed_label,
                                           binary=False)
        tre = label_loss.compute_centroid_distance(y_true=fixed_label,
                                                   y_pred=pred_fixed_label,
                                                   grid=grid_fixed)
        foreground_label = label_loss.foreground_proportion(y=fixed_label)
        foreground_pred = label_loss.foreground_proportion(y=pred_fixed_label)
        model.add_metric(dice_binary,
                         name="metric/dice_binary",
                         aggregation="mean")
        model.add_metric(dice_float,
                         name="metric/dice_float",
                         aggregation="mean")
        model.add_metric(tre, name="metric/tre", aggregation="mean")
        model.add_metric(foreground_label,
                         name="metric/foreground_label",
                         aggregation="mean")
        model.add_metric(foreground_pred,
                         name="metric/foreground_pred",
                         aggregation="mean")
    return model