Ejemplo n.º 1
0
    def test_squeeze_end_dims(self):
        with tempfile.TemporaryDirectory() as tempdir:

            for squeeze_end_dims in [False, True]:

                saver = NiftiSaver(
                    output_dir=tempdir,
                    output_postfix="",
                    output_ext=".nii.gz",
                    dtype=np.float32,
                    squeeze_end_dims=squeeze_end_dims,
                )

                fname = "testfile_squeeze"
                meta_data = {"filename_or_obj": fname}

                # 2d image w channel
                saver.save(torch.randint(0, 255, (1, 2, 2)), meta_data)

                im = LoadImage(image_only=True)(os.path.join(
                    tempdir, fname, fname + ".nii.gz"))
                self.assertTrue(im.ndim == 2 if squeeze_end_dims else 4)
Ejemplo n.º 2
0
    def run_inference(self, model, data_loader):
        logger = self.logger
        logger.info('Running inference...')

        model.eval()  # activate evaluation mode of model
        dice_scores = np.zeros(len(data_loader))

        if self.model == "UNet2d5_spvPA":
            model_segmentation = lambda *args, **kwargs: model(
                *args, **kwargs)[0]
        else:
            model_segmentation = model

        with torch.no_grad(
        ):  # turns off PyTorch's auto grad for better performance
            for i, data in enumerate(data_loader):
                logger.info("starting image {}".format(i))

                outputs = sliding_window_inference(
                    inputs=data["image"].to(self.device),
                    roi_size=self.sliding_window_inferer_roi_size,
                    sw_batch_size=1,
                    predictor=model_segmentation,
                    mode="gaussian",
                )

                dice_score = self.compute_dice_score(
                    outputs, data["label"].to(self.device))
                dice_scores[i] = dice_score.item()

                logger.info(f"dice_score = {dice_score.item()}")

                # export to nifti
                if self.export_inferred_segmentations:
                    logger.info(f"export to nifti...")

                    nifti_data_matrix = np.squeeze(
                        torch.argmax(outputs, dim=1, keepdim=True))[None, :]
                    data['label_meta_dict']['filename_or_obj'] = data[
                        'label_meta_dict']['filename_or_obj'][0]
                    data['label_meta_dict']['affine'] = np.squeeze(
                        data['label_meta_dict']['affine'])
                    data['label_meta_dict']['original_affine'] = np.squeeze(
                        data['label_meta_dict']['original_affine'])

                    folder_name = os.path.basename(
                        os.path.dirname(
                            data['label_meta_dict']['filename_or_obj']))
                    saver = NiftiSaver(output_dir=os.path.join(
                        self.results_folder_path,
                        'inferred_segmentations_nifti', folder_name),
                                       output_postfix='')
                    saver.save(nifti_data_matrix,
                               meta_data=data['label_meta_dict'])

                # plot centre of mass slice of label
                label = torch.squeeze(data["label"][0, 0, :, :, :])
                slice_idx = self.get_center_of_mass_slice(
                    label
                )  # choose slice of selected validation set image volume for the figure
                plt.figure("check", (18, 6))
                plt.clf()
                plt.subplot(1, 3, 1)
                plt.title("image " + str(i) + ", slice = " + str(slice_idx))
                plt.imshow(data["image"][0, 0, :, :, slice_idx],
                           cmap="gray",
                           interpolation="none")
                plt.subplot(1, 3, 2)
                plt.title("label " + str(i))
                plt.imshow(data["label"][0, 0, :, :, slice_idx],
                           interpolation="none")
                plt.subplot(1, 3, 3)
                plt.title("output " + str(i) +
                          f", dice = {dice_score.item():.4}")
                plt.imshow(torch.argmax(outputs,
                                        dim=1).detach().cpu()[0, :, :,
                                                              slice_idx],
                           interpolation="none")
                plt.savefig(
                    os.path.join(self.figures_path,
                                 "best_model_output_val" + str(i) + ".png"))

        plt.figure("dice score histogram")
        plt.hist(dice_scores, bins=np.arange(0, 1.01, 0.01))
        plt.savefig(
            os.path.join(self.figures_path,
                         "best_model_output_dice_score_histogram.png"))

        logger.info(f"all_dice_scores = {dice_scores}")
        logger.info(
            f"mean_dice_score = {dice_scores.mean()} +- {dice_scores.std()}")