def main(tempdir): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, seg = create_test_image_2d(128, 128, num_seg_classes=1) Image.fromarray(im.astype("uint8")).save(os.path.join(tempdir, f"img{i:d}.png")) Image.fromarray(seg.astype("uint8")).save(os.path.join(tempdir, f"seg{i:d}.png")) images = sorted(glob(os.path.join(tempdir, "img*.png"))) segs = sorted(glob(os.path.join(tempdir, "seg*.png"))) val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose( [ LoadImaged(keys=["img", "seg"]), AddChanneld(keys=["img", "seg"]), ScaleIntensityd(keys="img"), ToTensord(keys=["img", "seg"]), ] ) val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) # sliding window inference need to input 1 image in every iteration val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate) dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = UNet( dimensions=2, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) model.load_state_dict(torch.load("best_metric_model_segmentation2d_dict.pth")) model.eval() with torch.no_grad(): metric_sum = 0.0 metric_count = 0 saver = PNGSaver(output_dir="./output") for val_data in val_loader: val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device) # define sliding window size and batch size for windows inference roi_size = (96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) value = dice_metric(y_pred=val_outputs, y=val_labels) metric_count += len(value) metric_sum += value.item() * len(value) val_outputs = val_outputs.sigmoid() >= 0.5 saver.save_batch(val_outputs) metric = metric_sum / metric_count print("evaluation metric:", metric)
def test_saved_content_three_channel(self): with tempfile.TemporaryDirectory() as tempdir: saver = PNGSaver(output_dir=tempdir, output_postfix="seg", output_ext=".png", scale=255) meta_data = {"filename_or_obj": ["testfile" + str(i) + ".jpg" for i in range(8)]} saver.save_batch(torch.randint(1, 200, (8, 3, 2, 2)), meta_data) for i in range(8): filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.png") self.assertTrue(os.path.exists(os.path.join(tempdir, filepath)))
def test_saved_content_three_channel(self): default_dir = os.path.join(".", "tempdir") shutil.rmtree(default_dir, ignore_errors=True) saver = PNGSaver(output_dir=default_dir, output_postfix="seg", output_ext=".png") meta_data = {"filename_or_obj": ["testfile" + str(i) for i in range(8)]} saver.save_batch(torch.randint(1, 200, (8, 3, 2, 2)), meta_data) for i in range(8): filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.png") self.assertTrue(os.path.exists(os.path.join(default_dir, filepath))) shutil.rmtree(default_dir)
def test_saved_specified_root(self): with tempfile.TemporaryDirectory() as tempdir: saver = PNGSaver(output_dir=tempdir, output_postfix="seg", output_ext=".png", scale=255, data_root_dir="test") meta_data = { "filename_or_obj": [ os.path.join("test", "testfile" + str(i), "image" + ".jpg") for i in range(8) ] } saver.save_batch(torch.randint(1, 200, (8, 1, 2, 2)), meta_data) for i in range(8): filepath = os.path.join("testfile" + str(i), "image", "image" + "_seg.png") self.assertTrue(os.path.exists(os.path.join(tempdir, filepath)))
class SegmentationSaver: """ Event handler triggered on completing every iteration to save the segmentation predictions into files. """ def __init__( self, output_dir: str = "./", output_postfix: str = "seg", output_ext: str = ".nii.gz", resample: bool = True, mode: Union[GridSampleMode, InterpolateMode, str] = "nearest", padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, scale: Optional[int] = None, dtype: Optional[np.dtype] = np.float64, output_dtype: Optional[np.dtype] = np.float32, batch_transform: Callable = lambda x: x, output_transform: Callable = lambda x: x, name: Optional[str] = None, ) -> None: """ Args: output_dir: output image directory. output_postfix: a string appended to all output file names. output_ext: output file extension name. resample: whether to resample before saving the data array. mode: This option is used when ``resample = True``. Defaults to ``"nearest"``. - NIfTI files {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - PNG files {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate padding_mode: This option is used when ``resample = True``. Defaults to ``"border"``. - NIfTI files {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - PNG files This option is ignored. scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling. It's used for PNG format only. dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. If None, use the data type of input data. To be compatible with other modules, the output data type is always ``np.float32``, it's used for Nifti format only. output_dtype: data type for saving data. Defaults to ``np.float32``, it's used for Nifti format only. batch_transform: a callable that is used to transform the ignite.engine.batch into expected format to extract the meta_data dictionary. output_transform: a callable that is used to transform the ignite.engine.output into the form expected image data. The first dimension of this transform's output will be treated as the batch dimension. Each item in the batch will be saved individually. name: identifier of logging.logger to use, defaulting to `engine.logger`. """ self.saver: Union[NiftiSaver, PNGSaver] if output_ext in (".nii.gz", ".nii"): self.saver = NiftiSaver( output_dir=output_dir, output_postfix=output_postfix, output_ext=output_ext, resample=resample, mode=GridSampleMode(mode), padding_mode=padding_mode, dtype=dtype, output_dtype=output_dtype, ) elif output_ext == ".png": self.saver = PNGSaver( output_dir=output_dir, output_postfix=output_postfix, output_ext=output_ext, resample=resample, mode=InterpolateMode(mode), scale=scale, ) self.batch_transform = batch_transform self.output_transform = output_transform self.logger = logging.getLogger(name) self._name = name def attach(self, engine: Engine) -> None: """ Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ if self._name is None: self.logger = engine.logger if not engine.has_event_handler(self, Events.ITERATION_COMPLETED): engine.add_event_handler(Events.ITERATION_COMPLETED, self) def __call__(self, engine: Engine) -> None: """ This method assumes self.batch_transform will extract metadata from the input batch. Output file datatype is determined from ``engine.state.output.dtype``. Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ meta_data = self.batch_transform(engine.state.batch) engine_output = self.output_transform(engine.state.output) self.saver.save_batch(engine_output, meta_data) self.logger.info("saved all the model outputs into files.")
class SegmentationSaver: """ Event handler triggered on completing every iteration to save the segmentation predictions into files. """ def __init__( self, output_dir: str = "./", output_postfix: str = "seg", output_ext: str = ".nii.gz", resample: bool = True, interp_order: str = "nearest", mode: str = "border", scale=None, dtype: Optional[np.dtype] = None, batch_transform: Callable = lambda x: x, output_transform: Callable = lambda x: x, name: Optional[str] = None, ): """ Args: output_dir: output image directory. output_postfix: a string appended to all output file names. output_ext: output file extension name. resample: whether to resample before saving the data array. interp_order: The interpolation mode. Defaults to "nearest". This option is used when `resample = True`. When saving NIfTI files, the available options are "nearest", "bilinear" See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample. When saving PNG files, the available options are "nearest", "bilinear", "bicubic", "area". See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate. mode: The mode parameter determines how the input array is extended beyond its boundaries. This option is used when `resample = True`. When saving NIfTI files, the options are "zeros", "border", "reflection". Default is "border". When saving PNG files, the options is ignored. scale (255, 65535): postprocess data by clipping to [0, 1] and scaling [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling. It's used for PNG format only. dtype (np.dtype, optional): convert the image data to save to this data type. If None, keep the original type of data. It's used for Nifti format only. batch_transform: a callable that is used to transform the ignite.engine.batch into expected format to extract the meta_data dictionary. output_transform: a callable that is used to transform the ignite.engine.output into the form expected image data. The first dimension of this transform's output will be treated as the batch dimension. Each item in the batch will be saved individually. name: identifier of logging.logger to use, defaulting to `engine.logger`. """ self.saver: Union[NiftiSaver, PNGSaver] if output_ext in (".nii.gz", ".nii"): self.saver = NiftiSaver( output_dir=output_dir, output_postfix=output_postfix, output_ext=output_ext, resample=resample, interp_order=interp_order, mode=mode, dtype=dtype, ) elif output_ext == ".png": self.saver = PNGSaver( output_dir=output_dir, output_postfix=output_postfix, output_ext=output_ext, resample=resample, interp_order=interp_order, scale=scale, ) self.batch_transform = batch_transform self.output_transform = output_transform self.logger = None if name is None else logging.getLogger(name) self._name = name def attach(self, engine: Engine): if self._name is None: self.logger = engine.logger if not engine.has_event_handler(self, Events.ITERATION_COMPLETED): engine.add_event_handler(Events.ITERATION_COMPLETED, self) def __call__(self, engine): """ This method assumes self.batch_transform will extract metadata from the input batch. Output file datatype is determined from ``engine.state.output.dtype``. """ meta_data = self.batch_transform(engine.state.batch) engine_output = self.output_transform(engine.state.output) self.saver.save_batch(engine_output, meta_data) self.logger.info("saved all the model outputs into files.")
class SegmentationSaver: """ Event handler triggered on completing every iteration to save the segmentation predictions into files. """ def __init__( self, output_dir="./", output_postfix="seg", output_ext=".nii.gz", resample=True, interp_order=0, mode="constant", cval=0, scale=False, dtype=None, batch_transform=lambda x: x, output_transform=lambda x: x, name=None, ): """ Args: output_dir (str): output image directory. output_postfix (str): a string appended to all output file names. output_ext (str): output file extension name. resample (bool): whether to resample before saving the data array. interp_order (int): the order of the spline interpolation, default is 0. The order has to be in the range 0 - 5. https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.affine_transform.html this option is used when `resample = True`. mode (`reflect|constant|nearest|mirror|wrap`): The mode parameter determines how the input array is extended beyond its boundaries. this option is used when `resample = True`. cval (scalar): Value to fill past edges of input if mode is "constant". Default is 0.0. this option is used when `resample = True`. scale (bool): whether to scale data with 255 and convert to uint8 for data in range [0, 1]. it's used for PNG format only. dtype (np.dtype, optional): convert the image data to save to this data type. If None, keep the original type of data. it's used for Nifti format only. batch_transform (Callable): a callable that is used to transform the ignite.engine.batch into expected format to extract the meta_data dictionary. output_transform (Callable): a callable that is used to transform the ignite.engine.output into the form expected image data. The first dimension of this transform's output will be treated as the batch dimension. Each item in the batch will be saved individually. name (str): identifier of logging.logger to use, defaulting to `engine.logger`. """ if output_ext in (".nii.gz", ".nii"): self.saver = NiftiSaver(output_dir, output_postfix, output_ext, resample, interp_order, mode, cval, dtype) elif output_ext == ".png": self.saver = PNGSaver(output_dir, output_postfix, output_ext, resample, interp_order, mode, cval, scale) self.batch_transform = batch_transform self.output_transform = output_transform self.logger = None if name is None else logging.getLogger(name) def attach(self, engine): if self.logger is None: self.logger = engine.logger if not engine.has_event_handler(self, Events.ITERATION_COMPLETED): engine.add_event_handler(Events.ITERATION_COMPLETED, self) def __call__(self, engine): """ This method assumes self.batch_transform will extract metadata from the input batch. output file datatype is determined from ``engine.state.output.dtype``. """ meta_data = self.batch_transform(engine.state.batch) engine_output = self.output_transform(engine.state.output) self.saver.save_batch(engine_output, meta_data) self.logger.info("saved all the model outputs into files.")
def main(tempdir): config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, seg = create_test_image_2d(128, 128, num_seg_classes=1) Image.fromarray(im.astype("uint8")).save( os.path.join(tempdir, f"img{i:d}.png")) Image.fromarray(seg.astype("uint8")).save( os.path.join(tempdir, f"seg{i:d}.png")) images = sorted(glob(os.path.join(tempdir, "img*.png"))) segs = sorted(glob(os.path.join(tempdir, "seg*.png"))) # define transforms for image and segmentation imtrans = Compose([ LoadImage(image_only=True), ScaleIntensity(), AddChannel(), ToTensor() ]) segtrans = Compose([LoadImage(image_only=True), AddChannel(), ToTensor()]) val_ds = ArrayDataset(images, imtrans, segs, segtrans) # sliding window inference for one image at every iteration val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) dice_metric = DiceMetric(include_background=True, reduction="mean") post_trans = Compose( [Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = UNet( dimensions=2, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) model.load_state_dict( torch.load("best_metric_model_segmentation2d_array.pth")) model.eval() with torch.no_grad(): metric_sum = 0.0 metric_count = 0 saver = PNGSaver(output_dir="./output") for val_data in val_loader: val_images, val_labels = val_data[0].to(device), val_data[1].to( device) # define sliding window size and batch size for windows inference roi_size = (96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) val_outputs = post_trans(val_outputs) value, _ = dice_metric(y_pred=val_outputs, y=val_labels) metric_count += len(value) metric_sum += value.item() * len(value) saver.save_batch(val_outputs) metric = metric_sum / metric_count print("evaluation metric:", metric)