Ejemplo n.º 1
0
    def test_saved_content(self, test_data, output_ext, resample, save_batch):
        with tempfile.TemporaryDirectory() as tempdir:
            trans = SaveImaged(
                keys="img",
                output_dir=tempdir,
                output_ext=output_ext,
                resample=resample,
                save_batch=save_batch,
            )
            trans(test_data)

            if save_batch:
                for i in range(8):
                    filepath = os.path.join(
                        "testfile" + str(i),
                        "testfile" + str(i) + "_trans" + output_ext)
                    self.assertTrue(
                        os.path.exists(os.path.join(tempdir, filepath)))
            else:
                patch_index = test_data["img_meta_dict"].get(
                    "patch_index", None)
                patch_index = f"_{patch_index}" if patch_index is not None else ""
                filepath = os.path.join(
                    "testfile0",
                    "testfile0" + "_trans" + patch_index + output_ext)
                self.assertTrue(os.path.exists(os.path.join(tempdir,
                                                            filepath)))
Ejemplo n.º 2
0
 def test_correct(self):
     with tempfile.TemporaryDirectory() as temp_dir:
         transforms = Compose([
             LoadImaged(("im1", "im2")),
             EnsureChannelFirstd(("im1", "im2")),
             CopyItemsd(("im2", "im2_meta_dict"),
                        names=("im3", "im3_meta_dict")),
             ResampleToMatchd("im3", "im1_meta_dict"),
             Lambda(update_fname),
             SaveImaged("im3",
                        output_dir=temp_dir,
                        output_postfix="",
                        separate_folder=False),
         ])
         data = transforms({"im1": self.fnames[0], "im2": self.fnames[1]})
         # check that output sizes match
         assert_allclose(data["im1"].shape, data["im3"].shape)
         # and that the meta data has been updated accordingly
         assert_allclose(data["im3"].shape[1:],
                         data["im3_meta_dict"]["spatial_shape"],
                         type_test=False)
         assert_allclose(data["im3_meta_dict"]["affine"],
                         data["im1_meta_dict"]["affine"])
         # check we're different from the original
         self.assertTrue(
             any(i != j
                 for i, j in zip(data["im3"].shape, data["im2"].shape)))
         self.assertTrue(
             any(i != j
                 for i, j in zip(data["im3_meta_dict"]["affine"].flatten(
                 ), data["im2_meta_dict"]["affine"].flatten())))
         # test the inverse
         data = Invertd("im3", transforms, "im3")(data)
         assert_allclose(data["im2"].shape, data["im3"].shape)
Ejemplo n.º 3
0
 def test_correct(self):
     transforms = Compose([
         LoadImaged(("im1", "im2")),
         EnsureChannelFirstd(("im1", "im2")),
         CopyItemsd(("im2"), names=("im3")),
         ResampleToMatchd("im3", "im1"),
         Lambda(update_fname),
         SaveImaged("im3",
                    output_dir=self.tmpdir,
                    output_postfix="",
                    separate_folder=False,
                    resample=False),
     ])
     data = transforms({"im1": self.fnames[0], "im2": self.fnames[1]})
     # check that output sizes match
     assert_allclose(data["im1"].shape, data["im3"].shape)
     # and that the meta data has been updated accordingly
     assert_allclose(data["im3"].affine, data["im1"].affine)
     # check we're different from the original
     self.assertTrue(
         any(i != j for i, j in zip(data["im3"].shape, data["im2"].shape)))
     self.assertTrue(
         any(i != j for i, j in zip(data["im3"].affine.flatten(),
                                    data["im2"].affine.flatten())))
     # test the inverse
     data = Invertd("im3", transforms)(data)
     assert_allclose(data["im2"].shape, data["im3"].shape)
Ejemplo n.º 4
0
    def test_saved_content(self, test_data, output_ext, resample):
        with tempfile.TemporaryDirectory() as tempdir:
            trans = SaveImaged(
                keys=["img", "pred"],
                output_dir=tempdir,
                output_ext=output_ext,
                resample=resample,
                allow_missing_keys=True,
            )
            trans(test_data)

            patch_index = test_data["img"].meta.get("patch_index", None)
            patch_index = f"_{patch_index}" if patch_index is not None else ""
            filepath = os.path.join(
                "testfile0", "testfile0" + "_trans" + patch_index + output_ext)
            self.assertTrue(os.path.exists(os.path.join(tempdir, filepath)))
Ejemplo n.º 5
0
    def test_saved_content(self, test_data, output_ext, resample, save_batch):
        with tempfile.TemporaryDirectory() as tempdir:
            trans = SaveImaged(
                keys="img",
                output_dir=tempdir,
                output_ext=output_ext,
                resample=resample,
                save_batch=save_batch,
            )
            trans(test_data)

            if save_batch:
                for i in range(8):
                    filepath = os.path.join(
                        "testfile" + str(i),
                        "testfile" + str(i) + "_trans" + output_ext)
                    self.assertTrue(
                        os.path.exists(os.path.join(tempdir, filepath)))
            else:
                filepath = os.path.join("testfile0",
                                        "testfile0" + "_trans" + output_ext)
                self.assertTrue(os.path.exists(os.path.join(tempdir,
                                                            filepath)))
Ejemplo n.º 6
0
    def test_correct(self, reader, writer):
        loader = Compose([
            LoadImaged(("im1", "im2"), reader=reader),
            EnsureChannelFirstd(("im1", "im2"))
        ])
        data = loader({"im1": self.fnames[0], "im2": self.fnames[1]})

        with self.assertRaises(ValueError):
            ResampleToMatch(mode=None)(img=data["im2"], img_dst=data["im1"])
        im_mod = ResampleToMatch()(data["im2"], data["im1"])
        saver = SaveImaged("im3",
                           output_dir=self.tmpdir,
                           output_postfix="",
                           separate_folder=False,
                           writer=writer,
                           resample=False)
        im_mod.meta["filename_or_obj"] = get_rand_fname()
        saver({"im3": im_mod})

        saved = nib.load(
            os.path.join(self.tmpdir, im_mod.meta["filename_or_obj"]))
        assert_allclose(data["im1"].shape[1:], saved.shape)
        assert_allclose(saved.header["dim"][:4], np.array([3, 384, 384, 19]))
Ejemplo n.º 7
0
def run_inference_test(root_dir,
                       model_file,
                       device="cuda:0",
                       amp=False,
                       num_workers=4):
    images = sorted(glob(os.path.join(root_dir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    val_files = [{
        "image": img,
        "label": seg
    } for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose([
        LoadImaged(keys=["image", "label"]),
        AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
        ScaleIntensityd(keys=["image", "label"]),
        ToTensord(keys=["image", "label"]),
    ])

    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = monai.data.DataLoader(val_ds,
                                       batch_size=1,
                                       num_workers=num_workers)

    # create UNet, DiceLoss and Adam optimizer
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    val_post_transforms = Compose([
        Activationsd(keys="pred", sigmoid=True),
        AsDiscreted(keys="pred", threshold_values=True),
        KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
        # test the case that `pred` in `engine.state.output`, while `image_meta_dict` in `engine.state.batch`
        SaveImaged(
            keys="pred",
            meta_keys="image_meta_dict",
            output_dir=root_dir,
            output_postfix="seg_transform",
            save_batch=True,
        ),
    ])
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        CheckpointLoader(load_path=f"{model_file}", load_dict={"net": net}),
        SegmentationSaver(
            output_dir=root_dir,
            output_postfix="seg_handler",
            batch_transform=lambda batch: batch["image_meta_dict"],
            output_transform=lambda output: output["pred"],
        ),
    ]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96),
                                     sw_batch_size=4,
                                     overlap=0.5),
        post_transform=val_post_transforms,
        key_val_metric={
            "val_mean_dice":
            MeanDice(include_background=True,
                     output_transform=lambda x: (x["pred"], x["label"]))
        },
        additional_metrics={
            "val_acc":
            Accuracy(output_transform=lambda x: (x["pred"], x["label"]))
        },
        val_handlers=val_handlers,
        amp=True if amp else False,
    )
    evaluator.run()

    return evaluator.state.best_metric
Ejemplo n.º 8
0
def main(tempdir):
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # create a temporary directory and 40 random image, mask pairs
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_3d(128,
                                       128,
                                       128,
                                       num_seg_classes=1,
                                       channel_dim=-1)
        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))
        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
    val_files = [{
        "image": img,
        "label": seg
    } for img, seg in zip(images, segs)]

    # model file path
    model_file = glob("./runs/net_key_metric*")[0]

    # define transforms for image and segmentation
    val_transforms = Compose([
        LoadImaged(keys=["image", "label"]),
        AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
        ScaleIntensityd(keys="image"),
        EnsureTyped(keys=["image", "label"]),
    ])

    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = monai.networks.nets.UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    val_post_transforms = Compose([
        EnsureTyped(keys="pred"),
        Activationsd(keys="pred", sigmoid=True),
        AsDiscreted(keys="pred", threshold=0.5),
        KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
        SaveImaged(keys="pred",
                   meta_keys="image_meta_dict",
                   output_dir="./runs/")
    ])
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        CheckpointLoader(load_path=model_file, load_dict={"net": net}),
    ]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96),
                                     sw_batch_size=4,
                                     overlap=0.5),
        postprocessing=val_post_transforms,
        key_val_metric={
            "val_mean_dice":
            MeanDice(include_background=True,
                     output_transform=from_engine(["pred", "label"]))
        },
        additional_metrics={
            "val_acc":
            Accuracy(output_transform=from_engine(["pred", "label"]))
        },
        val_handlers=val_handlers,
        # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation
        amp=True if monai.utils.get_torch_version_tuple() >= (1, 6) else False,
    )
    evaluator.run()
Ejemplo n.º 9
0
def main(tempdir):
    print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, _ = create_test_image_3d(128,
                                     128,
                                     128,
                                     num_seg_classes=1,
                                     channel_dim=-1)
        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    files = [{"img": img} for img in images]

    # define pre transforms
    pre_transforms = Compose([
        LoadImaged(keys="img"),
        EnsureChannelFirstd(keys="img"),
        Orientationd(keys="img", axcodes="RAS"),
        Resized(keys="img",
                spatial_size=(96, 96, 96),
                mode="trilinear",
                align_corners=True),
        ScaleIntensityd(keys="img"),
        EnsureTyped(keys="img"),
    ])
    # define dataset and dataloader
    dataset = Dataset(data=files, transform=pre_transforms)
    dataloader = DataLoader(dataset, batch_size=2, num_workers=4)
    # define post transforms
    post_transforms = Compose([
        EnsureTyped(keys="pred"),
        Activationsd(keys="pred", sigmoid=True),
        Invertd(
            keys=
            "pred",  # invert the `pred` data field, also support multiple fields
            transform=pre_transforms,
            orig_keys=
            "img",  # get the previously applied pre_transforms information on the `img` data field,
            # then invert `pred` based on this information. we can use same info
            # for multiple fields, also support different orig_keys for different fields
            meta_keys=
            "pred_meta_dict",  # key field to save inverted meta data, every item maps to `keys`
            orig_meta_keys=
            "img_meta_dict",  # get the meta data from `img_meta_dict` field when inverting,
            # for example, may need the `affine` to invert `Spacingd` transform,
            # multiple fields can use the same meta data to invert
            meta_key_postfix=
            "meta_dict",  # if `meta_keys=None`, use "{keys}_{meta_key_postfix}" as the meta key,
            # if `orig_meta_keys=None`, use "{orig_keys}_{meta_key_postfix}",
            # otherwise, no need this arg during inverting
            nearest_interp=
            False,  # don't change the interpolation mode to "nearest" when inverting transforms
            # to ensure a smooth output, then execute `AsDiscreted` transform
            to_tensor=True,  # convert to PyTorch Tensor after inverting
        ),
        AsDiscreted(keys="pred", threshold=0.5),
        SaveImaged(keys="pred",
                   meta_keys="pred_meta_dict",
                   output_dir="./out",
                   output_postfix="seg",
                   resample=False),
    ])

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    net.load_state_dict(
        torch.load("best_metric_model_segmentation3d_dict.pth"))

    net.eval()
    with torch.no_grad():
        for d in dataloader:
            images = d["img"].to(device)
            # define sliding window size and batch size for windows inference
            d["pred"] = sliding_window_inference(inputs=images,
                                                 roi_size=(96, 96, 96),
                                                 sw_batch_size=4,
                                                 predictor=net)
            # decollate the batch data into a list of dictionaries, then execute postprocessing transforms
            d = [post_transforms(i) for i in decollate_batch(d)]
Ejemplo n.º 10
0
def main(tempdir):
    print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, _ = create_test_image_3d(128,
                                     128,
                                     128,
                                     num_seg_classes=1,
                                     channel_dim=-1)
        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    files = [{"img": img} for img in images]

    # define pre transforms
    pre_transforms = Compose([
        LoadImaged(keys="img"),
        EnsureChannelFirstd(keys="img"),
        Orientationd(keys="img", axcodes="RAS"),
        Resized(keys="img",
                spatial_size=(96, 96, 96),
                mode="trilinear",
                align_corners=True),
        ScaleIntensityd(keys="img"),
        ToTensord(keys="img"),
    ])
    # define dataset and dataloader
    dataset = Dataset(data=files, transform=pre_transforms)
    dataloader = DataLoader(dataset, batch_size=2, num_workers=4)
    # define post transforms
    post_transforms = Compose([
        Activationsd(keys="pred", sigmoid=True),
        AsDiscreted(keys="pred", threshold_values=True),
        Invertd(keys="pred",
                transform=pre_transforms,
                loader=dataloader,
                orig_keys="img",
                nearest_interp=True),
        SaveImaged(keys="pred_inverted",
                   output_dir="./output",
                   output_postfix="seg",
                   resample=False),
    ])

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    net.load_state_dict(
        torch.load("best_metric_model_segmentation3d_dict.pth"))

    net.eval()
    with torch.no_grad():
        for d in dataloader:
            images = d["img"].to(device)
            # define sliding window size and batch size for windows inference
            d["pred"] = sliding_window_inference(inputs=images,
                                                 roi_size=(96, 96, 96),
                                                 sw_batch_size=4,
                                                 predictor=net)
            # execute post transforms to invert spatial transforms and save to NIfTI files
            post_transforms(d)