Ejemplo n.º 1
0
    def run(self) -> None:

        testing_dir = os.path.join(self.out_dir, 'testing')
        if not os.path.exists(testing_dir):
            os.mkdir(testing_dir)
        basename = os.path.basename(self.load_dir)
        self.output_dir = os.path.join(testing_dir, basename)
        self._register_handlers([
            MetricsSaver(save_dir=self.output_dir,
                         metrics=['Test_AUC', 'Test_ACC'])
        ])

        super().run()
Ejemplo n.º 2
0
    def run(self, global_epoch: int) -> None:

        if global_epoch == 1:
            handlers = [
                StatsHandler(),
                TensorBoardStatsHandler(
                    summary_writer=self.summary_writer
                ),  #, output_transform=lambda x: None),
                CheckpointSaver(save_dir=self.output_dir,
                                save_dict={"network": self.network},
                                save_key_metric=True),
                MetricsSaver(save_dir=self.output_dir,
                             metrics=['Valid_AUC', 'Valid_ACC']),
                self.early_stop_handler,
            ]
            self._register_handlers(handlers)

        return super().run(global_epoch=global_epoch)
    def _run(self, tempdir):
        fnames = ["aaa" * 300, "bbb" * 301, "ccc" * 302]

        metrics_saver = MetricsSaver(
            save_dir=tempdir,
            metrics=["metric1", "metric2"],
            metric_details=["metric3", "metric4"],
            batch_transform=lambda x: x["image_meta_dict"],
            summary_ops="*",
        )

        def _val_func(engine, batch):
            pass

        engine = Engine(_val_func)

        if dist.get_rank() == 0:
            data = [{"image_meta_dict": {"filename_or_obj": [fnames[0]]}}]

            @engine.on(Events.EPOCH_COMPLETED)
            def _save_metrics0(engine):
                engine.state.metrics = {"metric1": 1, "metric2": 2}
                engine.state.metric_details = {
                    "metric3": torch.tensor([[1, 2]]),
                    "metric4": torch.tensor([[5, 6]]),
                }

        if dist.get_rank() == 1:
            # different ranks have different data length
            data = [
                {
                    "image_meta_dict": {
                        "filename_or_obj": [fnames[1]]
                    }
                },
                {
                    "image_meta_dict": {
                        "filename_or_obj": [fnames[2]]
                    }
                },
            ]

            @engine.on(Events.EPOCH_COMPLETED)
            def _save_metrics1(engine):
                engine.state.metrics = {"metric1": 1, "metric2": 2}
                engine.state.metric_details = {
                    "metric3": torch.tensor([[2, 3], [3, 4]]),
                    "metric4": torch.tensor([[6, 7], [7, 8]]),
                }

        metrics_saver.attach(engine)
        engine.run(data, max_epochs=1)

        if dist.get_rank() == 0:
            # check the metrics.csv and content
            self.assertTrue(
                os.path.exists(os.path.join(tempdir, "metrics.csv")))
            with open(os.path.join(tempdir, "metrics.csv")) as f:
                f_csv = csv.reader(f)
                for i, row in enumerate(f_csv):
                    self.assertEqual(row, [f"metric{i + 1}\t{i + 1}"])
            self.assertTrue(
                os.path.exists(os.path.join(tempdir, "metric3_raw.csv")))
            # check the metric_raw.csv and content
            with open(os.path.join(tempdir, "metric3_raw.csv")) as f:
                f_csv = csv.reader(f)
                for i, row in enumerate(f_csv):
                    if i > 0:
                        expected = [
                            f"{fnames[i-1]}\t{float(i)}\t{float(i + 1)}\t{i + 0.5}"
                        ]
                        self.assertEqual(row, expected)
            self.assertTrue(
                os.path.exists(os.path.join(tempdir, "metric3_summary.csv")))
            # check the metric_summary.csv and content
            with open(os.path.join(tempdir, "metric3_summary.csv")) as f:
                f_csv = csv.reader(f)
                for i, row in enumerate(f_csv):
                    if i == 1:
                        self.assertEqual(row, [
                            "class0\t1.0000\t1.0000\t1.0000\t1.0000\t1.0000\t0.0000"
                        ])
                    elif i == 2:
                        self.assertEqual(row, [
                            "class1\t2.0000\t2.0000\t2.0000\t2.0000\t2.0000\t0.0000"
                        ])
                    elif i == 3:
                        self.assertEqual(row, [
                            "mean\t1.5000\t1.5000\t1.5000\t1.5000\t1.5000\t0.0000"
                        ])
            self.assertTrue(
                os.path.exists(os.path.join(tempdir, "metric4_raw.csv")))
            self.assertTrue(
                os.path.exists(os.path.join(tempdir, "metric4_summary.csv")))
Ejemplo n.º 4
0
    def train(self,
              train_info,
              valid_info,
              hyperparameters,
              run_data_check=False):

        logging.basicConfig(stream=sys.stdout, level=logging.INFO)

        if not run_data_check:
            start_dt = datetime.datetime.now()
            start_dt_string = start_dt.strftime('%d/%m/%Y %H:%M:%S')
            print(f'Training started: {start_dt_string}')

            # 1. Create folders to save the model
            timedate_info = str(
                datetime.datetime.now()).split(' ')[0] + '_' + str(
                    datetime.datetime.now().strftime("%H:%M:%S")).replace(
                        ':', '-')
            path_to_model = os.path.join(
                self.out_dir, 'trained_models',
                self.unique_name + '_' + timedate_info)
            os.mkdir(path_to_model)

        # 2. Load hyperparameters
        learning_rate = hyperparameters['learning_rate']
        weight_decay = hyperparameters['weight_decay']
        total_epoch = hyperparameters['total_epoch']
        multiplicator = hyperparameters['multiplicator']
        batch_size = hyperparameters['batch_size']
        validation_epoch = hyperparameters['validation_epoch']
        validation_interval = hyperparameters['validation_interval']
        H = hyperparameters['H']
        L = hyperparameters['L']

        # 3. Consider class imbalance
        negative, positive = 0, 0
        for _, label in train_info:
            if int(label) == 0:
                negative += 1
            elif int(label) == 1:
                positive += 1

        pos_weight = torch.Tensor([(negative / positive)]).to(self.device)

        # 4. Create train and validation loaders, batch_size = 10 for validation loader (10 central slices)

        train_data = get_data_from_info(self.image_data_dir, self.seg_data_dir,
                                        train_info)
        valid_data = get_data_from_info(self.image_data_dir, self.seg_data_dir,
                                        valid_info)
        large_image_splitter(train_data, self.cache_dir)

        set_determinism(seed=100)
        train_trans, valid_trans = self.transformations(H, L)
        train_dataset = PersistentDataset(
            data=train_data[:],
            transform=train_trans,
            cache_dir=self.persistent_dataset_dir)
        valid_dataset = PersistentDataset(
            data=valid_data[:],
            transform=valid_trans,
            cache_dir=self.persistent_dataset_dir)

        train_loader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  pin_memory=self.pin_memory,
                                  num_workers=self.num_workers,
                                  collate_fn=PadListDataCollate(
                                      Method.SYMMETRIC, NumpyPadMode.CONSTANT))
        valid_loader = DataLoader(valid_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  pin_memory=self.pin_memory,
                                  num_workers=self.num_workers,
                                  collate_fn=PadListDataCollate(
                                      Method.SYMMETRIC, NumpyPadMode.CONSTANT))

        # Perform data checks
        if run_data_check:
            check_data = monai.utils.misc.first(train_loader)
            print(check_data["image"].shape, check_data["label"])
            for i in range(batch_size):
                multi_slice_viewer(
                    check_data["image"][i, 0, :, :, :],
                    check_data["image_meta_dict"]["filename_or_obj"][i])
            exit()
        """c = 1
        for d in train_loader:
            img = d["image"]
            seg = d["seg"][0]
            seg, _ = nrrd.read(seg)
            img_name = d["image_meta_dict"]["filename_or_obj"][0]
            print(c, "Name:", img_name, "Size:", img.nelement()*img.element_size()/1024/1024, "MB", "shape:", img.shape)
            multi_slice_viewer(img[0, 0, :, :, :], d["image_meta_dict"]["filename_or_obj"][0])
            #multi_slice_viewer(seg, d["image_meta_dict"]["filename_or_obj"][0])
            c += 1
        exit()"""

        # 5. Prepare model
        model = ModelCT().to(self.device)

        # 6. Define loss function, optimizer and scheduler
        loss_function = torch.nn.BCEWithLogitsLoss(
            pos_weight)  # pos_weight for class imbalance
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=learning_rate,
                                     weight_decay=weight_decay)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
                                                           multiplicator,
                                                           last_epoch=-1)
        # 7. Create post validation transforms and handlers
        path_to_tensorboard = os.path.join(self.out_dir, 'tensorboard')
        writer = SummaryWriter(log_dir=path_to_tensorboard)
        valid_post_transforms = Compose([
            Activationsd(keys="pred", sigmoid=True),
        ])
        valid_handlers = [
            StatsHandler(output_transform=lambda x: None),
            TensorBoardStatsHandler(summary_writer=writer,
                                    output_transform=lambda x: None),
            CheckpointSaver(save_dir=path_to_model,
                            save_dict={"model": model},
                            save_key_metric=True),
            MetricsSaver(save_dir=path_to_model,
                         metrics=['Valid_AUC', 'Valid_ACC']),
        ]
        # 8. Create validatior
        discrete = AsDiscrete(threshold_values=True)
        evaluator = SupervisedEvaluator(
            device=self.device,
            val_data_loader=valid_loader,
            network=model,
            post_transform=valid_post_transforms,
            key_val_metric={
                "Valid_AUC":
                ROCAUC(output_transform=lambda x: (x["pred"], x["label"]))
            },
            additional_metrics={
                "Valid_Accuracy":
                Accuracy(output_transform=lambda x:
                         (discrete(x["pred"]), x["label"]))
            },
            val_handlers=valid_handlers,
            amp=self.amp,
        )
        # 9. Create trainer

        # Loss function does the last sigmoid, so we dont need it here.
        train_post_transforms = Compose([
            # Empty
        ])
        logger = MetricLogger(evaluator=evaluator)
        train_handlers = [
            logger,
            LrScheduleHandler(lr_scheduler=scheduler, print_lr=True),
            ValidationHandlerCT(validator=evaluator,
                                start=validation_epoch,
                                interval=validation_interval,
                                epoch_level=True),
            StatsHandler(tag_name="loss",
                         output_transform=lambda x: x["loss"]),
            TensorBoardStatsHandler(summary_writer=writer,
                                    tag_name="Train_Loss",
                                    output_transform=lambda x: x["loss"]),
            CheckpointSaver(save_dir=path_to_model,
                            save_dict={
                                "model": model,
                                "opt": optimizer
                            },
                            save_interval=1,
                            n_saved=1),
        ]

        trainer = SupervisedTrainer(
            device=self.device,
            max_epochs=total_epoch,
            train_data_loader=train_loader,
            network=model,
            optimizer=optimizer,
            loss_function=loss_function,
            post_transform=train_post_transforms,
            train_handlers=train_handlers,
            amp=self.amp,
        )
        # 10. Run trainer
        trainer.run()
        # 11. Save results
        np.save(path_to_model + '/AUCS.npy',
                np.array(logger.metrics['Valid_AUC']))
        np.save(path_to_model + '/ACCS.npy',
                np.array(logger.metrics['Valid_ACC']))
        np.save(path_to_model + '/LOSSES.npy', np.array(logger.loss))
        np.save(path_to_model + '/PARAMETERS.npy', np.array(hyperparameters))

        return path_to_model
Ejemplo n.º 5
0
def train(data_folder=".", model_folder="runs", continue_training=False):
    """run a training pipeline."""

    #/== files for synthesis
    path_parent = Path(
        '/content/drive/My Drive/Datasets/covid19/COVID-19-20_augs_cea/')
    path_synthesis = Path(
        path_parent /
        'CeA_BASE_grow=1_bg=-1.00_step=-1.0_scale=-1.0_seed=1.0_ch0_1=-1_ch1_16=-1_ali_thr=0.1'
    )
    scans_syns = os.listdir(path_synthesis)
    decreasing_sequence = get_decreasing_sequence(255, splits=20)
    keys2 = ("image", "label", "synthetic_lesion")
    # READ THE SYTHETIC HEALTHY TEXTURE
    path_synthesis_old = '/content/drive/My Drive/Datasets/covid19/results/cea_synthesis/patient0/'
    texture_orig = np.load(f'{path_synthesis_old}texture.npy.npz')
    texture_orig = texture_orig.f.arr_0
    texture = texture_orig + np.abs(np.min(texture_orig)) + .07
    texture = np.pad(texture, ((100, 100), (100, 100)), mode='reflect')
    print(f'type(texture) = {type(texture)}, {np.shape(texture)}')
    #==/

    images = sorted(glob.glob(os.path.join(data_folder,
                                           "*_ct.nii.gz"))[:10])  #OMM
    labels = sorted(glob.glob(os.path.join(data_folder,
                                           "*_seg.nii.gz"))[:10])  #OMM
    logging.info(
        f"training: image/label ({len(images)}) folder: {data_folder}")

    amp = True  # auto. mixed precision
    keys = ("image", "label")
    train_frac, val_frac = 0.8, 0.2
    n_train = int(train_frac * len(images)) + 1
    n_val = min(len(images) - n_train, int(val_frac * len(images)))
    logging.info(
        f"training: train {n_train} val {n_val}, folder: {data_folder}")

    train_files = [{
        keys[0]: img,
        keys[1]: seg
    } for img, seg in zip(images[:n_train], labels[:n_train])]
    val_files = [{
        keys[0]: img,
        keys[1]: seg
    } for img, seg in zip(images[-n_val:], labels[-n_val:])]

    # create a training data loader
    batch_size = 1  # XX was 2
    logging.info(f"batch size {batch_size}")
    train_transforms = get_xforms("synthesis", keys, keys2, path_synthesis,
                                  decreasing_sequence, scans_syns, texture)
    train_ds = monai.data.CacheDataset(data=train_files,
                                       transform=train_transforms)
    train_loader = monai.data.DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
        pin_memory=torch.cuda.is_available(),
        # collate_fn=pad_list_data_collate,
    )

    # create a validation data loader
    val_transforms = get_xforms("val", keys)
    val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms)
    val_loader = monai.data.DataLoader(
        val_ds,
        batch_size=
        1,  # image-level batch to the sliding window method, not the window-level batch
        num_workers=2,
        pin_memory=torch.cuda.is_available(),
    )

    # create BasicUNet, DiceLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = get_net().to(device)

    # if continue training
    if continue_training:
        ckpts = sorted(glob.glob(os.path.join(model_folder, "*.pt")))
        ckpt = ckpts[-1]
        logging.info(f"continue training using {ckpt}.")
        net.load_state_dict(torch.load(ckpt, map_location=device))

    # max_epochs, lr, momentum = 500, 1e-4, 0.95
    max_epochs, lr, momentum = 20, 1e-4, 0.95  #OMM
    logging.info(f"epochs {max_epochs}, lr {lr}, momentum {momentum}")
    opt = torch.optim.Adam(net.parameters(), lr=lr)

    # create evaluator (to be used to measure model quality during training
    val_post_transform = monai.transforms.Compose([
        AsDiscreted(keys=("pred", "label"),
                    argmax=(True, False),
                    to_onehot=True,
                    n_classes=2)
    ])
    val_handlers = [
        ProgressBar(),
        MetricsSaver(save_dir="./metrics_val", metrics="*"),
        CheckpointSaver(save_dir=model_folder,
                        save_dict={"net": net},
                        save_key_metric=True,
                        key_metric_n_saved=6),
    ]
    evaluator = monai.engines.SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=get_inferer(),
        post_transform=val_post_transform,
        key_val_metric={
            "val_mean_dice":
            MeanDice(include_background=False,
                     output_transform=lambda x: (x["pred"], x["label"]))
        },
        val_handlers=val_handlers,
        amp=amp,
    )

    # evaluator as an event handler of the trainer
    train_handlers = [
        ValidationHandler(validator=evaluator, interval=1, epoch_level=True),
        # MetricsSaver(save_dir="./metrics_train", metrics="*"),
        StatsHandler(tag_name="train_loss",
                     output_transform=lambda x: x["loss"]),
    ]
    trainer = monai.engines.SupervisedTrainer(
        device=device,
        max_epochs=max_epochs,
        train_data_loader=train_loader,
        network=net,
        optimizer=opt,
        loss_function=DiceCELoss(),
        inferer=get_inferer(),
        key_train_metric=None,
        train_handlers=train_handlers,
        amp=amp,
    )
    trainer.run()
Ejemplo n.º 6
0
    def _run(self, tempdir):
        my_rank = dist.get_rank()
        fnames = ["aaa" * 300, "bbb" * 301, "ccc" * 302]

        metrics_saver = MetricsSaver(
            save_dir=tempdir,
            metrics=["metric1", "metric2"],
            metric_details=["metric3", "metric4"],
            batch_transform=lambda x: x[PostFix.meta("image")],
            summary_ops="*",
            delimiter="\t",
        )

        def _val_func(engine, batch):
            pass

        engine = Engine(_val_func)

        if my_rank == 0:
            data = [{PostFix.meta("image"): {"filename_or_obj": [fnames[0]]}}]

            @engine.on(Events.EPOCH_COMPLETED)
            def _save_metrics0(engine):
                engine.state.metrics = {"metric1": 1, "metric2": 2}
                engine.state.metric_details = {
                    "metric3": torch.tensor([[1, 2]]),
                    "metric4": torch.tensor([[5, 6]])
                }

        if my_rank == 1:
            # different ranks have different data length
            data = [
                {
                    PostFix.meta("image"): {
                        "filename_or_obj": [fnames[1]]
                    }
                },
                {
                    PostFix.meta("image"): {
                        "filename_or_obj": [fnames[2]]
                    }
                },
            ]

            @engine.on(Events.EPOCH_COMPLETED)
            def _save_metrics1(engine):
                engine.state.metrics = {"metric1": 1, "metric2": 2}
                engine.state.metric_details = {
                    "metric3": torch.tensor([[2, 3], [3, 4]]),
                    "metric4": torch.tensor([[6, 7], [7, 8]]),
                }

        @engine.on(Events.EPOCH_COMPLETED)
        def _all_gather(engine):
            scores = engine.state.metric_details["metric3"]
            engine.state.metric_details[
                "metric3"] = evenly_divisible_all_gather(data=scores,
                                                         concat=True)
            scores = engine.state.metric_details["metric4"]
            engine.state.metric_details[
                "metric4"] = evenly_divisible_all_gather(data=scores,
                                                         concat=True)

        metrics_saver.attach(engine)
        engine.run(data, max_epochs=1)

        if my_rank == 0:
            # check the metrics.csv and content
            self.assertTrue(
                os.path.exists(os.path.join(tempdir, "metrics.csv")))
            with open(os.path.join(tempdir, "metrics.csv")) as f:
                f_csv = csv.reader(f)
                for i, row in enumerate(f_csv):
                    self.assertEqual(row, [f"metric{i + 1}\t{i + 1}"])
            self.assertTrue(
                os.path.exists(os.path.join(tempdir, "metric3_raw.csv")))
            # check the metric_raw.csv and content
            with open(os.path.join(tempdir, "metric3_raw.csv")) as f:
                f_csv = csv.reader(f)
                for i, row in enumerate(f_csv):
                    if i > 0:
                        expected = [
                            f"{fnames[i-1]}\t{float(i):.4f}\t{float(i + 1):.4f}\t{i + 0.5:.4f}"
                        ]
                        self.assertEqual(row, expected)
            self.assertTrue(
                os.path.exists(os.path.join(tempdir, "metric3_summary.csv")))
            # check the metric_summary.csv and content
            with open(os.path.join(tempdir, "metric3_summary.csv")) as f:
                f_csv = csv.reader(f)
                for i, row in enumerate(f_csv):
                    if i == 1:
                        self.assertEqual(row, [
                            "class0\t2.0000\t2.0000\t3.0000\t1.0000\t2.8000\t0.8165\t3.0000"
                        ])
                    elif i == 2:
                        self.assertEqual(row, [
                            "class1\t3.0000\t3.0000\t4.0000\t2.0000\t3.8000\t0.8165\t3.0000"
                        ])
                    elif i == 3:
                        self.assertEqual(row, [
                            "mean\t2.5000\t2.5000\t3.5000\t1.5000\t3.3000\t0.8165\t3.0000"
                        ])
            self.assertTrue(
                os.path.exists(os.path.join(tempdir, "metric4_raw.csv")))
            self.assertTrue(
                os.path.exists(os.path.join(tempdir, "metric4_summary.csv")))
        dist.barrier()
Ejemplo n.º 7
0
    def test_content(self):
        with tempfile.TemporaryDirectory() as tempdir:
            metrics_saver = MetricsSaver(
                save_dir=tempdir,
                metrics=["metric1", "metric2"],
                metric_details=["metric3", "metric4"],
                batch_transform=lambda x: x["image_meta_dict"],
                summary_ops=[
                    "mean", "median", "max", "5percentile", "95percentile",
                    "notnans"
                ],
            )
            # set up engine
            data = [
                {
                    "image_meta_dict": {
                        "filename_or_obj": ["filepath1"]
                    }
                },
                {
                    "image_meta_dict": {
                        "filename_or_obj": ["filepath2"]
                    }
                },
            ]

            def _val_func(engine, batch):
                pass

            engine = Engine(_val_func)

            @engine.on(Events.EPOCH_COMPLETED)
            def _save_metrics(engine):
                engine.state.metrics = {"metric1": 1, "metric2": 2}
                engine.state.metric_details = {
                    "metric3":
                    torch.tensor([[1, 2], [2, 3]]),
                    "metric4":
                    torch.tensor([[5, 6], [7, torch.tensor(float("nan"))]]),
                }

            metrics_saver.attach(engine)
            engine.run(data, max_epochs=1)

            # check the metrics.csv and content
            self.assertTrue(
                os.path.exists(os.path.join(tempdir, "metrics.csv")))
            with open(os.path.join(tempdir, "metrics.csv")) as f:
                f_csv = csv.reader(f)
                for i, row in enumerate(f_csv):
                    self.assertEqual(row, [f"metric{i + 1}\t{i + 1}"])
            self.assertTrue(
                os.path.exists(os.path.join(tempdir, "metric3_raw.csv")))
            # check the metric_raw.csv and content
            with open(os.path.join(tempdir, "metric3_raw.csv")) as f:
                f_csv = csv.reader(f)
                for i, row in enumerate(f_csv):
                    if i > 0:
                        self.assertEqual(row, [
                            f"filepath{i}\t{float(i)}\t{float(i + 1)}\t{i + 0.5}"
                        ])
            self.assertTrue(
                os.path.exists(os.path.join(tempdir, "metric3_summary.csv")))
            # check the metric_summary.csv and content
            with open(os.path.join(tempdir, "metric4_summary.csv")) as f:
                f_csv = csv.reader(f)
                for i, row in enumerate(f_csv):
                    if i == 1:
                        self.assertEqual(row, [
                            "class0\t6.0000\t6.0000\t7.0000\t5.1000\t6.9000\t2.0000"
                        ])
                    elif i == 2:
                        self.assertEqual(row, [
                            "class1\t6.0000\t6.0000\t6.0000\t6.0000\t6.0000\t1.0000"
                        ])
                    elif i == 3:
                        self.assertEqual(row, [
                            "mean\t6.2500\t6.2500\t7.0000\t5.5750\t6.9250\t2.0000"
                        ])
            self.assertTrue(
                os.path.exists(os.path.join(tempdir, "metric4_raw.csv")))
            self.assertTrue(
                os.path.exists(os.path.join(tempdir, "metric3_summary.csv")))
Ejemplo n.º 8
0
    post_transform=val_post_transform,
    key_val_metric={
        "val_mean_dice":
        MeanDice(include_background=False,
                 output_transform=lambda x: (x["pred"], x["label"]))
    },
    val_handlers=val_handlers,
    amp=amp,
)

# %%
# evaluator as an event handler of the trainer
train_handlers = [
    ValidationHandler(validator=evaluator, interval=1, epoch_level=True),
    StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]),
    MetricsSaver(save_dir=model_folder, metrics='*')
]
trainer = monai.engines.SupervisedTrainer(
    device=device,
    max_epochs=max_epochs,
    train_data_loader=train_loader,
    network=net,
    optimizer=opt,
    loss_function=DiceCELoss(),
    inferer=get_inferer(),
    key_train_metric=None,
    # key_train_metric={
    #         "train_mean_dice": MeanDice(include_background=False, output_transform=lambda x: (x["pred"], x["label"]))
    #     },
    train_handlers=train_handlers,
    amp=amp,
Ejemplo n.º 9
0
def train(data_folder=".", model_folder="runs", continue_training=False):
    """run a training pipeline."""

    images = sorted(glob.glob(os.path.join(data_folder, "*_ct.nii.gz"))) #OMM
    labels = sorted(glob.glob(os.path.join(data_folder, "*_seg.nii.gz"))) #OMM
    logging.info(f"training: image/label ({len(images)}) folder: {data_folder}")

    amp = True  # auto. mixed precision
    keys = ("image", "label")
    train_frac, val_frac = 0.8, 0.2
    n_train = int(train_frac * len(images)) + 1
    n_val = min(len(images) - n_train, int(val_frac * len(images)))
    logging.info(f"training: train {n_train} val {n_val}, folder: {data_folder}")

    train_files = [{keys[0]: img, keys[1]: seg} for img, seg in zip(images[:n_train], labels[:n_train])]
    val_files = [{keys[0]: img, keys[1]: seg} for img, seg in zip(images[-n_val:], labels[-n_val:])]

    # create a training data loader
    batch_size = 2
    logging.info(f"batch size {batch_size}")
    train_transforms = get_xforms("train", keys)
    train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms)
    train_loader = monai.data.DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
        pin_memory=torch.cuda.is_available(),
    )

    # create a validation data loader
    val_transforms = get_xforms("val", keys)
    val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms)
    val_loader = monai.data.DataLoader(
        val_ds,
        batch_size=1,  # image-level batch to the sliding window method, not the window-level batch
        num_workers=2,
        pin_memory=torch.cuda.is_available(),
    )

    # create BasicUNet, DiceLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = get_net().to(device)

    # if continue training
    if continue_training:
        ckpts = sorted(glob.glob(os.path.join(model_folder, "*.pt")))
        ckpt = ckpts[-1]
        logging.info(f"continue training using {ckpt}.")
        net.load_state_dict(torch.load(ckpt, map_location=device))

    # max_epochs, lr, momentum = 500, 1e-4, 0.95
    max_epochs, lr, momentum = 20, 1e-4, 0.95 #OMM
    logging.info(f"epochs {max_epochs}, lr {lr}, momentum {momentum}")
    opt = torch.optim.Adam(net.parameters(), lr=lr)

    # create evaluator (to be used to measure model quality during training
    val_post_transform = monai.transforms.Compose(
        [AsDiscreted(keys=("pred", "label"), argmax=(True, False), to_onehot=True, n_classes=2)]
    )
    val_handlers = [
        ProgressBar(),
        MetricsSaver(save_dir="./metrics_val", metrics="*"),
        CheckpointSaver(save_dir=model_folder, save_dict={"net": net}, save_key_metric=True, key_metric_n_saved=6),
    ]
    evaluator = monai.engines.SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=get_inferer(),
        post_transform=val_post_transform,
        key_val_metric={
            "val_mean_dice": MeanDice(include_background=False, output_transform=lambda x: (x["pred"], x["label"]))
        },
        val_handlers=val_handlers,
        amp=amp,
    )

    # evaluator as an event handler of the trainer
    train_handlers = [
        ValidationHandler(validator=evaluator, interval=1, epoch_level=True),
        # MetricsSaver(save_dir="./metrics_train", metrics="*"),
        StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]),
    ]
    trainer = monai.engines.SupervisedTrainer(
        device=device,
        max_epochs=max_epochs,
        train_data_loader=train_loader,
        network=net,
        optimizer=opt,
        loss_function=DiceCELoss(),
        inferer=get_inferer(),
        key_train_metric=None,
        train_handlers=train_handlers,
        amp=amp,
    )
    trainer.run()