コード例 #1
0
 def construct_non_mixture_loss_function(
         cls, model_config: SegmentationModelBase,
         loss_type: SegmentationLoss,
         power: Optional[float]) -> SupervisedLearningCriterion:
     """
     :param model_config: model configuration to get some parameters from
     :param loss_type: type of loss function
     :param power: value for class_weight_power for the loss function
     :return: instance of loss function
     """
     if loss_type == SegmentationLoss.SoftDice:
         return SoftDiceLoss(class_weight_power=power)
     elif loss_type == SegmentationLoss.CrossEntropy:
         return CrossEntropyLoss(
             class_weight_power=power,
             smoothing_eps=model_config.label_smoothing_eps,
             focal_loss_gamma=None)
     elif loss_type == SegmentationLoss.Focal:
         return CrossEntropyLoss(
             class_weight_power=power,
             smoothing_eps=model_config.label_smoothing_eps,
             focal_loss_gamma=model_config.focal_loss_gamma)
     else:
         raise NotImplementedError(
             "Loss type {} is not implemented".format(loss_type))
コード例 #2
0
def test_get_focal_loss_pixel_weights() -> None:
    """
    Weights for the correctly predicted (logits) pixels should be closer to zero,
    and wrong predictions should have higher weights. The total sum of weights should be
    equal to the number of pixels in order not to change the scale of loss function.
    """
    x_entropy_loss = CrossEntropyLoss(focal_loss_gamma=2.0)

    target = torch.tensor([[[1, 0, 0], [0, 1, 1]]], dtype=torch.float32)
    logits = torch.tensor([[[0, 0, 0], [-1e9, 0, 0]]], dtype=torch.float32)
    pixel_weights = x_entropy_loss._get_focal_loss_pixel_weights(logits=logits,
                                                                 target=target)
    assert torch.allclose(torch.masked_select(pixel_weights, target.eq(1.0)),
                          torch.tensor([0.00, 1.50, 1.50]))
コード例 #3
0
def test_cross_entropy_loss_integration(focal_loss_gamma: float,
                                        loss_upper_bound: float,
                                        class_weight_power: float) -> None:
    """
    Solves a simple linear classification problem by training a multi-layer perceptron.
    Here the training objectives (cross-entropy and focal loss) are tested to see they function
    properly when they are optimised with a stochastic optimiser.
    """
    # Set a seed
    torch.random.manual_seed(1)

    # Define hyperparameters
    n_samples = 1000
    batch_size = 16
    n_epochs = 40

    # Set the input data (labels 1000x2, features 1000x50, 1000 samples, 50 dimensional features)
    features = torch.cat([torch.randn(n_samples // 2, 50),
                          torch.randn(n_samples // 2, 50) + 1.5], dim=0)
    indices = torch.cat([torch.zeros(n_samples // 2, dtype=torch.long),
                         torch.ones(n_samples // 2, dtype=torch.long)], dim=0)
    labels = torch.nn.functional.one_hot(indices, num_classes=2).float()

    # Shuffle the dataset
    perm = torch.randperm(n_samples)
    features = features[perm, :]
    labels = labels[perm, :]

    # Define a basic model (We actually don't even a non-linear unit to solve it)
    net = ToyNet()
    opt = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999))
    loss_fn = CrossEntropyLoss(class_weight_power=class_weight_power, focal_loss_gamma=focal_loss_gamma)

    # Perform forward and backward passes
    net.train()
    epoch_losses = []
    loss = torch.empty(0)  # to ensure never unset
    for epoch_id in range(0, n_epochs):
        for beg_i in range(0, features.size(0), batch_size):
            x_batch = features[beg_i:beg_i + batch_size, :]
            y_batch = labels[beg_i:beg_i + batch_size, :]

            opt.zero_grad()
            # (1) Forward
            y_hat = net(x_batch)
            # (2) Compute diff
            loss = loss_fn(y_hat, y_batch)
            # (3) Compute gradients
            loss.backward()
            # (4) update weights
            opt.step()

        # Add final epoch loss to the list
        epoch_losses.append(loss.data.numpy())

    # And see if loss is decaying for a given problem
    assert epoch_losses[0] > 0.10
    assert epoch_losses[10] < epoch_losses[0]
    assert epoch_losses[15] < epoch_losses[5]
    assert epoch_losses[n_epochs - 1] < loss_upper_bound
コード例 #4
0
def test_cross_entropy_loss_forward_smoothing(is_segmentation: bool) -> None:
    target = torch.tensor([[[0, 0, 1], [1, 1, 0]]], dtype=torch.float32)
    smoothed_target = torch.tensor([[[0.1, 0.1, 0.9], [0.9, 0.9, 0.1]]], dtype=torch.float32)
    logits = torch.tensor([[[-10, -10, 0], [0, 0, 0]]], dtype=torch.float32)

    barely_smoothed_loss_fn: SupervisedLearningCriterion = BinaryCrossEntropyWithLogitsLoss(smoothing_eps=0)
    smoothed_loss_fn: SupervisedLearningCriterion = BinaryCrossEntropyWithLogitsLoss(smoothing_eps=0.1)
    if is_segmentation:
        # The two loss values are only expected to be the same when no class weighting takes place,
        # because weighting is done on the *unsmoothed* target values.
        # We can't use a completely unsmoothed loss function because it won't like non-one-hot targets.
        barely_smoothed_loss_fn = CrossEntropyLoss(class_weight_power=0.0, smoothing_eps=1e-9)
        smoothed_loss_fn = CrossEntropyLoss(class_weight_power=0.0, smoothing_eps=0.1)

    loss1 = barely_smoothed_loss_fn(logits, smoothed_target)
    loss2 = smoothed_loss_fn(logits, target)
    assert torch.isclose(loss1, loss2)
コード例 #5
0
def test_cross_entropy_loss_forward_zero_loss() -> None:
    target = torch.tensor([[[0, 0, 0], [1, 1, 1]]], dtype=torch.float32)
    logits = torch.tensor([[[-1e9, -1e9, -1e9], [0, 0, 0]]], dtype=torch.float32)

    # Extract class indices
    loss_fn = CrossEntropyLoss(class_weight_power=0.0)
    loss = loss_fn(logits, target)

    assert torch.isclose(loss, torch.tensor([0.000]))
コード例 #6
0
def test_cross_entropy_loss_forward_balanced() -> None:
    # target: one-hot, B=1, C=2, N=3 voxels. First two voxels are class 1, last is class 0.
    target = torch.tensor([[[0, 0, 1], [1, 1, 0]]], dtype=torch.float32)
    # logits: predicting class 1 (correctly) at first two voxels, 50-50 at last voxel.
    logits = torch.tensor([[[-1e9, -1e9, 0], [0, 0, 0]]], dtype=torch.float32)

    # Compute loss values for unbalanced case.
    loss_fn = CrossEntropyLoss(class_weight_power=0.0)
    loss = loss_fn(logits, target)
    # Loss is (nearly) all from last voxel: -log(0.5). This is averaged over all 3 voxels (divide by 3).
    expected = -1 * torch.log(torch.tensor(0.5)) / 3.0
    assert (torch.isclose(loss, expected))

    # Compute loss values for balanced case.
    loss_fn = CrossEntropyLoss(class_weight_power=1.0)
    loss = loss_fn(logits, target)
    # Class weights should be 4/3 for class 0, 3/3 for class 1 (inverses of class frequencies, normalized
    # to average to 1). Loss comes from the uncertainty on the last voxel which is class 0...
    expected = expected * 4 / 3
    assert (torch.isclose(loss, expected))
コード例 #7
0
def test_focal_loss_cross_entropy_equivalence(
        use_class_balancing: bool) -> None:
    """
    Focal loss and cross-entropy loss should be equivalent to each other when the gamma parameter is set to zero.
    And this should also be independent from the class balancing term.
    """
    power = 1.0 if use_class_balancing else 0.0
    loss_fn_wout_focal_loss = CrossEntropyLoss(class_weight_power=power,
                                               focal_loss_gamma=None)
    loss_fn_w_focal_loss = CrossEntropyLoss(class_weight_power=power,
                                            focal_loss_gamma=0.0)

    class_indices = torch.randint(0, 5, torch.Size([1, 16, 16]))
    target = torch.nn.functional.one_hot(
        class_indices, num_classes=5).float().permute([0, 3, 1, 2])
    logits = torch.rand(torch.Size([1, 5, 16, 16]))

    loss1 = loss_fn_wout_focal_loss(logits, target)
    loss2 = loss_fn_w_focal_loss(logits, target)
    assert (torch.isclose(loss1, loss2))
コード例 #8
0
def test_focal_loss_forward_balanced(use_class_balancing: bool,
                                     expected_loss: torch.Tensor) -> None:
    """
    When logits are the same for both classes, cross entropy should return [0.5, 0.5] posterior probabilities.
    Loss for that particular pixel should be equal to -log(0.5) (negative log-likelihood). When loss terms are
    mean aggregated across 3 pixels or 2 classes, the result should be equal to -log(0.5)/3 and -log(0.5)/2
    respectively. Since the other two pixels are correctly predicted, their loss terms are equal to zero.
    """
    target = torch.tensor([[[0, 0, 1], [1, 1, 0]]], dtype=torch.float32)
    logits = torch.tensor([[[-1e9, -1e9, 0], [0, 0, 0]]], dtype=torch.float32)

    # Compute loss values for both balanced and unbalanced cases
    loss_fn = CrossEntropyLoss(
        class_weight_power=1.0 if use_class_balancing else 0.0,
        focal_loss_gamma=0.0)
    loss = loss_fn(logits, target)
    assert (torch.isclose(loss, expected_loss))
コード例 #9
0
def test_get_class_weights() -> None:
    target = torch.tensor([[2, 2, 1, 2, 2], [3, 3, 3, 3, 3]], dtype=torch.long)
    weights = CrossEntropyLoss._get_class_weights(target_labels=target, num_classes=4)
    assert torch.eq(weights, torch.tensor([0.00, 1.00, 0.25, 0.20])).all()