def test_result(self, input_param, input_data, expected_val): diceceloss = DiceCELoss(**input_param) result = diceceloss(**input_data) np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)
def __init__(self, focal): super(Loss, self).__init__() if focal: self.loss = DiceFocalLoss(gamma=2.0, softmax=True, to_onehot_y=True, batch=True) else: self.loss = DiceCELoss(softmax=True, to_onehot_y=True, batch=True)
#device = torch.device('cpu') model = UNet( dimensions=3, in_channels=1, out_channels=2, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, norm=Norm.BATCH, ).to(device) #loss_function = DiceLoss(to_onehot_y=True, softmax=True) #optimizer = torch.optim.Adam(model.parameters(), 1e-4) loss_function = DiceCELoss(include_background=True, to_onehot_y=True, softmax=True, lambda_dice=0.5, lambda_ce=0.5) optimizer = torch.optim.Adam(model.parameters(), 1e-3) dice_metric = DiceMetric(include_background=False, reduction="mean") scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', factor=0.5) ## """## Execute a typical PyTorch training process""" epoch_num = 300 val_interval = 2 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list()
def test_script(self): loss = DiceCELoss() test_input = torch.ones(2, 1, 8, 8) test_script_save(loss, test_input, test_input)
def test_ill_shape(self): loss = DiceCELoss() with self.assertRaisesRegex(ValueError, ""): loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))
def test_ill_reduction(self): with self.assertRaisesRegex(ValueError, ""): loss = DiceCELoss(reduction="none") loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))
def train(args): # load hyper parameters task_id = args.task_id fold = args.fold val_output_dir = "./runs_{}_fold{}_{}/".format(task_id, fold, args.expr_name) log_filename = "nnunet_task{}_fold{}.log".format(task_id, fold) log_filename = os.path.join(val_output_dir, log_filename) interval = args.interval learning_rate = args.learning_rate max_epochs = args.max_epochs multi_gpu_flag = args.multi_gpu amp_flag = args.amp lr_decay_flag = args.lr_decay sw_batch_size = args.sw_batch_size tta_val = args.tta_val batch_dice = args.batch_dice window_mode = args.window_mode eval_overlap = args.eval_overlap local_rank = args.local_rank determinism_flag = args.determinism_flag determinism_seed = args.determinism_seed if determinism_flag: set_determinism(seed=determinism_seed) if local_rank == 0: print("Using deterministic training.") # transforms train_batch_size = data_loader_params[task_id]["batch_size"] if multi_gpu_flag: dist.init_process_group(backend="nccl", init_method="env://") device = torch.device(f"cuda:{local_rank}") torch.cuda.set_device(device) else: device = torch.device("cuda") properties, val_loader = get_data(args, mode="validation") _, train_loader = get_data(args, batch_size=train_batch_size, mode="train") # produce the network checkpoint = args.checkpoint net = get_network(properties, task_id, val_output_dir, checkpoint) net = net.to(device) if multi_gpu_flag: net = DistributedDataParallel(module=net, device_ids=[device], find_unused_parameters=True) optimizer = torch.optim.SGD( net.parameters(), lr=learning_rate, momentum=0.99, weight_decay=3e-5, nesterov=True, ) scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda epoch: (1 - epoch / max_epochs)**0.9) # produce evaluator val_handlers = [ StatsHandler(output_transform=lambda x: None), CheckpointSaver(save_dir=val_output_dir, save_dict={"net": net}, save_key_metric=True), ] evaluator = DynUNetEvaluator( device=device, val_data_loader=val_loader, network=net, n_classes=len(properties["labels"]), inferer=SlidingWindowInferer( roi_size=patch_size[task_id], sw_batch_size=sw_batch_size, overlap=eval_overlap, mode=window_mode, ), post_transform=None, key_val_metric={ "val_mean_dice": MeanDice( include_background=False, output_transform=lambda x: (x["pred"], x["label"]), ) }, val_handlers=val_handlers, amp=amp_flag, tta_val=tta_val, ) # produce trainer loss = DiceCELoss(to_onehot_y=True, softmax=True, batch=batch_dice) train_handlers = [] if lr_decay_flag: train_handlers += [ LrScheduleHandler(lr_scheduler=scheduler, print_lr=True) ] train_handlers += [ ValidationHandler(validator=evaluator, interval=interval, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), ] trainer = DynUNetTrainer( device=device, max_epochs=max_epochs, train_data_loader=train_loader, network=net, optimizer=optimizer, loss_function=loss, inferer=SimpleInferer(), post_transform=None, key_train_metric=None, train_handlers=train_handlers, amp=amp_flag, ) # run logger = logging.getLogger() formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - %(message)s") # Setup file handler fhandler = logging.FileHandler(log_filename) fhandler.setLevel(logging.INFO) fhandler.setFormatter(formatter) # Configure stream handler for the cells chandler = logging.StreamHandler() chandler.setLevel(logging.INFO) chandler.setFormatter(formatter) # Add both handlers if local_rank == 0: logger.addHandler(fhandler) logger.addHandler(chandler) logger.setLevel(logging.INFO) trainer.run()
def main(): print_config() # Define paths for running the script data_dir = os.path.normpath('/to/be/defined') json_path = os.path.normpath('/to/be/defined') logdir = os.path.normpath('/to/be/defined') # If use_pretrained is set to 0, ViT weights will not be loaded and random initialization is used use_pretrained = 1 pretrained_path = os.path.normpath('/to/be/defined') # Training Hyper-parameters lr = 1e-4 max_iterations = 30000 eval_num = 100 if os.path.exists(logdir) == False: os.mkdir(logdir) # Training & Validation Transform chain train_transforms = Compose([ LoadImaged(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), Spacingd( keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest"), ), Orientationd(keys=["image", "label"], axcodes="RAS"), ScaleIntensityRanged( keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True, ), CropForegroundd(keys=["image", "label"], source_key="image"), RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=(96, 96, 96), pos=1, neg=1, num_samples=4, image_key="image", image_threshold=0, ), RandFlipd( keys=["image", "label"], spatial_axis=[0], prob=0.10, ), RandFlipd( keys=["image", "label"], spatial_axis=[1], prob=0.10, ), RandFlipd( keys=["image", "label"], spatial_axis=[2], prob=0.10, ), RandRotate90d( keys=["image", "label"], prob=0.10, max_k=3, ), RandShiftIntensityd( keys=["image"], offsets=0.10, prob=0.50, ), ToTensord(keys=["image", "label"]), ]) val_transforms = Compose([ LoadImaged(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), Spacingd( keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest"), ), Orientationd(keys=["image", "label"], axcodes="RAS"), ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True), CropForegroundd(keys=["image", "label"], source_key="image"), ToTensord(keys=["image", "label"]), ]) datalist = load_decathlon_datalist(base_dir=data_dir, data_list_file_path=json_path, is_segmentation=True, data_list_key="training") val_files = load_decathlon_datalist(base_dir=data_dir, data_list_file_path=json_path, is_segmentation=True, data_list_key="validation") train_ds = CacheDataset( data=datalist, transform=train_transforms, cache_num=24, cache_rate=1.0, num_workers=4, ) train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=4, pin_memory=True) val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_num=6, cache_rate=1.0, num_workers=4) val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) case_num = 0 img = val_ds[case_num]["image"] label = val_ds[case_num]["label"] img_shape = img.shape label_shape = label.shape print(f"image shape: {img_shape}, label shape: {label_shape}") os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = UNETR( in_channels=1, out_channels=14, img_size=(96, 96, 96), feature_size=16, hidden_size=768, mlp_dim=3072, num_heads=12, pos_embed="conv", norm_name="instance", res_block=True, dropout_rate=0.0, ) # Load ViT backbone weights into UNETR if use_pretrained == 1: print('Loading Weights from the Path {}'.format(pretrained_path)) vit_dict = torch.load(pretrained_path) vit_weights = vit_dict['state_dict'] # Delete the following variable names conv3d_transpose.weight, conv3d_transpose.bias, # conv3d_transpose_1.weight, conv3d_transpose_1.bias as they were used to match dimensions # while pretraining with ViTAutoEnc and are not a part of ViT backbone (this is used in UNETR) vit_weights.pop('conv3d_transpose_1.bias') vit_weights.pop('conv3d_transpose_1.weight') vit_weights.pop('conv3d_transpose.bias') vit_weights.pop('conv3d_transpose.weight') model.vit.load_state_dict(vit_weights) print('Pretrained Weights Succesfully Loaded !') elif use_pretrained == 0: print( 'No weights were loaded, all weights being used are randomly initialized!' ) model.to(device) loss_function = DiceCELoss(to_onehot_y=True, softmax=True) torch.backends.cudnn.benchmark = True optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5) post_label = AsDiscrete(to_onehot=14) post_pred = AsDiscrete(argmax=True, to_onehot=14) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) global_step = 0 dice_val_best = 0.0 global_step_best = 0 epoch_loss_values = [] metric_values = [] def validation(epoch_iterator_val): model.eval() dice_vals = list() with torch.no_grad(): for step, batch in enumerate(epoch_iterator_val): val_inputs, val_labels = (batch["image"].cuda(), batch["label"].cuda()) val_outputs = sliding_window_inference(val_inputs, (96, 96, 96), 4, model) val_labels_list = decollate_batch(val_labels) val_labels_convert = [ post_label(val_label_tensor) for val_label_tensor in val_labels_list ] val_outputs_list = decollate_batch(val_outputs) val_output_convert = [ post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list ] dice_metric(y_pred=val_output_convert, y=val_labels_convert) dice = dice_metric.aggregate().item() dice_vals.append(dice) epoch_iterator_val.set_description( "Validate (%d / %d Steps) (dice=%2.5f)" % (global_step, 10.0, dice)) dice_metric.reset() mean_dice_val = np.mean(dice_vals) return mean_dice_val def train(global_step, train_loader, dice_val_best, global_step_best): model.train() epoch_loss = 0 step = 0 epoch_iterator = tqdm(train_loader, desc="Training (X / X Steps) (loss=X.X)", dynamic_ncols=True) for step, batch in enumerate(epoch_iterator): step += 1 x, y = (batch["image"].cuda(), batch["label"].cuda()) logit_map = model(x) loss = loss_function(logit_map, y) loss.backward() epoch_loss += loss.item() optimizer.step() optimizer.zero_grad() epoch_iterator.set_description( "Training (%d / %d Steps) (loss=%2.5f)" % (global_step, max_iterations, loss)) if (global_step % eval_num == 0 and global_step != 0) or global_step == max_iterations: epoch_iterator_val = tqdm( val_loader, desc="Validate (X / X Steps) (dice=X.X)", dynamic_ncols=True) dice_val = validation(epoch_iterator_val) epoch_loss /= step epoch_loss_values.append(epoch_loss) metric_values.append(dice_val) if dice_val > dice_val_best: dice_val_best = dice_val global_step_best = global_step torch.save(model.state_dict(), os.path.join(logdir, "best_metric_model.pth")) print( "Model Was Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}" .format(dice_val_best, dice_val)) else: print( "Model Was Not Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}" .format(dice_val_best, dice_val)) plt.figure(1, (12, 6)) plt.subplot(1, 2, 1) plt.title("Iteration Average Loss") x = [eval_num * (i + 1) for i in range(len(epoch_loss_values))] y = epoch_loss_values plt.xlabel("Iteration") plt.plot(x, y) plt.grid() plt.subplot(1, 2, 2) plt.title("Val Mean Dice") x = [eval_num * (i + 1) for i in range(len(metric_values))] y = metric_values plt.xlabel("Iteration") plt.plot(x, y) plt.grid() plt.savefig( os.path.join(logdir, 'btcv_finetune_quick_update.png')) plt.clf() plt.close(1) global_step += 1 return global_step, dice_val_best, global_step_best while global_step < max_iterations: global_step, dice_val_best, global_step_best = train( global_step, train_loader, dice_val_best, global_step_best) model.load_state_dict( torch.load(os.path.join(logdir, "best_metric_model.pth"))) print(f"train completed, best_metric: {dice_val_best:.4f} " f"at iteration: {global_step_best}")
def tune_weight_decay_network(training_loader, network, weights, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), root_path=''): network.cuda(device) weights_plots = {} # Train network per weight to tune for weight in weights: network_cur_weight_iter = copy.deepcopy(network) optimizer = optim.Adam(network_cur_weight_iter.parameters(), lr=5e-5, weight_decay=weight) dice_ce_loss = DiceCELoss(include_background=False, ce_weight=CE_WEIGHTS) # Train the network network_cur_weight_iter.train(True) train_step = 1 batch_loss = {} for epoch in range(3): for idx, batch_data in enumerate(training_loader): print(f'Weight: {weight} \tTraining Step: {train_step}/{len(training_loader)}') torch.cuda.empty_cache() # Clear any unused variables inputs = batch_data["image"].to(device) labels = batch_data["label"] # Only pass to CUDA when required - preserve memory # Zero the parameter gradients optimizer.zero_grad() # Feed input data into the network to train outputs = network_cur_weight_iter(inputs) # Input no longer in use for current iteration - clear from CUDA memory inputs = inputs.cpu() torch.cuda.empty_cache() # labels to CUDA labels = batch_data["label"].to(device) torch.cuda.empty_cache() # Calculate DICE CE loss, permute tensors to correct dimensions loss = dice_ce_loss(outputs.permute(0, 1, 3, 4, 2), labels.permute(0, 1, 3, 4, 2)) # List of losses for current batch batch_loss[train_step - 1] = loss.detach().cpu().numpy() # Clear CUDA memory labels = labels.cpu() torch.cuda.empty_cache() # Backward pass loss.backward() # Optimize optimizer.step() train_step += 1 # Store loss against the weight decay parameter weights_plots[str(weight)] = batch_loss return weights_plots
def train_network(training_loader, val_loader, network, pre_load_training=False, checkpoint_name='', device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), root_path='', EPOCHS=10): network.cuda(device) optimizer = optim.Adam(network.parameters(), lr=5e-5, weight_decay=1e-3) # COMMENTED OUT - scheduler to increment the optimizer learning rates. # steps = lambda epoch: 1.25 # scheduler = optim.lr_scheduler.MultiplicativeLR(optimizer, steps) dice_ce_loss = DiceCELoss(include_background=False, ce_weight=CE_WEIGHTS) epoch_checkpoint = 0 losses = {} val_losses = {} # Test Learning rate dictionary for visualization scheduler_learning_rate_dict = {} if pre_load_training: checkpoint = torch.load(root_path + f'/{checkpoint_name}.pt') epoch_checkpoint = checkpoint['epoch'] + 1 network.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) loss = checkpoint['loss'] losses = checkpoint['losses'] val_losses = checkpoint['val_losses'] # Only used to find learning rate # scheduler_learning_rate_dict = checkpoint['scheduler_learning_rate_dict'] # Train the network for epoch in range(epoch_checkpoint, EPOCHS): network.train(True) print(f'losses: {losses}') print(f'val losses {val_losses}') train_step = 1 batch_loss = [] for batch_data in training_loader: print(f'Epoch {epoch}\tTraining Step: {train_step}/{len(training_loader)}') torch.cuda.empty_cache() # Clear any unused variables inputs = batch_data["image"].to(device) labels = batch_data["label"] # Only pass to CUDA when required - preserve memory # Zero the parameter gradients optimizer.zero_grad() # Feed input data into the network to train outputs = network(inputs) # Input no longer in use for current iteration - clear from CUDA memory inputs = inputs.cpu() torch.cuda.empty_cache() # labels to CUDA labels = batch_data["label"].to(device) torch.cuda.empty_cache() # Calculate DICE CE loss, permute tensors to correct dimensions loss = dice_ce_loss(outputs.permute(0, 1, 3, 4, 2), labels.permute(0, 1, 3, 4, 2)) # COMMENTED OUT - Store learning rate variables and plot to fine tune lr hyperparamter # current_learning_rate = optimizer.param_groups[0]['lr'] # print(f'type dict {type(scheduler_learning_rate_dict)}. type loss {type(loss)}') # scheduler_learning_rate_dict[current_learning_rate] = loss # List of losses for current batch batch_loss.append(loss.detach().cpu().numpy()) # Clear CUDA memory labels = labels.cpu() torch.cuda.empty_cache() # Backward pass loss.backward() # Optimize optimizer.step() # COMMENTED OUT - UPDATE OPTIMIZER LEARNING RATES TO FIND BEST LEARNING RATE # Used for testing best lr # scheduler.step() train_step += 1 # COMMENTED OUT - PLOT LOSS CHANGES WITH LEARNING RATES # ======================================================= # Plot losscheduler_learning_rate_list = sorted(scheduler_learning_rate_dict.items()) # x, y = zip(*scheduler_learning_rate_list) # plt.xscale('log') # plt.plot(x, y) # plt.xlabel('Learning Rate') # plt.ylabel('Loss') # plt.title('Training losses with varying learning rate') # plt.show()ses against learning rate # ======================================================= # Get average loss for current batch losses[epoch] = np.mean(batch_loss) print(f'train losses {batch_loss} \nmean loss {losses[epoch]}') if epoch % 2 == 0: # Set network to eval mode network.train(False) # Disiable gradient calculation and optimise memory with torch.no_grad(): # Initialise validation loss dice_ce_test_loss = 0 for i, batch_data in enumerate(val_loader): # Get inputs and labels from validation set inputs = batch_data["image"].to(device) labels = batch_data["label"] # Make prediction # sw_batch_size = 2 # roi_size = (96, 96, 16) # outputs = sliding_window_inference( # inputs, roi_size, sw_batch_size, network # ) outputs = network(inputs) # Memory optimization inputs = inputs.cpu() torch.cuda.empty_cache() labels = batch_data["label"].to(device) # Accumulate DICE CE loss validation error dice_ce_test_loss += dice_ce_loss(outputs.permute(0, 1, 3, 4, 2), labels.permute(0, 1, 3, 4, 2)) # Get average validation DICE CE loss val_losses[epoch] = dice_ce_test_loss / i # Print errors print( "==== Epoch: " + str(epoch) + " | DICE CE loss: " + str(numpy_from_tensor(dice_ce_test_loss / i)) + " | Total Loss: " + str(numpy_from_tensor(( dice_ce_test_loss) / i)) + " =====") # This is redundant code but will keep here incase we add more losses # View slice at halfway point half = outputs.shape[2] // 2 # Show predictions for current iteration view_slice(numpy_from_tensor(inputs[0, 0, half, :, :]), f'Input Image Epoch {epoch}') view_slice(numpy_from_tensor(outputs[0, 0, half, :, :]), f'Predicted Background Epoch {epoch}') view_slice(numpy_from_tensor(outputs[0, 1, half, :, :]), f'Predicted Pancreas Epoch {epoch}') view_slice(numpy_from_tensor(outputs[0, 2, half, :, :]), f'Predicted Cancer Epoch {epoch}') view_slice(numpy_from_tensor(labels[0, 0, half, :, :]), f'Labels Background Epoch {epoch}') view_slice(numpy_from_tensor(labels[0, 1, half, :, :]), f'Labels Pancreas Epoch {epoch}') view_slice(numpy_from_tensor(labels[0, 2, half, :, :]), f'Labels Cancer Epoch {epoch}') # Save training checkpoint torch.save({ 'epoch': epoch, 'model_state_dict': network.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, 'losses': losses, 'val_losses': val_losses, # 'scheduler_learning_rate_dict':scheduler_learning_rate_dict }, root_path + f'/{checkpoint_name}.pt') # Confirm current epoch trained params are saved print(f'Saved for epoch {epoch}') return network
def test_train_timing(self): images = sorted(glob(os.path.join(self.data_dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(self.data_dir, "seg*.nii.gz"))) train_files = [{ "image": img, "label": seg } for img, seg in zip(images[:32], segs[:32])] val_files = [{ "image": img, "label": seg } for img, seg in zip(images[-9:], segs[-9:])] device = torch.device("cuda:0") # define transforms for train and validation train_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")), ScaleIntensityd(keys="image"), CropForegroundd(keys=["image", "label"], source_key="image"), # pre-compute foreground and background indexes # and cache them to accelerate training FgBgToIndicesd(keys="label", fg_postfix="_fg", bg_postfix="_bg"), # change to execute transforms with Tensor data EnsureTyped(keys=["image", "label"]), # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch ToDeviced(keys=["image", "label"], device=device), # randomly crop out patch samples from big # image based on pos / neg ratio # the image centers of negative samples # must be in valid image area RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=(64, 64, 64), pos=1, neg=1, num_samples=4, fg_indices_key="label_fg", bg_indices_key="label_bg", ), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(keys=["image", "label"], prob=0.5), RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=(1, 2)), RandZoomd(keys=["image", "label"], prob=0.5, min_zoom=0.8, max_zoom=1.2, keep_size=True), RandRotated( keys=["image", "label"], prob=0.5, range_x=np.pi / 4, mode=("bilinear", "nearest"), align_corners=True, dtype=np.float64, ), RandAffined(keys=["image", "label"], prob=0.5, rotate_range=np.pi / 2, mode=("bilinear", "nearest")), RandGaussianNoised(keys="image", prob=0.5), RandStdShiftIntensityd(keys="image", prob=0.5, factors=0.05, nonzero=True), ]) val_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")), ScaleIntensityd(keys="image"), CropForegroundd(keys=["image", "label"], source_key="image"), EnsureTyped(keys=["image", "label"]), # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch ToDeviced(keys=["image", "label"], device=device), ]) max_epochs = 5 learning_rate = 2e-4 val_interval = 1 # do validation for every epoch # set CacheDataset, ThreadDataLoader and DiceCE loss for MONAI fast training train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=8) val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=5) # disable multi-workers because `ThreadDataLoader` works with multi-threads train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=4, shuffle=True) val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1) loss_function = DiceCELoss(to_onehot_y=True, softmax=True, squared_pred=True, batch=True) model = UNet( spatial_dims=3, in_channels=1, out_channels=2, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, norm=Norm.BATCH, ).to(device) # Novograd paper suggests to use a bigger LR than Adam, # because Adam does normalization by element-wise second moments optimizer = Novograd(model.parameters(), learning_rate * 10) scaler = torch.cuda.amp.GradScaler() post_pred = Compose( [EnsureType(), AsDiscrete(argmax=True, to_onehot=2)]) post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)]) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) best_metric = -1 total_start = time.time() for epoch in range(max_epochs): epoch_start = time.time() print("-" * 10) print(f"epoch {epoch + 1}/{max_epochs}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step_start = time.time() step += 1 optimizer.zero_grad() # set AMP for training with torch.cuda.amp.autocast(): outputs = model(batch_data["image"]) loss = loss_function(outputs, batch_data["label"]) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() epoch_loss += loss.item() epoch_len = math.ceil(len(train_ds) / train_loader.batch_size) print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}" f" step time: {(time.time() - step_start):.4f}") epoch_loss /= step print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): for val_data in val_loader: roi_size = (96, 96, 96) sw_batch_size = 4 # set AMP for validation with torch.cuda.amp.autocast(): val_outputs = sliding_window_inference( val_data["image"], roi_size, sw_batch_size, model) val_outputs = [ post_pred(i) for i in decollate_batch(val_outputs) ] val_labels = [ post_label(i) for i in decollate_batch(val_data["label"]) ] dice_metric(y_pred=val_outputs, y=val_labels) metric = dice_metric.aggregate().item() dice_metric.reset() if metric > best_metric: best_metric = metric print( f"epoch: {epoch + 1} current mean dice: {metric:.4f}, best mean dice: {best_metric:.4f}" ) print( f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}" ) total_time = time.time() - total_start print( f"train completed, best_metric: {best_metric:.4f} total time: {total_time:.4f}" ) # test expected metrics self.assertGreater(best_metric, 0.95)
# standard PyTorch program style: create UNet, DiceLoss and Adam optimizer device = torch.device("cuda:0") max_epochs = 6 learning_rate = 1e-4 val_interval = 2 model = UNet( spatial_dims=3, in_channels=1, out_channels=2, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, norm=Norm.BATCH, ).to(device) loss_function = DiceCELoss( to_onehot_y=True, softmax=True, squared_pred=True, batch=True ) optimizer = Novograd(model.parameters(), learning_rate * 10) scaler = torch.cuda.amp.GradScaler() dice_metric = DiceMetric( include_background=True, reduction="mean", get_not_nans=False ) post_pred = Compose( [EnsureType(), AsDiscrete(argmax=True, to_onehot=2)] ) post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)]) best_metric = -1 best_metric_epoch = -1 best_metrics_epochs_and_time = [[], [], []]
def loss_function(self, context: Context): return DiceCELoss(to_onehot_y=True, softmax=True, squared_pred=True, batch=True)
def run_interaction(self, train): label_names = {"spleen": 1, "background": 0} np.random.seed(0) data = [ { "image": np.random.randint(0, 256, size=(1, 15, 15, 15)).astype(np.float32), "label": np.random.randint(0, 2, size=(1, 15, 15, 15)), "label_names": label_names, } for _ in range(5) ] network = torch.nn.Conv3d(3, len(label_names), 1) lr = 1e-3 opt = torch.optim.Adam(network.parameters(), lr) loss = DiceCELoss(to_onehot_y=True, softmax=True) pre_transforms = Compose( [ FindAllValidSlicesMissingLabelsd(keys="label", sids="sids"), AddInitialSeedPointMissingLabelsd(keys="label", guidance="guidance", sids="sids"), AddGuidanceSignalDeepEditd(keys="image", guidance="guidance", number_intensity_ch=1), ToTensord(keys=("image", "label")), ] ) dataset = Dataset(data, transform=pre_transforms) data_loader = torch.utils.data.DataLoader(dataset, batch_size=5) iteration_transforms = [ FindDiscrepancyRegionsDeepEditd(keys="label", pred="pred", discrepancy="discrepancy"), AddRandomGuidanceDeepEditd( keys="NA", guidance="guidance", discrepancy="discrepancy", probability="probability" ), AddGuidanceSignalDeepEditd(keys="image", guidance="guidance", number_intensity_ch=1), ToTensord(keys=("image", "label")), ] post_transforms = [ Activationsd(keys="pred", softmax=True), AsDiscreted(keys=("pred", "label"), argmax=(True, False), to_onehot=len(label_names)), SplitPredsLabeld(keys="pred"), ToTensord(keys=("image", "label")), ] iteration_transforms = Compose(iteration_transforms) post_transforms = Compose(post_transforms) i = Interaction( deepgrow_probability=1.0, transforms=iteration_transforms, click_probability_key="probability", train=train, label_names=label_names, ) self.assertEqual(len(i.transforms.transforms), 4, "Mismatch in expected transforms") # set up engine engine = SupervisedTrainer( device=torch.device("cpu"), max_epochs=1, train_data_loader=data_loader, network=network, optimizer=opt, loss_function=loss, postprocessing=post_transforms, iteration_update=i, ) engine.add_event_handler(IterationEvents.INNER_ITERATION_STARTED, add_one) engine.add_event_handler(IterationEvents.INNER_ITERATION_COMPLETED, add_one) engine.run() self.assertIsNotNone(engine.state.batch[0].get("guidance"), "guidance is missing") self.assertEqual(engine.state.best_metric, 1)