def test_compute(self, input_params, expected):
        data = [
            {
                "image": torch.tensor([[[[2.0], [3.0]]]]),
                "filename": "test1"
            },
            {
                "image": torch.tensor([[[[6.0], [8.0]]]]),
                "filename": "test2"
            },
        ]
        # set up engine, PostProcessing handler works together with post_transform of engine
        engine = SupervisedEvaluator(
            device=torch.device("cpu:0"),
            val_data_loader=data,
            epoch_length=2,
            network=torch.nn.PReLU(),
            post_transform=Compose([Activationsd(keys="pred", sigmoid=True)]),
            val_handlers=[PostProcessing(**input_params)],
        )
        engine.run()

        torch.testing.assert_allclose(engine.state.output["pred"], expected)
        filename = engine.state.output.get("filename_bak")
        if filename is not None:
            self.assertEqual(filename, "test2")
Exemplo n.º 2
0
 def test_content(self, input_args, expected_value):
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
     dataloader = [{
         "image": torch.tensor([1, 2]),
         "label": torch.tensor([3, 4]),
         "extra1": torch.tensor([5, 6]),
         "extra2": 16,
         "extra3": "test",
     }]
     # set up engine
     evaluator = SupervisedEvaluator(
         device=device,
         val_data_loader=dataloader,
         epoch_length=1,
         network=TestNet(),
         non_blocking=True,
         prepare_batch=PrepareBatchExtraInput(**input_args),
         decollate=False,
     )
     evaluator.run()
     output = evaluator.state.output
     assert_allclose(output["image"], torch.tensor([1, 2], device=device))
     assert_allclose(output["label"], torch.tensor([3, 4], device=device))
     for k, v in output["pred"].items():
         if isinstance(v, torch.Tensor):
             assert_allclose(v, expected_value[k].to(device))
         else:
             self.assertEqual(v, expected_value[k])
Exemplo n.º 3
0
    def test_compute(self, input_params, decollate, expected):
        data = [
            {"image": torch.tensor([[[[2.0], [3.0]]]]), "filename": ["test1"]},
            {"image": torch.tensor([[[[6.0], [8.0]]]]), "filename": ["test2"]},
        ]
        # set up engine, PostProcessing handler works together with postprocessing transforms of engine
        engine = SupervisedEvaluator(
            device=torch.device("cpu:0"),
            val_data_loader=data,
            epoch_length=2,
            network=torch.nn.PReLU(),
            postprocessing=Compose([Activationsd(keys="pred", sigmoid=True)]),
            val_handlers=[PostProcessing(**input_params)],
            decollate=decollate,
        )
        engine.run()

        if isinstance(engine.state.output, list):
            # test decollated list items
            for o, e in zip(engine.state.output, expected):
                torch.testing.assert_allclose(o["pred"], e)
                filename = o.get("filename_bak")
                if filename is not None:
                    self.assertEqual(filename, "test2")
        else:
            # test batch data
            torch.testing.assert_allclose(engine.state.output["pred"], expected)
Exemplo n.º 4
0
 def test_empty_data(self):
     dataloader = []
     evaluator = SupervisedEvaluator(
         val_data_loader=dataloader,
         device=torch.device("cpu"),
         epoch_length=0,
         network=TestNet(),
         non_blocking=False,
         prepare_batch=PrepareBatchDefault(),
         decollate=False,
     )
     evaluator.run()
Exemplo n.º 5
0
    def test_compute(self, data, expected):
        # Set up handlers
        handlers = [
            # Mark with Ignite Event
            MarkHandler(Events.STARTED),
            # Mark with literal
            MarkHandler("EPOCH_STARTED"),
            # Mark with literal and providing the message
            MarkHandler("EPOCH_STARTED", "Start of the epoch"),
            # Define a range using one prefix (between BATCH_STARTED and BATCH_COMPLETED)
            RangeHandler("Batch"),
            # Define a range using a pair of events
            RangeHandler((Events.STARTED, Events.COMPLETED)),
            # Define a range using a pair of literals
            RangeHandler(("GET_BATCH_STARTED", "GET_BATCH_COMPLETED"),
                         msg="Batching!"),
            # Define a range using a pair of literal and events
            RangeHandler(("GET_BATCH_STARTED", Events.COMPLETED)),
            # Define the start of range using literal
            RangePushHandler("ITERATION_STARTED"),
            # Define the start of range using event
            RangePushHandler(Events.ITERATION_STARTED, "Iteration 2"),
            # Define the start of range using literals and providing message
            RangePushHandler("EPOCH_STARTED", "Epoch 2"),
            # Define the end of range using Ignite Event
            RangePopHandler(Events.ITERATION_COMPLETED),
            RangePopHandler(Events.EPOCH_COMPLETED),
            # Define the end of range using literal
            RangePopHandler("ITERATION_COMPLETED"),
            # Other handlers
            StatsHandler(tag_name="train",
                         output_transform=from_engine(["label"], first=True)),
        ]

        # Set up an engine
        engine = SupervisedEvaluator(
            device=torch.device("cpu:0"),
            val_data_loader=data,
            epoch_length=1,
            network=torch.nn.PReLU(),
            postprocessing=lambda x: dict(pred=x["pred"] + 1.0),
            decollate=True,
            val_handlers=handlers,
        )
        # Run the engine
        engine.run()

        # Get the output from the engine
        output = engine.state.output[0]

        torch.testing.assert_allclose(output["pred"], expected)
Exemplo n.º 6
0
    def _create_evaluator(self, context: Context):
        evaluator = None
        if context.val_datalist and len(context.val_datalist) > 0:
            val_hanlders: List = self.val_handlers(context)
            if context.local_rank == 0:
                val_hanlders.append(
                    CheckpointSaver(
                        save_dir=context.output_dir,
                        save_dict={self._model_dict_key: context.network},
                        save_key_metric=True,
                        key_metric_filename=self._key_metric_filename,
                        n_saved=5,
                    ))

            evaluator = SupervisedEvaluator(
                device=context.device,
                val_data_loader=self.val_data_loader(context),
                network=context.network,
                inferer=self.val_inferer(context),
                postprocessing=self._validate_transforms(
                    self.val_post_transforms(context), "Validation", "post"),
                key_val_metric=self.val_key_metric(context),
                additional_metrics=self.val_additional_metrics(context),
                val_handlers=val_hanlders,
                iteration_update=self.val_iteration_update(context),
                event_names=self.event_names(context),
            )
        return evaluator
 def test_compute(self, dataloaders):
     device = torch.device(f"cuda:{dist.get_rank()}" if torch.cuda.is_available() else "cpu")
     dataloader = dataloaders[dist.get_rank()]
     # set up engine
     evaluator = SupervisedEvaluator(
         device=device,
         val_data_loader=dataloader,
         epoch_length=len(dataloader),
         network=TestNet(),
         non_blocking=False,
         prepare_batch=PrepareBatchDefault(),
         decollate=False,
     )
     evaluator.run()
     output = evaluator.state.output
     if len(dataloader) > 0:
         assert_allclose(output["image"], dataloader[-1]["image"].to(device=device))
         assert_allclose(output["label"], dataloader[-1]["label"].to(device=device))
Exemplo n.º 8
0
    def test_compute(self):
        data = [
            {
                "image": torch.tensor([[[[2.0], [3.0]]]]),
                "filename": ["test1"]
            },
            {
                "image": torch.tensor([[[[6.0], [8.0]]]]),
                "filename": ["test2"]
            },
        ]

        handlers = [
            DecollateBatch(event="MODEL_COMPLETED"),
            PostProcessing(transform=Compose([
                Activationsd(keys="pred", sigmoid=True),
                CopyItemsd(keys="filename", times=1, names="filename_bak"),
                AsDiscreted(keys="pred",
                            threshold_values=True,
                            to_onehot=True,
                            num_classes=2),
            ])),
        ]
        # set up engine, PostProcessing handler works together with postprocessing transforms of engine
        engine = SupervisedEvaluator(
            device=torch.device("cpu:0"),
            val_data_loader=data,
            epoch_length=2,
            network=torch.nn.PReLU(),
            # set decollate=False and execute some postprocessing first, then decollate in handlers
            postprocessing=lambda x: dict(pred=x["pred"] + 1.0),
            decollate=False,
            val_handlers=handlers,
        )
        engine.run()

        expected = torch.tensor([[[[1.0], [1.0]], [[0.0], [0.0]]]])

        for o, e in zip(engine.state.output, expected):
            torch.testing.assert_allclose(o["pred"], e)
            filename = o.get("filename_bak")
            if filename is not None:
                self.assertEqual(filename, "test2")
 def test_content(self):
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
     dataloader = [{
         "image": torch.tensor([1, 2]),
         "label": torch.tensor([3, 4]),
         "extra1": torch.tensor([5, 6]),
         "extra2": 16,
         "extra3": "test",
     }]
     # set up engine
     evaluator = SupervisedEvaluator(
         device=device,
         val_data_loader=dataloader,
         epoch_length=1,
         network=TestNet(),
         non_blocking=False,
         prepare_batch=PrepareBatchDefault(),
         decollate=False,
     )
     evaluator.run()
     output = evaluator.state.output
     assert_allclose(output["image"], torch.tensor([1, 2], device=device))
     assert_allclose(output["label"], torch.tensor([3, 4], device=device))
Exemplo n.º 10
0
    def train(self,
              train_info,
              valid_info,
              hyperparameters,
              run_data_check=False):

        logging.basicConfig(stream=sys.stdout, level=logging.INFO)

        if not run_data_check:
            start_dt = datetime.datetime.now()
            start_dt_string = start_dt.strftime('%d/%m/%Y %H:%M:%S')
            print(f'Training started: {start_dt_string}')

            # 1. Create folders to save the model
            timedate_info = str(
                datetime.datetime.now()).split(' ')[0] + '_' + str(
                    datetime.datetime.now().strftime("%H:%M:%S")).replace(
                        ':', '-')
            path_to_model = os.path.join(
                self.out_dir, 'trained_models',
                self.unique_name + '_' + timedate_info)
            os.mkdir(path_to_model)

        # 2. Load hyperparameters
        learning_rate = hyperparameters['learning_rate']
        weight_decay = hyperparameters['weight_decay']
        total_epoch = hyperparameters['total_epoch']
        multiplicator = hyperparameters['multiplicator']
        batch_size = hyperparameters['batch_size']
        validation_epoch = hyperparameters['validation_epoch']
        validation_interval = hyperparameters['validation_interval']
        H = hyperparameters['H']
        L = hyperparameters['L']

        # 3. Consider class imbalance
        negative, positive = 0, 0
        for _, label in train_info:
            if int(label) == 0:
                negative += 1
            elif int(label) == 1:
                positive += 1

        pos_weight = torch.Tensor([(negative / positive)]).to(self.device)

        # 4. Create train and validation loaders, batch_size = 10 for validation loader (10 central slices)

        train_data = get_data_from_info(self.image_data_dir, self.seg_data_dir,
                                        train_info)
        valid_data = get_data_from_info(self.image_data_dir, self.seg_data_dir,
                                        valid_info)
        large_image_splitter(train_data, self.cache_dir)

        set_determinism(seed=100)
        train_trans, valid_trans = self.transformations(H, L)
        train_dataset = PersistentDataset(
            data=train_data[:],
            transform=train_trans,
            cache_dir=self.persistent_dataset_dir)
        valid_dataset = PersistentDataset(
            data=valid_data[:],
            transform=valid_trans,
            cache_dir=self.persistent_dataset_dir)

        train_loader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  pin_memory=self.pin_memory,
                                  num_workers=self.num_workers,
                                  collate_fn=PadListDataCollate(
                                      Method.SYMMETRIC, NumpyPadMode.CONSTANT))
        valid_loader = DataLoader(valid_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  pin_memory=self.pin_memory,
                                  num_workers=self.num_workers,
                                  collate_fn=PadListDataCollate(
                                      Method.SYMMETRIC, NumpyPadMode.CONSTANT))

        # Perform data checks
        if run_data_check:
            check_data = monai.utils.misc.first(train_loader)
            print(check_data["image"].shape, check_data["label"])
            for i in range(batch_size):
                multi_slice_viewer(
                    check_data["image"][i, 0, :, :, :],
                    check_data["image_meta_dict"]["filename_or_obj"][i])
            exit()
        """c = 1
        for d in train_loader:
            img = d["image"]
            seg = d["seg"][0]
            seg, _ = nrrd.read(seg)
            img_name = d["image_meta_dict"]["filename_or_obj"][0]
            print(c, "Name:", img_name, "Size:", img.nelement()*img.element_size()/1024/1024, "MB", "shape:", img.shape)
            multi_slice_viewer(img[0, 0, :, :, :], d["image_meta_dict"]["filename_or_obj"][0])
            #multi_slice_viewer(seg, d["image_meta_dict"]["filename_or_obj"][0])
            c += 1
        exit()"""

        # 5. Prepare model
        model = ModelCT().to(self.device)

        # 6. Define loss function, optimizer and scheduler
        loss_function = torch.nn.BCEWithLogitsLoss(
            pos_weight)  # pos_weight for class imbalance
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=learning_rate,
                                     weight_decay=weight_decay)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
                                                           multiplicator,
                                                           last_epoch=-1)
        # 7. Create post validation transforms and handlers
        path_to_tensorboard = os.path.join(self.out_dir, 'tensorboard')
        writer = SummaryWriter(log_dir=path_to_tensorboard)
        valid_post_transforms = Compose([
            Activationsd(keys="pred", sigmoid=True),
        ])
        valid_handlers = [
            StatsHandler(output_transform=lambda x: None),
            TensorBoardStatsHandler(summary_writer=writer,
                                    output_transform=lambda x: None),
            CheckpointSaver(save_dir=path_to_model,
                            save_dict={"model": model},
                            save_key_metric=True),
            MetricsSaver(save_dir=path_to_model,
                         metrics=['Valid_AUC', 'Valid_ACC']),
        ]
        # 8. Create validatior
        discrete = AsDiscrete(threshold_values=True)
        evaluator = SupervisedEvaluator(
            device=self.device,
            val_data_loader=valid_loader,
            network=model,
            post_transform=valid_post_transforms,
            key_val_metric={
                "Valid_AUC":
                ROCAUC(output_transform=lambda x: (x["pred"], x["label"]))
            },
            additional_metrics={
                "Valid_Accuracy":
                Accuracy(output_transform=lambda x:
                         (discrete(x["pred"]), x["label"]))
            },
            val_handlers=valid_handlers,
            amp=self.amp,
        )
        # 9. Create trainer

        # Loss function does the last sigmoid, so we dont need it here.
        train_post_transforms = Compose([
            # Empty
        ])
        logger = MetricLogger(evaluator=evaluator)
        train_handlers = [
            logger,
            LrScheduleHandler(lr_scheduler=scheduler, print_lr=True),
            ValidationHandlerCT(validator=evaluator,
                                start=validation_epoch,
                                interval=validation_interval,
                                epoch_level=True),
            StatsHandler(tag_name="loss",
                         output_transform=lambda x: x["loss"]),
            TensorBoardStatsHandler(summary_writer=writer,
                                    tag_name="Train_Loss",
                                    output_transform=lambda x: x["loss"]),
            CheckpointSaver(save_dir=path_to_model,
                            save_dict={
                                "model": model,
                                "opt": optimizer
                            },
                            save_interval=1,
                            n_saved=1),
        ]

        trainer = SupervisedTrainer(
            device=self.device,
            max_epochs=total_epoch,
            train_data_loader=train_loader,
            network=model,
            optimizer=optimizer,
            loss_function=loss_function,
            post_transform=train_post_transforms,
            train_handlers=train_handlers,
            amp=self.amp,
        )
        # 10. Run trainer
        trainer.run()
        # 11. Save results
        np.save(path_to_model + '/AUCS.npy',
                np.array(logger.metrics['Valid_AUC']))
        np.save(path_to_model + '/ACCS.npy',
                np.array(logger.metrics['Valid_ACC']))
        np.save(path_to_model + '/LOSSES.npy', np.array(logger.loss))
        np.save(path_to_model + '/PARAMETERS.npy', np.array(hyperparameters))

        return path_to_model
Exemplo n.º 11
0
def run_training_test(root_dir, device="cuda:0", amp=False):
    images = sorted(glob(os.path.join(root_dir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    train_files = [{
        "image": img,
        "label": seg
    } for img, seg in zip(images[:20], segs[:20])]
    val_files = [{
        "image": img,
        "label": seg
    } for img, seg in zip(images[-20:], segs[-20:])]

    # define transforms for image and segmentation
    train_transforms = Compose([
        LoadNiftid(keys=["image", "label"]),
        AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
        ScaleIntensityd(keys=["image", "label"]),
        RandCropByPosNegLabeld(keys=["image", "label"],
                               label_key="label",
                               spatial_size=[96, 96, 96],
                               pos=1,
                               neg=1,
                               num_samples=4),
        RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]),
        ToTensord(keys=["image", "label"]),
    ])
    val_transforms = Compose([
        LoadNiftid(keys=["image", "label"]),
        AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
        ScaleIntensityd(keys=["image", "label"]),
        ToTensord(keys=["image", "label"]),
    ])

    # create a training data loader
    train_ds = monai.data.CacheDataset(data=train_files,
                                       transform=train_transforms,
                                       cache_rate=0.5)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = monai.data.DataLoader(train_ds,
                                         batch_size=2,
                                         shuffle=True,
                                         num_workers=4)
    # create a validation data loader
    val_ds = monai.data.CacheDataset(data=val_files,
                                     transform=val_transforms,
                                     cache_rate=1.0)
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)

    # create UNet, DiceLoss and Adam optimizer
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss = monai.losses.DiceLoss(sigmoid=True)
    opt = torch.optim.Adam(net.parameters(), 1e-3)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=2, gamma=0.1)

    val_post_transforms = Compose([
        Activationsd(keys="pred", sigmoid=True),
        AsDiscreted(keys="pred", threshold_values=True),
        KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
    ])
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        TensorBoardStatsHandler(log_dir=root_dir,
                                output_transform=lambda x: None),
        TensorBoardImageHandler(log_dir=root_dir,
                                batch_transform=lambda x:
                                (x["image"], x["label"]),
                                output_transform=lambda x: x["pred"]),
        CheckpointSaver(save_dir=root_dir,
                        save_dict={"net": net},
                        save_key_metric=True),
    ]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96),
                                     sw_batch_size=4,
                                     overlap=0.5),
        post_transform=val_post_transforms,
        key_val_metric={
            "val_mean_dice":
            MeanDice(include_background=True,
                     output_transform=lambda x: (x["pred"], x["label"]))
        },
        additional_metrics={
            "val_acc":
            Accuracy(output_transform=lambda x: (x["pred"], x["label"]))
        },
        val_handlers=val_handlers,
        amp=True if amp else False,
    )

    train_post_transforms = Compose([
        Activationsd(keys="pred", sigmoid=True),
        AsDiscreted(keys="pred", threshold_values=True),
        KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
    ])
    train_handlers = [
        LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
        ValidationHandler(validator=evaluator, interval=2, epoch_level=True),
        StatsHandler(tag_name="train_loss",
                     output_transform=lambda x: x["loss"]),
        TensorBoardStatsHandler(log_dir=root_dir,
                                tag_name="train_loss",
                                output_transform=lambda x: x["loss"]),
        CheckpointSaver(save_dir=root_dir,
                        save_dict={
                            "net": net,
                            "opt": opt
                        },
                        save_interval=2,
                        epoch_level=True),
    ]

    trainer = SupervisedTrainer(
        device=device,
        max_epochs=5,
        train_data_loader=train_loader,
        network=net,
        optimizer=opt,
        loss_function=loss,
        inferer=SimpleInferer(),
        post_transform=train_post_transforms,
        key_train_metric={
            "train_acc":
            Accuracy(output_transform=lambda x: (x["pred"], x["label"]))
        },
        train_handlers=train_handlers,
        amp=True if amp else False,
    )
    trainer.run()

    return evaluator.state.best_metric
Exemplo n.º 12
0
def run_inference_test(root_dir, model_file, device="cuda:0", amp=False):
    images = sorted(glob(os.path.join(root_dir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    val_files = [{
        "image": img,
        "label": seg
    } for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose([
        LoadNiftid(keys=["image", "label"]),
        AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
        ScaleIntensityd(keys=["image", "label"]),
        ToTensord(keys=["image", "label"]),
    ])

    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)

    # create UNet, DiceLoss and Adam optimizer
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    val_post_transforms = Compose([
        Activationsd(keys="pred", sigmoid=True),
        AsDiscreted(keys="pred", threshold_values=True),
        KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
    ])
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        CheckpointLoader(load_path=f"{model_file}", load_dict={"net": net}),
        SegmentationSaver(
            output_dir=root_dir,
            batch_transform=lambda batch: batch["image_meta_dict"],
            output_transform=lambda output: output["pred"],
        ),
    ]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96),
                                     sw_batch_size=4,
                                     overlap=0.5),
        post_transform=val_post_transforms,
        key_val_metric={
            "val_mean_dice":
            MeanDice(include_background=True,
                     output_transform=lambda x: (x["pred"], x["label"]))
        },
        additional_metrics={
            "val_acc":
            Accuracy(output_transform=lambda x: (x["pred"], x["label"]))
        },
        val_handlers=val_handlers,
        amp=True if amp else False,
    )
    evaluator.run()

    return evaluator.state.best_metric
def run_inference_test(root_dir, model_file, device="cuda:0", amp=False, num_workers=4):
    images = sorted(glob(os.path.join(root_dir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    val_files = [{"image": img, "label": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
            ScaleIntensityd(keys=["image", "label"]),
            ToTensord(keys=["image", "label"]),
        ]
    )

    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=num_workers)

    # create UNet, DiceLoss and Adam optimizer
    net = monai.networks.nets.UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    val_postprocessing = Compose(
        [
            ToTensord(keys=["pred", "label"]),
            Activationsd(keys="pred", sigmoid=True),
            AsDiscreted(keys="pred", threshold=0.5),
            KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
            # test the case that `pred` in `engine.state.output`, while `image_meta_dict` in `engine.state.batch`
            SaveImaged(
                keys="pred", meta_keys=PostFix.meta("image"), output_dir=root_dir, output_postfix="seg_transform"
            ),
        ]
    )
    val_handlers = [
        StatsHandler(iteration_log=False),
        CheckpointLoader(load_path=f"{model_file}", load_dict={"net": net}),
        SegmentationSaver(
            output_dir=root_dir,
            output_postfix="seg_handler",
            batch_transform=from_engine(PostFix.meta("image")),
            output_transform=from_engine("pred"),
        ),
    ]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5),
        postprocessing=val_postprocessing,
        key_val_metric={
            "val_mean_dice": MeanDice(include_background=True, output_transform=from_engine(["pred", "label"]))
        },
        additional_metrics={"val_acc": Accuracy(output_transform=from_engine(["pred", "label"]))},
        val_handlers=val_handlers,
        amp=bool(amp),
    )
    evaluator.run()

    return evaluator.state.best_metric
def run_training_test(root_dir, device="cuda:0", amp=False, num_workers=4):
    images = sorted(glob(os.path.join(root_dir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    train_files = [{"image": img, "label": seg} for img, seg in zip(images[:20], segs[:20])]
    val_files = [{"image": img, "label": seg} for img, seg in zip(images[-20:], segs[-20:])]

    # define transforms for image and segmentation
    train_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
            ScaleIntensityd(keys=["image", "label"]),
            RandCropByPosNegLabeld(
                keys=["image", "label"], label_key="label", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4
            ),
            RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]),
            ToTensord(keys=["image", "label"]),
        ]
    )
    val_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
            ScaleIntensityd(keys=["image", "label"]),
            ToTensord(keys=["image", "label"]),
        ]
    )

    # create a training data loader
    train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.5)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=num_workers)
    # create a validation data loader
    val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0)
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=num_workers)

    # create UNet, DiceLoss and Adam optimizer
    net = monai.networks.nets.UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss = monai.losses.DiceLoss(sigmoid=True)
    opt = torch.optim.Adam(net.parameters(), 1e-3)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=2, gamma=0.1)
    summary_writer = SummaryWriter(log_dir=root_dir)

    val_postprocessing = Compose(
        [
            ToTensord(keys=["pred", "label"]),
            Activationsd(keys="pred", sigmoid=True),
            AsDiscreted(keys="pred", threshold=0.5),
            KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
        ]
    )

    class _TestEvalIterEvents:
        def attach(self, engine):
            engine.add_event_handler(IterationEvents.FORWARD_COMPLETED, self._forward_completed)

        def _forward_completed(self, engine):
            pass

    val_handlers = [
        StatsHandler(iteration_log=False),
        TensorBoardStatsHandler(summary_writer=summary_writer, iteration_log=False),
        TensorBoardImageHandler(
            log_dir=root_dir, batch_transform=from_engine(["image", "label"]), output_transform=from_engine("pred")
        ),
        CheckpointSaver(save_dir=root_dir, save_dict={"net": net}, save_key_metric=True),
        _TestEvalIterEvents(),
    ]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5),
        postprocessing=val_postprocessing,
        key_val_metric={
            "val_mean_dice": MeanDice(include_background=True, output_transform=from_engine(["pred", "label"]))
        },
        additional_metrics={"val_acc": Accuracy(output_transform=from_engine(["pred", "label"]))},
        metric_cmp_fn=lambda cur, prev: cur >= prev,  # if greater or equal, treat as new best metric
        val_handlers=val_handlers,
        amp=bool(amp),
        to_kwargs={"memory_format": torch.preserve_format},
        amp_kwargs={"dtype": torch.float16 if bool(amp) else torch.float32},
    )

    train_postprocessing = Compose(
        [
            ToTensord(keys=["pred", "label"]),
            Activationsd(keys="pred", sigmoid=True),
            AsDiscreted(keys="pred", threshold=0.5),
            KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
        ]
    )

    class _TestTrainIterEvents:
        def attach(self, engine):
            engine.add_event_handler(IterationEvents.FORWARD_COMPLETED, self._forward_completed)
            engine.add_event_handler(IterationEvents.LOSS_COMPLETED, self._loss_completed)
            engine.add_event_handler(IterationEvents.BACKWARD_COMPLETED, self._backward_completed)
            engine.add_event_handler(IterationEvents.MODEL_COMPLETED, self._model_completed)

        def _forward_completed(self, engine):
            pass

        def _loss_completed(self, engine):
            pass

        def _backward_completed(self, engine):
            pass

        def _model_completed(self, engine):
            pass

    train_handlers = [
        LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
        ValidationHandler(validator=evaluator, interval=2, epoch_level=True),
        StatsHandler(tag_name="train_loss", output_transform=from_engine("loss", first=True)),
        TensorBoardStatsHandler(
            summary_writer=summary_writer, tag_name="train_loss", output_transform=from_engine("loss", first=True)
        ),
        CheckpointSaver(save_dir=root_dir, save_dict={"net": net, "opt": opt}, save_interval=2, epoch_level=True),
        _TestTrainIterEvents(),
    ]

    trainer = SupervisedTrainer(
        device=device,
        max_epochs=5,
        train_data_loader=train_loader,
        network=net,
        optimizer=opt,
        loss_function=loss,
        inferer=SimpleInferer(),
        postprocessing=train_postprocessing,
        key_train_metric={"train_acc": Accuracy(output_transform=from_engine(["pred", "label"]))},
        train_handlers=train_handlers,
        amp=bool(amp),
        optim_set_to_none=True,
        to_kwargs={"memory_format": torch.preserve_format},
        amp_kwargs={"dtype": torch.float16 if bool(amp) else torch.float32},
    )
    trainer.run()

    return evaluator.state.best_metric
Exemplo n.º 15
0
    def configure(self):
        self.set_device()
        network = UNet(
            dimensions=3,
            in_channels=1,
            out_channels=2,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2,
            norm=Norm.BATCH,
        ).to(self.device)
        if self.multi_gpu:
            network = DistributedDataParallel(
                module=network,
                device_ids=[self.device],
                find_unused_parameters=False,
            )

        train_transforms = Compose([
            LoadImaged(keys=("image", "label")),
            EnsureChannelFirstd(keys=("image", "label")),
            Spacingd(keys=("image", "label"),
                     pixdim=[1.0, 1.0, 1.0],
                     mode=["bilinear", "nearest"]),
            ScaleIntensityRanged(
                keys="image",
                a_min=-57,
                a_max=164,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            CropForegroundd(keys=("image", "label"), source_key="image"),
            RandCropByPosNegLabeld(
                keys=("image", "label"),
                label_key="label",
                spatial_size=(96, 96, 96),
                pos=1,
                neg=1,
                num_samples=4,
                image_key="image",
                image_threshold=0,
            ),
            RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
            ToTensord(keys=("image", "label")),
        ])
        train_datalist = load_decathlon_datalist(self.data_list_file_path,
                                                 True, "training")
        if self.multi_gpu:
            train_datalist = partition_dataset(
                data=train_datalist,
                shuffle=True,
                num_partitions=dist.get_world_size(),
                even_divisible=True,
            )[dist.get_rank()]
        train_ds = CacheDataset(
            data=train_datalist,
            transform=train_transforms,
            cache_num=32,
            cache_rate=1.0,
            num_workers=4,
        )
        train_data_loader = DataLoader(
            train_ds,
            batch_size=2,
            shuffle=True,
            num_workers=4,
        )
        val_transforms = Compose([
            LoadImaged(keys=("image", "label")),
            EnsureChannelFirstd(keys=("image", "label")),
            ScaleIntensityRanged(
                keys="image",
                a_min=-57,
                a_max=164,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            CropForegroundd(keys=("image", "label"), source_key="image"),
            ToTensord(keys=("image", "label")),
        ])

        val_datalist = load_decathlon_datalist(self.data_list_file_path, True,
                                               "validation")
        val_ds = CacheDataset(val_datalist, val_transforms, 9, 0.0, 4)
        val_data_loader = DataLoader(
            val_ds,
            batch_size=1,
            shuffle=False,
            num_workers=4,
        )
        post_transform = Compose([
            Activationsd(keys="pred", softmax=True),
            AsDiscreted(
                keys=["pred", "label"],
                argmax=[True, False],
                to_onehot=True,
                n_classes=2,
            ),
        ])
        # metric
        key_val_metric = {
            "val_mean_dice":
            MeanDice(
                include_background=False,
                output_transform=lambda x: (x["pred"], x["label"]),
                device=self.device,
            )
        }
        val_handlers = [
            StatsHandler(output_transform=lambda x: None),
            CheckpointSaver(
                save_dir=self.ckpt_dir,
                save_dict={"model": network},
                save_key_metric=True,
            ),
            TensorBoardStatsHandler(log_dir=self.ckpt_dir,
                                    output_transform=lambda x: None),
        ]
        self.eval_engine = SupervisedEvaluator(
            device=self.device,
            val_data_loader=val_data_loader,
            network=network,
            inferer=SlidingWindowInferer(
                roi_size=[160, 160, 160],
                sw_batch_size=4,
                overlap=0.5,
            ),
            post_transform=post_transform,
            key_val_metric=key_val_metric,
            val_handlers=val_handlers,
            amp=self.amp,
        )

        optimizer = torch.optim.Adam(network.parameters(), self.learning_rate)
        loss_function = DiceLoss(to_onehot_y=True, softmax=True)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                       step_size=5000,
                                                       gamma=0.1)
        train_handlers = [
            LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
            ValidationHandler(validator=self.eval_engine,
                              interval=self.val_interval,
                              epoch_level=True),
            StatsHandler(tag_name="train_loss",
                         output_transform=lambda x: x["loss"]),
            TensorBoardStatsHandler(
                log_dir=self.ckpt_dir,
                tag_name="train_loss",
                output_transform=lambda x: x["loss"],
            ),
        ]

        self.train_engine = SupervisedTrainer(
            device=self.device,
            max_epochs=self.max_epochs,
            train_data_loader=train_data_loader,
            network=network,
            optimizer=optimizer,
            loss_function=loss_function,
            inferer=SimpleInferer(),
            post_transform=post_transform,
            key_train_metric=None,
            train_handlers=train_handlers,
            amp=self.amp,
        )

        if self.local_rank > 0:
            self.train_engine.logger.setLevel(logging.WARNING)
            self.eval_engine.logger.setLevel(logging.WARNING)
Exemplo n.º 16
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # create a temporary directory and 40 random image, mask paris
    tempdir = tempfile.mkdtemp()
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_3d(128,
                                       128,
                                       128,
                                       num_seg_classes=1,
                                       channel_dim=-1)
        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))
        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
    val_files = [{
        "image": img,
        "label": seg
    } for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose([
        LoadNiftid(keys=["image", "label"]),
        AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
        ScaleIntensityd(keys=["image", "label"]),
        ToTensord(keys=["image", "label"]),
    ])

    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda:0")
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    val_post_transforms = Compose([
        Activationsd(keys="pred", output_postfix="act", sigmoid=True),
        AsDiscreted(keys="pred_act",
                    output_postfix="dis",
                    threshold_values=True),
        KeepLargestConnectedComponentd(keys="pred_act_dis",
                                       applied_values=[1],
                                       output_postfix=None),
    ])
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        CheckpointLoader(load_path="./runs/net_key_metric=0.9101.pth",
                         load_dict={"net": net}),
        SegmentationSaver(
            output_dir="./runs/",
            batch_transform=lambda x: {
                "filename_or_obj": x["image.filename_or_obj"],
                "affine": x["image.affine"],
                "original_affine": x["image.original_affine"],
                "spatial_shape": x["image.spatial_shape"],
            },
            output_transform=lambda x: x["pred_act_dis"],
        ),
    ]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96),
                                     sw_batch_size=4,
                                     overlap=0.5),
        post_transform=val_post_transforms,
        key_val_metric={
            "val_mean_dice":
            MeanDice(include_background=True,
                     output_transform=lambda x:
                     (x["pred_act_dis"], x["label"]))
        },
        additional_metrics={
            "val_acc":
            Accuracy(
                output_transform=lambda x: (x["pred_act_dis"], x["label"]))
        },
        val_handlers=val_handlers,
    )
    evaluator.run()
    shutil.rmtree(tempdir)
def evaluate(args):
    if args.local_rank == 0 and not os.path.exists(args.dir):
        # create 16 random image, mask paris for evaluation
        print(f"generating synthetic data to {args.dir} (this may take a while)")
        os.makedirs(args.dir)
        # set random seed to generate same random data for every node
        np.random.seed(seed=0)
        for i in range(16):
            im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)
            n = nib.Nifti1Image(im, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"img{i:d}.nii.gz"))
            n = nib.Nifti1Image(seg, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"seg{i:d}.nii.gz"))

    # initialize the distributed evaluation process, every GPU runs in a process
    dist.init_process_group(backend="nccl", init_method="env://")

    images = sorted(glob(os.path.join(args.dir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz")))
    val_files = [{"image": img, "label": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
            ScaleIntensityd(keys="image"),
            ToTensord(keys=["image", "label"]),
        ]
    )

    # create a evaluation data loader
    val_ds = Dataset(data=val_files, transform=val_transforms)
    # create a evaluation data sampler
    val_sampler = DistributedSampler(val_ds, shuffle=False)
    # sliding window inference need to input 1 image in every iteration
    val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=True, sampler=val_sampler)

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device(f"cuda:{args.local_rank}")
    torch.cuda.set_device(device)
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    # wrap the model with DistributedDataParallel module
    net = DistributedDataParallel(net, device_ids=[device])

    val_post_transforms = Compose(
        [
            Activationsd(keys="pred", sigmoid=True),
            AsDiscreted(keys="pred", threshold_values=True),
            KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
        ]
    )
    val_handlers = [
        CheckpointLoader(
            load_path="./runs/checkpoint_epoch=4.pt",
            load_dict={"net": net},
            # config mapping to expected GPU device
            map_location={"cuda:0": f"cuda:{args.local_rank}"},
        ),
    ]
    if dist.get_rank() == 0:
        logging.basicConfig(stream=sys.stdout, level=logging.INFO)
        val_handlers.extend(
            [
                StatsHandler(output_transform=lambda x: None),
                SegmentationSaver(
                    output_dir="./runs/",
                    batch_transform=lambda batch: batch["image_meta_dict"],
                    output_transform=lambda output: output["pred"],
                ),
            ]
        )

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5),
        post_transform=val_post_transforms,
        key_val_metric={
            "val_mean_dice": MeanDice(
                include_background=True,
                output_transform=lambda x: (x["pred"], x["label"]),
                device=device,
            )
        },
        additional_metrics={"val_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]), device=device)},
        val_handlers=val_handlers,
        # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation
        amp=True if monai.config.get_torch_version_tuple() >= (1, 6) else False,
    )
    evaluator.run()
    dist.destroy_process_group()
Exemplo n.º 18
0
def main(tempdir):
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    ################################ DATASET ################################
    # create a temporary directory and 40 random image, mask pairs
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(40):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)
        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz"))
        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
    train_files = [{"image": img, "label": seg} for img, seg in zip(images[:20], segs[:20])]
    val_files = [{"image": img, "label": seg} for img, seg in zip(images[-20:], segs[-20:])]

    # define transforms for image and segmentation
    train_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
            ScaleIntensityd(keys="image"),
            RandCropByPosNegLabeld(
                keys=["image", "label"], label_key="label", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4
            ),
            RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]),
            ToTensord(keys=["image", "label"]),
        ]
    )
    val_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
            ScaleIntensityd(keys="image"),
            ToTensord(keys=["image", "label"]),
        ]
    )

    # create a training data loader
    train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.5)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)
    # create a validation data loader
    val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0)
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)
    ################################ DATASET ################################
    
    ################################ NETWORK ################################
    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    ################################ NETWORK ################################
    
    ################################ LOSS ################################    
    loss = monai.losses.DiceLoss(sigmoid=True)
    ################################ LOSS ################################
    
    ################################ OPT ################################
    opt = torch.optim.Adam(net.parameters(), 1e-3)
    ################################ OPT ################################
    
    ################################ LR ################################
    lr_scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=2, gamma=0.1)
    ################################ LR ################################
    
    val_post_transforms = Compose(
        [
            Activationsd(keys="pred", sigmoid=True),
            AsDiscreted(keys="pred", threshold_values=True),
            KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
        ]
    )
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        TensorBoardStatsHandler(log_dir="./runs/", output_transform=lambda x: None),
        TensorBoardImageHandler(
            log_dir="./runs/",
            batch_transform=lambda x: (x["image"], x["label"]),
            output_transform=lambda x: x["pred"],
        ),
        CheckpointSaver(save_dir="./runs/", save_dict={"net": net}, save_key_metric=True),
    ]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5),
        post_transform=val_post_transforms,
        key_val_metric={
            "val_mean_dice": MeanDice(include_background=True, output_transform=lambda x: (x["pred"], x["label"]))
        },
        additional_metrics={"val_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]))},
        val_handlers=val_handlers,
        # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation
        amp=True if monai.utils.get_torch_version_tuple() >= (1, 6) else False,
    )

    train_post_transforms = Compose(
        [
            Activationsd(keys="pred", sigmoid=True),
            AsDiscreted(keys="pred", threshold_values=True),
            KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
        ]
    )
    train_handlers = [
        LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
        ValidationHandler(validator=evaluator, interval=2, epoch_level=True),
        StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]),
        TensorBoardStatsHandler(log_dir="./runs/", tag_name="train_loss", output_transform=lambda x: x["loss"]),
        CheckpointSaver(save_dir="./runs/", save_dict={"net": net, "opt": opt}, save_interval=2, epoch_level=True),
    ]

    trainer = SupervisedTrainer(
        device=device,
        max_epochs=5,
        train_data_loader=train_loader,
        network=net,
        optimizer=opt,
        loss_function=loss,
        inferer=SimpleInferer(),
        post_transform=train_post_transforms,
        key_train_metric={"train_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]))},
        train_handlers=train_handlers,
        # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP training
        amp=True if monai.utils.get_torch_version_tuple() >= (1, 6) else False,
    )
    trainer.run()
Exemplo n.º 19
0
    def train(index):

        # ---------- Build the nn-Unet network ------------

        if opt.resolution is None:
            sizes, spacings = opt.patch_size, opt.spacing
        else:
            sizes, spacings = opt.patch_size, opt.resolution

        strides, kernels = [], []

        while True:
            spacing_ratio = [sp / min(spacings) for sp in spacings]
            stride = [
                2 if ratio <= 2 and size >= 8 else 1
                for (ratio, size) in zip(spacing_ratio, sizes)
            ]
            kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
            if all(s == 1 for s in stride):
                break
            sizes = [i / j for i, j in zip(sizes, stride)]
            spacings = [i * j for i, j in zip(spacings, stride)]
            kernels.append(kernel)
            strides.append(stride)
        strides.insert(0, len(spacings) * [1])
        kernels.append(len(spacings) * [3])

        net = monai.networks.nets.DynUNet(
            spatial_dims=3,
            in_channels=opt.in_channels,
            out_channels=opt.out_channels,
            kernel_size=kernels,
            strides=strides,
            upsample_kernel_size=strides[1:],
            res_block=True,
            # act=act_type,
            # norm=Norm.BATCH,
        ).to(device)

        from torch.autograd import Variable
        from torchsummaryX import summary

        data = Variable(
            torch.randn(int(opt.batch_size), int(opt.in_channels),
                        int(opt.patch_size[0]), int(opt.patch_size[1]),
                        int(opt.patch_size[2]))).cuda()

        out = net(data)
        summary(net, data)
        print("out size: {}".format(out.size()))

        # if opt.preload is not None:
        #     net.load_state_dict(torch.load(opt.preload))

        # ---------- ------------------------ ------------

        optim = torch.optim.Adam(net.parameters(), lr=opt.lr)
        lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            optim, lr_lambda=lambda epoch: (1 - epoch / opt.epochs)**0.9)

        loss_function = monai.losses.DiceCELoss(sigmoid=True)

        val_post_transforms = Compose([
            Activationsd(keys="pred", sigmoid=True),
            AsDiscreted(keys="pred", threshold_values=True),
            # KeepLargestConnectedComponentd(keys="pred", applied_labels=[1])
        ])

        val_handlers = [
            StatsHandler(output_transform=lambda x: None),
            CheckpointSaver(save_dir="./runs/",
                            save_dict={"net": net},
                            save_key_metric=True),
        ]

        evaluator = SupervisedEvaluator(
            device=device,
            val_data_loader=val_loaders[index],
            network=net,
            inferer=SlidingWindowInferer(roi_size=opt.patch_size,
                                         sw_batch_size=opt.batch_size,
                                         overlap=0.5),
            post_transform=val_post_transforms,
            key_val_metric={
                "val_mean_dice":
                MeanDice(
                    include_background=True,
                    output_transform=lambda x: (x["pred"], x["label"]),
                )
            },
            val_handlers=val_handlers)

        train_post_transforms = Compose([
            Activationsd(keys="pred", sigmoid=True),
            AsDiscreted(keys="pred", threshold_values=True),
            # KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
        ])

        train_handlers = [
            ValidationHandler(validator=evaluator,
                              interval=5,
                              epoch_level=True),
            LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
            StatsHandler(tag_name="train_loss",
                         output_transform=lambda x: x["loss"]),
            CheckpointSaver(save_dir="./runs/",
                            save_dict={
                                "net": net,
                                "opt": optim
                            },
                            save_final=True,
                            epoch_level=True),
        ]

        trainer = SupervisedTrainer(
            device=device,
            max_epochs=opt.epochs,
            train_data_loader=train_loaders[index],
            network=net,
            optimizer=optim,
            loss_function=loss_function,
            inferer=SimpleInferer(),
            post_transform=train_post_transforms,
            amp=False,
            train_handlers=train_handlers,
        )
        trainer.run()
        return net
Exemplo n.º 20
0
def train(cfg):
    log_dir = create_log_dir(cfg)
    device = set_device(cfg)
    # --------------------------------------------------------------------------
    # Data Loading and Preprocessing
    # --------------------------------------------------------------------------
    # __________________________________________________________________________
    # Build MONAI preprocessing
    train_preprocess = Compose([
        ToTensorD(keys="image"),
        TorchVisionD(keys="image",
                     name="ColorJitter",
                     brightness=64.0 / 255.0,
                     contrast=0.75,
                     saturation=0.25,
                     hue=0.04),
        ToNumpyD(keys="image"),
        RandFlipD(keys="image", prob=0.5),
        RandRotate90D(keys="image", prob=0.5),
        CastToTypeD(keys="image", dtype=np.float32),
        RandZoomD(keys="image", prob=0.5, min_zoom=0.9, max_zoom=1.1),
        ScaleIntensityRangeD(keys="image",
                             a_min=0.0,
                             a_max=255.0,
                             b_min=-1.0,
                             b_max=1.0),
        ToTensorD(keys=("image", "label")),
    ])
    valid_preprocess = Compose([
        CastToTypeD(keys="image", dtype=np.float32),
        ScaleIntensityRangeD(keys="image",
                             a_min=0.0,
                             a_max=255.0,
                             b_min=-1.0,
                             b_max=1.0),
        ToTensorD(keys=("image", "label")),
    ])
    # __________________________________________________________________________
    # Create MONAI dataset
    train_json_info_list = load_decathlon_datalist(
        data_list_file_path=cfg["dataset_json"],
        data_list_key="training",
        base_dir=cfg["data_root"],
    )
    valid_json_info_list = load_decathlon_datalist(
        data_list_file_path=cfg["dataset_json"],
        data_list_key="validation",
        base_dir=cfg["data_root"],
    )

    train_dataset = PatchWSIDataset(
        train_json_info_list,
        cfg["region_size"],
        cfg["grid_shape"],
        cfg["patch_size"],
        train_preprocess,
        image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM",
    )
    valid_dataset = PatchWSIDataset(
        valid_json_info_list,
        cfg["region_size"],
        cfg["grid_shape"],
        cfg["patch_size"],
        valid_preprocess,
        image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM",
    )

    # __________________________________________________________________________
    # DataLoaders
    train_dataloader = DataLoader(train_dataset,
                                  num_workers=cfg["num_workers"],
                                  batch_size=cfg["batch_size"],
                                  pin_memory=True)
    valid_dataloader = DataLoader(valid_dataset,
                                  num_workers=cfg["num_workers"],
                                  batch_size=cfg["batch_size"],
                                  pin_memory=True)

    # __________________________________________________________________________
    # Get sample batch and some info
    first_sample = first(train_dataloader)
    if first_sample is None:
        raise ValueError("Fist sample is None!")

    print("image: ")
    print("    shape", first_sample["image"].shape)
    print("    type: ", type(first_sample["image"]))
    print("    dtype: ", first_sample["image"].dtype)
    print("labels: ")
    print("    shape", first_sample["label"].shape)
    print("    type: ", type(first_sample["label"]))
    print("    dtype: ", first_sample["label"].dtype)
    print(f"batch size: {cfg['batch_size']}")
    print(f"train number of batches: {len(train_dataloader)}")
    print(f"valid number of batches: {len(valid_dataloader)}")

    # --------------------------------------------------------------------------
    # Deep Learning Classification Model
    # --------------------------------------------------------------------------
    # __________________________________________________________________________
    # initialize model
    model = TorchVisionFCModel("resnet18",
                               num_classes=1,
                               use_conv=True,
                               pretrained=cfg["pretrain"])
    model = model.to(device)

    # loss function
    loss_func = torch.nn.BCEWithLogitsLoss()
    loss_func = loss_func.to(device)

    # optimizer
    if cfg["novograd"]:
        optimizer = Novograd(model.parameters(), cfg["lr"])
    else:
        optimizer = SGD(model.parameters(), lr=cfg["lr"], momentum=0.9)

    # AMP scaler
    if cfg["amp"]:
        cfg["amp"] = True if monai.utils.get_torch_version_tuple() >= (
            1, 6) else False
    else:
        cfg["amp"] = False

    scheduler = lr_scheduler.CosineAnnealingLR(optimizer,
                                               T_max=cfg["n_epochs"])

    # --------------------------------------------
    # Ignite Trainer/Evaluator
    # --------------------------------------------
    # Evaluator
    val_handlers = [
        CheckpointSaver(save_dir=log_dir,
                        save_dict={"net": model},
                        save_key_metric=True),
        StatsHandler(output_transform=lambda x: None),
        TensorBoardStatsHandler(log_dir=log_dir,
                                output_transform=lambda x: None),
    ]
    val_postprocessing = Compose([
        ActivationsD(keys="pred", sigmoid=True),
        AsDiscreteD(keys="pred", threshold=0.5)
    ])
    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=valid_dataloader,
        network=model,
        postprocessing=val_postprocessing,
        key_val_metric={
            "val_acc":
            Accuracy(output_transform=from_engine(["pred", "label"]))
        },
        val_handlers=val_handlers,
        amp=cfg["amp"],
    )

    # Trainer
    train_handlers = [
        LrScheduleHandler(lr_scheduler=scheduler, print_lr=True),
        CheckpointSaver(save_dir=cfg["logdir"],
                        save_dict={
                            "net": model,
                            "opt": optimizer
                        },
                        save_interval=1,
                        epoch_level=True),
        StatsHandler(tag_name="train_loss",
                     output_transform=from_engine(["loss"], first=True)),
        ValidationHandler(validator=evaluator, interval=1, epoch_level=True),
        TensorBoardStatsHandler(log_dir=cfg["logdir"],
                                tag_name="train_loss",
                                output_transform=from_engine(["loss"],
                                                             first=True)),
    ]
    train_postprocessing = Compose([
        ActivationsD(keys="pred", sigmoid=True),
        AsDiscreteD(keys="pred", threshold=0.5)
    ])

    trainer = SupervisedTrainer(
        device=device,
        max_epochs=cfg["n_epochs"],
        train_data_loader=train_dataloader,
        network=model,
        optimizer=optimizer,
        loss_function=loss_func,
        postprocessing=train_postprocessing,
        key_train_metric={
            "train_acc":
            Accuracy(output_transform=from_engine(["pred", "label"]))
        },
        train_handlers=train_handlers,
        amp=cfg["amp"],
    )
    trainer.run()
Exemplo n.º 21
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # create a temporary directory and 40 random image, mask paris
    tempdir = tempfile.mkdtemp()
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(40):
        im, seg = create_test_image_3d(128,
                                       128,
                                       128,
                                       num_seg_classes=1,
                                       channel_dim=-1)
        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz"))
        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
    train_files = [{
        Keys.IMAGE: img,
        Keys.LABEL: seg
    } for img, seg in zip(images[:20], segs[:20])]
    val_files = [{
        Keys.IMAGE: img,
        Keys.LABEL: seg
    } for img, seg in zip(images[-20:], segs[-20:])]

    # define transforms for image and segmentation
    train_transforms = Compose([
        LoadNiftid(keys=[Keys.IMAGE, Keys.LABEL]),
        AsChannelFirstd(keys=[Keys.IMAGE, Keys.LABEL], channel_dim=-1),
        ScaleIntensityd(keys=[Keys.IMAGE, Keys.LABEL]),
        RandCropByPosNegLabeld(keys=[Keys.IMAGE, Keys.LABEL],
                               label_key=Keys.LABEL,
                               size=[96, 96, 96],
                               pos=1,
                               neg=1,
                               num_samples=4),
        RandRotate90d(keys=[Keys.IMAGE, Keys.LABEL],
                      prob=0.5,
                      spatial_axes=[0, 2]),
        ToTensord(keys=[Keys.IMAGE, Keys.LABEL]),
    ])
    val_transforms = Compose([
        LoadNiftid(keys=[Keys.IMAGE, Keys.LABEL]),
        AsChannelFirstd(keys=[Keys.IMAGE, Keys.LABEL], channel_dim=-1),
        ScaleIntensityd(keys=[Keys.IMAGE, Keys.LABEL]),
        ToTensord(keys=[Keys.IMAGE, Keys.LABEL]),
    ])

    # create a training data loader
    train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = DataLoader(train_ds,
                              batch_size=2,
                              shuffle=True,
                              num_workers=4,
                              collate_fn=list_data_collate)
    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=4,
                            collate_fn=list_data_collate)

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda:0")
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss = monai.losses.DiceLoss(do_sigmoid=True)
    opt = torch.optim.Adam(net.parameters(), 1e-3)

    val_handlers = [StatsHandler(output_transform=lambda x: None)]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96),
                                     sw_batch_size=4,
                                     overlap=0.5),
        val_handlers=val_handlers,
        key_val_metric={
            "val_mean_dice":
            MeanDice(include_background=True,
                     add_sigmoid=True,
                     output_transform=lambda x: (x[Keys.PRED], x[Keys.LABEL]))
        },
        additional_metrics=None,
    )

    train_handlers = [
        ValidationHandler(validator=evaluator, interval=2, epoch_level=True),
        StatsHandler(tag_name="train_loss",
                     output_transform=lambda x: x[Keys.INFO][Keys.LOSS]),
    ]

    trainer = SupervisedTrainer(
        device=device,
        max_epochs=5,
        train_data_loader=train_loader,
        network=net,
        optimizer=opt,
        loss_function=loss,
        inferer=SimpleInferer(),
        train_handlers=train_handlers,
        amp=False,
        key_train_metric=None,
    )
    trainer.run()

    shutil.rmtree(tempdir)
Exemplo n.º 22
0
def create_trainer(args):
    set_determinism(seed=args.seed)

    multi_gpu = args.multi_gpu
    local_rank = args.local_rank
    if multi_gpu:
        dist.init_process_group(backend="nccl", init_method="env://")
        device = torch.device("cuda:{}".format(local_rank))
        torch.cuda.set_device(device)
    else:
        device = torch.device("cuda" if args.use_gpu else "cpu")

    pre_transforms = get_pre_transforms(args.roi_size, args.model_size,
                                        args.dimensions)
    click_transforms = get_click_transforms()
    post_transform = get_post_transforms()

    train_loader, val_loader = get_loaders(args, pre_transforms)

    # define training components
    network = get_network(args.network, args.channels,
                          args.dimensions).to(device)
    if multi_gpu:
        network = torch.nn.parallel.DistributedDataParallel(
            network, device_ids=[local_rank], output_device=local_rank)

    if args.resume:
        logging.info('{}:: Loading Network...'.format(local_rank))
        map_location = {"cuda:0": "cuda:{}".format(local_rank)}
        network.load_state_dict(
            torch.load(args.model_filepath, map_location=map_location))

    # define event-handlers for engine
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        TensorBoardStatsHandler(log_dir=args.output,
                                output_transform=lambda x: None),
        DeepgrowStatsHandler(log_dir=args.output,
                             tag_name='val_dice',
                             image_interval=args.image_interval),
        CheckpointSaver(save_dir=args.output,
                        save_dict={"net": network},
                        save_key_metric=True,
                        save_final=True,
                        save_interval=args.save_interval,
                        final_filename='model.pt')
    ]
    val_handlers = val_handlers if local_rank == 0 else None

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=network,
        iteration_update=Interaction(
            transforms=click_transforms,
            max_interactions=args.max_val_interactions,
            key_probability='probability',
            train=False),
        inferer=SimpleInferer(),
        post_transform=post_transform,
        key_val_metric={
            "val_dice":
            MeanDice(include_background=False,
                     output_transform=lambda x: (x["pred"], x["label"]))
        },
        val_handlers=val_handlers)

    loss_function = DiceLoss(sigmoid=True, squared_pred=True)
    optimizer = torch.optim.Adam(network.parameters(), args.learning_rate)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   step_size=5000,
                                                   gamma=0.1)

    train_handlers = [
        LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
        ValidationHandler(validator=evaluator,
                          interval=args.val_freq,
                          epoch_level=True),
        StatsHandler(tag_name="train_loss",
                     output_transform=lambda x: x["loss"]),
        TensorBoardStatsHandler(log_dir=args.output,
                                tag_name="train_loss",
                                output_transform=lambda x: x["loss"]),
        CheckpointSaver(save_dir=args.output,
                        save_dict={
                            "net": network,
                            "opt": optimizer,
                            "lr": lr_scheduler
                        },
                        save_interval=args.save_interval * 2,
                        save_final=True,
                        final_filename='checkpoint.pt'),
    ]
    train_handlers = train_handlers if local_rank == 0 else train_handlers[:2]

    trainer = SupervisedTrainer(
        device=device,
        max_epochs=args.epochs,
        train_data_loader=train_loader,
        network=network,
        iteration_update=Interaction(
            transforms=click_transforms,
            max_interactions=args.max_train_interactions,
            key_probability='probability',
            train=True),
        optimizer=optimizer,
        loss_function=loss_function,
        inferer=SimpleInferer(),
        post_transform=post_transform,
        amp=args.amp,
        key_train_metric={
            "train_dice":
            MeanDice(include_background=False,
                     output_transform=lambda x: (x["pred"], x["label"]))
        },
        train_handlers=train_handlers,
    )
    return trainer
Exemplo n.º 23
0
def main(tempdir):
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # create a temporary directory and 40 random image, mask pairs
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)
        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))
        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
    val_files = [{"image": img, "label": seg} for img, seg in zip(images, segs)]

    # model file path
    model_file = glob("./runs/net_key_metric*")[0]

    # define transforms for image and segmentation
    val_transforms = Compose(
        [
            LoadNiftid(keys=["image", "label"]),
            AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
            ScaleIntensityd(keys="image"),
            ToTensord(keys=["image", "label"]),
        ]
    )

    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    val_post_transforms = Compose(
        [
            Activationsd(keys="pred", sigmoid=True),
            AsDiscreted(keys="pred", threshold_values=True),
            KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
        ]
    )
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        CheckpointLoader(load_path=model_file, load_dict={"net": net}),
        SegmentationSaver(
            output_dir="./runs/",
            batch_transform=lambda batch: batch["image_meta_dict"],
            output_transform=lambda output: output["pred"],
        ),
    ]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5),
        post_transform=val_post_transforms,
        key_val_metric={
            "val_mean_dice": MeanDice(include_background=True, output_transform=lambda x: (x["pred"], x["label"]))
        },
        additional_metrics={"val_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]))},
        val_handlers=val_handlers,
        # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation
        amp=True if monai.config.get_torch_version_tuple() >= (1, 6) else False,
    )
    evaluator.run()
def main(config):
    now = datetime.now().strftime("%Y%m%d-%H:%M:%S")

    # path
    csv_path = config['path']['csv_path']

    trained_model_path = config['path'][
        'trained_model_path']  # if None, trained from scratch
    training_model_folder = os.path.join(
        config['path']['training_model_folder'], now)  # '/path/to/folder'
    if not os.path.exists(training_model_folder):
        os.makedirs(training_model_folder)
    logdir = os.path.join(training_model_folder, 'logs')
    if not os.path.exists(logdir):
        os.makedirs(logdir)

    # PET CT scan params
    image_shape = tuple(config['preprocessing']['image_shape'])  # (x, y, z)
    in_channels = config['preprocessing']['in_channels']
    voxel_spacing = tuple(
        config['preprocessing']
        ['voxel_spacing'])  # (4.8, 4.8, 4.8)  # in millimeter, (x, y, z)
    data_augment = config['preprocessing'][
        'data_augment']  # True  # for training dataset only
    resize = config['preprocessing']['resize']  # True  # not use yet
    origin = config['preprocessing']['origin']  # how to set the new origin
    normalize = config['preprocessing'][
        'normalize']  # True  # whether or not to normalize the inputs
    number_class = config['preprocessing']['number_class']  # 2

    # CNN params
    architecture = config['model']['architecture']  # 'unet' or 'vnet'

    cnn_params = config['model'][architecture]['cnn_params']
    # transform list to tuple
    for key, value in cnn_params.items():
        if isinstance(value, list):
            cnn_params[key] = tuple(value)

    # Training params
    epochs = config['training']['epochs']
    batch_size = config['training']['batch_size']
    shuffle = config['training']['shuffle']
    opt_params = config['training']["optimizer"]["opt_params"]

    # Get Data
    DM = DataManager(csv_path=csv_path)
    train_images_paths, val_images_paths, test_images_paths = DM.get_train_val_test(
        wrap_with_dict=True)

    # Input preprocessing
    # use data augmentation for training
    train_transforms = Compose([  # read img + meta info
        LoadNifti(keys=["pet_img", "ct_img", "mask_img"]),
        Roi2Mask(keys=['pet_img', 'mask_img'],
                 method='otsu',
                 tval=0.0,
                 idx_channel=0),
        ResampleReshapeAlign(target_shape=image_shape,
                             target_voxel_spacing=voxel_spacing,
                             keys=['pet_img', "ct_img", 'mask_img'],
                             origin='head',
                             origin_key='pet_img'),
        Sitk2Numpy(keys=['pet_img', 'ct_img', 'mask_img']),
        # user can also add other random transforms
        RandAffined(keys=("pet_img", "ct_img", "mask_img"),
                    spatial_size=None,
                    prob=0.4,
                    rotate_range=(0, np.pi / 30, np.pi / 15),
                    shear_range=None,
                    translate_range=(10, 10, 10),
                    scale_range=(0.1, 0.1, 0.1),
                    mode=("bilinear", "bilinear", "nearest"),
                    padding_mode="border"),
        # normalize input
        ScaleIntensityRanged(
            keys=["pet_img"],
            a_min=0.0,
            a_max=25.0,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        ScaleIntensityRanged(
            keys=["ct_img"],
            a_min=-1000.0,
            a_max=1000.0,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        # Prepare for neural network
        ConcatModality(keys=['pet_img', 'ct_img']),
        AddChanneld(keys=["mask_img"]),  # Add channel to the first axis
        ToTensord(keys=["image", "mask_img"]),
    ])
    # without data augmentation for validation
    val_transforms = Compose([  # read img + meta info
        LoadNifti(keys=["pet_img", "ct_img", "mask_img"]),
        Roi2Mask(keys=['pet_img', 'mask_img'],
                 method='otsu',
                 tval=0.0,
                 idx_channel=0),
        ResampleReshapeAlign(target_shape=image_shape,
                             target_voxel_spacing=voxel_spacing,
                             keys=['pet_img', "ct_img", 'mask_img'],
                             origin='head',
                             origin_key='pet_img'),
        Sitk2Numpy(keys=['pet_img', 'ct_img', 'mask_img']),
        # normalize input
        ScaleIntensityRanged(
            keys=["pet_img"],
            a_min=0.0,
            a_max=25.0,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        ScaleIntensityRanged(
            keys=["ct_img"],
            a_min=-1000.0,
            a_max=1000.0,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        # Prepare for neural network
        ConcatModality(keys=['pet_img', 'ct_img']),
        AddChanneld(keys=["mask_img"]),  # Add channel to the first axis
        ToTensord(keys=["image", "mask_img"]),
    ])

    # create a training data loader
    train_ds = monai.data.CacheDataset(data=train_images_paths,
                                       transform=train_transforms,
                                       cache_rate=0.5)
    # use batch_size=2 to load images to generate 2 x 4 images for network training
    train_loader = monai.data.DataLoader(train_ds,
                                         batch_size=batch_size,
                                         shuffle=shuffle,
                                         num_workers=2)
    # create a validation data loader
    val_ds = monai.data.CacheDataset(data=val_images_paths,
                                     transform=val_transforms,
                                     cache_rate=1.0)
    val_loader = monai.data.DataLoader(val_ds,
                                       batch_size=batch_size,
                                       num_workers=2)

    # Model
    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = UNet(
        dimensions=3,  # 3D
        in_channels=in_channels,
        out_channels=1,
        kernel_size=5,
        channels=(8, 16, 32, 64, 128),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss = monai.losses.DiceLoss(sigmoid=True, squared_pred=True)
    opt = torch.optim.Adam(net.parameters(), 1e-3)

    # training
    val_post_transforms = Compose([
        Activationsd(keys="pred", sigmoid=True),
        AsDiscreted(keys="pred", threshold_values=True),
    ])
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        TensorBoardStatsHandler(log_dir="./runs/",
                                output_transform=lambda x: None),
        # TensorBoardImageHandler(
        #     log_dir="./runs/",
        #     batch_transform=lambda x: (x["image"], x["label"]),
        #     output_transform=lambda x: x["pred"],
        # ),
        CheckpointSaver(save_dir="./runs/",
                        save_dict={
                            "net": net,
                            "opt": opt
                        },
                        save_key_metric=True),
    ]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SimpleInferer(),
        post_transform=val_post_transforms,
        key_val_metric={
            "val_mean_dice":
            MeanDice(include_background=True,
                     output_transform=lambda x: (x["pred"], x["label"]))
        },
        additional_metrics={
            "val_acc":
            Accuracy(output_transform=lambda x: (x["pred"], x["label"])),
            "val_precision":
            Precision(output_transform=lambda x: (x["pred"], x["label"])),
            "val_recall":
            Recall(output_transform=lambda x: (x["pred"], x["label"]))
        },
        val_handlers=val_handlers,
        # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation
        # amp=True if monai.config.get_torch_version_tuple() >= (1, 6) else False,
    )

    train_post_transforms = Compose([
        Activationsd(keys="pred", sigmoid=True),
        AsDiscreted(keys="pred", threshold_values=True),
    ])
    train_handlers = [
        # LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
        ValidationHandler(validator=evaluator, interval=1, epoch_level=True),
        StatsHandler(tag_name="train_loss",
                     output_transform=lambda x: x["loss"]),
        TensorBoardStatsHandler(log_dir="./runs/",
                                tag_name="train_loss",
                                output_transform=lambda x: x["loss"]),
        CheckpointSaver(save_dir="./runs/",
                        save_dict={
                            "net": net,
                            "opt": opt
                        },
                        save_interval=2,
                        epoch_level=True),
    ]

    trainer = SupervisedTrainer(
        device=device,
        max_epochs=5,
        train_data_loader=train_loader,
        network=net,
        optimizer=opt,
        loss_function=loss,
        prepare_batch=lambda x: (x['image'], x['mask_img']),
        inferer=SimpleInferer(),
        post_transform=train_post_transforms,
        key_train_metric={
            "train_mean_dice":
            MeanDice(include_background=True,
                     output_transform=lambda x: (x["pred"], x["label"]))
        },
        additional_metrics={
            "train_acc":
            Accuracy(output_transform=lambda x: (x["pred"], x["label"])),
            "train_precision":
            Precision(output_transform=lambda x: (x["pred"], x["label"])),
            "train_recall":
            Recall(output_transform=lambda x: (x["pred"], x["label"]))
        },
        train_handlers=train_handlers,
        # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP training
        amp=True if monai.config.get_torch_version_tuple() >=
        (1, 6) else False,
    )
    trainer.run()