Exemple #1
0
def run_test(batch_size=64, train_steps=100, device=torch.device("cuda:0")):
    class _TestBatch(Dataset):
        def __getitem__(self, _unused_id):
            im, seg = create_test_image_2d(128,
                                           128,
                                           noise_max=1,
                                           num_objs=4,
                                           num_seg_classes=1)
            return im[None], seg[None].astype(np.float32)

        def __len__(self):
            return train_steps

    net = UNet(
        dimensions=2,
        in_channels=1,
        out_channels=1,
        channels=(4, 8, 16, 32),
        strides=(2, 2, 2),
        num_res_units=2,
    ).to(device)

    loss = DiceLoss(do_sigmoid=True)
    opt = torch.optim.Adam(net.parameters(), 1e-4)
    src = DataLoader(_TestBatch(), batch_size=batch_size)

    trainer = create_supervised_trainer(net, opt, loss, device, False)

    trainer.run(src, 1)
    loss = trainer.state.output
    return loss
Exemple #2
0
    def test_epistemic_scoring(self):
        input_size = (20, 20, 20)
        device = "cuda" if torch.cuda.is_available() else "cpu"
        keys = ["image", "label"]
        num_training_ims = 10
        train_data = self.get_data(num_training_ims, input_size)
        test_data = self.get_data(1, input_size)

        transforms = Compose([
            AddChanneld(keys),
            CropForegroundd(keys, source_key="image"),
            DivisiblePadd(keys, 4),
        ])

        infer_transforms = Compose([
            AddChannel(),
            CropForeground(),
            DivisiblePad(4),
        ])

        train_ds = CacheDataset(train_data, transforms)
        # output might be different size, so pad so that they match
        train_loader = DataLoader(train_ds,
                                  batch_size=2,
                                  collate_fn=pad_list_data_collate)

        model = UNet(3, 1, 1, channels=(6, 6), strides=(2, 2)).to(device)
        loss_function = DiceLoss(sigmoid=True)
        optimizer = torch.optim.Adam(model.parameters(), 1e-3)

        num_epochs = 10
        for _ in trange(num_epochs):
            epoch_loss = 0

            for batch_data in train_loader:
                inputs, labels = batch_data["image"].to(
                    device), batch_data["label"].to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = loss_function(outputs, labels)
                loss.backward()
                optimizer.step()
                epoch_loss += loss.item()

            epoch_loss /= len(train_loader)

        entropy_score = EpistemicScoring(model=model,
                                         transforms=infer_transforms,
                                         roi_size=[20, 20, 20],
                                         num_samples=10)
        # Call Individual Infer from Epistemic Scoring
        ip_stack = [test_data["image"], test_data["image"], test_data["image"]]
        ip_stack = np.array(ip_stack)
        score_3d = entropy_score.entropy_3d_volume(ip_stack)
        score_3d_sum = np.sum(score_3d)
        # Call Entropy Metric from Epistemic Scoring
        self.assertEqual(score_3d.shape, input_size)
        self.assertIsInstance(score_3d_sum, np.float32)
        self.assertGreater(score_3d_sum, 3.0)
def run_test(batch_size=64, train_steps=200, device=torch.device("cuda:0")):
    class _TestBatch(Dataset):
        def __init__(self, transforms):
            self.transforms = transforms

        def __getitem__(self, _unused_id):
            im, seg = create_test_image_2d(128,
                                           128,
                                           noise_max=1,
                                           num_objs=4,
                                           num_seg_classes=1)
            seed = np.random.randint(2147483647)
            self.transforms.set_random_state(seed=seed)
            im = self.transforms(im)
            self.transforms.set_random_state(seed=seed)
            seg = self.transforms(seg)
            return im, seg

        def __len__(self):
            return train_steps

    net = UNet(
        dimensions=2,
        in_channels=1,
        out_channels=1,
        channels=(4, 8, 16, 32),
        strides=(2, 2, 2),
        num_res_units=2,
    ).to(device)

    loss = DiceLoss(do_sigmoid=True)
    opt = torch.optim.Adam(net.parameters(), 1e-2)
    train_transforms = Compose([
        AddChannel(),
        ScaleIntensity(),
        RandSpatialCrop((96, 96), random_size=False),
        RandRotate90(),
        ToTensor()
    ])

    src = DataLoader(_TestBatch(train_transforms),
                     batch_size=batch_size,
                     shuffle=True)

    net.train()
    epoch_loss = 0
    step = 0
    for img, seg in src:
        step += 1
        opt.zero_grad()
        output = net(img.to(device))
        step_loss = loss(output, seg.to(device))
        step_loss.backward()
        opt.step()
        epoch_loss += step_loss.item()
    epoch_loss /= step

    return epoch_loss, step
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)
#loss_function = DiceLoss(to_onehot_y=True, softmax=True)
#optimizer = torch.optim.Adam(model.parameters(), 1e-4)

loss_function = DiceCELoss(include_background=True,
                           to_onehot_y=True,
                           softmax=True,
                           lambda_dice=0.5,
                           lambda_ce=0.5)
optimizer = torch.optim.Adam(model.parameters(), 1e-3)
dice_metric = DiceMetric(include_background=False, reduction="mean")
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                       'max',
                                                       factor=0.5)  ##
"""## Execute a typical PyTorch training process"""

epoch_num = 300
val_interval = 2
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = list()
metric_values = list()
post_pred = AsDiscrete(argmax=True, to_onehot=True, n_classes=2)
post_label = AsDiscrete(to_onehot=True, n_classes=2)
Exemple #5
0
def main_worker(args):
    # disable logging for processes except 0 on every node
    if args.local_rank != 0:
        f = open(os.devnull, "w")
        sys.stdout = sys.stderr = f
    if not os.path.exists(args.dir):
        raise FileNotFoundError(f"Missing directory {args.dir}")

    # initialize the distributed training process, every GPU runs in a process
    dist.init_process_group(backend="nccl", init_method="env://")

    total_start = time.time()
    train_transforms = Compose([
        # load 4 Nifti images and stack them together
        LoadNiftid(keys=["image", "label"]),
        AsChannelFirstd(keys="image"),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Spacingd(keys=["image", "label"],
                 pixdim=(1.5, 1.5, 2.0),
                 mode=("bilinear", "nearest")),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        RandSpatialCropd(keys=["image", "label"],
                         roi_size=[128, 128, 64],
                         random_size=False),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
        RandScaleIntensityd(keys="image", factors=0.1, prob=0.5),
        RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
        ToTensord(keys=["image", "label"]),
    ])

    # create a training data loader
    train_ds = BratsCacheDataset(
        root_dir=args.dir,
        transform=train_transforms,
        section="training",
        num_workers=4,
        cache_rate=args.cache_rate,
        shuffle=True,
    )
    train_loader = DataLoader(train_ds,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers,
                              pin_memory=True)

    # validation transforms and dataset
    val_transforms = Compose([
        LoadNiftid(keys=["image", "label"]),
        AsChannelFirstd(keys="image"),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Spacingd(keys=["image", "label"],
                 pixdim=(1.5, 1.5, 2.0),
                 mode=("bilinear", "nearest")),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        CenterSpatialCropd(keys=["image", "label"], roi_size=[128, 128, 64]),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        ToTensord(keys=["image", "label"]),
    ])
    val_ds = BratsCacheDataset(
        root_dir=args.dir,
        transform=val_transforms,
        section="validation",
        num_workers=4,
        cache_rate=args.cache_rate,
        shuffle=False,
    )
    val_loader = DataLoader(val_ds,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=True)

    if dist.get_rank() == 0:
        # Logging for TensorBoard
        writer = SummaryWriter(log_dir=args.log_dir)

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device(f"cuda:{args.local_rank}")
    if args.network == "UNet":
        model = UNet(
            dimensions=3,
            in_channels=4,
            out_channels=3,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2,
        ).to(device)
    else:
        model = SegResNet(in_channels=4,
                          out_channels=3,
                          init_filters=16,
                          dropout_prob=0.2).to(device)
    loss_function = DiceLoss(to_onehot_y=False,
                             sigmoid=True,
                             squared_pred=True)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=1e-5,
                                 amsgrad=True)
    # wrap the model with DistributedDataParallel module
    model = DistributedDataParallel(model, device_ids=[args.local_rank])

    # start a typical PyTorch training
    total_epoch = args.epochs
    best_metric = -1000000
    best_metric_epoch = -1
    epoch_time = AverageMeter("Time", ":6.3f")
    progress = ProgressMeter(total_epoch, [epoch_time], prefix="Epoch: ")
    end = time.time()
    print(f"Time elapsed before training: {end-total_start}")
    for epoch in range(total_epoch):

        train_loss = train(train_loader, model, loss_function, optimizer,
                           epoch, args, device)
        epoch_time.update(time.time() - end)

        if epoch % args.print_freq == 0:
            progress.display(epoch)

        if dist.get_rank() == 0:
            writer.add_scalar("Loss/train", train_loss, epoch)

        if (epoch + 1) % args.val_interval == 0:
            metric, metric_tc, metric_wt, metric_et = evaluate(
                model, val_loader, device)

            if dist.get_rank() == 0:
                writer.add_scalar("Mean Dice/val", metric, epoch)
                writer.add_scalar("Mean Dice TC/val", metric_tc, epoch)
                writer.add_scalar("Mean Dice WT/val", metric_wt, epoch)
                writer.add_scalar("Mean Dice ET/val", metric_et, epoch)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                print(
                    f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                    f" tc: {metric_tc:.4f} wt: {metric_wt:.4f} et: {metric_et:.4f}"
                    f"\nbest mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}"
                )
        end = time.time()
        print(f"Time elapsed after epoch {epoch + 1} is {end - total_start}")

    if dist.get_rank() == 0:
        print(
            f"train completed, best_metric: {best_metric:.4f}  at epoch: {best_metric_epoch}"
        )
        # all processes should see same parameters as they all start from same
        # random parameters and gradients are synchronized in backward passes,
        # therefore, saving it in one process is sufficient
        torch.save(model.state_dict(), "final_model.pth")
        writer.flush()
    dist.destroy_process_group()
    def test_test_time_augmentation(self):
        input_size = (20, 40)  # test different input data shape to pad list collate
        keys = ["image", "label"]
        num_training_ims = 10

        train_data = self.get_data(num_training_ims, input_size)
        test_data = self.get_data(1, input_size)
        device = "cuda" if torch.cuda.is_available() else "cpu"

        transforms = Compose(
            [
                AddChanneld(keys),
                RandAffined(
                    keys,
                    prob=1.0,
                    spatial_size=(30, 30),
                    rotate_range=(np.pi / 3, np.pi / 3),
                    translate_range=(3, 3),
                    scale_range=((0.8, 1), (0.8, 1)),
                    padding_mode="zeros",
                    mode=("bilinear", "nearest"),
                    as_tensor_output=False,
                ),
                CropForegroundd(keys, source_key="image"),
                DivisiblePadd(keys, 4),
            ]
        )

        train_ds = CacheDataset(train_data, transforms)
        # output might be different size, so pad so that they match
        train_loader = DataLoader(train_ds, batch_size=2, collate_fn=pad_list_data_collate)

        model = UNet(2, 1, 1, channels=(6, 6), strides=(2, 2)).to(device)
        loss_function = DiceLoss(sigmoid=True)
        optimizer = torch.optim.Adam(model.parameters(), 1e-3)

        num_epochs = 10
        for _ in trange(num_epochs):
            epoch_loss = 0

            for batch_data in train_loader:
                inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = loss_function(outputs, labels)
                loss.backward()
                optimizer.step()
                epoch_loss += loss.item()

            epoch_loss /= len(train_loader)

        post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])

        tt_aug = TestTimeAugmentation(
            transform=transforms,
            batch_size=5,
            num_workers=0,
            inferrer_fn=model,
            device=device,
            to_tensor=True,
            output_device="cpu",
            post_func=post_trans,
        )
        mode, mean, std, vvc = tt_aug(test_data)
        self.assertEqual(mode.shape, (1,) + input_size)
        self.assertEqual(mean.shape, (1,) + input_size)
        self.assertTrue(all(np.unique(mode) == (0, 1)))
        self.assertGreaterEqual(mean.min(), 0.0)
        self.assertLessEqual(mean.max(), 1.0)
        self.assertEqual(std.shape, (1,) + input_size)
        self.assertIsInstance(vvc, float)
Exemple #7
0
class UNet2DSegmenter(AbstractBaseLearner):
    """Segmenter based on the U-Net architecture."""

    def __init__(
            self,
            architecture: SegmentationArchitectures = SegmentationArchitectures.ResidualUNet2D,
            loss: SegmentationLosses = SegmentationLosses.GeneralizedDiceLoss,
            optimizer: Optimizers = Optimizers.Adam,
            mask_type: MaskType = MaskType.TIFF_LABELS,
            in_channels: int = 1,
            out_channels: int = 3,
            roi_size: Tuple[int, int] = (384, 384),
            num_filters_in_first_layer: int = 16,
            learning_rate: float = 0.001,
            weight_decay: float = 0.0001,
            momentum: float = 0.9,
            num_epochs: int = 400,
            batch_sizes: Tuple[int, int, int, int] = (8, 1, 1, 1),
            num_workers: Tuple[int, int, int, int] = (4, 4, 1, 1),
            validation_step: int = 2,
            sliding_window_batch_size: int = 4,
            class_names: Tuple[str, ...] = ("Background", "Object", "Border"),
            experiment_name: str = "Unet",
            model_name: str = "best_model",
            seed: int = 4294967295,
            working_dir: str = '.',
            stdout: TextIOWrapper = sys.stdout,
            stderr: TextIOWrapper = sys.stderr
    ):
        """Constructor.

        @param mask_type: MaskType
            Type of mask: defines file type, mask geometry and they way pixels
            are assigned to the various classes.

            @see qu.data.model.MaskType

        @param architecture: SegmentationArchitectures
            Core network architecture: one of (SegmentationArchitectures.ResidualUNet2D, SegmentationArchitectures.AttentionUNet2D)

        @param loss: SegmentationLosses
            Loss function: currently only SegmentationLosses.GeneralizedDiceLoss is supported

        @param optimizer: Optimizers
            Optimizer: one of (Optimizers.Adam, Optimizers.SGD)

        @param in_channels: int, optional: default = 1
            Number of channels in the input (e.g. 1 for gray-value images).

        @param out_channels: int, optional: default = 3
            Number of channels in the output (classes).

        @param roi_size: Tuple[int, int], optional: default = (384, 384)
            Crop area (and input size of the U-Net network) used for training and validation/prediction.

        @param num_filters_in_first_layer: int
            Number of filters in the first layer. Every subsequent layer doubles the number of filters.

        @param learning_rate: float, optional: default = 1e-3
            Initial learning rate for the optimizer.

        @param weight_decay: float, optional: default = 1e-4
            Weight decay of the learning rate for the optimizer.
            Used by the Adam optimizer.

        @param momentum: float, optional: default = 0.9
            Momentum of the accelerated gradient for the optimizer.
            Used by the SGD optimizer.

        @param num_epochs: int, optional: default = 400
            Number of epochs for training.

        @param batch_sizes: Tuple[int, int, int], optional: default = (8, 1, 1, 1)
            Batch sizes for training, validation, testing, and prediction, respectively.

        @param num_workers: Tuple[int, int, int], optional: default = (4, 4, 1, 1)
            Number of workers for training, validation, testing, and prediction, respectively.

        @param validation_step: int, optional: default = 2
            Number of training steps before the next validation is performed.

        @param sliding_window_batch_size: int, optional: default = 4
            Number of batches for sliding window inference during validation and prediction.

        @param class_names: Tuple[str, ...], optional: default = ("Background", "Object", "Border")
            Name of the classes for logging validation curves.

        @param experiment_name: str, optional: default = ""
            Name of the experiment that maps to the folder that contains training information (to
            be used by tensorboard). Please note, current datetime will be appended.

        @param model_name: str, optional: default = "best_model.ph"
            Name of the file that stores the best model. Please note, current datetime will be appended
            (before the extension).

        @param seed: int, optional; default = 4294967295
            Set random seed for modules to enable or disable deterministic training.

        @param working_dir: str, optional, default = "."
            Working folder where to save the model weights and the logs for tensorboard.

        """

        # Call base constructor
        super().__init__()

        # Standard pipe wrappers
        self._stdout = stdout
        self._stderr = stderr

        # Device (initialize as "cpu")
        self._device = "cpu"

        # Architecture, loss function and optimizer
        self._option_architecture = architecture
        self._option_loss = loss
        self._option_optimizer = optimizer
        self._learning_rate = learning_rate
        self._weight_decay = weight_decay
        self._momentum = momentum

        # Mask type
        self._mask_type = mask_type

        # Input and output channels
        self._in_channels = in_channels
        self._out_channels = out_channels

        # Define hyper parameters
        self._roi_size = roi_size
        self._num_filters_in_first_layer = num_filters_in_first_layer
        self._training_batch_size = batch_sizes[0]
        self._validation_batch_size = batch_sizes[1]
        self._test_batch_size = batch_sizes[2]
        self._prediction_batch_size = batch_sizes[3]
        self._training_num_workers = num_workers[0]
        self._validation_num_workers = num_workers[1]
        self._test_num_workers = num_workers[2]
        self._prediction_num_workers = num_workers[3]
        self._n_epochs = num_epochs
        self._validation_step = validation_step
        self._sliding_window_batch_size = sliding_window_batch_size

        # Other parameters
        self._class_names = out_channels * ["Unknown"]
        for i in range(min(out_channels, len(class_names))):
            self._class_names[i] = class_names[i]

        # Set monai seed
        set_determinism(seed=seed)

        # All file names
        self._train_image_names: list = []
        self._train_mask_names: list = []
        self._validation_image_names: list = []
        self._validation_mask_names: list = []
        self._test_image_names: list = []
        self._test_mask_names: list = []

        # Transforms
        self._train_image_transforms = None
        self._train_mask_transforms = None
        self._validation_image_transforms = None
        self._validation_mask_transforms = None
        self._test_image_transforms = None
        self._test_mask_transforms = None
        self._prediction_image_transforms = None
        self._validation_post_transforms = None
        self._test_post_transforms = None
        self._prediction_post_transforms = None

        # Datasets and data loaders
        self._train_dataset = None
        self._train_dataloader = None
        self._validation_dataset = None
        self._validation_dataloader = None
        self._test_dataset = None
        self._test_dataloader = None
        self._prediction_dataset = None
        self._prediction_dataloader = None

        # Set model architecture, loss function, metric and optimizer
        self._model = None
        self._training_loss_function = None
        self._optimizer = None
        self._validation_metric = None

        # Working directory, model file name and experiment name for Tensorboard logs.
        # The file names will be redefined at the beginning of the training.
        self._working_dir = Path(working_dir).resolve()
        self._raw_experiment_name = experiment_name
        self._raw_model_file_name = model_name

        # Keep track of the full path of the best model
        self._best_model = ''

        # Keep track of last error message
        self._message = ""

    def train(self) -> bool:
        """Run training in a separate thread (added to the global application ThreadPool)."""

        # Free memory on the GPU
        self._clear_session()

        # Check that the data is set properly
        if len(self._train_image_names) == 0 or \
                len(self._train_mask_names) == 0 or \
                len(self._validation_image_names) == 0 or \
                len(self._validation_mask_names) == 0:
            self._message = "No training/validation data found."
            return False

        if len(self._train_image_names) != len(self._train_mask_names) == 0:
            self._message = "The number of training images does not match the number of training masks."
            return False

        if len(self._validation_image_names) != len(self._validation_mask_names) == 0:
            self._message = "The number of validation images does not match the number of validation masks."
            return False

        # Define the transforms
        self._define_training_transforms()

        # Define the datasets and data loaders
        self._define_training_data_loaders()

        # Instantiate the model
        self._define_model()

        # Define the loss function
        self._define_training_loss()

        # Define the optimizer (with default parameters)
        self._define_optimizer()

        # Define the validation metric
        self._define_validation_metric()

        # Define experiment name and model name
        experiment_name, model_file_name = self._prepare_experiment_and_model_names()

        # Keep track of the best model file name
        self._best_model = model_file_name

        # Enter the main training loop
        best_metric = -1
        best_metric_epoch = -1

        epoch_loss_values = list()
        metric_values = list()

        # Initialize TensorBoard's SummaryWriter
        writer = SummaryWriter(experiment_name)

        for epoch in range(self._n_epochs):

            # Inform
            self._print_header(f"Epoch {epoch + 1}/{self._n_epochs}")

            # Switch to training mode
            self._model.train()

            epoch_loss = 0
            step = 0
            for batch_data in self._train_dataloader:

                # Update step
                step += 1

                # Get the next batch and move it to device
                inputs, labels = batch_data[0].to(self._device), batch_data[1].to(self._device)

                # Zero the gradient buffers
                self._optimizer.zero_grad()

                # Forward pass
                outputs = self._model(inputs)

                # Calculate the loss
                loss = self._training_loss_function(outputs, labels)

                # Back-propagate
                loss.backward()

                # Update weights (optimize)
                self._optimizer.step()

                # Update and store metrics
                epoch_loss += loss.item()
                epoch_len = len(self._train_dataset) / self._train_dataloader.batch_size
                if epoch_len != int(epoch_len):
                    epoch_len = int(epoch_len) + 1

                print(f"Batch {step}/{epoch_len}: train_loss = {loss.item():.4f}", file=self._stdout)

            epoch_loss /= step
            epoch_loss_values.append(epoch_loss)
            print(f"Average loss = {epoch_loss:.4f}", file=self._stdout)
            writer.add_scalar("average_train_loss", epoch_loss, epoch + 1)

            # Validation
            if (epoch + 1) % self._validation_step == 0:

                self._print_header("Validation")

                # Switch to evaluation mode
                self._model.eval()

                # Make sure not to update the gradients
                with torch.no_grad():

                    # Global metrics
                    metric_sum = 0.0
                    metric_count = 0
                    metric = 0.0

                    # Keep track of the metrics for all classes
                    metric_sum_classes = self._out_channels * [0.0]
                    metric_count_classes = self._out_channels * [0]
                    metric_classes = self._out_channels * [0.0]

                    for val_data in self._validation_dataloader:

                        # Get the next batch and move it to device
                        val_images, val_labels = val_data[0].to(self._device), val_data[1].to(self._device)

                        # Apply sliding inference over ROI size
                        val_outputs = sliding_window_inference(
                            val_images,
                            self._roi_size,
                            self._sliding_window_batch_size,
                            self._model
                        )
                        val_outputs = self._validation_post_transforms(val_outputs)

                        # Compute overall metric
                        value, not_nans = self._validation_metric(
                            y_pred=val_outputs,
                            y=val_labels
                        )
                        not_nans = not_nans.item()
                        metric_count += not_nans
                        metric_sum += value.item() * not_nans

                        # Compute metric for each class
                        for c in range(self._out_channels):
                            value_obj, not_nans = self._validation_metric(
                                y_pred=val_outputs[:, c:c + 1],
                                y=val_labels[:, c:c + 1]
                            )
                            not_nans = not_nans.item()
                            metric_count_classes[c] += not_nans
                            metric_sum_classes[c] += value_obj.item() * not_nans

                    # Global metric
                    metric = metric_sum / metric_count
                    metric_values.append(metric)

                    # Metric per class
                    for c in range(self._out_channels):
                        metric_classes[c] = metric_sum_classes[c] / metric_count_classes[c]

                    # Print summary
                    print(f"Global metric = {metric:.4f} ", file=self._stdout)
                    for c in range(self._out_channels):
                        print(f"Class '{self._class_names[c]}' metric = {metric_classes[c]:.4f} ", file=self._stdout)

                    # Do we have the best metric so far?
                    if metric > best_metric:
                        best_metric = metric
                        best_metric_epoch = epoch + 1
                        torch.save(
                            self._model.state_dict(),
                            model_file_name
                        )
                        print(f"New best global metric = {best_metric:.4f} at epoch: {best_metric_epoch}", file=self._stdout)
                        print(f"Saved best model '{Path(model_file_name).name}'", file=self._stdout)

                    # Add validation loss and metrics to log
                    writer.add_scalar("val_mean_dice_loss", metric, epoch + 1)
                    for c in range(self._out_channels):
                        metric_name = f"val_{self._class_names[c].lower()}_metric"
                        writer.add_scalar(metric_name, metric_classes[c], epoch + 1)

        print(f"Training completed. Best_metric = {best_metric:.4f} at epoch: {best_metric_epoch}", file=self._stdout)
        writer.close()

        # Return success
        return True

    def test_predict(
            self,
            target_folder: Union[Path, str] = '',
            model_path: Union[Path, str] = ''
    ) -> bool:
        """Run prediction on predefined test data.

        @param target_folder: Path|str, optional: default = ''
            Path to the folder where to store the predicted images. If not specified,
            if defaults to '{working_dir}/predictions'. See constructor.

        @param model_path: Path|str, optional: default = ''
            Full path to the model to use. If omitted and a training was
            just run, the path to the model with the best metric is
            already stored and will be used.

            @see get_best_model_path()

        @return True if the prediction was successful, False otherwise.
        """

        # Inform
        self._print_header("Test prediction")

        # Get the device
        self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # If the model is not in memory, instantiate it first
        if self._model is None:
            self._define_model()

        # If the path to the best model was not set, use current one (if set)
        if model_path == '':
            model_path = self.get_best_model_path()

        # Try loading the model weights: they must be compatible
        # with the model in memory
        try:
            checkpoint = torch.load(
                model_path,
                map_location=torch.device('cpu')
            )
            self._model.load_state_dict(checkpoint)
            print(f"Loaded best metric model {model_path}.", file=self._stdout)
        except Exception as e:
            self._message = "Error: there was a problem loading the model! Aborting."
            return False

        # If the target folder is not specified, set it to the standard predictions out
        if target_folder == '':
            target_folder = Path(self._working_dir) / "tests"
        else:
            target_folder = Path(target_folder)
        target_folder.mkdir(parents=True, exist_ok=True)

        # Switch to evaluation mode
        self._model.eval()

        indx = 0

        # Make sure not to update the gradients
        with torch.no_grad():
            for test_data in self._test_dataloader:

                # Get the next batch and move it to device
                test_images, test_masks = test_data[0].to(self._device), test_data[1].to(self._device)

                # Apply sliding inference over ROI size
                test_outputs = sliding_window_inference(
                    test_images,
                    self._roi_size,
                    self._sliding_window_batch_size,
                    self._model
                )
                test_outputs = self._test_post_transforms(test_outputs)

                # Retrieve the image from the GPU (if needed)
                pred = test_outputs.cpu().numpy().squeeze()

                # Prepare the output file name
                basename = os.path.splitext(os.path.basename(self._test_image_names[indx]))[0]
                basename = basename.replace('train_', 'pred_')

                # Convert to label image
                label_img = self._prediction_to_label_tiff_image(pred)

                # Save label image as tiff file
                label_file_name = os.path.join(
                    str(target_folder),
                    basename + '.tif')
                with TiffWriter(label_file_name) as tif:
                    tif.save(label_img)

                # Inform
                print(f"Saved {str(target_folder)}/{basename}.tif", file=self._stdout)

                # Update the index
                indx += 1

        # Inform
        print(f"Test prediction completed.", file=self._stdout)

        # Return success
        return True

    def predict(self,
                input_folder: Union[Path, str],
                target_folder: Union[Path, str],
                model_path: Union[Path, str]
                ):
        """Run prediction.

        @param input_folder: Path|str
            Path to the folder where to store the predicted images.

        @param target_folder: Path|str
            Path to the folder where to store the predicted images.

        @param model_path: Path|str
            Full path to the model to use.

        @return True if the prediction was successful, False otherwise.
        """
        # Inform
        self._print_header("Prediction")

        # Get the device
        self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # If the model is not in memory, instantiate it first
        if self._model is None:
            self._define_model()

        # Try loading the model weights: they must be compatible
        # with the model in memory
        try:
            checkpoint = torch.load(
                model_path,
                map_location=torch.device('cpu')
            )
            self._model.load_state_dict(checkpoint)
            print(f"Loaded best metric model {model_path}.", file=self._stdout)
        except Exception as e:
            self._message = "Error: there was a problem loading the model! Aborting."
            return False

        # Make sure the target folder exists
        if type(target_folder) == str and target_folder == '':
            self._message = "Error: please specify a valid target folder! Aborting."
            return False

        target_folder = Path(target_folder)
        target_folder.mkdir(parents=True, exist_ok=True)

        # Get prediction dataloader
        if not self._define_prediction_data_loaders(input_folder):
            self._message = "Error: could not instantiate prediction dataloader! Aborting."
            return False

        # Switch to evaluation mode
        self._model.eval()

        indx = 0

        # Make sure not to update the gradients
        with torch.no_grad():
            for prediction_data in self._prediction_dataloader:

                # Get the next batch and move it to device
                prediction_images = prediction_data.to(self._device)

                # Apply sliding inference over ROI size
                prediction_outputs = sliding_window_inference(
                    prediction_images,
                    self._roi_size,
                    self._sliding_window_batch_size,
                    self._model
                )
                prediction_outputs = self._prediction_post_transforms(prediction_outputs)

                # Retrieve the image from the GPU (if needed)
                pred = prediction_outputs.cpu().numpy().squeeze()

                # Prepare the output file name
                basename = os.path.splitext(os.path.basename(self._prediction_image_names[indx]))[0]
                basename = "pred_" + basename

                # Convert to label image
                label_img = self._prediction_to_label_tiff_image(pred)

                # Save label image as tiff file
                label_file_name = os.path.join(
                    str(target_folder),
                    basename + '.tif')
                with TiffWriter(label_file_name) as tif:
                    tif.save(label_img)

                # Inform
                print(f"Saved {str(target_folder)}/{basename}.tif", file=self._stdout)

                # Update the index
                indx += 1

        # Inform
        print(f"Prediction completed.", file=self._stdout)

        # Return success
        return True

    def set_training_data(self,
                          train_image_names,
                          train_mask_names,
                          val_image_names,
                          val_mask_names,
                          test_image_names,
                          test_mask_names) -> None:
        """Set all training files names.

        @param train_image_names: list
            List of training image names.

        @param train_mask_names: list
            List of training mask names.

        @param val_image_names: list
            List of validation image names.

        @param val_mask_names: list
            List of validation image names.

        @param test_image_names: list
            List of test image names.

        @param test_mask_names: list
            List of test image names.
        """

        # First validate all data
        if len(train_image_names) != len(train_mask_names):
            raise ValueError("The number of training images does not match the number of training masks.")

        if len(val_image_names) != len(val_mask_names):
            raise ValueError("The number of validation images does not match the number of validation masks.")

        if len(test_image_names) != len(test_mask_names):
            raise ValueError("The number of test images does not match the number of test masks.")

        # Training data
        self._train_image_names = train_image_names
        self._train_mask_names = train_mask_names

        # Validation data
        self._validation_image_names = val_image_names
        self._validation_mask_names = val_mask_names

        # Test data
        self._test_image_names = test_image_names
        self._test_mask_names = test_mask_names

    @staticmethod
    def _prediction_to_label_tiff_image(prediction):
        """Save the prediction to a label image (TIFF)"""

        # Convert to label image
        label_img = one_hot_stack_to_label_image(
            prediction,
            first_index_is_background=True,
            channels_first=True,
            dtype=np.uint16
        )

        return label_img

    def _define_training_transforms(self):
        """Define and initialize all training data transforms.

          * training set images transform
          * training set masks transform
          * validation set images transform
          * validation set masks transform
          * validation set images post-transform
          * test set images transform
          * test set masks transform
          * test set images post-transform
          * prediction set images transform
          * prediction set images post-transform

        @return True if data transforms could be instantiated, False otherwise.
        """

        if self._mask_type == MaskType.UNKNOWN:
            raise Exception("The mask type is unknown. Cannot continue!")

        # Depending on the mask type, we will need to adapt the Mask Loader
        # and Transform. We start by initializing the most common types.
        MaskLoader = LoadMask(self._mask_type)
        MaskTransform = Identity

        # Adapt the transform for the LABEL types
        if self._mask_type == MaskType.TIFF_LABELS or self._mask_type == MaskType.NUMPY_LABELS:
            MaskTransform = ToOneHot(num_classes=self._out_channels)

        # The H5_ONE_HOT type requires a different loader
        if self._mask_type == MaskType.H5_ONE_HOT:
            # MaskLoader: still missing
            raise Exception("HDF5 one-hot masks are not supported yet!")

        # Define transforms for training
        self._train_image_transforms = Compose(
            [
                LoadImage(image_only=True),
                ScaleIntensity(),
                AddChannel(),
                RandSpatialCrop(self._roi_size, random_size=False),
                RandRotate90(prob=0.5, spatial_axes=(0, 1)),
                ToTensor()
            ]
        )
        self._train_mask_transforms = Compose(
            [
                MaskLoader,
                MaskTransform,
                RandSpatialCrop(self._roi_size, random_size=False),
                RandRotate90(prob=0.5, spatial_axes=(0, 1)),
                ToTensor()
            ]
        )

        # Define transforms for validation
        self._validation_image_transforms = Compose(
            [
                LoadImage(image_only=True),
                ScaleIntensity(),
                AddChannel(),
                ToTensor()
            ]
        )
        self._validation_mask_transforms = Compose(
            [
                MaskLoader,
                MaskTransform,
                ToTensor()
            ]
        )

        # Define transforms for testing
        self._test_image_transforms = Compose(
            [
                LoadImage(image_only=True),
                ScaleIntensity(),
                AddChannel(),
                ToTensor()
            ]
        )
        self._test_mask_transforms = Compose(
            [
                MaskLoader,
                MaskTransform,
                ToTensor()
            ]
        )

        # Post transforms
        self._validation_post_transforms = Compose(
            [
                Activations(softmax=True),
                AsDiscrete(threshold_values=True)
            ]
        )

        self._test_post_transforms = Compose(
            [
                Activations(softmax=True),
                AsDiscrete(threshold_values=True)
            ]
        )

    def _define_training_data_loaders(self) -> bool:
        """Initialize training datasets and data loaders.

        @Note: in Windows, it is essential to set `persistent_workers=True` in the data loaders!

        @return True if datasets and data loaders could be instantiated, False otherwise.
        """

        # Optimize arguments
        if sys.platform == 'win32':
            persistent_workers = True
            pin_memory = False
        else:
            persistent_workers = False
            pin_memory = torch.cuda.is_available()

        if len(self._train_image_names) == 0 or \
                len(self._train_mask_names) == 0 or \
                len(self._validation_image_names) == 0 or \
                len(self._validation_mask_names) == 0 or \
                len(self._test_image_names) == 0 or \
                len(self._test_mask_names) == 0:

            self._train_dataset = None
            self._train_dataloader = None
            self._validation_dataset = None
            self._validation_dataloader = None
            self._test_dataset = None
            self._test_dataloader = None

            return False

        # Training
        self._train_dataset = ArrayDataset(
            self._train_image_names,
            self._train_image_transforms,
            self._train_mask_names,
            self._train_mask_transforms
        )
        self._train_dataloader = DataLoader(
            self._train_dataset,
            batch_size=self._training_batch_size,
            shuffle=False,
            num_workers=self._training_num_workers,
            persistent_workers=persistent_workers,
            pin_memory=pin_memory
        )

        # Validation
        self._validation_dataset = ArrayDataset(
            self._validation_image_names,
            self._validation_image_transforms,
            self._validation_mask_names,
            self._validation_mask_transforms
        )
        self._validation_dataloader = DataLoader(
            self._validation_dataset,
            batch_size=self._validation_batch_size,
            shuffle=False,
            num_workers=self._validation_num_workers,
            persistent_workers=persistent_workers,
            pin_memory=pin_memory
        )

        # Test
        self._test_dataset = ArrayDataset(
            self._test_image_names,
            self._test_image_transforms,
            self._test_mask_names,
            self._test_mask_transforms
        )
        self._test_dataloader = DataLoader(
            self._test_dataset,
            batch_size=self._test_batch_size,
            shuffle=False,
            num_workers=self._test_num_workers,
            persistent_workers=persistent_workers,
            pin_memory=pin_memory
        )

        return True

    def _define_prediction_transforms(self):
        """Define and initialize all prediction data transforms.

          * prediction set images transform
          * prediction set images post-transform

        @return True if data transforms could be instantiated, False otherwise.
        """

        # Define transforms for prediction
        self._prediction_image_transforms = Compose(
            [
                LoadImage(image_only=True),
                ScaleIntensity(),
                AddChannel(),
                ToTensor(),
            ]
        )

        self._prediction_post_transforms = Compose(
            [
                Activations(softmax=True),
                AsDiscrete(threshold_values=True),
            ]
        )

    def _define_prediction_data_loaders(
            self,
            prediction_folder_path: Union[Path, str]
    ) -> bool:
        """Initialize prediction datasets and data loaders.

        @Note: in Windows, it is essential to set `persistent_workers=True` in the data loaders!

        @return True if datasets and data loaders could be instantiated, False otherwise.
        """

        # Check that the path exists
        prediction_folder_path = Path(prediction_folder_path)
        if not prediction_folder_path.is_dir():
            return False

        # Scan for images
        self._prediction_image_names = natsorted(
            glob(str(Path(prediction_folder_path) / "*.tif"))
        )

        # Optimize arguments
        if sys.platform == 'win32':
            persistent_workers = True
            pin_memory = False
        else:
            persistent_workers = False
            pin_memory = torch.cuda.is_available()

        if len(self._prediction_image_names) == 0:

            self._prediction_dataset = None
            self._prediction_dataloader = None

            return False

        # Define the transforms
        self._define_prediction_transforms()

        # Prediction
        self._prediction_dataset = Dataset(
            self._prediction_image_names,
            self._prediction_image_transforms
        )
        self._prediction_dataloader = DataLoader(
            self._prediction_dataset,
            batch_size=self._test_batch_size,
            shuffle=False,
            num_workers=self._test_num_workers,
            persistent_workers=persistent_workers,
            pin_memory=pin_memory
        )

        return True

    def get_message(self):
        """Return last error message."""
        return self._message

    def get_best_model_path(self):
        """Return the full path to the best model."""
        return self._best_model

    def _clear_session(self) -> None:
        """Try clearing cache on the GPU."""
        if self._device != "cpu":
            torch.cuda.empty_cache()

    def _define_model(self) -> None:
        """Instantiate the U-Net architecture."""

        # Create U-Net
        self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device '{self._device}'.", file=self._stdout)

        # Try to free memory on the GPU
        if self._device != "cpu":
            torch.cuda.empty_cache()

        # Instantiate the requested model
        if self._option_architecture == SegmentationArchitectures.ResidualUNet2D:
            # Monai's UNet
            self._model = UNet(
                dimensions=2,
                in_channels=self._in_channels,
                out_channels=self._out_channels,
                channels=tuple((self._num_filters_in_first_layer * 2**i for i in range(0, 5))),
                strides=(2, 2, 2, 2),
                num_res_units=2
            ).to(self._device)

        elif self._option_architecture == SegmentationArchitectures.AttentionUNet2D:

            # Attention U-Net
            self._model = AttentionUNet2D(
                img_ch=self._in_channels,
                output_ch=self._out_channels,
                n1=self._num_filters_in_first_layer
            ).to(self._device)

        else:
            raise ValueError(f"Unexpected architecture {self._option_architecture}! Aborting.")

    def _define_training_loss(self) -> None:
        """Define the loss function."""

        if self._option_loss == SegmentationLosses.GeneralizedDiceLoss:
            self._training_loss_function = GeneralizedDiceLoss(
                include_background=True,
                to_onehot_y=False,
                softmax=True,
                batch=True,
            )
        else:
            raise ValueError(f"Unknown loss option {self._option_loss}! Aborting.")

    def _define_optimizer(self) -> None:
        """Define the optimizer."""

        if self._model is None:
            return

        if self._option_optimizer == Optimizers.Adam:
            self._optimizer = Adam(
                self._model.parameters(),
                self._learning_rate,
                weight_decay=self._weight_decay,
                amsgrad=True
            )
        elif self._option_optimizer == Optimizers.SGD:
            self._optimizer = SGD(
                self._model.parameters(),
                lr=self._learning_rate,
                momentum=self._momentum
            )
        else:
            raise ValueError(f"Unknown optimizer option {self._option_optimizer}! Aborting.")

    def _define_validation_metric(self):
        """Define the metric for validation function."""

        self._validation_metric = DiceMetric(
            include_background=True,
            reduction="mean"
        )

    def _prepare_experiment_and_model_names(self) -> Tuple[str, str]:
        """Prepare the experiment and model names.

        @return experiment_file_name, model_file_name

        Current date time is appended and the full path is returned.
        """

        # Make sure the "runs" subfolder exists
        runs_dir = Path(self._working_dir) / "runs"
        runs_dir.mkdir(parents=True, exist_ok=True)

        now = datetime.now()  # current date and time
        date_time = now.strftime("%Y%m%d_%H%M%S")

        # Experiment name
        experiment_name = f"{self._raw_experiment_name}_{str(self._option_architecture)}_{date_time}" \
            if self._raw_experiment_name != "" \
            else f"{str(self._option_architecture)}_{date_time}"
        experiment_name = runs_dir / experiment_name

        # Best model file name
        name = Path(self._raw_model_file_name).stem
        model_file_name = f"{name}_{date_time}.pth"
        model_file_name = runs_dir / model_file_name

        return str(experiment_name), str(model_file_name)

    def _print_header(self, header_text, line_length=80, file=None):
        """Print a section header."""
        if file is None:
            file = self._stdout
        print(f"{line_length * '-'}", file=file)
        print(f"{header_text}", file=self._stdout)
        print(f"{line_length * '-'}", file=file)
Exemple #8
0
    def configure(self):
        self.set_device()
        network = UNet(
            dimensions=3,
            in_channels=1,
            out_channels=2,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2,
            norm=Norm.BATCH,
        ).to(self.device)
        if self.multi_gpu:
            network = DistributedDataParallel(
                module=network,
                device_ids=[self.device],
                find_unused_parameters=False,
            )

        train_transforms = Compose([
            LoadImaged(keys=("image", "label")),
            EnsureChannelFirstd(keys=("image", "label")),
            Spacingd(keys=("image", "label"),
                     pixdim=[1.0, 1.0, 1.0],
                     mode=["bilinear", "nearest"]),
            ScaleIntensityRanged(
                keys="image",
                a_min=-57,
                a_max=164,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            CropForegroundd(keys=("image", "label"), source_key="image"),
            RandCropByPosNegLabeld(
                keys=("image", "label"),
                label_key="label",
                spatial_size=(96, 96, 96),
                pos=1,
                neg=1,
                num_samples=4,
                image_key="image",
                image_threshold=0,
            ),
            RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
            ToTensord(keys=("image", "label")),
        ])
        train_datalist = load_decathlon_datalist(self.data_list_file_path,
                                                 True, "training")
        if self.multi_gpu:
            train_datalist = partition_dataset(
                data=train_datalist,
                shuffle=True,
                num_partitions=dist.get_world_size(),
                even_divisible=True,
            )[dist.get_rank()]
        train_ds = CacheDataset(
            data=train_datalist,
            transform=train_transforms,
            cache_num=32,
            cache_rate=1.0,
            num_workers=4,
        )
        train_data_loader = DataLoader(
            train_ds,
            batch_size=2,
            shuffle=True,
            num_workers=4,
        )
        val_transforms = Compose([
            LoadImaged(keys=("image", "label")),
            EnsureChannelFirstd(keys=("image", "label")),
            ScaleIntensityRanged(
                keys="image",
                a_min=-57,
                a_max=164,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            CropForegroundd(keys=("image", "label"), source_key="image"),
            ToTensord(keys=("image", "label")),
        ])

        val_datalist = load_decathlon_datalist(self.data_list_file_path, True,
                                               "validation")
        val_ds = CacheDataset(val_datalist, val_transforms, 9, 0.0, 4)
        val_data_loader = DataLoader(
            val_ds,
            batch_size=1,
            shuffle=False,
            num_workers=4,
        )
        post_transform = Compose([
            Activationsd(keys="pred", softmax=True),
            AsDiscreted(
                keys=["pred", "label"],
                argmax=[True, False],
                to_onehot=True,
                n_classes=2,
            ),
        ])
        # metric
        key_val_metric = {
            "val_mean_dice":
            MeanDice(
                include_background=False,
                output_transform=lambda x: (x["pred"], x["label"]),
                device=self.device,
            )
        }
        val_handlers = [
            StatsHandler(output_transform=lambda x: None),
            CheckpointSaver(
                save_dir=self.ckpt_dir,
                save_dict={"model": network},
                save_key_metric=True,
            ),
            TensorBoardStatsHandler(log_dir=self.ckpt_dir,
                                    output_transform=lambda x: None),
        ]
        self.eval_engine = SupervisedEvaluator(
            device=self.device,
            val_data_loader=val_data_loader,
            network=network,
            inferer=SlidingWindowInferer(
                roi_size=[160, 160, 160],
                sw_batch_size=4,
                overlap=0.5,
            ),
            post_transform=post_transform,
            key_val_metric=key_val_metric,
            val_handlers=val_handlers,
            amp=self.amp,
        )

        optimizer = torch.optim.Adam(network.parameters(), self.learning_rate)
        loss_function = DiceLoss(to_onehot_y=True, softmax=True)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                       step_size=5000,
                                                       gamma=0.1)
        train_handlers = [
            LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
            ValidationHandler(validator=self.eval_engine,
                              interval=self.val_interval,
                              epoch_level=True),
            StatsHandler(tag_name="train_loss",
                         output_transform=lambda x: x["loss"]),
            TensorBoardStatsHandler(
                log_dir=self.ckpt_dir,
                tag_name="train_loss",
                output_transform=lambda x: x["loss"],
            ),
        ]

        self.train_engine = SupervisedTrainer(
            device=self.device,
            max_epochs=self.max_epochs,
            train_data_loader=train_data_loader,
            network=network,
            optimizer=optimizer,
            loss_function=loss_function,
            inferer=SimpleInferer(),
            post_transform=post_transform,
            key_train_metric=None,
            train_handlers=train_handlers,
            amp=self.amp,
        )

        if self.local_rank > 0:
            self.train_engine.logger.setLevel(logging.WARNING)
            self.eval_engine.logger.setLevel(logging.WARNING)
    def test_train_timing(self):
        images = sorted(glob(os.path.join(self.data_dir, "img*.nii.gz")))
        segs = sorted(glob(os.path.join(self.data_dir, "seg*.nii.gz")))
        train_files = [{
            "image": img,
            "label": seg
        } for img, seg in zip(images[:32], segs[:32])]
        val_files = [{
            "image": img,
            "label": seg
        } for img, seg in zip(images[-9:], segs[-9:])]

        device = torch.device("cuda:0")
        # define transforms for train and validation
        train_transforms = Compose([
            LoadImaged(keys=["image", "label"]),
            EnsureChannelFirstd(keys=["image", "label"]),
            Spacingd(keys=["image", "label"],
                     pixdim=(1.0, 1.0, 1.0),
                     mode=("bilinear", "nearest")),
            ScaleIntensityd(keys="image"),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            # pre-compute foreground and background indexes
            # and cache them to accelerate training
            FgBgToIndicesd(keys="label", fg_postfix="_fg", bg_postfix="_bg"),
            # change to execute transforms with Tensor data
            EnsureTyped(keys=["image", "label"]),
            # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch
            ToDeviced(keys=["image", "label"], device=device),
            # randomly crop out patch samples from big
            # image based on pos / neg ratio
            # the image centers of negative samples
            # must be in valid image area
            RandCropByPosNegLabeld(
                keys=["image", "label"],
                label_key="label",
                spatial_size=(64, 64, 64),
                pos=1,
                neg=1,
                num_samples=4,
                fg_indices_key="label_fg",
                bg_indices_key="label_bg",
            ),
            RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=[1, 2]),
            RandAxisFlipd(keys=["image", "label"], prob=0.5),
            RandRotate90d(keys=["image", "label"],
                          prob=0.5,
                          spatial_axes=(1, 2)),
            RandZoomd(keys=["image", "label"],
                      prob=0.5,
                      min_zoom=0.8,
                      max_zoom=1.2,
                      keep_size=True),
            RandRotated(
                keys=["image", "label"],
                prob=0.5,
                range_x=np.pi / 4,
                mode=("bilinear", "nearest"),
                align_corners=True,
                dtype=np.float64,
            ),
            RandAffined(keys=["image", "label"],
                        prob=0.5,
                        rotate_range=np.pi / 2,
                        mode=("bilinear", "nearest")),
            RandGaussianNoised(keys="image", prob=0.5),
            RandStdShiftIntensityd(keys="image",
                                   prob=0.5,
                                   factors=0.05,
                                   nonzero=True),
        ])

        val_transforms = Compose([
            LoadImaged(keys=["image", "label"]),
            EnsureChannelFirstd(keys=["image", "label"]),
            Spacingd(keys=["image", "label"],
                     pixdim=(1.0, 1.0, 1.0),
                     mode=("bilinear", "nearest")),
            ScaleIntensityd(keys="image"),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            EnsureTyped(keys=["image", "label"]),
            # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch
            ToDeviced(keys=["image", "label"], device=device),
        ])

        max_epochs = 5
        learning_rate = 2e-4
        val_interval = 1  # do validation for every epoch

        # set CacheDataset, ThreadDataLoader and DiceCE loss for MONAI fast training
        train_ds = CacheDataset(data=train_files,
                                transform=train_transforms,
                                cache_rate=1.0,
                                num_workers=8)
        val_ds = CacheDataset(data=val_files,
                              transform=val_transforms,
                              cache_rate=1.0,
                              num_workers=5)
        # disable multi-workers because `ThreadDataLoader` works with multi-threads
        train_loader = ThreadDataLoader(train_ds,
                                        num_workers=0,
                                        batch_size=4,
                                        shuffle=True)
        val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1)

        loss_function = DiceCELoss(to_onehot_y=True,
                                   softmax=True,
                                   squared_pred=True,
                                   batch=True)
        model = UNet(
            spatial_dims=3,
            in_channels=1,
            out_channels=2,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2,
            norm=Norm.BATCH,
        ).to(device)

        # Novograd paper suggests to use a bigger LR than Adam,
        # because Adam does normalization by element-wise second moments
        optimizer = Novograd(model.parameters(), learning_rate * 10)
        scaler = torch.cuda.amp.GradScaler()

        post_pred = Compose(
            [EnsureType(), AsDiscrete(argmax=True, to_onehot=2)])
        post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)])

        dice_metric = DiceMetric(include_background=True,
                                 reduction="mean",
                                 get_not_nans=False)

        best_metric = -1
        total_start = time.time()
        for epoch in range(max_epochs):
            epoch_start = time.time()
            print("-" * 10)
            print(f"epoch {epoch + 1}/{max_epochs}")
            model.train()
            epoch_loss = 0
            step = 0
            for batch_data in train_loader:
                step_start = time.time()
                step += 1
                optimizer.zero_grad()
                # set AMP for training
                with torch.cuda.amp.autocast():
                    outputs = model(batch_data["image"])
                    loss = loss_function(outputs, batch_data["label"])
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                epoch_loss += loss.item()
                epoch_len = math.ceil(len(train_ds) / train_loader.batch_size)
                print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}"
                      f" step time: {(time.time() - step_start):.4f}")
            epoch_loss /= step
            print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

            if (epoch + 1) % val_interval == 0:
                model.eval()
                with torch.no_grad():
                    for val_data in val_loader:
                        roi_size = (96, 96, 96)
                        sw_batch_size = 4
                        # set AMP for validation
                        with torch.cuda.amp.autocast():
                            val_outputs = sliding_window_inference(
                                val_data["image"], roi_size, sw_batch_size,
                                model)

                        val_outputs = [
                            post_pred(i) for i in decollate_batch(val_outputs)
                        ]
                        val_labels = [
                            post_label(i)
                            for i in decollate_batch(val_data["label"])
                        ]
                        dice_metric(y_pred=val_outputs, y=val_labels)

                    metric = dice_metric.aggregate().item()
                    dice_metric.reset()
                    if metric > best_metric:
                        best_metric = metric
                    print(
                        f"epoch: {epoch + 1} current mean dice: {metric:.4f}, best mean dice: {best_metric:.4f}"
                    )
            print(
                f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}"
            )

        total_time = time.time() - total_start
        print(
            f"train completed, best_metric: {best_metric:.4f} total time: {total_time:.4f}"
        )
        # test expected metrics
        self.assertGreater(best_metric, 0.95)
Exemple #10
0
max_epochs = 6
learning_rate = 1e-4
val_interval = 2
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)
loss_function = DiceCELoss(
    to_onehot_y=True, softmax=True, squared_pred=True, batch=True
)
optimizer = Novograd(model.parameters(), learning_rate * 10)
scaler = torch.cuda.amp.GradScaler()
dice_metric = DiceMetric(
    include_background=True, reduction="mean", get_not_nans=False
)

post_pred = Compose(
    [EnsureType(), AsDiscrete(argmax=True, to_onehot=2)]
)
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)])

best_metric = -1
best_metric_epoch = -1
best_metrics_epochs_and_time = [[], [], []]
epoch_loss_values = []
metric_values = []
class UNet_DF(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams

        self.unet = UNet(
            dimensions=3,
            in_channels=1,
            out_channels=2,
            channels=(64, 128, 258, 512, 1024),
            strides=(2, 2, 2, 2),
            norm=monai.networks.layers.Norm.BATCH,
            dropout=0,
        )
        self.sample_masks = []

    # Data setup
    def setup(self, stage):
        data_df = pd.read_csv(
            '/data/shared/prostate/yale_prostate/input_lists/MR_yale.csv')

        train_imgs = data_df['IMAGE'][0:295].tolist()
        train_masks = data_df['SEGM'][0:295].tolist()

        train_dicts = [{
            'image': image,
            'mask': mask
        } for (image, mask) in zip(train_imgs, train_masks)]

        train_dicts, val_dicts = train_test_split(train_dicts, test_size=0.15)

        # Basic transforms
        data_keys = ["image", "mask"]
        data_transforms = Compose([
            LoadNiftid(keys=data_keys),
            AddChanneld(keys=data_keys),
            NormalizeIntensityd(keys="image"),
            RandCropByPosNegLabeld(keys=data_keys,
                                   label_key="mask",
                                   spatial_size=self.hparams.patch_size,
                                   num_samples=4,
                                   image_key="image"),
        ])

        self.train_dataset = monai.data.CacheDataset(
            data=train_dicts,
            transform=Compose([data_transforms,
                               ToTensord(keys=data_keys)]),
            cache_rate=1.0)

        self.val_dataset = monai.data.CacheDataset(
            data=val_dicts,
            transform=Compose([data_transforms,
                               ToTensord(keys=data_keys)]),
            cache_rate=1.0)

    def train_dataloader(self):
        return monai.data.DataLoader(self.train_dataset,
                                     batch_size=self.hparams.batch_size,
                                     shuffle=True,
                                     num_workers=self.hparams.num_workers)

    def val_dataloader(self):
        return monai.data.DataLoader(self.val_dataset,
                                     batch_size=self.hparams.batch_size,
                                     num_workers=self.hparams.num_workers)

    # Training setup
    def forward(self, image):
        return self.unet(image)

    def criterion(self, y_hat, y):
        dice_loss = monai.losses.DiceLoss(to_onehot_y=True, softmax=True)
        focal_loss = monai.losses.FocalLoss()
        return dice_loss(y_hat, y) + focal_loss(y_hat, y)

    def training_step(self, batch, batch_idx):
        inputs, labels = batch['image'], batch['mask']
        outputs = self(inputs)
        loss = self.criterion(outputs, labels)

        self.logger.log_metrics({"loss/train": loss}, self.global_step)

        return {'loss': loss}

    def configure_optimizers(self):
        lr = self.hparams.lr
        optimizer = torch.optim.Adam(self.unet.parameters(), lr=lr)
        return optimizer

    def validation_step(self, batch, batch_idx):
        inputs, labels = (
            batch["image"],
            batch["mask"],
        )
        outputs = self(inputs)

        # Sample masks
        if self.current_epoch != 0:
            middle = int(outputs[0].argmax(0).shape[2] / 2)
            image = outputs[0].argmax(0)[:, :, middle].unsqueeze(0).detach()
            self.sample_masks.append(image)

        loss = self.criterion(outputs, labels)
        return {"val_loss": loss}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        self.logger.log_metrics({"loss/val": avg_loss}, self.current_epoch)

        if self.current_epoch != 0:
            grid = torchvision.utils.make_grid(self.sample_masks)
            self.logger.experiment.add_image('sample_masks', grid,
                                             self.current_epoch)
            self.sample_masks = []

        return {"val_loss": avg_loss}
class MaskGAN(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams

        self.generator = UNet(
            dimensions=3,
            in_channels=1,
            out_channels=2,
            channels=(64, 128, 258, 512, 1024),
            strides=(2, 2, 2, 2),
            norm=monai.networks.layers.Norm.BATCH,
            dropout=0,
        )

        self.discriminator = Discriminator(
            in_shape=self.hparams.patch_size,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            norm=monai.networks.layers.Norm.BATCH,
        )

        self.generated_masks = None
        self.sample_masks = []

    # Data setup
    def setup(self, stage):
        data_df = pd.read_csv(
            '/data/shared/prostate/yale_prostate/input_lists/MR_yale.csv')

        train_imgs = data_df['IMAGE'][0:295].tolist()
        train_masks = data_df['SEGM'][0:295].tolist()

        train_dicts = [{
            'image': image,
            'mask': mask
        } for (image, mask) in zip(train_imgs, train_masks)]

        train_dicts, val_dicts = train_test_split(train_dicts, test_size=0.15)

        # Basic transforms
        data_keys = ["image", "mask"]
        data_transforms = Compose([
            LoadNiftid(keys=data_keys),
            AddChanneld(keys=data_keys),
            NormalizeIntensityd(keys="image"),
            RandCropByPosNegLabeld(keys=data_keys,
                                   label_key="mask",
                                   spatial_size=self.hparams.patch_size,
                                   num_samples=4,
                                   image_key="image"),
        ])

        self.train_dataset = monai.data.CacheDataset(
            data=train_dicts,
            transform=Compose([data_transforms,
                               ToTensord(keys=data_keys)]),
            cache_rate=1.0)

        self.val_dataset = monai.data.CacheDataset(
            data=val_dicts,
            transform=Compose([data_transforms,
                               ToTensord(keys=data_keys)]),
            cache_rate=1.0)

    def train_dataloader(self):
        return monai.data.DataLoader(self.train_dataset,
                                     batch_size=self.hparams.batch_size,
                                     shuffle=True,
                                     num_workers=self.hparams.num_workers)

    def val_dataloader(self):
        return monai.data.DataLoader(self.val_dataset,
                                     batch_size=self.hparams.batch_size,
                                     num_workers=self.hparams.num_workers)

    # Training setup
    def forward(self, image):
        return self.generator(image)

    def generator_loss(self, y_hat, y):
        dice_loss = monai.losses.DiceLoss(to_onehot_y=True, softmax=True)
        return dice_loss(y_hat, y)

    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)

    def training_step(self, batch, batch_idx, optimizer_idx):
        inputs, labels = batch['image'], batch['mask']
        batch_size = inputs.size(0)
        # Generator training
        if optimizer_idx == 0:
            self.generated_masks = self(inputs)

            # Loss from difference between real and generated masks
            g_loss = self.generator_loss(self.generated_masks, labels)

            # Loss from discriminator
            # The generator wants the discriminator to be wrong,
            # so the wrong labels are used
            fake_labels = torch.ones(batch_size, 1).cuda(inputs.device.index)
            d_loss = self.adversarial_loss(
                self.discriminator(
                    self.generated_masks.argmax(1).type(
                        torch.FloatTensor).cuda(inputs.device.index)),
                fake_labels)

            avg_loss = g_loss + 0.5 * d_loss

            self.logger.log_metrics({"g_train/g_loss": g_loss},
                                    self.global_step)
            self.logger.log_metrics({"g_train/d_loss": d_loss},
                                    self.global_step)
            self.logger.log_metrics({"g_train/tot_loss": avg_loss},
                                    self.global_step)
            return {'loss': avg_loss}

        # Discriminator trainig
        else:
            # Learning real masks
            real_labels = torch.ones(batch_size, 1).cuda(inputs.device.index)
            real_loss = self.adversarial_loss(
                self.discriminator(
                    labels.squeeze(1).type(torch.FloatTensor).cuda(
                        inputs.device.index)), real_labels)

            # Learning "fake" masks
            fake_labels = torch.zeros(batch_size, 1).cuda(inputs.device.index)
            fake_loss = self.adversarial_loss(
                self.discriminator(
                    self.generated_masks.argmax(1).detach().type(
                        torch.FloatTensor).cuda(inputs.device.index)),
                fake_labels)

            avg_loss = real_loss + fake_loss

            self.logger.log_metrics({"d_train/real_loss": real_loss},
                                    self.global_step)
            self.logger.log_metrics({"d_train/fake_loss": fake_loss},
                                    self.global_step)
            self.logger.log_metrics({"d_train/tot_loss": avg_loss},
                                    self.global_step)

            return {'loss': avg_loss}

    def configure_optimizers(self):
        lr = self.hparams.lr
        g_optimizer = torch.optim.Adam(self.generator.parameters(), lr=lr)
        d_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=lr)
        return [g_optimizer, d_optimizer], []

    def validation_step(self, batch, batch_idx):
        inputs, labels = (
            batch["image"],
            batch["mask"],
        )
        outputs = self(inputs)

        # Sample masks
        if self.current_epoch != 0:
            middle = int(outputs[0].argmax(0).shape[2] / 2)
            image = outputs[0].argmax(0)[:, :, middle].unsqueeze(0).detach()
            self.sample_masks.append(image)

        loss = self.generator_loss(outputs, labels)
        return {"val_loss": loss}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        self.logger.log_metrics({"val/loss": avg_loss}, self.current_epoch)

        if self.current_epoch != 0:
            grid = torchvision.utils.make_grid(self.sample_masks)
            self.logger.experiment.add_image('sample_masks', grid,
                                             self.current_epoch)
            self.sample_masks = []

        return {"val_loss": avg_loss}
Exemple #13
0
def train_process(fast=False):
    epoch_num = 10
    val_interval = 1
    train_trans, val_trans = transformations()
    train_ds = Dataset(data=train_files, transform=train_trans)
    val_ds = Dataset(data=val_files, transform=val_trans)

    train_loader = DataLoader(train_ds, batch_size=2, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=1)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n1 = 16
    model = UNet(dimensions=3,
                 in_channels=1,
                 out_channels=2,
                 channels=(n1 * 1, n1 * 2, n1 * 4, n1 * 8, n1 * 16),
                 strides=(2, 2, 2, 2)).to(device)
    loss_function = DiceLoss(to_onehot_y=True, softmax=True)
    post_pred = AsDiscrete(argmax=True, to_onehot=True, n_classes=2)
    post_label = AsDiscrete(to_onehot=True, n_classes=2)
    optimizer = torch.optim.Adam(model.parameters(), 1e-4, weight_decay=1e-5)

    best_metric = -1
    best_metric_epoch = -1
    best_metrics_epochs_and_time = [[], [], []]
    epoch_loss_values = list()
    metric_values = list()

    for epoch in range(epoch_num):
        print(f"epoch {epoch + 1}/{epoch_num}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data['image'].to(
                device), batch_data['label'].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = math.ceil(len(train_ds) / train_loader.batch_size)
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                metric_sum = 0.
                metric_count = 0
                for val_data in val_loader:
                    val_inputs, val_labels = val_data['image'].to(
                        device), val_data['label'].to(device)
                    val_outputs = model(val_inputs)
                    val_outputs = post_pred(val_outputs)
                    val_labels = post_label(val_labels)
                    value = compute_meandice(y_pred=val_outputs,
                                             y=val_labels,
                                             include_background=False)
                    metric_count += len(value)
                    metric_sum += value.sum().item()
                metric = metric_sum / metric_count
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    epochs_no_improve = 0
                    best_metric_epoch = epoch + 1
                    best_metrics_epochs_and_time[0].append(best_metric)
                    best_metrics_epochs_and_time[1].append(best_metric_epoch)
                    torch.save(model.state_dict(), 'sLUMRTL644.pth')
                else:
                    epochs_no_improve += 1

            print(
                f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                f" best mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}"
            )

    print(
        f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}"
    )
    return epoch_num, epoch_loss_values, metric_values, best_metrics_epochs_and_time
Exemple #14
0
# standard PyTorch program style: create UNet, DiceLoss and Adam optimizer
device = torch.device("cuda:0")
max_epochs = 6
learning_rate = 1e-4
val_interval = 2
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
optimizer = Adam(model.parameters(), learning_rate)
dice_metric = DiceMetric(include_background=True,
                         reduction="mean",
                         get_not_nans=False)

post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=2)])
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)])

best_metric = -1
best_metric_epoch = -1
best_metrics_epochs_and_time = [[], [], []]
epoch_loss_values = []
metric_values = []
epoch_times = []
total_start = time.time()
writer = SummaryWriter(log_dir=out_dir)
Exemple #15
0
class UNet2DRestorer(AbstractBaseLearner):
    """Restorer based on the U-Net architecture."""
    def __init__(self,
                 in_channels: int = 1,
                 out_channels: int = 1,
                 roi_size: Tuple[int, int] = (384, 384),
                 num_epochs: int = 400,
                 batch_sizes: Tuple[int, int, int, int] = (8, 1, 1, 1),
                 num_workers: Tuple[int, int, int, int] = (4, 4, 1, 1),
                 validation_step: int = 2,
                 sliding_window_batch_size: int = 4,
                 experiment_name: str = "",
                 model_name: str = "best_model",
                 seed: int = 4294967295,
                 working_dir: str = '.',
                 stdout: TextIOWrapper = sys.stdout,
                 stderr: TextIOWrapper = sys.stderr):
        """Constructor.

        @param in_channels: int, optional: default = 1
            Number of channels in the input (e.g. 1 for gray-value images).

        @param out_channels: int, optional: default = 3
            Number of channels in the output (classes).

        @param roi_size: Tuple[int, int], optional: default = (384, 384)
            Crop area (and input size of the U-Net network) used for training and validation/prediction.

        @param num_epochs: int, optional: default = 400
            Number of epochs for training.

        @param batch_sizes: Tuple[int, int, int], optional: default = (8, 1, 1, 1)
            Batch sizes for training, validation, testing, and prediction, respectively.

        @param num_workers: Tuple[int, int, int], optional: default = (4, 4, 1, 1)
            Number of workers for training, validation, testing, and prediction, respectively.

        @param validation_step: int, optional: default = 2
            Number of training steps before the next validation is performed.

        @param sliding_window_batch_size: int, optional: default = 4
            Number of batches for sliding window inference during validation and prediction.

        @param experiment_name: str, optional: default = ""
            Name of the experiment that maps to the folder that contains training information (to
            be used by tensorboard). Please note, current datetime will be appended.

        @param model_name: str, optional: default = "best_model.ph"
            Name of the file that stores the best model. Please note, current datetime will be appended
            (before the extension).

        @param seed: int, optional; default = 4294967295
            Set random seed for modules to enable or disable deterministic training.

        @param working_dir: str, optional, default = "."
            Working folder where to save the model weights and the logs for tensorboard.

        """

        # Call base constructor
        super().__init__()

        # Standard pipe wrappers
        self._stdout = stdout
        self._stderr = stderr

        # Device (initialize as "cpu")
        self._device = "cpu"

        # Input and output channels
        self._in_channels = in_channels
        self._out_channels = out_channels

        # Define hyper parameters
        self._roi_size = roi_size
        self._training_batch_size = batch_sizes[0]
        self._validation_batch_size = batch_sizes[1]
        self._test_batch_size = batch_sizes[2]
        self._prediction_batch_size = batch_sizes[3]
        self._training_num_workers = num_workers[0]
        self._validation_num_workers = num_workers[1]
        self._test_num_workers = num_workers[2]
        self._prediction_num_workers = num_workers[3]
        self._n_epochs = num_epochs
        self._validation_step = validation_step
        self._sliding_window_batch_size = sliding_window_batch_size

        # Set monai seed
        set_determinism(seed=seed)

        # All file names
        self._train_image_names: list = []
        self._train_target_names: list = []
        self._validation_image_names: list = []
        self._validation_target_names: list = []
        self._test_image_names: list = []
        self._test_target_names: list = []

        # Transforms
        self._train_image_transforms = None
        self._train_target_transforms = None
        self._validation_image_transforms = None
        self._validation_target_transforms = None
        self._test_image_transforms = None
        self._test_target_transforms = None
        self._prediction_image_transforms = None
        self._validation_post_transforms = None
        self._test_post_transforms = None
        self._prediction_post_transforms = None

        # Datasets and data loaders
        self._train_dataset = None
        self._train_dataloader = None
        self._validation_dataset = None
        self._validation_dataloader = None
        self._test_dataset = None
        self._test_dataloader = None
        self._prediction_dataset = None
        self._prediction_dataloader = None

        # Set model architecture, loss function, metric and optimizer
        self._model = None
        self._training_loss_function = None
        self._optimizer = None
        self._validation_metric = None

        # Working directory, model file name and experiment name for Tensorboard logs.
        # The file names will be redefined at the beginning of the training.
        self._working_dir = Path(working_dir).resolve()
        self._raw_experiment_name = experiment_name
        self._raw_model_file_name = model_name

        # Keep track of the full path of the best model
        self._best_model = ''

        # Keep track of last error message
        self._message = ""

    def train(self) -> bool:
        """Run training in a separate thread (added to the global application ThreadPool)."""

        # Free memory on the GPU
        self._clear_session()

        # Check that the data is set properly
        if len(self._train_image_names) == 0 or \
                len(self._train_target_names) == 0 or \
                len(self._validation_image_names) == 0 or \
                len(self._validation_target_names) == 0:
            self._message = "No training/validation data found."
            return False

        if len(self._train_image_names) != len(self._train_target_names) == 0:
            self._message = "The number of training images does not match the number of training targets."
            return False

        if len(self._validation_image_names) != len(
                self._validation_target_names) == 0:
            self._message = "The number of validation images does not match the number of validation targets."
            return False

        # Define the transforms
        self._define_transforms()

        # Define the datasets and data loaders
        self._define_training_data_loaders()

        # Instantiate the model
        self._define_model()

        # Define the loss function
        self._define_training_loss()

        # Define the optimizer (with default parameters)
        self._define_optimizer()

        # Define the validation metric
        self._define_validation_metric()

        # Define experiment name and model name
        experiment_name, model_file_name = self._prepare_experiment_and_model_names(
        )

        # Keep track of the best model file name
        self._best_model = model_file_name

        # Enter the main training loop
        lowest_validation_loss = np.Inf
        lowest_validation_epoch = -1

        epoch_loss_values = list()
        validation_loss_values = list()

        # Initialize TensorBoard's SummaryWriter
        writer = SummaryWriter(experiment_name)

        for epoch in range(self._n_epochs):

            # Inform
            self._print_header(f"Epoch {epoch + 1}/{self._n_epochs}")

            # Switch to training mode
            self._model.train()

            epoch_loss = 0
            step = 0
            for batch_data in self._train_dataloader:

                # Update step
                step += 1

                # Get the next batch and move it to device
                inputs, labels = batch_data[0].to(
                    self._device), batch_data[1].to(self._device)

                # Zero the gradient buffers
                self._optimizer.zero_grad()

                # Forward pass
                outputs = self._model(inputs)

                # Calculate the loss
                loss = self._training_loss_function(outputs, labels)

                # Back-propagate
                loss.backward()

                # Update weights (optimize)
                self._optimizer.step()

                # Update and store metrics
                epoch_loss += loss.item()
                epoch_len = len(
                    self._train_dataset) / self._train_dataloader.batch_size
                if epoch_len != int(epoch_len):
                    epoch_len = int(epoch_len) + 1

                print(
                    f"Batch {step}/{epoch_len}: train_loss = {loss.item():.4f}",
                    file=self._stdout)

            epoch_loss /= step
            epoch_loss_values.append(epoch_loss)
            print(f"Average loss = {epoch_loss:.4f}", file=self._stdout)
            writer.add_scalar("average_train_loss", epoch_loss, epoch + 1)

            # Validation
            if (epoch + 1) % self._validation_step == 0:

                self._print_header("Validation")

                # Switch to evaluation mode
                self._model.eval()

                # Make sure not to update the gradients
                with torch.no_grad():

                    # Global validation loss
                    validation_loss_sum = 0.0
                    validation_loss_count = 0

                    for val_data in self._validation_dataloader:

                        # Get the next batch and move it to device
                        val_images, val_labels = val_data[0].to(
                            self._device), val_data[1].to(self._device)

                        # Apply sliding inference over ROI size
                        val_outputs = sliding_window_inference(
                            val_images, self._roi_size,
                            self._sliding_window_batch_size, self._model)
                        val_outputs = self._validation_post_transforms(
                            val_outputs)

                        # Calculate the validation loss
                        val_loss = self._training_loss_function(
                            val_outputs, val_labels)

                        # Add to the current loss
                        validation_loss_count += 1
                        validation_loss_sum += val_loss.item()

                    # Global validation loss
                    validation_loss = validation_loss_sum / validation_loss_count
                    validation_loss_values.append(validation_loss)

                    # Print summary
                    print(f"Validation loss = {validation_loss:.4f} ",
                          file=self._stdout)

                    # Do we have the best metric so far?
                    if validation_loss < lowest_validation_loss:
                        lowest_validation_loss = validation_loss
                        lowest_validation_epoch = epoch + 1
                        torch.save(self._model.state_dict(), model_file_name)
                        print(
                            f"New lowest validation loss = {lowest_validation_loss:.4f} at epoch: {lowest_validation_epoch}",
                            file=self._stdout)
                        print(
                            f"Saved best model '{Path(model_file_name).name}'",
                            file=self._stdout)

                    # Add validation loss and metrics to log
                    writer.add_scalar("val_mean_loss", validation_loss,
                                      epoch + 1)

        print(
            f"Training completed. Lowest validation loss = {lowest_validation_loss:.4f} at epoch: {lowest_validation_epoch}",
            file=self._stdout)
        writer.close()

        # Return success
        return True

    def test_predict(self,
                     target_folder: Union[Path, str] = '',
                     model_path: Union[Path, str] = '') -> bool:
        """Run prediction on predefined test data.

        @param target_folder: Path|str, optional: default = ''
            Path to the folder where to store the predicted images. If not specified,
            if defaults to '{working_dir}/predictions'. See constructor.

        @param model_path: Path|str, optional: default = ''
            Full path to the model to use. If omitted and a training was
            just run, the path to the model with the best metric is
            already stored and will be used.

            @see get_best_model_path()

        @return True if the prediction was successful, False otherwise.
        """

        # Inform
        self._print_header("Test prediction")

        # Get the device
        self._device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        # If the model is not in memory, instantiate it first
        if self._model is None:
            self._define_model()

        # If the path to the best model was not set, use current one (if set)
        if model_path == '':
            model_path = self.get_best_model_path()

        # Try loading the model weights: they must be compatible
        # with the model in memory
        try:
            checkpoint = torch.load(model_path,
                                    map_location=torch.device('cpu'))
            self._model.load_state_dict(checkpoint)
            print(f"Loaded best metric model {model_path}.", file=self._stdout)
        except Exception as e:
            self._message = "Error: there was a problem loading the model! Aborting."
            return False

        # If the target folder is not specified, set it to the standard predictions out
        if target_folder == '':
            target_folder = Path(self._working_dir) / "tests"
        else:
            target_folder = Path(target_folder)
        target_folder.mkdir(parents=True, exist_ok=True)

        # Switch to evaluation mode
        self._model.eval()

        indx = 0

        # Make sure not to update the gradients
        with torch.no_grad():
            for test_data in self._test_dataloader:

                # Get the next batch and move it to device
                test_images, test_masks = test_data[0].to(
                    self._device), test_data[1].to(self._device)

                # Apply sliding inference over ROI size
                test_outputs = sliding_window_inference(
                    test_images, self._roi_size,
                    self._sliding_window_batch_size, self._model)
                test_outputs = self._test_post_transforms(test_outputs)

                # The ToNumpy() transform already causes the Tensor
                # to be gathered from the GPU to the CPU
                pred = test_outputs.squeeze()

                # Prepare the output file name
                basename = os.path.splitext(
                    os.path.basename(self._test_image_names[indx]))[0]
                basename = basename.replace('train_', 'pred_')

                # Save label image as tiff file
                pred_file_name = os.path.join(str(target_folder),
                                              basename + '.tif')
                with TiffWriter(pred_file_name) as tif:
                    tif.save(pred)

                # Inform
                print(f"Saved {str(target_folder)}/{basename}.tif",
                      file=self._stdout)

                # Update the index
                indx += 1

        # Inform
        print(f"Test prediction completed.", file=self._stdout)

        # Return success
        return True

    def predict(self, input_folder: Union[Path, str],
                target_folder: Union[Path, str], model_path: Union[Path, str]):
        """Run prediction.

        @param input_folder: Path|str
            Path to the folder where to store the predicted images.

        @param target_folder: Path|str
            Path to the folder where to store the predicted images.

        @param model_path: Path|str
            Full path to the model to use.

        @return True if the prediction was successful, False otherwise.
        """
        # Inform
        self._print_header("Prediction")

        # Get the device
        self._device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        # If the model is not in memory, instantiate it first
        if self._model is None:
            self._define_model()

        # Try loading the model weights: they must be compatible
        # with the model in memory
        try:
            checkpoint = torch.load(model_path,
                                    map_location=torch.device('cpu'))
            self._model.load_state_dict(checkpoint)
            print(f"Loaded best metric model {model_path}.", file=self._stdout)
        except Exception as e:
            self._message = "Error: there was a problem loading the model! Aborting."
            return False

        # Make sure the target folder exists
        if type(target_folder) == str and target_folder == '':
            self._message = "Error: please specify a valid target folder! Aborting."
            return False

        target_folder = Path(target_folder)
        target_folder.mkdir(parents=True, exist_ok=True)

        # Get prediction dataloader
        if not self._define_prediction_data_loaders(input_folder):
            self._message = "Error: could not instantiate prediction dataloader! Aborting."
            return False

        # Switch to evaluation mode
        self._model.eval()

        indx = 0

        # Make sure not to update the gradients
        with torch.no_grad():
            for prediction_data in self._prediction_dataloader:

                # Get the next batch and move it to device
                prediction_images = prediction_data.to(self._device)

                # Apply sliding inference over ROI size
                prediction_outputs = sliding_window_inference(
                    prediction_images, self._roi_size,
                    self._sliding_window_batch_size, self._model)
                prediction_outputs = self._prediction_post_transforms(
                    prediction_outputs)

                # The ToNumpy() transform already causes the Tensor
                # to be gathered from the GPU to the CPU
                pred = prediction_outputs.squeeze()

                # Prepare the output file name
                basename = os.path.splitext(
                    os.path.basename(self._prediction_image_names[indx]))[0]
                basename = "pred_" + basename

                # Save label image as tiff file
                pred_file_name = os.path.join(str(target_folder),
                                              basename + '.tif')
                with TiffWriter(pred_file_name) as tif:
                    tif.save(pred)

                # Inform
                print(f"Saved {str(target_folder)}/{basename}.tif",
                      file=self._stdout)

                # Update the index
                indx += 1

        # Inform
        print(f"Prediction completed.", file=self._stdout)

        # Return success
        return True

    def set_training_data(self, train_image_names, train_mask_names,
                          val_image_names, val_mask_names, test_image_names,
                          test_mask_names) -> None:
        """Set all training files names.

        @param train_image_names: list
            List of training image names.

        @param train_mask_names: list
            List of training mask names.

        @param val_image_names: list
            List of validation image names.

        @param val_mask_names: list
            List of validation image names.

        @param test_image_names: list
            List of test image names.

        @param test_mask_names: list
            List of test image names.
        """

        # First validate all data
        if len(train_image_names) != len(train_mask_names):
            raise ValueError(
                "The number of training images does not match the number of training masks."
            )

        if len(val_image_names) != len(val_mask_names):
            raise ValueError(
                "The number of validation images does not match the number of validation masks."
            )

        if len(test_image_names) != len(test_mask_names):
            raise ValueError(
                "The number of test images does not match the number of test masks."
            )

        # Training data
        self._train_image_names = train_image_names
        self._train_target_names = train_mask_names

        # Validation data
        self._validation_image_names = val_image_names
        self._validation_target_names = val_mask_names

        # Test data
        self._test_image_names = test_image_names
        self._test_target_names = test_mask_names

    @staticmethod
    def _prediction_to_label_tiff_image(prediction):
        """Save the prediction to a label image (TIFF)"""

        # Convert to label image
        label_img = one_hot_stack_to_label_image(
            prediction,
            first_index_is_background=True,
            channels_first=True,
            dtype=np.uint16)

        return label_img

    def _define_transforms(self):
        """Define and initialize all data transforms.

          * training set images transform
          * training set targets transform
          * validation set images transform
          * validation set targets transform
          * validation set images post-transform
          * test set images transform
          * test set targets transform
          * test set images post-transform
          * prediction set images transform
          * prediction set images post-transform

        @return True if data transforms could be instantiated, False otherwise.
        """
        # Define transforms for training
        self._train_image_transforms = Compose([
            LoadImage(image_only=True),
            ScaleIntensityRange(0, 65535, 0.0, 1.0, clip=False),
            AddChannel(),
            RandSpatialCrop(self._roi_size, random_size=False),
            RandRotate90(prob=0.5, spatial_axes=(0, 1)),
            ToTensor()
        ])
        self._train_target_transforms = Compose([
            LoadImage(image_only=True),
            ScaleIntensityRange(0, 65535, 0.0, 1.0, clip=False),
            AddChannel(),
            RandSpatialCrop(self._roi_size, random_size=False),
            RandRotate90(prob=0.5, spatial_axes=(0, 1)),
            ToTensor()
        ])

        # Define transforms for validation
        self._validation_image_transforms = Compose([
            LoadImage(image_only=True),
            ScaleIntensityRange(0, 65535, 0.0, 1.0, clip=False),
            AddChannel(),
            ToTensor()
        ])
        self._validation_target_transforms = Compose([
            LoadImage(image_only=True),
            ScaleIntensityRange(0, 65535, 0.0, 1.0, clip=False),
            AddChannel(),
            ToTensor()
        ])

        # Define transforms for testing
        self._test_image_transforms = Compose([
            LoadImage(image_only=True),
            ScaleIntensityRange(0, 65535, 0.0, 1.0, clip=False),
            AddChannel(),
            ToTensor()
        ])
        self._test_target_transforms = Compose([
            LoadImage(image_only=True),
            ScaleIntensityRange(0, 65535, 0.0, 1.0, clip=False),
            AddChannel(),
            ToTensor()
        ])

        # Define transforms for prediction
        self._prediction_image_transforms = Compose(
            [LoadImage(image_only=True),
             AddChannel(),
             ToTensor()])

        # Post transforms
        self._validation_post_transforms = Compose([Identity()])

        self._test_post_transforms = Compose(
            [ToNumpy(), ScaleIntensity(0, 65535)])

        self._prediction_post_transforms = Compose(
            [ToNumpy(), ScaleIntensity(0, 65535)])

    def _define_training_data_loaders(self) -> bool:
        """Initialize training datasets and data loaders.

        @Note: in Windows, it is essential to set `persistent_workers=True` in the data loaders!

        @return True if datasets and data loaders could be instantiated, False otherwise.
        """

        # Optimize arguments
        if sys.platform == 'win32':
            persistent_workers = True
            pin_memory = False
        else:
            persistent_workers = False
            pin_memory = torch.cuda.is_available()

        if len(self._train_image_names) == 0 or \
                len(self._train_target_names) == 0 or \
                len(self._validation_image_names) == 0 or \
                len(self._validation_target_names) == 0 or \
                len(self._test_image_names) == 0 or \
                len(self._test_target_names) == 0:

            self._train_dataset = None
            self._train_dataloader = None
            self._validation_dataset = None
            self._validation_dataloader = None
            self._test_dataset = None
            self._test_dataloader = None

            return False

        # Training
        self._train_dataset = ArrayDataset(self._train_image_names,
                                           self._train_image_transforms,
                                           self._train_target_names,
                                           self._train_target_transforms)
        self._train_dataloader = DataLoader(
            self._train_dataset,
            batch_size=self._training_batch_size,
            shuffle=False,
            num_workers=self._training_num_workers,
            persistent_workers=persistent_workers,
            pin_memory=pin_memory)

        # Validation
        self._validation_dataset = ArrayDataset(
            self._validation_image_names, self._validation_image_transforms,
            self._validation_target_names, self._validation_target_transforms)
        self._validation_dataloader = DataLoader(
            self._validation_dataset,
            batch_size=self._validation_batch_size,
            shuffle=False,
            num_workers=self._validation_num_workers,
            persistent_workers=persistent_workers,
            pin_memory=pin_memory)

        # Test
        self._test_dataset = ArrayDataset(self._test_image_names,
                                          self._test_image_transforms,
                                          self._test_target_names,
                                          self._test_target_transforms)
        self._test_dataloader = DataLoader(
            self._test_dataset,
            batch_size=self._test_batch_size,
            shuffle=False,
            num_workers=self._test_num_workers,
            persistent_workers=persistent_workers,
            pin_memory=pin_memory)

        return True

    def _define_prediction_data_loaders(
            self, prediction_folder_path: Union[Path, str]) -> bool:
        """Initialize prediction datasets and data loaders.

        @Note: in Windows, it is essential to set `persistent_workers=True` in the data loaders!

        @return True if datasets and data loaders could be instantiated, False otherwise.
        """

        # Check that the path exists
        prediction_folder_path = Path(prediction_folder_path)
        if not prediction_folder_path.is_dir():
            return False

        # Scan for images
        self._prediction_image_names = natsorted(
            glob(str(Path(prediction_folder_path) / "*.tif")))

        # Optimize arguments
        if sys.platform == 'win32':
            persistent_workers = True
            pin_memory = False
        else:
            persistent_workers = False
            pin_memory = torch.cuda.is_available()

        if len(self._prediction_image_names) == 0:

            self._prediction_dataset = None
            self._prediction_dataloader = None

            return False

        # Prediction
        self._prediction_dataset = Dataset(self._prediction_image_names,
                                           self._prediction_image_transforms)
        self._prediction_dataloader = DataLoader(
            self._prediction_dataset,
            batch_size=self._test_batch_size,
            shuffle=False,
            num_workers=self._test_num_workers,
            persistent_workers=persistent_workers,
            pin_memory=pin_memory)

        return True

    def get_message(self):
        """Return last error message."""
        return self._message

    def get_best_model_path(self):
        """Return the full path to the best model."""
        return self._best_model

    def _clear_session(self) -> None:
        """Try clearing cache on the GPU."""
        if self._device != "cpu":
            torch.cuda.empty_cache()

    def _define_model(self) -> None:
        """Instantiate the U-Net architecture."""

        # Create U-Net
        self._device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device '{self._device}'.", file=self._stdout)

        # Try to free memory on the GPU
        if self._device != "cpu":
            torch.cuda.empty_cache()

        # Monai's UNet
        self._model = UNet(dimensions=2,
                           in_channels=self._in_channels,
                           out_channels=self._out_channels,
                           channels=(16, 32, 64, 128, 256),
                           strides=(2, 2, 2, 2),
                           num_res_units=2).to(self._device)

    def _define_training_loss(self) -> None:
        """Define the loss function."""

        # Use the MAE loss
        self._training_loss_function = L1Loss()

    def _define_optimizer(self,
                          learning_rate: float = 1e-3,
                          weight_decay: float = 1e-4) -> None:
        """Define the optimizer.

        @param learning_rate: float, optional, default = 1e-3
            Initial learning rate for the optimizer.

        @param weight_decay: float, optional, default = 1e-4
            Weight decay of the learning rate for the optimizer.

        """

        if self._model is None:
            return

        self._optimizer = Adam(self._model.parameters(),
                               learning_rate,
                               weight_decay=weight_decay,
                               amsgrad=True)

    def _define_validation_metric(self):
        """Define the metric for validation function."""

        self._validation_metric = DiceMetric(include_background=True,
                                             reduction="mean")

    def _prepare_experiment_and_model_names(self) -> Tuple[str, str]:
        """Prepare the experiment and model names.

        @return experiment_file_name, model_file_name

        Current date time is appended and the full path is returned.
        """

        # Make sure the "runs" subfolder exists
        runs_dir = Path(self._working_dir) / "runs"
        runs_dir.mkdir(parents=True, exist_ok=True)

        now = datetime.now()  # current date and time
        date_time = now.strftime("%Y%m%d_%H%M%S")

        # Experiment name
        experiment_name = f"{self._raw_experiment_name}_{date_time}" \
            if self._raw_experiment_name != "" \
            else f"{date_time}"
        experiment_name = runs_dir / experiment_name

        # Best model file name
        name = Path(self._raw_model_file_name).stem
        model_file_name = f"{name}_{date_time}.pth"
        model_file_name = runs_dir / model_file_name

        return str(experiment_name), str(model_file_name)

    def _print_header(self, header_text, line_length=80, file=None):
        """Print a section header."""
        if file is None:
            file = self._stdout
        print(f"{line_length * '-'}", file=file)
        print(f"{header_text}", file=self._stdout)
        print(f"{line_length * '-'}", file=file)
def main(config):
    now = datetime.now().strftime("%Y%m%d-%H:%M:%S")

    # path
    csv_path = config['path']['csv_path']

    trained_model_path = config['path'][
        'trained_model_path']  # if None, trained from scratch
    training_model_folder = os.path.join(
        config['path']['training_model_folder'], now)  # '/path/to/folder'
    if not os.path.exists(training_model_folder):
        os.makedirs(training_model_folder)
    logdir = os.path.join(training_model_folder, 'logs')
    if not os.path.exists(logdir):
        os.makedirs(logdir)

    # PET CT scan params
    image_shape = tuple(config['preprocessing']['image_shape'])  # (x, y, z)
    in_channels = config['preprocessing']['in_channels']
    voxel_spacing = tuple(
        config['preprocessing']
        ['voxel_spacing'])  # (4.8, 4.8, 4.8)  # in millimeter, (x, y, z)
    data_augment = config['preprocessing'][
        'data_augment']  # True  # for training dataset only
    resize = config['preprocessing']['resize']  # True  # not use yet
    origin = config['preprocessing']['origin']  # how to set the new origin
    normalize = config['preprocessing'][
        'normalize']  # True  # whether or not to normalize the inputs
    number_class = config['preprocessing']['number_class']  # 2

    # CNN params
    architecture = config['model']['architecture']  # 'unet' or 'vnet'

    cnn_params = config['model'][architecture]['cnn_params']
    # transform list to tuple
    for key, value in cnn_params.items():
        if isinstance(value, list):
            cnn_params[key] = tuple(value)

    # Training params
    epochs = config['training']['epochs']
    batch_size = config['training']['batch_size']
    shuffle = config['training']['shuffle']
    opt_params = config['training']["optimizer"]["opt_params"]

    # Get Data
    DM = DataManager(csv_path=csv_path)
    train_images_paths, val_images_paths, test_images_paths = DM.get_train_val_test(
        wrap_with_dict=True)

    # Input preprocessing
    # use data augmentation for training
    train_transforms = Compose([  # read img + meta info
        LoadNifti(keys=["pet_img", "ct_img", "mask_img"]),
        Roi2Mask(keys=['pet_img', 'mask_img'],
                 method='otsu',
                 tval=0.0,
                 idx_channel=0),
        ResampleReshapeAlign(target_shape=image_shape,
                             target_voxel_spacing=voxel_spacing,
                             keys=['pet_img', "ct_img", 'mask_img'],
                             origin='head',
                             origin_key='pet_img'),
        Sitk2Numpy(keys=['pet_img', 'ct_img', 'mask_img']),
        # user can also add other random transforms
        RandAffined(keys=("pet_img", "ct_img", "mask_img"),
                    spatial_size=None,
                    prob=0.4,
                    rotate_range=(0, np.pi / 30, np.pi / 15),
                    shear_range=None,
                    translate_range=(10, 10, 10),
                    scale_range=(0.1, 0.1, 0.1),
                    mode=("bilinear", "bilinear", "nearest"),
                    padding_mode="border"),
        # normalize input
        ScaleIntensityRanged(
            keys=["pet_img"],
            a_min=0.0,
            a_max=25.0,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        ScaleIntensityRanged(
            keys=["ct_img"],
            a_min=-1000.0,
            a_max=1000.0,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        # Prepare for neural network
        ConcatModality(keys=['pet_img', 'ct_img']),
        AddChanneld(keys=["mask_img"]),  # Add channel to the first axis
        ToTensord(keys=["image", "mask_img"]),
    ])
    # without data augmentation for validation
    val_transforms = Compose([  # read img + meta info
        LoadNifti(keys=["pet_img", "ct_img", "mask_img"]),
        Roi2Mask(keys=['pet_img', 'mask_img'],
                 method='otsu',
                 tval=0.0,
                 idx_channel=0),
        ResampleReshapeAlign(target_shape=image_shape,
                             target_voxel_spacing=voxel_spacing,
                             keys=['pet_img', "ct_img", 'mask_img'],
                             origin='head',
                             origin_key='pet_img'),
        Sitk2Numpy(keys=['pet_img', 'ct_img', 'mask_img']),
        # normalize input
        ScaleIntensityRanged(
            keys=["pet_img"],
            a_min=0.0,
            a_max=25.0,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        ScaleIntensityRanged(
            keys=["ct_img"],
            a_min=-1000.0,
            a_max=1000.0,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        # Prepare for neural network
        ConcatModality(keys=['pet_img', 'ct_img']),
        AddChanneld(keys=["mask_img"]),  # Add channel to the first axis
        ToTensord(keys=["image", "mask_img"]),
    ])

    # create a training data loader
    train_ds = monai.data.CacheDataset(data=train_images_paths,
                                       transform=train_transforms,
                                       cache_rate=0.5)
    # use batch_size=2 to load images to generate 2 x 4 images for network training
    train_loader = monai.data.DataLoader(train_ds,
                                         batch_size=batch_size,
                                         shuffle=shuffle,
                                         num_workers=2)
    # create a validation data loader
    val_ds = monai.data.CacheDataset(data=val_images_paths,
                                     transform=val_transforms,
                                     cache_rate=1.0)
    val_loader = monai.data.DataLoader(val_ds,
                                       batch_size=batch_size,
                                       num_workers=2)

    # Model
    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = UNet(
        dimensions=3,  # 3D
        in_channels=in_channels,
        out_channels=1,
        kernel_size=5,
        channels=(8, 16, 32, 64, 128),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss = monai.losses.DiceLoss(sigmoid=True, squared_pred=True)
    opt = torch.optim.Adam(net.parameters(), 1e-3)

    # training
    val_post_transforms = Compose([
        Activationsd(keys="pred", sigmoid=True),
        AsDiscreted(keys="pred", threshold_values=True),
    ])
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        TensorBoardStatsHandler(log_dir="./runs/",
                                output_transform=lambda x: None),
        # TensorBoardImageHandler(
        #     log_dir="./runs/",
        #     batch_transform=lambda x: (x["image"], x["label"]),
        #     output_transform=lambda x: x["pred"],
        # ),
        CheckpointSaver(save_dir="./runs/",
                        save_dict={
                            "net": net,
                            "opt": opt
                        },
                        save_key_metric=True),
    ]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SimpleInferer(),
        post_transform=val_post_transforms,
        key_val_metric={
            "val_mean_dice":
            MeanDice(include_background=True,
                     output_transform=lambda x: (x["pred"], x["label"]))
        },
        additional_metrics={
            "val_acc":
            Accuracy(output_transform=lambda x: (x["pred"], x["label"])),
            "val_precision":
            Precision(output_transform=lambda x: (x["pred"], x["label"])),
            "val_recall":
            Recall(output_transform=lambda x: (x["pred"], x["label"]))
        },
        val_handlers=val_handlers,
        # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation
        # amp=True if monai.config.get_torch_version_tuple() >= (1, 6) else False,
    )

    train_post_transforms = Compose([
        Activationsd(keys="pred", sigmoid=True),
        AsDiscreted(keys="pred", threshold_values=True),
    ])
    train_handlers = [
        # LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
        ValidationHandler(validator=evaluator, interval=1, epoch_level=True),
        StatsHandler(tag_name="train_loss",
                     output_transform=lambda x: x["loss"]),
        TensorBoardStatsHandler(log_dir="./runs/",
                                tag_name="train_loss",
                                output_transform=lambda x: x["loss"]),
        CheckpointSaver(save_dir="./runs/",
                        save_dict={
                            "net": net,
                            "opt": opt
                        },
                        save_interval=2,
                        epoch_level=True),
    ]

    trainer = SupervisedTrainer(
        device=device,
        max_epochs=5,
        train_data_loader=train_loader,
        network=net,
        optimizer=opt,
        loss_function=loss,
        prepare_batch=lambda x: (x['image'], x['mask_img']),
        inferer=SimpleInferer(),
        post_transform=train_post_transforms,
        key_train_metric={
            "train_mean_dice":
            MeanDice(include_background=True,
                     output_transform=lambda x: (x["pred"], x["label"]))
        },
        additional_metrics={
            "train_acc":
            Accuracy(output_transform=lambda x: (x["pred"], x["label"])),
            "train_precision":
            Precision(output_transform=lambda x: (x["pred"], x["label"])),
            "train_recall":
            Recall(output_transform=lambda x: (x["pred"], x["label"]))
        },
        train_handlers=train_handlers,
        # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP training
        amp=True if monai.config.get_torch_version_tuple() >=
        (1, 6) else False,
    )
    trainer.run()