def nifti_rw(self, test_data, reader, writer, dtype, resample=True): test_data = test_data.astype(dtype) ndim = len(test_data.shape) - 1 for p in TEST_NDARRAYS: output_ext = ".nii.gz" filepath = f"testfile_{ndim}d" saver = SaveImage(output_dir=self.test_dir, output_ext=output_ext, resample=resample, separate_folder=False, writer=writer) saver( p(test_data), { "filename_or_obj": f"{filepath}.png", "affine": np.eye(4), "original_affine": np.array([[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]), }, ) saved_path = os.path.join(self.test_dir, filepath + "_trans" + output_ext) self.assertTrue(os.path.exists(saved_path)) loader = LoadImage(reader=reader, squeeze_non_spatial_dims=True) data, meta = loader(saved_path) if meta["original_channel_dim"] == -1: _test_data = moveaxis(test_data, 0, -1) else: _test_data = test_data[0] if resample: _test_data = moveaxis(_test_data, 0, 1) assert_allclose(data, _test_data)
def test_saved_content(self, test_data, meta_data, output_ext, resample, save_batch): with tempfile.TemporaryDirectory() as tempdir: trans = SaveImage( output_dir=tempdir, output_ext=output_ext, resample=resample, save_batch=save_batch, ) trans(test_data, meta_data) if save_batch: for i in range(8): filepath = os.path.join( "testfile" + str(i), "testfile" + str(i) + "_trans" + output_ext) self.assertTrue( os.path.exists(os.path.join(tempdir, filepath))) else: if meta_data is not None: filepath = os.path.join( "testfile0", "testfile0" + "_trans" + output_ext) else: filepath = os.path.join("0", "0" + "_trans" + output_ext) self.assertTrue(os.path.exists(os.path.join(tempdir, filepath)))
def png_rw(self, test_data, reader, writer, dtype, resample=True): test_data = test_data.astype(dtype) ndim = len(test_data.shape) - 1 for p in TEST_NDARRAYS: output_ext = ".png" filepath = f"testfile_{ndim}d" saver = SaveImage(output_dir=self.test_dir, output_ext=output_ext, resample=resample, separate_folder=False, writer=writer) saver(p(test_data), { "filename_or_obj": f"{filepath}.png", "spatial_shape": (6, 8) }) saved_path = os.path.join(self.test_dir, filepath + "_trans" + output_ext) self.assertTrue(os.path.exists(saved_path)) loader = LoadImage(reader=reader) data, meta = loader(saved_path) if meta["original_channel_dim"] == -1: _test_data = moveaxis(test_data, 0, -1) else: _test_data = test_data[0] assert_allclose(data, _test_data)
def run_inference_test(root_dir, device="cuda:0"): images = sorted(glob(os.path.join(root_dir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz"))) val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose( [ LoadImaged(keys=["img", "seg"]), EnsureChannelFirstd(keys=["img", "seg"]), # resampling with align_corners=True or dtype=float64 will generate # slight different results between PyTorch 1.5 an 1.6 Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd(keys="img"), ToTensord(keys=["img", "seg"]), ] ) val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) # sliding window inference need to input 1 image in every iteration val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) val_post_tran = Compose([ToTensor(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) model = UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) model_filename = os.path.join(root_dir, "best_metric_model.pth") model.load_state_dict(torch.load(model_filename)) with eval_mode(model): # resampling with align_corners=True or dtype=float64 will generate # slight different results between PyTorch 1.5 an 1.6 saver = SaveImage( output_dir=os.path.join(root_dir, "output"), dtype=np.float32, output_ext=".nii.gz", output_postfix="seg", mode="bilinear", ) for val_data in val_loader: val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device) # define sliding window size and batch size for windows inference sw_batch_size, roi_size = 4, (96, 96, 96) val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) # decollate prediction into a list val_outputs = [val_post_tran(i) for i in decollate_batch(val_outputs)] val_meta = decollate_batch(val_data[PostFix.meta("img")]) # compute metrics dice_metric(y_pred=val_outputs, y=val_labels) for img, meta in zip(val_outputs, val_meta): # save a decollated batch of files saver(img, meta) return dice_metric.aggregate().item()
def main(tempdir): config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, seg = create_test_image_2d(128, 128, num_seg_classes=1) Image.fromarray((im * 255).astype("uint8")).save(os.path.join(tempdir, f"img{i:d}.png")) Image.fromarray((seg * 255).astype("uint8")).save(os.path.join(tempdir, f"seg{i:d}.png")) images = sorted(glob(os.path.join(tempdir, "img*.png"))) segs = sorted(glob(os.path.join(tempdir, "seg*.png"))) # define transforms for image and segmentation imtrans = Compose([LoadImage(image_only=True), AddChannel(), ScaleIntensity(), EnsureType()]) segtrans = Compose([LoadImage(image_only=True), AddChannel(), ScaleIntensity(), EnsureType()]) val_ds = ArrayDataset(images, imtrans, segs, segtrans) # sliding window inference for one image at every iteration val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) saver = SaveImage(output_dir="./output", output_ext=".png", output_postfix="seg") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = UNet( spatial_dims=2, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) model.load_state_dict(torch.load("best_metric_model_segmentation2d_array.pth")) model.eval() with torch.no_grad(): for val_data in val_loader: val_images, val_labels = val_data[0].to(device), val_data[1].to(device) # define sliding window size and batch size for windows inference roi_size = (96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)] val_labels = decollate_batch(val_labels) # compute metric for current iteration dice_metric(y_pred=val_outputs, y=val_labels) for val_output in val_outputs: saver(val_output) # aggregate the final mean dice result print("evaluation metric:", dice_metric.aggregate().item()) # reset the status dice_metric.reset()
def test_saved_content(self, test_data, meta_data, output_ext, resample): with tempfile.TemporaryDirectory() as tempdir: trans = SaveImage( output_dir=tempdir, output_ext=output_ext, resample=resample, separate_folder=False, # test saving into the same folder ) trans(test_data, meta_data) filepath = "testfile0" if meta_data is not None else "0" self.assertTrue( os.path.exists( os.path.join(tempdir, filepath + "_trans" + output_ext)))
def run_test(batch_size, img_name, seg_name, output_dir, device="cuda:0"): ds = ImageDataset([img_name], [seg_name], transform=AddChannel(), seg_transform=AddChannel(), image_only=True) loader = DataLoader(ds, batch_size=1, pin_memory=torch.cuda.is_available()) net = UNet(spatial_dims=3, in_channels=1, out_channels=1, channels=(4, 8, 16, 32), strides=(2, 2, 2), num_res_units=2).to(device) roi_size = (16, 32, 48) sw_batch_size = batch_size saver = SaveImage(output_dir=output_dir, output_ext=".nii.gz", output_postfix="seg") def _sliding_window_processor(_engine, batch): img = batch[0] # first item from ImageDataset is the input image with eval_mode(net): seg_probs = sliding_window_inference(img.to(device), roi_size, sw_batch_size, net, device=device) return predict_segmentation(seg_probs) def save_func(engine): if pytorch_after(1, 9, 1): for m in engine.state.output: saver(m) else: saver(engine.state.output[0]) infer_engine = Engine(_sliding_window_processor) infer_engine.add_event_handler(Events.ITERATION_COMPLETED, save_func) infer_engine.run(loader) basename = os.path.basename(img_name)[:-len(".nii.gz")] saved_name = os.path.join(output_dir, basename, f"{basename}_seg.nii.gz") return saved_name
def nrrd_rw(self, test_data, reader, writer, dtype, resample=True): test_data = test_data.astype(dtype) ndim = len(test_data.shape) for p in TEST_NDARRAYS: output_ext = ".nrrd" filepath = f"testfile_{ndim}d" saver = SaveImage(output_dir=self.test_dir, output_ext=output_ext, resample=resample, separate_folder=False, writer=writer) test_data = MetaTensor(p(test_data), meta={ "filename_or_obj": f"{filepath}{output_ext}", "spatial_shape": test_data.shape }) saver(test_data) saved_path = os.path.join(self.test_dir, filepath + "_trans" + output_ext) loader = LoadImage(image_only=True, reader=reader) data = loader(saved_path) assert_allclose(data, torch.as_tensor(test_data))
def __call__( self, img: NdarrayTensor, meta_data: Optional[Dict] = None, prefix: Optional[str] = None, data_type: Optional[bool] = None, data_shape: Optional[bool] = None, value_range: Optional[bool] = None, data_value: Optional[bool] = None, additional_info: Optional[Callable] = None, ) -> NdarrayTensor: img = super().__init__(img, prefix, data_type, data_shape, value_range, data_value, additional_info) if self.save_data: saver = SaveImage( output_dir=self.save_data_dir, output_postfix=self.prefix, output_ext=".nii.gz", resample=False, ) saver(img, meta_data) return img
def __init__( self, output_dir: str = "./", output_postfix: str = "seg", output_ext: str = ".nii.gz", resample: bool = True, mode: Union[GridSampleMode, InterpolateMode, str] = "nearest", padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, scale: Optional[int] = None, dtype: DtypeLike = np.float64, output_dtype: DtypeLike = np.float32, squeeze_end_dims: bool = True, data_root_dir: str = "", batch_transform: Callable = lambda x: x, output_transform: Callable = lambda x: x, name: Optional[str] = None, ) -> None: """ Args: output_dir: output image directory. output_postfix: a string appended to all output file names, default to `seg`. output_ext: output file extension name, available extensions: `.nii.gz`, `.nii`, `.png`. resample: whether to resample before saving the data array. if saving PNG format image, based on the `spatial_shape` from metadata. if saving NIfTI format image, based on the `original_affine` from metadata. mode: This option is used when ``resample = True``. Defaults to ``"nearest"``. - NIfTI files {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - PNG files {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate padding_mode: This option is used when ``resample = True``. Defaults to ``"border"``. - NIfTI files {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - PNG files This option is ignored. scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling. It's used for PNG format only. dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. If None, use the data type of input data. It's used for Nifti format only. output_dtype: data type for saving data. Defaults to ``np.float32``, it's used for Nifti format only. squeeze_end_dims: if True, any trailing singleton dimensions will be removed (after the channel has been moved to the end). So if input is (C,H,W,D), this will be altered to (H,W,D,C), and then if C==1, it will be saved as (H,W,D). If D also ==1, it will be saved as (H,W). If false, image will always be saved as (H,W,D,C). it's used for NIfTI format only. data_root_dir: if not empty, it specifies the beginning parts of the input file's absolute path. it's used to compute `input_file_rel_path`, the relative path to the file from `data_root_dir` to preserve folder structure when saving in case there are files in different folders with the same file names. for example: input_file_name: /foo/bar/test1/image.nii, output_postfix: seg output_ext: nii.gz output_dir: /output, data_root_dir: /foo/bar, output will be: /output/test1/image/image_seg.nii.gz batch_transform: a callable that is used to transform the ignite.engine.batch into expected format to extract the meta_data dictionary. output_transform: a callable that is used to transform the ignite.engine.output into the form expected image data. The first dimension of this transform's output will be treated as the batch dimension. Each item in the batch will be saved individually. name: identifier of logging.logger to use, defaulting to `engine.logger`. """ self._saver = SaveImage( output_dir=output_dir, output_postfix=output_postfix, output_ext=output_ext, resample=resample, mode=mode, padding_mode=padding_mode, scale=scale, dtype=dtype, output_dtype=output_dtype, squeeze_end_dims=squeeze_end_dims, data_root_dir=data_root_dir, save_batch=True, ) self.batch_transform = batch_transform self.output_transform = output_transform self.logger = logging.getLogger(name) self._name = name
def __init__( self, output_dir: str = "./", output_postfix: str = "seg", output_ext: str = ".nii.gz", resample: bool = True, mode: Union[GridSampleMode, InterpolateMode, str] = "nearest", padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, scale: Optional[int] = None, dtype: DtypeLike = np.float64, output_dtype: DtypeLike = np.float32, batch_transform: Callable = lambda x: x, output_transform: Callable = lambda x: x, name: Optional[str] = None, ) -> None: """ Args: output_dir: output image directory. output_postfix: a string appended to all output file names, default to `seg`. output_ext: output file extension name, available extensions: `.nii.gz`, `.nii`, `.png`. resample: whether to resample before saving the data array. if saving PNG format image, based on the `spatial_shape` from metadata. if saving NIfTI format image, based on the `original_affine` from metadata. mode: This option is used when ``resample = True``. Defaults to ``"nearest"``. - NIfTI files {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - PNG files {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate padding_mode: This option is used when ``resample = True``. Defaults to ``"border"``. - NIfTI files {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - PNG files This option is ignored. scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling. It's used for PNG format only. dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. If None, use the data type of input data. It's used for Nifti format only. output_dtype: data type for saving data. Defaults to ``np.float32``, it's used for Nifti format only. batch_transform: a callable that is used to transform the ignite.engine.batch into expected format to extract the meta_data dictionary. output_transform: a callable that is used to transform the ignite.engine.output into the form expected image data. The first dimension of this transform's output will be treated as the batch dimension. Each item in the batch will be saved individually. name: identifier of logging.logger to use, defaulting to `engine.logger`. """ self._saver = SaveImage( output_dir=output_dir, output_postfix=output_postfix, output_ext=output_ext, resample=resample, mode=mode, padding_mode=padding_mode, scale=scale, dtype=dtype, output_dtype=output_dtype, save_batch=True, ) self.batch_transform = batch_transform self.output_transform = output_transform self.logger = logging.getLogger(name) self._name = name
def main(tempdir): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose( [ LoadImaged(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), ScaleIntensityd(keys="img"), EnsureTyped(keys=["img", "seg"]), ] ) val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) # define sliding window size and batch size for windows inference roi_size = (96, 96, 96) sw_batch_size = 4 post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) save_image = SaveImage(output_dir="tempdir", output_ext=".nii.gz", output_postfix="seg") def _sliding_window_processor(engine, batch): net.eval() with torch.no_grad(): val_images, val_labels = batch["img"].to(device), batch["seg"].to(device) seg_probs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) seg_probs = [post_trans(i) for i in decollate_batch(seg_probs)] val_data = decollate_batch(batch["img_meta_dict"]) for seg_prob, data in zip(seg_probs, val_data): save_image(seg_prob, data) return seg_probs, val_labels evaluator = Engine(_sliding_window_processor) # add evaluation metric to the evaluator engine MeanDice().attach(evaluator, "Mean_Dice") # StatsHandler prints loss at every iteration and print metrics at every epoch, # we don't need to print loss for evaluator, so just print metrics, user can also customize print functions val_stats_handler = StatsHandler( name="evaluator", output_transform=lambda x: None, # no need to print loss value, so disable per iteration output ) val_stats_handler.attach(evaluator) # the model was trained by "unet_training_dict" example CheckpointLoader(load_path="./runs_dict/net_checkpoint_50.pt", load_dict={"net": net}).attach(evaluator) # sliding window inference for one image at every iteration val_loader = DataLoader( val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available() ) state = evaluator.run(val_loader) print(state)
def main(tempdir): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose( [ LoadImaged(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), ScaleIntensityd(keys="img"), EnsureTyped(keys=["img", "seg"]), ] ) val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) # sliding window inference need to input 1 image in every iteration val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) saver = SaveImage(output_dir="./output", output_ext=".nii.gz", output_postfix="seg") # try to use all the available GPUs devices = [torch.device("cuda" if torch.cuda.is_available() else "cpu")] #devices = get_devices_spec(None) model = UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(devices[0]) model.load_state_dict(torch.load("best_metric_model_segmentation3d_dict.pth")) # if we have multiple GPUs, set data parallel to execute sliding window inference if len(devices) > 1: model = torch.nn.DataParallel(model, device_ids=devices) model.eval() with torch.no_grad(): for val_data in val_loader: val_images, val_labels = val_data["img"].to(devices[0]), val_data["seg"].to(devices[0]) # define sliding window size and batch size for windows inference roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)] val_labels = decollate_batch(val_labels) meta_data = decollate_batch(val_data["img_meta_dict"]) # compute metric for current iteration dice_metric(y_pred=val_outputs, y=val_labels) for val_output, data in zip(val_outputs, meta_data): saver(val_output, data) # aggregate the final mean dice result print("evaluation metric:", dice_metric.aggregate().item()) # reset the status dice_metric.reset()
def run_inference_test(root_dir, model_file, device="cuda:0", amp=False, num_workers=4): images = sorted(glob(os.path.join(root_dir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz"))) val_files = [{ "image": img, "label": seg } for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose([ LoadImaged(keys=["image", "label"]), AsChannelFirstd(keys=["image", "label"], channel_dim=-1), ScaleIntensityd(keys=["image", "label"]), ]) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=num_workers) # create UNet, DiceLoss and Adam optimizer net = monai.networks.nets.UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) val_postprocessing = Compose([ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold=0.5), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), # test the case that `pred` in `engine.state.output`, while `image_meta_dict` in `engine.state.batch` SaveImaged(keys="pred", output_dir=root_dir, output_postfix="seg_transform"), ]) val_handlers = [ StatsHandler(iteration_log=False), CheckpointLoader(load_path=f"{model_file}", load_dict={"net": net}), ] saver = SaveImage(output_dir=root_dir, output_postfix="seg_handler") def save_func(engine): for o in from_engine("pred")(engine.state.output): saver(o) evaluator = SupervisedEvaluator( device=device, val_data_loader=val_loader, network=net, inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5), postprocessing=val_postprocessing, key_val_metric={ "val_mean_dice": MeanDice(include_background=True, output_transform=from_engine(["pred", "label"])) }, additional_metrics={ "val_acc": Accuracy(output_transform=from_engine(["pred", "label"])) }, val_handlers=val_handlers, amp=bool(amp), ) evaluator.add_event_handler(Events.ITERATION_COMPLETED, save_func) evaluator.run() return evaluator.state.best_metric