def save_segmentation_examples(self, nr_cubes=3, inference_full_image=True): # deal with recursion when defaulting to patchign if "lidc" in self.dataset_name: return torch.cuda.ipc_collect() torch.cuda.empty_cache() dump_tensors() if hasattr(self.trainer, "model"): del self.trainer.model del self.trainer sleep(15) self.trainer = Trainer(config=self.config, dataset=None) dump_tensors() torch.cuda.ipc_collect() torch.cuda.empty_cache() dump_tensors() sleep(5) if inference_full_image is False: print("PATCHING Will be Done") dump_tensors() torch.cuda.ipc_collect() torch.cuda.empty_cache() dump_tensors() self.trainer.load_model(from_path=True, path=self.model_path, phase="sup", ensure_sup_is_completed=True) cubes_to_use = [] cubes_to_use.extend(self.sample_k_full_cubes_which_were_used_for_testing(nr_cubes)) cubes_to_use.extend(self.sample_k_full_cubes_which_were_used_for_training(nr_cubes)) cubes_to_use_path = [os.path.join(self.dataset_dir, i) for i in cubes_to_use] label_cubes_of_cubes_to_use_path = [os.path.join(self.dataset_labels_dir, i) for i in cubes_to_use] for cube_idx, cube_path in enumerate(cubes_to_use_path): np_array = self._load_cube_to_np_array(cube_path) # (x,y,z) self.original_cube_dimensions = np_array.shape if sum([i for i in np_array.shape]) > 550 and self.two_dim is False: inference_full_image = False if self.dataset_name.lower() in ("task04_sup", "task01_sup", "cellari_heart_sup_10_192", "cellari_heart_sup"): if self.tried is False: inference_full_image = True else: inference_full_image = False if inference_full_image is False: print("CUBE TOO BIG, PATCHING") patcher = Patcher(np_array, two_dim=self.two_dim) with torch.no_grad(): self.trainer.model.eval() for idx, patch in patcher: patch = torch.unsqueeze(patch, 0) # (1,C,H,W or 1) -> (1,1,C,H,W or 1) if self.config.model.lower() in ( "vnet_mg", "unet_3d", "unet_acs", "unet_acs_axis_aware_decoder", "unet_acs_with_cls", ): patch, pad_tuple = pad_if_necessary_one_array(patch, return_pad_tuple=True) pred = self.trainer.model(patch) assert pred.shape == patch.shape, "{} vs {}".format(pred.shape, patch.shape) # need to then unpad to reconstruct if self.two_dim is True: raise RuntimeError("SHOULD NOT BE USED HERE") pred = self._unpad_3d_array(pred, pad_tuple) pred = torch.squeeze(pred, dim=0) # (1, 1, C,H,W) -> (1,C,H,W) pred_mask = pred # self._make_pred_mask_from_pred(pred) del pred patcher.predicitons_to_reconstruct_from[ :, idx ] = pred_mask # update array in patcher that will construct full cube predicted mask dump_tensors() torch.cuda.ipc_collect() torch.cuda.empty_cache() dump_tensors() pred_mask_full_cube = patcher.get_pred_mask_full_cube() # segmentations.append(patcher.get_pred_mask_full_cube()) else: full_cube_tensor = torch.Tensor(np_array) full_cube_tensor = torch.unsqueeze(full_cube_tensor, 0) # (C,H,W) -> (1,C,H,W) full_cube_tensor = torch.unsqueeze(full_cube_tensor, 0) # (1,C,H,W) -> (1,1,C,H,W) with torch.no_grad(): self.trainer.model.eval() if self.two_dim is False: if self.config.model.lower() in ( "vnet_mg", "unet_3d", "unet_acs", "unet_acs_axis_aware_decoder", "unet_acs_with_cls", ): full_cube_tensor, pad_tuple = pad_if_necessary_one_array(full_cube_tensor, return_pad_tuple=True) try: p = self.trainer.model(full_cube_tensor) p.to("cpu") pred = p del p dump_tensors() torch.cuda.ipc_collect() torch.cuda.empty_cache() dump_tensors() torch.cuda.empty_cache() pred = self._unpad_3d_array(pred, pad_tuple) pred = torch.squeeze(pred, dim=0) # (1, 1, C,H,W) -> (1,C,H,W) pred = torch.squeeze(pred, dim=0) pred_mask_full_cube = pred # self._make_pred_mask_from_pred(pred) torch.cuda.ipc_collect() torch.cuda.empty_cache() del pred except RuntimeError as e: if "out of memory" in str(e) or "cuDNN error: CUDNN_STATUS_NOT_SUPPORTED" in str(e): print("TOO BIG FOR MEMORY, DEFAULTING TO PATCHING") # exit(0) dump_tensors() torch.cuda.ipc_collect() torch.cuda.empty_cache() dump_tensors() self.tried = True self.save_segmentation_examples(inference_full_image=False) return # segmentations.append(pred_mask_full_cube) else: pred_mask_full_cube = torch.zeros(self.original_cube_dimensions) for z_idx in range(full_cube_tensor.size()[-1]): tensor_slice = full_cube_tensor[..., z_idx] # SLICE : (1,1,C,H,W) -> (1,1,C,H) assert tensor_slice.shape == (1, 1, self.original_cube_dimensions[0], self.original_cube_dimensions[1]) pred = self.trainer.model(tensor_slice) pred = torch.squeeze(pred, dim=0) # (1, 1, C,H) -> (1,C,H) pred = torch.squeeze(pred, dim=0) # (1,C,H) -> (C,H) pred_mask_slice = pred # self._make_pred_mask_from_pred(pred) pred_mask_full_cube[..., z_idx] = pred_mask_slice # segmentations.append(pred_mask_full_cube) # for idx, pred_mask_full_cube in enumerate(segmentations): print(cube_idx) if cube_idx < nr_cubes: if inference_full_image is True: save_dir = os.path.join(self.save_dir, self.dataset_name, "testing_examples_full/", cubes_to_use[cube_idx][:-4]) else: save_dir = os.path.join( self.save_dir, self.dataset_name, "testing_examples_full/", cubes_to_use[cube_idx][:-4] + "_with_patcher" ) else: if inference_full_image is True: save_dir = os.path.join(self.save_dir, self.dataset_name, "training_examples_full/", cubes_to_use[cube_idx][:-4]) else: save_dir = os.path.join( self.save_dir, self.dataset_name, "training_examples_full/", cubes_to_use[cube_idx][:-4] + "_with_patcher" ) make_dir(save_dir) # save nii of segmentation pred_mask_full_cube = pred_mask_full_cube.cpu() # logits mask pred_mask_full_cube_binary = self._make_pred_mask_from_pred(pred_mask_full_cube) # binary mask nifty_img = nibabel.Nifti1Image(np.array(pred_mask_full_cube).astype(np.float32), np.eye(4)) nibabel.save(nifty_img, os.path.join(save_dir, cubes_to_use[cube_idx][:-4] + "_logits_mask.nii.gz")) nifty_img = nibabel.Nifti1Image(np.array(pred_mask_full_cube_binary).astype(np.float32), np.eye(4)) nibabel.save(nifty_img, os.path.join(save_dir, cubes_to_use[cube_idx][:-4] + "_binary_mask.nii.gz")) # save .nii.gz of cube if is npy original full cube file if ".npy" in cube_path: nifty_img = nibabel.Nifti1Image(np_array.astype(np.float32), np.eye(4)) nibabel.save(nifty_img, os.path.join(save_dir, cubes_to_use[cube_idx][:-4] + "_cube.nii.gz")) # self.save_3d_plot(np.array(pred_mask_full_cube), os.path.join(save_dir, "{}_plt3d.png".format(cubes_to_use[idx]))) label_tensor_of_cube = torch.Tensor(self._load_cube_to_np_array(label_cubes_of_cubes_to_use_path[cube_idx])) label_tensor_of_cube = self.adjust_label_cube_acording_to_dataset(label_tensor_of_cube) label_tensor_of_cube_masked = np.array(label_tensor_of_cube) label_tensor_of_cube_masked = np.ma.masked_where( label_tensor_of_cube_masked < 0.5, label_tensor_of_cube_masked ) # it's binary anyway pred_mask_full_cube_binary_masked = np.array(pred_mask_full_cube_binary) pred_mask_full_cube_binary_masked = np.ma.masked_where( pred_mask_full_cube_binary_masked < 0.5, pred_mask_full_cube_binary_masked ) # it's binary anyway pred_mask_full_cube_logits_masked = np.array(pred_mask_full_cube) pred_mask_full_cube_logits_masked = np.ma.masked_where( pred_mask_full_cube_logits_masked < 0.3, pred_mask_full_cube_logits_masked ) # it's binary anyway make_dir(os.path.join(save_dir, "slices/")) for z_idx in range(pred_mask_full_cube.shape[-1]): # binary fig = plt.figure(figsize=(10, 5)) plt.imshow(np_array[:, :, z_idx], cmap=cm.Greys_r) plt.imshow(pred_mask_full_cube_binary_masked[:, :, z_idx], cmap="Accent") plt.axis("off") fig.savefig( os.path.join(save_dir, "slices/", "slice_{}_binary.jpg".format(z_idx + 1)), bbox_inches="tight", dpi=150, ) plt.close(fig=fig) # logits fig = plt.figure(figsize=(10, 5)) plt.imshow(np_array[:, :, z_idx], cmap=cm.Greys_r) plt.imshow(pred_mask_full_cube_logits_masked[:, :, z_idx], cmap="Blues", alpha=0.5) plt.axis("off") fig.savefig( os.path.join(save_dir, "slices/", "slice_{}_logits.jpg".format(z_idx + 1)), bbox_inches="tight", dpi=150, ) plt.close(fig=fig) # dist of logits histogram distribution_logits = np.array(pred_mask_full_cube[:, :, z_idx].contiguous().view(-1)) fig = plt.figure(figsize=(10, 5)) plt.hist(distribution_logits, bins=np.arange(min(distribution_logits), max(distribution_logits) + 0.05, 0.05)) fig.savefig( os.path.join(save_dir, "slices/", "slice_{}_logits_histogram.jpg".format(z_idx + 1)), bbox_inches="tight", dpi=150, ) plt.close(fig=fig) # save ground truth as wel, overlayed on original fig = plt.figure(figsize=(10, 5)) plt.imshow(np_array[:, :, z_idx], cmap=cm.Greys_r) plt.imshow(label_tensor_of_cube_masked[:, :, z_idx], cmap="jet") plt.axis("off") fig.savefig( os.path.join(save_dir, "slices/", "slice_{}_gt.jpg".format(z_idx + 1)), bbox_inches="tight", dpi=150, ) plt.close(fig=fig) dice_score_soft = float(DiceLoss.dice_loss(pred_mask_full_cube, label_tensor_of_cube, return_loss=False)) dice_score_binary = float(DiceLoss.dice_loss(pred_mask_full_cube_binary, label_tensor_of_cube, return_loss=False)) x_flat = pred_mask_full_cube_binary.contiguous().view(-1) y_flat = pred_mask_full_cube_binary.contiguous().view(-1) x_flat = x_flat.cpu() y_flat = y_flat.cpu() jaccard_scr = jaccard_score(y_flat, x_flat) metrics = {"dice_logits": dice_score_soft, "dice_binary": dice_score_binary, "jaccard": jaccard_scr} # print(dice) with open(os.path.join(save_dir, "dice.json"), "w") as f: json.dump(metrics, f) dump_tensors() torch.cuda.ipc_collect() torch.cuda.empty_cache() dump_tensors() dump_tensors() torch.cuda.ipc_collect() torch.cuda.empty_cache() dump_tensors() sleep(10)
def compute_metrics_for_all_cubes(self, inference_full_image=True): cubes_to_use = [] dump_tensors() torch.cuda.ipc_collect() torch.cuda.empty_cache() torch.cuda.ipc_collect() torch.cuda.empty_cache() dump_tensors() torch.cuda.empty_cache() if "lidc" in self.dataset_name: return if hasattr(self.trainer, "model"): del self.trainer.model del self.trainer sleep(20) self.trainer = Trainer(config=self.config, dataset=None) dump_tensors() torch.cuda.ipc_collect() torch.cuda.empty_cache() dump_tensors() dump_tensors() torch.cuda.ipc_collect() torch.cuda.empty_cache() dump_tensors() self.trainer.load_model(from_path=True, path=self.model_path, phase="sup", ensure_sup_is_completed=True) if inference_full_image is False: print("PATCHING Will be Done") full_cubes_used_for_testing = self.get_all_cubes_which_were_used_for_testing() full_cubes_used_for_training = self.get_all_cubes_which_were_used_for_training() cubes_to_use.extend(full_cubes_used_for_testing) cubes_to_use.extend(full_cubes_used_for_training) cubes_to_use_path = [os.path.join(self.dataset_dir, i) for i in cubes_to_use] label_cubes_of_cubes_to_use_path = [os.path.join(self.dataset_labels_dir, i) for i in cubes_to_use] metric_dict = dict() ( dice_logits_test, dice_logits_train, dice_binary_test, dice_binary_train, jaccard_test, jaccard_train, hausdorff_test, hausdorff_train, ) = ([], [], [], [], [], [], [], []) for idx, cube_path in enumerate(cubes_to_use_path): np_array = self._load_cube_to_np_array(cube_path) # (x,y,z) self.original_cube_dimensions = np_array.shape if sum([i for i in np_array.shape]) > 550 and self.two_dim is False: inference_full_image = False if self.dataset_name.lower() in ("task04_sup", "task01_sup", "cellari_heart_sup_10_192", "cellari_heart_sup"): if self.tried is False: inference_full_image = True else: inference_full_image = False if inference_full_image is False: print("CUBE TOO BIG, PATCHING") patcher = Patcher(np_array, two_dim=self.two_dim) with torch.no_grad(): self.trainer.model.eval() for patch_idx, patch in patcher: patch = torch.unsqueeze(patch, 0) # (1,C,H,W or 1) -> (1,1,C,H,W or 1) if self.config.model.lower() in ( "vnet_mg", "unet_3d", "unet_acs", "unet_acs_axis_aware_decoder", "unet_acs_with_cls", ): patch, pad_tuple = pad_if_necessary_one_array(patch, return_pad_tuple=True) pred = self.trainer.model(patch) assert pred.shape == patch.shape, "{} vs {}".format(pred.shape, patch.shape) # need to then unpad to reconstruct if self.two_dim is True: raise RuntimeError("SHOULD NOT BE USED HERE") pred = self._unpad_3d_array(pred, pad_tuple) pred = torch.squeeze(pred, dim=0) # (1, 1, C,H,W) -> (1,C,H,W) # pred_mask = self._make_pred_mask_from_pred(pred) patcher.predicitons_to_reconstruct_from[ :, patch_idx ] = pred # update array in patcher that will construct full cube predicted mask del pred dump_tensors() torch.cuda.ipc_collect() torch.cuda.empty_cache() dump_tensors() pred_mask_full_cube = patcher.get_pred_mask_full_cube() else: full_cube_tensor = torch.Tensor(np_array) full_cube_tensor = torch.unsqueeze(full_cube_tensor, 0) # (C,H,W) -> (1,C,H,W) full_cube_tensor = torch.unsqueeze(full_cube_tensor, 0) # (1,C,H,W) -> (1,1,C,H,W) with torch.no_grad(): self.trainer.model.eval() if self.two_dim is False: if self.config.model.lower() in ( "vnet_mg", "unet_3d", "unet_acs", "unet_acs_axis_aware_decoder", "unet_acs_with_cls", ): full_cube_tensor, pad_tuple = pad_if_necessary_one_array(full_cube_tensor, return_pad_tuple=True) try: p = self.trainer.model(full_cube_tensor) p.to("cpu") pred = p del p dump_tensors() torch.cuda.ipc_collect() torch.cuda.empty_cache() dump_tensors() torch.cuda.empty_cache() pred = self._unpad_3d_array(pred, pad_tuple) pred = torch.squeeze(pred, dim=0) # (1, 1, C,H,W) -> (1,C,H,W) pred = torch.squeeze(pred, dim=0) pred_mask_full_cube = pred # self._make_pred_mask_from_pred(pred) torch.cuda.ipc_collect() torch.cuda.empty_cache() del pred except RuntimeError as e: if "out of memory" in str(e) or "cuDNN error: CUDNN_STATUS_NOT_SUPPORTED" in str(e): print("TOO BIG FOR MEMORY, DEFAULTING TO PATCHING") # exit(0) dump_tensors() torch.cuda.ipc_collect() torch.cuda.empty_cache() dump_tensors() self.tried = True res = self.compute_metrics_for_all_cubes(inference_full_image=False) return res else: pred_mask_full_cube = torch.zeros(self.original_cube_dimensions) for z_idx in range(full_cube_tensor.size()[-1]): tensor_slice = full_cube_tensor[..., z_idx] # SLICE : (1,1,C,H,W) -> (1,1,C,H) assert tensor_slice.shape == (1, 1, self.original_cube_dimensions[0], self.original_cube_dimensions[1]) pred = self.trainer.model(tensor_slice) pred = torch.squeeze(pred, dim=0) # (1, 1, C,H) -> (1,C,H) pred = torch.squeeze(pred, dim=0) # (1,C,H) -> (C,H) pred_mask_slice = pred # self._make_pred_mask_from_pred(pred) pred_mask_full_cube[..., z_idx] = pred_mask_slice full_cube_label_tensor = torch.Tensor(self._load_cube_to_np_array(label_cubes_of_cubes_to_use_path[idx])) full_cube_label_tensor = self.adjust_label_cube_acording_to_dataset(full_cube_label_tensor) pred_mask_full_cube = pred_mask_full_cube.to("cpu") threshold = self._set_threshold(pred_mask_full_cube, full_cube_label_tensor) pred_mask_full_cube_binary = self._make_pred_mask_from_pred(pred_mask_full_cube, threshold=threshold) dice_score_soft = float(DiceLoss.dice_loss(pred_mask_full_cube, full_cube_label_tensor, return_loss=False)) dice_score_binary = float(DiceLoss.dice_loss(pred_mask_full_cube_binary, full_cube_label_tensor, return_loss=False)) hausdorff = hausdorff_distance(np.array(pred_mask_full_cube_binary), np.array(full_cube_label_tensor)) x_flat = pred_mask_full_cube_binary.contiguous().view(-1) y_flat = full_cube_label_tensor.contiguous().view(-1) x_flat = x_flat.cpu() y_flat = y_flat.cpu() jac_score = jaccard_score(y_flat, x_flat) if idx < len(full_cubes_used_for_testing): dice_logits_test.append(dice_score_soft) dice_binary_test.append(dice_score_binary) jaccard_test.append(jac_score) hausdorff_test.append(hausdorff) else: dice_logits_train.append(dice_score_soft) dice_binary_train.append(dice_score_binary) jaccard_train.append(jac_score) hausdorff_train.append(hausdorff) dump_tensors() torch.cuda.ipc_collect() torch.cuda.empty_cache() dump_tensors() sleep(10) print(idx) avg_jaccard_test = sum(jaccard_test) / len(jaccard_test) avg_jaccard_train = sum(jaccard_train) / len(jaccard_train) avg_dice_test_soft = sum(dice_logits_test) / len(dice_logits_test) avg_dice_test_binary = sum(dice_binary_test) / len(dice_binary_test) avg_dice_train_soft = sum(dice_logits_train) / len(dice_logits_train) avg_dice_train_binary = sum(dice_binary_train) / len(dice_binary_train) avg_hausdorff_train = sum(hausdorff_train) / len(hausdorff_train) avg_hausdorff_test = sum(hausdorff_test) / len(hausdorff_test) metric_dict["dice_test_soft"] = avg_dice_test_soft metric_dict["dice_test_binary"] = avg_dice_test_binary metric_dict["dice_train_soft"] = avg_dice_train_soft metric_dict["dice_train_binary"] = avg_dice_train_binary metric_dict["jaccard_test"] = avg_jaccard_test metric_dict["jaccard_train"] = avg_jaccard_train metric_dict["hausdorff_test"] = avg_hausdorff_test metric_dict["hausdorff_train"] = avg_hausdorff_train return metric_dict