Exemple #1
0
 def nifti_rw(self, test_data, reader, writer, dtype, resample=True):
     test_data = test_data.astype(dtype)
     ndim = len(test_data.shape) - 1
     for p in TEST_NDARRAYS:
         output_ext = ".nii.gz"
         filepath = f"testfile_{ndim}d"
         saver = SaveImage(output_dir=self.test_dir,
                           output_ext=output_ext,
                           resample=resample,
                           separate_folder=False,
                           writer=writer)
         saver(
             p(test_data),
             {
                 "filename_or_obj":
                 f"{filepath}.png",
                 "affine":
                 np.eye(4),
                 "original_affine":
                 np.array([[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0],
                           [0, 0, 0, 1]]),
             },
         )
         saved_path = os.path.join(self.test_dir,
                                   filepath + "_trans" + output_ext)
         self.assertTrue(os.path.exists(saved_path))
         loader = LoadImage(reader=reader, squeeze_non_spatial_dims=True)
         data, meta = loader(saved_path)
         if meta["original_channel_dim"] == -1:
             _test_data = moveaxis(test_data, 0, -1)
         else:
             _test_data = test_data[0]
         if resample:
             _test_data = moveaxis(_test_data, 0, 1)
         assert_allclose(data, _test_data)
Exemple #2
0
    def test_saved_content(self, test_data, meta_data, output_ext, resample,
                           save_batch):
        with tempfile.TemporaryDirectory() as tempdir:
            trans = SaveImage(
                output_dir=tempdir,
                output_ext=output_ext,
                resample=resample,
                save_batch=save_batch,
            )
            trans(test_data, meta_data)

            if save_batch:
                for i in range(8):
                    filepath = os.path.join(
                        "testfile" + str(i),
                        "testfile" + str(i) + "_trans" + output_ext)
                    self.assertTrue(
                        os.path.exists(os.path.join(tempdir, filepath)))
            else:
                if meta_data is not None:
                    filepath = os.path.join(
                        "testfile0", "testfile0" + "_trans" + output_ext)
                else:
                    filepath = os.path.join("0", "0" + "_trans" + output_ext)
                self.assertTrue(os.path.exists(os.path.join(tempdir,
                                                            filepath)))
Exemple #3
0
 def png_rw(self, test_data, reader, writer, dtype, resample=True):
     test_data = test_data.astype(dtype)
     ndim = len(test_data.shape) - 1
     for p in TEST_NDARRAYS:
         output_ext = ".png"
         filepath = f"testfile_{ndim}d"
         saver = SaveImage(output_dir=self.test_dir,
                           output_ext=output_ext,
                           resample=resample,
                           separate_folder=False,
                           writer=writer)
         saver(p(test_data), {
             "filename_or_obj": f"{filepath}.png",
             "spatial_shape": (6, 8)
         })
         saved_path = os.path.join(self.test_dir,
                                   filepath + "_trans" + output_ext)
         self.assertTrue(os.path.exists(saved_path))
         loader = LoadImage(reader=reader)
         data, meta = loader(saved_path)
         if meta["original_channel_dim"] == -1:
             _test_data = moveaxis(test_data, 0, -1)
         else:
             _test_data = test_data[0]
         assert_allclose(data, _test_data)
def run_inference_test(root_dir, device="cuda:0"):
    images = sorted(glob(os.path.join(root_dir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose(
        [
            LoadImaged(keys=["img", "seg"]),
            EnsureChannelFirstd(keys=["img", "seg"]),
            # resampling with align_corners=True or dtype=float64 will generate
            # slight different results between PyTorch 1.5 an 1.6
            Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32),
            ScaleIntensityd(keys="img"),
            ToTensord(keys=["img", "seg"]),
        ]
    )
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    # sliding window inference need to input 1 image in every iteration
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)
    val_post_tran = Compose([ToTensor(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
    dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)

    model = UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    model_filename = os.path.join(root_dir, "best_metric_model.pth")
    model.load_state_dict(torch.load(model_filename))
    with eval_mode(model):
        # resampling with align_corners=True or dtype=float64 will generate
        # slight different results between PyTorch 1.5 an 1.6
        saver = SaveImage(
            output_dir=os.path.join(root_dir, "output"),
            dtype=np.float32,
            output_ext=".nii.gz",
            output_postfix="seg",
            mode="bilinear",
        )
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device)
            # define sliding window size and batch size for windows inference
            sw_batch_size, roi_size = 4, (96, 96, 96)
            val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
            # decollate prediction into a list
            val_outputs = [val_post_tran(i) for i in decollate_batch(val_outputs)]
            val_meta = decollate_batch(val_data[PostFix.meta("img")])
            # compute metrics
            dice_metric(y_pred=val_outputs, y=val_labels)
            for img, meta in zip(val_outputs, val_meta):  # save a decollated batch of files
                saver(img, meta)

    return dice_metric.aggregate().item()
def main(tempdir):
    config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_2d(128, 128, num_seg_classes=1)
        Image.fromarray((im * 255).astype("uint8")).save(os.path.join(tempdir, f"img{i:d}.png"))
        Image.fromarray((seg * 255).astype("uint8")).save(os.path.join(tempdir, f"seg{i:d}.png"))

    images = sorted(glob(os.path.join(tempdir, "img*.png")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.png")))

    # define transforms for image and segmentation
    imtrans = Compose([LoadImage(image_only=True), AddChannel(), ScaleIntensity(), EnsureType()])
    segtrans = Compose([LoadImage(image_only=True), AddChannel(), ScaleIntensity(), EnsureType()])
    val_ds = ArrayDataset(images, imtrans, segs, segtrans)
    # sliding window inference for one image at every iteration
    val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available())
    dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
    post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
    saver = SaveImage(output_dir="./output", output_ext=".png", output_postfix="seg")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = UNet(
        spatial_dims=2,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    model.load_state_dict(torch.load("best_metric_model_segmentation2d_array.pth"))
    model.eval()
    with torch.no_grad():
        for val_data in val_loader:
            val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
            # define sliding window size and batch size for windows inference
            roi_size = (96, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
            val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
            val_labels = decollate_batch(val_labels)
            # compute metric for current iteration
            dice_metric(y_pred=val_outputs, y=val_labels)
            for val_output in val_outputs:
                saver(val_output)
        # aggregate the final mean dice result
        print("evaluation metric:", dice_metric.aggregate().item())
        # reset the status
        dice_metric.reset()
    def test_saved_content(self, test_data, meta_data, output_ext, resample):
        with tempfile.TemporaryDirectory() as tempdir:
            trans = SaveImage(
                output_dir=tempdir,
                output_ext=output_ext,
                resample=resample,
                separate_folder=False,  # test saving into the same folder
            )
            trans(test_data, meta_data)

            filepath = "testfile0" if meta_data is not None else "0"
            self.assertTrue(
                os.path.exists(
                    os.path.join(tempdir, filepath + "_trans" + output_ext)))
def run_test(batch_size, img_name, seg_name, output_dir, device="cuda:0"):
    ds = ImageDataset([img_name], [seg_name],
                      transform=AddChannel(),
                      seg_transform=AddChannel(),
                      image_only=True)
    loader = DataLoader(ds, batch_size=1, pin_memory=torch.cuda.is_available())

    net = UNet(spatial_dims=3,
               in_channels=1,
               out_channels=1,
               channels=(4, 8, 16, 32),
               strides=(2, 2, 2),
               num_res_units=2).to(device)
    roi_size = (16, 32, 48)
    sw_batch_size = batch_size

    saver = SaveImage(output_dir=output_dir,
                      output_ext=".nii.gz",
                      output_postfix="seg")

    def _sliding_window_processor(_engine, batch):
        img = batch[0]  # first item from ImageDataset is the input image
        with eval_mode(net):
            seg_probs = sliding_window_inference(img.to(device),
                                                 roi_size,
                                                 sw_batch_size,
                                                 net,
                                                 device=device)
            return predict_segmentation(seg_probs)

    def save_func(engine):
        if pytorch_after(1, 9, 1):
            for m in engine.state.output:
                saver(m)
        else:
            saver(engine.state.output[0])

    infer_engine = Engine(_sliding_window_processor)
    infer_engine.add_event_handler(Events.ITERATION_COMPLETED, save_func)
    infer_engine.run(loader)

    basename = os.path.basename(img_name)[:-len(".nii.gz")]
    saved_name = os.path.join(output_dir, basename, f"{basename}_seg.nii.gz")
    return saved_name
Exemple #8
0
 def nrrd_rw(self, test_data, reader, writer, dtype, resample=True):
     test_data = test_data.astype(dtype)
     ndim = len(test_data.shape)
     for p in TEST_NDARRAYS:
         output_ext = ".nrrd"
         filepath = f"testfile_{ndim}d"
         saver = SaveImage(output_dir=self.test_dir,
                           output_ext=output_ext,
                           resample=resample,
                           separate_folder=False,
                           writer=writer)
         test_data = MetaTensor(p(test_data),
                                meta={
                                    "filename_or_obj":
                                    f"{filepath}{output_ext}",
                                    "spatial_shape": test_data.shape
                                })
         saver(test_data)
         saved_path = os.path.join(self.test_dir,
                                   filepath + "_trans" + output_ext)
         loader = LoadImage(image_only=True, reader=reader)
         data = loader(saved_path)
         assert_allclose(data, torch.as_tensor(test_data))
Exemple #9
0
    def __call__(
        self,
        img: NdarrayTensor,
        meta_data: Optional[Dict] = None,
        prefix: Optional[str] = None,
        data_type: Optional[bool] = None,
        data_shape: Optional[bool] = None,
        value_range: Optional[bool] = None,
        data_value: Optional[bool] = None,
        additional_info: Optional[Callable] = None,
    ) -> NdarrayTensor:
        img = super().__init__(img, prefix, data_type, data_shape, value_range,
                               data_value, additional_info)

        if self.save_data:
            saver = SaveImage(
                output_dir=self.save_data_dir,
                output_postfix=self.prefix,
                output_ext=".nii.gz",
                resample=False,
            )
            saver(img, meta_data)

        return img
Exemple #10
0
    def __init__(
        self,
        output_dir: str = "./",
        output_postfix: str = "seg",
        output_ext: str = ".nii.gz",
        resample: bool = True,
        mode: Union[GridSampleMode, InterpolateMode, str] = "nearest",
        padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER,
        scale: Optional[int] = None,
        dtype: DtypeLike = np.float64,
        output_dtype: DtypeLike = np.float32,
        squeeze_end_dims: bool = True,
        data_root_dir: str = "",
        batch_transform: Callable = lambda x: x,
        output_transform: Callable = lambda x: x,
        name: Optional[str] = None,
    ) -> None:
        """
        Args:
            output_dir: output image directory.
            output_postfix: a string appended to all output file names, default to `seg`.
            output_ext: output file extension name, available extensions: `.nii.gz`, `.nii`, `.png`.
            resample: whether to resample before saving the data array.
                if saving PNG format image, based on the `spatial_shape` from metadata.
                if saving NIfTI format image, based on the `original_affine` from metadata.
            mode: This option is used when ``resample = True``. Defaults to ``"nearest"``.

                - NIfTI files {``"bilinear"``, ``"nearest"``}
                    Interpolation mode to calculate output values.
                    See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
                - PNG files {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``}
                    The interpolation mode.
                    See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate

            padding_mode: This option is used when ``resample = True``. Defaults to ``"border"``.

                - NIfTI files {``"zeros"``, ``"border"``, ``"reflection"``}
                    Padding mode for outside grid values.
                    See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
                - PNG files
                    This option is ignored.

            scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling
                [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling.
                It's used for PNG format only.
            dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision.
                If None, use the data type of input data.
                It's used for Nifti format only.
            output_dtype: data type for saving data. Defaults to ``np.float32``, it's used for Nifti format only.
            squeeze_end_dims: if True, any trailing singleton dimensions will be removed (after the channel
                has been moved to the end). So if input is (C,H,W,D), this will be altered to (H,W,D,C), and
                then if C==1, it will be saved as (H,W,D). If D also ==1, it will be saved as (H,W). If false,
                image will always be saved as (H,W,D,C).
                it's used for NIfTI format only.
            data_root_dir: if not empty, it specifies the beginning parts of the input file's
                absolute path. it's used to compute `input_file_rel_path`, the relative path to the file from
                `data_root_dir` to preserve folder structure when saving in case there are files in different
                folders with the same file names. for example:
                input_file_name: /foo/bar/test1/image.nii,
                output_postfix: seg
                output_ext: nii.gz
                output_dir: /output,
                data_root_dir: /foo/bar,
                output will be: /output/test1/image/image_seg.nii.gz
            batch_transform: a callable that is used to transform the
                ignite.engine.batch into expected format to extract the meta_data dictionary.
            output_transform: a callable that is used to transform the
                ignite.engine.output into the form expected image data.
                The first dimension of this transform's output will be treated as the
                batch dimension. Each item in the batch will be saved individually.
            name: identifier of logging.logger to use, defaulting to `engine.logger`.

        """
        self._saver = SaveImage(
            output_dir=output_dir,
            output_postfix=output_postfix,
            output_ext=output_ext,
            resample=resample,
            mode=mode,
            padding_mode=padding_mode,
            scale=scale,
            dtype=dtype,
            output_dtype=output_dtype,
            squeeze_end_dims=squeeze_end_dims,
            data_root_dir=data_root_dir,
            save_batch=True,
        )
        self.batch_transform = batch_transform
        self.output_transform = output_transform

        self.logger = logging.getLogger(name)
        self._name = name
    def __init__(
        self,
        output_dir: str = "./",
        output_postfix: str = "seg",
        output_ext: str = ".nii.gz",
        resample: bool = True,
        mode: Union[GridSampleMode, InterpolateMode, str] = "nearest",
        padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER,
        scale: Optional[int] = None,
        dtype: DtypeLike = np.float64,
        output_dtype: DtypeLike = np.float32,
        batch_transform: Callable = lambda x: x,
        output_transform: Callable = lambda x: x,
        name: Optional[str] = None,
    ) -> None:
        """
        Args:
            output_dir: output image directory.
            output_postfix: a string appended to all output file names, default to `seg`.
            output_ext: output file extension name, available extensions: `.nii.gz`, `.nii`, `.png`.
            resample: whether to resample before saving the data array.
                if saving PNG format image, based on the `spatial_shape` from metadata.
                if saving NIfTI format image, based on the `original_affine` from metadata.
            mode: This option is used when ``resample = True``. Defaults to ``"nearest"``.

                - NIfTI files {``"bilinear"``, ``"nearest"``}
                    Interpolation mode to calculate output values.
                    See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
                - PNG files {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``}
                    The interpolation mode.
                    See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate

            padding_mode: This option is used when ``resample = True``. Defaults to ``"border"``.

                - NIfTI files {``"zeros"``, ``"border"``, ``"reflection"``}
                    Padding mode for outside grid values.
                    See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
                - PNG files
                    This option is ignored.

            scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling
                [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling.
                It's used for PNG format only.
            dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision.
                If None, use the data type of input data.
                It's used for Nifti format only.
            output_dtype: data type for saving data. Defaults to ``np.float32``, it's used for Nifti format only.
            batch_transform: a callable that is used to transform the
                ignite.engine.batch into expected format to extract the meta_data dictionary.
            output_transform: a callable that is used to transform the
                ignite.engine.output into the form expected image data.
                The first dimension of this transform's output will be treated as the
                batch dimension. Each item in the batch will be saved individually.
            name: identifier of logging.logger to use, defaulting to `engine.logger`.

        """
        self._saver = SaveImage(
            output_dir=output_dir,
            output_postfix=output_postfix,
            output_ext=output_ext,
            resample=resample,
            mode=mode,
            padding_mode=padding_mode,
            scale=scale,
            dtype=dtype,
            output_dtype=output_dtype,
            save_batch=True,
        )
        self.batch_transform = batch_transform
        self.output_transform = output_transform

        self.logger = logging.getLogger(name)
        self._name = name
Exemple #12
0
def main(tempdir):
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose(
        [
            LoadImaged(keys=["img", "seg"]),
            AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
            ScaleIntensityd(keys="img"),
            EnsureTyped(keys=["img", "seg"]),
        ]
    )
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    # define sliding window size and batch size for windows inference
    roi_size = (96, 96, 96)
    sw_batch_size = 4

    post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
    save_image = SaveImage(output_dir="tempdir", output_ext=".nii.gz", output_postfix="seg")

    def _sliding_window_processor(engine, batch):
        net.eval()
        with torch.no_grad():
            val_images, val_labels = batch["img"].to(device), batch["seg"].to(device)
            seg_probs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
            seg_probs = [post_trans(i) for i in decollate_batch(seg_probs)]
            val_data = decollate_batch(batch["img_meta_dict"])
            for seg_prob, data in zip(seg_probs, val_data):
                save_image(seg_prob, data)
            return seg_probs, val_labels

    evaluator = Engine(_sliding_window_processor)

    # add evaluation metric to the evaluator engine
    MeanDice().attach(evaluator, "Mean_Dice")

    # StatsHandler prints loss at every iteration and print metrics at every epoch,
    # we don't need to print loss for evaluator, so just print metrics, user can also customize print functions
    val_stats_handler = StatsHandler(
        name="evaluator",
        output_transform=lambda x: None,  # no need to print loss value, so disable per iteration output
    )
    val_stats_handler.attach(evaluator)

    # the model was trained by "unet_training_dict" example
    CheckpointLoader(load_path="./runs_dict/net_checkpoint_50.pt", load_dict={"net": net}).attach(evaluator)

    # sliding window inference for one image at every iteration
    val_loader = DataLoader(
        val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()
    )
    state = evaluator.run(val_loader)
    print(state)
def main(tempdir):
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose(
        [
            LoadImaged(keys=["img", "seg"]),
            AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
            ScaleIntensityd(keys="img"),
            EnsureTyped(keys=["img", "seg"]),
        ]
    )
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    # sliding window inference need to input 1 image in every iteration
    val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate)
    dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
    post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
    saver = SaveImage(output_dir="./output", output_ext=".nii.gz", output_postfix="seg")
    # try to use all the available GPUs
    devices = [torch.device("cuda" if torch.cuda.is_available() else "cpu")]
    #devices = get_devices_spec(None)
    model = UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(devices[0])

    model.load_state_dict(torch.load("best_metric_model_segmentation3d_dict.pth"))

    # if we have multiple GPUs, set data parallel to execute sliding window inference
    if len(devices) > 1:
        model = torch.nn.DataParallel(model, device_ids=devices)

    model.eval()
    with torch.no_grad():
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(devices[0]), val_data["seg"].to(devices[0])
            # define sliding window size and batch size for windows inference
            roi_size = (96, 96, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
            val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
            val_labels = decollate_batch(val_labels)
            meta_data = decollate_batch(val_data["img_meta_dict"])
            # compute metric for current iteration
            dice_metric(y_pred=val_outputs, y=val_labels)
            for val_output, data in zip(val_outputs, meta_data):
                saver(val_output, data)
        # aggregate the final mean dice result
        print("evaluation metric:", dice_metric.aggregate().item())
        # reset the status
        dice_metric.reset()
def run_inference_test(root_dir,
                       model_file,
                       device="cuda:0",
                       amp=False,
                       num_workers=4):
    images = sorted(glob(os.path.join(root_dir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    val_files = [{
        "image": img,
        "label": seg
    } for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose([
        LoadImaged(keys=["image", "label"]),
        AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
        ScaleIntensityd(keys=["image", "label"]),
    ])

    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = monai.data.DataLoader(val_ds,
                                       batch_size=1,
                                       num_workers=num_workers)

    # create UNet, DiceLoss and Adam optimizer
    net = monai.networks.nets.UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    val_postprocessing = Compose([
        Activationsd(keys="pred", sigmoid=True),
        AsDiscreted(keys="pred", threshold=0.5),
        KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
        # test the case that `pred` in `engine.state.output`, while `image_meta_dict` in `engine.state.batch`
        SaveImaged(keys="pred",
                   output_dir=root_dir,
                   output_postfix="seg_transform"),
    ])
    val_handlers = [
        StatsHandler(iteration_log=False),
        CheckpointLoader(load_path=f"{model_file}", load_dict={"net": net}),
    ]

    saver = SaveImage(output_dir=root_dir, output_postfix="seg_handler")

    def save_func(engine):
        for o in from_engine("pred")(engine.state.output):
            saver(o)

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96),
                                     sw_batch_size=4,
                                     overlap=0.5),
        postprocessing=val_postprocessing,
        key_val_metric={
            "val_mean_dice":
            MeanDice(include_background=True,
                     output_transform=from_engine(["pred", "label"]))
        },
        additional_metrics={
            "val_acc":
            Accuracy(output_transform=from_engine(["pred", "label"]))
        },
        val_handlers=val_handlers,
        amp=bool(amp),
    )
    evaluator.add_event_handler(Events.ITERATION_COMPLETED, save_func)
    evaluator.run()

    return evaluator.state.best_metric