def create_3dshallow_dataset(depth=5): """ Create 3D shallow images and save with the given depth to allow fast shuffled loading during training. :param depth: int: slice thickness to create for images """ assert ( depth > 1 and depth % 2 == 1 ), f"depth must be an odd number greater than 1 but got {depth}" reader = NIIReader() base_folder = Path("/media/y4tsu/4B172BDA26AB3054/cmr_folds") folds = [x for x in base_folder.iterdir() if x.is_dir()] for fold in folds: data_3d = fold / "3D" data_3d_shallow = fold / "3DShallow" data_3d_shallow.mkdir(parents=False, exist_ok=True) train_data_3d_shallow = data_3d_shallow / "train" val_data_3d_shallow = data_3d_shallow / "val" train_data_3d_shallow.mkdir(parents=False, exist_ok=True) val_data_3d_shallow.mkdir(parents=False, exist_ok=True) transverse_train_2d = train_data_3d_shallow / "transverse" transverse_val_2d = val_data_3d_shallow / "transverse" transverse_train_2d.mkdir(parents=False, exist_ok=True) transverse_val_2d.mkdir(parents=False, exist_ok=True) for g in ["train", "val"]: for x in tqdm((data_3d / g).iterdir()): image_fname = x / f"{x.stem}_SAX.nii.gz" label_fname = x / f"{x.stem}_SAX_mask2.nii.gz" image = reader.read(image_fname) label = reader.read(label_fname) curr_dir = data_3d_shallow / g / "transverse" / x.stem curr_dir.mkdir(parents=False, exist_ok=True) for j in range(image.shape[2] - depth + 1): np.save( str( data_3d_shallow / g / "transverse" / x.stem / f"{x.stem}_{j:03}_image" ), image[:, :, j : j + depth], ) np.save( str( data_3d_shallow / g / "transverse" / x.stem / f"{x.stem}_{j:03}_label" ), label[:, :, j : j + depth], )
def find_widths_distributions(root): """Get the widths of all images (square) in the dataset and display as a distribution.""" depths = [] reader = NIIReader() for g in ["train", "val"]: for x in tqdm(sorted(os.listdir(os.path.join(root, "3D", g)))): image_fname = os.path.join(root, "3D", g, x, f"{x}_SAX.nii.gz") image = reader.read(image_fname) depths += [image.shape[0]] plt.hist( depths, bins=np.arange(min(depths), max(depths) + 2) - 0.5, rwidth=1, color="cadetblue", ) plt.xlabel("Width of image") plt.ylabel("Frequency") plt.title("Frequency of image widths") # Create legend & Show graphic plt.savefig( "/home/y4tsu/Desktop/diss img/width_frequencies.png", bbox_inches="tight", dpi=300, ) plt.savefig( "/home/y4tsu/Desktop/diss img/width_frequencies.pdf", bbox_inches="tight" ) plt.show()
def _get_manual_bboxes(self): """Load the manual segmentation masks from 3D images to get bounding boxes on the fly.""" print( f"Loading bboxes from manual segmentations for {self.dataset} data ... " ) reader = NIIReader() base_folder = os.path.join(self.data_path, "3D", self.dataset) roots = sorted(os.listdir(base_folder)) out = {} for i, root in enumerate(tqdm(roots)): label = np.squeeze( reader.read( os.path.join(base_folder, root, f"{root}_SAX_mask2.nii.gz"))) binary_label = np.where(np.equal(label, 0), 0, 1) bbox = self.__find_bbox(binary_label) out[root] = { "top": bbox[0], "left": bbox[1], "bottom": bbox[2], "right": bbox[3], } return out
def create_2d_dataset(): """ Create the 2D images and labels slices so that they can be loaded quickly in a shuffled order during training. """ reader = NIIReader() base_folder = Path("/media/y4tsu/4B172BDA26AB3054/cmr_folds") folds = [x for x in base_folder.iterdir() if x.is_dir()] for fold in folds: data_3d = fold / "3D" data_2d = fold / "2D" data_2d.mkdir(parents=False, exist_ok=True) train_data_2d = data_2d / "train" val_data_2d = data_2d / "val" train_data_2d.mkdir(parents=False, exist_ok=True) val_data_2d.mkdir(parents=False, exist_ok=True) transverse_train_2d = train_data_2d / "transverse" transverse_val_2d = val_data_2d / "transverse" transverse_train_2d.mkdir(parents=False, exist_ok=True) transverse_val_2d.mkdir(parents=False, exist_ok=True) for g in ["train", "val"]: for x in tqdm((data_3d / g).iterdir()): image_fname = x / f"{x.stem}_SAX.nii.gz" label_fname = x / f"{x.stem}_SAX_mask2.nii.gz" image = reader.read(image_fname) label = reader.read(label_fname) curr_dir = data_2d / g / "transverse" / x.stem curr_dir.mkdir(parents=False, exist_ok=True) for j in range(image.shape[2]): np.save( str( data_2d / g / "transverse" / x.stem / f"{x.stem}_{j:03}_image" ), image[:, :, j], ) np.save( str( data_2d / g / "transverse" / x.stem / f"{x.stem}_{j:03}_label" ), label[:, :, j], )
def search_incorrect_orientations(): """ Look through all the available images to find those which are malrotated with respect to the majority so that they can be fixed. """ reader = NIIReader() base_folder = Path("/media/y4tsu/4B172BDA26AB3054/cmr_clean") roots = sorted([x for x in base_folder.iterdir() if x.is_dir()]) print(f"{len(roots)=}") for root in roots: print(f"Reading root {root}") file_num = root.stem[-3:] image = np.squeeze(reader.read(root / f"20CA015_N{file_num}_SAX.nii.gz")) print(f"{image[:20, ...].sum()=}") reader.scroll_view(image)
def __init__(self, data_path, dataset, model_path, train_config, post_process): super().__init__(data_path, dataset, train_config, post_process) self.full_data_path = os.path.join(data_path, "3D", self.dataset) self.reader = NIIReader() if not self.cascade else NPYReader() # TODO: test this all works ok self.image_fnames = [ os.path.join(self.full_data_path, x) for x in sorted(os.listdir(self.full_data_path)) ] self.label_fnames = [ os.path.join(self.full_data_path, x) for x in sorted(os.listdir(self.full_data_path)) ] self.model = self.load_model(model_path) self.dimensionality = "3D"
def calculate_label_weights(self): """Calculate beta pixel weighting and its inverse as the label weights for weighted loss functions.""" print( f"Calculating label weightings across {len(self.train_gen.image_fnames)} label images for use in loss" f" function, may take a while ... ") if self.combine_labels: sums = np.zeros(len(self.combine_labels), dtype=np.float32) else: sums = np.zeros(len(self.labels), dtype=np.float32) for i in tqdm( range(len(self.train_gen.image_fnames) // self.batch_size)): _, label_img = self.train_gen.__getitem__(i, weight_mode=True) # Get the number of labelled voxels of each class for each label image if self.quality_weighted_mode: sums += [ label_img["m"][..., j].sum() for j in range(label_img["m"].shape[-1]) ] else: sums += [ label_img[..., j].sum() for j in range(label_img.shape[-1]) ] # Get the total number of voxels in the dataset to normalize the beta if self.model != "UNet3DFrozenDepth": total_voxels = np.prod( np.array([*self.image_size, len(self.train_gen.image_fnames)])) else: total_voxels = 0 reader = NIIReader() for fname in self.train_gen.image_fnames: img = reader.read(fname) total_voxels += self.image_size[0] * self.image_size[ 1] * img.shape[-1] beta = sums / total_voxels print(1.0 / beta) # Return weightings: 1 / beta return 1.0 / beta
def __init__( self, model_save_path, generic_data_path, data_path, plane, batch_size, image_size, labels, dataset, shuffle=True, augmenter=None, use_cropper=False, combine_labels=None, cascade=None, quality_weighting_scores=None, ): super().__init__( model_save_path, generic_data_path, data_path, plane, batch_size, image_size, labels, dataset, shuffle, augmenter, use_cropper, combine_labels, cascade, quality_weighting_scores, ) if cascade: self.reader = NPYReader() self.image_fnames = [ os.path.join(model_save_path, "mask", dataset, f"{x}_SAX.nii.gz") for x in sorted( os.listdir(os.path.join(model_save_path, "mask", dataset)) ) ] else: self.reader = NIIReader() if self.quality_weighting_scores: self.resizer = NPYReader()
def rotate_incorrect_orientations(): """ Rotate a list of manually specified images by 90 degrees anti-clockwise so that all images are in the same orientation to simplify the learning task. """ reader = NIIReader() base_folder = Path("/media/y4tsu/4B172BDA26AB3054/cmr_clean") roots = sorted([x for x in base_folder.iterdir() if x.is_dir()]) print(f"{len(roots)=}") non_squares = [] for i, root in enumerate(roots): file_num = root.stem[-3:] image = np.squeeze(reader.read(root / f"20CA015_N{file_num}_SAX.nii.gz")) shape = image.shape label = np.squeeze(reader.read(root / f"20CA015_N{file_num}_SAX_mask2.nii.gz")) if shape[0] != shape[1]: print(f"{root} is non-square") non_squares += [root] if file_num in [ "008", "014", "024", "030", "062", "064", "083", "089", "135", "138", "141", "144", "156", "159", "168", "174", "181", "192", "213", "215", "227", "262", "278", "294", "304", "307", "319", "330", "347", "348", "353", "355", "375", ]: # Show the original image # plt.imshow(label[:, :, 5], cmap="gray") # plt.show() # Rotate the bad images and labels rot_image = rotate(image, axes=(0, 1), angle=-90.0, reshape=False, order=3) rot_label = rotate(label, axes=(0, 1), angle=-90.0, reshape=False, order=0) # Set them as nifti type images new_img = nib.Nifti1Image(rot_image, np.eye(4)) new_label = nib.Nifti1Image(rot_label, np.eye(4)) # Save them, overwriting the original files nib.save(new_img, root / f"20CA015_N{file_num}_SAX.nii.gz") nib.save(new_label, root / f"20CA015_N{file_num}_SAX_mask2.nii.gz") # Show the newly rotated image # plt.imshow(rot_label[:, :, 5], cmap="gray") # plt.show() print(f"{len(non_squares)=}")
class Predictor3D(__Predictor): def __init__(self, data_path, dataset, model_path, train_config, post_process): super().__init__(data_path, dataset, train_config, post_process) self.full_data_path = os.path.join(data_path, "3D", self.dataset) self.reader = NIIReader() if not self.cascade else NPYReader() # TODO: test this all works ok self.image_fnames = [ os.path.join(self.full_data_path, x) for x in sorted(os.listdir(self.full_data_path)) ] self.label_fnames = [ os.path.join(self.full_data_path, x) for x in sorted(os.listdir(self.full_data_path)) ] self.model = self.load_model(model_path) self.dimensionality = "3D" def load_image_label(self, fname): """Loads the image and label files.""" image_folder, label_folder, suffix, fname = self._get_folder_paths( fname) # Load image and label image = self.reader.read( os.path.join(image_folder, f"{suffix}_SAX.nii.gz")) label = self.reader.read( os.path.join(label_folder, f"{suffix}_SAX_mask2.nii.gz")) image, label = self._prepare_image_label(image, label, suffix) # Set to the correct rank image = image[np.newaxis, ..., np.newaxis] return image, label, fname def predict(self, fname=None, display=False, apply_combine=True, return_fname=False): image, label, fname = self.load_image_label(fname) if self.quality_weighted_mode: pred_label = self.model.predict( (image, np.array([1.0], dtype=np.float32)))[1] else: pred_label = self.model.predict(image) pred_label = np.squeeze(np.argmax(pred_label, axis=-1)) if self.combine_labels and apply_combine: label = self.apply_label_combine(label) if self.post_process: pred_label = self.post_process_label(pred_label) if display: print(self.calculate_dice(label, pred_label)) self.display(image, label, pred_label) if return_fname: return image, label, pred_label, fname else: return image, label, pred_label
with open(os.path.join(path_tr, "train_config.json"), "r") as f: p1 = Predictor2D(data_path, dataset, path_tr, json.load(f)) # Sagittal predictor with open(os.path.join(path_sag, "train_config.json"), "r") as f: p2 = Predictor2D(data_path, dataset, path_sag, json.load(f)) # Coronal predictor with open(os.path.join(path_cor, "train_config.json"), "r") as f: p3 = Predictor2D(data_path, dataset, path_cor, json.load(f)) return p1, p2, p3 if __name__ == "__main__": nii_reader = NIIReader() data_path = "/media/y4tsu/ml_data/cmr" dataset = "val" p_tr, p_sag, p_cor = setup_predictors( data_path, dataset, "/home/y4tsu/PycharmProjects/3d_unet/checkpoint/2D_tr", "/home/y4tsu/PycharmProjects/3d_unet/checkpoint/2D_sag", "/home/y4tsu/PycharmProjects/3d_unet/checkpoint/2D_cor", ) roots = sorted(os.listdir(os.path.join(data_path, "3D", dataset))) dices = 0 class_wise_dices = np.zeros([len(p_tr.labels_dict)]) start = time.time()
def _get_auto_bboxes(self, model_path): """Run the segmentation model to get bounding boxes for all the images in the dataset.""" from predict import load_predictor print( f"Predicting bboxes using automatic cropper model at {model_path} for {self.dataset} data ... " ) reader = NIIReader() base_folder = os.path.join(self.data_path, "3D", self.dataset) roots = sorted(os.listdir(base_folder)) out = {} # Set up the correct config for the predictor model predict_config = { "model_path": model_path, "data_path": self.data_path, "dataset": self.dataset, "post_process": False, } # Load the cropper model as a Predictor object p = load_predictor(predict_config) # Iterate over all the images, getting predicted labels for i, root in enumerate(tqdm(roots)): # Search for the correct full filename fname = None for x in p.image_fnames: if root in x: fname = x if fname is None: raise ValueError( f"Unable to find the correct path for image {root}") image_size = reader.read( os.path.join(self.data_path, "3D", self.dataset, f"{root}/{root}_SAX.nii.gz")).shape _, label, pred_label = p.predict(fname, display=False) # Remove noise and small islands from the prediction pred_label = self.__clean_prediction(pred_label) # In the case where an image has different dimensions to model input, needs to be re-scaled new_pred_label = np.empty(image_size, dtype=np.int8) if image_size[0] != p.image_size[0]: for j in range(image_size[-1]): curr = pred_label[..., j] new_pred_label[..., j] = cv2.resize( curr, tuple(reversed(image_size[:2])), interpolation=cv2.INTER_NEAREST, ) pred_label = new_pred_label # Now find the bounding box around the segmentation mask bbox = self.__find_bbox(pred_label) out[root] = { "top": bbox[0], "left": bbox[1], "bottom": bbox[2], "right": bbox[3], } return out
def main(): start = time.time() with open("predict_config.json", "r") as f: predict_config = json.load(f) # Load the correct Predictor class for the given model type p = load_predictor(predict_config) plane = p.plane cropper = p.cropper if p.model_name in ["UNet3D", "VNet", "UNet3DFrozenDepth"]: dims = "3D" plane = "" elif p.model_name in ["UNet3DShallow", "VNetShallow"]: dims = "3DShallow" else: dims = "2D" # Get the names of all the scans we are interested in roots = sorted( os.listdir( os.path.join( predict_config["data_path"], dims, predict_config["dataset"], plane ) ) ) roots = [ os.path.join( predict_config["data_path"], dims, predict_config["dataset"], plane, root ) for root in roots ] print(f"{roots=}") # Get the correct metadata headers = {} shapes = {} affines = {} reader = NIIReader() # Get roots from the 3D folder real_roots = sorted( os.listdir( os.path.join(predict_config["data_path"], "3D", predict_config["dataset"]) ) ) real_roots = [ os.path.join(predict_config["data_path"], "3D", predict_config["dataset"], x) for x in real_roots ] # Load important information from the original .nii.gz files for these images for real_root in tqdm(real_roots): name_end = real_root.split("/")[-1] img = nib.load( os.path.join(real_root, f"{name_end}_SAX_mask2.nii.gz"), mmap=False ) # Headers = headers from the original NiFTi file headers[name_end] = img.header # Affine = transformation which maps points to 3D space of the MRI image affines[name_end] = img.affine # Shape of each image shapes[name_end] = np.squeeze(img.get_fdata()).shape # Get a prediction for each image and save it in NiFTi format for loading in ITK-SNAP for root in tqdm(roots): # Get the prediction image, label, pred_label = p.predict(fname=root, display=False) # Get the key to find the correct header and affine for the output name_end = root.split("/")[-1] # print(f'{name_end=}') save_img = np.zeros(shapes[name_end]) # print(f'{save_img.shape=}') # bbox = cropper.bboxes[name_end] # print(f"{bbox=}") reverse_size, cut_dims = p.cropper.reverse_crop(shapes[name_end], name_end) # print(f"{reverse_size=}") # print(f"{cut_dims=}") # Resize the prediction to undo the pre-processing pred_label = reader.resize(pred_label, reverse_size, interpolation_order=0) # Get the correct portion of the predicted label to output try: save_img[ cut_dims["top"] : cut_dims["bottom"], cut_dims["left"] : cut_dims["right"], ..., ] = pred_label except ValueError: print( f"Unable to undo pre-processing for image {name_end}! {save_img.shape=}, {pred_label.shape=}" ) continue # Save the image in NiFTi format # exit() save_img = save_img.astype(np.uint16) save_img = nib.Nifti1Image( save_img, affine=affines[name_end], header=headers[name_end], ) nib.save( save_img, os.path.join( "/home/y4tsu/PycharmProjects/3d_unet/saved_preds", f"{name_end}_prediction.nii.gz", ), ) # print('-----') print(f"Finished! Process took {time.time() - start:.2f} seconds.")