Beispiel #1
0
    def __init__(
        self,
        output_dir: str = "./",
        output_postfix: str = "seg",
        output_ext: str = ".nii.gz",
        resample: bool = True,
        mode: Union[GridSampleMode, InterpolateMode, str] = "nearest",
        padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER,
        scale=None,
        dtype: Optional[np.dtype] = None,
        batch_transform: Callable = lambda x: x,
        output_transform: Callable = lambda x: x,
        name: Optional[str] = None,
    ):
        """
        Args:
            output_dir: output image directory.
            output_postfix: a string appended to all output file names.
            output_ext: output file extension name.
            resample: whether to resample before saving the data array.
            mode: This option is used when ``resample = True``. Defaults to ``"nearest"``.

                - NIfTI files {``"bilinear"``, ``"nearest"``}
                    Interpolation mode to calculate output values.
                    See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
                - PNG files {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``}
                    The interpolation mode.
                    See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate

            padding_mode: This option is used when ``resample = True``. Defaults to ``"border"``.

                - NIfTI files {``"zeros"``, ``"border"``, ``"reflection"``}
                    Padding mode for outside grid values.
                    See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
                - PNG files
                    This option is ignored.

            scale (255, 65535): postprocess data by clipping to [0, 1] and scaling
                [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling.
                It's used for PNG format only.
            dtype (np.dtype, optional): convert the image data to save to this data type.
                If None, keep the original type of data. It's used for Nifti format only.
            batch_transform: a callable that is used to transform the
                ignite.engine.batch into expected format to extract the meta_data dictionary.
            output_transform: a callable that is used to transform the
                ignite.engine.output into the form expected image data.
                The first dimension of this transform's output will be treated as the
                batch dimension. Each item in the batch will be saved individually.
            name: identifier of logging.logger to use, defaulting to `engine.logger`.

        """
        self.saver: Union[NiftiSaver, PNGSaver]
        if output_ext in (".nii.gz", ".nii"):
            self.saver = NiftiSaver(
                output_dir=output_dir,
                output_postfix=output_postfix,
                output_ext=output_ext,
                resample=resample,
                mode=mode,
                padding_mode=padding_mode,
                dtype=dtype,
            )
        elif output_ext == ".png":
            self.saver = PNGSaver(
                output_dir=output_dir,
                output_postfix=output_postfix,
                output_ext=output_ext,
                resample=resample,
                mode=mode,
                scale=scale,
            )
        self.batch_transform = batch_transform
        self.output_transform = output_transform

        self.logger = None if name is None else logging.getLogger(name)
        self._name = name
    def __init__(
        self,
        output_dir: str = "./",
        output_postfix: str = "seg",
        output_ext: str = ".nii.gz",
        resample: bool = True,
        interp_order: str = "nearest",
        mode: str = "border",
        scale=None,
        dtype: Optional[np.dtype] = None,
        batch_transform: Callable = lambda x: x,
        output_transform: Callable = lambda x: x,
        name: Optional[str] = None,
    ):
        """
        Args:
            output_dir: output image directory.
            output_postfix: a string appended to all output file names.
            output_ext: output file extension name.
            resample: whether to resample before saving the data array.
            interp_order:
                The interpolation mode. Defaults to "nearest". This option is used when `resample = True`.
                When saving NIfTI files, the available options are "nearest", "bilinear"
                See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample.
                When saving PNG files, the available options are "nearest", "bilinear", "bicubic", "area".
                See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate.
            mode: The mode parameter determines how the input array is extended beyond its boundaries.
                This option is used when `resample = True`.
                When saving NIfTI files, the options are "zeros", "border", "reflection". Default is "border".
                When saving PNG files, the options is ignored.
            scale (255, 65535): postprocess data by clipping to [0, 1] and scaling
                [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling.
                It's used for PNG format only.
            dtype (np.dtype, optional): convert the image data to save to this data type.
                If None, keep the original type of data. It's used for Nifti format only.
            batch_transform: a callable that is used to transform the
                ignite.engine.batch into expected format to extract the meta_data dictionary.
            output_transform: a callable that is used to transform the
                ignite.engine.output into the form expected image data.
                The first dimension of this transform's output will be treated as the
                batch dimension. Each item in the batch will be saved individually.
            name: identifier of logging.logger to use, defaulting to `engine.logger`.

        """
        self.saver: Union[NiftiSaver, PNGSaver]
        if output_ext in (".nii.gz", ".nii"):
            self.saver = NiftiSaver(
                output_dir=output_dir,
                output_postfix=output_postfix,
                output_ext=output_ext,
                resample=resample,
                interp_order=interp_order,
                mode=mode,
                dtype=dtype,
            )
        elif output_ext == ".png":
            self.saver = PNGSaver(
                output_dir=output_dir,
                output_postfix=output_postfix,
                output_ext=output_ext,
                resample=resample,
                interp_order=interp_order,
                scale=scale,
            )
        self.batch_transform = batch_transform
        self.output_transform = output_transform

        self.logger = None if name is None else logging.getLogger(name)
        self._name = name
Beispiel #3
0
def main():
    config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    tempdir = tempfile.mkdtemp()
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))

    # define transforms for image and segmentation
    imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()])
    segtrans = Compose([AddChannel(), ToTensor()])
    val_ds = NiftiDataset(images,
                          segs,
                          transform=imtrans,
                          seg_transform=segtrans,
                          image_only=False)
    # sliding window inference for one image at every iteration
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=1,
                            pin_memory=torch.cuda.is_available())

    device = torch.device("cuda:0")
    model = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    model.load_state_dict(torch.load("best_metric_model.pth"))
    model.eval()
    with torch.no_grad():
        metric_sum = 0.0
        metric_count = 0
        saver = NiftiSaver(output_dir="./output")
        for val_data in val_loader:
            val_images, val_labels = val_data[0].to(device), val_data[1].to(
                device)
            # define sliding window size and batch size for windows inference
            roi_size = (96, 96, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_images, roi_size,
                                                   sw_batch_size, model)
            value = compute_meandice(y_pred=val_outputs,
                                     y=val_labels,
                                     include_background=True,
                                     to_onehot_y=False,
                                     sigmoid=True)
            metric_count += len(value)
            metric_sum += value.sum().item()
            val_outputs = (val_outputs.sigmoid() >= 0.5).float()
            saver.save_batch(val_outputs, val_data[2])
        metric = metric_sum / metric_count
        print("evaluation metric:", metric)
    shutil.rmtree(tempdir)
Beispiel #4
0
def main():
    """
    Read input and configuration parameters
    """
    parser = argparse.ArgumentParser(
        description='Run inference with basic UNet with MONAI.')
    parser.add_argument('--config',
                        dest='config',
                        metavar='config',
                        type=str,
                        help='config file')
    args = parser.parse_args()

    with open(args.config) as f:
        config_info = yaml.load(f, Loader=yaml.FullLoader)

    # print to log the parameter setups
    print(yaml.dump(config_info))

    # GPU params
    cuda_device = config_info['device']['cuda_device']
    num_workers = config_info['device']['num_workers']
    # inference params
    batch_size_inference = config_info['inference']['batch_size_inference']
    # temporary check as sliding window inference does not accept higher batch size
    assert batch_size_inference == 1
    prob_thr = config_info['inference']['probability_threshold']
    model_to_load = config_info['inference']['model_to_load']
    if not os.path.exists(model_to_load):
        raise IOError('Trained model not found')
    # data params
    data_root = config_info['data']['data_root']
    inference_list = config_info['data']['inference_list']
    # output saving
    out_dir = config_info['output']['out_dir']

    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    torch.cuda.set_device(cuda_device)
    """
    Data Preparation
    """
    val_files = create_data_list(data_folder_list=data_root,
                                 subject_list=inference_list,
                                 img_postfix='_Image',
                                 is_inference=True)

    print(len(val_files))
    print(val_files[0])
    print(val_files[-1])

    # data preprocessing for inference:
    # - convert data to right format [batch, channel, dim, dim, dim]
    # - apply whitening
    # - NOTE: resizing needs to be applied afterwards, otherwise it cannot be remapped back to original size
    val_transforms = Compose([
        LoadNiftid(keys=['img']),
        AddChanneld(keys=['img']),
        NormalizeIntensityd(keys=['img']),
        ToTensord(keys=['img'])
    ])
    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=batch_size_inference,
                            num_workers=num_workers)
    """
    Network preparation
    """
    device = torch.cuda.current_device()
    # Create UNet, DiceLoss and Adam optimizer.
    net = monai.networks.nets.UNet(
        dimensions=2,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    net.load_state_dict(torch.load(model_to_load))
    net.eval()
    """
    Run inference
    """
    with torch.no_grad():
        saver = NiftiSaver(output_dir=out_dir)
        for val_data in val_loader:
            val_images = val_data['img'].to(device)
            orig_size = list(val_images.shape)
            resized_size = copy.deepcopy(orig_size)
            resized_size[2] = 96
            resized_size[3] = 96
            val_images_resize = torch.nn.functional.interpolate(
                val_images, size=resized_size[2:], mode='trilinear')
            # define sliding window size and batch size for windows inference
            roi_size = (96, 96, 1)
            val_outputs = sliding_window_inference(val_images_resize, roi_size,
                                                   batch_size_inference, net)
            val_outputs = (val_outputs.sigmoid() >= prob_thr).float()
            val_outputs_resized = torch.nn.functional.interpolate(
                val_outputs, size=orig_size[2:], mode='nearest')
            # add post-processing
            val_outputs_resized = val_outputs_resized.detach().cpu().numpy()
            strt = ndimage.generate_binary_structure(3, 2)
            post = padded_binary_closing(np.squeeze(val_outputs_resized), strt)
            post = get_largest_component(post)
            val_outputs_resized = val_outputs_resized * post
            # out = np.zeros(img.shape[:-1], np.uint8)
            # out = set_ND_volume_roi_with_bounding_box_range(out, bb_min, bb_max, out_roi)

            saver.save_batch(
                val_outputs_resized, {
                    'filename_or_obj': val_data['img.filename_or_obj'],
                    'affine': val_data['img.affine']
                })
Beispiel #5
0
device = torch.device('cuda:0')
model = UNet(
    dimensions=3,
    in_channels=1,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
).to(device)

model.load_state_dict(torch.load('best_metric_model.pth'))
model.eval()
with torch.no_grad():
    metric_sum = 0.
    metric_count = 0
    saver = NiftiSaver(output_dir='./output')
    for val_data in val_loader:
        val_images, val_labels = val_data['img'].to(
            device), val_data['seg'].to(device)
        # define sliding window size and batch size for windows inference
        roi_size = (96, 96, 96)
        sw_batch_size = 4
        val_outputs = sliding_window_inference(val_images, roi_size,
                                               sw_batch_size, model)
        value = compute_meandice(y_pred=val_outputs,
                                 y=val_labels,
                                 include_background=True,
                                 to_onehot_y=False,
                                 add_sigmoid=True)
        metric_count += len(value)
        metric_sum += value.sum().item()
Beispiel #6
0
    def run_inference(self, model, data_loader):
        logger = self.logger
        logger.info('Running inference...')

        model.eval()  # activate evaluation mode of model
        dice_scores = np.zeros(len(data_loader))

        if self.model == "UNet2d5_spvPA":
            model_segmentation = lambda *args, **kwargs: model(
                *args, **kwargs)[0]
        else:
            model_segmentation = model

        with torch.no_grad(
        ):  # turns off PyTorch's auto grad for better performance
            for i, data in enumerate(data_loader):
                logger.info("starting image {}".format(i))

                outputs = sliding_window_inference(
                    inputs=data["image"].to(self.device),
                    roi_size=self.sliding_window_inferer_roi_size,
                    sw_batch_size=1,
                    predictor=model_segmentation,
                    mode="gaussian",
                )

                dice_score = self.compute_dice_score(
                    outputs, data["label"].to(self.device))
                dice_scores[i] = dice_score.item()

                logger.info(f"dice_score = {dice_score.item()}")

                # export to nifti
                if self.export_inferred_segmentations:
                    logger.info(f"export to nifti...")

                    nifti_data_matrix = np.squeeze(
                        torch.argmax(outputs, dim=1, keepdim=True))[None, :]
                    data['label_meta_dict']['filename_or_obj'] = data[
                        'label_meta_dict']['filename_or_obj'][0]
                    data['label_meta_dict']['affine'] = np.squeeze(
                        data['label_meta_dict']['affine'])
                    data['label_meta_dict']['original_affine'] = np.squeeze(
                        data['label_meta_dict']['original_affine'])

                    folder_name = os.path.basename(
                        os.path.dirname(
                            data['label_meta_dict']['filename_or_obj']))
                    saver = NiftiSaver(output_dir=os.path.join(
                        self.results_folder_path,
                        'inferred_segmentations_nifti', folder_name),
                                       output_postfix='')
                    saver.save(nifti_data_matrix,
                               meta_data=data['label_meta_dict'])

                # plot centre of mass slice of label
                label = torch.squeeze(data["label"][0, 0, :, :, :])
                slice_idx = self.get_center_of_mass_slice(
                    label
                )  # choose slice of selected validation set image volume for the figure
                plt.figure("check", (18, 6))
                plt.clf()
                plt.subplot(1, 3, 1)
                plt.title("image " + str(i) + ", slice = " + str(slice_idx))
                plt.imshow(data["image"][0, 0, :, :, slice_idx],
                           cmap="gray",
                           interpolation="none")
                plt.subplot(1, 3, 2)
                plt.title("label " + str(i))
                plt.imshow(data["label"][0, 0, :, :, slice_idx],
                           interpolation="none")
                plt.subplot(1, 3, 3)
                plt.title("output " + str(i) +
                          f", dice = {dice_score.item():.4}")
                plt.imshow(torch.argmax(outputs,
                                        dim=1).detach().cpu()[0, :, :,
                                                              slice_idx],
                           interpolation="none")
                plt.savefig(
                    os.path.join(self.figures_path,
                                 "best_model_output_val" + str(i) + ".png"))

        plt.figure("dice score histogram")
        plt.hist(dice_scores, bins=np.arange(0, 1.01, 0.01))
        plt.savefig(
            os.path.join(self.figures_path,
                         "best_model_output_dice_score_histogram.png"))

        logger.info(f"all_dice_scores = {dice_scores}")
        logger.info(
            f"mean_dice_score = {dice_scores.mean()} +- {dice_scores.std()}")
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    tempdir = tempfile.mkdtemp()
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose(
        [
            LoadNiftid(keys=["img", "seg"]),
            AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
            ScaleIntensityd(keys=["img", "seg"]),
            ToTensord(keys=["img", "seg"]),
        ]
    )
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    # sliding window inference need to input 1 image in every iteration
    val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate)
    dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean")

    # try to use all the available GPUs
    devices = get_devices_spec(None)
    model = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(devices[0])

    model.load_state_dict(torch.load("best_metric_model.pth"))

    # if we have multiple GPUs, set data parallel to execute sliding window inference
    if len(devices) > 1:
        model = torch.nn.DataParallel(model, device_ids=devices)

    model.eval()
    with torch.no_grad():
        metric_sum = 0.0
        metric_count = 0
        saver = NiftiSaver(output_dir="./output")
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(devices[0]), val_data["seg"].to(devices[0])
            # define sliding window size and batch size for windows inference
            roi_size = (96, 96, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
            value = dice_metric(y_pred=val_outputs, y=val_labels)
            metric_count += len(value)
            metric_sum += value.item() * len(value)
            val_outputs = (val_outputs.sigmoid() >= 0.5).float()
            saver.save_batch(val_outputs, val_data["img_meta_dict"])
        metric = metric_sum / metric_count
        print("evaluation metric:", metric)
    shutil.rmtree(tempdir)
Beispiel #8
0
        device)  #2nd argument: training_batch_size or 16
    with torch.no_grad():
        logits = forward(model, inputs)
    labels = logits.argmax(dim=CHANNELS_DIMENSION, keepdim=True)
    batch_mri = inputs
    batch_label = labels
    #slices = torch.cat((batch_mri, batch_label))
    #slices = torch.cat((batch_mri, batch_label),dim=1)
    #inf_path = 'inference.png'
    #save_image(slices, inf_path, nrow=training_batch_size//2, normalize=True, scale_each=True, padding=0)
    #display.Image(inf_path)

    #saver = NiftiSaver(output_dir="./niftinferece",output_postfix = str(i))
    #saver.save_batch(slices)

    saver1 = NiftiSaver(output_dir="./inputsnifti", output_postfix=str(i))
    saver2 = NiftiSaver(output_dir="./labelsnifti", output_postfix=str(i))
    saver1.save_batch(inputs)
    saver2.save_batch(labels)

    #Dice score for inference slide
    dice_score.append(
        get_dice_score(F.softmax(logits, dim=CHANNELS_DIMENSION), targets))
    #Dice loss for inference slide
    dice_losses.append(
        get_dice_loss(F.softmax(logits, dim=CHANNELS_DIMENSION), targets))

## experimental to output all dice scores -- WORKS!
dataset_iter = iter(validation_loader2)
for i in range(5):
    try:
Beispiel #9
0
def run_inference_test(root_dir, device="cuda:0"):
    images = sorted(glob(os.path.join(root_dir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose([
        LoadImaged(keys=["img", "seg"]),
        EnsureChannelFirstd(keys=["img", "seg"]),
        # resampling with align_corners=True or dtype=float64 will generate
        # slight different results between PyTorch 1.5 an 1.6
        Spacingd(keys=["img", "seg"],
                 pixdim=[1.2, 0.8, 0.7],
                 mode=["bilinear", "nearest"],
                 dtype=np.float32),
        ScaleIntensityd(keys="img"),
        ToTensord(keys=["img", "seg"]),
    ])
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    # sliding window inference need to input 1 image in every iteration
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)
    val_post_tran = Compose([
        ToTensor(),
        Activations(sigmoid=True),
        AsDiscrete(threshold_values=True)
    ])
    dice_metric = DiceMetric(include_background=True,
                             reduction="mean",
                             get_not_nans=False)

    model = UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    model_filename = os.path.join(root_dir, "best_metric_model.pth")
    model.load_state_dict(torch.load(model_filename))
    with eval_mode(model):
        # resampling with align_corners=True or dtype=float64 will generate
        # slight different results between PyTorch 1.5 an 1.6
        saver = NiftiSaver(output_dir=os.path.join(root_dir, "output"),
                           dtype=np.float32)
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(
                device), val_data["seg"].to(device)
            # define sliding window size and batch size for windows inference
            sw_batch_size, roi_size = 4, (96, 96, 96)
            val_outputs = sliding_window_inference(val_images, roi_size,
                                                   sw_batch_size, model)
            # decollate prediction into a list and execute post processing for every item
            val_outputs = [
                val_post_tran(i) for i in decollate_batch(val_outputs)
            ]
            # compute metrics
            dice_metric(y_pred=val_outputs, y=val_labels)
            saver.save_batch(val_outputs, val_data["img_meta_dict"])

    return dice_metric.aggregate().item()