def run_interaction(self, train, compose):
        data = [{
            "image": np.ones((1, 2, 2, 2)).astype(np.float32),
            "label": np.ones((1, 2, 2, 2))
        } for _ in range(5)]
        network = torch.nn.Linear(2, 2)
        lr = 1e-3
        opt = torch.optim.SGD(network.parameters(), lr)
        loss = torch.nn.L1Loss()
        train_transforms = Compose([
            FindAllValidSlicesd(label="label", sids="sids"),
            AddInitialSeedPointd(label="label",
                                 guidance="guidance",
                                 sids="sids"),
            AddGuidanceSignald(image="image", guidance="guidance"),
            ToTensord(keys=("image", "label")),
        ])
        dataset = Dataset(data, transform=train_transforms)
        data_loader = torch.utils.data.DataLoader(dataset, batch_size=5)

        iteration_transforms = [
            Activationsd(keys="pred", sigmoid=True),
            ToNumpyd(keys=["image", "label", "pred"]),
            FindDiscrepancyRegionsd(label="label",
                                    pred="pred",
                                    discrepancy="discrepancy"),
            AddRandomGuidanced(guidance="guidance",
                               discrepancy="discrepancy",
                               probability="probability"),
            AddGuidanceSignald(image="image", guidance="guidance"),
            ToTensord(keys=("image", "label")),
        ]
        iteration_transforms = Compose(
            iteration_transforms) if compose else iteration_transforms

        i = Interaction(transforms=iteration_transforms,
                        train=train,
                        max_interactions=5)
        self.assertEqual(len(i.transforms.transforms), 6,
                         "Mismatch in expected transforms")

        # set up engine
        engine = SupervisedTrainer(
            device=torch.device("cpu"),
            max_epochs=1,
            train_data_loader=data_loader,
            network=network,
            optimizer=opt,
            loss_function=loss,
            iteration_update=i,
        )
        engine.add_event_handler(IterationEvents.INNER_ITERATION_STARTED,
                                 add_one)
        engine.add_event_handler(IterationEvents.INNER_ITERATION_COMPLETED,
                                 add_one)

        engine.run()
        self.assertIsNotNone(engine.state.batch[0].get("guidance"),
                             "guidance is missing")
        self.assertEqual(engine.state.best_metric, 9)
 def prepare_data(self):
     data_dir = self.hparams.data_dir
     
     # Train imgs/masks
     train_imgs = []
     train_masks = []
     with open(data_dir + 'train_imgs.txt', 'r') as f:
         train_imgs = [data_dir + image.rstrip() for image in f.readlines()]
     with open(data_dir + 'train_masks.txt', 'r') as f:
         train_masks = [data_dir + mask.rstrip() for mask in f.readlines()]
     train_dicts = [{'image': image, 'mask': mask} for (image, mask) in zip(train_imgs, train_masks)]
     train_dicts, val_dicts = train_test_split(train_dicts, test_size=0.2)
     
     # Basic transforms
     data_keys = ["image", "mask"]
     data_transforms = Compose(
         [
             LoadNiftid(keys=data_keys),
             AddChanneld(keys=data_keys),
             ScaleIntensityRangePercentilesd(
                 keys='image',
                 lower=25,
                 upper=75,
                 b_min=-0.5,
                 b_max=0.5
             )
         ]
     )
     
     self.train_dataset = monai.data.CacheDataset(
         data=train_dicts,
         transform=Compose(
             [
                 data_transforms,
                 RandCropByPosNegLabeld(
                     keys=data_keys,
                     label_key="mask",
                     spatial_size=self.hparams.patch_size,
                     num_samples=4,
                     image_key="image",
                     pos=0.8,
                     neg=0.2
                 ),
                 ToTensord(keys=data_keys)
             ]
         ),
         cache_rate=1.0
     )
     
     self.val_dataset = monai.data.CacheDataset(
         data=val_dicts,
         transform=Compose(
             [
                 data_transforms,
                 CenterSpatialCropd(keys=data_keys, roi_size=self.hparams.patch_size),
                 ToTensord(keys=data_keys)
             ]
         ),
         cache_rate=1.0
     )
Beispiel #3
0
    def transformations(self, H, L):
        lower = L - (H / 2)
        upper = L + (H / 2)

        basic_transforms = Compose([
            # Load image
            LoadImaged(keys=["image"]),

            # Segmentacija
            CTSegmentation(keys=["image"]),
            AddChanneld(keys=["image"]),

            # Crop foreground based on seg image.
            CropForegroundd(keys=["image"],
                            source_key="image",
                            margin=(30, 30, 0)),

            # Obreži sliko v Z smeri, relative_z_roi = ( % od spodaj, % od zgoraj)
            RelativeAsymmetricZCropd(keys=["image"],
                                     relative_z_roi=(0.15, 0.25)),
        ])

        train_transforms = Compose([
            basic_transforms,

            # Normalizacija na CT okno
            # https://radiopaedia.org/articles/windowing-ct
            RandCTWindowd(keys=["image"],
                          prob=1.0,
                          width=(H - 50, H + 50),
                          level=(L - 25, L + 25)),

            # Mogoče zanimiva
            RandAxisFlipd(keys=["image"], prob=0.1),
            RandAffined(
                keys=["image"],
                prob=0.25,
                rotate_range=(0, 0, np.pi / 16),
                shear_range=(0.05, 0.05, 0.0),
                translate_range=(10, 10, 0),
                scale_range=(0.05, 0.05, 0.0),
                spatial_size=(-1, -1, -1),
                padding_mode="zeros",
            ),
            ToTensord(keys=["image"]),
        ]).flatten()

        # NOTE: No random transforms in the validation data
        valid_transforms = Compose([
            basic_transforms,

            # Normalizacija na CT okno
            # https://radiopaedia.org/articles/windowing-ct
            CTWindowd(keys=["image"], width=H, level=L),
            ToTensord(keys=["image"]),
        ]).flatten()

        return train_transforms, valid_transforms
Beispiel #4
0
def transformations():
    train_transforms = Compose([
        LoadImaged(keys=['image', 'label']),
        AddChanneld(keys=['image', 'label']),
        ToTensord(keys=['image', 'label'])
    ])
    val_transforms = Compose([
        LoadImaged(keys=['image', 'label']),
        AddChanneld(keys=['image', 'label']),
        ToTensord(keys=['image', 'label'])
    ])
    return train_transforms, val_transforms
    def setup(self, stage):
        data_dir = 'data/'
        
        # Train imgs/masks
        train_imgs = []
        with open(data_dir + 'train_imgs.txt', 'r') as f:
            train_imgs = [image.rstrip() for image in f.readlines()]

        train_masks = []
        with open(data_dir + 'train_masks.txt', 'r') as f:
            train_masks = [mask.rstrip() for mask in f.readlines()]
        
        train_dicts = [{'image': image, 'mask': mask} for (image, mask) in zip(train_imgs, train_masks)]
        
        train_dicts, val_dicts = train_test_split(train_dicts, test_size=0.2)
        
        # Basic transforms
        data_keys = ["image", "mask"]
        data_transforms = Compose(
            [
                LoadNiftid(keys=data_keys),
                AddChanneld(keys=data_keys),
                NormalizeIntensityd(keys="image"),
                RandCropByPosNegLabeld(
                    keys=data_keys, label_key="mask", size=(256, 256, 16), num_samples=4, image_key="image"
                ),
            ]
        )
        
        self.train_dataset = monai.data.CacheDataset(
            data=train_dicts,
            transform=Compose(
                [
                    data_transforms,
                    self.augmentations,
                    ToTensord(keys=data_keys)
                ]
            ),
            cache_rate=1.0
        )
        
        self.val_dataset = monai.data.CacheDataset(
            data=val_dicts,
            transform=Compose(
                [
                    data_transforms,
                    ToTensord(keys=data_keys)
                ]
            ),
            cache_rate=1.0
        )
    def prepare_data(self):
        data_df = pd.read_csv(
            '/data/shared/prostate/yale_prostate/input_lists/MR_yale.csv')
        train_imgs = data_df['IMAGE'][0:295].tolist()
        train_masks = data_df['SEGM'][0:295].tolist()

        train_dicts = [{
            'image': image,
            'mask': mask
        } for (image, mask) in zip(train_imgs, train_masks)]
        train_dicts, val_dicts = train_test_split(train_dicts, test_size=0.2)

        # Basic transforms
        data_keys = ["image", "mask"]
        data_transforms = Compose([
            LoadNiftid(keys=data_keys),
            AddChanneld(keys=data_keys),
            ScaleIntensityRangePercentilesd(keys='image',
                                            lower=25,
                                            upper=75,
                                            b_min=-0.5,
                                            b_max=0.5)
        ])

        self.train_dataset = monai.data.CacheDataset(
            data=train_dicts,
            transform=Compose([
                data_transforms,
                RandCropByPosNegLabeld(keys=data_keys,
                                       label_key="mask",
                                       spatial_size=self.hparams.patch_size,
                                       num_samples=4,
                                       image_key="image",
                                       pos=0.8,
                                       neg=0.2),
                ToTensord(keys=data_keys)
            ]),
            cache_rate=1.0)

        self.val_dataset = monai.data.CacheDataset(
            data=val_dicts,
            transform=Compose([
                data_transforms,
                CenterSpatialCropd(keys=data_keys,
                                   roi_size=self.hparams.patch_size),
                ToTensord(keys=data_keys)
            ]),
            cache_rate=1.0)
Beispiel #7
0
    def test_tranform_dict(self, input):
        transforms = Compose([
            Range("random flip dict")(Flipd(keys="image")),
            Range()(ToTensord("image"))
        ])
        # Apply transforms
        output = transforms(input)["image"]

        # Decorate with NVTX Range
        transforms1 = Range()(transforms)
        transforms2 = Range("Transforms2")(transforms)
        transforms3 = Range(name="Transforms3", methods="__call__")(transforms)

        # Apply transforms with Range
        output1 = transforms1(input)["image"]
        output2 = transforms2(input)["image"]
        output3 = transforms3(input)["image"]

        # Check the outputs
        self.assertIsInstance(output, torch.Tensor)
        self.assertIsInstance(output1, torch.Tensor)
        self.assertIsInstance(output2, torch.Tensor)
        self.assertIsInstance(output3, torch.Tensor)
        np.testing.assert_equal(output.numpy(), output1.numpy())
        np.testing.assert_equal(output.numpy(), output2.numpy())
        np.testing.assert_equal(output.numpy(), output3.numpy())
Beispiel #8
0
 def val_pre_transforms(self, context: Context):
     return [
         LoadImaged(keys=("image", "label"), reader="ITKReader"),
         NormalizeLabelsInDatasetd(keys="label", label_names=self._labels),
         EnsureChannelFirstd(keys=("image", "label")),
         Orientationd(keys=["image", "label"], axcodes="RAS"),
         # This transform may not work well for MR images
         ScaleIntensityRanged(keys=("image"),
                              a_min=-175,
                              a_max=250,
                              b_min=0.0,
                              b_max=1.0,
                              clip=True),
         Resized(keys=("image", "label"),
                 spatial_size=self.spatial_size,
                 mode=("area", "nearest")),
         # Transforms for click simulation
         FindAllValidSlicesMissingLabelsd(keys="label", sids="sids"),
         AddInitialSeedPointMissingLabelsd(keys="label",
                                           guidance="guidance",
                                           sids="sids"),
         AddGuidanceSignalCustomd(
             keys="image",
             guidance="guidance",
             number_intensity_ch=self.number_intensity_ch),
         #
         ToTensord(keys=("image", "label")),
         SelectItemsd(keys=("image", "label", "guidance", "label_names")),
     ]
Beispiel #9
0
    def _get_loader(self, folders):
        images = []
        segs = []
        for folder in folders:
            images += glob(os.path.join(folder, "*_im.nii.gz"))
            segs += glob(os.path.join(folder, "*_seg.nii.gz"))
        images = sorted(images, key=os.path.basename)
        segs = sorted(segs, key=os.path.basename)

        files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

        transforms = Compose([
            LoadImaged(keys=["img", "seg"]),
            AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
            ScaleIntensityd(keys="img"),
            ToTensord(keys=["img", "seg"]),
        ])

        ds = CacheDataset(data=files, transform=transforms)
        loader = DataLoader(ds,
                            batch_size=1,
                            num_workers=4,
                            collate_fn=list_data_collate)

        return loader
Beispiel #10
0
    def test_decollation(self, batch_size=2, num_workers=2):

        im = create_test_image_2d(100, 101)[0]
        data = [{
            "image": make_nifti_image(im) if has_nib else im
        } for _ in range(6)]

        transforms = Compose([
            AddChanneld("image"),
            SpatialPadd("image", 150),
            RandFlipd("image", prob=1.0, spatial_axis=1),
            ToTensord("image"),
        ])
        # If nibabel present, read from disk
        if has_nib:
            transforms = Compose([LoadImaged("image"), transforms])

        dataset = CacheDataset(data, transforms, progress=False)
        loader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=num_workers)

        for b, batch_data in enumerate(loader):
            decollated_1 = decollate_batch(batch_data)
            decollated_2 = Decollated()(batch_data)

            for decollated in [decollated_1, decollated_2]:
                for i, d in enumerate(decollated):
                    self.check_match(dataset[b * batch_size + i], d)
    def test_decollation(self, *transforms):

        batch_size = 2
        num_workers = 2

        t_compose = Compose(
            [AddChanneld(KEYS),
             Compose(transforms),
             ToTensord(KEYS)])
        # If nibabel present, read from disk
        if has_nib:
            t_compose = Compose([LoadImaged("image"), t_compose])

        dataset = CacheDataset(self.data, t_compose, progress=False)
        loader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=num_workers)

        for b, batch_data in enumerate(loader):
            decollated_1 = decollate_batch(batch_data)
            decollated_2 = Decollated()(batch_data)

            for decollated in [decollated_1, decollated_2]:
                for i, d in enumerate(decollated):
                    self.check_match(dataset[b * batch_size + i], d)
Beispiel #12
0
 def pre_transforms(self):
     t = [
         LoadImaged(keys="image", reader="nibabelreader"),
         AddChanneld(keys="image"),
         # Spacing might not be needed as resize transform is used later.
         # Spacingd(keys="image", pixdim=self.spacing),
         RandAffined(
             keys="image",
             prob=1,
             rotate_range=(np.pi / 4, np.pi / 4, np.pi / 4),
             padding_mode="zeros",
             as_tensor_output=False,
         ),
         RandFlipd(keys="image", prob=0.5, spatial_axis=0),
         RandRotated(keys="image",
                     range_x=(-5, 5),
                     range_y=(-5, 5),
                     range_z=(-5, 5)),
         Resized(keys="image", spatial_size=self.spatial_size),
     ]
     # If using TTA for deepedit
     if self.deepedit:
         t.append(DiscardAddGuidanced(keys="image"))
     t.append(ToTensord(keys="image"))
     return Compose(t)
    def test_collation(self, _, transform, collate_fn, ndim):
        data = self.data_3d if ndim == 3 else self.data_2d
        if collate_fn:
            modified_transform = transform
        else:
            modified_transform = Compose(
                [transform,
                 ResizeWithPadOrCropd(KEYS, 100),
                 ToTensord(KEYS)])

        # num workers = 0 for mac or gpu transforms
        num_workers = 0 if sys.platform != "linux" or torch.cuda.is_available(
        ) else 2

        dataset = CacheDataset(data,
                               transform=modified_transform,
                               progress=False)
        loader = DataLoader(dataset,
                            num_workers,
                            batch_size=self.batch_size,
                            collate_fn=collate_fn)

        for item in loader:
            np.testing.assert_array_equal(
                item["image_transforms"][0]["do_transforms"],
                item["label_transforms"][0]["do_transforms"])
def run_inference_test(root_dir, device=torch.device("cuda:0")):
    images = sorted(glob(os.path.join(root_dir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose([
        LoadNiftid(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        # resampling with align_corners=True or dtype=float64 will generate
        # slight different results between PyTorch 1.5 an 1.6
        Spacingd(keys=["img", "seg"],
                 pixdim=[1.2, 0.8, 0.7],
                 mode=["bilinear", "nearest"],
                 dtype=np.float32),
        ScaleIntensityd(keys=["img", "seg"]),
        ToTensord(keys=["img", "seg"]),
    ])
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    # sliding window inferene need to input 1 image in every iteration
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)
    dice_metric = DiceMetric(include_background=True,
                             to_onehot_y=False,
                             sigmoid=True,
                             reduction="mean")

    model = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    model_filename = os.path.join(root_dir, "best_metric_model.pth")
    model.load_state_dict(torch.load(model_filename))
    model.eval()
    with torch.no_grad():
        metric_sum = 0.0
        metric_count = 0
        # resampling with align_corners=True or dtype=float64 will generate
        # slight different results between PyTorch 1.5 an 1.6
        saver = NiftiSaver(output_dir=os.path.join(root_dir, "output"),
                           dtype=np.float32)
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(
                device), val_data["seg"].to(device)
            # define sliding window size and batch size for windows inference
            sw_batch_size, roi_size = 4, (96, 96, 96)
            val_outputs = sliding_window_inference(val_images, roi_size,
                                                   sw_batch_size, model)
            value = dice_metric(y_pred=val_outputs, y=val_labels)
            not_nans = dice_metric.not_nans.item()
            metric_count += not_nans
            metric_sum += value.item() * not_nans
            val_outputs = (val_outputs.sigmoid() >= 0.5).float()
            saver.save_batch(val_outputs, val_data["img_meta_dict"])
        metric = metric_sum / metric_count
    return metric
Beispiel #15
0
 def train_pre_transforms(self, context: Context):
     return [
         LoadImaged(keys=("image", "label"), dtype=np.uint8),
         FilterImaged(keys="image", min_size=5),
         AsChannelFirstd(keys="image"),
         AddChanneld(keys="label"),
         ToTensord(keys="image"),
         TorchVisiond(keys="image",
                      name="ColorJitter",
                      brightness=64.0 / 255.0,
                      contrast=0.75,
                      saturation=0.25,
                      hue=0.04),
         ToNumpyd(keys="image"),
         RandRotate90d(keys=("image", "label"),
                       prob=0.5,
                       spatial_axes=(0, 1)),
         ScaleIntensityRangeD(keys="image",
                              a_min=0.0,
                              a_max=255.0,
                              b_min=-1.0,
                              b_max=1.0),
         AddInitialSeedPointExd(label="label", guidance="guidance"),
         AddGuidanceSignald(image="image",
                            guidance="guidance",
                            number_intensity_ch=3),
         EnsureTyped(keys=("image", "label")),
     ]
Beispiel #16
0
    def __init__(self, tranforms):
        self.tranform_list = []
        for tranform in tranforms:
            if 'LoadImaged' == tranform:
                self.tranform_list.append(LoadImaged(keys=["image", "label"]))
            elif 'AsChannelFirstd' == tranform:
                self.tranform_list.append(AsChannelFirstd(keys="image"))
            elif 'ConvertToMultiChannelBasedOnBratsClassesd' == tranform:
                self.tranform_list.append(ConvertToMultiChannelBasedOnBratsClassesd(keys="label"))
            elif 'Spacingd' == tranform:
                self.tranform_list.append(Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")))
            elif 'Orientationd' == tranform:
                self.tranform_list.append(Orientationd(keys=["image", "label"], axcodes="RAS"))
            elif 'CenterSpatialCropd' == tranform:
                self.tranform_list.append(CenterSpatialCropd(keys=["image", "label"], roi_size=[128, 128, 64]))
            elif 'NormalizeIntensityd' == tranform:
                self.tranform_list.append(NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True))
            elif 'ToTensord' == tranform:
                self.tranform_list.append(ToTensord(keys=["image", "label"]))
            elif 'Activations' == tranform:
                self.tranform_list.append(Activations(sigmoid=True))
            elif 'AsDiscrete' == tranform:
                self.tranform_list.append(AsDiscrete(threshold_values=True))
            else:
                raise ValueError(
                    f"Unsupported tranform: {tranform}. Please add it to support it."
                )

        super().__init__(self.tranform_list)
Beispiel #17
0
def run_inference_test(root_dir, device=torch.device("cuda:0")):
    images = sorted(glob(os.path.join(root_dir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose([
        LoadNiftid(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        ScaleIntensityd(keys=["img", "seg"]),
        ToTensord(keys=["img", "seg"]),
    ])
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    # sliding window inferene need to input 1 image in every iteration
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=4,
                            collate_fn=list_data_collate,
                            pin_memory=torch.cuda.is_available())

    model = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    model_filename = os.path.join(root_dir, "best_metric_model.pth")
    model.load_state_dict(torch.load(model_filename))
    model.eval()
    with torch.no_grad():
        metric_sum = 0.0
        metric_count = 0
        saver = NiftiSaver(output_dir=os.path.join(root_dir, "output"),
                           dtype=int)
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(
                device), val_data["seg"].to(device)
            # define sliding window size and batch size for windows inference
            sw_batch_size, roi_size = 4, (96, 96, 96)
            val_outputs = sliding_window_inference(val_images, roi_size,
                                                   sw_batch_size, model)
            value = compute_meandice(y_pred=val_outputs,
                                     y=val_labels,
                                     include_background=True,
                                     to_onehot_y=False,
                                     add_sigmoid=True)
            metric_count += len(value)
            metric_sum += value.sum().item()
            val_outputs = (val_outputs.sigmoid() >= 0.5).float()
            saver.save_batch(
                val_outputs, {
                    "filename_or_obj": val_data["img.filename_or_obj"],
                    "affine": val_data["img.affine"]
                })
        metric = metric_sum / metric_count
    return metric
    def setup(self, stage):
        data_dir = "data/"

        # Train imgs/masks
        train_imgs = []
        with open(data_dir + "train_imgs.txt", "r") as f:
            train_imgs = [image.rstrip() for image in f.readlines()]

        train_masks = []
        with open(data_dir + "train_masks.txt", "r") as f:
            train_masks = [mask.rstrip() for mask in f.readlines()]

        train_dicts = [{
            "image": image,
            "mask": mask
        } for (image, mask) in zip(train_imgs, train_masks)]

        train_dicts, val_dicts = train_test_split(train_dicts, test_size=0.2)

        # Basic transforms
        data_keys = ["image", "mask"]
        data_transforms = Compose([
            AddChanneld(keys=data_keys),
            NormalizeIntensityd(keys="image"),
            RandCropByPosNegLabeld(
                keys=data_keys,
                label_key="mask",
                spatial_size=self.hparams.patch_size,
                num_samples=4,
                image_key="image",
            ),
        ])

        self.train_dataset = monai.data.CacheDataset(
            data=train_dicts,
            transform=Compose([data_transforms,
                               ToTensord(keys=data_keys)]),
            cache_rate=1.0,
        )

        self.val_dataset = monai.data.CacheDataset(
            data=val_dicts,
            transform=Compose([data_transforms,
                               ToTensord(keys=data_keys)]),
            cache_rate=1.0,
        )
Beispiel #19
0
    def test_decollation_dict(self, *transforms):
        t_compose = Compose([AddChanneld(KEYS), Compose(transforms), ToTensord(KEYS)])
        # If nibabel present, read from disk
        if has_nib:
            t_compose = Compose([LoadImaged("image"), t_compose])

        dataset = CacheDataset(self.data_dict, t_compose, progress=False)
        self.check_decollate(dataset=dataset)
def run_inference_test(root_dir, device="cuda:0"):
    images = sorted(glob(os.path.join(root_dir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose(
        [
            LoadImaged(keys=["img", "seg"]),
            EnsureChannelFirstd(keys=["img", "seg"]),
            # resampling with align_corners=True or dtype=float64 will generate
            # slight different results between PyTorch 1.5 an 1.6
            Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32),
            ScaleIntensityd(keys="img"),
            ToTensord(keys=["img", "seg"]),
        ]
    )
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    # sliding window inference need to input 1 image in every iteration
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)
    val_post_tran = Compose([ToTensor(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
    dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)

    model = UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    model_filename = os.path.join(root_dir, "best_metric_model.pth")
    model.load_state_dict(torch.load(model_filename))
    with eval_mode(model):
        # resampling with align_corners=True or dtype=float64 will generate
        # slight different results between PyTorch 1.5 an 1.6
        saver = SaveImage(
            output_dir=os.path.join(root_dir, "output"),
            dtype=np.float32,
            output_ext=".nii.gz",
            output_postfix="seg",
            mode="bilinear",
        )
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device)
            # define sliding window size and batch size for windows inference
            sw_batch_size, roi_size = 4, (96, 96, 96)
            val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
            # decollate prediction into a list
            val_outputs = [val_post_tran(i) for i in decollate_batch(val_outputs)]
            val_meta = decollate_batch(val_data[PostFix.meta("img")])
            # compute metrics
            dice_metric(y_pred=val_outputs, y=val_labels)
            for img, meta in zip(val_outputs, val_meta):  # save a decollated batch of files
                saver(img, meta)

    return dice_metric.aggregate().item()
def main(tempdir):
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_2d(128, 128, num_seg_classes=1)
        Image.fromarray(im.astype("uint8")).save(os.path.join(tempdir, f"img{i:d}.png"))
        Image.fromarray(seg.astype("uint8")).save(os.path.join(tempdir, f"seg{i:d}.png"))

    images = sorted(glob(os.path.join(tempdir, "img*.png")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.png")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose(
        [
            LoadImaged(keys=["img", "seg"]),
            AddChanneld(keys=["img", "seg"]),
            ScaleIntensityd(keys="img"),
            ToTensord(keys=["img", "seg"]),
        ]
    )
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    # sliding window inference need to input 1 image in every iteration
    val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate)
    dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = UNet(
        dimensions=2,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    model.load_state_dict(torch.load("best_metric_model_segmentation2d_dict.pth"))

    model.eval()
    with torch.no_grad():
        metric_sum = 0.0
        metric_count = 0
        saver = PNGSaver(output_dir="./output")
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device)
            # define sliding window size and batch size for windows inference
            roi_size = (96, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
            value = dice_metric(y_pred=val_outputs, y=val_labels)
            metric_count += len(value)
            metric_sum += value.item() * len(value)
            val_outputs = val_outputs.sigmoid() >= 0.5
            saver.save_batch(val_outputs)
        metric = metric_sum / metric_count
        print("evaluation metric:", metric)
    def test_values(self):
        testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                                   "testing_data")
        transform = Compose([
            LoadImaged(keys="image"),
            AddChanneld(keys="image"),
            ScaleIntensityd(keys="image"),
            ToTensord(keys=["image", "label"]),
        ])

        def _test_dataset(dataset):
            self.assertEqual(
                len(dataset),
                int(MEDNIST_FULL_DATASET_LENGTH * dataset.test_frac))
            self.assertTrue("image" in dataset[0])
            self.assertTrue("label" in dataset[0])
            self.assertTrue(PostFix.meta("image") in dataset[0])
            self.assertTupleEqual(dataset[0]["image"].shape, (1, 64, 64))

        with skip_if_downloading_fails():
            data = MedNISTDataset(root_dir=testing_dir,
                                  transform=transform,
                                  section="test",
                                  download=True,
                                  copy_cache=False)

        _test_dataset(data)

        # testing from
        data = MedNISTDataset(root_dir=Path(testing_dir),
                              transform=transform,
                              section="test",
                              download=False)
        self.assertEqual(data.get_num_classes(), 6)
        _test_dataset(data)
        data = MedNISTDataset(root_dir=testing_dir,
                              section="test",
                              download=False)
        self.assertTupleEqual(data[0]["image"].shape, (64, 64))
        # test same dataset length with different random seed
        data = MedNISTDataset(root_dir=testing_dir,
                              transform=transform,
                              section="test",
                              download=False,
                              seed=42)
        _test_dataset(data)
        self.assertEqual(data[0]["class_name"], "AbdomenCT")
        self.assertEqual(data[0]["label"].cpu().item(), 0)
        shutil.rmtree(os.path.join(testing_dir, "MedNIST"))
        try:
            MedNISTDataset(root_dir=testing_dir,
                           transform=transform,
                           section="test",
                           download=False)
        except RuntimeError as e:
            print(str(e))
            self.assertTrue(str(e).startswith("Cannot find dataset directory"))
Beispiel #23
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/
    images = [
        os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI607-Guys-1097-T1.nii.gz"]),
        os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI175-HH-1570-T1.nii.gz"]),
        os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI385-HH-2078-T1.nii.gz"]),
        os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI344-Guys-0905-T1.nii.gz"]),
        os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI409-Guys-0960-T1.nii.gz"]),
        os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI584-Guys-1129-T1.nii.gz"]),
        os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI253-HH-1694-T1.nii.gz"]),
        os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI092-HH-1436-T1.nii.gz"]),
        os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI574-IOP-1156-T1.nii.gz"]),
        os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI585-Guys-1130-T1.nii.gz"]),
    ]

    # 2 binary labels for gender classification: man and woman
    labels = np.array([0, 0, 1, 0, 1, 0, 1, 0, 1, 0], dtype=np.int64)
    val_files = [{"img": img, "label": label} for img, label in zip(images, labels)]

    # Define transforms for image
    val_transforms = Compose(
        [
            LoadNiftid(keys=["img"]),
            AddChanneld(keys=["img"]),
            ScaleIntensityd(keys=["img"]),
            Resized(keys=["img"], spatial_size=(96, 96, 96)),
            ToTensord(keys=["img"]),
        ]
    )

    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available())

    # Create DenseNet121
    device = torch.device("cuda:0")
    model = monai.networks.nets.densenet.densenet121(spatial_dims=3, in_channels=1, out_channels=2).to(device)

    model.load_state_dict(torch.load("best_metric_model.pth"))
    model.eval()
    with torch.no_grad():
        num_correct = 0.0
        metric_count = 0
        saver = CSVSaver(output_dir="./output")
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(device), val_data["label"].to(device)
            val_outputs = model(val_images).argmax(dim=1)
            value = torch.eq(val_outputs, val_labels)
            metric_count += len(value)
            num_correct += value.sum().item()
            saver.save_batch(val_outputs, val_data["img_meta_dict"])
        metric = num_correct / metric_count
        print("evaluation metric:", metric)
        saver.finalize()
Beispiel #24
0
    def test_values(self):
        testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                                   "testing_data")
        transform = Compose([
            LoadNiftid(keys=["image", "label"]),
            AddChanneld(keys=["image", "label"]),
            ScaleIntensityd(keys="image"),
            ToTensord(keys=["image", "label"]),
        ])

        def _test_dataset(dataset):
            self.assertEqual(len(dataset), 52)
            self.assertTrue("image" in dataset[0])
            self.assertTrue("label" in dataset[0])
            self.assertTrue("image_meta_dict" in dataset[0])
            self.assertTupleEqual(dataset[0]["image"].shape, (1, 33, 47, 34))

        try:  # will start downloading if testing_dir doesn't have the Decathlon files
            data = DecathlonDataset(
                root_dir=testing_dir,
                task="Task04_Hippocampus",
                transform=transform,
                section="validation",
                download=True,
            )
        except (ContentTooShortError, HTTPError, RuntimeError) as e:
            print(str(e))
            if isinstance(e, RuntimeError):
                # FIXME: skip MD5 check as current downloading method may fail
                self.assertTrue(str(e).startswith("MD5 check"))
            return  # skipping this test due the network connection errors

        _test_dataset(data)
        data = DecathlonDataset(root_dir=testing_dir,
                                task="Task04_Hippocampus",
                                transform=transform,
                                section="validation",
                                download=False)
        _test_dataset(data)
        data = DecathlonDataset(root_dir=testing_dir,
                                task="Task04_Hippocampus",
                                section="validation",
                                download=False)
        self.assertTupleEqual(data[0]["image"].shape, (33, 47, 34))
        shutil.rmtree(os.path.join(testing_dir, "Task04_Hippocampus"))
        try:
            data = DecathlonDataset(
                root_dir=testing_dir,
                task="Task04_Hippocampus",
                transform=transform,
                section="validation",
                download=False,
            )
        except RuntimeError as e:
            print(str(e))
            self.assertTrue(str(e).startswith("Cannot find dataset directory"))
Beispiel #25
0
def get_image_transforms():
    itk_reader = monai.data.ITKReader()
    # Define transforms for image
    image_transforms = Compose([
        LoadImaged(keys=['img'], reader=itk_reader),
        EnsureChannelFirstd(keys=['img']),
        ScaleIntensityd(keys=['img']),
        ToTensord(keys=['img']),
    ])
    return image_transforms
Beispiel #26
0
    def get_transforms(self):
        self.logger.info("Getting transforms...")
        # Setup transforms of data sets
        train_transforms = Compose([
            LoadNiftid(keys=["image", "label"]),
            AddChanneld(keys=["image", "label"]),
            Orientationd(keys=["image", "label"], axcodes="RAS"),
            NormalizeIntensityd(keys=["image"]),
            SpatialPadd(keys=["image", "label"],
                        spatial_size=self.pad_crop_shape),
            RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
            RandSpatialCropd(keys=["image", "label"],
                             roi_size=self.pad_crop_shape,
                             random_center=True,
                             random_size=False),
            ToTensord(keys=["image", "label"]),
        ])

        val_transforms = Compose([
            LoadNiftid(keys=["image", "label"]),
            AddChanneld(keys=["image", "label"]),
            Orientationd(keys=["image", "label"], axcodes="RAS"),
            NormalizeIntensityd(keys=["image"]),
            SpatialPadd(keys=["image", "label"],
                        spatial_size=self.pad_crop_shape),
            RandSpatialCropd(
                keys=["image", "label"],
                roi_size=self.pad_crop_shape,
                random_center=True,
                random_size=False,
            ),
            ToTensord(keys=["image", "label"]),
        ])

        test_transforms = Compose([
            LoadNiftid(keys=["image", "label"]),
            AddChanneld(keys=["image", "label"]),
            Orientationd(keys=["image", "label"], axcodes="RAS"),
            NormalizeIntensityd(keys=["image"]),
            ToTensord(keys=["image", "label"]),
        ])

        return train_transforms, val_transforms, test_transforms
Beispiel #27
0
    def test_values(self):
        tempdir = tempfile.mkdtemp()
        transform = Compose([
            LoadNiftid(keys=["image", "label"]),
            AddChanneld(keys=["image", "label"]),
            ScaleIntensityd(keys="image"),
            ToTensord(keys=["image", "label"]),
        ])

        def _test_dataset(dataset):
            self.assertEqual(len(dataset), 52)
            self.assertTrue("image" in dataset[0])
            self.assertTrue("label" in dataset[0])
            self.assertTrue("image_meta_dict" in dataset[0])
            self.assertTupleEqual(dataset[0]["image"].shape, (1, 33, 47, 34))

        try:
            data = DecathlonDataset(root_dir=tempdir,
                                    task="Task04_Hippocampus",
                                    transform=transform,
                                    section="validation",
                                    download=True)
        except RuntimeError as e:
            if str(e).startswith(
                    "download failed due to network issue or permission denied."
            ):
                shutil.rmtree(tempdir)
                return

        _test_dataset(data)
        data = DecathlonDataset(root_dir=tempdir,
                                task="Task04_Hippocampus",
                                transform=transform,
                                section="validation",
                                download=False)
        _test_dataset(data)
        data = DecathlonDataset(root_dir=tempdir,
                                task="Task04_Hippocampus",
                                section="validation",
                                download=False)
        self.assertTupleEqual(data[0]["image"].shape, (33, 47, 34))
        shutil.rmtree(os.path.join(tempdir, "Task04_Hippocampus"))
        try:
            data = DecathlonDataset(root_dir=tempdir,
                                    task="Task04_Hippocampus",
                                    transform=transform,
                                    section="validation",
                                    download=False)
        except RuntimeError as e:
            print(str(e))
            self.assertTrue(
                str(e).startswith("can not find dataset directory"))

        shutil.rmtree(tempdir)
Beispiel #28
0
    def test_values(self):
        testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                                   "testing_data")
        train_transform = Compose([
            LoadImaged(keys=["image", "label"]),
            AddChanneld(keys=["image", "label"]),
            ScaleIntensityd(keys="image"),
            ToTensord(keys=["image", "label"]),
        ])
        val_transform = LoadImaged(keys=["image", "label"])

        def _test_dataset(dataset):
            self.assertEqual(len(dataset), 52)
            self.assertTrue("image" in dataset[0])
            self.assertTrue("label" in dataset[0])
            self.assertTrue("image_meta_dict" in dataset[0])
            self.assertTupleEqual(dataset[0]["image"].shape, (1, 34, 49, 41))

        cvdataset = CrossValidation(
            dataset_cls=DecathlonDataset,
            nfolds=5,
            seed=12345,
            root_dir=testing_dir,
            task="Task04_Hippocampus",
            section="validation",
            transform=train_transform,
            download=True,
        )

        try:  # will start downloading if testing_dir doesn't have the Decathlon files
            data = cvdataset.get_dataset(folds=0)
        except (ContentTooShortError, HTTPError, RuntimeError) as e:
            print(str(e))
            if isinstance(e, RuntimeError):
                # FIXME: skip MD5 check as current downloading method may fail
                self.assertTrue(str(e).startswith("md5 check"))
            return  # skipping this test due the network connection errors

        _test_dataset(data)

        # test training data for fold [1, 2, 3, 4] of 5 splits
        data = cvdataset.get_dataset(folds=[1, 2, 3, 4])
        self.assertTupleEqual(data[0]["image"].shape, (1, 35, 52, 33))
        self.assertEqual(len(data), 208)
        # test train / validation for fold 4 of 5 splits
        data = cvdataset.get_dataset(folds=[4],
                                     transform=val_transform,
                                     download=False)
        # val_transform doesn't add the channel dim to shape
        self.assertTupleEqual(data[0]["image"].shape, (38, 53, 30))
        self.assertEqual(len(data), 52)
        data = cvdataset.get_dataset(folds=[0, 1, 2, 3])
        self.assertTupleEqual(data[0]["image"].shape, (1, 34, 49, 41))
        self.assertEqual(len(data), 208)
 def test_deep_copy(self):
     data = {"img": np.ones((1, 10, 11, 12))}
     num_samples = 3
     sampler = RandSpatialCropSamplesd(
         keys=["img"], roi_size=(3, 3, 3), num_samples=num_samples, random_center=True, random_size=False
     )
     transform = Compose([ToTensord(keys="img"), sampler])
     samples = transform(data)
     self.assertEqual(len(samples), num_samples)
     for sample in samples:
         self.assertEqual(len(sample["img_transforms"]), len(transform))
Beispiel #30
0
 def train_post_transforms(self, context: Context):
     return [
         ToTensord(keys=("pred", "label")),
         Activationsd(keys="pred", softmax=True),
         AsDiscreted(
             keys=("pred", "label"),
             argmax=(True, False),
             to_onehot=True,
             n_classes=2,
         ),
     ]