Exemple #1
0
def test_weighted_binary_cross_entropy_loss_multi_label() -> None:
    # Class 0 has 2 positive examples, class 1 has none
    target = torch.tensor([[1, 0], [1, 0], [0, 0]], dtype=torch.float32)
    smoothed_target = torch.tensor([[0.9, 0.1], [0.9, 0.1], [0.1, 0.1]],
                                   dtype=torch.float32)
    logits = torch.tensor([[-10, 1], [-10, 1], [10, 0]], dtype=torch.float32)
    weighted_non_smoothed_loss_fn: SupervisedLearningCriterion = \
        BinaryCrossEntropyWithLogitsLoss(num_classes=2,
                                         smoothing_eps=0,
                                         class_counts={1.0: 0, 0.0: 2},
                                         num_train_samples=target.shape[0])
    weighted_smoothed_loss_fn: SupervisedLearningCriterion = \
        BinaryCrossEntropyWithLogitsLoss(num_classes=2,
                                         smoothing_eps=0.1,
                                         class_counts={1.0: 0, 0.0: 2},
                                         num_train_samples=target.shape[0])
    non_weighted_smoothed_loss_fn: SupervisedLearningCriterion = \
        BinaryCrossEntropyWithLogitsLoss(num_classes=2,
                                         smoothing_eps=0.1,
                                         class_counts=None)
    w_loss1 = weighted_non_smoothed_loss_fn(logits, smoothed_target)
    w_loss2 = weighted_smoothed_loss_fn(logits, target)
    w_loss3 = non_weighted_smoothed_loss_fn(logits, target)
    positive_class_weights = weighted_smoothed_loss_fn.get_positive_class_weights(
    )  # type: ignore
    assert torch.isclose(w_loss1, w_loss2)
    assert not torch.isclose(w_loss2, w_loss3)
    assert torch.equal(positive_class_weights, torch.tensor([0.5, 1]))
Exemple #2
0
def test_weighted_binary_cross_entropy_loss_forward_smoothing() -> None:
    target = torch.tensor([[1], [1], [1], [1], [1], [0]], dtype=torch.float32)
    smoothed_target = torch.tensor([[0.9], [0.9], [0.9], [0.9], [0.9], [0.1]],
                                   dtype=torch.float32)
    logits = torch.tensor([[-10], [-10], [0], [0], [0], [0]],
                          dtype=torch.float32)
    weighted_non_smoothed_loss_fn: SupervisedLearningCriterion = \
        BinaryCrossEntropyWithLogitsLoss(num_classes=1,
                                         smoothing_eps=0,
                                         class_counts={1.0: 5},
                                         num_train_samples=target.shape[0])
    weighted_smoothed_loss_fn: SupervisedLearningCriterion = \
        BinaryCrossEntropyWithLogitsLoss(num_classes=1,
                                         smoothing_eps=0.1,
                                         class_counts={1.0: 5},
                                         num_train_samples=target.shape[0])
    non_weighted_smoothed_loss_fn: SupervisedLearningCriterion = BinaryCrossEntropyWithLogitsLoss(
        num_classes=1, smoothing_eps=0.1, class_counts=None)
    w_loss1 = weighted_non_smoothed_loss_fn(logits, smoothed_target)
    w_loss2 = weighted_smoothed_loss_fn(logits, target)
    w_loss3 = non_weighted_smoothed_loss_fn(logits, target)
    positive_class_weights = weighted_smoothed_loss_fn.get_positive_class_weights(
    )  # type: ignore
    assert torch.isclose(w_loss1, w_loss2)
    assert not torch.isclose(w_loss2, w_loss3)
    assert torch.all(positive_class_weights == torch.tensor([[0.2]]))
def create_scalar_loss_function(config: ScalarModelBase) -> torch.nn.Module:
    """
    Returns a torch module that computes a loss function for classification and regression models.
    """
    if config.loss_type == ScalarLoss.BinaryCrossEntropyWithLogits:
        return BinaryCrossEntropyWithLogitsLoss(num_classes=len(config.class_names),
                                                smoothing_eps=config.label_smoothing_eps)
    if config.loss_type == ScalarLoss.WeightedCrossEntropyWithLogits:
        return BinaryCrossEntropyWithLogitsLoss(
            num_classes=len(config.class_names),
            smoothing_eps=config.label_smoothing_eps,
            class_counts=config.get_training_class_counts(),
            num_train_samples=config.get_total_number_of_training_samples())
    elif config.loss_type == ScalarLoss.MeanSquaredError:
        return MSELoss()
    else:
        raise NotImplementedError(f"Loss type {config.loss_type} is not implemented")
Exemple #4
0
def test_cross_entropy_loss_forward_smoothing(is_segmentation: bool) -> None:
    target = torch.tensor([[[0, 0, 1], [1, 1, 0]]], dtype=torch.float32)
    smoothed_target = torch.tensor([[[0.1, 0.1, 0.9], [0.9, 0.9, 0.1]]], dtype=torch.float32)
    logits = torch.tensor([[[-10, -10, 0], [0, 0, 0]]], dtype=torch.float32)

    barely_smoothed_loss_fn: SupervisedLearningCriterion = BinaryCrossEntropyWithLogitsLoss(smoothing_eps=0)
    smoothed_loss_fn: SupervisedLearningCriterion = BinaryCrossEntropyWithLogitsLoss(smoothing_eps=0.1)
    if is_segmentation:
        # The two loss values are only expected to be the same when no class weighting takes place,
        # because weighting is done on the *unsmoothed* target values.
        # We can't use a completely unsmoothed loss function because it won't like non-one-hot targets.
        barely_smoothed_loss_fn = CrossEntropyLoss(class_weight_power=0.0, smoothing_eps=1e-9)
        smoothed_loss_fn = CrossEntropyLoss(class_weight_power=0.0, smoothing_eps=0.1)

    loss1 = barely_smoothed_loss_fn(logits, smoothed_target)
    loss2 = smoothed_loss_fn(logits, target)
    assert torch.isclose(loss1, loss2)
Exemple #5
0
def test_weighted_binary_cross_entropy_loss_multi_target() -> None:
    target = torch.tensor([[[1], [0]], [[1], [0]], [[0], [0]]], dtype=torch.float32)
    smoothed_target = torch.tensor([[[0.9], [0.1]], [[0.9], [0.1]], [[0.1], [0.1]]], dtype=torch.float32)
    logits = torch.tensor([[[-10], [1]], [[-10], [1]], [[10], [0]]], dtype=torch.float32)
    weighted_non_smoothed_loss_fn: SupervisedLearningCriterion = \
        BinaryCrossEntropyWithLogitsLoss(smoothing_eps=0, class_counts={1.0: 2, 0.0: 4})
    weighted_smoothed_loss_fn: SupervisedLearningCriterion = \
        BinaryCrossEntropyWithLogitsLoss(smoothing_eps=0.1, class_counts={1.0: 2, 0.0: 4})
    non_weighted_smoothed_loss_fn: SupervisedLearningCriterion = \
        BinaryCrossEntropyWithLogitsLoss(smoothing_eps=0.1, class_counts=None)
    w_loss1 = weighted_non_smoothed_loss_fn(logits, smoothed_target)
    w_loss2 = weighted_smoothed_loss_fn(logits, target)
    w_loss3 = non_weighted_smoothed_loss_fn(logits, target)
    positive_class_weights = weighted_smoothed_loss_fn.get_positive_class_weights()  # type: ignore
    assert torch.isclose(w_loss1, w_loss2)
    assert not torch.isclose(w_loss2, w_loss3)
    assert torch.all(positive_class_weights == torch.tensor(2))
Exemple #6
0
def test_invalid_initialization(num_classes: int,
                                class_counts: Dict[float, int]) -> None:
    with pytest.raises(ValueError) as ex:
        BinaryCrossEntropyWithLogitsLoss(num_classes=num_classes,
                                         smoothing_eps=0,
                                         class_counts=class_counts,
                                         num_train_samples=10)
    assert f"Have {num_classes} classes but got counts for {len(class_counts)} classes" in str(
        ex)
Exemple #7
0
 def create_loss_function(self) -> torch.nn.Module:
     """
     Returns a torch module that computes a loss function.
     Depending on the chosen loss function, the required data type for the labels tensor is set in
     self.
     """
     if self.model_config.loss_type == ScalarLoss.BinaryCrossEntropyWithLogits:
         return BinaryCrossEntropyWithLogitsLoss(
             smoothing_eps=self.model_config.label_smoothing_eps)
     if self.model_config.loss_type == ScalarLoss.WeightedCrossEntropyWithLogits:
         return BinaryCrossEntropyWithLogitsLoss(
             smoothing_eps=self.model_config.label_smoothing_eps,
             class_counts=self.model_config.get_training_class_counts())
     elif self.model_config.loss_type == ScalarLoss.MeanSquaredError:
         self.label_tensor_dtype = torch.float32
         return MSELoss()
     else:
         raise NotImplementedError("Loss type {} is not implemented".format(
             self.model_config.loss_type))