Esempio n. 1
0
    def test_metrics_writer(self):
        default_dir = os.path.join('.', 'runs')
        shutil.rmtree(default_dir, ignore_errors=True)
        with tempfile.TemporaryDirectory() as temp_dir:

            # set up engine
            def _train_func(engine, batch):
                return batch + 1.0

            engine = Engine(_train_func)

            # set up dummy metric
            @engine.on(Events.EPOCH_COMPLETED)
            def _update_metric(engine):
                current_metric = engine.state.metrics.get('acc', 0.1)
                engine.state.metrics['acc'] = current_metric + 0.1

            # set up testing handler
            writer = SummaryWriter(log_dir=temp_dir)
            stats_handler = TensorBoardStatsHandler(
                writer,
                output_transform=lambda x: {'loss': x * 2.0},
                global_epoch_transform=lambda x: x * 3.0)
            stats_handler.attach(engine)
            engine.run(range(3), max_epochs=2)
            # check logging output
            self.assertTrue(len(glob.glob(temp_dir)) > 0)
            self.assertTrue(not os.path.exists(default_dir))
    def test_metrics_writer(self):
        with tempfile.TemporaryDirectory() as tempdir:

            # set up engine
            def _train_func(engine, batch):
                return [batch + 1.0]

            engine = Engine(_train_func)

            # set up dummy metric
            @engine.on(Events.EPOCH_COMPLETED)
            def _update_metric(engine):
                current_metric = engine.state.metrics.get("acc", 0.1)
                engine.state.metrics["acc"] = current_metric + 0.1
                engine.state.test = current_metric

            # set up testing handler
            writer = SummaryWriter(log_dir=tempdir)
            stats_handler = TensorBoardStatsHandler(
                summary_writer=writer,
                iteration_log=True,
                epoch_log=False,
                output_transform=lambda x: {"loss": x[0] * 2.0},
                global_epoch_transform=lambda x: x * 3.0,
                state_attributes=["test"],
            )
            stats_handler.attach(engine)
            engine.run(range(3), max_epochs=2)
            writer.close()
            # check logging output
            self.assertTrue(len(glob.glob(tempdir)) > 0)
Esempio n. 3
0
    def test_metrics_writer(self):
        tempdir = tempfile.mkdtemp()
        shutil.rmtree(tempdir, ignore_errors=True)

        # set up engine
        def _train_func(engine, batch):
            return batch + 1.0

        engine = Engine(_train_func)

        # set up dummy metric
        @engine.on(Events.EPOCH_COMPLETED)
        def _update_metric(engine):
            current_metric = engine.state.metrics.get("acc", 0.1)
            engine.state.metrics["acc"] = current_metric + 0.1

        # set up testing handler
        writer = SummaryWriter(log_dir=tempdir)
        stats_handler = TensorBoardStatsHandler(
            writer, output_transform=lambda x: {"loss": x * 2.0}, global_epoch_transform=lambda x: x * 3.0
        )
        stats_handler.attach(engine)
        engine.run(range(3), max_epochs=2)
        # check logging output
        self.assertTrue(os.path.exists(tempdir))
        self.assertTrue(len(glob.glob(tempdir)) > 0)
        shutil.rmtree(tempdir)
Esempio n. 4
0
    def train_handlers(self, context: Context):
        handlers: List[Any] = []

        # LR Scheduler
        lr_scheduler = self.lr_scheduler_handler(context)
        if lr_scheduler:
            handlers.append(lr_scheduler)

        if context.local_rank == 0:
            handlers.extend([
                StatsHandler(tag_name="train_loss",
                             output_transform=from_engine(["loss"],
                                                          first=True)),
                TensorBoardStatsHandler(
                    log_dir=context.events_dir,
                    tag_name="train_loss",
                    output_transform=from_engine(["loss"], first=True),
                ),
            ])

        if context.evaluator:
            logger.info(
                f"{context.local_rank} - Adding Validation to run every '{self._val_interval}' interval"
            )
            handlers.append(
                ValidationHandler(self._val_interval,
                                  validator=context.evaluator,
                                  epoch_level=True))

        return handlers
Esempio n. 5
0
 def val_handlers(self, context: Context):
     val_handlers = [
         StatsHandler(output_transform=lambda x: None),
         TensorBoardStatsHandler(log_dir=context.events_dir,
                                 output_transform=lambda x: None),
     ]
     return val_handlers if context.local_rank == 0 else None
Esempio n. 6
0
    def test_metrics_print(self):
        with tempfile.TemporaryDirectory() as tempdir:

            # set up engine
            def _train_func(engine, batch):
                return batch + 1.0

            engine = Engine(_train_func)

            # set up dummy metric
            @engine.on(Events.EPOCH_COMPLETED)
            def _update_metric(engine):
                current_metric = engine.state.metrics.get("acc", 0.1)
                engine.state.metrics["acc"] = current_metric + 0.1

            # set up testing handler
            stats_handler = TensorBoardStatsHandler(log_dir=tempdir)
            stats_handler.attach(engine)
            engine.run(range(3), max_epochs=2)
            # check logging output
            self.assertTrue(len(glob.glob(tempdir)) > 0)
Esempio n. 7
0
File: trainer.py Progetto: ckbr0/RIS
    def run(self, date=None) -> str:
        
        if date is not None:
            now = date
        else:
            now = datetime.datetime.now()
        datetime_string = now.strftime('%d/%m/%Y %H:%M:%S')
        print(f'Training started: {datetime_string}')

        now = datetime.datetime.now()
        timedate_info = str(now).split(' ')[0] + '_' + str(now.strftime("%H:%M:%S")).replace(':', '-')
        training_dir = os.path.join(self.out_dir, 'training')
        if not os.path.exists(training_dir):
            os.mkdir(training_dir)
        self.output_dir = os.path.join(training_dir, self.out_name +  '_' + timedate_info)
        os.mkdir(self.output_dir)
        
        self.validator.output_dir = self.output_dir

        if self.summary_writer is None:
            self.summary_writer = SummaryWriter(log_dir=self.output_dir)
        if self.validator.summary_writer is None:
            self.validator.summary_writer = self.summary_writer

        handlers = [
            MetricLogger(self.output_dir, validator=self.validator),
            ValidationHandler(
                validator=self.validator,
                start=self.validation_epoch,
                interval=self.validation_interval
            ),
            StatsHandler(tag_name="loss", output_transform=lambda x: x["loss"]),
            TensorBoardStatsHandler(
                summary_writer=self.summary_writer,
                tag_name="Loss",
                output_transform=lambda x: x["loss"]
            ),
        ]
        save_dict = { 'network': self.network, 'optimizer': self.optimizer }
        if self.lr_scheduler is not None:
            handlers.insert(0, LrScheduleHandler(lr_scheduler=self.lr_scheduler, print_lr=True))
            save_dict['lr_scheduler'] = self.lr_scheduler
        handlers.append(
            CheckpointSaver(save_dir=self.output_dir, save_dict=save_dict, save_interval=1, n_saved=1)
        )
        self._register_handlers(handlers)

        super().run()
        return self.output_dir
Esempio n. 8
0
    def test_metrics_print(self):
        default_dir = os.path.join('.', 'runs')
        shutil.rmtree(default_dir, ignore_errors=True)

        # set up engine
        def _train_func(engine, batch):
            return batch + 1.0

        engine = Engine(_train_func)

        # set up dummy metric
        @engine.on(Events.EPOCH_COMPLETED)
        def _update_metric(engine):
            current_metric = engine.state.metrics.get('acc', 0.1)
            engine.state.metrics['acc'] = current_metric + 0.1

        # set up testing handler
        stats_handler = TensorBoardStatsHandler()
        stats_handler.attach(engine)
        engine.run(range(3), max_epochs=2)
        # check logging output

        self.assertTrue(os.path.exists(default_dir))
        shutil.rmtree(default_dir)
Esempio n. 9
0
    def run(self, global_epoch: int) -> None:

        if global_epoch == 1:
            handlers = [
                StatsHandler(),
                TensorBoardStatsHandler(
                    summary_writer=self.summary_writer
                ),  #, output_transform=lambda x: None),
                CheckpointSaver(save_dir=self.output_dir,
                                save_dict={"network": self.network},
                                save_key_metric=True),
                MetricsSaver(save_dir=self.output_dir,
                             metrics=['Valid_AUC', 'Valid_ACC']),
                self.early_stop_handler,
            ]
            self._register_handlers(handlers)

        return super().run(global_epoch=global_epoch)
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # create a temporary directory and 40 random image, mask paris
    tempdir = tempfile.mkdtemp()
    print('generating synthetic data to {} (this may take a while)'.format(tempdir))
    for i in range(40):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i))

    images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz')))
    segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz')))

    # define transforms for image and segmentation
    train_imtrans = Compose([
        ScaleIntensity(),
        AddChannel(),
        RandSpatialCrop((96, 96, 96), random_size=False),
        ToTensor()
    ])
    train_segtrans = Compose([
        AddChannel(),
        RandSpatialCrop((96, 96, 96), random_size=False),
        ToTensor()
    ])
    val_imtrans = Compose([
        ScaleIntensity(),
        AddChannel(),
        Resize((96, 96, 96)),
        ToTensor()
    ])
    val_segtrans = Compose([
        AddChannel(),
        Resize((96, 96, 96)),
        ToTensor()
    ])

    # define nifti dataset, data loader
    check_ds = NiftiDataset(images, segs, transform=train_imtrans, seg_transform=train_segtrans)
    check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available())
    im, seg = monai.utils.misc.first(check_loader)
    print(im.shape, seg.shape)

    # create a training data loader
    train_ds = NiftiDataset(images[:20], segs[:20], transform=train_imtrans, seg_transform=train_segtrans)
    train_loader = DataLoader(train_ds, batch_size=5, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available())
    # create a validation data loader
    val_ds = NiftiDataset(images[-20:], segs[-20:], transform=val_imtrans, seg_transform=val_segtrans)
    val_loader = DataLoader(val_ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available())

    # create UNet, DiceLoss and Adam optimizer
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    )
    loss = monai.losses.DiceLoss(do_sigmoid=True)
    lr = 1e-3
    opt = torch.optim.Adam(net.parameters(), lr)
    device = torch.device('cuda:0')

    # ignite trainer expects batch=(img, seg) and returns output=loss at every iteration,
    # user can add output_transform to return other values, like: y_pred, y, etc.
    trainer = create_supervised_trainer(net, opt, loss, device, False)

    # adding checkpoint handler to save models (network params and optimizer stats) during training
    checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False)
    trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                              handler=checkpoint_handler,
                              to_save={'net': net, 'opt': opt})

    # StatsHandler prints loss at every iteration and print metrics at every epoch,
    # we don't set metrics for trainer here, so just print loss, user can also customize print functions
    # and can use output_transform to convert engine.state.output if it's not a loss value
    train_stats_handler = StatsHandler(name='trainer')
    train_stats_handler.attach(trainer)

    # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler
    train_tensorboard_stats_handler = TensorBoardStatsHandler()
    train_tensorboard_stats_handler.attach(trainer)

    validation_every_n_epochs = 1
    # Set parameters for validation
    metric_name = 'Mean_Dice'
    # add evaluation metric to the evaluator engine
    val_metrics = {metric_name: MeanDice(add_sigmoid=True, to_onehot_y=False)}

    # ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration,
    # user can add output_transform to return other values
    evaluator = create_supervised_evaluator(net, val_metrics, device, True)


    @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs))
    def run_validation(engine):
        evaluator.run(val_loader)


    # add early stopping handler to evaluator
    early_stopper = EarlyStopping(patience=4,
                                  score_function=stopping_fn_from_metric(metric_name),
                                  trainer=trainer)
    evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper)

    # add stats event handler to print validation stats via evaluator
    val_stats_handler = StatsHandler(
        name='evaluator',
        output_transform=lambda x: None,  # no need to print loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.epoch)  # fetch global epoch number from trainer
    val_stats_handler.attach(evaluator)

    # add handler to record metrics to TensorBoard at every validation epoch
    val_tensorboard_stats_handler = TensorBoardStatsHandler(
        output_transform=lambda x: None,  # no need to plot loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.epoch)  # fetch global epoch number from trainer
    val_tensorboard_stats_handler.attach(evaluator)

    # add handler to draw the first image and the corresponding label and model output in the last batch
    # here we draw the 3D output as GIF format along Depth axis, at every validation epoch
    val_tensorboard_image_handler = TensorBoardImageHandler(
        batch_transform=lambda batch: (batch[0], batch[1]),
        output_transform=lambda output: predict_segmentation(output[0]),
        global_iter_transform=lambda x: trainer.state.epoch
    )
    evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=val_tensorboard_image_handler)

    train_epochs = 30
    state = trainer.run(train_loader, train_epochs)
    shutil.rmtree(tempdir)
Esempio n. 11
0
def main(tempdir):
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    ################################ DATASET ################################
    # create a temporary directory and 40 random image, mask pairs
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(40):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)
        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz"))
        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
    train_files = [{"image": img, "label": seg} for img, seg in zip(images[:20], segs[:20])]
    val_files = [{"image": img, "label": seg} for img, seg in zip(images[-20:], segs[-20:])]

    # define transforms for image and segmentation
    train_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
            ScaleIntensityd(keys="image"),
            RandCropByPosNegLabeld(
                keys=["image", "label"], label_key="label", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4
            ),
            RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]),
            ToTensord(keys=["image", "label"]),
        ]
    )
    val_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
            ScaleIntensityd(keys="image"),
            ToTensord(keys=["image", "label"]),
        ]
    )

    # create a training data loader
    train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.5)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)
    # create a validation data loader
    val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0)
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)
    ################################ DATASET ################################
    
    ################################ NETWORK ################################
    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    ################################ NETWORK ################################
    
    ################################ LOSS ################################    
    loss = monai.losses.DiceLoss(sigmoid=True)
    ################################ LOSS ################################
    
    ################################ OPT ################################
    opt = torch.optim.Adam(net.parameters(), 1e-3)
    ################################ OPT ################################
    
    ################################ LR ################################
    lr_scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=2, gamma=0.1)
    ################################ LR ################################
    
    val_post_transforms = Compose(
        [
            Activationsd(keys="pred", sigmoid=True),
            AsDiscreted(keys="pred", threshold_values=True),
            KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
        ]
    )
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        TensorBoardStatsHandler(log_dir="./runs/", output_transform=lambda x: None),
        TensorBoardImageHandler(
            log_dir="./runs/",
            batch_transform=lambda x: (x["image"], x["label"]),
            output_transform=lambda x: x["pred"],
        ),
        CheckpointSaver(save_dir="./runs/", save_dict={"net": net}, save_key_metric=True),
    ]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5),
        post_transform=val_post_transforms,
        key_val_metric={
            "val_mean_dice": MeanDice(include_background=True, output_transform=lambda x: (x["pred"], x["label"]))
        },
        additional_metrics={"val_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]))},
        val_handlers=val_handlers,
        # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation
        amp=True if monai.utils.get_torch_version_tuple() >= (1, 6) else False,
    )

    train_post_transforms = Compose(
        [
            Activationsd(keys="pred", sigmoid=True),
            AsDiscreted(keys="pred", threshold_values=True),
            KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
        ]
    )
    train_handlers = [
        LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
        ValidationHandler(validator=evaluator, interval=2, epoch_level=True),
        StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]),
        TensorBoardStatsHandler(log_dir="./runs/", tag_name="train_loss", output_transform=lambda x: x["loss"]),
        CheckpointSaver(save_dir="./runs/", save_dict={"net": net, "opt": opt}, save_interval=2, epoch_level=True),
    ]

    trainer = SupervisedTrainer(
        device=device,
        max_epochs=5,
        train_data_loader=train_loader,
        network=net,
        optimizer=opt,
        loss_function=loss,
        inferer=SimpleInferer(),
        post_transform=train_post_transforms,
        key_train_metric={"train_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]))},
        train_handlers=train_handlers,
        # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP training
        amp=True if monai.utils.get_torch_version_tuple() >= (1, 6) else False,
    )
    trainer.run()
Esempio n. 12
0
    def train(self,
              train_info,
              valid_info,
              hyperparameters,
              run_data_check=False):

        logging.basicConfig(stream=sys.stdout, level=logging.INFO)

        if not run_data_check:
            start_dt = datetime.datetime.now()
            start_dt_string = start_dt.strftime('%d/%m/%Y %H:%M:%S')
            print(f'Training started: {start_dt_string}')

            # 1. Create folders to save the model
            timedate_info = str(
                datetime.datetime.now()).split(' ')[0] + '_' + str(
                    datetime.datetime.now().strftime("%H:%M:%S")).replace(
                        ':', '-')
            path_to_model = os.path.join(
                self.out_dir, 'trained_models',
                self.unique_name + '_' + timedate_info)
            os.mkdir(path_to_model)

        # 2. Load hyperparameters
        learning_rate = hyperparameters['learning_rate']
        weight_decay = hyperparameters['weight_decay']
        total_epoch = hyperparameters['total_epoch']
        multiplicator = hyperparameters['multiplicator']
        batch_size = hyperparameters['batch_size']
        validation_epoch = hyperparameters['validation_epoch']
        validation_interval = hyperparameters['validation_interval']
        H = hyperparameters['H']
        L = hyperparameters['L']

        # 3. Consider class imbalance
        negative, positive = 0, 0
        for _, label in train_info:
            if int(label) == 0:
                negative += 1
            elif int(label) == 1:
                positive += 1

        pos_weight = torch.Tensor([(negative / positive)]).to(self.device)

        # 4. Create train and validation loaders, batch_size = 10 for validation loader (10 central slices)

        train_data = get_data_from_info(self.image_data_dir, self.seg_data_dir,
                                        train_info)
        valid_data = get_data_from_info(self.image_data_dir, self.seg_data_dir,
                                        valid_info)
        large_image_splitter(train_data, self.cache_dir)

        set_determinism(seed=100)
        train_trans, valid_trans = self.transformations(H, L)
        train_dataset = PersistentDataset(
            data=train_data[:],
            transform=train_trans,
            cache_dir=self.persistent_dataset_dir)
        valid_dataset = PersistentDataset(
            data=valid_data[:],
            transform=valid_trans,
            cache_dir=self.persistent_dataset_dir)

        train_loader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  pin_memory=self.pin_memory,
                                  num_workers=self.num_workers,
                                  collate_fn=PadListDataCollate(
                                      Method.SYMMETRIC, NumpyPadMode.CONSTANT))
        valid_loader = DataLoader(valid_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  pin_memory=self.pin_memory,
                                  num_workers=self.num_workers,
                                  collate_fn=PadListDataCollate(
                                      Method.SYMMETRIC, NumpyPadMode.CONSTANT))

        # Perform data checks
        if run_data_check:
            check_data = monai.utils.misc.first(train_loader)
            print(check_data["image"].shape, check_data["label"])
            for i in range(batch_size):
                multi_slice_viewer(
                    check_data["image"][i, 0, :, :, :],
                    check_data["image_meta_dict"]["filename_or_obj"][i])
            exit()
        """c = 1
        for d in train_loader:
            img = d["image"]
            seg = d["seg"][0]
            seg, _ = nrrd.read(seg)
            img_name = d["image_meta_dict"]["filename_or_obj"][0]
            print(c, "Name:", img_name, "Size:", img.nelement()*img.element_size()/1024/1024, "MB", "shape:", img.shape)
            multi_slice_viewer(img[0, 0, :, :, :], d["image_meta_dict"]["filename_or_obj"][0])
            #multi_slice_viewer(seg, d["image_meta_dict"]["filename_or_obj"][0])
            c += 1
        exit()"""

        # 5. Prepare model
        model = ModelCT().to(self.device)

        # 6. Define loss function, optimizer and scheduler
        loss_function = torch.nn.BCEWithLogitsLoss(
            pos_weight)  # pos_weight for class imbalance
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=learning_rate,
                                     weight_decay=weight_decay)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
                                                           multiplicator,
                                                           last_epoch=-1)
        # 7. Create post validation transforms and handlers
        path_to_tensorboard = os.path.join(self.out_dir, 'tensorboard')
        writer = SummaryWriter(log_dir=path_to_tensorboard)
        valid_post_transforms = Compose([
            Activationsd(keys="pred", sigmoid=True),
        ])
        valid_handlers = [
            StatsHandler(output_transform=lambda x: None),
            TensorBoardStatsHandler(summary_writer=writer,
                                    output_transform=lambda x: None),
            CheckpointSaver(save_dir=path_to_model,
                            save_dict={"model": model},
                            save_key_metric=True),
            MetricsSaver(save_dir=path_to_model,
                         metrics=['Valid_AUC', 'Valid_ACC']),
        ]
        # 8. Create validatior
        discrete = AsDiscrete(threshold_values=True)
        evaluator = SupervisedEvaluator(
            device=self.device,
            val_data_loader=valid_loader,
            network=model,
            post_transform=valid_post_transforms,
            key_val_metric={
                "Valid_AUC":
                ROCAUC(output_transform=lambda x: (x["pred"], x["label"]))
            },
            additional_metrics={
                "Valid_Accuracy":
                Accuracy(output_transform=lambda x:
                         (discrete(x["pred"]), x["label"]))
            },
            val_handlers=valid_handlers,
            amp=self.amp,
        )
        # 9. Create trainer

        # Loss function does the last sigmoid, so we dont need it here.
        train_post_transforms = Compose([
            # Empty
        ])
        logger = MetricLogger(evaluator=evaluator)
        train_handlers = [
            logger,
            LrScheduleHandler(lr_scheduler=scheduler, print_lr=True),
            ValidationHandlerCT(validator=evaluator,
                                start=validation_epoch,
                                interval=validation_interval,
                                epoch_level=True),
            StatsHandler(tag_name="loss",
                         output_transform=lambda x: x["loss"]),
            TensorBoardStatsHandler(summary_writer=writer,
                                    tag_name="Train_Loss",
                                    output_transform=lambda x: x["loss"]),
            CheckpointSaver(save_dir=path_to_model,
                            save_dict={
                                "model": model,
                                "opt": optimizer
                            },
                            save_interval=1,
                            n_saved=1),
        ]

        trainer = SupervisedTrainer(
            device=self.device,
            max_epochs=total_epoch,
            train_data_loader=train_loader,
            network=model,
            optimizer=optimizer,
            loss_function=loss_function,
            post_transform=train_post_transforms,
            train_handlers=train_handlers,
            amp=self.amp,
        )
        # 10. Run trainer
        trainer.run()
        # 11. Save results
        np.save(path_to_model + '/AUCS.npy',
                np.array(logger.metrics['Valid_AUC']))
        np.save(path_to_model + '/ACCS.npy',
                np.array(logger.metrics['Valid_ACC']))
        np.save(path_to_model + '/LOSSES.npy', np.array(logger.loss))
        np.save(path_to_model + '/PARAMETERS.npy', np.array(hyperparameters))

        return path_to_model
Esempio n. 13
0
def run_training_test(root_dir, device="cuda:0", amp=False):
    images = sorted(glob(os.path.join(root_dir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    train_files = [{
        "image": img,
        "label": seg
    } for img, seg in zip(images[:20], segs[:20])]
    val_files = [{
        "image": img,
        "label": seg
    } for img, seg in zip(images[-20:], segs[-20:])]

    # define transforms for image and segmentation
    train_transforms = Compose([
        LoadNiftid(keys=["image", "label"]),
        AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
        ScaleIntensityd(keys=["image", "label"]),
        RandCropByPosNegLabeld(keys=["image", "label"],
                               label_key="label",
                               spatial_size=[96, 96, 96],
                               pos=1,
                               neg=1,
                               num_samples=4),
        RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]),
        ToTensord(keys=["image", "label"]),
    ])
    val_transforms = Compose([
        LoadNiftid(keys=["image", "label"]),
        AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
        ScaleIntensityd(keys=["image", "label"]),
        ToTensord(keys=["image", "label"]),
    ])

    # create a training data loader
    train_ds = monai.data.CacheDataset(data=train_files,
                                       transform=train_transforms,
                                       cache_rate=0.5)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = monai.data.DataLoader(train_ds,
                                         batch_size=2,
                                         shuffle=True,
                                         num_workers=4)
    # create a validation data loader
    val_ds = monai.data.CacheDataset(data=val_files,
                                     transform=val_transforms,
                                     cache_rate=1.0)
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)

    # create UNet, DiceLoss and Adam optimizer
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss = monai.losses.DiceLoss(sigmoid=True)
    opt = torch.optim.Adam(net.parameters(), 1e-3)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=2, gamma=0.1)

    val_post_transforms = Compose([
        Activationsd(keys="pred", sigmoid=True),
        AsDiscreted(keys="pred", threshold_values=True),
        KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
    ])
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        TensorBoardStatsHandler(log_dir=root_dir,
                                output_transform=lambda x: None),
        TensorBoardImageHandler(log_dir=root_dir,
                                batch_transform=lambda x:
                                (x["image"], x["label"]),
                                output_transform=lambda x: x["pred"]),
        CheckpointSaver(save_dir=root_dir,
                        save_dict={"net": net},
                        save_key_metric=True),
    ]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96),
                                     sw_batch_size=4,
                                     overlap=0.5),
        post_transform=val_post_transforms,
        key_val_metric={
            "val_mean_dice":
            MeanDice(include_background=True,
                     output_transform=lambda x: (x["pred"], x["label"]))
        },
        additional_metrics={
            "val_acc":
            Accuracy(output_transform=lambda x: (x["pred"], x["label"]))
        },
        val_handlers=val_handlers,
        amp=True if amp else False,
    )

    train_post_transforms = Compose([
        Activationsd(keys="pred", sigmoid=True),
        AsDiscreted(keys="pred", threshold_values=True),
        KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
    ])
    train_handlers = [
        LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
        ValidationHandler(validator=evaluator, interval=2, epoch_level=True),
        StatsHandler(tag_name="train_loss",
                     output_transform=lambda x: x["loss"]),
        TensorBoardStatsHandler(log_dir=root_dir,
                                tag_name="train_loss",
                                output_transform=lambda x: x["loss"]),
        CheckpointSaver(save_dir=root_dir,
                        save_dict={
                            "net": net,
                            "opt": opt
                        },
                        save_interval=2,
                        epoch_level=True),
    ]

    trainer = SupervisedTrainer(
        device=device,
        max_epochs=5,
        train_data_loader=train_loader,
        network=net,
        optimizer=opt,
        loss_function=loss,
        inferer=SimpleInferer(),
        post_transform=train_post_transforms,
        key_train_metric={
            "train_acc":
            Accuracy(output_transform=lambda x: (x["pred"], x["label"]))
        },
        train_handlers=train_handlers,
        amp=True if amp else False,
    )
    trainer.run()

    return evaluator.state.best_metric
Esempio n. 14
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # create a temporary directory and 40 random image, mask paris
    tempdir = tempfile.mkdtemp()
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(40):
        im, seg = create_test_image_3d(128,
                                       128,
                                       128,
                                       num_seg_classes=1,
                                       channel_dim=-1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz"))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
    train_files = [{
        "img": img,
        "seg": seg
    } for img, seg in zip(images[:20], segs[:20])]
    val_files = [{
        "img": img,
        "seg": seg
    } for img, seg in zip(images[-20:], segs[-20:])]

    # define transforms for image and segmentation
    train_transforms = Compose([
        LoadNiftid(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        ScaleIntensityd(keys=["img", "seg"]),
        RandCropByPosNegLabeld(keys=["img", "seg"],
                               label_key="seg",
                               spatial_size=[96, 96, 96],
                               pos=1,
                               neg=1,
                               num_samples=4),
        RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 2]),
        ToTensord(keys=["img", "seg"]),
    ])
    val_transforms = Compose([
        LoadNiftid(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        ScaleIntensityd(keys=["img", "seg"]),
        ToTensord(keys=["img", "seg"]),
    ])

    # define dataset, data loader
    check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    check_loader = DataLoader(check_ds,
                              batch_size=2,
                              num_workers=4,
                              collate_fn=list_data_collate,
                              pin_memory=torch.cuda.is_available())
    check_data = monai.utils.misc.first(check_loader)
    print(check_data["img"].shape, check_data["seg"].shape)

    # create a training data loader
    train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = DataLoader(
        train_ds,
        batch_size=2,
        shuffle=True,
        num_workers=4,
        collate_fn=list_data_collate,
        pin_memory=torch.cuda.is_available(),
    )
    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=5,
                            num_workers=8,
                            collate_fn=list_data_collate,
                            pin_memory=torch.cuda.is_available())

    # create UNet, DiceLoss and Adam optimizer
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    )
    loss = monai.losses.DiceLoss(sigmoid=True)
    lr = 1e-3
    opt = torch.optim.Adam(net.parameters(), lr)
    device = torch.device("cuda:0")

    # Ignite trainer expects batch=(img, seg) and returns output=loss at every iteration,
    # user can add output_transform to return other values, like: y_pred, y, etc.
    def prepare_batch(batch, device=None, non_blocking=False):
        return _prepare_batch((batch["img"], batch["seg"]), device,
                              non_blocking)

    trainer = create_supervised_trainer(net,
                                        opt,
                                        loss,
                                        device,
                                        False,
                                        prepare_batch=prepare_batch)

    # adding checkpoint handler to save models (network params and optimizer stats) during training
    checkpoint_handler = ModelCheckpoint("./runs/",
                                         "net",
                                         n_saved=10,
                                         require_empty=False)
    trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                              handler=checkpoint_handler,
                              to_save={
                                  "net": net,
                                  "opt": opt
                              })

    # StatsHandler prints loss at every iteration and print metrics at every epoch,
    # we don't set metrics for trainer here, so just print loss, user can also customize print functions
    # and can use output_transform to convert engine.state.output if it's not loss value
    train_stats_handler = StatsHandler(name="trainer")
    train_stats_handler.attach(trainer)

    # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler
    train_tensorboard_stats_handler = TensorBoardStatsHandler()
    train_tensorboard_stats_handler.attach(trainer)

    validation_every_n_iters = 5
    # set parameters for validation
    metric_name = "Mean_Dice"
    # add evaluation metric to the evaluator engine
    val_metrics = {metric_name: MeanDice(sigmoid=True, to_onehot_y=False)}

    # Ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration,
    # user can add output_transform to return other values
    evaluator = create_supervised_evaluator(net,
                                            val_metrics,
                                            device,
                                            True,
                                            prepare_batch=prepare_batch)

    @trainer.on(Events.ITERATION_COMPLETED(every=validation_every_n_iters))
    def run_validation(engine):
        evaluator.run(val_loader)

    # add early stopping handler to evaluator
    early_stopper = EarlyStopping(
        patience=4,
        score_function=stopping_fn_from_metric(metric_name),
        trainer=trainer)
    evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                                handler=early_stopper)

    # add stats event handler to print validation stats via evaluator
    val_stats_handler = StatsHandler(
        name="evaluator",
        output_transform=lambda x:
        None,  # no need to print loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.epoch,
    )  # fetch global epoch number from trainer
    val_stats_handler.attach(evaluator)

    # add handler to record metrics to TensorBoard at every validation epoch
    val_tensorboard_stats_handler = TensorBoardStatsHandler(
        output_transform=lambda x:
        None,  # no need to plot loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.iteration,
    )  # fetch global iteration number from trainer
    val_tensorboard_stats_handler.attach(evaluator)

    # add handler to draw the first image and the corresponding label and model output in the last batch
    # here we draw the 3D output as GIF format along the depth axis, every 2 validation iterations.
    val_tensorboard_image_handler = TensorBoardImageHandler(
        batch_transform=lambda batch: (batch["img"], batch["seg"]),
        output_transform=lambda output: predict_segmentation(output[0]),
        global_iter_transform=lambda x: trainer.state.epoch,
    )
    evaluator.add_event_handler(event_name=Events.ITERATION_COMPLETED(every=2),
                                handler=val_tensorboard_image_handler)

    train_epochs = 5
    state = trainer.run(train_loader, train_epochs)
    print(state)
    shutil.rmtree(tempdir)
Esempio n. 15
0
trainer = create_supervised_trainer(net, opt, loss, device, False)

# adding checkpoint handler to save models (network params and optimizer stats) during training
checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False)
trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                          handler=checkpoint_handler,
                          to_save={'net': net, 'opt': opt})

# StatsHandler prints loss at every iteration and print metrics at every epoch,
# we don't set metrics for trainer here, so just print loss, user can also customize print functions
# and can use output_transform to convert engine.state.output if it's not loss value
train_stats_handler = StatsHandler(name='trainer')
train_stats_handler.attach(trainer)

# TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler
train_tensorboard_stats_handler = TensorBoardStatsHandler()
train_tensorboard_stats_handler.attach(trainer)

# set parameters for validation
validation_every_n_epochs = 1

metric_name = 'Accuracy'
# add evaluation metric to the evaluator engine
val_metrics = {metric_name: Accuracy()}
# ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration,
# user can add output_transform to return other values
evaluator = create_supervised_evaluator(net, val_metrics, device, True)

# add stats event handler to print validation stats via evaluator
val_stats_handler = StatsHandler(
    name='evaluator',
Esempio n. 16
0
def train(cfg):
    log_dir = create_log_dir(cfg)
    device = set_device(cfg)
    # --------------------------------------------------------------------------
    # Data Loading and Preprocessing
    # --------------------------------------------------------------------------
    # __________________________________________________________________________
    # Build MONAI preprocessing
    train_preprocess = Compose([
        ToTensorD(keys="image"),
        TorchVisionD(keys="image",
                     name="ColorJitter",
                     brightness=64.0 / 255.0,
                     contrast=0.75,
                     saturation=0.25,
                     hue=0.04),
        ToNumpyD(keys="image"),
        RandFlipD(keys="image", prob=0.5),
        RandRotate90D(keys="image", prob=0.5),
        CastToTypeD(keys="image", dtype=np.float32),
        RandZoomD(keys="image", prob=0.5, min_zoom=0.9, max_zoom=1.1),
        ScaleIntensityRangeD(keys="image",
                             a_min=0.0,
                             a_max=255.0,
                             b_min=-1.0,
                             b_max=1.0),
        ToTensorD(keys=("image", "label")),
    ])
    valid_preprocess = Compose([
        CastToTypeD(keys="image", dtype=np.float32),
        ScaleIntensityRangeD(keys="image",
                             a_min=0.0,
                             a_max=255.0,
                             b_min=-1.0,
                             b_max=1.0),
        ToTensorD(keys=("image", "label")),
    ])
    # __________________________________________________________________________
    # Create MONAI dataset
    train_json_info_list = load_decathlon_datalist(
        data_list_file_path=cfg["dataset_json"],
        data_list_key="training",
        base_dir=cfg["data_root"],
    )
    valid_json_info_list = load_decathlon_datalist(
        data_list_file_path=cfg["dataset_json"],
        data_list_key="validation",
        base_dir=cfg["data_root"],
    )

    train_dataset = PatchWSIDataset(
        train_json_info_list,
        cfg["region_size"],
        cfg["grid_shape"],
        cfg["patch_size"],
        train_preprocess,
        image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM",
    )
    valid_dataset = PatchWSIDataset(
        valid_json_info_list,
        cfg["region_size"],
        cfg["grid_shape"],
        cfg["patch_size"],
        valid_preprocess,
        image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM",
    )

    # __________________________________________________________________________
    # DataLoaders
    train_dataloader = DataLoader(train_dataset,
                                  num_workers=cfg["num_workers"],
                                  batch_size=cfg["batch_size"],
                                  pin_memory=True)
    valid_dataloader = DataLoader(valid_dataset,
                                  num_workers=cfg["num_workers"],
                                  batch_size=cfg["batch_size"],
                                  pin_memory=True)

    # __________________________________________________________________________
    # Get sample batch and some info
    first_sample = first(train_dataloader)
    if first_sample is None:
        raise ValueError("Fist sample is None!")

    print("image: ")
    print("    shape", first_sample["image"].shape)
    print("    type: ", type(first_sample["image"]))
    print("    dtype: ", first_sample["image"].dtype)
    print("labels: ")
    print("    shape", first_sample["label"].shape)
    print("    type: ", type(first_sample["label"]))
    print("    dtype: ", first_sample["label"].dtype)
    print(f"batch size: {cfg['batch_size']}")
    print(f"train number of batches: {len(train_dataloader)}")
    print(f"valid number of batches: {len(valid_dataloader)}")

    # --------------------------------------------------------------------------
    # Deep Learning Classification Model
    # --------------------------------------------------------------------------
    # __________________________________________________________________________
    # initialize model
    model = TorchVisionFCModel("resnet18",
                               num_classes=1,
                               use_conv=True,
                               pretrained=cfg["pretrain"])
    model = model.to(device)

    # loss function
    loss_func = torch.nn.BCEWithLogitsLoss()
    loss_func = loss_func.to(device)

    # optimizer
    if cfg["novograd"]:
        optimizer = Novograd(model.parameters(), cfg["lr"])
    else:
        optimizer = SGD(model.parameters(), lr=cfg["lr"], momentum=0.9)

    # AMP scaler
    if cfg["amp"]:
        cfg["amp"] = True if monai.utils.get_torch_version_tuple() >= (
            1, 6) else False
    else:
        cfg["amp"] = False

    scheduler = lr_scheduler.CosineAnnealingLR(optimizer,
                                               T_max=cfg["n_epochs"])

    # --------------------------------------------
    # Ignite Trainer/Evaluator
    # --------------------------------------------
    # Evaluator
    val_handlers = [
        CheckpointSaver(save_dir=log_dir,
                        save_dict={"net": model},
                        save_key_metric=True),
        StatsHandler(output_transform=lambda x: None),
        TensorBoardStatsHandler(log_dir=log_dir,
                                output_transform=lambda x: None),
    ]
    val_postprocessing = Compose([
        ActivationsD(keys="pred", sigmoid=True),
        AsDiscreteD(keys="pred", threshold=0.5)
    ])
    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=valid_dataloader,
        network=model,
        postprocessing=val_postprocessing,
        key_val_metric={
            "val_acc":
            Accuracy(output_transform=from_engine(["pred", "label"]))
        },
        val_handlers=val_handlers,
        amp=cfg["amp"],
    )

    # Trainer
    train_handlers = [
        LrScheduleHandler(lr_scheduler=scheduler, print_lr=True),
        CheckpointSaver(save_dir=cfg["logdir"],
                        save_dict={
                            "net": model,
                            "opt": optimizer
                        },
                        save_interval=1,
                        epoch_level=True),
        StatsHandler(tag_name="train_loss",
                     output_transform=from_engine(["loss"], first=True)),
        ValidationHandler(validator=evaluator, interval=1, epoch_level=True),
        TensorBoardStatsHandler(log_dir=cfg["logdir"],
                                tag_name="train_loss",
                                output_transform=from_engine(["loss"],
                                                             first=True)),
    ]
    train_postprocessing = Compose([
        ActivationsD(keys="pred", sigmoid=True),
        AsDiscreteD(keys="pred", threshold=0.5)
    ])

    trainer = SupervisedTrainer(
        device=device,
        max_epochs=cfg["n_epochs"],
        train_data_loader=train_dataloader,
        network=model,
        optimizer=optimizer,
        loss_function=loss_func,
        postprocessing=train_postprocessing,
        key_train_metric={
            "train_acc":
            Accuracy(output_transform=from_engine(["pred", "label"]))
        },
        train_handlers=train_handlers,
        amp=cfg["amp"],
    )
    trainer.run()
Esempio n. 17
0
def run_training_test(root_dir, device="cuda:0"):
    real_images = sorted(glob(os.path.join(root_dir, "img*.nii.gz")))
    train_files = [{"reals": img} for img in zip(real_images)]

    # prepare real data
    train_transforms = Compose([
        LoadNiftid(keys=["reals"]),
        AsChannelFirstd(keys=["reals"]),
        ScaleIntensityd(keys=["reals"]),
        RandFlipd(keys=["reals"], prob=0.5),
        ToTensord(keys=["reals"]),
    ])
    train_ds = monai.data.CacheDataset(data=train_files,
                                       transform=train_transforms,
                                       cache_rate=0.5)
    train_loader = monai.data.DataLoader(train_ds,
                                         batch_size=2,
                                         shuffle=True,
                                         num_workers=4)

    learning_rate = 2e-4
    betas = (0.5, 0.999)
    real_label = 1
    fake_label = 0

    # create discriminator
    disc_net = Discriminator(in_shape=(1, 64, 64),
                             channels=(8, 16, 32, 64, 1),
                             strides=(2, 2, 2, 2, 1),
                             num_res_units=1,
                             kernel_size=5).to(device)
    disc_net.apply(normal_init)
    disc_opt = torch.optim.Adam(disc_net.parameters(),
                                learning_rate,
                                betas=betas)
    disc_loss_criterion = torch.nn.BCELoss()

    def discriminator_loss(gen_images, real_images):
        real = real_images.new_full((real_images.shape[0], 1), real_label)
        gen = gen_images.new_full((gen_images.shape[0], 1), fake_label)
        realloss = disc_loss_criterion(disc_net(real_images), real)
        genloss = disc_loss_criterion(disc_net(gen_images.detach()), gen)
        return torch.div(torch.add(realloss, genloss), 2)

    # create generator
    latent_size = 64
    gen_net = Generator(latent_shape=latent_size,
                        start_shape=(latent_size, 8, 8),
                        channels=[32, 16, 8, 1],
                        strides=[2, 2, 2, 1])
    gen_net.apply(normal_init)
    gen_net.conv.add_module("activation", torch.nn.Sigmoid())
    gen_net = gen_net.to(device)
    gen_opt = torch.optim.Adam(gen_net.parameters(),
                               learning_rate,
                               betas=betas)
    gen_loss_criterion = torch.nn.BCELoss()

    def generator_loss(gen_images):
        output = disc_net(gen_images)
        cats = output.new_full(output.shape, real_label)
        return gen_loss_criterion(output, cats)

    key_train_metric = None

    train_handlers = [
        StatsHandler(
            name="training_loss",
            output_transform=lambda x: {
                Keys.GLOSS: x[Keys.GLOSS],
                Keys.DLOSS: x[Keys.DLOSS]
            },
        ),
        TensorBoardStatsHandler(
            log_dir=root_dir,
            tag_name="training_loss",
            output_transform=lambda x: {
                Keys.GLOSS: x[Keys.GLOSS],
                Keys.DLOSS: x[Keys.DLOSS]
            },
        ),
        CheckpointSaver(save_dir=root_dir,
                        save_dict={
                            "g_net": gen_net,
                            "d_net": disc_net
                        },
                        save_interval=2,
                        epoch_level=True),
    ]

    disc_train_steps = 2
    num_epochs = 5

    trainer = GanTrainer(
        device,
        num_epochs,
        train_loader,
        gen_net,
        gen_opt,
        generator_loss,
        disc_net,
        disc_opt,
        discriminator_loss,
        d_train_steps=disc_train_steps,
        latent_shape=latent_size,
        key_train_metric=key_train_metric,
        train_handlers=train_handlers,
    )
    trainer.run()

    return trainer.state
Esempio n. 18
0
def main(tempdir):
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # create a temporary directory and 40 random image, mask pairs
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(40):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))

    # define transforms for image and segmentation
    train_imtrans = Compose([
        ScaleIntensity(),
        AddChannel(),
        RandSpatialCrop((96, 96, 96), random_size=False),
        EnsureType(),
    ])
    train_segtrans = Compose([
        AddChannel(),
        RandSpatialCrop((96, 96, 96), random_size=False),
        EnsureType()
    ])
    val_imtrans = Compose(
        [ScaleIntensity(),
         AddChannel(),
         Resize((96, 96, 96)),
         EnsureType()])
    val_segtrans = Compose([AddChannel(), Resize((96, 96, 96)), EnsureType()])

    # define image dataset, data loader
    check_ds = ImageDataset(images,
                            segs,
                            transform=train_imtrans,
                            seg_transform=train_segtrans)
    check_loader = DataLoader(check_ds,
                              batch_size=10,
                              num_workers=2,
                              pin_memory=torch.cuda.is_available())
    im, seg = monai.utils.misc.first(check_loader)
    print(im.shape, seg.shape)

    # create a training data loader
    train_ds = ImageDataset(images[:20],
                            segs[:20],
                            transform=train_imtrans,
                            seg_transform=train_segtrans)
    train_loader = DataLoader(
        train_ds,
        batch_size=5,
        shuffle=True,
        num_workers=8,
        pin_memory=torch.cuda.is_available(),
    )
    # create a validation data loader
    val_ds = ImageDataset(images[-20:],
                          segs[-20:],
                          transform=val_imtrans,
                          seg_transform=val_segtrans)
    val_loader = DataLoader(val_ds,
                            batch_size=5,
                            num_workers=8,
                            pin_memory=torch.cuda.is_available())

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = monai.networks.nets.UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss = monai.losses.DiceLoss(sigmoid=True)
    lr = 1e-3
    opt = torch.optim.Adam(net.parameters(), lr)

    # Ignite trainer expects batch=(img, seg) and returns output=loss at every iteration,
    # user can add output_transform to return other values, like: y_pred, y, etc.
    trainer = create_supervised_trainer(net, opt, loss, device, False)

    # adding checkpoint handler to save models (network params and optimizer stats) during training
    checkpoint_handler = ModelCheckpoint("./runs_array/",
                                         "net",
                                         n_saved=10,
                                         require_empty=False)
    trainer.add_event_handler(
        event_name=Events.EPOCH_COMPLETED,
        handler=checkpoint_handler,
        to_save={
            "net": net,
            "opt": opt
        },
    )

    # StatsHandler prints loss at every iteration and print metrics at every epoch,
    # we don't set metrics for trainer here, so just print loss, user can also customize print functions
    # and can use output_transform to convert engine.state.output if it's not a loss value
    train_stats_handler = StatsHandler(name="trainer",
                                       output_transform=lambda x: x)
    train_stats_handler.attach(trainer)

    # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler
    train_tensorboard_stats_handler = TensorBoardStatsHandler(
        output_transform=lambda x: x)
    train_tensorboard_stats_handler.attach(trainer)

    validation_every_n_epochs = 1
    # Set parameters for validation
    metric_name = "Mean_Dice"
    # add evaluation metric to the evaluator engine
    val_metrics = {metric_name: MeanDice()}

    post_pred = Compose(
        [EnsureType(),
         Activations(sigmoid=True),
         AsDiscrete(threshold=0.5)])
    post_label = Compose([EnsureType(), AsDiscrete(threshold=0.5)])

    # Ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration,
    # user can add output_transform to return other values
    evaluator = create_supervised_evaluator(
        net,
        val_metrics,
        device,
        True,
        output_transform=lambda x, y, y_pred:
        ([post_pred(i) for i in decollate_batch(y_pred)],
         [post_label(i) for i in decollate_batch(y)]),
    )

    @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs))
    def run_validation(engine):
        evaluator.run(val_loader)

    # add early stopping handler to evaluator
    early_stopper = EarlyStopping(
        patience=4,
        score_function=stopping_fn_from_metric(metric_name),
        trainer=trainer)
    evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                                handler=early_stopper)

    # add stats event handler to print validation stats via evaluator
    val_stats_handler = StatsHandler(
        name="evaluator",
        output_transform=lambda x:
        None,  # no need to print loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.epoch,
    )  # fetch global epoch number from trainer
    val_stats_handler.attach(evaluator)

    # add handler to record metrics to TensorBoard at every validation epoch
    val_tensorboard_stats_handler = TensorBoardStatsHandler(
        output_transform=lambda x:
        None,  # no need to plot loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.epoch,
    )  # fetch global epoch number from trainer
    val_tensorboard_stats_handler.attach(evaluator)

    # add handler to draw the first image and the corresponding label and model output in the last batch
    # here we draw the 3D output as GIF format along Depth axis, at every validation epoch
    val_tensorboard_image_handler = TensorBoardImageHandler(
        batch_transform=lambda batch: (batch[0], batch[1]),
        output_transform=lambda output: output[0],
        global_iter_transform=lambda x: trainer.state.epoch,
    )
    evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                                handler=val_tensorboard_image_handler)

    train_epochs = 30
    state = trainer.run(train_loader, train_epochs)
    print(state)
Esempio n. 19
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/
    # the path of ixi IXI-T1 dataset
    data_path = os.sep.join([".", "workspace", "data", "medical", "ixi", "IXI-T1"])
    images = [
        "IXI314-IOP-0889-T1.nii.gz",
        "IXI249-Guys-1072-T1.nii.gz",
        "IXI609-HH-2600-T1.nii.gz",
        "IXI173-HH-1590-T1.nii.gz",
        "IXI020-Guys-0700-T1.nii.gz",
        "IXI342-Guys-0909-T1.nii.gz",
        "IXI134-Guys-0780-T1.nii.gz",
        "IXI577-HH-2661-T1.nii.gz",
        "IXI066-Guys-0731-T1.nii.gz",
        "IXI130-HH-1528-T1.nii.gz",
        "IXI607-Guys-1097-T1.nii.gz",
        "IXI175-HH-1570-T1.nii.gz",
        "IXI385-HH-2078-T1.nii.gz",
        "IXI344-Guys-0905-T1.nii.gz",
        "IXI409-Guys-0960-T1.nii.gz",
        "IXI584-Guys-1129-T1.nii.gz",
        "IXI253-HH-1694-T1.nii.gz",
        "IXI092-HH-1436-T1.nii.gz",
        "IXI574-IOP-1156-T1.nii.gz",
        "IXI585-Guys-1130-T1.nii.gz",
    ]
    images = [os.sep.join([data_path, f]) for f in images]

    # 2 binary labels for gender classification: man and woman
    labels = np.array([0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0], dtype=np.int64)

    # define transforms
    train_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), RandRotate90(), EnsureType()])
    val_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), EnsureType()])

    # define image dataset, data loader
    check_ds = ImageDataset(image_files=images, labels=labels, transform=train_transforms)
    check_loader = DataLoader(check_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available())
    im, label = monai.utils.misc.first(check_loader)
    print(type(im), im.shape, label)

    # create DenseNet121, CrossEntropyLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=2).to(device)
    loss = torch.nn.CrossEntropyLoss()
    lr = 1e-5
    opt = torch.optim.Adam(net.parameters(), lr)

    # Ignite trainer expects batch=(img, label) and returns output=loss at every iteration,
    # user can add output_transform to return other values, like: y_pred, y, etc.
    trainer = create_supervised_trainer(net, opt, loss, device, False)

    # adding checkpoint handler to save models (network params and optimizer stats) during training
    checkpoint_handler = ModelCheckpoint("./runs_array/", "net", n_saved=10, require_empty=False)
    trainer.add_event_handler(
        event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={"net": net, "opt": opt}
    )

    # StatsHandler prints loss at every iteration and print metrics at every epoch,
    # we don't set metrics for trainer here, so just print loss, user can also customize print functions
    # and can use output_transform to convert engine.state.output if it's not loss value
    train_stats_handler = StatsHandler(name="trainer", output_transform=lambda x: x)
    train_stats_handler.attach(trainer)

    # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler
    train_tensorboard_stats_handler = TensorBoardStatsHandler(output_transform=lambda x: x)
    train_tensorboard_stats_handler.attach(trainer)

    # set parameters for validation
    validation_every_n_epochs = 1

    metric_name = "Accuracy"
    # add evaluation metric to the evaluator engine
    val_metrics = {metric_name: Accuracy()}
    # Ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration,
    # user can add output_transform to return other values
    evaluator = create_supervised_evaluator(net, val_metrics, device, True)

    # add stats event handler to print validation stats via evaluator
    val_stats_handler = StatsHandler(
        name="evaluator",
        output_transform=lambda x: None,  # no need to print loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.epoch,
    )  # fetch global epoch number from trainer
    val_stats_handler.attach(evaluator)

    # add handler to record metrics to TensorBoard at every epoch
    val_tensorboard_stats_handler = TensorBoardStatsHandler(
        output_transform=lambda x: None,  # no need to plot loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.epoch,
    )  # fetch global epoch number from trainer
    val_tensorboard_stats_handler.attach(evaluator)

    # add early stopping handler to evaluator
    early_stopper = EarlyStopping(patience=4, score_function=stopping_fn_from_metric(metric_name), trainer=trainer)
    evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper)

    # create a validation data loader
    val_ds = ImageDataset(image_files=images[-10:], labels=labels[-10:], transform=val_transforms)
    val_loader = DataLoader(val_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available())

    @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs))
    def run_validation(engine):
        evaluator.run(val_loader)

    # create a training data loader
    train_ds = ImageDataset(image_files=images[:10], labels=labels[:10], transform=train_transforms)
    train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available())

    train_epochs = 30
    state = trainer.run(train_loader, train_epochs)
    print(state)
Esempio n. 20
0
    def configure(self):
        self.set_device()
        network = UNet(
            dimensions=3,
            in_channels=1,
            out_channels=2,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2,
            norm=Norm.BATCH,
        ).to(self.device)
        if self.multi_gpu:
            network = DistributedDataParallel(
                module=network,
                device_ids=[self.device],
                find_unused_parameters=False,
            )

        train_transforms = Compose([
            LoadImaged(keys=("image", "label")),
            EnsureChannelFirstd(keys=("image", "label")),
            Spacingd(keys=("image", "label"),
                     pixdim=[1.0, 1.0, 1.0],
                     mode=["bilinear", "nearest"]),
            ScaleIntensityRanged(
                keys="image",
                a_min=-57,
                a_max=164,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            CropForegroundd(keys=("image", "label"), source_key="image"),
            RandCropByPosNegLabeld(
                keys=("image", "label"),
                label_key="label",
                spatial_size=(96, 96, 96),
                pos=1,
                neg=1,
                num_samples=4,
                image_key="image",
                image_threshold=0,
            ),
            RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
            ToTensord(keys=("image", "label")),
        ])
        train_datalist = load_decathlon_datalist(self.data_list_file_path,
                                                 True, "training")
        if self.multi_gpu:
            train_datalist = partition_dataset(
                data=train_datalist,
                shuffle=True,
                num_partitions=dist.get_world_size(),
                even_divisible=True,
            )[dist.get_rank()]
        train_ds = CacheDataset(
            data=train_datalist,
            transform=train_transforms,
            cache_num=32,
            cache_rate=1.0,
            num_workers=4,
        )
        train_data_loader = DataLoader(
            train_ds,
            batch_size=2,
            shuffle=True,
            num_workers=4,
        )
        val_transforms = Compose([
            LoadImaged(keys=("image", "label")),
            EnsureChannelFirstd(keys=("image", "label")),
            ScaleIntensityRanged(
                keys="image",
                a_min=-57,
                a_max=164,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            CropForegroundd(keys=("image", "label"), source_key="image"),
            ToTensord(keys=("image", "label")),
        ])

        val_datalist = load_decathlon_datalist(self.data_list_file_path, True,
                                               "validation")
        val_ds = CacheDataset(val_datalist, val_transforms, 9, 0.0, 4)
        val_data_loader = DataLoader(
            val_ds,
            batch_size=1,
            shuffle=False,
            num_workers=4,
        )
        post_transform = Compose([
            Activationsd(keys="pred", softmax=True),
            AsDiscreted(
                keys=["pred", "label"],
                argmax=[True, False],
                to_onehot=True,
                n_classes=2,
            ),
        ])
        # metric
        key_val_metric = {
            "val_mean_dice":
            MeanDice(
                include_background=False,
                output_transform=lambda x: (x["pred"], x["label"]),
                device=self.device,
            )
        }
        val_handlers = [
            StatsHandler(output_transform=lambda x: None),
            CheckpointSaver(
                save_dir=self.ckpt_dir,
                save_dict={"model": network},
                save_key_metric=True,
            ),
            TensorBoardStatsHandler(log_dir=self.ckpt_dir,
                                    output_transform=lambda x: None),
        ]
        self.eval_engine = SupervisedEvaluator(
            device=self.device,
            val_data_loader=val_data_loader,
            network=network,
            inferer=SlidingWindowInferer(
                roi_size=[160, 160, 160],
                sw_batch_size=4,
                overlap=0.5,
            ),
            post_transform=post_transform,
            key_val_metric=key_val_metric,
            val_handlers=val_handlers,
            amp=self.amp,
        )

        optimizer = torch.optim.Adam(network.parameters(), self.learning_rate)
        loss_function = DiceLoss(to_onehot_y=True, softmax=True)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                       step_size=5000,
                                                       gamma=0.1)
        train_handlers = [
            LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
            ValidationHandler(validator=self.eval_engine,
                              interval=self.val_interval,
                              epoch_level=True),
            StatsHandler(tag_name="train_loss",
                         output_transform=lambda x: x["loss"]),
            TensorBoardStatsHandler(
                log_dir=self.ckpt_dir,
                tag_name="train_loss",
                output_transform=lambda x: x["loss"],
            ),
        ]

        self.train_engine = SupervisedTrainer(
            device=self.device,
            max_epochs=self.max_epochs,
            train_data_loader=train_data_loader,
            network=network,
            optimizer=optimizer,
            loss_function=loss_function,
            inferer=SimpleInferer(),
            post_transform=post_transform,
            key_train_metric=None,
            train_handlers=train_handlers,
            amp=self.amp,
        )

        if self.local_rank > 0:
            self.train_engine.logger.setLevel(logging.WARNING)
            self.eval_engine.logger.setLevel(logging.WARNING)
Esempio n. 21
0
def create_trainer(args):
    set_determinism(seed=args.seed)

    multi_gpu = args.multi_gpu
    local_rank = args.local_rank
    if multi_gpu:
        dist.init_process_group(backend="nccl", init_method="env://")
        device = torch.device("cuda:{}".format(local_rank))
        torch.cuda.set_device(device)
    else:
        device = torch.device("cuda" if args.use_gpu else "cpu")

    pre_transforms = get_pre_transforms(args.roi_size, args.model_size,
                                        args.dimensions)
    click_transforms = get_click_transforms()
    post_transform = get_post_transforms()

    train_loader, val_loader = get_loaders(args, pre_transforms)

    # define training components
    network = get_network(args.network, args.channels,
                          args.dimensions).to(device)
    if multi_gpu:
        network = torch.nn.parallel.DistributedDataParallel(
            network, device_ids=[local_rank], output_device=local_rank)

    if args.resume:
        logging.info('{}:: Loading Network...'.format(local_rank))
        map_location = {"cuda:0": "cuda:{}".format(local_rank)}
        network.load_state_dict(
            torch.load(args.model_filepath, map_location=map_location))

    # define event-handlers for engine
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        TensorBoardStatsHandler(log_dir=args.output,
                                output_transform=lambda x: None),
        DeepgrowStatsHandler(log_dir=args.output,
                             tag_name='val_dice',
                             image_interval=args.image_interval),
        CheckpointSaver(save_dir=args.output,
                        save_dict={"net": network},
                        save_key_metric=True,
                        save_final=True,
                        save_interval=args.save_interval,
                        final_filename='model.pt')
    ]
    val_handlers = val_handlers if local_rank == 0 else None

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=network,
        iteration_update=Interaction(
            transforms=click_transforms,
            max_interactions=args.max_val_interactions,
            key_probability='probability',
            train=False),
        inferer=SimpleInferer(),
        post_transform=post_transform,
        key_val_metric={
            "val_dice":
            MeanDice(include_background=False,
                     output_transform=lambda x: (x["pred"], x["label"]))
        },
        val_handlers=val_handlers)

    loss_function = DiceLoss(sigmoid=True, squared_pred=True)
    optimizer = torch.optim.Adam(network.parameters(), args.learning_rate)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   step_size=5000,
                                                   gamma=0.1)

    train_handlers = [
        LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
        ValidationHandler(validator=evaluator,
                          interval=args.val_freq,
                          epoch_level=True),
        StatsHandler(tag_name="train_loss",
                     output_transform=lambda x: x["loss"]),
        TensorBoardStatsHandler(log_dir=args.output,
                                tag_name="train_loss",
                                output_transform=lambda x: x["loss"]),
        CheckpointSaver(save_dir=args.output,
                        save_dict={
                            "net": network,
                            "opt": optimizer,
                            "lr": lr_scheduler
                        },
                        save_interval=args.save_interval * 2,
                        save_final=True,
                        final_filename='checkpoint.pt'),
    ]
    train_handlers = train_handlers if local_rank == 0 else train_handlers[:2]

    trainer = SupervisedTrainer(
        device=device,
        max_epochs=args.epochs,
        train_data_loader=train_loader,
        network=network,
        iteration_update=Interaction(
            transforms=click_transforms,
            max_interactions=args.max_train_interactions,
            key_probability='probability',
            train=True),
        optimizer=optimizer,
        loss_function=loss_function,
        inferer=SimpleInferer(),
        post_transform=post_transform,
        amp=args.amp,
        key_train_metric={
            "train_dice":
            MeanDice(include_background=False,
                     output_transform=lambda x: (x["pred"], x["label"]))
        },
        train_handlers=train_handlers,
    )
    return trainer
def run_training_test(root_dir, device="cuda:0", amp=False, num_workers=4):
    images = sorted(glob(os.path.join(root_dir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    train_files = [{"image": img, "label": seg} for img, seg in zip(images[:20], segs[:20])]
    val_files = [{"image": img, "label": seg} for img, seg in zip(images[-20:], segs[-20:])]

    # define transforms for image and segmentation
    train_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
            ScaleIntensityd(keys=["image", "label"]),
            RandCropByPosNegLabeld(
                keys=["image", "label"], label_key="label", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4
            ),
            RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]),
            ToTensord(keys=["image", "label"]),
        ]
    )
    val_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
            ScaleIntensityd(keys=["image", "label"]),
            ToTensord(keys=["image", "label"]),
        ]
    )

    # create a training data loader
    train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.5)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=num_workers)
    # create a validation data loader
    val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0)
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=num_workers)

    # create UNet, DiceLoss and Adam optimizer
    net = monai.networks.nets.UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss = monai.losses.DiceLoss(sigmoid=True)
    opt = torch.optim.Adam(net.parameters(), 1e-3)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=2, gamma=0.1)
    summary_writer = SummaryWriter(log_dir=root_dir)

    val_postprocessing = Compose(
        [
            ToTensord(keys=["pred", "label"]),
            Activationsd(keys="pred", sigmoid=True),
            AsDiscreted(keys="pred", threshold=0.5),
            KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
        ]
    )

    class _TestEvalIterEvents:
        def attach(self, engine):
            engine.add_event_handler(IterationEvents.FORWARD_COMPLETED, self._forward_completed)

        def _forward_completed(self, engine):
            pass

    val_handlers = [
        StatsHandler(iteration_log=False),
        TensorBoardStatsHandler(summary_writer=summary_writer, iteration_log=False),
        TensorBoardImageHandler(
            log_dir=root_dir, batch_transform=from_engine(["image", "label"]), output_transform=from_engine("pred")
        ),
        CheckpointSaver(save_dir=root_dir, save_dict={"net": net}, save_key_metric=True),
        _TestEvalIterEvents(),
    ]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5),
        postprocessing=val_postprocessing,
        key_val_metric={
            "val_mean_dice": MeanDice(include_background=True, output_transform=from_engine(["pred", "label"]))
        },
        additional_metrics={"val_acc": Accuracy(output_transform=from_engine(["pred", "label"]))},
        metric_cmp_fn=lambda cur, prev: cur >= prev,  # if greater or equal, treat as new best metric
        val_handlers=val_handlers,
        amp=bool(amp),
        to_kwargs={"memory_format": torch.preserve_format},
        amp_kwargs={"dtype": torch.float16 if bool(amp) else torch.float32},
    )

    train_postprocessing = Compose(
        [
            ToTensord(keys=["pred", "label"]),
            Activationsd(keys="pred", sigmoid=True),
            AsDiscreted(keys="pred", threshold=0.5),
            KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
        ]
    )

    class _TestTrainIterEvents:
        def attach(self, engine):
            engine.add_event_handler(IterationEvents.FORWARD_COMPLETED, self._forward_completed)
            engine.add_event_handler(IterationEvents.LOSS_COMPLETED, self._loss_completed)
            engine.add_event_handler(IterationEvents.BACKWARD_COMPLETED, self._backward_completed)
            engine.add_event_handler(IterationEvents.MODEL_COMPLETED, self._model_completed)

        def _forward_completed(self, engine):
            pass

        def _loss_completed(self, engine):
            pass

        def _backward_completed(self, engine):
            pass

        def _model_completed(self, engine):
            pass

    train_handlers = [
        LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
        ValidationHandler(validator=evaluator, interval=2, epoch_level=True),
        StatsHandler(tag_name="train_loss", output_transform=from_engine("loss", first=True)),
        TensorBoardStatsHandler(
            summary_writer=summary_writer, tag_name="train_loss", output_transform=from_engine("loss", first=True)
        ),
        CheckpointSaver(save_dir=root_dir, save_dict={"net": net, "opt": opt}, save_interval=2, epoch_level=True),
        _TestTrainIterEvents(),
    ]

    trainer = SupervisedTrainer(
        device=device,
        max_epochs=5,
        train_data_loader=train_loader,
        network=net,
        optimizer=opt,
        loss_function=loss,
        inferer=SimpleInferer(),
        postprocessing=train_postprocessing,
        key_train_metric={"train_acc": Accuracy(output_transform=from_engine(["pred", "label"]))},
        train_handlers=train_handlers,
        amp=bool(amp),
        optim_set_to_none=True,
        to_kwargs={"memory_format": torch.preserve_format},
        amp_kwargs={"dtype": torch.float16 if bool(amp) else torch.float32},
    )
    trainer.run()

    return evaluator.state.best_metric
Esempio n. 23
0
def main(tempdir):
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    ################################ DATASET ################################
    # get dataset
    train_ds = CacheDataset(data=train_files,
                            transform=train_transforms,
                            cache_rate=0.5)
    train_loader = DataLoader(train_ds,
                              batch_size=2,
                              shuffle=True,
                              num_workers=4)
    val_ds = CacheDataset(data=val_files,
                          transform=val_transforms,
                          cache_rate=1.0)
    val_loader = DataLoader(val_ds, batch_size=1, num_workers=4)
    ################################ DATASET ################################

    ################################ NETWORK ################################
    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    ################################ NETWORK ################################

    ################################ LOSS ################################
    loss = monai.losses.DiceLoss(sigmoid=True)
    ################################ LOSS ################################

    ################################ OPT ################################
    opt = torch.optim.Adam(net.parameters(), 1e-3)
    ################################ OPT ################################

    ################################ LR ################################
    lr_scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=2, gamma=0.1)
    ################################ LR ################################

    ################################ Evalutaion ################################
    val_post_transforms = ...
    val_handlers = ...
    evaluator = ...

    train_post_transforms = Compose([
        Activationsd(keys="pred", sigmoid=True),
        AsDiscreted(keys="pred", threshold_values=True),
        KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
    ])
    train_handlers = [
        LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
        ValidationHandler(validator=evaluator, interval=2, epoch_level=True),
        StatsHandler(tag_name="train_loss",
                     output_transform=lambda x: x["loss"]),
        TensorBoardStatsHandler(log_dir="./runs/",
                                tag_name="train_loss",
                                output_transform=lambda x: x["loss"]),
        CheckpointSaver(save_dir="./runs/",
                        save_dict={
                            "net": net,
                            "opt": opt
                        },
                        save_interval=2,
                        epoch_level=True),
    ]

    trainer = SupervisedTrainer(
        device=device,
        max_epochs=5,
        train_data_loader=train_loader,
        network=net,
        optimizer=opt,
        loss_function=loss,
        inferer=SimpleInferer(),
        post_transform=train_post_transforms,
        key_train_metric={
            "train_acc":
            Accuracy(output_transform=lambda x: (x["pred"], x["label"]))
        },
        train_handlers=train_handlers,
        # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP training
        amp=True if monai.utils.get_torch_version_tuple() >= (1, 6) else False,
    )
    trainer.run()
Esempio n. 24
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/
    images = [
        "/workspace/data/medical/ixi/IXI-T1/IXI314-IOP-0889-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI249-Guys-1072-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI609-HH-2600-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI173-HH-1590-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI020-Guys-0700-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI342-Guys-0909-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI134-Guys-0780-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI577-HH-2661-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI066-Guys-0731-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI130-HH-1528-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz",
    ]
    # 2 binary labels for gender classification: man and woman
    labels = np.array(
        [0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0])
    train_files = [{
        "img": img,
        "label": label
    } for img, label in zip(images[:10], labels[:10])]
    val_files = [{
        "img": img,
        "label": label
    } for img, label in zip(images[-10:], labels[-10:])]

    # define transforms for image
    train_transforms = Compose([
        LoadNiftid(keys=["img"]),
        AddChanneld(keys=["img"]),
        ScaleIntensityd(keys=["img"]),
        Resized(keys=["img"], spatial_size=(96, 96, 96)),
        RandRotate90d(keys=["img"], prob=0.8, spatial_axes=[0, 2]),
        ToTensord(keys=["img"]),
    ])
    val_transforms = Compose([
        LoadNiftid(keys=["img"]),
        AddChanneld(keys=["img"]),
        ScaleIntensityd(keys=["img"]),
        Resized(keys=["img"], spatial_size=(96, 96, 96)),
        ToTensord(keys=["img"]),
    ])

    # define dataset, data loader
    check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    check_loader = DataLoader(check_ds,
                              batch_size=2,
                              num_workers=4,
                              pin_memory=torch.cuda.is_available())
    check_data = monai.utils.misc.first(check_loader)
    print(check_data["img"].shape, check_data["label"])

    # create DenseNet121, CrossEntropyLoss and Adam optimizer
    net = monai.networks.nets.densenet.densenet121(
        spatial_dims=3,
        in_channels=1,
        out_channels=2,
    )
    loss = torch.nn.CrossEntropyLoss()
    lr = 1e-5
    opt = torch.optim.Adam(net.parameters(), lr)
    device = torch.device("cuda:0")

    # Ignite trainer expects batch=(img, label) and returns output=loss at every iteration,
    # user can add output_transform to return other values, like: y_pred, y, etc.
    def prepare_batch(batch, device=None, non_blocking=False):

        return _prepare_batch((batch["img"], batch["label"]), device,
                              non_blocking)

    trainer = create_supervised_trainer(net,
                                        opt,
                                        loss,
                                        device,
                                        False,
                                        prepare_batch=prepare_batch)

    # adding checkpoint handler to save models (network params and optimizer stats) during training
    checkpoint_handler = ModelCheckpoint("./runs/",
                                         "net",
                                         n_saved=10,
                                         require_empty=False)
    trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                              handler=checkpoint_handler,
                              to_save={
                                  "net": net,
                                  "opt": opt
                              })

    # StatsHandler prints loss at every iteration and print metrics at every epoch,
    # we don't set metrics for trainer here, so just print loss, user can also customize print functions
    # and can use output_transform to convert engine.state.output if it's not loss value
    train_stats_handler = StatsHandler(name="trainer")
    train_stats_handler.attach(trainer)

    # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler
    train_tensorboard_stats_handler = TensorBoardStatsHandler()
    train_tensorboard_stats_handler.attach(trainer)

    # set parameters for validation
    validation_every_n_epochs = 1

    metric_name = "Accuracy"
    # add evaluation metric to the evaluator engine
    val_metrics = {
        metric_name: Accuracy(),
        "AUC": ROCAUC(to_onehot_y=True, add_softmax=True)
    }
    # Ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration,
    # user can add output_transform to return other values
    evaluator = create_supervised_evaluator(net,
                                            val_metrics,
                                            device,
                                            True,
                                            prepare_batch=prepare_batch)

    # add stats event handler to print validation stats via evaluator
    val_stats_handler = StatsHandler(
        name="evaluator",
        output_transform=lambda x:
        None,  # no need to print loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.epoch,
    )  # fetch global epoch number from trainer
    val_stats_handler.attach(evaluator)

    # add handler to record metrics to TensorBoard at every epoch
    val_tensorboard_stats_handler = TensorBoardStatsHandler(
        output_transform=lambda x:
        None,  # no need to plot loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.epoch,
    )  # fetch global epoch number from trainer
    val_tensorboard_stats_handler.attach(evaluator)

    # add early stopping handler to evaluator
    early_stopper = EarlyStopping(
        patience=4,
        score_function=stopping_fn_from_metric(metric_name),
        trainer=trainer)
    evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                                handler=early_stopper)

    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=2,
                            num_workers=4,
                            pin_memory=torch.cuda.is_available())

    @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs))
    def run_validation(engine):
        evaluator.run(val_loader)

    # create a training data loader
    train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    train_loader = DataLoader(train_ds,
                              batch_size=2,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=torch.cuda.is_available())

    train_epochs = 30
    state = trainer.run(train_loader, train_epochs)
Esempio n. 25
0
def main():
    """
    Basic UNet as implemented in MONAI for Fetal Brain Segmentation, but using
    ignite to manage training and validation loop and checkpointing
    :return:
    """
    """
    Read input and configuration parameters
    """
    parser = argparse.ArgumentParser(
        description='Run basic UNet with MONAI - Ignite version.')
    parser.add_argument('--config',
                        dest='config',
                        metavar='config',
                        type=str,
                        help='config file')
    args = parser.parse_args()

    with open(args.config) as f:
        config_info = yaml.load(f, Loader=yaml.FullLoader)

    # print to log the parameter setups
    print(yaml.dump(config_info))

    # GPU params
    cuda_device = config_info['device']['cuda_device']
    num_workers = config_info['device']['num_workers']
    # training and validation params
    loss_type = config_info['training']['loss_type']
    batch_size_train = config_info['training']['batch_size_train']
    batch_size_valid = config_info['training']['batch_size_valid']
    lr = float(config_info['training']['lr'])
    lr_decay = config_info['training']['lr_decay']
    if lr_decay is not None:
        lr_decay = float(lr_decay)
    nr_train_epochs = config_info['training']['nr_train_epochs']
    validation_every_n_epochs = config_info['training'][
        'validation_every_n_epochs']
    sliding_window_validation = config_info['training'][
        'sliding_window_validation']
    if 'model_to_load' in config_info['training'].keys():
        model_to_load = config_info['training']['model_to_load']
        if not os.path.exists(model_to_load):
            raise BlockingIOError(
                "cannot find model: {}".format(model_to_load))
    else:
        model_to_load = None
    if 'manual_seed' in config_info['training'].keys():
        seed = config_info['training']['manual_seed']
    else:
        seed = None
    # data params
    data_root = config_info['data']['data_root']
    training_list = config_info['data']['training_list']
    validation_list = config_info['data']['validation_list']
    # model saving
    out_model_dir = os.path.join(
        config_info['output']['out_model_dir'],
        datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_' +
        config_info['output']['output_subfix'])
    print("Saving to directory ", out_model_dir)
    if 'cache_dir' in config_info['output'].keys():
        out_cache_dir = config_info['output']['cache_dir']
    else:
        out_cache_dir = os.path.join(out_model_dir, 'persistent_cache')
    max_nr_models_saved = config_info['output']['max_nr_models_saved']
    val_image_to_tensorboad = config_info['output']['val_image_to_tensorboad']

    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    torch.cuda.set_device(cuda_device)
    if seed is not None:
        # set manual seed if required (both numpy and torch)
        set_determinism(seed=seed)
        # # set torch only seed
        # torch.manual_seed(seed)
        # torch.backends.cudnn.deterministic = True
        # torch.backends.cudnn.benchmark = False
    """
    Data Preparation
    """
    # create cache directory to store results for Persistent Dataset
    persistent_cache: Path = Path(out_cache_dir)
    persistent_cache.mkdir(parents=True, exist_ok=True)

    # create training and validation data lists
    train_files = create_data_list(data_folder_list=data_root,
                                   subject_list=training_list,
                                   img_postfix='_Image',
                                   label_postfix='_Label')

    print(len(train_files))
    print(train_files[0])
    print(train_files[-1])

    val_files = create_data_list(data_folder_list=data_root,
                                 subject_list=validation_list,
                                 img_postfix='_Image',
                                 label_postfix='_Label')
    print(len(val_files))
    print(val_files[0])
    print(val_files[-1])

    # data preprocessing for training:
    # - convert data to right format [batch, channel, dim, dim, dim]
    # - apply whitening
    # - resize to (96, 96) in-plane (preserve z-direction)
    # - define 2D patches to be extracted
    # - add data augmentation (random rotation and random flip)
    # - squeeze to 2D
    train_transforms = Compose([
        LoadNiftid(keys=['img', 'seg']),
        AddChanneld(keys=['img', 'seg']),
        NormalizeIntensityd(keys=['img']),
        Resized(keys=['img', 'seg'],
                spatial_size=[96, 96],
                interp_order=[1, 0],
                anti_aliasing=[True, False]),
        RandSpatialCropd(keys=['img', 'seg'],
                         roi_size=[96, 96, 1],
                         random_size=False),
        RandRotated(keys=['img', 'seg'],
                    degrees=90,
                    prob=0.2,
                    spatial_axes=[0, 1],
                    interp_order=[1, 0],
                    reshape=False),
        RandFlipd(keys=['img', 'seg'], spatial_axis=[0, 1]),
        SqueezeDimd(keys=['img', 'seg'], dim=-1),
        ToTensord(keys=['img', 'seg'])
    ])
    # create a training data loader
    # train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    # train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0,
    #                                    num_workers=num_workers)
    train_ds = monai.data.PersistentDataset(data=train_files,
                                            transform=train_transforms,
                                            cache_dir=persistent_cache)
    train_loader = DataLoader(train_ds,
                              batch_size=batch_size_train,
                              shuffle=True,
                              num_workers=num_workers,
                              collate_fn=list_data_collate,
                              pin_memory=torch.cuda.is_available())
    # check_train_data = monai.utils.misc.first(train_loader)
    # print("Training data tensor shapes")
    # print(check_train_data['img'].shape, check_train_data['seg'].shape)

    # data preprocessing for validation:
    # - convert data to right format [batch, channel, dim, dim, dim]
    # - apply whitening
    # - resize to (96, 96) in-plane (preserve z-direction)
    if sliding_window_validation:
        val_transforms = Compose([
            LoadNiftid(keys=['img', 'seg']),
            AddChanneld(keys=['img', 'seg']),
            NormalizeIntensityd(keys=['img']),
            Resized(keys=['img', 'seg'],
                    spatial_size=[96, 96],
                    interp_order=[1, 0],
                    anti_aliasing=[True, False]),
            ToTensord(keys=['img', 'seg'])
        ])
        do_shuffle = False
        collate_fn_to_use = None
    else:
        # - add extraction of 2D slices from validation set to emulate how loss is computed at training
        val_transforms = Compose([
            LoadNiftid(keys=['img', 'seg']),
            AddChanneld(keys=['img', 'seg']),
            NormalizeIntensityd(keys=['img']),
            Resized(keys=['img', 'seg'],
                    spatial_size=[96, 96],
                    interp_order=[1, 0],
                    anti_aliasing=[True, False]),
            RandSpatialCropd(keys=['img', 'seg'],
                             roi_size=[96, 96, 1],
                             random_size=False),
            SqueezeDimd(keys=['img', 'seg'], dim=-1),
            ToTensord(keys=['img', 'seg'])
        ])
        do_shuffle = True
        collate_fn_to_use = list_data_collate
    # create a validation data loader
    # val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    # val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0,
    #                                    num_workers=num_workers)
    val_ds = monai.data.PersistentDataset(data=val_files,
                                          transform=val_transforms,
                                          cache_dir=persistent_cache)
    val_loader = DataLoader(val_ds,
                            batch_size=batch_size_valid,
                            shuffle=do_shuffle,
                            collate_fn=collate_fn_to_use,
                            num_workers=num_workers)
    # check_valid_data = monai.utils.misc.first(val_loader)
    # print("Validation data tensor shapes")
    # print(check_valid_data['img'].shape, check_valid_data['seg'].shape)
    """
    Network preparation
    """
    # Create UNet, DiceLoss and Adam optimizer.
    net = monai.networks.nets.UNet(
        dimensions=2,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    )

    loss_function = monai.losses.DiceLoss(do_sigmoid=True)
    opt = torch.optim.Adam(net.parameters(), lr)
    device = torch.cuda.current_device()
    if lr_decay is not None:
        lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=opt,
                                                              gamma=lr_decay,
                                                              last_epoch=-1)
    """
    Set ignite trainer
    """

    # function to manage batch at training
    def prepare_batch(batch, device=None, non_blocking=False):
        return _prepare_batch((batch['img'], batch['seg']), device,
                              non_blocking)

    trainer = create_supervised_trainer(model=net,
                                        optimizer=opt,
                                        loss_fn=loss_function,
                                        device=device,
                                        non_blocking=False,
                                        prepare_batch=prepare_batch)

    # adding checkpoint handler to save models (network params and optimizer stats) during training
    if model_to_load is not None:
        checkpoint_handler = CheckpointLoader(load_path=model_to_load,
                                              load_dict={
                                                  'net': net,
                                                  'opt': opt,
                                              })
        checkpoint_handler.attach(trainer)
        state = trainer.state_dict()
    else:
        checkpoint_handler = ModelCheckpoint(out_model_dir,
                                             'net',
                                             n_saved=max_nr_models_saved,
                                             require_empty=False)
        # trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=save_params)
        trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                                  handler=checkpoint_handler,
                                  to_save={
                                      'net': net,
                                      'opt': opt
                                  })

    # StatsHandler prints loss at every iteration and print metrics at every epoch
    train_stats_handler = StatsHandler(name='trainer')
    train_stats_handler.attach(trainer)

    # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler
    writer_train = SummaryWriter(log_dir=os.path.join(out_model_dir, "train"))
    train_tensorboard_stats_handler = TensorBoardStatsHandler(
        summary_writer=writer_train)
    train_tensorboard_stats_handler.attach(trainer)

    if lr_decay is not None:
        print("Using Exponential LR decay")
        lr_schedule_handler = LrScheduleHandler(lr_scheduler,
                                                print_lr=True,
                                                name="lr_scheduler",
                                                writer=writer_train)
        lr_schedule_handler.attach(trainer)
    """
    Set ignite evaluator to perform validation at training
    """
    # set parameters for validation
    metric_name = 'Mean_Dice'
    # add evaluation metric to the evaluator engine
    val_metrics = {
        "Loss": 1.0 - MeanDice(add_sigmoid=True, to_onehot_y=False),
        "Mean_Dice": MeanDice(add_sigmoid=True, to_onehot_y=False)
    }

    def _sliding_window_processor(engine, batch):
        net.eval()
        with torch.no_grad():
            val_images, val_labels = batch['img'].to(device), batch['seg'].to(
                device)
            roi_size = (96, 96, 1)
            seg_probs = sliding_window_inference(val_images, roi_size,
                                                 batch_size_valid, net)
            return seg_probs, val_labels

    if sliding_window_validation:
        # use sliding window inference at validation
        print("3D evaluator is used")
        net.to(device)
        evaluator = Engine(_sliding_window_processor)
        for name, metric in val_metrics.items():
            metric.attach(evaluator, name)
    else:
        # ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration,
        # user can add output_transform to return other values
        print("2D evaluator is used")
        evaluator = create_supervised_evaluator(model=net,
                                                metrics=val_metrics,
                                                device=device,
                                                non_blocking=True,
                                                prepare_batch=prepare_batch)

    epoch_len = len(train_ds) // train_loader.batch_size
    validation_every_n_iters = validation_every_n_epochs * epoch_len

    @trainer.on(Events.ITERATION_COMPLETED(every=validation_every_n_iters))
    def run_validation(engine):
        evaluator.run(val_loader)

    # add early stopping handler to evaluator
    # early_stopper = EarlyStopping(patience=4,
    #                               score_function=stopping_fn_from_metric(metric_name),
    #                               trainer=trainer)
    # evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper)

    # add stats event handler to print validation stats via evaluator
    val_stats_handler = StatsHandler(
        name='evaluator',
        output_transform=lambda x:
        None,  # no need to print loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.epoch
    )  # fetch global epoch number from trainer
    val_stats_handler.attach(evaluator)

    # add handler to record metrics to TensorBoard at every validation epoch
    writer_valid = SummaryWriter(log_dir=os.path.join(out_model_dir, "valid"))
    val_tensorboard_stats_handler = TensorBoardStatsHandler(
        summary_writer=writer_valid,
        output_transform=lambda x:
        None,  # no need to plot loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.iteration
    )  # fetch global iteration number from trainer
    val_tensorboard_stats_handler.attach(evaluator)

    # add handler to draw the first image and the corresponding label and model output in the last batch
    # here we draw the 3D output as GIF format along the depth axis, every 2 validation iterations.
    if val_image_to_tensorboad:
        val_tensorboard_image_handler = TensorBoardImageHandler(
            summary_writer=writer_valid,
            batch_transform=lambda batch: (batch['img'], batch['seg']),
            output_transform=lambda output: predict_segmentation(output[0]),
            global_iter_transform=lambda x: trainer.state.epoch)
        evaluator.add_event_handler(
            event_name=Events.ITERATION_COMPLETED(every=1),
            handler=val_tensorboard_image_handler)
    """
    Run training
    """
    state = trainer.run(train_loader, nr_train_epochs)
    print("Done!")
Esempio n. 26
0
def run_training(train_file_list, valid_file_list, config_info):
    """
    Pipeline to train a dynUNet segmentation model in MONAI. It is composed of the following main blocks:
        * Data Preparation: Extract the filenames and prepare the training/validation processing transforms
        * Load Data: Load training and validation data to PyTorch DataLoader
        * Network Preparation: Define the network, loss function, optimiser and learning rate scheduler
        * MONAI Evaluator: Initialise the dynUNet evaluator, i.e. the class providing utilities to perform validation
            during training. Attach handlers to save the best model on the validation set. A 2D sliding window approach
            on the 3D volume is used at evaluation. The mean 3D Dice is used as validation metric.
        * MONAI Trainer: Initialise the dynUNet trainer, i.e. the class providing utilities to perform the training loop.
        * Run training: The MONAI trainer is run, performing training and validation during training.
    Args:
        train_file_list: .txt or .csv file (with no header) storing two-columns filenames for training:
            image filename in the first column and segmentation filename in the second column.
            The two columns should be separated by a comma.
            See monaifbs/config/mock_train_file_list_for_dynUnet_training.txt for an example of the expected format.
        valid_file_list: .txt or .csv file (with no header) storing two-columns filenames for validation:
            image filename in the first column and segmentation filename in the second column.
            The two columns should be separated by a comma.
            See monaifbs/config/mock_valid_file_list_for_dynUnet_training.txt for an example of the expected format.
        config_info: dict, contains configuration parameters for sampling, network and training.
            See monaifbs/config/monai_dynUnet_training_config.yml for an example of the expected fields.
    """

    """
    Read input and configuration parameters
    """
    # print MONAI config information
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
    print_config()

    # print to log the parameter setups
    print(yaml.dump(config_info))

    # extract network parameters, perform checks/set defaults if not present and print them to log
    if 'seg_labels' in config_info['training'].keys():
        seg_labels = config_info['training']['seg_labels']
    else:
        seg_labels = [1]
    nr_out_channels = len(seg_labels)
    print("Considering the following {} labels in the segmentation: {}".format(nr_out_channels, seg_labels))
    patch_size = config_info["training"]["inplane_size"] + [1]
    print("Considering patch size = {}".format(patch_size))

    spacing = config_info["training"]["spacing"]
    print("Bringing all images to spacing = {}".format(spacing))

    if 'model_to_load' in config_info['training'].keys() and config_info['training']['model_to_load'] is not None:
        model_to_load = config_info['training']['model_to_load']
        if not os.path.exists(model_to_load):
            raise FileNotFoundError("Cannot find model: {}".format(model_to_load))
        else:
            print("Loading model from {}".format(model_to_load))
    else:
        model_to_load = None

    # set up either GPU or CPU usage
    if torch.cuda.is_available():
        print("\n#### GPU INFORMATION ###")
        print("Using device number: {}, name: {}\n".format(torch.cuda.current_device(), torch.cuda.get_device_name()))
        current_device = torch.device("cuda:0")
    else:
        current_device = torch.device("cpu")
        print("Using device: {}".format(current_device))

    # set determinism if required
    if 'manual_seed' in config_info['training'].keys() and config_info['training']['manual_seed'] is not None:
        seed = config_info['training']['manual_seed']
    else:
        seed = None
    if seed is not None:
        print("Using determinism with seed = {}\n".format(seed))
        set_determinism(seed=seed)

    """
    Setup data output directory
    """
    out_model_dir = os.path.join(config_info['output']['out_dir'],
                                 datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_' +
                                 config_info['output']['out_postfix'])
    print("Saving to directory {}\n".format(out_model_dir))
    # create cache directory to store results for Persistent Dataset
    if 'cache_dir' in config_info['output'].keys():
        out_cache_dir = config_info['output']['cache_dir']
    else:
        out_cache_dir = os.path.join(out_model_dir, 'persistent_cache')
    persistent_cache: Path = Path(out_cache_dir)
    persistent_cache.mkdir(parents=True, exist_ok=True)

    """
    Data preparation
    """
    # Read the input files for training and validation
    print("*** Loading input data for training...")

    train_files = create_data_list_of_dictionaries(train_file_list)
    print("Number of inputs for training = {}".format(len(train_files)))

    val_files = create_data_list_of_dictionaries(valid_file_list)
    print("Number of inputs for validation = {}".format(len(val_files)))

    # Define MONAI processing transforms for the training data. This includes:
    # - Load Nifti files and convert to format Batch x Channel x Dim1 x Dim2 x Dim3
    # - CropForegroundd: Reduce the background from the MR image
    # - InPlaneSpacingd: Perform in-plane resampling to the desired spacing, but preserve the resolution along the
    #       last direction (lowest resolution) to avoid introducing motion artefact resampling errors
    # - SpatialPadd: Pad the in-plane size to the defined network input patch size [N, M] if needed
    # - NormalizeIntensityd: Apply whitening
    # - RandSpatialCropd: Crop a random patch from the input with size [B, C, N, M, 1]
    # - SqueezeDimd: Convert the 3D patch to a 2D one as input to the network (i.e. bring it to size [B, C, N, M])
    # - Apply data augmentation (RandZoomd, RandRotated, RandGaussianNoised, RandGaussianSmoothd, RandScaleIntensityd,
    #       RandFlipd)
    # - ToTensor: convert to pytorch tensor
    train_transforms = Compose(
        [
            LoadNiftid(keys=["image", "label"]),
            AddChanneld(keys=["image", "label"]),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            InPlaneSpacingd(
                keys=["image", "label"],
                pixdim=spacing,
                mode=("bilinear", "nearest"),
            ),
            SpatialPadd(keys=["image", "label"], spatial_size=patch_size,
                        mode=["constant", "edge"]),
            NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True),
            RandSpatialCropd(keys=["image", "label"], roi_size=patch_size, random_size=False),
            SqueezeDimd(keys=["image", "label"], dim=-1),
            RandZoomd(
                keys=["image", "label"],
                min_zoom=0.9,
                max_zoom=1.2,
                mode=("bilinear", "nearest"),
                align_corners=(True, None),
                prob=0.16,
            ),
            RandRotated(keys=["image", "label"], range_x=90, range_y=90, prob=0.2,
                        keep_size=True, mode=["bilinear", "nearest"],
                        padding_mode=["zeros", "border"]),
            RandGaussianNoised(keys=["image"], std=0.01, prob=0.15),
            RandGaussianSmoothd(
                keys=["image"],
                sigma_x=(0.5, 1.15),
                sigma_y=(0.5, 1.15),
                sigma_z=(0.5, 1.15),
                prob=0.15,
            ),
            RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.15),
            RandFlipd(["image", "label"], spatial_axis=[0, 1], prob=0.5),
            ToTensord(keys=["image", "label"]),
        ]
    )

    # Define MONAI processing transforms for the validation data
    # - Load Nifti files and convert to format Batch x Channel x Dim1 x Dim2 x Dim3
    # - CropForegroundd: Reduce the background from the MR image
    # - InPlaneSpacingd: Perform in-plane resampling to the desired spacing, but preserve the resolution along the
    #       last direction (lowest resolution) to avoid introducing motion artefact resampling errors
    # - SpatialPadd: Pad the in-plane size to the defined network input patch size [N, M] if needed
    # - NormalizeIntensityd: Apply whitening
    # - ToTensor: convert to pytorch tensor
    # NOTE: The validation data is kept 3D as a 2D sliding window approach is used throughout the volume at inference
    val_transforms = Compose(
        [
            LoadNiftid(keys=["image", "label"]),
            AddChanneld(keys=["image", "label"]),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            InPlaneSpacingd(
                keys=["image", "label"],
                pixdim=spacing,
                mode=("bilinear", "nearest"),
            ),
            SpatialPadd(keys=["image", "label"], spatial_size=patch_size, mode=["constant", "edge"]),
            NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True),
            ToTensord(keys=["image", "label"]),
        ]
    )

    """
    Load data 
    """
    # create training data loader
    train_ds = PersistentDataset(data=train_files, transform=train_transforms,
                                 cache_dir=persistent_cache)
    train_loader = DataLoader(train_ds,
                              batch_size=config_info['training']['batch_size_train'],
                              shuffle=True,
                              num_workers=config_info['device']['num_workers'])
    check_train_data = misc.first(train_loader)
    print("Training data tensor shapes:")
    print("Image = {}; Label = {}".format(check_train_data["image"].shape, check_train_data["label"].shape))

    # create validation data loader
    if config_info['training']['batch_size_valid'] != 1:
        raise Exception("Batch size different from 1 at validation ar currently not supported")
    val_ds = PersistentDataset(data=val_files, transform=val_transforms, cache_dir=persistent_cache)
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            shuffle=False,
                            num_workers=config_info['device']['num_workers'])
    check_valid_data = misc.first(val_loader)
    print("Validation data tensor shapes (Example):")
    print("Image = {}; Label = {}\n".format(check_valid_data["image"].shape, check_valid_data["label"].shape))

    """
    Network preparation
    """
    print("*** Preparing the network ...")
    # automatically extracts the strides and kernels based on nnU-Net empirical rules
    spacings = spacing[:2]
    sizes = patch_size[:2]
    strides, kernels = [], []
    while True:
        spacing_ratio = [sp / min(spacings) for sp in spacings]
        stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)]
        kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
        if all(s == 1 for s in stride):
            break
        sizes = [i / j for i, j in zip(sizes, stride)]
        spacings = [i * j for i, j in zip(spacings, stride)]
        kernels.append(kernel)
        strides.append(stride)
    strides.insert(0, len(spacings) * [1])
    kernels.append(len(spacings) * [3])

    # initialise the network
    net = DynUNet(
        spatial_dims=2,
        in_channels=1,
        out_channels=nr_out_channels,
        kernel_size=kernels,
        strides=strides,
        upsample_kernel_size=strides[1:],
        norm_name="instance",
        deep_supervision=True,
        deep_supr_num=2,
        res_block=False,
    ).to(current_device)
    print(net)

    # define the loss function
    loss_function = choose_loss_function(nr_out_channels, config_info)

    # define the optimiser and the learning rate scheduler
    opt = torch.optim.SGD(net.parameters(), lr=float(config_info['training']['lr']), momentum=0.95)
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        opt, lr_lambda=lambda epoch: (1 - epoch / config_info['training']['nr_train_epochs']) ** 0.9
    )

    """
    MONAI evaluator
    """
    print("*** Preparing the dynUNet evaluator engine...\n")
    # val_post_transforms = Compose(
    #     [
    #         Activationsd(keys="pred", sigmoid=True),
    #     ]
    # )
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        TensorBoardStatsHandler(log_dir=os.path.join(out_model_dir, "valid"),
                                output_transform=lambda x: None,
                                global_epoch_transform=lambda x: trainer.state.iteration),
        CheckpointSaver(save_dir=out_model_dir, save_dict={"net": net, "opt": opt}, save_key_metric=True,
                        file_prefix='best_valid'),
    ]
    if config_info['output']['val_image_to_tensorboad']:
        val_handlers.append(TensorBoardImageHandler(log_dir=os.path.join(out_model_dir, "valid"),
                                                    batch_transform=lambda x: (x["image"], x["label"]),
                                                    output_transform=lambda x: x["pred"], interval=2))

    # Define customized evaluator
    class DynUNetEvaluator(SupervisedEvaluator):
        def _iteration(self, engine, batchdata):
            inputs, targets = self.prepare_batch(batchdata)
            inputs, targets = inputs.to(engine.state.device), targets.to(engine.state.device)
            flip_inputs_1 = torch.flip(inputs, dims=(2,))
            flip_inputs_2 = torch.flip(inputs, dims=(3,))
            flip_inputs_3 = torch.flip(inputs, dims=(2, 3))

            def _compute_pred():
                pred = self.inferer(inputs, self.network)
                # use random flipping as data augmentation at inference
                flip_pred_1 = torch.flip(self.inferer(flip_inputs_1, self.network), dims=(2,))
                flip_pred_2 = torch.flip(self.inferer(flip_inputs_2, self.network), dims=(3,))
                flip_pred_3 = torch.flip(self.inferer(flip_inputs_3, self.network), dims=(2, 3))
                return (pred + flip_pred_1 + flip_pred_2 + flip_pred_3) / 4

            # execute forward computation
            self.network.eval()
            with torch.no_grad():
                if self.amp:
                    with torch.cuda.amp.autocast():
                        predictions = _compute_pred()
                else:
                    predictions = _compute_pred()
            return {"image": inputs, "label": targets, "pred": predictions}

    evaluator = DynUNetEvaluator(
        device=current_device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer2D(roi_size=patch_size, sw_batch_size=4, overlap=0.0),
        post_transform=None,
        key_val_metric={
            "Mean_dice": MeanDice(
                include_background=False,
                to_onehot_y=True,
                mutually_exclusive=True,
                output_transform=lambda x: (x["pred"], x["label"]),
            )
        },
        val_handlers=val_handlers,
        amp=False,
    )

    """
    MONAI trainer
    """
    print("*** Preparing the dynUNet trainer engine...\n")
    # train_post_transforms = Compose(
    #     [
    #         Activationsd(keys="pred", sigmoid=True),
    #     ]
    # )

    validation_every_n_epochs = config_info['training']['validation_every_n_epochs']
    epoch_len = len(train_ds) // train_loader.batch_size
    validation_every_n_iters = validation_every_n_epochs * epoch_len

    # define event handlers for the trainer
    writer_train = SummaryWriter(log_dir=os.path.join(out_model_dir, "train"))
    train_handlers = [
        LrScheduleHandler(lr_scheduler=scheduler, print_lr=True),
        ValidationHandler(validator=evaluator, interval=validation_every_n_iters, epoch_level=False),
        StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]),
        TensorBoardStatsHandler(summary_writer=writer_train,
                                log_dir=os.path.join(out_model_dir, "train"), tag_name="Loss",
                                output_transform=lambda x: x["loss"],
                                global_epoch_transform=lambda x: trainer.state.iteration),
        CheckpointSaver(save_dir=out_model_dir, save_dict={"net": net, "opt": opt},
                        save_final=True,
                        save_interval=2, epoch_level=True,
                        n_saved=config_info['output']['max_nr_models_saved']),
    ]
    if model_to_load is not None:
        train_handlers.append(CheckpointLoader(load_path=model_to_load, load_dict={"net": net, "opt": opt}))

    # define customized trainer
    class DynUNetTrainer(SupervisedTrainer):
        def _iteration(self, engine, batchdata):
            inputs, targets = self.prepare_batch(batchdata)
            inputs, targets = inputs.to(engine.state.device), targets.to(engine.state.device)

            def _compute_loss(preds, label):
                labels = [label] + [interpolate(label, pred.shape[2:]) for pred in preds[1:]]
                return sum([0.5 ** i * self.loss_function(p, l) for i, (p, l) in enumerate(zip(preds, labels))])

            self.network.train()
            self.optimizer.zero_grad()
            if self.amp and self.scaler is not None:
                with torch.cuda.amp.autocast():
                    predictions = self.inferer(inputs, self.network)
                    loss = _compute_loss(predictions, targets)
                self.scaler.scale(loss).backward()
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                predictions = self.inferer(inputs, self.network)
                loss = _compute_loss(predictions, targets).mean()
                loss.backward()
                self.optimizer.step()
            return {"image": inputs, "label": targets, "pred": predictions, "loss": loss.item()}

    trainer = DynUNetTrainer(
        device=current_device,
        max_epochs=config_info['training']['nr_train_epochs'],
        train_data_loader=train_loader,
        network=net,
        optimizer=opt,
        loss_function=loss_function,
        inferer=SimpleInferer(),
        post_transform=None,
        key_train_metric=None,
        train_handlers=train_handlers,
        amp=False,
    )

    """
    Run training
    """
    print("*** Run training...")
    trainer.run()
    print("Done!")
def main(config):
    now = datetime.now().strftime("%Y%m%d-%H:%M:%S")

    # path
    csv_path = config['path']['csv_path']

    trained_model_path = config['path'][
        'trained_model_path']  # if None, trained from scratch
    training_model_folder = os.path.join(
        config['path']['training_model_folder'], now)  # '/path/to/folder'
    if not os.path.exists(training_model_folder):
        os.makedirs(training_model_folder)
    logdir = os.path.join(training_model_folder, 'logs')
    if not os.path.exists(logdir):
        os.makedirs(logdir)

    # PET CT scan params
    image_shape = tuple(config['preprocessing']['image_shape'])  # (x, y, z)
    in_channels = config['preprocessing']['in_channels']
    voxel_spacing = tuple(
        config['preprocessing']
        ['voxel_spacing'])  # (4.8, 4.8, 4.8)  # in millimeter, (x, y, z)
    data_augment = config['preprocessing'][
        'data_augment']  # True  # for training dataset only
    resize = config['preprocessing']['resize']  # True  # not use yet
    origin = config['preprocessing']['origin']  # how to set the new origin
    normalize = config['preprocessing'][
        'normalize']  # True  # whether or not to normalize the inputs
    number_class = config['preprocessing']['number_class']  # 2

    # CNN params
    architecture = config['model']['architecture']  # 'unet' or 'vnet'

    cnn_params = config['model'][architecture]['cnn_params']
    # transform list to tuple
    for key, value in cnn_params.items():
        if isinstance(value, list):
            cnn_params[key] = tuple(value)

    # Training params
    epochs = config['training']['epochs']
    batch_size = config['training']['batch_size']
    shuffle = config['training']['shuffle']
    opt_params = config['training']["optimizer"]["opt_params"]

    # Get Data
    DM = DataManager(csv_path=csv_path)
    train_images_paths, val_images_paths, test_images_paths = DM.get_train_val_test(
        wrap_with_dict=True)

    # Input preprocessing
    # use data augmentation for training
    train_transforms = Compose([  # read img + meta info
        LoadNifti(keys=["pet_img", "ct_img", "mask_img"]),
        Roi2Mask(keys=['pet_img', 'mask_img'],
                 method='otsu',
                 tval=0.0,
                 idx_channel=0),
        ResampleReshapeAlign(target_shape=image_shape,
                             target_voxel_spacing=voxel_spacing,
                             keys=['pet_img', "ct_img", 'mask_img'],
                             origin='head',
                             origin_key='pet_img'),
        Sitk2Numpy(keys=['pet_img', 'ct_img', 'mask_img']),
        # user can also add other random transforms
        RandAffined(keys=("pet_img", "ct_img", "mask_img"),
                    spatial_size=None,
                    prob=0.4,
                    rotate_range=(0, np.pi / 30, np.pi / 15),
                    shear_range=None,
                    translate_range=(10, 10, 10),
                    scale_range=(0.1, 0.1, 0.1),
                    mode=("bilinear", "bilinear", "nearest"),
                    padding_mode="border"),
        # normalize input
        ScaleIntensityRanged(
            keys=["pet_img"],
            a_min=0.0,
            a_max=25.0,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        ScaleIntensityRanged(
            keys=["ct_img"],
            a_min=-1000.0,
            a_max=1000.0,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        # Prepare for neural network
        ConcatModality(keys=['pet_img', 'ct_img']),
        AddChanneld(keys=["mask_img"]),  # Add channel to the first axis
        ToTensord(keys=["image", "mask_img"]),
    ])
    # without data augmentation for validation
    val_transforms = Compose([  # read img + meta info
        LoadNifti(keys=["pet_img", "ct_img", "mask_img"]),
        Roi2Mask(keys=['pet_img', 'mask_img'],
                 method='otsu',
                 tval=0.0,
                 idx_channel=0),
        ResampleReshapeAlign(target_shape=image_shape,
                             target_voxel_spacing=voxel_spacing,
                             keys=['pet_img', "ct_img", 'mask_img'],
                             origin='head',
                             origin_key='pet_img'),
        Sitk2Numpy(keys=['pet_img', 'ct_img', 'mask_img']),
        # normalize input
        ScaleIntensityRanged(
            keys=["pet_img"],
            a_min=0.0,
            a_max=25.0,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        ScaleIntensityRanged(
            keys=["ct_img"],
            a_min=-1000.0,
            a_max=1000.0,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        # Prepare for neural network
        ConcatModality(keys=['pet_img', 'ct_img']),
        AddChanneld(keys=["mask_img"]),  # Add channel to the first axis
        ToTensord(keys=["image", "mask_img"]),
    ])

    # create a training data loader
    train_ds = monai.data.CacheDataset(data=train_images_paths,
                                       transform=train_transforms,
                                       cache_rate=0.5)
    # use batch_size=2 to load images to generate 2 x 4 images for network training
    train_loader = monai.data.DataLoader(train_ds,
                                         batch_size=batch_size,
                                         shuffle=shuffle,
                                         num_workers=2)
    # create a validation data loader
    val_ds = monai.data.CacheDataset(data=val_images_paths,
                                     transform=val_transforms,
                                     cache_rate=1.0)
    val_loader = monai.data.DataLoader(val_ds,
                                       batch_size=batch_size,
                                       num_workers=2)

    # Model
    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = UNet(
        dimensions=3,  # 3D
        in_channels=in_channels,
        out_channels=1,
        kernel_size=5,
        channels=(8, 16, 32, 64, 128),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss = monai.losses.DiceLoss(sigmoid=True, squared_pred=True)
    opt = torch.optim.Adam(net.parameters(), 1e-3)

    # training
    val_post_transforms = Compose([
        Activationsd(keys="pred", sigmoid=True),
        AsDiscreted(keys="pred", threshold_values=True),
    ])
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        TensorBoardStatsHandler(log_dir="./runs/",
                                output_transform=lambda x: None),
        # TensorBoardImageHandler(
        #     log_dir="./runs/",
        #     batch_transform=lambda x: (x["image"], x["label"]),
        #     output_transform=lambda x: x["pred"],
        # ),
        CheckpointSaver(save_dir="./runs/",
                        save_dict={
                            "net": net,
                            "opt": opt
                        },
                        save_key_metric=True),
    ]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SimpleInferer(),
        post_transform=val_post_transforms,
        key_val_metric={
            "val_mean_dice":
            MeanDice(include_background=True,
                     output_transform=lambda x: (x["pred"], x["label"]))
        },
        additional_metrics={
            "val_acc":
            Accuracy(output_transform=lambda x: (x["pred"], x["label"])),
            "val_precision":
            Precision(output_transform=lambda x: (x["pred"], x["label"])),
            "val_recall":
            Recall(output_transform=lambda x: (x["pred"], x["label"]))
        },
        val_handlers=val_handlers,
        # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation
        # amp=True if monai.config.get_torch_version_tuple() >= (1, 6) else False,
    )

    train_post_transforms = Compose([
        Activationsd(keys="pred", sigmoid=True),
        AsDiscreted(keys="pred", threshold_values=True),
    ])
    train_handlers = [
        # LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
        ValidationHandler(validator=evaluator, interval=1, epoch_level=True),
        StatsHandler(tag_name="train_loss",
                     output_transform=lambda x: x["loss"]),
        TensorBoardStatsHandler(log_dir="./runs/",
                                tag_name="train_loss",
                                output_transform=lambda x: x["loss"]),
        CheckpointSaver(save_dir="./runs/",
                        save_dict={
                            "net": net,
                            "opt": opt
                        },
                        save_interval=2,
                        epoch_level=True),
    ]

    trainer = SupervisedTrainer(
        device=device,
        max_epochs=5,
        train_data_loader=train_loader,
        network=net,
        optimizer=opt,
        loss_function=loss,
        prepare_batch=lambda x: (x['image'], x['mask_img']),
        inferer=SimpleInferer(),
        post_transform=train_post_transforms,
        key_train_metric={
            "train_mean_dice":
            MeanDice(include_background=True,
                     output_transform=lambda x: (x["pred"], x["label"]))
        },
        additional_metrics={
            "train_acc":
            Accuracy(output_transform=lambda x: (x["pred"], x["label"])),
            "train_precision":
            Precision(output_transform=lambda x: (x["pred"], x["label"])),
            "train_recall":
            Recall(output_transform=lambda x: (x["pred"], x["label"]))
        },
        train_handlers=train_handlers,
        # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP training
        amp=True if monai.config.get_torch_version_tuple() >=
        (1, 6) else False,
    )
    trainer.run()