def test_ill_opts(self): chn_input = torch.ones((1, 2, 3)) chn_target = torch.ones((1, 1, 3)) with self.assertRaisesRegex(ValueError, ""): FocalLoss(reduction="unknown")(chn_input, chn_target) with self.assertRaisesRegex(ValueError, ""): FocalLoss(reduction=None)(chn_input, chn_target)
def test_bin_seg_3d(self): num_classes = 2 # labels 0, 1 # define 3d examples target = torch.tensor([ # raw 0 [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]], # raw 1 [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]], # raw 2 [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]], ]) # add another dimension corresponding to the batch (batch size = 1 here) target = target.unsqueeze(0) # shape (1, H, W, D) target_one_hot = F.one_hot(target, num_classes=num_classes).permute( 0, 4, 1, 2, 3) # test one hot pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute( 0, 4, 1, 2, 3).float() - 500.0 # initialize the mean dice loss loss = FocalLoss(to_onehot_y=True) loss_onehot = FocalLoss(to_onehot_y=False) # focal loss for pred_very_good should be close to 0 target = target.unsqueeze(1) # shape (1, 1, H, W) focal_loss_good = float(loss(pred_very_good, target).cpu()) self.assertAlmostEqual(focal_loss_good, 0.0, places=3) focal_loss_good = float( loss_onehot(pred_very_good, target_one_hot).cpu()) self.assertAlmostEqual(focal_loss_good, 0.0, places=3)
def test_ill_shape(self): chn_input = torch.ones((1, 2, 3)) chn_target = torch.ones((1, 3)) with self.assertRaisesRegex(ValueError, ""): FocalLoss(reduction="mean")(chn_input, chn_target) chn_target = torch.ones((1, 2, 3)) with self.assertRaisesRegex(ValueError, ""): FocalLoss(reduction="mean")(chn_input, chn_target)
def test_ill_class_weight(self): chn_input = torch.ones((1, 4, 3, 3)) chn_target = torch.ones((1, 4, 3, 3)) with self.assertRaisesRegex(ValueError, ""): FocalLoss(include_background=True, weight=(1.0, 1.0, 2.0))(chn_input, chn_target) with self.assertRaisesRegex(ValueError, ""): FocalLoss(include_background=False, weight=(1.0, 1.0, 1.0, 1.0))(chn_input, chn_target) with self.assertRaisesRegex(ValueError, ""): FocalLoss(include_background=False, weight=(1.0, 1.0, -1.0))(chn_input, chn_target)
def test_bin_seg_2d(self): # define 2d examples target = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]]) # add another dimension corresponding to the batch (batch size = 1 here) target = target.unsqueeze(0) # shape (1, H, W) pred_very_good = 1000 * F.one_hot(target, num_classes=2).permute( 0, 3, 1, 2).float() # initialize the mean dice loss loss = FocalLoss() # focal loss for pred_very_good should be close to 0 target = target.unsqueeze(1) # shape (1, 1, H, W) focal_loss_good = float(loss.forward(pred_very_good, target).cpu()) self.assertAlmostEqual(focal_loss_good, 0.0, places=3)
def test_consistency_with_cross_entropy_2d_no_reduction(self): """For gamma=0 the focal loss reduces to the cross entropy loss""" import numpy as np focal_loss = FocalLoss(to_onehot_y=False, gamma=0.0, reduction="none", weight=1.0) ce = nn.BCEWithLogitsLoss(reduction="none") max_error = 0 class_num = 10 batch_size = 128 for _ in range(100): # Create a random tensor of shape (batch_size, class_num, 8, 4) x = torch.rand(batch_size, class_num, 8, 4, requires_grad=True) # Create a random batch of classes l = torch.randint(low=0, high=2, size=(batch_size, class_num, 8, 4)).float() if torch.cuda.is_available(): x = x.cuda() l = l.cuda() output0 = focal_loss(x, l) output1 = ce(x, l) a = output0.cpu().detach().numpy() b = output1.cpu().detach().numpy() error = np.abs(a - b) max_error = np.maximum(error, max_error) # if np.all(np.abs(a - b) > max_error): # max_error = np.abs(a - b) assert np.allclose(max_error, 0)
def test_result_onehot_target_include_bg(self): size = [3, 3, 5, 5] label = torch.randint(low=0, high=2, size=size) pred = torch.randn(size) for reduction in ["sum", "mean", "none"]: common_params = { "include_background": True, "to_onehot_y": False, "reduction": reduction } for focal_weight in [ None, torch.tensor([1.0, 1.0, 2.0]), (3, 2.0, 1) ]: for lambda_focal in [0.5, 1.0, 1.5]: dice_focal = DiceFocalLoss(focal_weight=focal_weight, gamma=1.0, lambda_focal=lambda_focal, **common_params) dice = DiceLoss(**common_params) focal = FocalLoss(weight=focal_weight, gamma=1.0, **common_params) result = dice_focal(pred, label) expected_val = dice( pred, label) + lambda_focal * focal(pred, label) np.testing.assert_allclose(result, expected_val)
def test_consistency_with_cross_entropy_2d(self): """For gamma=0 the focal loss reduces to the cross entropy loss""" focal_loss = FocalLoss(to_onehot_y=False, gamma=0.0, reduction="mean", weight=1.0) ce = nn.BCEWithLogitsLoss(reduction="mean") max_error = 0 class_num = 10 batch_size = 128 for _ in range(100): # Create a random tensor of shape (batch_size, class_num, 8, 4) x = torch.rand(batch_size, class_num, 8, 4, requires_grad=True) # Create a random batch of classes l = torch.randint(low=0, high=2, size=(batch_size, class_num, 8, 4)).float() if torch.cuda.is_available(): x = x.cuda() l = l.cuda() output0 = focal_loss(x, l) output1 = ce(x, l) a = float(output0.cpu().detach()) b = float(output1.cpu().detach()) if abs(a - b) > max_error: max_error = abs(a - b) self.assertAlmostEqual(max_error, 0.0, places=3)
def __init__(self, focal): super(Loss, self).__init__() self.dice = DiceLoss(include_background=False, softmax=True, to_onehot_y=True, batch=True) self.focal = FocalLoss(gamma=2.0) self.cross_entropy = nn.CrossEntropyLoss() self.use_focal = focal
def test_multi_class_seg_2d(self): num_classes = 6 # labels 0 to 5 # define 2d examples target = torch.tensor([[0, 0, 0, 0], [0, 1, 2, 0], [0, 3, 4, 0], [0, 0, 0, 0]]) # add another dimension corresponding to the batch (batch size = 1 here) target = target.unsqueeze(0) # shape (1, H, W) pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float() # initialize the mean dice loss loss = FocalLoss(to_onehot_y=True) loss_onehot = FocalLoss(to_onehot_y=False) # focal loss for pred_very_good should be close to 0 target_one_hot = F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2) # test one hot target = target.unsqueeze(1) # shape (1, 1, H, W) focal_loss_good = float(loss(pred_very_good, target).cpu()) self.assertAlmostEqual(focal_loss_good, 0.0, places=3) focal_loss_good = float(loss_onehot(pred_very_good, target_one_hot).cpu()) self.assertAlmostEqual(focal_loss_good, 0.0, places=3)
def test_consistency_with_cross_entropy_classification(self): # for gamma=0 the focal loss reduces to the cross entropy loss focal_loss = FocalLoss(gamma=0.0, reduction="mean") ce = nn.CrossEntropyLoss(reduction="mean") max_error = 0 class_num = 10 batch_size = 128 for _ in range(100): # Create a random scores tensor of shape (batch_size, class_num) x = torch.rand(batch_size, class_num, requires_grad=True) # Create a random batch of classes l = torch.randint(low=0, high=class_num, size=(batch_size, 1)) l = l.long() if torch.cuda.is_available(): x = x.cuda() l = l.cuda() output0 = focal_loss.forward(x, l) output1 = ce.forward(x, l[:, 0]) a = float(output0.cpu().detach()) b = float(output1.cpu().detach()) if abs(a - b) > max_error: max_error = abs(a - b) self.assertAlmostEqual(max_error, 0.0, places=3)
def test_foreground(self): background = torch.ones(1, 1, 5, 5) foreground = torch.zeros(1, 1, 5, 5) target = torch.cat((background, foreground), dim=1) input = torch.cat((background, foreground), dim=1) target[:, 0, 2, 2] = 0 target[:, 1, 2, 2] = 1 fgbg = FocalLoss(to_onehot_y=False, include_background=True)(input, target) fg = FocalLoss(to_onehot_y=False, include_background=False)(input, target) self.assertAlmostEqual(float(fgbg.cpu()), 0.1116, places=3) self.assertAlmostEqual(float(fg.cpu()), 0.1733, places=3)
def test_consistency_with_cross_entropy_classification_01(self): # for gamma=0.1 the focal loss differs from the cross entropy loss focal_loss = FocalLoss(to_onehot_y=True, gamma=0.1, reduction="mean") ce = nn.BCEWithLogitsLoss(reduction="mean") max_error = 0 class_num = 10 batch_size = 128 for _ in range(100): # Create a random scores tensor of shape (batch_size, class_num) x = torch.rand(batch_size, class_num, requires_grad=True) # Create a random batch of classes l = torch.randint(low=0, high=class_num, size=(batch_size, 1)) l = l.long() if torch.cuda.is_available(): x = x.cuda() l = l.cuda() output0 = focal_loss(x, l) output1 = ce(x, one_hot(l, num_classes=class_num)) a = float(output0.cpu().detach()) b = float(output1.cpu().detach()) if abs(a - b) > max_error: max_error = abs(a - b) self.assertNotAlmostEqual(max_error, 0.0, places=3)
def test_result_no_onehot_no_bg(self): size = [3, 3, 5, 5] label = torch.randint(low=0, high=2, size=size) label = torch.argmax(label, dim=1, keepdim=True) pred = torch.randn(size) for reduction in ["sum", "mean", "none"]: common_params = { "include_background": False, "to_onehot_y": True, "reduction": reduction } for focal_weight in [2.0, torch.tensor([1.0, 2.0]), (2.0, 1)]: for lambda_focal in [0.5, 1.0, 1.5]: generalized_dice_focal = GeneralizedDiceFocalLoss( focal_weight=focal_weight, lambda_focal=lambda_focal, **common_params) generalized_dice = GeneralizedDiceLoss(**common_params) focal = FocalLoss(weight=focal_weight, **common_params) result = generalized_dice_focal(pred, label) expected_val = generalized_dice( pred, label) + lambda_focal * focal(pred, label) np.testing.assert_allclose(result, expected_val)
def test_script(self): loss = FocalLoss() test_input = torch.ones(2, 2, 8, 8) test_script_save(loss, test_input, test_input)
def __init__(self, focal): super(Loss, self).__init__() self.dice = DiceLoss() self.cross_entropy = nn.CrossEntropyLoss() self.focal = FocalLoss(gamma=2.0) self.use_focal = focal
def __init__(self, focal): super(LossBraTS, self).__init__() self.dice = DiceLoss(sigmoid=True, batch=True) self.ce = FocalLoss( gamma=2.0, to_onehot_y=False) if focal else nn.BCEWithLogitsLoss()
def __init__(self, loss): super().__init__() self.loss = loss self.focal = FocalLoss(gamma=2.0) self.dice_bg = DiceLoss(include_background=True, softmax=True, to_onehot_y=True, batch=True) self.dice_nbg = DiceLoss(include_background=False, softmax=True, to_onehot_y=True, batch=True)
def test_convergence(self): """ The goal of this test is to assess if the gradient of the loss function is correct by testing if we can train a one layer neural network to segment one image. We verify that the loss is decreasing in almost all SGD steps. """ learning_rate = 0.001 max_iter = 20 # define a simple 3d example target_seg = torch.tensor([ # raw 0 [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]], # raw 1 [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]], # raw 2 [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]], ]) target_seg = torch.unsqueeze(target_seg, dim=0) image = 12 * target_seg + 27 image = image.float() num_classes = 2 num_voxels = 3 * 4 * 4 # define a one layer model class OnelayerNet(nn.Module): def __init__(self): super(OnelayerNet, self).__init__() self.layer = nn.Linear(num_voxels, num_voxels * num_classes) def forward(self, x): x = x.view(-1, num_voxels) x = self.layer(x) x = x.view(-1, num_classes, 3, 4, 4) return x # initialise the network net = OnelayerNet() # initialize the loss loss = FocalLoss() # initialize an SGD optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9) loss_history = [] # train the network for _ in range(max_iter): # set the gradient to zero optimizer.zero_grad() # forward pass output = net(image) loss_val = loss(output, target_seg) # backward pass loss_val.backward() optimizer.step() # stats loss_history.append(loss_val.item()) # count the number of SGD steps in which the loss decreases num_decreasing_steps = 0 for i in range(len(loss_history) - 1): if loss_history[i] > loss_history[i + 1]: num_decreasing_steps += 1 decreasing_steps_ratio = float(num_decreasing_steps) / ( len(loss_history) - 1) # verify that the loss is decreasing for sufficiently many SGD steps self.assertTrue(decreasing_steps_ratio > 0.9)
def train(n_feat, crop_size, bs, ep, optimizer="rmsprop", lr=5e-4, pretrain=None): model_name = f"./HaN_{n_feat}_{bs}_{ep}_{crop_size}_{lr}_" print(f"save the best model as '{model_name}' during training.") crop_size = [int(cz) for cz in crop_size.split(",")] print(f"input image crop_size: {crop_size}") # starting training set loader train_images = ImageLabelDataset(path=TRAIN_PATH, n_class=N_CLASSES) if np.any([cz == -1 for cz in crop_size]): # using full image train_transform = Compose([ AddChannelDict(keys="image"), Rand3DElasticd( keys=("image", "label"), spatial_size=crop_size, sigma_range=(10, 50), # 30 magnitude_range=(600, 1200), # 1000 prob=0.8, rotate_range=(np.pi / 12, np.pi / 12, np.pi / 12), shear_range=(np.pi / 18, np.pi / 18, np.pi / 18), translate_range=tuple(sz * 0.05 for sz in crop_size), scale_range=(0.2, 0.2, 0.2), mode=("bilinear", "nearest"), padding_mode=("border", "zeros"), ), ]) train_dataset = Dataset(train_images, transform=train_transform) # when bs > 1, the loader assumes that the full image sizes are the same across the dataset train_dataloader = torch.utils.data.DataLoader(train_dataset, num_workers=4, batch_size=bs, shuffle=True) else: # draw balanced foreground/background window samples according to the ground truth label train_transform = Compose([ AddChannelDict(keys="image"), SpatialPadd( keys=("image", "label"), spatial_size=crop_size), # ensure image size >= crop_size RandCropByPosNegLabeld(keys=("image", "label"), label_key="label", spatial_size=crop_size, num_samples=bs), Rand3DElasticd( keys=("image", "label"), spatial_size=crop_size, sigma_range=(10, 50), # 30 magnitude_range=(600, 1200), # 1000 prob=0.8, rotate_range=(np.pi / 12, np.pi / 12, np.pi / 12), shear_range=(np.pi / 18, np.pi / 18, np.pi / 18), translate_range=tuple(sz * 0.05 for sz in crop_size), scale_range=(0.2, 0.2, 0.2), mode=("bilinear", "nearest"), padding_mode=("border", "zeros"), ), ]) train_dataset = Dataset(train_images, transform=train_transform ) # each dataset item is a list of windows train_dataloader = torch.utils.data.DataLoader( # stack each dataset item into a single tensor train_dataset, num_workers=4, batch_size=1, shuffle=True, collate_fn=list_data_collate) first_sample = first(train_dataloader) print(first_sample["image"].shape) # starting validation set loader val_transform = Compose([AddChannelDict(keys="image")]) val_dataset = Dataset(ImageLabelDataset(VAL_PATH, n_class=N_CLASSES), transform=val_transform) val_dataloader = torch.utils.data.DataLoader(val_dataset, num_workers=1, batch_size=1) print(val_dataset[0]["image"].shape) print( f"training images: {len(train_dataloader)}, validation images: {len(val_dataloader)}" ) model = UNetPipe(spatial_dims=3, in_channels=1, out_channels=N_CLASSES, n_feat=n_feat) model = flatten_sequential(model) lossweight = torch.from_numpy( np.array([2.22, 1.31, 1.99, 1.13, 1.93, 1.93, 1.0, 1.0, 1.90, 1.98], np.float32)) if optimizer.lower() == "rmsprop": optimizer = torch.optim.RMSprop(model.parameters(), lr=lr) # lr = 5e-4 elif optimizer.lower() == "momentum": optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9) # lr = 1e-4 for finetuning else: raise ValueError( f"Unknown optimizer type {optimizer}. (options are 'rmsprop' and 'momentum')." ) # config GPipe x = first_sample["image"].float() x = torch.autograd.Variable(x.cuda()) partitions = torch.cuda.device_count() print(f"partition: {partitions}, input: {x.size()}") balance = balance_by_size(partitions, model, x) model = GPipe(model, balance, chunks=4, checkpoint="always") # config loss functions dice_loss_func = DiceLoss(softmax=True, reduction="none") # use the same pipeline and loss in # AnatomyNet: Deep learning for fast and fully automated whole‐volume segmentation of head and neck anatomy, # Medical Physics, 2018. focal_loss_func = FocalLoss(reduction="none") if pretrain: print(f"loading from {pretrain}.") pretrained_dict = torch.load(pretrain)["weight"] model_dict = model.state_dict() pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } model_dict.update(pretrained_dict) model.load_state_dict(pretrained_dict) b_time = time.time() best_val_loss = [0] * (N_CLASSES - 1) # foreground for epoch in range(ep): model.train() trainloss = 0 for b_idx, data_dict in enumerate(train_dataloader): x_train = data_dict["image"] y_train = data_dict["label"] flagvec = data_dict["with_complete_groundtruth"] x_train = torch.autograd.Variable(x_train.cuda()) y_train = torch.autograd.Variable(y_train.cuda().float()) optimizer.zero_grad() o = model(x_train).to(0, non_blocking=True).float() loss = (dice_loss_func(o, y_train.to(o)) * flagvec.to(o) * lossweight.to(o)).mean() loss += 0.5 * (focal_loss_func(o, y_train.to(o)) * flagvec.to(o) * lossweight.to(o)).mean() loss.backward() optimizer.step() trainloss += loss.item() if b_idx % 20 == 0: print( f"Train Epoch: {epoch} [{b_idx}/{len(train_dataloader)}] \tLoss: {loss.item()}" ) print(f"epoch {epoch} TRAIN loss {trainloss / len(train_dataloader)}") if epoch % 10 == 0: model.eval() # check validation dice val_loss = [0] * (N_CLASSES - 1) n_val = [0] * (N_CLASSES - 1) for data_dict in val_dataloader: x_val = data_dict["image"] y_val = data_dict["label"] with torch.no_grad(): x_val = torch.autograd.Variable(x_val.cuda()) o = model(x_val).to(0, non_blocking=True) loss = compute_meandice(o, y_val.to(o), mutually_exclusive=True, include_background=False) val_loss = [ l.item() + tl if l == l else tl for l, tl in zip(loss[0], val_loss) ] n_val = [ n + 1 if l == l else n for l, n in zip(loss[0], n_val) ] val_loss = [l / n for l, n in zip(val_loss, n_val)] print( "validation scores %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f" % tuple(val_loss)) for c in range(1, 10): if best_val_loss[c - 1] < val_loss[c - 1]: best_val_loss[c - 1] = val_loss[c - 1] state = { "epoch": epoch, "weight": model.state_dict(), "score_" + str(c): best_val_loss[c - 1] } torch.save(state, f"{model_name}" + str(c)) print( "best validation scores %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f" % tuple(best_val_loss)) print("total time", time.time() - b_time)