Ejemplo n.º 1
0
 def test_ill_opts(self):
     chn_input = torch.ones((1, 2, 3))
     chn_target = torch.ones((1, 1, 3))
     with self.assertRaisesRegex(ValueError, ""):
         FocalLoss(reduction="unknown")(chn_input, chn_target)
     with self.assertRaisesRegex(ValueError, ""):
         FocalLoss(reduction=None)(chn_input, chn_target)
Ejemplo n.º 2
0
    def test_bin_seg_3d(self):
        num_classes = 2  # labels 0, 1
        # define 3d examples
        target = torch.tensor([
            # raw 0
            [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],
            # raw 1
            [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],
            # raw 2
            [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],
        ])
        # add another dimension corresponding to the batch (batch size = 1 here)
        target = target.unsqueeze(0)  # shape (1, H, W, D)
        target_one_hot = F.one_hot(target, num_classes=num_classes).permute(
            0, 4, 1, 2, 3)  # test one hot
        pred_very_good = 1000 * F.one_hot(target,
                                          num_classes=num_classes).permute(
                                              0, 4, 1, 2, 3).float() - 500.0

        # initialize the mean dice loss
        loss = FocalLoss(to_onehot_y=True)
        loss_onehot = FocalLoss(to_onehot_y=False)

        # focal loss for pred_very_good should be close to 0
        target = target.unsqueeze(1)  # shape (1, 1, H, W)
        focal_loss_good = float(loss(pred_very_good, target).cpu())
        self.assertAlmostEqual(focal_loss_good, 0.0, places=3)

        focal_loss_good = float(
            loss_onehot(pred_very_good, target_one_hot).cpu())
        self.assertAlmostEqual(focal_loss_good, 0.0, places=3)
Ejemplo n.º 3
0
 def test_ill_shape(self):
     chn_input = torch.ones((1, 2, 3))
     chn_target = torch.ones((1, 3))
     with self.assertRaisesRegex(ValueError, ""):
         FocalLoss(reduction="mean")(chn_input, chn_target)
     chn_target = torch.ones((1, 2, 3))
     with self.assertRaisesRegex(ValueError, ""):
         FocalLoss(reduction="mean")(chn_input, chn_target)
Ejemplo n.º 4
0
 def test_ill_class_weight(self):
     chn_input = torch.ones((1, 4, 3, 3))
     chn_target = torch.ones((1, 4, 3, 3))
     with self.assertRaisesRegex(ValueError, ""):
         FocalLoss(include_background=True, weight=(1.0, 1.0, 2.0))(chn_input, chn_target)
     with self.assertRaisesRegex(ValueError, ""):
         FocalLoss(include_background=False, weight=(1.0, 1.0, 1.0, 1.0))(chn_input, chn_target)
     with self.assertRaisesRegex(ValueError, ""):
         FocalLoss(include_background=False, weight=(1.0, 1.0, -1.0))(chn_input, chn_target)
Ejemplo n.º 5
0
    def test_bin_seg_2d(self):
        # define 2d examples
        target = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0],
                               [0, 0, 0, 0]])
        # add another dimension corresponding to the batch (batch size = 1 here)
        target = target.unsqueeze(0)  # shape (1, H, W)
        pred_very_good = 1000 * F.one_hot(target, num_classes=2).permute(
            0, 3, 1, 2).float()

        # initialize the mean dice loss
        loss = FocalLoss()

        # focal loss for pred_very_good should be close to 0
        target = target.unsqueeze(1)  # shape (1, 1, H, W)
        focal_loss_good = float(loss.forward(pred_very_good, target).cpu())
        self.assertAlmostEqual(focal_loss_good, 0.0, places=3)
Ejemplo n.º 6
0
    def test_consistency_with_cross_entropy_2d_no_reduction(self):
        """For gamma=0 the focal loss reduces to the cross entropy loss"""
        import numpy as np

        focal_loss = FocalLoss(to_onehot_y=False,
                               gamma=0.0,
                               reduction="none",
                               weight=1.0)
        ce = nn.BCEWithLogitsLoss(reduction="none")
        max_error = 0
        class_num = 10
        batch_size = 128
        for _ in range(100):
            # Create a random tensor of shape (batch_size, class_num, 8, 4)
            x = torch.rand(batch_size, class_num, 8, 4, requires_grad=True)
            # Create a random batch of classes
            l = torch.randint(low=0,
                              high=2,
                              size=(batch_size, class_num, 8, 4)).float()
            if torch.cuda.is_available():
                x = x.cuda()
                l = l.cuda()
            output0 = focal_loss(x, l)
            output1 = ce(x, l)
            a = output0.cpu().detach().numpy()
            b = output1.cpu().detach().numpy()
            error = np.abs(a - b)
            max_error = np.maximum(error, max_error)
            # if np.all(np.abs(a - b) > max_error):
            #     max_error = np.abs(a - b)

        assert np.allclose(max_error, 0)
Ejemplo n.º 7
0
 def test_result_onehot_target_include_bg(self):
     size = [3, 3, 5, 5]
     label = torch.randint(low=0, high=2, size=size)
     pred = torch.randn(size)
     for reduction in ["sum", "mean", "none"]:
         common_params = {
             "include_background": True,
             "to_onehot_y": False,
             "reduction": reduction
         }
         for focal_weight in [
                 None, torch.tensor([1.0, 1.0, 2.0]), (3, 2.0, 1)
         ]:
             for lambda_focal in [0.5, 1.0, 1.5]:
                 dice_focal = DiceFocalLoss(focal_weight=focal_weight,
                                            gamma=1.0,
                                            lambda_focal=lambda_focal,
                                            **common_params)
                 dice = DiceLoss(**common_params)
                 focal = FocalLoss(weight=focal_weight,
                                   gamma=1.0,
                                   **common_params)
                 result = dice_focal(pred, label)
                 expected_val = dice(
                     pred, label) + lambda_focal * focal(pred, label)
                 np.testing.assert_allclose(result, expected_val)
Ejemplo n.º 8
0
 def test_consistency_with_cross_entropy_2d(self):
     """For gamma=0 the focal loss reduces to the cross entropy loss"""
     focal_loss = FocalLoss(to_onehot_y=False,
                            gamma=0.0,
                            reduction="mean",
                            weight=1.0)
     ce = nn.BCEWithLogitsLoss(reduction="mean")
     max_error = 0
     class_num = 10
     batch_size = 128
     for _ in range(100):
         # Create a random tensor of shape (batch_size, class_num, 8, 4)
         x = torch.rand(batch_size, class_num, 8, 4, requires_grad=True)
         # Create a random batch of classes
         l = torch.randint(low=0,
                           high=2,
                           size=(batch_size, class_num, 8, 4)).float()
         if torch.cuda.is_available():
             x = x.cuda()
             l = l.cuda()
         output0 = focal_loss(x, l)
         output1 = ce(x, l)
         a = float(output0.cpu().detach())
         b = float(output1.cpu().detach())
         if abs(a - b) > max_error:
             max_error = abs(a - b)
     self.assertAlmostEqual(max_error, 0.0, places=3)
 def __init__(self, focal):
     super(Loss, self).__init__()
     self.dice = DiceLoss(include_background=False,
                          softmax=True,
                          to_onehot_y=True,
                          batch=True)
     self.focal = FocalLoss(gamma=2.0)
     self.cross_entropy = nn.CrossEntropyLoss()
     self.use_focal = focal
Ejemplo n.º 10
0
    def test_multi_class_seg_2d(self):
        num_classes = 6  # labels 0 to 5
        # define 2d examples
        target = torch.tensor([[0, 0, 0, 0], [0, 1, 2, 0], [0, 3, 4, 0], [0, 0, 0, 0]])
        # add another dimension corresponding to the batch (batch size = 1 here)
        target = target.unsqueeze(0)  # shape (1, H, W)
        pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float()
        # initialize the mean dice loss
        loss = FocalLoss(to_onehot_y=True)
        loss_onehot = FocalLoss(to_onehot_y=False)

        # focal loss for pred_very_good should be close to 0
        target_one_hot = F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2)  # test one hot
        target = target.unsqueeze(1)  # shape (1, 1, H, W)

        focal_loss_good = float(loss(pred_very_good, target).cpu())
        self.assertAlmostEqual(focal_loss_good, 0.0, places=3)

        focal_loss_good = float(loss_onehot(pred_very_good, target_one_hot).cpu())
        self.assertAlmostEqual(focal_loss_good, 0.0, places=3)
Ejemplo n.º 11
0
 def test_consistency_with_cross_entropy_classification(self):
     # for gamma=0 the focal loss reduces to the cross entropy loss
     focal_loss = FocalLoss(gamma=0.0, reduction="mean")
     ce = nn.CrossEntropyLoss(reduction="mean")
     max_error = 0
     class_num = 10
     batch_size = 128
     for _ in range(100):
         # Create a random scores tensor of shape (batch_size, class_num)
         x = torch.rand(batch_size, class_num, requires_grad=True)
         # Create a random batch of classes
         l = torch.randint(low=0, high=class_num, size=(batch_size, 1))
         l = l.long()
         if torch.cuda.is_available():
             x = x.cuda()
             l = l.cuda()
         output0 = focal_loss.forward(x, l)
         output1 = ce.forward(x, l[:, 0])
         a = float(output0.cpu().detach())
         b = float(output1.cpu().detach())
         if abs(a - b) > max_error:
             max_error = abs(a - b)
     self.assertAlmostEqual(max_error, 0.0, places=3)
Ejemplo n.º 12
0
    def test_foreground(self):
        background = torch.ones(1, 1, 5, 5)
        foreground = torch.zeros(1, 1, 5, 5)
        target = torch.cat((background, foreground), dim=1)
        input = torch.cat((background, foreground), dim=1)
        target[:, 0, 2, 2] = 0
        target[:, 1, 2, 2] = 1

        fgbg = FocalLoss(to_onehot_y=False, include_background=True)(input,
                                                                     target)
        fg = FocalLoss(to_onehot_y=False, include_background=False)(input,
                                                                    target)
        self.assertAlmostEqual(float(fgbg.cpu()), 0.1116, places=3)
        self.assertAlmostEqual(float(fg.cpu()), 0.1733, places=3)
Ejemplo n.º 13
0
 def test_consistency_with_cross_entropy_classification_01(self):
     # for gamma=0.1 the focal loss differs from the cross entropy loss
     focal_loss = FocalLoss(to_onehot_y=True, gamma=0.1, reduction="mean")
     ce = nn.BCEWithLogitsLoss(reduction="mean")
     max_error = 0
     class_num = 10
     batch_size = 128
     for _ in range(100):
         # Create a random scores tensor of shape (batch_size, class_num)
         x = torch.rand(batch_size, class_num, requires_grad=True)
         # Create a random batch of classes
         l = torch.randint(low=0, high=class_num, size=(batch_size, 1))
         l = l.long()
         if torch.cuda.is_available():
             x = x.cuda()
             l = l.cuda()
         output0 = focal_loss(x, l)
         output1 = ce(x, one_hot(l, num_classes=class_num))
         a = float(output0.cpu().detach())
         b = float(output1.cpu().detach())
         if abs(a - b) > max_error:
             max_error = abs(a - b)
     self.assertNotAlmostEqual(max_error, 0.0, places=3)
 def test_result_no_onehot_no_bg(self):
     size = [3, 3, 5, 5]
     label = torch.randint(low=0, high=2, size=size)
     label = torch.argmax(label, dim=1, keepdim=True)
     pred = torch.randn(size)
     for reduction in ["sum", "mean", "none"]:
         common_params = {
             "include_background": False,
             "to_onehot_y": True,
             "reduction": reduction
         }
         for focal_weight in [2.0, torch.tensor([1.0, 2.0]), (2.0, 1)]:
             for lambda_focal in [0.5, 1.0, 1.5]:
                 generalized_dice_focal = GeneralizedDiceFocalLoss(
                     focal_weight=focal_weight,
                     lambda_focal=lambda_focal,
                     **common_params)
                 generalized_dice = GeneralizedDiceLoss(**common_params)
                 focal = FocalLoss(weight=focal_weight, **common_params)
                 result = generalized_dice_focal(pred, label)
                 expected_val = generalized_dice(
                     pred, label) + lambda_focal * focal(pred, label)
                 np.testing.assert_allclose(result, expected_val)
Ejemplo n.º 15
0
 def test_script(self):
     loss = FocalLoss()
     test_input = torch.ones(2, 2, 8, 8)
     test_script_save(loss, test_input, test_input)
Ejemplo n.º 16
0
 def __init__(self, focal):
     super(Loss, self).__init__()
     self.dice = DiceLoss()
     self.cross_entropy = nn.CrossEntropyLoss()
     self.focal = FocalLoss(gamma=2.0)
     self.use_focal = focal
Ejemplo n.º 17
0
 def __init__(self, focal):
     super(LossBraTS, self).__init__()
     self.dice = DiceLoss(sigmoid=True, batch=True)
     self.ce = FocalLoss(
         gamma=2.0, to_onehot_y=False) if focal else nn.BCEWithLogitsLoss()
Ejemplo n.º 18
0
 def __init__(self, loss):
     super().__init__()
     self.loss = loss
     self.focal = FocalLoss(gamma=2.0)
     self.dice_bg = DiceLoss(include_background=True, softmax=True, to_onehot_y=True, batch=True)
     self.dice_nbg = DiceLoss(include_background=False, softmax=True, to_onehot_y=True, batch=True)
Ejemplo n.º 19
0
    def test_convergence(self):
        """
        The goal of this test is to assess if the gradient of the loss function
        is correct by testing if we can train a one layer neural network
        to segment one image.
        We verify that the loss is decreasing in almost all SGD steps.
        """
        learning_rate = 0.001
        max_iter = 20

        # define a simple 3d example
        target_seg = torch.tensor([
            # raw 0
            [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],
            # raw 1
            [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],
            # raw 2
            [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],
        ])
        target_seg = torch.unsqueeze(target_seg, dim=0)
        image = 12 * target_seg + 27
        image = image.float()
        num_classes = 2
        num_voxels = 3 * 4 * 4

        # define a one layer model
        class OnelayerNet(nn.Module):
            def __init__(self):
                super(OnelayerNet, self).__init__()
                self.layer = nn.Linear(num_voxels, num_voxels * num_classes)

            def forward(self, x):
                x = x.view(-1, num_voxels)
                x = self.layer(x)
                x = x.view(-1, num_classes, 3, 4, 4)
                return x

        # initialise the network
        net = OnelayerNet()

        # initialize the loss
        loss = FocalLoss()

        # initialize an SGD
        optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)

        loss_history = []
        # train the network
        for _ in range(max_iter):
            # set the gradient to zero
            optimizer.zero_grad()

            # forward pass
            output = net(image)
            loss_val = loss(output, target_seg)

            # backward pass
            loss_val.backward()
            optimizer.step()

            # stats
            loss_history.append(loss_val.item())

        # count the number of SGD steps in which the loss decreases
        num_decreasing_steps = 0
        for i in range(len(loss_history) - 1):
            if loss_history[i] > loss_history[i + 1]:
                num_decreasing_steps += 1
        decreasing_steps_ratio = float(num_decreasing_steps) / (
            len(loss_history) - 1)

        # verify that the loss is decreasing for sufficiently many SGD steps
        self.assertTrue(decreasing_steps_ratio > 0.9)
Ejemplo n.º 20
0
def train(n_feat,
          crop_size,
          bs,
          ep,
          optimizer="rmsprop",
          lr=5e-4,
          pretrain=None):
    model_name = f"./HaN_{n_feat}_{bs}_{ep}_{crop_size}_{lr}_"
    print(f"save the best model as '{model_name}' during training.")

    crop_size = [int(cz) for cz in crop_size.split(",")]
    print(f"input image crop_size: {crop_size}")

    # starting training set loader
    train_images = ImageLabelDataset(path=TRAIN_PATH, n_class=N_CLASSES)
    if np.any([cz == -1 for cz in crop_size]):  # using full image
        train_transform = Compose([
            AddChannelDict(keys="image"),
            Rand3DElasticd(
                keys=("image", "label"),
                spatial_size=crop_size,
                sigma_range=(10, 50),  # 30
                magnitude_range=(600, 1200),  # 1000
                prob=0.8,
                rotate_range=(np.pi / 12, np.pi / 12, np.pi / 12),
                shear_range=(np.pi / 18, np.pi / 18, np.pi / 18),
                translate_range=tuple(sz * 0.05 for sz in crop_size),
                scale_range=(0.2, 0.2, 0.2),
                mode=("bilinear", "nearest"),
                padding_mode=("border", "zeros"),
            ),
        ])
        train_dataset = Dataset(train_images, transform=train_transform)
        # when bs > 1, the loader assumes that the full image sizes are the same across the dataset
        train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                       num_workers=4,
                                                       batch_size=bs,
                                                       shuffle=True)
    else:
        # draw balanced foreground/background window samples according to the ground truth label
        train_transform = Compose([
            AddChannelDict(keys="image"),
            SpatialPadd(
                keys=("image", "label"),
                spatial_size=crop_size),  # ensure image size >= crop_size
            RandCropByPosNegLabeld(keys=("image", "label"),
                                   label_key="label",
                                   spatial_size=crop_size,
                                   num_samples=bs),
            Rand3DElasticd(
                keys=("image", "label"),
                spatial_size=crop_size,
                sigma_range=(10, 50),  # 30
                magnitude_range=(600, 1200),  # 1000
                prob=0.8,
                rotate_range=(np.pi / 12, np.pi / 12, np.pi / 12),
                shear_range=(np.pi / 18, np.pi / 18, np.pi / 18),
                translate_range=tuple(sz * 0.05 for sz in crop_size),
                scale_range=(0.2, 0.2, 0.2),
                mode=("bilinear", "nearest"),
                padding_mode=("border", "zeros"),
            ),
        ])
        train_dataset = Dataset(train_images, transform=train_transform
                                )  # each dataset item is a list of windows
        train_dataloader = torch.utils.data.DataLoader(  # stack each dataset item into a single tensor
            train_dataset,
            num_workers=4,
            batch_size=1,
            shuffle=True,
            collate_fn=list_data_collate)
    first_sample = first(train_dataloader)
    print(first_sample["image"].shape)

    # starting validation set loader
    val_transform = Compose([AddChannelDict(keys="image")])
    val_dataset = Dataset(ImageLabelDataset(VAL_PATH, n_class=N_CLASSES),
                          transform=val_transform)
    val_dataloader = torch.utils.data.DataLoader(val_dataset,
                                                 num_workers=1,
                                                 batch_size=1)
    print(val_dataset[0]["image"].shape)
    print(
        f"training images: {len(train_dataloader)}, validation images: {len(val_dataloader)}"
    )

    model = UNetPipe(spatial_dims=3,
                     in_channels=1,
                     out_channels=N_CLASSES,
                     n_feat=n_feat)
    model = flatten_sequential(model)
    lossweight = torch.from_numpy(
        np.array([2.22, 1.31, 1.99, 1.13, 1.93, 1.93, 1.0, 1.0, 1.90, 1.98],
                 np.float32))

    if optimizer.lower() == "rmsprop":
        optimizer = torch.optim.RMSprop(model.parameters(), lr=lr)  # lr = 5e-4
    elif optimizer.lower() == "momentum":
        optimizer = torch.optim.SGD(model.parameters(), lr=lr,
                                    momentum=0.9)  # lr = 1e-4 for finetuning
    else:
        raise ValueError(
            f"Unknown optimizer type {optimizer}. (options are 'rmsprop' and 'momentum')."
        )

    # config GPipe
    x = first_sample["image"].float()
    x = torch.autograd.Variable(x.cuda())
    partitions = torch.cuda.device_count()
    print(f"partition: {partitions}, input: {x.size()}")
    balance = balance_by_size(partitions, model, x)
    model = GPipe(model, balance, chunks=4, checkpoint="always")

    # config loss functions
    dice_loss_func = DiceLoss(softmax=True, reduction="none")
    # use the same pipeline and loss in
    # AnatomyNet: Deep learning for fast and fully automated whole‐volume segmentation of head and neck anatomy,
    # Medical Physics, 2018.
    focal_loss_func = FocalLoss(reduction="none")

    if pretrain:
        print(f"loading from {pretrain}.")
        pretrained_dict = torch.load(pretrain)["weight"]
        model_dict = model.state_dict()
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        model_dict.update(pretrained_dict)
        model.load_state_dict(pretrained_dict)

    b_time = time.time()
    best_val_loss = [0] * (N_CLASSES - 1)  # foreground
    for epoch in range(ep):
        model.train()
        trainloss = 0
        for b_idx, data_dict in enumerate(train_dataloader):
            x_train = data_dict["image"]
            y_train = data_dict["label"]
            flagvec = data_dict["with_complete_groundtruth"]

            x_train = torch.autograd.Variable(x_train.cuda())
            y_train = torch.autograd.Variable(y_train.cuda().float())
            optimizer.zero_grad()
            o = model(x_train).to(0, non_blocking=True).float()

            loss = (dice_loss_func(o, y_train.to(o)) * flagvec.to(o) *
                    lossweight.to(o)).mean()
            loss += 0.5 * (focal_loss_func(o, y_train.to(o)) * flagvec.to(o) *
                           lossweight.to(o)).mean()
            loss.backward()
            optimizer.step()
            trainloss += loss.item()

            if b_idx % 20 == 0:
                print(
                    f"Train Epoch: {epoch} [{b_idx}/{len(train_dataloader)}] \tLoss: {loss.item()}"
                )
        print(f"epoch {epoch} TRAIN loss {trainloss / len(train_dataloader)}")

        if epoch % 10 == 0:
            model.eval()
            # check validation dice
            val_loss = [0] * (N_CLASSES - 1)
            n_val = [0] * (N_CLASSES - 1)
            for data_dict in val_dataloader:
                x_val = data_dict["image"]
                y_val = data_dict["label"]
                with torch.no_grad():
                    x_val = torch.autograd.Variable(x_val.cuda())
                o = model(x_val).to(0, non_blocking=True)
                loss = compute_meandice(o,
                                        y_val.to(o),
                                        mutually_exclusive=True,
                                        include_background=False)
                val_loss = [
                    l.item() + tl if l == l else tl
                    for l, tl in zip(loss[0], val_loss)
                ]
                n_val = [
                    n + 1 if l == l else n for l, n in zip(loss[0], n_val)
                ]
            val_loss = [l / n for l, n in zip(val_loss, n_val)]
            print(
                "validation scores %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f"
                % tuple(val_loss))
            for c in range(1, 10):
                if best_val_loss[c - 1] < val_loss[c - 1]:
                    best_val_loss[c - 1] = val_loss[c - 1]
                    state = {
                        "epoch": epoch,
                        "weight": model.state_dict(),
                        "score_" + str(c): best_val_loss[c - 1]
                    }
                    torch.save(state, f"{model_name}" + str(c))
            print(
                "best validation scores %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f"
                % tuple(best_val_loss))

    print("total time", time.time() - b_time)