def add_image_loss( model: tf.keras.Model, fixed_image: tf.Tensor, pred_fixed_image: tf.Tensor, loss_config: dict, ) -> tf.keras.Model: """ Add image dissimilarity loss of ddf into model. :param model: tf.keras.Model :param fixed_image: tensor of shape (batch, f_dim1, f_dim2, f_dim3) :param pred_fixed_image: tensor of shape (batch, f_dim1, f_dim2, f_dim3) :param loss_config: config for loss """ if loss_config["dissimilarity"]["image"]["weight"] > 0: loss_image = tf.reduce_mean( image_loss.dissimilarity_fn( y_true=fixed_image, y_pred=pred_fixed_image, **loss_config["dissimilarity"]["image"], )) weighted_loss_image = (loss_image * loss_config["dissimilarity"]["image"]["weight"]) model.add_loss(weighted_loss_image) model.add_metric(loss_image, name="loss/image_dissimilarity", aggregation="mean") model.add_metric( weighted_loss_image, name="loss/weighted_image_dissimilarity", aggregation="mean", ) return model
def add_ddf_loss(model: tf.keras.Model, ddf: tf.Tensor, loss_config: dict) -> tf.keras.Model: """ add regularization loss of ddf into model :param model: tf.keras.Model :param ddf: tensor of shape (batch, m_dim1, m_dim2, m_dim3, 3) :param loss_config: config for loss """ loss_reg = tf.reduce_mean( deform_loss.local_displacement_energy(ddf, **loss_config["regularization"])) weighted_loss_reg = loss_reg * loss_config["regularization"]["weight"] model.add_loss(weighted_loss_reg) model.add_metric(loss_reg, name="loss/regularization", aggregation="mean") model.add_metric(weighted_loss_reg, name="loss/weighted_regularization", aggregation="mean") return model
def add_label_loss( model: tf.keras.Model, grid_fixed: tf.Tensor, fixed_label: (tf.Tensor, None), pred_fixed_label: (tf.Tensor, None), loss_config: dict, ) -> tf.keras.Model: """ Add label dissimilarity loss of ddf into model. :param model: tf.keras.Model :param grid_fixed: tensor of shape (f_dim1, f_dim2, f_dim3, 3) :param fixed_label: tensor of shape (batch, f_dim1, f_dim2, f_dim3) :param pred_fixed_label: tensor of shape (batch, f_dim1, f_dim2, f_dim3) :param loss_config: config for loss """ if fixed_label is not None: loss_label = tf.reduce_mean( label_loss.get_dissimilarity_fn( config=loss_config["dissimilarity"]["label"])( y_true=fixed_label, y_pred=pred_fixed_label)) weighted_loss_label = (loss_label * loss_config["dissimilarity"]["label"]["weight"]) model.add_loss(weighted_loss_label) model.add_metric(loss_label, name="loss/label_dissimilarity", aggregation="mean") model.add_metric( weighted_loss_label, name="loss/weighted_label_dissimilarity", aggregation="mean", ) # metrics dice_binary = label_loss.dice_score(y_true=fixed_label, y_pred=pred_fixed_label, binary=True) dice_float = label_loss.dice_score(y_true=fixed_label, y_pred=pred_fixed_label, binary=False) tre = label_loss.compute_centroid_distance(y_true=fixed_label, y_pred=pred_fixed_label, grid=grid_fixed) foreground_label = label_loss.foreground_proportion(y=fixed_label) foreground_pred = label_loss.foreground_proportion(y=pred_fixed_label) model.add_metric(dice_binary, name="metric/dice_binary", aggregation="mean") model.add_metric(dice_float, name="metric/dice_float", aggregation="mean") model.add_metric(tre, name="metric/tre", aggregation="mean") model.add_metric(foreground_label, name="metric/foreground_label", aggregation="mean") model.add_metric(foreground_pred, name="metric/foreground_pred", aggregation="mean") return model