def test_parse_multiple_loss(self): loss_config = { "label": [ { "name": "dice", "weight": 1.0, }, { "name": "cross-entropy", "weight": 1.0, }, ], } got = parse_label_loss(loss_config=loss_config) assert got == loss_config
def test_parse_background_weight(self): outdated_config = { "label": { "name": "dice", "weight": 1.0, "neg_weight": 2.0, }, } expected_config = { "label": { "name": "dice", "weight": 1.0, "background_weight": 2.0, }, } got = parse_label_loss(loss_config=outdated_config) assert got == expected_config
def test_parse_old_loss(self, name_loss: str, expected_config: dict): loss_config = { "label": { "name": name_loss, "single_scale": { "loss_type": "dice_generalized", }, "multi_scale": { "loss_type": "mean-squared", "loss_scales": [0, 1], }, }, } if name_loss == "multi_scale": loss_config["label"]["weight"] = 2.0 got = parse_label_loss(loss_config=loss_config) assert got == expected_config