def test_result(self, input_param, input_data, expected_val):
     diceceloss = DiceCELoss(**input_param)
     result = diceceloss(**input_data)
     np.testing.assert_allclose(result.detach().cpu().numpy(),
                                expected_val,
                                atol=1e-4,
                                rtol=1e-4)
Example #2
0
 def __init__(self, focal):
     super(Loss, self).__init__()
     if focal:
         self.loss = DiceFocalLoss(gamma=2.0,
                                   softmax=True,
                                   to_onehot_y=True,
                                   batch=True)
     else:
         self.loss = DiceCELoss(softmax=True, to_onehot_y=True, batch=True)
#device = torch.device('cpu')
model = UNet(
    dimensions=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)
#loss_function = DiceLoss(to_onehot_y=True, softmax=True)
#optimizer = torch.optim.Adam(model.parameters(), 1e-4)

loss_function = DiceCELoss(include_background=True,
                           to_onehot_y=True,
                           softmax=True,
                           lambda_dice=0.5,
                           lambda_ce=0.5)
optimizer = torch.optim.Adam(model.parameters(), 1e-3)
dice_metric = DiceMetric(include_background=False, reduction="mean")
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                       'max',
                                                       factor=0.5)  ##
"""## Execute a typical PyTorch training process"""

epoch_num = 300
val_interval = 2
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = list()
metric_values = list()
Example #4
0
 def test_script(self):
     loss = DiceCELoss()
     test_input = torch.ones(2, 1, 8, 8)
     test_script_save(loss, test_input, test_input)
Example #5
0
 def test_ill_shape(self):
     loss = DiceCELoss()
     with self.assertRaisesRegex(ValueError, ""):
         loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))
Example #6
0
 def test_ill_reduction(self):
     with self.assertRaisesRegex(ValueError, ""):
         loss = DiceCELoss(reduction="none")
         loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))
Example #7
0
def train(args):
    # load hyper parameters
    task_id = args.task_id
    fold = args.fold
    val_output_dir = "./runs_{}_fold{}_{}/".format(task_id, fold,
                                                   args.expr_name)
    log_filename = "nnunet_task{}_fold{}.log".format(task_id, fold)
    log_filename = os.path.join(val_output_dir, log_filename)
    interval = args.interval
    learning_rate = args.learning_rate
    max_epochs = args.max_epochs
    multi_gpu_flag = args.multi_gpu
    amp_flag = args.amp
    lr_decay_flag = args.lr_decay
    sw_batch_size = args.sw_batch_size
    tta_val = args.tta_val
    batch_dice = args.batch_dice
    window_mode = args.window_mode
    eval_overlap = args.eval_overlap
    local_rank = args.local_rank
    determinism_flag = args.determinism_flag
    determinism_seed = args.determinism_seed
    if determinism_flag:
        set_determinism(seed=determinism_seed)
        if local_rank == 0:
            print("Using deterministic training.")

    # transforms
    train_batch_size = data_loader_params[task_id]["batch_size"]
    if multi_gpu_flag:
        dist.init_process_group(backend="nccl", init_method="env://")

        device = torch.device(f"cuda:{local_rank}")
        torch.cuda.set_device(device)
    else:
        device = torch.device("cuda")

    properties, val_loader = get_data(args, mode="validation")
    _, train_loader = get_data(args, batch_size=train_batch_size, mode="train")

    # produce the network
    checkpoint = args.checkpoint
    net = get_network(properties, task_id, val_output_dir, checkpoint)
    net = net.to(device)

    if multi_gpu_flag:
        net = DistributedDataParallel(module=net,
                                      device_ids=[device],
                                      find_unused_parameters=True)

    optimizer = torch.optim.SGD(
        net.parameters(),
        lr=learning_rate,
        momentum=0.99,
        weight_decay=3e-5,
        nesterov=True,
    )

    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer, lr_lambda=lambda epoch: (1 - epoch / max_epochs)**0.9)
    # produce evaluator
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        CheckpointSaver(save_dir=val_output_dir,
                        save_dict={"net": net},
                        save_key_metric=True),
    ]

    evaluator = DynUNetEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        n_classes=len(properties["labels"]),
        inferer=SlidingWindowInferer(
            roi_size=patch_size[task_id],
            sw_batch_size=sw_batch_size,
            overlap=eval_overlap,
            mode=window_mode,
        ),
        post_transform=None,
        key_val_metric={
            "val_mean_dice":
            MeanDice(
                include_background=False,
                output_transform=lambda x: (x["pred"], x["label"]),
            )
        },
        val_handlers=val_handlers,
        amp=amp_flag,
        tta_val=tta_val,
    )
    # produce trainer
    loss = DiceCELoss(to_onehot_y=True, softmax=True, batch=batch_dice)
    train_handlers = []
    if lr_decay_flag:
        train_handlers += [
            LrScheduleHandler(lr_scheduler=scheduler, print_lr=True)
        ]

    train_handlers += [
        ValidationHandler(validator=evaluator,
                          interval=interval,
                          epoch_level=True),
        StatsHandler(tag_name="train_loss",
                     output_transform=lambda x: x["loss"]),
    ]

    trainer = DynUNetTrainer(
        device=device,
        max_epochs=max_epochs,
        train_data_loader=train_loader,
        network=net,
        optimizer=optimizer,
        loss_function=loss,
        inferer=SimpleInferer(),
        post_transform=None,
        key_train_metric=None,
        train_handlers=train_handlers,
        amp=amp_flag,
    )

    # run
    logger = logging.getLogger()

    formatter = logging.Formatter(
        "%(asctime)s - %(name)s - %(levelname)s - %(message)s")

    # Setup file handler
    fhandler = logging.FileHandler(log_filename)
    fhandler.setLevel(logging.INFO)
    fhandler.setFormatter(formatter)

    # Configure stream handler for the cells
    chandler = logging.StreamHandler()
    chandler.setLevel(logging.INFO)
    chandler.setFormatter(formatter)

    # Add both handlers
    if local_rank == 0:
        logger.addHandler(fhandler)
        logger.addHandler(chandler)
        logger.setLevel(logging.INFO)

    trainer.run()
Example #8
0
def main():
    print_config()

    # Define paths for running the script
    data_dir = os.path.normpath('/to/be/defined')
    json_path = os.path.normpath('/to/be/defined')
    logdir = os.path.normpath('/to/be/defined')

    # If use_pretrained is set to 0, ViT weights will not be loaded and random initialization is used
    use_pretrained = 1
    pretrained_path = os.path.normpath('/to/be/defined')

    # Training Hyper-parameters
    lr = 1e-4
    max_iterations = 30000
    eval_num = 100

    if os.path.exists(logdir) == False:
        os.mkdir(logdir)

    # Training & Validation Transform chain
    train_transforms = Compose([
        LoadImaged(keys=["image", "label"]),
        AddChanneld(keys=["image", "label"]),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 2.0),
            mode=("bilinear", "nearest"),
        ),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-175,
            a_max=250,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        RandCropByPosNegLabeld(
            keys=["image", "label"],
            label_key="label",
            spatial_size=(96, 96, 96),
            pos=1,
            neg=1,
            num_samples=4,
            image_key="image",
            image_threshold=0,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[0],
            prob=0.10,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[1],
            prob=0.10,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[2],
            prob=0.10,
        ),
        RandRotate90d(
            keys=["image", "label"],
            prob=0.10,
            max_k=3,
        ),
        RandShiftIntensityd(
            keys=["image"],
            offsets=0.10,
            prob=0.50,
        ),
        ToTensord(keys=["image", "label"]),
    ])
    val_transforms = Compose([
        LoadImaged(keys=["image", "label"]),
        AddChanneld(keys=["image", "label"]),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 2.0),
            mode=("bilinear", "nearest"),
        ),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        ScaleIntensityRanged(keys=["image"],
                             a_min=-175,
                             a_max=250,
                             b_min=0.0,
                             b_max=1.0,
                             clip=True),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        ToTensord(keys=["image", "label"]),
    ])

    datalist = load_decathlon_datalist(base_dir=data_dir,
                                       data_list_file_path=json_path,
                                       is_segmentation=True,
                                       data_list_key="training")

    val_files = load_decathlon_datalist(base_dir=data_dir,
                                        data_list_file_path=json_path,
                                        is_segmentation=True,
                                        data_list_key="validation")
    train_ds = CacheDataset(
        data=datalist,
        transform=train_transforms,
        cache_num=24,
        cache_rate=1.0,
        num_workers=4,
    )
    train_loader = DataLoader(train_ds,
                              batch_size=1,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=True)
    val_ds = CacheDataset(data=val_files,
                          transform=val_transforms,
                          cache_num=6,
                          cache_rate=1.0,
                          num_workers=4)
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            shuffle=False,
                            num_workers=4,
                            pin_memory=True)

    case_num = 0
    img = val_ds[case_num]["image"]
    label = val_ds[case_num]["label"]
    img_shape = img.shape
    label_shape = label.shape
    print(f"image shape: {img_shape}, label shape: {label_shape}")

    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = UNETR(
        in_channels=1,
        out_channels=14,
        img_size=(96, 96, 96),
        feature_size=16,
        hidden_size=768,
        mlp_dim=3072,
        num_heads=12,
        pos_embed="conv",
        norm_name="instance",
        res_block=True,
        dropout_rate=0.0,
    )

    # Load ViT backbone weights into UNETR
    if use_pretrained == 1:
        print('Loading Weights from the Path {}'.format(pretrained_path))
        vit_dict = torch.load(pretrained_path)
        vit_weights = vit_dict['state_dict']

        #  Delete the following variable names conv3d_transpose.weight, conv3d_transpose.bias,
        #  conv3d_transpose_1.weight, conv3d_transpose_1.bias as they were used to match dimensions
        #  while pretraining with ViTAutoEnc and are not a part of ViT backbone (this is used in UNETR)
        vit_weights.pop('conv3d_transpose_1.bias')
        vit_weights.pop('conv3d_transpose_1.weight')
        vit_weights.pop('conv3d_transpose.bias')
        vit_weights.pop('conv3d_transpose.weight')

        model.vit.load_state_dict(vit_weights)
        print('Pretrained Weights Succesfully Loaded !')

    elif use_pretrained == 0:
        print(
            'No weights were loaded, all weights being used are randomly initialized!'
        )

    model.to(device)

    loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
    torch.backends.cudnn.benchmark = True
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)

    post_label = AsDiscrete(to_onehot=14)
    post_pred = AsDiscrete(argmax=True, to_onehot=14)
    dice_metric = DiceMetric(include_background=True,
                             reduction="mean",
                             get_not_nans=False)
    global_step = 0
    dice_val_best = 0.0
    global_step_best = 0
    epoch_loss_values = []
    metric_values = []

    def validation(epoch_iterator_val):
        model.eval()
        dice_vals = list()

        with torch.no_grad():
            for step, batch in enumerate(epoch_iterator_val):
                val_inputs, val_labels = (batch["image"].cuda(),
                                          batch["label"].cuda())
                val_outputs = sliding_window_inference(val_inputs,
                                                       (96, 96, 96), 4, model)
                val_labels_list = decollate_batch(val_labels)
                val_labels_convert = [
                    post_label(val_label_tensor)
                    for val_label_tensor in val_labels_list
                ]
                val_outputs_list = decollate_batch(val_outputs)
                val_output_convert = [
                    post_pred(val_pred_tensor)
                    for val_pred_tensor in val_outputs_list
                ]
                dice_metric(y_pred=val_output_convert, y=val_labels_convert)
                dice = dice_metric.aggregate().item()
                dice_vals.append(dice)
                epoch_iterator_val.set_description(
                    "Validate (%d / %d Steps) (dice=%2.5f)" %
                    (global_step, 10.0, dice))

            dice_metric.reset()

        mean_dice_val = np.mean(dice_vals)
        return mean_dice_val

    def train(global_step, train_loader, dice_val_best, global_step_best):
        model.train()
        epoch_loss = 0
        step = 0
        epoch_iterator = tqdm(train_loader,
                              desc="Training (X / X Steps) (loss=X.X)",
                              dynamic_ncols=True)
        for step, batch in enumerate(epoch_iterator):
            step += 1
            x, y = (batch["image"].cuda(), batch["label"].cuda())
            logit_map = model(x)
            loss = loss_function(logit_map, y)
            loss.backward()
            epoch_loss += loss.item()
            optimizer.step()
            optimizer.zero_grad()
            epoch_iterator.set_description(
                "Training (%d / %d Steps) (loss=%2.5f)" %
                (global_step, max_iterations, loss))

            if (global_step % eval_num == 0
                    and global_step != 0) or global_step == max_iterations:
                epoch_iterator_val = tqdm(
                    val_loader,
                    desc="Validate (X / X Steps) (dice=X.X)",
                    dynamic_ncols=True)
                dice_val = validation(epoch_iterator_val)

                epoch_loss /= step
                epoch_loss_values.append(epoch_loss)
                metric_values.append(dice_val)
                if dice_val > dice_val_best:
                    dice_val_best = dice_val
                    global_step_best = global_step
                    torch.save(model.state_dict(),
                               os.path.join(logdir, "best_metric_model.pth"))
                    print(
                        "Model Was Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}"
                        .format(dice_val_best, dice_val))
                else:
                    print(
                        "Model Was Not Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}"
                        .format(dice_val_best, dice_val))

                plt.figure(1, (12, 6))
                plt.subplot(1, 2, 1)
                plt.title("Iteration Average Loss")
                x = [eval_num * (i + 1) for i in range(len(epoch_loss_values))]
                y = epoch_loss_values
                plt.xlabel("Iteration")
                plt.plot(x, y)
                plt.grid()
                plt.subplot(1, 2, 2)
                plt.title("Val Mean Dice")
                x = [eval_num * (i + 1) for i in range(len(metric_values))]
                y = metric_values
                plt.xlabel("Iteration")
                plt.plot(x, y)
                plt.grid()
                plt.savefig(
                    os.path.join(logdir, 'btcv_finetune_quick_update.png'))
                plt.clf()
                plt.close(1)

            global_step += 1
        return global_step, dice_val_best, global_step_best

    while global_step < max_iterations:
        global_step, dice_val_best, global_step_best = train(
            global_step, train_loader, dice_val_best, global_step_best)
    model.load_state_dict(
        torch.load(os.path.join(logdir, "best_metric_model.pth")))

    print(f"train completed, best_metric: {dice_val_best:.4f} "
          f"at iteration: {global_step_best}")
def tune_weight_decay_network(training_loader, network, weights, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), root_path=''):
    network.cuda(device)

    weights_plots = {}

    # Train network per weight to tune
    for weight in weights:
        network_cur_weight_iter = copy.deepcopy(network)

        optimizer = optim.Adam(network_cur_weight_iter.parameters(), lr=5e-5, weight_decay=weight)
        dice_ce_loss = DiceCELoss(include_background=False, ce_weight=CE_WEIGHTS)

        # Train the network
        network_cur_weight_iter.train(True)

        train_step = 1
        batch_loss = {}
        for epoch in range(3):
            for idx, batch_data in enumerate(training_loader):
                print(f'Weight: {weight} \tTraining Step: {train_step}/{len(training_loader)}')

                torch.cuda.empty_cache()  # Clear any unused variables
                inputs = batch_data["image"].to(device)
                labels = batch_data["label"]  # Only pass to CUDA when required - preserve memory

                # Zero the parameter gradients
                optimizer.zero_grad()

                # Feed input data into the network to train
                outputs = network_cur_weight_iter(inputs)

                # Input no longer in use for current iteration - clear from CUDA memory
                inputs = inputs.cpu()
                torch.cuda.empty_cache()

                # labels to CUDA
                labels = batch_data["label"].to(device)
                torch.cuda.empty_cache()

                # Calculate DICE CE loss, permute tensors to correct dimensions
                loss = dice_ce_loss(outputs.permute(0, 1, 3, 4, 2), labels.permute(0, 1, 3, 4, 2))

                # List of losses for current batch
                batch_loss[train_step - 1] = loss.detach().cpu().numpy()

                # Clear CUDA memory
                labels = labels.cpu()
                torch.cuda.empty_cache()

                # Backward pass
                loss.backward()

                # Optimize
                optimizer.step()

                train_step += 1

            # Store loss against the weight decay parameter
            weights_plots[str(weight)] = batch_loss

    return weights_plots
def train_network(training_loader, val_loader, network, pre_load_training=False, checkpoint_name='', device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), root_path='', EPOCHS=10):
    network.cuda(device)

    optimizer = optim.Adam(network.parameters(), lr=5e-5, weight_decay=1e-3)
    # COMMENTED OUT - scheduler to increment the optimizer learning rates.
    # steps = lambda epoch: 1.25
    # scheduler = optim.lr_scheduler.MultiplicativeLR(optimizer, steps)

    dice_ce_loss = DiceCELoss(include_background=False, ce_weight=CE_WEIGHTS)

    epoch_checkpoint = 0

    losses = {}
    val_losses = {}

    # Test Learning rate dictionary for visualization
    scheduler_learning_rate_dict = {}

    if pre_load_training:
        checkpoint = torch.load(root_path + f'/{checkpoint_name}.pt')
        epoch_checkpoint = checkpoint['epoch'] + 1
        network.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        loss = checkpoint['loss']
        losses = checkpoint['losses']
        val_losses = checkpoint['val_losses']
        # Only used to find learning rate
        # scheduler_learning_rate_dict = checkpoint['scheduler_learning_rate_dict']

    # Train the network
    for epoch in range(epoch_checkpoint, EPOCHS):
        network.train(True)

        print(f'losses: {losses}')
        print(f'val losses {val_losses}')

        train_step = 1
        batch_loss = []

        for batch_data in training_loader:
            print(f'Epoch {epoch}\tTraining Step: {train_step}/{len(training_loader)}')

            torch.cuda.empty_cache()  # Clear any unused variables
            inputs = batch_data["image"].to(device)
            labels = batch_data["label"]  # Only pass to CUDA when required - preserve memory

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Feed input data into the network to train
            outputs = network(inputs)

            # Input no longer in use for current iteration - clear from CUDA memory
            inputs = inputs.cpu()
            torch.cuda.empty_cache()

            # labels to CUDA
            labels = batch_data["label"].to(device)
            torch.cuda.empty_cache()

            # Calculate DICE CE loss, permute tensors to correct dimensions
            loss = dice_ce_loss(outputs.permute(0, 1, 3, 4, 2), labels.permute(0, 1, 3, 4, 2))

            # COMMENTED OUT - Store learning rate variables and plot to fine tune lr hyperparamter
            # current_learning_rate = optimizer.param_groups[0]['lr']
            # print(f'type dict {type(scheduler_learning_rate_dict)}. type loss {type(loss)}')
            # scheduler_learning_rate_dict[current_learning_rate] = loss

            # List of losses for current batch
            batch_loss.append(loss.detach().cpu().numpy())

            # Clear CUDA memory
            labels = labels.cpu()
            torch.cuda.empty_cache()

            # Backward pass
            loss.backward()

            # Optimize
            optimizer.step()

            # COMMENTED OUT - UPDATE OPTIMIZER LEARNING RATES TO FIND BEST LEARNING RATE
            # Used for testing best lr
            # scheduler.step()

            train_step += 1

        # COMMENTED OUT - PLOT LOSS CHANGES WITH LEARNING RATES
        # =======================================================
        # Plot losscheduler_learning_rate_list = sorted(scheduler_learning_rate_dict.items())
        # x, y = zip(*scheduler_learning_rate_list)

        # plt.xscale('log')
        # plt.plot(x, y)
        # plt.xlabel('Learning Rate')
        # plt.ylabel('Loss')
        # plt.title('Training losses with varying learning rate')
        # plt.show()ses against learning rate
        # =======================================================

        # Get average loss for current batch
        losses[epoch] = np.mean(batch_loss)
        print(f'train losses {batch_loss} \nmean loss {losses[epoch]}')

        if epoch % 2 == 0:
            # Set network to eval mode
            network.train(False)
            # Disiable gradient calculation and optimise memory
            with torch.no_grad():
                # Initialise validation loss
                dice_ce_test_loss = 0
                for i, batch_data in enumerate(val_loader):
                    # Get inputs and labels from validation set
                    inputs = batch_data["image"].to(device)
                    labels = batch_data["label"]

                    # Make prediction
                    # sw_batch_size = 2
                    # roi_size = (96, 96, 16)
                    # outputs = sliding_window_inference(
                    #   inputs, roi_size, sw_batch_size, network
                    # )
                    outputs = network(inputs)

                    # Memory optimization
                    inputs = inputs.cpu()
                    torch.cuda.empty_cache()
                    labels = batch_data["label"].to(device)

                    # Accumulate DICE CE loss validation error
                    dice_ce_test_loss += dice_ce_loss(outputs.permute(0, 1, 3, 4, 2), labels.permute(0, 1, 3, 4, 2))

                # Get average validation DICE CE loss
                val_losses[epoch] = dice_ce_test_loss / i

                # Print errors
                print(
                    "==== Epoch: " + str(epoch) +
                    " | DICE CE loss: " + str(numpy_from_tensor(dice_ce_test_loss / i)) +
                    " | Total Loss: " + str(numpy_from_tensor((
                                                                  dice_ce_test_loss) / i)) + " =====")  # This is redundant code but will keep here incase we add more losses

                # View slice at halfway point
                half = outputs.shape[2] // 2

                # Show predictions for current iteration
                view_slice(numpy_from_tensor(inputs[0, 0, half, :, :]), f'Input Image Epoch {epoch}')
                view_slice(numpy_from_tensor(outputs[0, 0, half, :, :]), f'Predicted Background Epoch {epoch}')
                view_slice(numpy_from_tensor(outputs[0, 1, half, :, :]), f'Predicted Pancreas Epoch {epoch}')
                view_slice(numpy_from_tensor(outputs[0, 2, half, :, :]), f'Predicted Cancer Epoch {epoch}')
                view_slice(numpy_from_tensor(labels[0, 0, half, :, :]), f'Labels Background Epoch {epoch}')
                view_slice(numpy_from_tensor(labels[0, 1, half, :, :]), f'Labels Pancreas Epoch {epoch}')
                view_slice(numpy_from_tensor(labels[0, 2, half, :, :]), f'Labels Cancer Epoch {epoch}')

        # Save training checkpoint
        torch.save({
            'epoch': epoch,
            'model_state_dict': network.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            'losses': losses,
            'val_losses': val_losses,
            # 'scheduler_learning_rate_dict':scheduler_learning_rate_dict
        }, root_path + f'/{checkpoint_name}.pt')

        # Confirm current epoch trained params are saved
        print(f'Saved for epoch {epoch}')

    return network
    def test_train_timing(self):
        images = sorted(glob(os.path.join(self.data_dir, "img*.nii.gz")))
        segs = sorted(glob(os.path.join(self.data_dir, "seg*.nii.gz")))
        train_files = [{
            "image": img,
            "label": seg
        } for img, seg in zip(images[:32], segs[:32])]
        val_files = [{
            "image": img,
            "label": seg
        } for img, seg in zip(images[-9:], segs[-9:])]

        device = torch.device("cuda:0")
        # define transforms for train and validation
        train_transforms = Compose([
            LoadImaged(keys=["image", "label"]),
            EnsureChannelFirstd(keys=["image", "label"]),
            Spacingd(keys=["image", "label"],
                     pixdim=(1.0, 1.0, 1.0),
                     mode=("bilinear", "nearest")),
            ScaleIntensityd(keys="image"),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            # pre-compute foreground and background indexes
            # and cache them to accelerate training
            FgBgToIndicesd(keys="label", fg_postfix="_fg", bg_postfix="_bg"),
            # change to execute transforms with Tensor data
            EnsureTyped(keys=["image", "label"]),
            # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch
            ToDeviced(keys=["image", "label"], device=device),
            # randomly crop out patch samples from big
            # image based on pos / neg ratio
            # the image centers of negative samples
            # must be in valid image area
            RandCropByPosNegLabeld(
                keys=["image", "label"],
                label_key="label",
                spatial_size=(64, 64, 64),
                pos=1,
                neg=1,
                num_samples=4,
                fg_indices_key="label_fg",
                bg_indices_key="label_bg",
            ),
            RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=[1, 2]),
            RandAxisFlipd(keys=["image", "label"], prob=0.5),
            RandRotate90d(keys=["image", "label"],
                          prob=0.5,
                          spatial_axes=(1, 2)),
            RandZoomd(keys=["image", "label"],
                      prob=0.5,
                      min_zoom=0.8,
                      max_zoom=1.2,
                      keep_size=True),
            RandRotated(
                keys=["image", "label"],
                prob=0.5,
                range_x=np.pi / 4,
                mode=("bilinear", "nearest"),
                align_corners=True,
                dtype=np.float64,
            ),
            RandAffined(keys=["image", "label"],
                        prob=0.5,
                        rotate_range=np.pi / 2,
                        mode=("bilinear", "nearest")),
            RandGaussianNoised(keys="image", prob=0.5),
            RandStdShiftIntensityd(keys="image",
                                   prob=0.5,
                                   factors=0.05,
                                   nonzero=True),
        ])

        val_transforms = Compose([
            LoadImaged(keys=["image", "label"]),
            EnsureChannelFirstd(keys=["image", "label"]),
            Spacingd(keys=["image", "label"],
                     pixdim=(1.0, 1.0, 1.0),
                     mode=("bilinear", "nearest")),
            ScaleIntensityd(keys="image"),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            EnsureTyped(keys=["image", "label"]),
            # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch
            ToDeviced(keys=["image", "label"], device=device),
        ])

        max_epochs = 5
        learning_rate = 2e-4
        val_interval = 1  # do validation for every epoch

        # set CacheDataset, ThreadDataLoader and DiceCE loss for MONAI fast training
        train_ds = CacheDataset(data=train_files,
                                transform=train_transforms,
                                cache_rate=1.0,
                                num_workers=8)
        val_ds = CacheDataset(data=val_files,
                              transform=val_transforms,
                              cache_rate=1.0,
                              num_workers=5)
        # disable multi-workers because `ThreadDataLoader` works with multi-threads
        train_loader = ThreadDataLoader(train_ds,
                                        num_workers=0,
                                        batch_size=4,
                                        shuffle=True)
        val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1)

        loss_function = DiceCELoss(to_onehot_y=True,
                                   softmax=True,
                                   squared_pred=True,
                                   batch=True)
        model = UNet(
            spatial_dims=3,
            in_channels=1,
            out_channels=2,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2,
            norm=Norm.BATCH,
        ).to(device)

        # Novograd paper suggests to use a bigger LR than Adam,
        # because Adam does normalization by element-wise second moments
        optimizer = Novograd(model.parameters(), learning_rate * 10)
        scaler = torch.cuda.amp.GradScaler()

        post_pred = Compose(
            [EnsureType(), AsDiscrete(argmax=True, to_onehot=2)])
        post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)])

        dice_metric = DiceMetric(include_background=True,
                                 reduction="mean",
                                 get_not_nans=False)

        best_metric = -1
        total_start = time.time()
        for epoch in range(max_epochs):
            epoch_start = time.time()
            print("-" * 10)
            print(f"epoch {epoch + 1}/{max_epochs}")
            model.train()
            epoch_loss = 0
            step = 0
            for batch_data in train_loader:
                step_start = time.time()
                step += 1
                optimizer.zero_grad()
                # set AMP for training
                with torch.cuda.amp.autocast():
                    outputs = model(batch_data["image"])
                    loss = loss_function(outputs, batch_data["label"])
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                epoch_loss += loss.item()
                epoch_len = math.ceil(len(train_ds) / train_loader.batch_size)
                print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}"
                      f" step time: {(time.time() - step_start):.4f}")
            epoch_loss /= step
            print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

            if (epoch + 1) % val_interval == 0:
                model.eval()
                with torch.no_grad():
                    for val_data in val_loader:
                        roi_size = (96, 96, 96)
                        sw_batch_size = 4
                        # set AMP for validation
                        with torch.cuda.amp.autocast():
                            val_outputs = sliding_window_inference(
                                val_data["image"], roi_size, sw_batch_size,
                                model)

                        val_outputs = [
                            post_pred(i) for i in decollate_batch(val_outputs)
                        ]
                        val_labels = [
                            post_label(i)
                            for i in decollate_batch(val_data["label"])
                        ]
                        dice_metric(y_pred=val_outputs, y=val_labels)

                    metric = dice_metric.aggregate().item()
                    dice_metric.reset()
                    if metric > best_metric:
                        best_metric = metric
                    print(
                        f"epoch: {epoch + 1} current mean dice: {metric:.4f}, best mean dice: {best_metric:.4f}"
                    )
            print(
                f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}"
            )

        total_time = time.time() - total_start
        print(
            f"train completed, best_metric: {best_metric:.4f} total time: {total_time:.4f}"
        )
        # test expected metrics
        self.assertGreater(best_metric, 0.95)
Example #12
0
# standard PyTorch program style: create UNet, DiceLoss and Adam optimizer
device = torch.device("cuda:0")
max_epochs = 6
learning_rate = 1e-4
val_interval = 2
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)
loss_function = DiceCELoss(
    to_onehot_y=True, softmax=True, squared_pred=True, batch=True
)
optimizer = Novograd(model.parameters(), learning_rate * 10)
scaler = torch.cuda.amp.GradScaler()
dice_metric = DiceMetric(
    include_background=True, reduction="mean", get_not_nans=False
)

post_pred = Compose(
    [EnsureType(), AsDiscrete(argmax=True, to_onehot=2)]
)
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)])

best_metric = -1
best_metric_epoch = -1
best_metrics_epochs_and_time = [[], [], []]
Example #13
0
 def loss_function(self, context: Context):
     return DiceCELoss(to_onehot_y=True,
                       softmax=True,
                       squared_pred=True,
                       batch=True)
Example #14
0
    def run_interaction(self, train):
        label_names = {"spleen": 1, "background": 0}
        np.random.seed(0)
        data = [
            {
                "image": np.random.randint(0, 256, size=(1, 15, 15, 15)).astype(np.float32),
                "label": np.random.randint(0, 2, size=(1, 15, 15, 15)),
                "label_names": label_names,
            }
            for _ in range(5)
        ]
        network = torch.nn.Conv3d(3, len(label_names), 1)
        lr = 1e-3
        opt = torch.optim.Adam(network.parameters(), lr)
        loss = DiceCELoss(to_onehot_y=True, softmax=True)
        pre_transforms = Compose(
            [
                FindAllValidSlicesMissingLabelsd(keys="label", sids="sids"),
                AddInitialSeedPointMissingLabelsd(keys="label", guidance="guidance", sids="sids"),
                AddGuidanceSignalDeepEditd(keys="image", guidance="guidance", number_intensity_ch=1),
                ToTensord(keys=("image", "label")),
            ]
        )
        dataset = Dataset(data, transform=pre_transforms)
        data_loader = torch.utils.data.DataLoader(dataset, batch_size=5)

        iteration_transforms = [
            FindDiscrepancyRegionsDeepEditd(keys="label", pred="pred", discrepancy="discrepancy"),
            AddRandomGuidanceDeepEditd(
                keys="NA", guidance="guidance", discrepancy="discrepancy", probability="probability"
            ),
            AddGuidanceSignalDeepEditd(keys="image", guidance="guidance", number_intensity_ch=1),
            ToTensord(keys=("image", "label")),
        ]
        post_transforms = [
            Activationsd(keys="pred", softmax=True),
            AsDiscreted(keys=("pred", "label"), argmax=(True, False), to_onehot=len(label_names)),
            SplitPredsLabeld(keys="pred"),
            ToTensord(keys=("image", "label")),
        ]
        iteration_transforms = Compose(iteration_transforms)
        post_transforms = Compose(post_transforms)

        i = Interaction(
            deepgrow_probability=1.0,
            transforms=iteration_transforms,
            click_probability_key="probability",
            train=train,
            label_names=label_names,
        )
        self.assertEqual(len(i.transforms.transforms), 4, "Mismatch in expected transforms")

        # set up engine
        engine = SupervisedTrainer(
            device=torch.device("cpu"),
            max_epochs=1,
            train_data_loader=data_loader,
            network=network,
            optimizer=opt,
            loss_function=loss,
            postprocessing=post_transforms,
            iteration_update=i,
        )
        engine.add_event_handler(IterationEvents.INNER_ITERATION_STARTED, add_one)
        engine.add_event_handler(IterationEvents.INNER_ITERATION_COMPLETED, add_one)

        engine.run()
        self.assertIsNotNone(engine.state.batch[0].get("guidance"), "guidance is missing")
        self.assertEqual(engine.state.best_metric, 1)