コード例 #1
0
 def test_content(self, input_args, expected_value):
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
     dataloader = [{
         "image": torch.tensor([1, 2]),
         "label": torch.tensor([3, 4]),
         "extra1": torch.tensor([5, 6]),
         "extra2": 16,
         "extra3": "test",
     }]
     # set up engine
     evaluator = SupervisedEvaluator(
         device=device,
         val_data_loader=dataloader,
         epoch_length=1,
         network=TestNet(),
         non_blocking=True,
         prepare_batch=PrepareBatchExtraInput(**input_args),
         decollate=False,
     )
     evaluator.run()
     output = evaluator.state.output
     assert_allclose(output["image"], torch.tensor([1, 2], device=device))
     assert_allclose(output["label"], torch.tensor([3, 4], device=device))
     for k, v in output["pred"].items():
         if isinstance(v, torch.Tensor):
             assert_allclose(v, expected_value[k].to(device))
         else:
             self.assertEqual(v, expected_value[k])
コード例 #2
0
    def test_compute(self, input_params, decollate, expected):
        data = [
            {"image": torch.tensor([[[[2.0], [3.0]]]]), "filename": ["test1"]},
            {"image": torch.tensor([[[[6.0], [8.0]]]]), "filename": ["test2"]},
        ]
        # set up engine, PostProcessing handler works together with postprocessing transforms of engine
        engine = SupervisedEvaluator(
            device=torch.device("cpu:0"),
            val_data_loader=data,
            epoch_length=2,
            network=torch.nn.PReLU(),
            postprocessing=Compose([Activationsd(keys="pred", sigmoid=True)]),
            val_handlers=[PostProcessing(**input_params)],
            decollate=decollate,
        )
        engine.run()

        if isinstance(engine.state.output, list):
            # test decollated list items
            for o, e in zip(engine.state.output, expected):
                torch.testing.assert_allclose(o["pred"], e)
                filename = o.get("filename_bak")
                if filename is not None:
                    self.assertEqual(filename, "test2")
        else:
            # test batch data
            torch.testing.assert_allclose(engine.state.output["pred"], expected)
コード例 #3
0
    def test_compute(self, input_params, expected):
        data = [
            {
                "image": torch.tensor([[[[2.0], [3.0]]]]),
                "filename": "test1"
            },
            {
                "image": torch.tensor([[[[6.0], [8.0]]]]),
                "filename": "test2"
            },
        ]
        # set up engine, PostProcessing handler works together with post_transform of engine
        engine = SupervisedEvaluator(
            device=torch.device("cpu:0"),
            val_data_loader=data,
            epoch_length=2,
            network=torch.nn.PReLU(),
            post_transform=Compose([Activationsd(keys="pred", sigmoid=True)]),
            val_handlers=[PostProcessing(**input_params)],
        )
        engine.run()

        torch.testing.assert_allclose(engine.state.output["pred"], expected)
        filename = engine.state.output.get("filename_bak")
        if filename is not None:
            self.assertEqual(filename, "test2")
コード例 #4
0
 def test_empty_data(self):
     dataloader = []
     evaluator = SupervisedEvaluator(
         val_data_loader=dataloader,
         device=torch.device("cpu"),
         epoch_length=0,
         network=TestNet(),
         non_blocking=False,
         prepare_batch=PrepareBatchDefault(),
         decollate=False,
     )
     evaluator.run()
コード例 #5
0
    def test_compute(self, data, expected):
        # Set up handlers
        handlers = [
            # Mark with Ignite Event
            MarkHandler(Events.STARTED),
            # Mark with literal
            MarkHandler("EPOCH_STARTED"),
            # Mark with literal and providing the message
            MarkHandler("EPOCH_STARTED", "Start of the epoch"),
            # Define a range using one prefix (between BATCH_STARTED and BATCH_COMPLETED)
            RangeHandler("Batch"),
            # Define a range using a pair of events
            RangeHandler((Events.STARTED, Events.COMPLETED)),
            # Define a range using a pair of literals
            RangeHandler(("GET_BATCH_STARTED", "GET_BATCH_COMPLETED"),
                         msg="Batching!"),
            # Define a range using a pair of literal and events
            RangeHandler(("GET_BATCH_STARTED", Events.COMPLETED)),
            # Define the start of range using literal
            RangePushHandler("ITERATION_STARTED"),
            # Define the start of range using event
            RangePushHandler(Events.ITERATION_STARTED, "Iteration 2"),
            # Define the start of range using literals and providing message
            RangePushHandler("EPOCH_STARTED", "Epoch 2"),
            # Define the end of range using Ignite Event
            RangePopHandler(Events.ITERATION_COMPLETED),
            RangePopHandler(Events.EPOCH_COMPLETED),
            # Define the end of range using literal
            RangePopHandler("ITERATION_COMPLETED"),
            # Other handlers
            StatsHandler(tag_name="train",
                         output_transform=from_engine(["label"], first=True)),
        ]

        # Set up an engine
        engine = SupervisedEvaluator(
            device=torch.device("cpu:0"),
            val_data_loader=data,
            epoch_length=1,
            network=torch.nn.PReLU(),
            postprocessing=lambda x: dict(pred=x["pred"] + 1.0),
            decollate=True,
            val_handlers=handlers,
        )
        # Run the engine
        engine.run()

        # Get the output from the engine
        output = engine.state.output[0]

        torch.testing.assert_allclose(output["pred"], expected)
コード例 #6
0
 def test_compute(self, dataloaders):
     device = torch.device(f"cuda:{dist.get_rank()}" if torch.cuda.is_available() else "cpu")
     dataloader = dataloaders[dist.get_rank()]
     # set up engine
     evaluator = SupervisedEvaluator(
         device=device,
         val_data_loader=dataloader,
         epoch_length=len(dataloader),
         network=TestNet(),
         non_blocking=False,
         prepare_batch=PrepareBatchDefault(),
         decollate=False,
     )
     evaluator.run()
     output = evaluator.state.output
     if len(dataloader) > 0:
         assert_allclose(output["image"], dataloader[-1]["image"].to(device=device))
         assert_allclose(output["label"], dataloader[-1]["label"].to(device=device))
コード例 #7
0
    def test_compute(self):
        data = [
            {
                "image": torch.tensor([[[[2.0], [3.0]]]]),
                "filename": ["test1"]
            },
            {
                "image": torch.tensor([[[[6.0], [8.0]]]]),
                "filename": ["test2"]
            },
        ]

        handlers = [
            DecollateBatch(event="MODEL_COMPLETED"),
            PostProcessing(transform=Compose([
                Activationsd(keys="pred", sigmoid=True),
                CopyItemsd(keys="filename", times=1, names="filename_bak"),
                AsDiscreted(keys="pred",
                            threshold_values=True,
                            to_onehot=True,
                            num_classes=2),
            ])),
        ]
        # set up engine, PostProcessing handler works together with postprocessing transforms of engine
        engine = SupervisedEvaluator(
            device=torch.device("cpu:0"),
            val_data_loader=data,
            epoch_length=2,
            network=torch.nn.PReLU(),
            # set decollate=False and execute some postprocessing first, then decollate in handlers
            postprocessing=lambda x: dict(pred=x["pred"] + 1.0),
            decollate=False,
            val_handlers=handlers,
        )
        engine.run()

        expected = torch.tensor([[[[1.0], [1.0]], [[0.0], [0.0]]]])

        for o, e in zip(engine.state.output, expected):
            torch.testing.assert_allclose(o["pred"], e)
            filename = o.get("filename_bak")
            if filename is not None:
                self.assertEqual(filename, "test2")
コード例 #8
0
 def test_content(self):
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
     dataloader = [{
         "image": torch.tensor([1, 2]),
         "label": torch.tensor([3, 4]),
         "extra1": torch.tensor([5, 6]),
         "extra2": 16,
         "extra3": "test",
     }]
     # set up engine
     evaluator = SupervisedEvaluator(
         device=device,
         val_data_loader=dataloader,
         epoch_length=1,
         network=TestNet(),
         non_blocking=False,
         prepare_batch=PrepareBatchDefault(),
         decollate=False,
     )
     evaluator.run()
     output = evaluator.state.output
     assert_allclose(output["image"], torch.tensor([1, 2], device=device))
     assert_allclose(output["label"], torch.tensor([3, 4], device=device))
コード例 #9
0
def main(tempdir):
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # create a temporary directory and 40 random image, mask pairs
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)
        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))
        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
    val_files = [{"image": img, "label": seg} for img, seg in zip(images, segs)]

    # model file path
    model_file = glob("./runs/net_key_metric*")[0]

    # define transforms for image and segmentation
    val_transforms = Compose(
        [
            LoadNiftid(keys=["image", "label"]),
            AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
            ScaleIntensityd(keys="image"),
            ToTensord(keys=["image", "label"]),
        ]
    )

    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    val_post_transforms = Compose(
        [
            Activationsd(keys="pred", sigmoid=True),
            AsDiscreted(keys="pred", threshold_values=True),
            KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
        ]
    )
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        CheckpointLoader(load_path=model_file, load_dict={"net": net}),
        SegmentationSaver(
            output_dir="./runs/",
            batch_transform=lambda batch: batch["image_meta_dict"],
            output_transform=lambda output: output["pred"],
        ),
    ]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5),
        post_transform=val_post_transforms,
        key_val_metric={
            "val_mean_dice": MeanDice(include_background=True, output_transform=lambda x: (x["pred"], x["label"]))
        },
        additional_metrics={"val_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]))},
        val_handlers=val_handlers,
        # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation
        amp=True if monai.config.get_torch_version_tuple() >= (1, 6) else False,
    )
    evaluator.run()
コード例 #10
0
def evaluate(args):
    if args.local_rank == 0 and not os.path.exists(args.dir):
        # create 16 random image, mask paris for evaluation
        print(f"generating synthetic data to {args.dir} (this may take a while)")
        os.makedirs(args.dir)
        # set random seed to generate same random data for every node
        np.random.seed(seed=0)
        for i in range(16):
            im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)
            n = nib.Nifti1Image(im, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"img{i:d}.nii.gz"))
            n = nib.Nifti1Image(seg, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"seg{i:d}.nii.gz"))

    # initialize the distributed evaluation process, every GPU runs in a process
    dist.init_process_group(backend="nccl", init_method="env://")

    images = sorted(glob(os.path.join(args.dir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz")))
    val_files = [{"image": img, "label": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
            ScaleIntensityd(keys="image"),
            ToTensord(keys=["image", "label"]),
        ]
    )

    # create a evaluation data loader
    val_ds = Dataset(data=val_files, transform=val_transforms)
    # create a evaluation data sampler
    val_sampler = DistributedSampler(val_ds, shuffle=False)
    # sliding window inference need to input 1 image in every iteration
    val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=True, sampler=val_sampler)

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device(f"cuda:{args.local_rank}")
    torch.cuda.set_device(device)
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    # wrap the model with DistributedDataParallel module
    net = DistributedDataParallel(net, device_ids=[device])

    val_post_transforms = Compose(
        [
            Activationsd(keys="pred", sigmoid=True),
            AsDiscreted(keys="pred", threshold_values=True),
            KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
        ]
    )
    val_handlers = [
        CheckpointLoader(
            load_path="./runs/checkpoint_epoch=4.pt",
            load_dict={"net": net},
            # config mapping to expected GPU device
            map_location={"cuda:0": f"cuda:{args.local_rank}"},
        ),
    ]
    if dist.get_rank() == 0:
        logging.basicConfig(stream=sys.stdout, level=logging.INFO)
        val_handlers.extend(
            [
                StatsHandler(output_transform=lambda x: None),
                SegmentationSaver(
                    output_dir="./runs/",
                    batch_transform=lambda batch: batch["image_meta_dict"],
                    output_transform=lambda output: output["pred"],
                ),
            ]
        )

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5),
        post_transform=val_post_transforms,
        key_val_metric={
            "val_mean_dice": MeanDice(
                include_background=True,
                output_transform=lambda x: (x["pred"], x["label"]),
                device=device,
            )
        },
        additional_metrics={"val_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]), device=device)},
        val_handlers=val_handlers,
        # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation
        amp=True if monai.config.get_torch_version_tuple() >= (1, 6) else False,
    )
    evaluator.run()
    dist.destroy_process_group()
コード例 #11
0
def run_inference_test(root_dir, model_file, device="cuda:0", amp=False):
    images = sorted(glob(os.path.join(root_dir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    val_files = [{
        "image": img,
        "label": seg
    } for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose([
        LoadNiftid(keys=["image", "label"]),
        AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
        ScaleIntensityd(keys=["image", "label"]),
        ToTensord(keys=["image", "label"]),
    ])

    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)

    # create UNet, DiceLoss and Adam optimizer
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    val_post_transforms = Compose([
        Activationsd(keys="pred", sigmoid=True),
        AsDiscreted(keys="pred", threshold_values=True),
        KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
    ])
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        CheckpointLoader(load_path=f"{model_file}", load_dict={"net": net}),
        SegmentationSaver(
            output_dir=root_dir,
            batch_transform=lambda batch: batch["image_meta_dict"],
            output_transform=lambda output: output["pred"],
        ),
    ]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96),
                                     sw_batch_size=4,
                                     overlap=0.5),
        post_transform=val_post_transforms,
        key_val_metric={
            "val_mean_dice":
            MeanDice(include_background=True,
                     output_transform=lambda x: (x["pred"], x["label"]))
        },
        additional_metrics={
            "val_acc":
            Accuracy(output_transform=lambda x: (x["pred"], x["label"]))
        },
        val_handlers=val_handlers,
        amp=True if amp else False,
    )
    evaluator.run()

    return evaluator.state.best_metric
コード例 #12
0
def run_inference_test(root_dir, model_file, device="cuda:0", amp=False, num_workers=4):
    images = sorted(glob(os.path.join(root_dir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    val_files = [{"image": img, "label": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
            ScaleIntensityd(keys=["image", "label"]),
            ToTensord(keys=["image", "label"]),
        ]
    )

    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=num_workers)

    # create UNet, DiceLoss and Adam optimizer
    net = monai.networks.nets.UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    val_postprocessing = Compose(
        [
            ToTensord(keys=["pred", "label"]),
            Activationsd(keys="pred", sigmoid=True),
            AsDiscreted(keys="pred", threshold=0.5),
            KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
            # test the case that `pred` in `engine.state.output`, while `image_meta_dict` in `engine.state.batch`
            SaveImaged(
                keys="pred", meta_keys=PostFix.meta("image"), output_dir=root_dir, output_postfix="seg_transform"
            ),
        ]
    )
    val_handlers = [
        StatsHandler(iteration_log=False),
        CheckpointLoader(load_path=f"{model_file}", load_dict={"net": net}),
        SegmentationSaver(
            output_dir=root_dir,
            output_postfix="seg_handler",
            batch_transform=from_engine(PostFix.meta("image")),
            output_transform=from_engine("pred"),
        ),
    ]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5),
        postprocessing=val_postprocessing,
        key_val_metric={
            "val_mean_dice": MeanDice(include_background=True, output_transform=from_engine(["pred", "label"]))
        },
        additional_metrics={"val_acc": Accuracy(output_transform=from_engine(["pred", "label"]))},
        val_handlers=val_handlers,
        amp=bool(amp),
    )
    evaluator.run()

    return evaluator.state.best_metric
コード例 #13
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # create a temporary directory and 40 random image, mask paris
    tempdir = tempfile.mkdtemp()
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_3d(128,
                                       128,
                                       128,
                                       num_seg_classes=1,
                                       channel_dim=-1)
        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))
        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
    val_files = [{
        "image": img,
        "label": seg
    } for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose([
        LoadNiftid(keys=["image", "label"]),
        AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
        ScaleIntensityd(keys=["image", "label"]),
        ToTensord(keys=["image", "label"]),
    ])

    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda:0")
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    val_post_transforms = Compose([
        Activationsd(keys="pred", output_postfix="act", sigmoid=True),
        AsDiscreted(keys="pred_act",
                    output_postfix="dis",
                    threshold_values=True),
        KeepLargestConnectedComponentd(keys="pred_act_dis",
                                       applied_values=[1],
                                       output_postfix=None),
    ])
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        CheckpointLoader(load_path="./runs/net_key_metric=0.9101.pth",
                         load_dict={"net": net}),
        SegmentationSaver(
            output_dir="./runs/",
            batch_transform=lambda x: {
                "filename_or_obj": x["image.filename_or_obj"],
                "affine": x["image.affine"],
                "original_affine": x["image.original_affine"],
                "spatial_shape": x["image.spatial_shape"],
            },
            output_transform=lambda x: x["pred_act_dis"],
        ),
    ]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96),
                                     sw_batch_size=4,
                                     overlap=0.5),
        post_transform=val_post_transforms,
        key_val_metric={
            "val_mean_dice":
            MeanDice(include_background=True,
                     output_transform=lambda x:
                     (x["pred_act_dis"], x["label"]))
        },
        additional_metrics={
            "val_acc":
            Accuracy(
                output_transform=lambda x: (x["pred_act_dis"], x["label"]))
        },
        val_handlers=val_handlers,
    )
    evaluator.run()
    shutil.rmtree(tempdir)