Example #1
0
    def setup(self, stage: Optional[str] = None):
        """Creates persistent training and validation sets based on provided splits.

        Args:
            stage (Optional[str], optional): Stage (e.g., "fit", "eval") for more efficient setup. Defaults to None.
        """
        set_determinism(seed=self.seed)
        if stage == "fit" or stage is None:
            train_scans, val_scans = self.splits
            self.train_dicts = [{
                "image": nod["image"],
                "label": nod[self.target]
            } for scan in train_scans for nod in scan["nodules"]
                                if nod["annotations"] >= self.min_anns
                                and nod[self.target] not in self.exclude_labels
                                ]
            self.val_dicts = [{
                "image": nod["image"],
                "label": nod[self.target]
            } for scan in val_scans for nod in scan["nodules"]
                              if nod["annotations"] >= self.min_anns
                              and nod[self.target] not in self.exclude_labels]
            self.train_ds = PersistentDataset(self.train_dicts,
                                              transform=self.train_transforms,
                                              cache_dir=self.cache_dir)
            self.val_ds = PersistentDataset(self.val_dicts,
                                            transform=self.val_transforms,
                                            cache_dir=self.cache_dir)
            return
    def test_mp_dataset(self):
        print("persistent", dist.get_rank())
        items = [[list(range(i))] for i in range(5)]
        ds = PersistentDataset(items,
                               transform=_InplaceXform(),
                               cache_dir=self.tempdir)
        self.assertEqual(items,
                         [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])
        ds1 = PersistentDataset(items,
                                transform=_InplaceXform(),
                                cache_dir=self.tempdir)
        self.assertEqual(list(ds1), list(ds))
        self.assertEqual(items,
                         [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])

        ds = PersistentDataset(items,
                               transform=_InplaceXform(),
                               cache_dir=self.tempdir,
                               hash_func=json_hashing)
        self.assertEqual(items,
                         [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])
        ds1 = PersistentDataset(items,
                                transform=_InplaceXform(),
                                cache_dir=self.tempdir,
                                hash_func=json_hashing)
        self.assertEqual(list(ds1), list(ds))
        self.assertEqual(items,
                         [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])
    def test_cache(self):
        """testing no inplace change to the hashed item"""
        items = [[list(range(i))] for i in range(5)]

        with tempfile.TemporaryDirectory() as tempdir:
            ds = PersistentDataset(items,
                                   transform=_InplaceXform(),
                                   cache_dir=tempdir)
            self.assertEqual(
                items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])
            ds1 = PersistentDataset(items,
                                    transform=_InplaceXform(),
                                    cache_dir=tempdir)
            self.assertEqual(list(ds1), list(ds))
            self.assertEqual(
                items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])

            ds = PersistentDataset(items,
                                   transform=_InplaceXform(),
                                   cache_dir=tempdir,
                                   hash_func=json_hashing)
            self.assertEqual(
                items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])
            ds1 = PersistentDataset(items,
                                    transform=_InplaceXform(),
                                    cache_dir=tempdir,
                                    hash_func=json_hashing)
            self.assertEqual(list(ds1), list(ds))
            self.assertEqual(
                items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])
 def test_mp_dataset(self):
     print("persistent", dist.get_rank())
     items = [[list(range(i))] for i in range(5)]
     cache_dir = os.path.join(self.tempdir, "test")
     ds = PersistentDataset(items, transform=_InplaceXform(), cache_dir=cache_dir)
     self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])
     ds1 = PersistentDataset(items, transform=_InplaceXform(), cache_dir=cache_dir)
     self.assertEqual(list(ds1), list(ds))
     self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])
    def test_shape(self, expected_shape):
        test_image = nib.Nifti1Image(
            np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4))
        tempdir = tempfile.mkdtemp()
        nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz"))
        nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz"))
        nib.save(test_image, os.path.join(tempdir, "test_extra1.nii.gz"))
        nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz"))
        nib.save(test_image, os.path.join(tempdir, "test_label2.nii.gz"))
        nib.save(test_image, os.path.join(tempdir, "test_extra2.nii.gz"))
        test_data = [
            {
                "image": os.path.join(tempdir, "test_image1.nii.gz"),
                "label": os.path.join(tempdir, "test_label1.nii.gz"),
                "extra": os.path.join(tempdir, "test_extra1.nii.gz"),
            },
            {
                "image": os.path.join(tempdir, "test_image2.nii.gz"),
                "label": os.path.join(tempdir, "test_label2.nii.gz"),
                "extra": os.path.join(tempdir, "test_extra2.nii.gz"),
            },
        ]
        test_transform = Compose([
            LoadNiftid(keys=["image", "label", "extra"]),
            SimulateDelayd(keys=["image", "label", "extra"],
                           delay_time=[1e-7, 1e-6, 1e-5]),
        ])

        dataset_precached = PersistentDataset(data=test_data,
                                              transform=test_transform,
                                              cache_dir=tempdir)
        data1_precached = dataset_precached[0]
        data2_precached = dataset_precached[1]

        dataset_postcached = PersistentDataset(data=test_data,
                                               transform=test_transform,
                                               cache_dir=tempdir)
        data1_postcached = dataset_postcached[0]
        data2_postcached = dataset_postcached[1]
        shutil.rmtree(tempdir)

        self.assertTupleEqual(data1_precached["image"].shape, expected_shape)
        self.assertTupleEqual(data1_precached["label"].shape, expected_shape)
        self.assertTupleEqual(data1_precached["extra"].shape, expected_shape)
        self.assertTupleEqual(data2_precached["image"].shape, expected_shape)
        self.assertTupleEqual(data2_precached["label"].shape, expected_shape)
        self.assertTupleEqual(data2_precached["extra"].shape, expected_shape)

        self.assertTupleEqual(data1_postcached["image"].shape, expected_shape)
        self.assertTupleEqual(data1_postcached["label"].shape, expected_shape)
        self.assertTupleEqual(data1_postcached["extra"].shape, expected_shape)
        self.assertTupleEqual(data2_postcached["image"].shape, expected_shape)
        self.assertTupleEqual(data2_postcached["label"].shape, expected_shape)
        self.assertTupleEqual(data2_postcached["extra"].shape, expected_shape)
Example #6
0
    def test_shape(self, transform, expected_shape):
        test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4))
        with tempfile.TemporaryDirectory() as tempdir:
            nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz"))
            nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz"))
            nib.save(test_image, os.path.join(tempdir, "test_extra1.nii.gz"))
            nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz"))
            nib.save(test_image, os.path.join(tempdir, "test_label2.nii.gz"))
            nib.save(test_image, os.path.join(tempdir, "test_extra2.nii.gz"))
            test_data = [
                {
                    "image": os.path.join(tempdir, "test_image1.nii.gz"),
                    "label": os.path.join(tempdir, "test_label1.nii.gz"),
                    "extra": os.path.join(tempdir, "test_extra1.nii.gz"),
                },
                {
                    "image": os.path.join(tempdir, "test_image2.nii.gz"),
                    "label": os.path.join(tempdir, "test_label2.nii.gz"),
                    "extra": os.path.join(tempdir, "test_extra2.nii.gz"),
                },
            ]

            cache_dir = os.path.join(os.path.join(tempdir, "cache"), "data")
            dataset_precached = PersistentDataset(data=test_data, transform=transform, cache_dir=cache_dir)
            data1_precached = dataset_precached[0]
            data2_precached = dataset_precached[1]

            dataset_postcached = PersistentDataset(data=test_data, transform=transform, cache_dir=cache_dir)
            data1_postcached = dataset_postcached[0]
            data2_postcached = dataset_postcached[1]
            data3_postcached = dataset_postcached[0:2]

            if transform is None:
                self.assertEqual(data1_precached["image"], os.path.join(tempdir, "test_image1.nii.gz"))
                self.assertEqual(data2_precached["label"], os.path.join(tempdir, "test_label2.nii.gz"))
                self.assertEqual(data1_postcached["image"], os.path.join(tempdir, "test_image1.nii.gz"))
                self.assertEqual(data2_postcached["extra"], os.path.join(tempdir, "test_extra2.nii.gz"))
            else:
                self.assertTupleEqual(data1_precached["image"].shape, expected_shape)
                self.assertTupleEqual(data1_precached["label"].shape, expected_shape)
                self.assertTupleEqual(data1_precached["extra"].shape, expected_shape)
                self.assertTupleEqual(data2_precached["image"].shape, expected_shape)
                self.assertTupleEqual(data2_precached["label"].shape, expected_shape)
                self.assertTupleEqual(data2_precached["extra"].shape, expected_shape)

                self.assertTupleEqual(data1_postcached["image"].shape, expected_shape)
                self.assertTupleEqual(data1_postcached["label"].shape, expected_shape)
                self.assertTupleEqual(data1_postcached["extra"].shape, expected_shape)
                self.assertTupleEqual(data2_postcached["image"].shape, expected_shape)
                self.assertTupleEqual(data2_postcached["label"].shape, expected_shape)
                self.assertTupleEqual(data2_postcached["extra"].shape, expected_shape)
                for d in data3_postcached:
                    self.assertTupleEqual(d["image"].shape, expected_shape)
Example #7
0
    def test_shape(self, transform, expected_shape):
        test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4))
        tempdir = tempfile.mkdtemp()
        nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz"))
        nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz"))
        nib.save(test_image, os.path.join(tempdir, "test_extra1.nii.gz"))
        nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz"))
        nib.save(test_image, os.path.join(tempdir, "test_label2.nii.gz"))
        nib.save(test_image, os.path.join(tempdir, "test_extra2.nii.gz"))
        test_data = [
            {
                "image": os.path.join(tempdir, "test_image1.nii.gz"),
                "label": os.path.join(tempdir, "test_label1.nii.gz"),
                "extra": os.path.join(tempdir, "test_extra1.nii.gz"),
            },
            {
                "image": os.path.join(tempdir, "test_image2.nii.gz"),
                "label": os.path.join(tempdir, "test_label2.nii.gz"),
                "extra": os.path.join(tempdir, "test_extra2.nii.gz"),
            },
        ]

        dataset_precached = PersistentDataset(data=test_data, transform=transform, cache_dir=tempdir)
        data1_precached = dataset_precached[0]
        data2_precached = dataset_precached[1]

        dataset_postcached = PersistentDataset(data=test_data, transform=transform, cache_dir=tempdir)
        data1_postcached = dataset_postcached[0]
        data2_postcached = dataset_postcached[1]
        shutil.rmtree(tempdir)

        if transform is None:
            self.assertEqual(data1_precached["image"], os.path.join(tempdir, "test_image1.nii.gz"))
            self.assertEqual(data2_precached["label"], os.path.join(tempdir, "test_label2.nii.gz"))
            self.assertEqual(data1_postcached["image"], os.path.join(tempdir, "test_image1.nii.gz"))
            self.assertEqual(data2_postcached["extra"], os.path.join(tempdir, "test_extra2.nii.gz"))
        else:
            self.assertTupleEqual(data1_precached["image"].shape, expected_shape)
            self.assertTupleEqual(data1_precached["label"].shape, expected_shape)
            self.assertTupleEqual(data1_precached["extra"].shape, expected_shape)
            self.assertTupleEqual(data2_precached["image"].shape, expected_shape)
            self.assertTupleEqual(data2_precached["label"].shape, expected_shape)
            self.assertTupleEqual(data2_precached["extra"].shape, expected_shape)

            self.assertTupleEqual(data1_postcached["image"].shape, expected_shape)
            self.assertTupleEqual(data1_postcached["label"].shape, expected_shape)
            self.assertTupleEqual(data1_postcached["extra"].shape, expected_shape)
            self.assertTupleEqual(data2_postcached["image"].shape, expected_shape)
            self.assertTupleEqual(data2_postcached["label"].shape, expected_shape)
            self.assertTupleEqual(data2_postcached["extra"].shape, expected_shape)
Example #8
0
    def setup(self, stage: Optional[str] = None):
        """Set up persistent datasets for training and validation.

        Args:
            stage (Optional[str], optional): Stage in the model lifecycle, e.g., `fit` or `test`. Only needed datasets will be created. Defaults to None.
        """
        self.scans = pd.read_csv(
            self.data_dir/"meta/scans.csv", index_col="PatientID")
        set_determinism(seed=self.seed)

        if stage == "fit" or stage is None:
            train_scans, val_scans = self.splits
            self.train_dicts = [{"image": scan["image"], "label": scan["mask"]}
                for scan in train_scans]
            self.val_dicts = [{"image": scan["image"], "label": scan["mask"]}
                for scan in val_scans]
            self.train_ds = PersistentDataset(
                self.train_dicts, transform=self.train_transforms, cache_dir=self.cache_dir)
            self.val_ds = PersistentDataset(
                self.val_dicts, transform=self.val_transforms, cache_dir=self.cache_dir)
        return
Example #9
0
    def test_cache(self):
        """testing no inplace change to the hashed item"""
        items = [[list(range(i))] for i in range(5)]

        class _InplaceXform(Transform):
            def __call__(self, data):
                if data:
                    data[0] = data[0] + np.pi
                else:
                    data.append(1)
                return data

        with tempfile.TemporaryDirectory() as tempdir:
            ds = PersistentDataset(items,
                                   transform=_InplaceXform(),
                                   cache_dir=tempdir)
            self.assertEqual(
                items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])
            ds1 = PersistentDataset(items,
                                    transform=_InplaceXform(),
                                    cache_dir=tempdir)
            self.assertEqual(list(ds1), list(ds))
            self.assertEqual(
                items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])

            ds = PersistentDataset(items,
                                   transform=_InplaceXform(),
                                   cache_dir=tempdir,
                                   hash_func=json_hashing)
            self.assertEqual(
                items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])
            ds1 = PersistentDataset(items,
                                    transform=_InplaceXform(),
                                    cache_dir=tempdir,
                                    hash_func=json_hashing)
            self.assertEqual(list(ds1), list(ds))
            self.assertEqual(
                items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])
Example #10
0
    def _dataset(self, context, datalist, replace_rate=0.25):
        if context.multi_gpu:
            world_size = torch.distributed.get_world_size()
            if len(
                    datalist
            ) // world_size:  # every gpu gets full data when datalist is smaller
                datalist = partition_dataset(
                    data=datalist,
                    num_partitions=world_size,
                    even_divisible=True)[context.local_rank]

        transforms = self._validate_transforms(
            self.train_pre_transforms(context), "Training", "pre")
        dataset = (
            CacheDataset(datalist, transforms)
            if context.dataset_type == "CacheDataset" else
            SmartCacheDataset(datalist, transforms, replace_rate)
            if context.dataset_type == "SmartCacheDataset" else
            PersistentDataset(datalist,
                              transforms,
                              cache_dir=os.path.join(context.cache_dir, "pds"))
            if context.dataset_type == "PersistentDataset" else Dataset(
                datalist, transforms))
        return dataset, datalist
Example #11
0
    def train(self,
              train_info,
              valid_info,
              hyperparameters,
              run_data_check=False):

        logging.basicConfig(stream=sys.stdout, level=logging.INFO)

        if not run_data_check:
            start_dt = datetime.datetime.now()
            start_dt_string = start_dt.strftime('%d/%m/%Y %H:%M:%S')
            print(f'Training started: {start_dt_string}')

            # 1. Create folders to save the model
            timedate_info = str(
                datetime.datetime.now()).split(' ')[0] + '_' + str(
                    datetime.datetime.now().strftime("%H:%M:%S")).replace(
                        ':', '-')
            path_to_model = os.path.join(
                self.out_dir, 'trained_models',
                self.unique_name + '_' + timedate_info)
            os.mkdir(path_to_model)

        # 2. Load hyperparameters
        learning_rate = hyperparameters['learning_rate']
        weight_decay = hyperparameters['weight_decay']
        total_epoch = hyperparameters['total_epoch']
        multiplicator = hyperparameters['multiplicator']
        batch_size = hyperparameters['batch_size']
        validation_epoch = hyperparameters['validation_epoch']
        validation_interval = hyperparameters['validation_interval']
        H = hyperparameters['H']
        L = hyperparameters['L']

        # 3. Consider class imbalance
        negative, positive = 0, 0
        for _, label in train_info:
            if int(label) == 0:
                negative += 1
            elif int(label) == 1:
                positive += 1

        pos_weight = torch.Tensor([(negative / positive)]).to(self.device)

        # 4. Create train and validation loaders, batch_size = 10 for validation loader (10 central slices)

        train_data = get_data_from_info(self.image_data_dir, self.seg_data_dir,
                                        train_info)
        valid_data = get_data_from_info(self.image_data_dir, self.seg_data_dir,
                                        valid_info)
        large_image_splitter(train_data, self.cache_dir)

        set_determinism(seed=100)
        train_trans, valid_trans = self.transformations(H, L)
        train_dataset = PersistentDataset(
            data=train_data[:],
            transform=train_trans,
            cache_dir=self.persistent_dataset_dir)
        valid_dataset = PersistentDataset(
            data=valid_data[:],
            transform=valid_trans,
            cache_dir=self.persistent_dataset_dir)

        train_loader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  pin_memory=self.pin_memory,
                                  num_workers=self.num_workers,
                                  collate_fn=PadListDataCollate(
                                      Method.SYMMETRIC, NumpyPadMode.CONSTANT))
        valid_loader = DataLoader(valid_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  pin_memory=self.pin_memory,
                                  num_workers=self.num_workers,
                                  collate_fn=PadListDataCollate(
                                      Method.SYMMETRIC, NumpyPadMode.CONSTANT))

        # Perform data checks
        if run_data_check:
            check_data = monai.utils.misc.first(train_loader)
            print(check_data["image"].shape, check_data["label"])
            for i in range(batch_size):
                multi_slice_viewer(
                    check_data["image"][i, 0, :, :, :],
                    check_data["image_meta_dict"]["filename_or_obj"][i])
            exit()
        """c = 1
        for d in train_loader:
            img = d["image"]
            seg = d["seg"][0]
            seg, _ = nrrd.read(seg)
            img_name = d["image_meta_dict"]["filename_or_obj"][0]
            print(c, "Name:", img_name, "Size:", img.nelement()*img.element_size()/1024/1024, "MB", "shape:", img.shape)
            multi_slice_viewer(img[0, 0, :, :, :], d["image_meta_dict"]["filename_or_obj"][0])
            #multi_slice_viewer(seg, d["image_meta_dict"]["filename_or_obj"][0])
            c += 1
        exit()"""

        # 5. Prepare model
        model = ModelCT().to(self.device)

        # 6. Define loss function, optimizer and scheduler
        loss_function = torch.nn.BCEWithLogitsLoss(
            pos_weight)  # pos_weight for class imbalance
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=learning_rate,
                                     weight_decay=weight_decay)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
                                                           multiplicator,
                                                           last_epoch=-1)
        # 7. Create post validation transforms and handlers
        path_to_tensorboard = os.path.join(self.out_dir, 'tensorboard')
        writer = SummaryWriter(log_dir=path_to_tensorboard)
        valid_post_transforms = Compose([
            Activationsd(keys="pred", sigmoid=True),
        ])
        valid_handlers = [
            StatsHandler(output_transform=lambda x: None),
            TensorBoardStatsHandler(summary_writer=writer,
                                    output_transform=lambda x: None),
            CheckpointSaver(save_dir=path_to_model,
                            save_dict={"model": model},
                            save_key_metric=True),
            MetricsSaver(save_dir=path_to_model,
                         metrics=['Valid_AUC', 'Valid_ACC']),
        ]
        # 8. Create validatior
        discrete = AsDiscrete(threshold_values=True)
        evaluator = SupervisedEvaluator(
            device=self.device,
            val_data_loader=valid_loader,
            network=model,
            post_transform=valid_post_transforms,
            key_val_metric={
                "Valid_AUC":
                ROCAUC(output_transform=lambda x: (x["pred"], x["label"]))
            },
            additional_metrics={
                "Valid_Accuracy":
                Accuracy(output_transform=lambda x:
                         (discrete(x["pred"]), x["label"]))
            },
            val_handlers=valid_handlers,
            amp=self.amp,
        )
        # 9. Create trainer

        # Loss function does the last sigmoid, so we dont need it here.
        train_post_transforms = Compose([
            # Empty
        ])
        logger = MetricLogger(evaluator=evaluator)
        train_handlers = [
            logger,
            LrScheduleHandler(lr_scheduler=scheduler, print_lr=True),
            ValidationHandlerCT(validator=evaluator,
                                start=validation_epoch,
                                interval=validation_interval,
                                epoch_level=True),
            StatsHandler(tag_name="loss",
                         output_transform=lambda x: x["loss"]),
            TensorBoardStatsHandler(summary_writer=writer,
                                    tag_name="Train_Loss",
                                    output_transform=lambda x: x["loss"]),
            CheckpointSaver(save_dir=path_to_model,
                            save_dict={
                                "model": model,
                                "opt": optimizer
                            },
                            save_interval=1,
                            n_saved=1),
        ]

        trainer = SupervisedTrainer(
            device=self.device,
            max_epochs=total_epoch,
            train_data_loader=train_loader,
            network=model,
            optimizer=optimizer,
            loss_function=loss_function,
            post_transform=train_post_transforms,
            train_handlers=train_handlers,
            amp=self.amp,
        )
        # 10. Run trainer
        trainer.run()
        # 11. Save results
        np.save(path_to_model + '/AUCS.npy',
                np.array(logger.metrics['Valid_AUC']))
        np.save(path_to_model + '/ACCS.npy',
                np.array(logger.metrics['Valid_ACC']))
        np.save(path_to_model + '/LOSSES.npy', np.array(logger.loss))
        np.save(path_to_model + '/PARAMETERS.npy', np.array(hyperparameters))

        return path_to_model
Example #12
0
    def prepare_data(self):
        data_images = sorted([
            os.path.join(data_path, x) for x in os.listdir(data_path)
            if x.startswith("data")
        ])
        data_labels = sorted([
            os.path.join(data_path, x) for x in os.listdir(data_path)
            if x.startswith("label")
        ])
        data_dicts = [{
            "image":
            image_name,
            "label":
            label_name,
            "patient":
            image_name.split("/")[-1].replace("data",
                                              "").replace(".nii.gz", ""),
        } for image_name, label_name in zip(data_images, data_labels)]
        train_files, val_files = train_val_split(data_dicts)
        print(
            f"Training patients: {len(train_files)}, Validation patients: {len(val_files)}"
        )

        set_determinism(seed=0)

        train_transforms = Compose([
            LoadNiftid(keys=["image", "label"]),
            AddChanneld(keys=["image", "label"]),
            Spacingd(keys=["image", "label"],
                     pixdim=PIXDIM,
                     mode=("bilinear", "nearest")),
            DataStatsdWithPatient(keys=["image", "label"]),
            ScaleIntensityRanged(
                keys=["image"],
                a_min=-100,
                a_max=300,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            RandCropByPosNegLabeld(
                keys=["image", "label"],
                label_key="label",
                spatial_size=PATCH_SIZE,
                pos=1,
                neg=1,
                num_samples=16,
                image_key="image",
                image_threshold=0,
            ),
            RandFlipd(["image", "label"], spatial_axis=[0, 1, 2], prob=0.5),
            ToTensord(keys=["image", "label"]),
        ])
        val_transforms = Compose([
            LoadNiftid(keys=["image", "label"]),
            AddChanneld(keys=["image", "label"]),
            Spacingd(keys=["image", "label"],
                     pixdim=PIXDIM,
                     mode=("bilinear", "nearest")),
            DataStatsdWithPatient(keys=["image", "label"]),
            ScaleIntensityRanged(
                keys=["image"],
                a_min=-100,
                a_max=300,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            StoreShaped(keys=['image']),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            ToTensord(keys=["image", "label"]),
        ])

        self.train_ds = PersistentDataset(data=train_files,
                                          transform=train_transforms,
                                          cache_dir=cache_path)
        self.val_ds = PersistentDataset(data=val_files,
                                        transform=val_transforms,
                                        cache_dir=cache_path)
Example #13
0
File: test.py Project: ckbr0/RIS
def main(train_output):
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
    print_config()

    # Setup directories
    dirs = setup_directories()

    # Setup torch device
    device, using_gpu = create_device("cuda")

    # Load and randomize images

    # HACKATON image and segmentation data
    hackathon_dir = os.path.join(dirs["data"], 'HACKATHON')
    map_fn = lambda x: (x[0], int(x[1]))
    with open(os.path.join(hackathon_dir, "train.txt"), 'r') as fp:
        train_info_hackathon = [
            map_fn(entry.strip().split(',')) for entry in fp.readlines()
        ]
    image_dir = os.path.join(hackathon_dir, 'images', 'train')
    seg_dir = os.path.join(hackathon_dir, 'segmentations', 'train')
    _train_data_hackathon = get_data_from_info(image_dir,
                                               seg_dir,
                                               train_info_hackathon,
                                               dual_output=False)
    large_image_splitter(_train_data_hackathon, dirs["cache"])

    balance_training_data(_train_data_hackathon, seed=72)

    # PSUF data
    """psuf_dir = os.path.join(dirs["data"], 'psuf')
    with open(os.path.join(psuf_dir, "train.txt"), 'r') as fp:
        train_info = [entry.strip().split(',') for entry in fp.readlines()]
    image_dir = os.path.join(psuf_dir, 'images')
    train_data_psuf = get_data_from_info(image_dir, None, train_info)"""
    # Split data into train, validate and test
    train_split, test_data_hackathon = train_test_split(_train_data_hackathon,
                                                        test_size=0.2,
                                                        shuffle=True,
                                                        random_state=42)
    #train_data_hackathon, valid_data_hackathon = train_test_split(train_split, test_size=0.2, shuffle=True, random_state=43)
    # Setup transforms

    # Crop foreground
    crop_foreground = CropForegroundd(
        keys=["image"],
        source_key="image",
        margin=(5, 5, 0),
        #select_fn = lambda x: x != 0
    )
    # Crop Z
    crop_z = RelativeCropZd(keys=["image"], relative_z_roi=(0.07, 0.12))
    # Window width and level (window center)
    WW, WL = 1500, -600
    ct_window = CTWindowd(keys=["image"], width=WW, level=WL)
    spatial_pad = SpatialPadd(keys=["image"], spatial_size=(-1, -1, 30))
    resize = Resized(keys=["image"],
                     spatial_size=(int(512 * 0.50), int(512 * 0.50), -1),
                     mode="trilinear")

    # Create transforms
    common_transform = Compose([
        LoadImaged(keys=["image"]),
        ct_window,
        CTSegmentation(keys=["image"]),
        AddChanneld(keys=["image"]),
        resize,
        crop_foreground,
        crop_z,
        spatial_pad,
    ])
    hackathon_train_transfrom = Compose([
        common_transform,
        ToTensord(keys=["image"]),
    ]).flatten()
    psuf_transforms = Compose([
        LoadImaged(keys=["image"]),
        AddChanneld(keys=["image"]),
        ToTensord(keys=["image"]),
    ])

    # Setup data
    #set_determinism(seed=100)
    test_dataset = PersistentDataset(data=test_data_hackathon[:],
                                     transform=hackathon_train_transfrom,
                                     cache_dir=dirs["persistent"])
    test_loader = DataLoader(test_dataset,
                             batch_size=2,
                             shuffle=True,
                             pin_memory=using_gpu,
                             num_workers=1,
                             collate_fn=PadListDataCollate(
                                 Method.SYMMETRIC, NumpyPadMode.CONSTANT))

    # Setup network, loss function, optimizer and scheduler
    network = nets.DenseNet121(spatial_dims=3, in_channels=1,
                               out_channels=1).to(device)

    # Setup validator and trainer
    valid_post_transforms = Compose([
        Activationsd(keys="pred", sigmoid=True),
    ])

    # Setup tester
    tester = Tester(device=device,
                    test_data_loader=test_loader,
                    load_dir=train_output,
                    out_dir=dirs["out"],
                    network=network,
                    post_transform=valid_post_transforms,
                    non_blocking=using_gpu,
                    amp=using_gpu)

    # Run tester
    tester.run()
Example #14
0
def get_dataflow(seed, data_dir, cache_dir, batch_size):
    img = nib.load(str(data_dir / "average_smwc1.nii"))
    img_data_1 = img.get_fdata()
    img_data_1 = np.expand_dims(img_data_1, axis=0)

    img = nib.load(str(data_dir / "average_smwc2.nii"))
    img_data_2 = img.get_fdata()
    img_data_2 = np.expand_dims(img_data_2, axis=0)

    img = nib.load(str(data_dir / "average_smwc3.nii"))
    img_data_3 = img.get_fdata()
    img_data_3 = np.expand_dims(img_data_3, axis=0)

    mask = np.concatenate((img_data_1, img_data_2, img_data_3))
    mask[mask > 0.3] = 1
    mask[mask <= 0.3] = 0

    # Define transformations
    train_transforms = transforms.Compose([
        transforms.LoadNiftid(keys=["c1", "c2", "c3"]),
        transforms.AddChanneld(keys=["c1", "c2", "c3"]),
        transforms.ConcatItemsd(keys=["c1", "c2", "c3"], name="img"),
        transforms.MaskIntensityd(keys=["img"], mask_data=mask),
        transforms.ScaleIntensityd(keys="img"),
        transforms.ToTensord(keys=["img", "label"])
    ])

    val_transforms = transforms.Compose([
        transforms.LoadNiftid(keys=["c1", "c2", "c3"]),
        transforms.AddChanneld(keys=["c1", "c2", "c3"]),
        transforms.ConcatItemsd(keys=["c1", "c2", "c3"], name="img"),
        transforms.MaskIntensityd(keys=["img"], mask_data=mask),
        transforms.ScaleIntensityd(keys="img"),
        transforms.ToTensord(keys=["img", "label"])
    ])

    # Get img paths
    df = pd.read_csv(data_dir / "banc2019_training_dataset.csv")
    df = df.sample(frac=1, random_state=seed)
    df["NormAge"] = (((df["Age"] - 18) / (92 - 18)) * 2) - 1
    data_dicts = []
    for index, row in df.iterrows():
        study_dir = data_dir / row["Study"] / "derivatives" / "spm"
        subj = list(study_dir.glob(f"sub-{row['Subject']}"))

        if subj == []:
            subj = list(study_dir.glob(f"*sub-{row['Subject']}*"))
            if subj == []:
                subj = list(
                    study_dir.glob(f"*sub-{row['Subject'].rstrip('_S1')}*"))
                if subj == []:
                    if row["Study"] == "SALD":
                        subj = list(
                            study_dir.glob(f"sub-{int(row['Subject']):06d}*"))
                        if subj == []:
                            print(f"{row['Study']} {row['Subject']}")
                            continue
                    else:
                        print(f"{row['Study']} {row['Subject']}")
                        continue

        c1_img = list(subj[0].glob("./smwc1*"))
        c2_img = list(subj[0].glob("./smwc2*"))
        c3_img = list(subj[0].glob("./smwc3*"))

        if c1_img == []:
            print(f"{row['Study']} {row['Subject']}")
            continue
        if c2_img == []:
            print(f"{row['Study']} {row['Subject']}")
            continue
        if c3_img == []:
            print(f"{row['Study']} {row['Subject']}")
            continue

        data_dicts.append({
            "c1": str(c1_img[0]),
            "c2": str(c2_img[0]),
            "c3": str(c3_img[0]),
            "label": row["NormAge"]
        })

    print(f"Found {len(data_dicts)} subjects.")
    val_size = len(data_dicts) // 10
    # Create datasets and dataloaders
    train_ds = PersistentDataset(data=data_dicts[:-val_size],
                                 transform=train_transforms,
                                 cache_dir=cache_dir)
    # train_ds = Dataset(data=data_dicts[:-val_size], transform=train_transforms)
    train_loader = DataLoader(train_ds,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=4,
                              collate_fn=list_data_collate)

    val_ds = PersistentDataset(data=data_dicts[-val_size:],
                               transform=val_transforms,
                               cache_dir=cache_dir)
    # val_ds = Dataset(data=data_dicts[-val_size:], transform=val_transforms)
    val_loader = DataLoader(val_ds, batch_size=batch_size, num_workers=4)

    return train_loader, val_loader
Example #15
0
    def test_thread_safe(self, persistent_workers, cache_workers,
                         loader_workers):
        expected = [102, 202, 302, 402, 502, 602, 702, 802, 902, 1002]
        _kwg = {
            "persistent_workers": persistent_workers
        } if pytorch_after(1, 8) else {}
        data_list = list(range(1, 11))
        dataset = CacheDataset(data=data_list,
                               transform=_StatefulTransform(),
                               cache_rate=1.0,
                               num_workers=cache_workers,
                               progress=False)
        self.assertListEqual(expected, list(dataset))
        loader = DataLoader(
            CacheDataset(
                data=data_list,
                transform=_StatefulTransform(),
                cache_rate=1.0,
                num_workers=cache_workers,
                progress=False,
            ),
            batch_size=1,
            num_workers=loader_workers,
            **_kwg,
        )
        self.assertListEqual(expected, [y.item() for y in loader])
        self.assertListEqual(expected, [y.item() for y in loader])

        dataset = SmartCacheDataset(
            data=data_list,
            transform=_StatefulTransform(),
            cache_rate=0.7,
            replace_rate=0.5,
            num_replace_workers=cache_workers,
            progress=False,
            shuffle=False,
        )
        self.assertListEqual(expected[:7], list(dataset))
        loader = DataLoader(
            SmartCacheDataset(
                data=data_list,
                transform=_StatefulTransform(),
                cache_rate=0.7,
                replace_rate=0.5,
                num_replace_workers=cache_workers,
                progress=False,
                shuffle=False,
            ),
            batch_size=1,
            num_workers=loader_workers,
            **_kwg,
        )
        self.assertListEqual(expected[:7], [y.item() for y in loader])
        self.assertListEqual(expected[:7], [y.item() for y in loader])

        with tempfile.TemporaryDirectory() as tempdir:
            pdata = PersistentDataset(data=data_list,
                                      transform=_StatefulTransform(),
                                      cache_dir=tempdir)
            self.assertListEqual(expected, list(pdata))
            loader = DataLoader(
                PersistentDataset(data=data_list,
                                  transform=_StatefulTransform(),
                                  cache_dir=tempdir),
                batch_size=1,
                num_workers=loader_workers,
                shuffle=False,
                **_kwg,
            )
            self.assertListEqual(expected, [y.item() for y in loader])
            self.assertListEqual(expected, [y.item() for y in loader])
Example #16
0
def run_training(train_file_list, valid_file_list, config_info):
    """
    Pipeline to train a dynUNet segmentation model in MONAI. It is composed of the following main blocks:
        * Data Preparation: Extract the filenames and prepare the training/validation processing transforms
        * Load Data: Load training and validation data to PyTorch DataLoader
        * Network Preparation: Define the network, loss function, optimiser and learning rate scheduler
        * MONAI Evaluator: Initialise the dynUNet evaluator, i.e. the class providing utilities to perform validation
            during training. Attach handlers to save the best model on the validation set. A 2D sliding window approach
            on the 3D volume is used at evaluation. The mean 3D Dice is used as validation metric.
        * MONAI Trainer: Initialise the dynUNet trainer, i.e. the class providing utilities to perform the training loop.
        * Run training: The MONAI trainer is run, performing training and validation during training.
    Args:
        train_file_list: .txt or .csv file (with no header) storing two-columns filenames for training:
            image filename in the first column and segmentation filename in the second column.
            The two columns should be separated by a comma.
            See monaifbs/config/mock_train_file_list_for_dynUnet_training.txt for an example of the expected format.
        valid_file_list: .txt or .csv file (with no header) storing two-columns filenames for validation:
            image filename in the first column and segmentation filename in the second column.
            The two columns should be separated by a comma.
            See monaifbs/config/mock_valid_file_list_for_dynUnet_training.txt for an example of the expected format.
        config_info: dict, contains configuration parameters for sampling, network and training.
            See monaifbs/config/monai_dynUnet_training_config.yml for an example of the expected fields.
    """

    """
    Read input and configuration parameters
    """
    # print MONAI config information
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
    print_config()

    # print to log the parameter setups
    print(yaml.dump(config_info))

    # extract network parameters, perform checks/set defaults if not present and print them to log
    if 'seg_labels' in config_info['training'].keys():
        seg_labels = config_info['training']['seg_labels']
    else:
        seg_labels = [1]
    nr_out_channels = len(seg_labels)
    print("Considering the following {} labels in the segmentation: {}".format(nr_out_channels, seg_labels))
    patch_size = config_info["training"]["inplane_size"] + [1]
    print("Considering patch size = {}".format(patch_size))

    spacing = config_info["training"]["spacing"]
    print("Bringing all images to spacing = {}".format(spacing))

    if 'model_to_load' in config_info['training'].keys() and config_info['training']['model_to_load'] is not None:
        model_to_load = config_info['training']['model_to_load']
        if not os.path.exists(model_to_load):
            raise FileNotFoundError("Cannot find model: {}".format(model_to_load))
        else:
            print("Loading model from {}".format(model_to_load))
    else:
        model_to_load = None

    # set up either GPU or CPU usage
    if torch.cuda.is_available():
        print("\n#### GPU INFORMATION ###")
        print("Using device number: {}, name: {}\n".format(torch.cuda.current_device(), torch.cuda.get_device_name()))
        current_device = torch.device("cuda:0")
    else:
        current_device = torch.device("cpu")
        print("Using device: {}".format(current_device))

    # set determinism if required
    if 'manual_seed' in config_info['training'].keys() and config_info['training']['manual_seed'] is not None:
        seed = config_info['training']['manual_seed']
    else:
        seed = None
    if seed is not None:
        print("Using determinism with seed = {}\n".format(seed))
        set_determinism(seed=seed)

    """
    Setup data output directory
    """
    out_model_dir = os.path.join(config_info['output']['out_dir'],
                                 datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_' +
                                 config_info['output']['out_postfix'])
    print("Saving to directory {}\n".format(out_model_dir))
    # create cache directory to store results for Persistent Dataset
    if 'cache_dir' in config_info['output'].keys():
        out_cache_dir = config_info['output']['cache_dir']
    else:
        out_cache_dir = os.path.join(out_model_dir, 'persistent_cache')
    persistent_cache: Path = Path(out_cache_dir)
    persistent_cache.mkdir(parents=True, exist_ok=True)

    """
    Data preparation
    """
    # Read the input files for training and validation
    print("*** Loading input data for training...")

    train_files = create_data_list_of_dictionaries(train_file_list)
    print("Number of inputs for training = {}".format(len(train_files)))

    val_files = create_data_list_of_dictionaries(valid_file_list)
    print("Number of inputs for validation = {}".format(len(val_files)))

    # Define MONAI processing transforms for the training data. This includes:
    # - Load Nifti files and convert to format Batch x Channel x Dim1 x Dim2 x Dim3
    # - CropForegroundd: Reduce the background from the MR image
    # - InPlaneSpacingd: Perform in-plane resampling to the desired spacing, but preserve the resolution along the
    #       last direction (lowest resolution) to avoid introducing motion artefact resampling errors
    # - SpatialPadd: Pad the in-plane size to the defined network input patch size [N, M] if needed
    # - NormalizeIntensityd: Apply whitening
    # - RandSpatialCropd: Crop a random patch from the input with size [B, C, N, M, 1]
    # - SqueezeDimd: Convert the 3D patch to a 2D one as input to the network (i.e. bring it to size [B, C, N, M])
    # - Apply data augmentation (RandZoomd, RandRotated, RandGaussianNoised, RandGaussianSmoothd, RandScaleIntensityd,
    #       RandFlipd)
    # - ToTensor: convert to pytorch tensor
    train_transforms = Compose(
        [
            LoadNiftid(keys=["image", "label"]),
            AddChanneld(keys=["image", "label"]),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            InPlaneSpacingd(
                keys=["image", "label"],
                pixdim=spacing,
                mode=("bilinear", "nearest"),
            ),
            SpatialPadd(keys=["image", "label"], spatial_size=patch_size,
                        mode=["constant", "edge"]),
            NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True),
            RandSpatialCropd(keys=["image", "label"], roi_size=patch_size, random_size=False),
            SqueezeDimd(keys=["image", "label"], dim=-1),
            RandZoomd(
                keys=["image", "label"],
                min_zoom=0.9,
                max_zoom=1.2,
                mode=("bilinear", "nearest"),
                align_corners=(True, None),
                prob=0.16,
            ),
            RandRotated(keys=["image", "label"], range_x=90, range_y=90, prob=0.2,
                        keep_size=True, mode=["bilinear", "nearest"],
                        padding_mode=["zeros", "border"]),
            RandGaussianNoised(keys=["image"], std=0.01, prob=0.15),
            RandGaussianSmoothd(
                keys=["image"],
                sigma_x=(0.5, 1.15),
                sigma_y=(0.5, 1.15),
                sigma_z=(0.5, 1.15),
                prob=0.15,
            ),
            RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.15),
            RandFlipd(["image", "label"], spatial_axis=[0, 1], prob=0.5),
            ToTensord(keys=["image", "label"]),
        ]
    )

    # Define MONAI processing transforms for the validation data
    # - Load Nifti files and convert to format Batch x Channel x Dim1 x Dim2 x Dim3
    # - CropForegroundd: Reduce the background from the MR image
    # - InPlaneSpacingd: Perform in-plane resampling to the desired spacing, but preserve the resolution along the
    #       last direction (lowest resolution) to avoid introducing motion artefact resampling errors
    # - SpatialPadd: Pad the in-plane size to the defined network input patch size [N, M] if needed
    # - NormalizeIntensityd: Apply whitening
    # - ToTensor: convert to pytorch tensor
    # NOTE: The validation data is kept 3D as a 2D sliding window approach is used throughout the volume at inference
    val_transforms = Compose(
        [
            LoadNiftid(keys=["image", "label"]),
            AddChanneld(keys=["image", "label"]),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            InPlaneSpacingd(
                keys=["image", "label"],
                pixdim=spacing,
                mode=("bilinear", "nearest"),
            ),
            SpatialPadd(keys=["image", "label"], spatial_size=patch_size, mode=["constant", "edge"]),
            NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True),
            ToTensord(keys=["image", "label"]),
        ]
    )

    """
    Load data 
    """
    # create training data loader
    train_ds = PersistentDataset(data=train_files, transform=train_transforms,
                                 cache_dir=persistent_cache)
    train_loader = DataLoader(train_ds,
                              batch_size=config_info['training']['batch_size_train'],
                              shuffle=True,
                              num_workers=config_info['device']['num_workers'])
    check_train_data = misc.first(train_loader)
    print("Training data tensor shapes:")
    print("Image = {}; Label = {}".format(check_train_data["image"].shape, check_train_data["label"].shape))

    # create validation data loader
    if config_info['training']['batch_size_valid'] != 1:
        raise Exception("Batch size different from 1 at validation ar currently not supported")
    val_ds = PersistentDataset(data=val_files, transform=val_transforms, cache_dir=persistent_cache)
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            shuffle=False,
                            num_workers=config_info['device']['num_workers'])
    check_valid_data = misc.first(val_loader)
    print("Validation data tensor shapes (Example):")
    print("Image = {}; Label = {}\n".format(check_valid_data["image"].shape, check_valid_data["label"].shape))

    """
    Network preparation
    """
    print("*** Preparing the network ...")
    # automatically extracts the strides and kernels based on nnU-Net empirical rules
    spacings = spacing[:2]
    sizes = patch_size[:2]
    strides, kernels = [], []
    while True:
        spacing_ratio = [sp / min(spacings) for sp in spacings]
        stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)]
        kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
        if all(s == 1 for s in stride):
            break
        sizes = [i / j for i, j in zip(sizes, stride)]
        spacings = [i * j for i, j in zip(spacings, stride)]
        kernels.append(kernel)
        strides.append(stride)
    strides.insert(0, len(spacings) * [1])
    kernels.append(len(spacings) * [3])

    # initialise the network
    net = DynUNet(
        spatial_dims=2,
        in_channels=1,
        out_channels=nr_out_channels,
        kernel_size=kernels,
        strides=strides,
        upsample_kernel_size=strides[1:],
        norm_name="instance",
        deep_supervision=True,
        deep_supr_num=2,
        res_block=False,
    ).to(current_device)
    print(net)

    # define the loss function
    loss_function = choose_loss_function(nr_out_channels, config_info)

    # define the optimiser and the learning rate scheduler
    opt = torch.optim.SGD(net.parameters(), lr=float(config_info['training']['lr']), momentum=0.95)
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        opt, lr_lambda=lambda epoch: (1 - epoch / config_info['training']['nr_train_epochs']) ** 0.9
    )

    """
    MONAI evaluator
    """
    print("*** Preparing the dynUNet evaluator engine...\n")
    # val_post_transforms = Compose(
    #     [
    #         Activationsd(keys="pred", sigmoid=True),
    #     ]
    # )
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        TensorBoardStatsHandler(log_dir=os.path.join(out_model_dir, "valid"),
                                output_transform=lambda x: None,
                                global_epoch_transform=lambda x: trainer.state.iteration),
        CheckpointSaver(save_dir=out_model_dir, save_dict={"net": net, "opt": opt}, save_key_metric=True,
                        file_prefix='best_valid'),
    ]
    if config_info['output']['val_image_to_tensorboad']:
        val_handlers.append(TensorBoardImageHandler(log_dir=os.path.join(out_model_dir, "valid"),
                                                    batch_transform=lambda x: (x["image"], x["label"]),
                                                    output_transform=lambda x: x["pred"], interval=2))

    # Define customized evaluator
    class DynUNetEvaluator(SupervisedEvaluator):
        def _iteration(self, engine, batchdata):
            inputs, targets = self.prepare_batch(batchdata)
            inputs, targets = inputs.to(engine.state.device), targets.to(engine.state.device)
            flip_inputs_1 = torch.flip(inputs, dims=(2,))
            flip_inputs_2 = torch.flip(inputs, dims=(3,))
            flip_inputs_3 = torch.flip(inputs, dims=(2, 3))

            def _compute_pred():
                pred = self.inferer(inputs, self.network)
                # use random flipping as data augmentation at inference
                flip_pred_1 = torch.flip(self.inferer(flip_inputs_1, self.network), dims=(2,))
                flip_pred_2 = torch.flip(self.inferer(flip_inputs_2, self.network), dims=(3,))
                flip_pred_3 = torch.flip(self.inferer(flip_inputs_3, self.network), dims=(2, 3))
                return (pred + flip_pred_1 + flip_pred_2 + flip_pred_3) / 4

            # execute forward computation
            self.network.eval()
            with torch.no_grad():
                if self.amp:
                    with torch.cuda.amp.autocast():
                        predictions = _compute_pred()
                else:
                    predictions = _compute_pred()
            return {"image": inputs, "label": targets, "pred": predictions}

    evaluator = DynUNetEvaluator(
        device=current_device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer2D(roi_size=patch_size, sw_batch_size=4, overlap=0.0),
        post_transform=None,
        key_val_metric={
            "Mean_dice": MeanDice(
                include_background=False,
                to_onehot_y=True,
                mutually_exclusive=True,
                output_transform=lambda x: (x["pred"], x["label"]),
            )
        },
        val_handlers=val_handlers,
        amp=False,
    )

    """
    MONAI trainer
    """
    print("*** Preparing the dynUNet trainer engine...\n")
    # train_post_transforms = Compose(
    #     [
    #         Activationsd(keys="pred", sigmoid=True),
    #     ]
    # )

    validation_every_n_epochs = config_info['training']['validation_every_n_epochs']
    epoch_len = len(train_ds) // train_loader.batch_size
    validation_every_n_iters = validation_every_n_epochs * epoch_len

    # define event handlers for the trainer
    writer_train = SummaryWriter(log_dir=os.path.join(out_model_dir, "train"))
    train_handlers = [
        LrScheduleHandler(lr_scheduler=scheduler, print_lr=True),
        ValidationHandler(validator=evaluator, interval=validation_every_n_iters, epoch_level=False),
        StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]),
        TensorBoardStatsHandler(summary_writer=writer_train,
                                log_dir=os.path.join(out_model_dir, "train"), tag_name="Loss",
                                output_transform=lambda x: x["loss"],
                                global_epoch_transform=lambda x: trainer.state.iteration),
        CheckpointSaver(save_dir=out_model_dir, save_dict={"net": net, "opt": opt},
                        save_final=True,
                        save_interval=2, epoch_level=True,
                        n_saved=config_info['output']['max_nr_models_saved']),
    ]
    if model_to_load is not None:
        train_handlers.append(CheckpointLoader(load_path=model_to_load, load_dict={"net": net, "opt": opt}))

    # define customized trainer
    class DynUNetTrainer(SupervisedTrainer):
        def _iteration(self, engine, batchdata):
            inputs, targets = self.prepare_batch(batchdata)
            inputs, targets = inputs.to(engine.state.device), targets.to(engine.state.device)

            def _compute_loss(preds, label):
                labels = [label] + [interpolate(label, pred.shape[2:]) for pred in preds[1:]]
                return sum([0.5 ** i * self.loss_function(p, l) for i, (p, l) in enumerate(zip(preds, labels))])

            self.network.train()
            self.optimizer.zero_grad()
            if self.amp and self.scaler is not None:
                with torch.cuda.amp.autocast():
                    predictions = self.inferer(inputs, self.network)
                    loss = _compute_loss(predictions, targets)
                self.scaler.scale(loss).backward()
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                predictions = self.inferer(inputs, self.network)
                loss = _compute_loss(predictions, targets).mean()
                loss.backward()
                self.optimizer.step()
            return {"image": inputs, "label": targets, "pred": predictions, "loss": loss.item()}

    trainer = DynUNetTrainer(
        device=current_device,
        max_epochs=config_info['training']['nr_train_epochs'],
        train_data_loader=train_loader,
        network=net,
        optimizer=opt,
        loss_function=loss_function,
        inferer=SimpleInferer(),
        post_transform=None,
        key_train_metric=None,
        train_handlers=train_handlers,
        amp=False,
    )

    """
    Run training
    """
    print("*** Run training...")
    trainer.run()
    print("Done!")
Example #17
0
def main():
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
    print_config()

    # Setup directories
    dirs = setup_directories()

    # Setup torch device
    device, using_gpu = create_device("cuda")

    # Load and randomize images

    # HACKATON image and segmentation data
    hackathon_dir = os.path.join(dirs["data"], 'HACKATHON')
    map_fn = lambda x: (x[0], int(x[1]))
    with open(os.path.join(hackathon_dir, "train.txt"), 'r') as fp:
        train_info_hackathon = [
            map_fn(entry.strip().split(',')) for entry in fp.readlines()
        ]
    image_dir = os.path.join(hackathon_dir, 'images', 'train')
    seg_dir = os.path.join(hackathon_dir, 'segmentations', 'train')
    _train_data_hackathon = get_data_from_info(image_dir,
                                               seg_dir,
                                               train_info_hackathon,
                                               dual_output=False)
    _train_data_hackathon = large_image_splitter(_train_data_hackathon,
                                                 dirs["cache"])
    copy_list = transform_and_copy(_train_data_hackathon, dirs['cache'])
    balance_training_data2(_train_data_hackathon, copy_list, seed=72)

    # PSUF data
    """psuf_dir = os.path.join(dirs["data"], 'psuf')
    with open(os.path.join(psuf_dir, "train.txt"), 'r') as fp:
        train_info = [entry.strip().split(',') for entry in fp.readlines()]
    image_dir = os.path.join(psuf_dir, 'images')
    train_data_psuf = get_data_from_info(image_dir, None, train_info)"""
    # Split data into train, validate and test
    train_split, test_data_hackathon = train_test_split(_train_data_hackathon,
                                                        test_size=0.2,
                                                        shuffle=True,
                                                        random_state=42)
    train_data_hackathon, valid_data_hackathon = train_test_split(
        train_split, test_size=0.2, shuffle=True, random_state=43)

    #balance_training_data(train_data_hackathon, seed=72)
    #balance_training_data(valid_data_hackathon, seed=73)
    #balance_training_data(test_data_hackathon, seed=74)
    # Setup transforms

    # Crop foreground
    crop_foreground = CropForegroundd(keys=["image"],
                                      source_key="image",
                                      margin=(5, 5, 0),
                                      select_fn=lambda x: x != 0)
    # Crop Z
    crop_z = RelativeCropZd(keys=["image"], relative_z_roi=(0.07, 0.12))
    # Window width and level (window center)
    WW, WL = 1500, -600
    ct_window = CTWindowd(keys=["image"], width=WW, level=WL)
    # Random axis flip
    rand_x_flip = RandFlipd(keys=["image"], spatial_axis=0, prob=0.50)
    rand_y_flip = RandFlipd(keys=["image"], spatial_axis=1, prob=0.50)
    rand_z_flip = RandFlipd(keys=["image"], spatial_axis=2, prob=0.50)
    # Rand affine transform
    rand_affine = RandAffined(keys=["image"],
                              prob=0.5,
                              rotate_range=(0, 0, np.pi / 12),
                              shear_range=(0.07, 0.07, 0.0),
                              translate_range=(0, 0, 0),
                              scale_range=(0.07, 0.07, 0.0),
                              padding_mode="zeros")
    # Pad image to have hight at least 30
    spatial_pad = SpatialPadd(keys=["image"], spatial_size=(-1, -1, 30))
    resize = Resized(keys=["image"],
                     spatial_size=(int(512 * 0.50), int(512 * 0.50), -1),
                     mode="trilinear")
    # Apply Gaussian noise
    rand_gaussian_noise = RandGaussianNoised(keys=["image"],
                                             prob=0.25,
                                             mean=0.0,
                                             std=0.1)

    # Create transforms
    common_transform = Compose([
        LoadImaged(keys=["image"]),
        ct_window,
        CTSegmentation(keys=["image"]),
        AddChanneld(keys=["image"]),
        resize,
        crop_foreground,
        crop_z,
        spatial_pad,
    ])
    hackathon_train_transform = Compose([
        common_transform,
        rand_x_flip,
        rand_y_flip,
        rand_z_flip,
        rand_affine,
        rand_gaussian_noise,
        ToTensord(keys=["image"]),
    ]).flatten()
    hackathon_valid_transfrom = Compose([
        common_transform,
        #rand_x_flip,
        #rand_y_flip,
        #rand_z_flip,
        #rand_affine,
        ToTensord(keys=["image"]),
    ]).flatten()
    hackathon_test_transfrom = Compose([
        common_transform,
        ToTensord(keys=["image"]),
    ]).flatten()
    psuf_transforms = Compose([
        LoadImaged(keys=["image"]),
        AddChanneld(keys=["image"]),
        ToTensord(keys=["image"]),
    ])

    # Setup data
    #set_determinism(seed=100)
    train_dataset = PersistentDataset(data=train_data_hackathon[:],
                                      transform=hackathon_train_transform,
                                      cache_dir=dirs["persistent"])
    valid_dataset = PersistentDataset(data=valid_data_hackathon[:],
                                      transform=hackathon_valid_transfrom,
                                      cache_dir=dirs["persistent"])
    test_dataset = PersistentDataset(data=test_data_hackathon[:],
                                     transform=hackathon_test_transfrom,
                                     cache_dir=dirs["persistent"])
    train_loader = DataLoader(
        train_dataset,
        batch_size=4,
        #shuffle=True,
        pin_memory=using_gpu,
        num_workers=2,
        sampler=ImbalancedDatasetSampler(
            train_data_hackathon,
            callback_get_label=lambda x, i: x[i]['_label']),
        collate_fn=PadListDataCollate(Method.SYMMETRIC, NumpyPadMode.CONSTANT))
    valid_loader = DataLoader(
        valid_dataset,
        batch_size=4,
        shuffle=False,
        pin_memory=using_gpu,
        num_workers=2,
        sampler=ImbalancedDatasetSampler(
            valid_data_hackathon,
            callback_get_label=lambda x, i: x[i]['_label']),
        collate_fn=PadListDataCollate(Method.SYMMETRIC, NumpyPadMode.CONSTANT))
    test_loader = DataLoader(test_dataset,
                             batch_size=4,
                             shuffle=False,
                             pin_memory=using_gpu,
                             num_workers=2,
                             collate_fn=PadListDataCollate(
                                 Method.SYMMETRIC, NumpyPadMode.CONSTANT))

    # Setup network, loss function, optimizer and scheduler
    network = nets.DenseNet121(spatial_dims=3, in_channels=1,
                               out_channels=1).to(device)
    # pos_weight for class imbalance
    _, n, p = calculate_class_imbalance(train_data_hackathon)
    pos_weight = torch.Tensor([n, p]).to(device)
    loss_function = torch.nn.BCEWithLogitsLoss(pos_weight)
    optimizer = torch.optim.Adam(network.parameters(), lr=1e-4, weight_decay=0)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
                                                       gamma=0.95,
                                                       last_epoch=-1)

    # Setup validator and trainer
    valid_post_transforms = Compose([
        Activationsd(keys="pred", sigmoid=True),
        #Activationsd(keys="pred", softmax=True),
    ])
    validator = Validator(device=device,
                          val_data_loader=valid_loader,
                          network=network,
                          post_transform=valid_post_transforms,
                          amp=using_gpu,
                          non_blocking=using_gpu)

    trainer = Trainer(device=device,
                      out_dir=dirs["out"],
                      out_name="DenseNet121",
                      max_epochs=120,
                      validation_epoch=1,
                      validation_interval=1,
                      train_data_loader=train_loader,
                      network=network,
                      optimizer=optimizer,
                      loss_function=loss_function,
                      lr_scheduler=None,
                      validator=validator,
                      amp=using_gpu,
                      non_blocking=using_gpu)
    """x_max, y_max, z_max, size_max = 0, 0, 0, 0
    for data in valid_loader:
        image = data["image"]
        label = data["label"]
        print()
        print(len(data['image_transforms']))
        #print(data['image_transforms'])
        print(label)
        shape = image.shape
        x_max = max(x_max, shape[-3])
        y_max = max(y_max, shape[-2])
        z_max = max(z_max, shape[-1])
        size = int(image.nelement()*image.element_size()/1024/1024)
        size_max = max(size_max, size)
        print("shape:", shape, "size:", str(size)+"MB")
        #multi_slice_viewer(image[0, 0, :, :, :], str(label))
    print(x_max, y_max, z_max, str(size_max)+"MB")
    exit()"""

    # Run trainer
    train_output = trainer.run()

    # Setup tester
    tester = Tester(device=device,
                    test_data_loader=test_loader,
                    load_dir=train_output,
                    out_dir=dirs["out"],
                    network=network,
                    post_transform=valid_post_transforms,
                    non_blocking=using_gpu,
                    amp=using_gpu)

    # Run tester
    tester.run()