def test_two_elements() -> None:
    config = DummyModel()
    element1 = ModelTrainingStepsForSegmentation.construct_non_mixture_loss_function(
        config, SegmentationLoss.CrossEntropy, power=None)
    element2 = ModelTrainingStepsForSegmentation.construct_non_mixture_loss_function(
        config, SegmentationLoss.SoftDice, power=None)
    weight1, weight2 = 0.3, 0.7
    mixture = MixtureLoss([(weight1, element1), (weight2, element2)])
    target = torch.tensor([[[0, 0, 1], [1, 1, 0]]], dtype=torch.float32)
    logits = torch.tensor([[[-1e9, -1e9, 0], [0, 0, 0]]], dtype=torch.float32)
    # Extract class indices
    element1_loss = element1(logits, target)
    element2_loss = element2(logits, target)
    mixture_loss = mixture(logits, target)
    assert torch.isclose(weight1 * element1_loss + weight2 * element2_loss,
                         mixture_loss)
def test_single_element() -> None:
    config = DummyModel()
    element = ModelTrainingStepsForSegmentation.construct_non_mixture_loss_function(
        config, SegmentationLoss.CrossEntropy, power=None)
    mixture = MixtureLoss([(1.0, element)])
    target = torch.tensor([[[0, 0, 1], [1, 1, 0]]], dtype=torch.float32)
    logits = torch.tensor([[[-1e9, -1e9, 0], [0, 0, 0]]], dtype=torch.float32)
    # Extract class indices
    element_loss = element(logits, target)
    mixture_loss = mixture(logits, target)
    assert torch.isclose(element_loss, mixture_loss)
Exemple #3
0
def test_construct_loss_function() -> None:
    model_config = DummyModel()
    model_config.loss_type = SegmentationLoss.Mixture
    # Weights deliberately do not sum to 1.0.
    weights = [1.5, 0.5]
    model_config.mixture_loss_components = [
        MixtureLossComponent(weights[0], SegmentationLoss.CrossEntropy, 0.2),
        MixtureLossComponent(weights[1], SegmentationLoss.SoftDice, 0.1)]
    loss_fn = ModelTrainingStepsForSegmentation.construct_loss_function(model_config)
    assert isinstance(loss_fn, MixtureLoss)
    assert len(loss_fn.components) == len(weights)
    assert loss_fn.components[0][0] == weights[0] / sum(weights)
    assert loss_fn.components[1][0] == weights[1] / sum(weights)
def create_model_training_steps(model_config: ModelConfigBase,
                                train_val_params: TrainValidateParameters) -> ModelTrainingStepsBase:
    """
    Create model training steps based on the model config and train/val parameters
    :param model_config: Model configs to use
    :param train_val_params: Train/Val parameters to use
    :return:
    """
    if isinstance(model_config, SegmentationModelBase):
        return ModelTrainingStepsForSegmentation(model_config, train_val_params)
    elif isinstance(model_config, ScalarModelBase):
        if isinstance(model_config, SequenceModelBase):
            return ModelTrainingStepsForSequenceModel(model_config, train_val_params)
        else:
            return ModelTrainingStepsForScalarModel(model_config, train_val_params)
    else:
        raise NotImplementedError(f"There are no model training steps defined for config type {type(model_config)}")