def save_segmentation_examples(self, nr_cubes=3, inference_full_image=True):

        # deal with recursion when defaulting to patchign

        if "lidc" in self.dataset_name:
            return

        torch.cuda.ipc_collect()
        torch.cuda.empty_cache()
        dump_tensors()

        if hasattr(self.trainer, "model"):
            del self.trainer.model
            del self.trainer
            sleep(15)
            self.trainer = Trainer(config=self.config, dataset=None)
            dump_tensors()
            torch.cuda.ipc_collect()
            torch.cuda.empty_cache()
            dump_tensors()
            sleep(5)
        if inference_full_image is False:
            print("PATCHING Will be Done")

        dump_tensors()
        torch.cuda.ipc_collect()
        torch.cuda.empty_cache()
        dump_tensors()

        self.trainer.load_model(from_path=True, path=self.model_path, phase="sup", ensure_sup_is_completed=True)

        cubes_to_use = []
        cubes_to_use.extend(self.sample_k_full_cubes_which_were_used_for_testing(nr_cubes))
        cubes_to_use.extend(self.sample_k_full_cubes_which_were_used_for_training(nr_cubes))

        cubes_to_use_path = [os.path.join(self.dataset_dir, i) for i in cubes_to_use]
        label_cubes_of_cubes_to_use_path = [os.path.join(self.dataset_labels_dir, i) for i in cubes_to_use]

        for cube_idx, cube_path in enumerate(cubes_to_use_path):
            np_array = self._load_cube_to_np_array(cube_path)  # (x,y,z)
            self.original_cube_dimensions = np_array.shape
            if sum([i for i in np_array.shape]) > 550 and self.two_dim is False:
                inference_full_image = False

            if self.dataset_name.lower() in ("task04_sup", "task01_sup", "cellari_heart_sup_10_192", "cellari_heart_sup"):
                if self.tried is False:
                    inference_full_image = True
                else:
                    inference_full_image = False

            if inference_full_image is False:
                print("CUBE TOO BIG, PATCHING")
                patcher = Patcher(np_array, two_dim=self.two_dim)

                with torch.no_grad():
                    self.trainer.model.eval()
                    for idx, patch in patcher:

                        patch = torch.unsqueeze(patch, 0)  # (1,C,H,W or 1) -> (1,1,C,H,W or 1)
                        if self.config.model.lower() in (
                            "vnet_mg",
                            "unet_3d",
                            "unet_acs",
                            "unet_acs_axis_aware_decoder",
                            "unet_acs_with_cls",
                        ):
                            patch, pad_tuple = pad_if_necessary_one_array(patch, return_pad_tuple=True)

                        pred = self.trainer.model(patch)
                        assert pred.shape == patch.shape, "{} vs {}".format(pred.shape, patch.shape)
                        # need to then unpad to reconstruct
                        if self.two_dim is True:
                            raise RuntimeError("SHOULD  NOT BE USED HERE")

                        pred = self._unpad_3d_array(pred, pad_tuple)
                        pred = torch.squeeze(pred, dim=0)  # (1, 1, C,H,W) -> (1,C,H,W)
                        pred_mask = pred  # self._make_pred_mask_from_pred(pred)
                        del pred

                        patcher.predicitons_to_reconstruct_from[
                            :, idx
                        ] = pred_mask  # update array in patcher that will construct full cube predicted mask

                        dump_tensors()
                        torch.cuda.ipc_collect()
                        torch.cuda.empty_cache()
                        dump_tensors()

                pred_mask_full_cube = patcher.get_pred_mask_full_cube()
                # segmentations.append(patcher.get_pred_mask_full_cube())
            else:

                full_cube_tensor = torch.Tensor(np_array)
                full_cube_tensor = torch.unsqueeze(full_cube_tensor, 0)  # (C,H,W) -> (1,C,H,W)
                full_cube_tensor = torch.unsqueeze(full_cube_tensor, 0)  # (1,C,H,W) -> (1,1,C,H,W)

                with torch.no_grad():
                    self.trainer.model.eval()
                    if self.two_dim is False:
                        if self.config.model.lower() in (
                            "vnet_mg",
                            "unet_3d",
                            "unet_acs",
                            "unet_acs_axis_aware_decoder",
                            "unet_acs_with_cls",
                        ):
                            full_cube_tensor, pad_tuple = pad_if_necessary_one_array(full_cube_tensor, return_pad_tuple=True)
                            try:
                                p = self.trainer.model(full_cube_tensor)
                                p.to("cpu")
                                pred = p
                                del p
                                dump_tensors()
                                torch.cuda.ipc_collect()
                                torch.cuda.empty_cache()
                                dump_tensors()
                                torch.cuda.empty_cache()
                                pred = self._unpad_3d_array(pred, pad_tuple)
                                pred = torch.squeeze(pred, dim=0)  # (1, 1, C,H,W) -> (1,C,H,W)
                                pred = torch.squeeze(pred, dim=0)
                                pred_mask_full_cube = pred  # self._make_pred_mask_from_pred(pred)
                                torch.cuda.ipc_collect()
                                torch.cuda.empty_cache()
                                del pred

                            except RuntimeError as e:
                                if "out of memory" in str(e) or "cuDNN error: CUDNN_STATUS_NOT_SUPPORTED" in str(e):
                                    print("TOO BIG FOR MEMORY, DEFAULTING TO PATCHING")
                                    # exit(0)
                                    dump_tensors()
                                    torch.cuda.ipc_collect()
                                    torch.cuda.empty_cache()
                                    dump_tensors()
                                    self.tried = True
                                    self.save_segmentation_examples(inference_full_image=False)
                                    return

                            # segmentations.append(pred_mask_full_cube)
                    else:
                        pred_mask_full_cube = torch.zeros(self.original_cube_dimensions)
                        for z_idx in range(full_cube_tensor.size()[-1]):
                            tensor_slice = full_cube_tensor[..., z_idx]  # SLICE : (1,1,C,H,W) -> (1,1,C,H)
                            assert tensor_slice.shape == (1, 1, self.original_cube_dimensions[0], self.original_cube_dimensions[1])
                            pred = self.trainer.model(tensor_slice)
                            pred = torch.squeeze(pred, dim=0)  # (1, 1, C,H) -> (1,C,H)
                            pred = torch.squeeze(pred, dim=0)  # (1,C,H) -> (C,H)
                            pred_mask_slice = pred  # self._make_pred_mask_from_pred(pred)
                            pred_mask_full_cube[..., z_idx] = pred_mask_slice

                        # segmentations.append(pred_mask_full_cube)

            # for idx, pred_mask_full_cube in enumerate(segmentations):

            print(cube_idx)

            if cube_idx < nr_cubes:
                if inference_full_image is True:
                    save_dir = os.path.join(self.save_dir, self.dataset_name, "testing_examples_full/", cubes_to_use[cube_idx][:-4])
                else:
                    save_dir = os.path.join(
                        self.save_dir, self.dataset_name, "testing_examples_full/", cubes_to_use[cube_idx][:-4] + "_with_patcher"
                    )
            else:
                if inference_full_image is True:
                    save_dir = os.path.join(self.save_dir, self.dataset_name, "training_examples_full/", cubes_to_use[cube_idx][:-4])
                else:
                    save_dir = os.path.join(
                        self.save_dir, self.dataset_name, "training_examples_full/", cubes_to_use[cube_idx][:-4] + "_with_patcher"
                    )

            make_dir(save_dir)

            # save nii of segmentation
            pred_mask_full_cube = pred_mask_full_cube.cpu()  # logits mask
            pred_mask_full_cube_binary = self._make_pred_mask_from_pred(pred_mask_full_cube)  # binary mask

            nifty_img = nibabel.Nifti1Image(np.array(pred_mask_full_cube).astype(np.float32), np.eye(4))
            nibabel.save(nifty_img, os.path.join(save_dir, cubes_to_use[cube_idx][:-4] + "_logits_mask.nii.gz"))

            nifty_img = nibabel.Nifti1Image(np.array(pred_mask_full_cube_binary).astype(np.float32), np.eye(4))
            nibabel.save(nifty_img, os.path.join(save_dir, cubes_to_use[cube_idx][:-4] + "_binary_mask.nii.gz"))

            # save .nii.gz of cube if is npy original full cube file
            if ".npy" in cube_path:
                nifty_img = nibabel.Nifti1Image(np_array.astype(np.float32), np.eye(4))
                nibabel.save(nifty_img, os.path.join(save_dir, cubes_to_use[cube_idx][:-4] + "_cube.nii.gz"))

            # self.save_3d_plot(np.array(pred_mask_full_cube), os.path.join(save_dir, "{}_plt3d.png".format(cubes_to_use[idx])))

            label_tensor_of_cube = torch.Tensor(self._load_cube_to_np_array(label_cubes_of_cubes_to_use_path[cube_idx]))
            label_tensor_of_cube = self.adjust_label_cube_acording_to_dataset(label_tensor_of_cube)
            label_tensor_of_cube_masked = np.array(label_tensor_of_cube)
            label_tensor_of_cube_masked = np.ma.masked_where(
                label_tensor_of_cube_masked < 0.5, label_tensor_of_cube_masked
            )  # it's binary anyway

            pred_mask_full_cube_binary_masked = np.array(pred_mask_full_cube_binary)
            pred_mask_full_cube_binary_masked = np.ma.masked_where(
                pred_mask_full_cube_binary_masked < 0.5, pred_mask_full_cube_binary_masked
            )  # it's binary anyway

            pred_mask_full_cube_logits_masked = np.array(pred_mask_full_cube)
            pred_mask_full_cube_logits_masked = np.ma.masked_where(
                pred_mask_full_cube_logits_masked < 0.3, pred_mask_full_cube_logits_masked
            )  # it's binary anyway

            make_dir(os.path.join(save_dir, "slices/"))

            for z_idx in range(pred_mask_full_cube.shape[-1]):

                # binary
                fig = plt.figure(figsize=(10, 5))
                plt.imshow(np_array[:, :, z_idx], cmap=cm.Greys_r)
                plt.imshow(pred_mask_full_cube_binary_masked[:, :, z_idx], cmap="Accent")
                plt.axis("off")
                fig.savefig(
                    os.path.join(save_dir, "slices/", "slice_{}_binary.jpg".format(z_idx + 1)),
                    bbox_inches="tight",
                    dpi=150,
                )
                plt.close(fig=fig)

                # logits
                fig = plt.figure(figsize=(10, 5))
                plt.imshow(np_array[:, :, z_idx], cmap=cm.Greys_r)
                plt.imshow(pred_mask_full_cube_logits_masked[:, :, z_idx], cmap="Blues", alpha=0.5)
                plt.axis("off")
                fig.savefig(
                    os.path.join(save_dir, "slices/", "slice_{}_logits.jpg".format(z_idx + 1)),
                    bbox_inches="tight",
                    dpi=150,
                )
                plt.close(fig=fig)

                # dist of logits histogram
                distribution_logits = np.array(pred_mask_full_cube[:, :, z_idx].contiguous().view(-1))
                fig = plt.figure(figsize=(10, 5))
                plt.hist(distribution_logits, bins=np.arange(min(distribution_logits), max(distribution_logits) + 0.05, 0.05))
                fig.savefig(
                    os.path.join(save_dir, "slices/", "slice_{}_logits_histogram.jpg".format(z_idx + 1)),
                    bbox_inches="tight",
                    dpi=150,
                )
                plt.close(fig=fig)

                # save ground truth as wel, overlayed on original
                fig = plt.figure(figsize=(10, 5))
                plt.imshow(np_array[:, :, z_idx], cmap=cm.Greys_r)
                plt.imshow(label_tensor_of_cube_masked[:, :, z_idx], cmap="jet")
                plt.axis("off")
                fig.savefig(
                    os.path.join(save_dir, "slices/", "slice_{}_gt.jpg".format(z_idx + 1)),
                    bbox_inches="tight",
                    dpi=150,
                )
                plt.close(fig=fig)

            dice_score_soft = float(DiceLoss.dice_loss(pred_mask_full_cube, label_tensor_of_cube, return_loss=False))
            dice_score_binary = float(DiceLoss.dice_loss(pred_mask_full_cube_binary, label_tensor_of_cube, return_loss=False))
            x_flat = pred_mask_full_cube_binary.contiguous().view(-1)
            y_flat = pred_mask_full_cube_binary.contiguous().view(-1)
            x_flat = x_flat.cpu()
            y_flat = y_flat.cpu()
            jaccard_scr = jaccard_score(y_flat, x_flat)
            metrics = {"dice_logits": dice_score_soft, "dice_binary": dice_score_binary, "jaccard": jaccard_scr}
            # print(dice)
            with open(os.path.join(save_dir, "dice.json"), "w") as f:
                json.dump(metrics, f)

            dump_tensors()
            torch.cuda.ipc_collect()
            torch.cuda.empty_cache()
            dump_tensors()
            dump_tensors()
            torch.cuda.ipc_collect()
            torch.cuda.empty_cache()
            dump_tensors()
            sleep(10)
    def compute_metrics_for_all_cubes(self, inference_full_image=True):

        cubes_to_use = []

        dump_tensors()
        torch.cuda.ipc_collect()
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
        torch.cuda.empty_cache()
        dump_tensors()
        torch.cuda.empty_cache()

        if "lidc" in self.dataset_name:
            return

        if hasattr(self.trainer, "model"):
            del self.trainer.model
            del self.trainer
            sleep(20)
            self.trainer = Trainer(config=self.config, dataset=None)
            dump_tensors()
            torch.cuda.ipc_collect()
            torch.cuda.empty_cache()
            dump_tensors()

        dump_tensors()
        torch.cuda.ipc_collect()
        torch.cuda.empty_cache()
        dump_tensors()

        self.trainer.load_model(from_path=True, path=self.model_path, phase="sup", ensure_sup_is_completed=True)

        if inference_full_image is False:
            print("PATCHING Will be Done")

        full_cubes_used_for_testing = self.get_all_cubes_which_were_used_for_testing()
        full_cubes_used_for_training = self.get_all_cubes_which_were_used_for_training()
        cubes_to_use.extend(full_cubes_used_for_testing)
        cubes_to_use.extend(full_cubes_used_for_training)

        cubes_to_use_path = [os.path.join(self.dataset_dir, i) for i in cubes_to_use]
        label_cubes_of_cubes_to_use_path = [os.path.join(self.dataset_labels_dir, i) for i in cubes_to_use]

        metric_dict = dict()

        (
            dice_logits_test,
            dice_logits_train,
            dice_binary_test,
            dice_binary_train,
            jaccard_test,
            jaccard_train,
            hausdorff_test,
            hausdorff_train,
        ) = ([], [], [], [], [], [], [], [])

        for idx, cube_path in enumerate(cubes_to_use_path):
            np_array = self._load_cube_to_np_array(cube_path)  # (x,y,z)
            self.original_cube_dimensions = np_array.shape
            if sum([i for i in np_array.shape]) > 550 and self.two_dim is False:
                inference_full_image = False

            if self.dataset_name.lower() in ("task04_sup", "task01_sup", "cellari_heart_sup_10_192", "cellari_heart_sup"):
                if self.tried is False:
                    inference_full_image = True
                else:
                    inference_full_image = False

            if inference_full_image is False:
                print("CUBE TOO BIG, PATCHING")

                patcher = Patcher(np_array, two_dim=self.two_dim)

                with torch.no_grad():
                    self.trainer.model.eval()
                    for patch_idx, patch in patcher:

                        patch = torch.unsqueeze(patch, 0)  # (1,C,H,W or 1) -> (1,1,C,H,W or 1)
                        if self.config.model.lower() in (
                            "vnet_mg",
                            "unet_3d",
                            "unet_acs",
                            "unet_acs_axis_aware_decoder",
                            "unet_acs_with_cls",
                        ):
                            patch, pad_tuple = pad_if_necessary_one_array(patch, return_pad_tuple=True)

                        pred = self.trainer.model(patch)
                        assert pred.shape == patch.shape, "{} vs {}".format(pred.shape, patch.shape)
                        # need to then unpad to reconstruct
                        if self.two_dim is True:
                            raise RuntimeError("SHOULD  NOT BE USED HERE")

                        pred = self._unpad_3d_array(pred, pad_tuple)
                        pred = torch.squeeze(pred, dim=0)  # (1, 1, C,H,W) -> (1,C,H,W)
                        # pred_mask = self._make_pred_mask_from_pred(pred)
                        patcher.predicitons_to_reconstruct_from[
                            :, patch_idx
                        ] = pred  # update array in patcher that will construct full cube predicted mask
                        del pred
                        dump_tensors()
                        torch.cuda.ipc_collect()
                        torch.cuda.empty_cache()
                        dump_tensors()

                pred_mask_full_cube = patcher.get_pred_mask_full_cube()

            else:

                full_cube_tensor = torch.Tensor(np_array)
                full_cube_tensor = torch.unsqueeze(full_cube_tensor, 0)  # (C,H,W) -> (1,C,H,W)
                full_cube_tensor = torch.unsqueeze(full_cube_tensor, 0)  # (1,C,H,W) -> (1,1,C,H,W)

                with torch.no_grad():
                    self.trainer.model.eval()
                    if self.two_dim is False:
                        if self.config.model.lower() in (
                            "vnet_mg",
                            "unet_3d",
                            "unet_acs",
                            "unet_acs_axis_aware_decoder",
                            "unet_acs_with_cls",
                        ):
                            full_cube_tensor, pad_tuple = pad_if_necessary_one_array(full_cube_tensor, return_pad_tuple=True)
                            try:
                                p = self.trainer.model(full_cube_tensor)
                                p.to("cpu")
                                pred = p
                                del p
                                dump_tensors()
                                torch.cuda.ipc_collect()
                                torch.cuda.empty_cache()
                                dump_tensors()
                                torch.cuda.empty_cache()
                                pred = self._unpad_3d_array(pred, pad_tuple)
                                pred = torch.squeeze(pred, dim=0)  # (1, 1, C,H,W) -> (1,C,H,W)
                                pred = torch.squeeze(pred, dim=0)
                                pred_mask_full_cube = pred  # self._make_pred_mask_from_pred(pred)
                                torch.cuda.ipc_collect()
                                torch.cuda.empty_cache()
                                del pred

                            except RuntimeError as e:
                                if "out of memory" in str(e) or "cuDNN error: CUDNN_STATUS_NOT_SUPPORTED" in str(e):
                                    print("TOO BIG FOR MEMORY, DEFAULTING TO PATCHING")
                                    # exit(0)
                                    dump_tensors()
                                    torch.cuda.ipc_collect()
                                    torch.cuda.empty_cache()
                                    dump_tensors()
                                    self.tried = True
                                    res = self.compute_metrics_for_all_cubes(inference_full_image=False)
                                    return res

                    else:
                        pred_mask_full_cube = torch.zeros(self.original_cube_dimensions)
                        for z_idx in range(full_cube_tensor.size()[-1]):
                            tensor_slice = full_cube_tensor[..., z_idx]  # SLICE : (1,1,C,H,W) -> (1,1,C,H)
                            assert tensor_slice.shape == (1, 1, self.original_cube_dimensions[0], self.original_cube_dimensions[1])
                            pred = self.trainer.model(tensor_slice)
                            pred = torch.squeeze(pred, dim=0)  # (1, 1, C,H) -> (1,C,H)
                            pred = torch.squeeze(pred, dim=0)  # (1,C,H) -> (C,H)
                            pred_mask_slice = pred  # self._make_pred_mask_from_pred(pred)
                            pred_mask_full_cube[..., z_idx] = pred_mask_slice

            full_cube_label_tensor = torch.Tensor(self._load_cube_to_np_array(label_cubes_of_cubes_to_use_path[idx]))
            full_cube_label_tensor = self.adjust_label_cube_acording_to_dataset(full_cube_label_tensor)

            pred_mask_full_cube = pred_mask_full_cube.to("cpu")
            threshold = self._set_threshold(pred_mask_full_cube, full_cube_label_tensor)
            pred_mask_full_cube_binary = self._make_pred_mask_from_pred(pred_mask_full_cube, threshold=threshold)

            dice_score_soft = float(DiceLoss.dice_loss(pred_mask_full_cube, full_cube_label_tensor, return_loss=False))
            dice_score_binary = float(DiceLoss.dice_loss(pred_mask_full_cube_binary, full_cube_label_tensor, return_loss=False))
            hausdorff = hausdorff_distance(np.array(pred_mask_full_cube_binary), np.array(full_cube_label_tensor))
            x_flat = pred_mask_full_cube_binary.contiguous().view(-1)
            y_flat = full_cube_label_tensor.contiguous().view(-1)
            x_flat = x_flat.cpu()
            y_flat = y_flat.cpu()
            jac_score = jaccard_score(y_flat, x_flat)

            if idx < len(full_cubes_used_for_testing):
                dice_logits_test.append(dice_score_soft)
                dice_binary_test.append(dice_score_binary)
                jaccard_test.append(jac_score)
                hausdorff_test.append(hausdorff)
            else:
                dice_logits_train.append(dice_score_soft)
                dice_binary_train.append(dice_score_binary)
                jaccard_train.append(jac_score)
                hausdorff_train.append(hausdorff)

            dump_tensors()
            torch.cuda.ipc_collect()
            torch.cuda.empty_cache()
            dump_tensors()
            sleep(10)
            print(idx)

        avg_jaccard_test = sum(jaccard_test) / len(jaccard_test)
        avg_jaccard_train = sum(jaccard_train) / len(jaccard_train)

        avg_dice_test_soft = sum(dice_logits_test) / len(dice_logits_test)
        avg_dice_test_binary = sum(dice_binary_test) / len(dice_binary_test)

        avg_dice_train_soft = sum(dice_logits_train) / len(dice_logits_train)
        avg_dice_train_binary = sum(dice_binary_train) / len(dice_binary_train)

        avg_hausdorff_train = sum(hausdorff_train) / len(hausdorff_train)
        avg_hausdorff_test = sum(hausdorff_test) / len(hausdorff_test)

        metric_dict["dice_test_soft"] = avg_dice_test_soft
        metric_dict["dice_test_binary"] = avg_dice_test_binary
        metric_dict["dice_train_soft"] = avg_dice_train_soft
        metric_dict["dice_train_binary"] = avg_dice_train_binary
        metric_dict["jaccard_test"] = avg_jaccard_test
        metric_dict["jaccard_train"] = avg_jaccard_train
        metric_dict["hausdorff_test"] = avg_hausdorff_test
        metric_dict["hausdorff_train"] = avg_hausdorff_train

        return metric_dict