def test_result_onehot_target_include_bg(self): size = [3, 3, 5, 5] label = torch.randint(low=0, high=2, size=size) pred = torch.randn(size) for reduction in ["sum", "mean", "none"]: common_params = { "include_background": True, "to_onehot_y": False, "reduction": reduction } for focal_weight in [ None, torch.tensor([1.0, 1.0, 2.0]), (3, 2.0, 1) ]: for lambda_focal in [0.5, 1.0, 1.5]: dice_focal = DiceFocalLoss(focal_weight=focal_weight, gamma=1.0, lambda_focal=lambda_focal, **common_params) dice = DiceLoss(**common_params) focal = FocalLoss(weight=focal_weight, gamma=1.0, **common_params) result = dice_focal(pred, label) expected_val = dice( pred, label) + lambda_focal * focal(pred, label) np.testing.assert_allclose(result, expected_val)
def __init__(self, focal): super(Loss, self).__init__() if focal: self.loss = DiceFocalLoss(gamma=2.0, softmax=True, to_onehot_y=True, batch=True) else: self.loss = DiceCELoss(softmax=True, to_onehot_y=True, batch=True)
def test_script(self): loss = DiceFocalLoss() test_input = torch.ones(2, 1, 8, 8) test_script_save(loss, test_input, test_input)
def test_ill_lambda(self): with self.assertRaisesRegex(ValueError, ""): DiceFocalLoss(lambda_dice=-1.0)
def test_ill_shape(self): loss = DiceFocalLoss() with self.assertRaisesRegex(ValueError, ""): loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))
def main_worker(args): # disable logging for processes except 0 on every node if args.local_rank != 0: f = open(os.devnull, "w") sys.stdout = sys.stderr = f if not os.path.exists(args.dir): raise FileNotFoundError(f"missing directory {args.dir}") # initialize the distributed training process, every GPU runs in a process dist.init_process_group(backend="nccl", init_method="env://") device = torch.device(f"cuda:{args.local_rank}") torch.cuda.set_device(device) # use amp to accelerate training scaler = torch.cuda.amp.GradScaler() torch.backends.cudnn.benchmark = True total_start = time.time() train_transforms = Compose([ # load 4 Nifti images and stack them together LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys="image"), ConvertToMultiChannelBasedOnBratsClassesd(keys="label"), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd( keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest"), ), EnsureTyped(keys=["image", "label"]), ToDeviced(keys=["image", "label"], device=device), RandSpatialCropd(keys=["image", "label"], roi_size=[224, 224, 144], random_size=False), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2), NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), RandScaleIntensityd(keys="image", factors=0.1, prob=0.5), RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5), ]) # create a training data loader train_ds = BratsCacheDataset( root_dir=args.dir, transform=train_transforms, section="training", num_workers=4, cache_rate=args.cache_rate, shuffle=True, ) # ThreadDataLoader can be faster if no IO operations when caching all the data in memory train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=args.batch_size, shuffle=True) # validation transforms and dataset val_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys="image"), ConvertToMultiChannelBasedOnBratsClassesd(keys="label"), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd( keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest"), ), NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), EnsureTyped(keys=["image", "label"]), ToDeviced(keys=["image", "label"], device=device), ]) val_ds = BratsCacheDataset( root_dir=args.dir, transform=val_transforms, section="validation", num_workers=4, cache_rate=args.cache_rate, shuffle=False, ) # ThreadDataLoader can be faster if no IO operations when caching all the data in memory val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=args.batch_size, shuffle=False) # create network, loss function and optimizer if args.network == "SegResNet": model = SegResNet( blocks_down=[1, 2, 2, 4], blocks_up=[1, 1, 1], init_filters=16, in_channels=4, out_channels=3, dropout_prob=0.0, ).to(device) else: model = UNet( spatial_dims=3, in_channels=4, out_channels=3, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss_function = DiceFocalLoss( smooth_nr=1e-5, smooth_dr=1e-5, squared_pred=True, to_onehot_y=False, sigmoid=True, batch=True, ) optimizer = Novograd(model.parameters(), lr=args.lr) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=args.epochs) # wrap the model with DistributedDataParallel module model = DistributedDataParallel(model, device_ids=[device]) dice_metric = DiceMetric(include_background=True, reduction="mean") dice_metric_batch = DiceMetric(include_background=True, reduction="mean_batch") post_trans = Compose( [EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) # start a typical PyTorch training best_metric = -1 best_metric_epoch = -1 print(f"time elapsed before training: {time.time() - total_start}") train_start = time.time() for epoch in range(args.epochs): epoch_start = time.time() print("-" * 10) print(f"epoch {epoch + 1}/{args.epochs}") epoch_loss = train(train_loader, model, loss_function, optimizer, lr_scheduler, scaler) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % args.val_interval == 0: metric, metric_tc, metric_wt, metric_et = evaluate( model, val_loader, dice_metric, dice_metric_batch, post_trans) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 if dist.get_rank() == 0: torch.save(model.state_dict(), "best_metric_model.pth") print( f"current epoch: {epoch + 1} current mean dice: {metric:.4f}" f" tc: {metric_tc:.4f} wt: {metric_wt:.4f} et: {metric_et:.4f}" f"\nbest mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}" ) print( f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}" ) print( f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}," f" total train time: {(time.time() - train_start):.4f}") dist.destroy_process_group()