예제 #1
0
def test_similarity_fn_unknown_loss():
    """
    Test dissimilarity function raises an error
    if an unknonw loss type is passed.
    """
    config = {"name": "random"}
    with pytest.raises(ValueError):
        label.get_dissimilarity_fn(config)
예제 #2
0
 def test_unknown_cases(self):
     """
     Test dissimilarity function raises an error
     if an unknonw loss type is passed.
     """
     config = {"name": "random"}
     with pytest.raises(ValueError) as err_info:
         label.get_dissimilarity_fn(config)
     assert "Unknown loss type" in str(err_info.value)
예제 #3
0
def test_similarity_fn_single_scale():
    """
    Asserting loss function returned by get dissimilarity
    function when appropriate strings passed.
    """
    config = {"name": "single_scale", "single_scale": "jaccard"}
    assert isinstance(label.get_dissimilarity_fn(config), FunctionType)
예제 #4
0
파일: util.py 프로젝트: zcemycl/DeepReg
def add_label_loss(
    model: tf.keras.Model,
    grid_fixed: tf.Tensor,
    fixed_label: (tf.Tensor, None),
    pred_fixed_label: (tf.Tensor, None),
    loss_config: dict,
) -> tf.keras.Model:
    """
    Add label dissimilarity loss of ddf into model.

    :param model: tf.keras.Model
    :param grid_fixed: tensor of shape (f_dim1, f_dim2, f_dim3, 3)
    :param fixed_label: tensor of shape (batch, f_dim1, f_dim2, f_dim3)
    :param pred_fixed_label: tensor of shape (batch, f_dim1, f_dim2, f_dim3)
    :param loss_config: config for loss
    """
    if fixed_label is not None:
        loss_label = tf.reduce_mean(
            label_loss.get_dissimilarity_fn(
                config=loss_config["dissimilarity"]["label"])(
                    y_true=fixed_label, y_pred=pred_fixed_label))
        weighted_loss_label = (loss_label *
                               loss_config["dissimilarity"]["label"]["weight"])
        model.add_loss(weighted_loss_label)
        model.add_metric(loss_label,
                         name="loss/label_dissimilarity",
                         aggregation="mean")
        model.add_metric(
            weighted_loss_label,
            name="loss/weighted_label_dissimilarity",
            aggregation="mean",
        )

        # metrics
        dice_binary = label_loss.dice_score(y_true=fixed_label,
                                            y_pred=pred_fixed_label,
                                            binary=True)
        dice_float = label_loss.dice_score(y_true=fixed_label,
                                           y_pred=pred_fixed_label,
                                           binary=False)
        tre = label_loss.compute_centroid_distance(y_true=fixed_label,
                                                   y_pred=pred_fixed_label,
                                                   grid=grid_fixed)
        foreground_label = label_loss.foreground_proportion(y=fixed_label)
        foreground_pred = label_loss.foreground_proportion(y=pred_fixed_label)
        model.add_metric(dice_binary,
                         name="metric/dice_binary",
                         aggregation="mean")
        model.add_metric(dice_float,
                         name="metric/dice_float",
                         aggregation="mean")
        model.add_metric(tre, name="metric/tre", aggregation="mean")
        model.add_metric(foreground_label,
                         name="metric/foreground_label",
                         aggregation="mean")
        model.add_metric(foreground_pred,
                         name="metric/foreground_pred",
                         aggregation="mean")
    return model
예제 #5
0
 def test_known_cases(self, config):
     """
     Asserting loss function returned by get dissimilarity
     function when appropriate strings passed.
     """
     loss_fn = label.get_dissimilarity_fn(config)
     loss = loss_fn(y_true=self.y_true, y_pred=self.y_pred)
     assert loss.shape == (self.batch_size, )