def _load_images(self, patient): image_path, segmentation_paths = self.data_ct[patient] ct, meta_ct = read_image(image_path) segmentation_ct = self._get_segmentation_ct(segmentation_paths) logger.debug( f"Image and segmentation shapes:\nCT: {ct.shape}\nCT seg: {segmentation_ct.shape}" ) image_path, affine_fn, segmentation_paths = self.data_cbct[patient] ct_fn, ct_seg_fn = self.data_ct[patient] cbct, meta = read_image(image_path, affine_matrix=True, ref_fn=ct_fn, interpolator="linear") # no_meta=True, self.meta = meta segmentation_cbct = self._get_segmentation_cbct( segmentation_paths, ct_seg_fn) logger.debug(f"CBCT Shape: {cbct.shape}") logger.debug( f"Image and segmentation shapes:\nCBCT: {cbct.shape}\nCBCT seg: {segmentation_cbct.shape}" ) if len(ct.shape) == 3: # add "channels" dimension if it is not present ct = np.expand_dims(ct, axis=0) if len(cbct.shape) == 3: # add "channels" dimension if it is not present cbct = np.expand_dims(cbct, axis=0) return cbct, ct, segmentation_cbct, segmentation_ct
def _get_segmentation_ct(self, segmentations): if len(segmentations) == 2: seg_bladder = read_image(segmentations[0], no_meta=True) seg_cervix_uterus = read_image(segmentations[1], no_meta=True) start = int((seg_bladder.shape[2] - 512) / 2) seg_bladder = crop_to_bbox( seg_bladder, (0, start, start, seg_bladder.shape[0], 512, 512)) seg_cervix_uterus = crop_to_bbox( seg_cervix_uterus, (0, start, start, seg_cervix_uterus.shape[0], 512, 512)) all_segs = seg_bladder + seg_cervix_uterus other = all_segs < 1 segs = [seg_bladder, seg_cervix_uterus, other] segmentation = np.stack(segs).astype(int) return segmentation
def main_CBCT(): target_dir = Path("/data/cyclegan/cbct") files = list(Path("/data/cervix/patients").iterdir()) files = [f for f in files if len(list(f.iterdir())) > 0] transform_CT = transforms.Compose([ClipAndNormalize(250, 1750)]) j = 0 image_index = 0 for p in tqdm(files): for f in p.glob("X*.nii"): image, meta = read_image(str(f), ref_fn=(p / "CT1.nii"), affine_matrix=True) image = transform_CT({"image": image})["image"] for X in image: image_index += 1 im = Image.fromarray(np.uint8(X.squeeze() * 255), 'L') im.save( str(target_dir / "test_CBCT" / f"{image_index}_{f.parent.stem}_{f.stem}.jpg")) if temp_save: im.save(f"TEMP/{image_index}_{f.parent.stem}_{f.stem}.jpg")
def get_shapes_extra(root_dir): shapes = {} for patient in os.listdir(root_dir): try: img = read_image(os.path.join(root_dir, patient, "full", "CT.nrrd"), no_meta=True) print(patient, img.shape) shapes[patient + "/full"] = img.shape img2 = read_image(os.path.join(root_dir, patient, "empty", "CT.nrrd"), no_meta=True) print(patient, img2.shape) shapes[patient + "/empty"] = img2.shape except: print(patient, "failed") return shapes
def save_scans(args, slices): source_dir = Path("/data/cervix/patients") for patient, scans in tqdm(slices["fake"].items()): for scan, s in scans.items(): s = sorted(s, key=lambda tup: tup[0]) image = [] for tup in s: im_frame = Image.open(tup[1]).convert('L') np_array = np.array(im_frame.getdata()).reshape( (512, 512)) / 255.0 image.append(np_array) img = np.stack(image) img = img * 1500 + 250 if not (source_dir / patient / (f"bladder_{scan.lower()}.nii")).exists(): continue _, meta = read_image(str(source_dir / patient / ("CT1.nii"))) bladder = read_image( str(source_dir / patient / (f"bladder_{scan.lower()}.nii")), no_meta=True, ref_fn=str(source_dir / patient / ("CT1.nii")), affine_matrix=True) cervix = read_image(str(source_dir / patient / (f"cervix_uterus_{scan.lower()}.nii")), no_meta=True, ref_fn=str(source_dir / patient / ("CT1.nii")), affine_matrix=True) if not (args.output_dir / patient).exists(): (args.output_dir / patient).mkdir() write_image(img, str(args.output_dir / patient / (scan + ".nrrd")), metadata=meta) write_image(bladder, str(args.output_dir / patient / (f"bladder_{scan.lower()}.nrrd")), metadata=meta) write_image(cervix, str(args.output_dir / patient / (f"cervix_uterus_{scan.lower()}.nrrd")), metadata=meta)
def get_shapes(root_dir): shapes = {} for patient in os.listdir(root_dir): try: img = read_image(os.path.join(root_dir, patient, "CT1.nii"), no_meta=True) print(patient, img.shape) shapes[patient] = img.shape except: print(patient, "failed") return shapes
def _get_segmentation(self, segmentations): if len(segmentations) == 2: seg_bladder = read_image(segmentations[0], no_meta=True) seg_cervix_uterus = read_image(segmentations[1], no_meta=True) all_segs = seg_bladder + seg_cervix_uterus # Combine cervix and uterus segmentation elif len(segmentations) == 3: seg_bladder = read_image(segmentations[0], no_meta=True) seg_cervix = read_image(segmentations[1], no_meta=True) seg_uterus = read_image(segmentations[2], no_meta=True) seg_cervix_uterus = (seg_cervix | seg_uterus) all_segs = seg_bladder + seg_cervix + seg_uterus start = int((all_segs.shape[1] - 512) / 2) seg_bladder = crop_to_bbox( seg_bladder, (0, start, start, seg_bladder.shape[0], 512, 512)) seg_cervix_uterus = crop_to_bbox( seg_cervix_uterus, (0, start, start, seg_cervix_uterus.shape[0], 512, 512)) all_segs = crop_to_bbox(all_segs, (0, start, start, all_segs.shape[0], 512, 512)) other = all_segs < 1 segs = [seg_bladder, seg_cervix_uterus, other] segmentation = np.stack(segs).astype(int) return segmentation
def _load_image(self, patient): cache_fn = self.cachedir / f"{patient}_CT1" cache_fn_seg = self.cachedir / f"{patient}_CT1_seg" if cache_fn.exists() and cache_fn_seg.exists(): image = read_object(cache_fn) segmentation = read_object(cache_fn_seg) else: image_path, segmentation_paths = self.data[patient] image = read_image(image_path, no_meta=True) segmentation = self._get_segmentation(segmentation_paths) if len(image.shape) == 3: # add "channels" dimension if it is not present image = np.expand_dims(image, axis=0) return image, segmentation
def get_shapes_cbct(root_dir): shapes = {} for patient in os.listdir(root_dir): images = glob.glob(os.path.join(root_dir, patient, "X*.nii")) for cbct in images: try: m = re.search("X[0-9]+", cbct) n = m.group(0) img = read_image(cbct, no_meta=True) segmentations = glob.glob( os.path.join(root_dir, patient, "*_{}.nii".format(n.lower()))) print(patient + "\\" + n, img.shape, len(segmentations)) if len(segmentations) > 0: shapes[patient + "\\" + n] = img.shape except: print(patient, "failed") return shapes
def process_CT(source_dir, target_dir): files = list(source_dir.rglob("CT.nrrd")) train = files[:int(0.8 * len(files))] val = files[int(0.8 * len(files)):int(0.9 * len(files))] test = files[int(0.9 * len(files)):] transform_CT = transforms.Compose([ClipAndNormalize(250, 1750)]) image_index = 0 # A: CT, B: CBCT for i, f in enumerate(tqdm(train)): image, meta = read_image(str(f)) bboxes = np.array( get_table(image, clip_val=300, margin=2500, show_imgs=False)) if len(bboxes) == 0: print("Table not found in image!") image_index = 0 image = transform_CT({"image": image})["image"] for X in image[10:-10]: image_index += 1 im = Image.fromarray(np.uint8(X.squeeze() * 255), 'L') im.save(f"TEMP/NO_TABLE_{image_index}_{f.parent.stem}.jpg") bbox = (stats.mode(bboxes[:, 0])[0][0], stats.mode(bboxes[:, 1])[0][0], stats.mode(bboxes[:, 2])[0][0], stats.mode(bboxes[:, 3])[0][0]) image = np.clip(image, 0, image.max()) image[:, bbox[0]:512, bbox[1]:bbox[3]] = 0 image = transform_CT({"image": image})["image"] for X in image[10:-10]: image_index += 1 im = Image.fromarray(np.uint8(X.squeeze() * 255), 'L') im.save( str(target_dir / "trainA" / f"{image_index}_{f.parent.stem}.jpg")) if temp_save: im.save(f"TEMP/TRAIN_A_{image_index}_{f.parent.stem}.jpg") image_index = 0 for f in tqdm(val): image, meta = read_image(str(f)) bboxes = np.array( get_table(image, clip_val=200, margin=2500, show_imgs=False)) if len(bboxes) == 0: print("Table not found in image!") bbox = (stats.mode(bboxes[:, 0])[0][0], stats.mode(bboxes[:, 1])[0][0], stats.mode(bboxes[:, 2])[0][0], stats.mode(bboxes[:, 3])[0][0]) image = np.clip(image, 0, image.max()) image[:, bbox[0]:512, bbox[1]:bbox[3]] = 0 image = transform_CT({"image": image})["image"] for X in image[10:-10]: image_index += 1 im = Image.fromarray(np.uint8(X.squeeze() * 255), 'L') im.save( str(target_dir / "valA" / f"{image_index}_{f.parent.stem}.jpg")) if temp_save: im.save(f"TEMP/VAL_A_{image_index}_{f.parent.stem}.jpg") image_index = 0 for f in tqdm(test): image, meta = read_image(str(f)) bboxes = np.array( get_table(image, clip_val=200, margin=2500, show_imgs=False)) if len(bboxes) == 0: print("Table not found in image!") image_index = 0 image = transform_CT({"image": image})["image"] for X in image[10:-10]: image_index += 1 im = Image.fromarray(np.uint8(X.squeeze() * 255), 'L') im.save(f"TEMP/NO_TABLE_{image_index}_{f.parent.stem}.jpg") bbox = (stats.mode(bboxes[:, 0])[0][0], stats.mode(bboxes[:, 1])[0][0], stats.mode(bboxes[:, 2])[0][0], stats.mode(bboxes[:, 3])[0][0]) image = np.clip(image, 0, image.max()) image[:, bbox[0]:512, bbox[1]:bbox[3]] = 0 image = transform_CT({"image": image})["image"] for X in image[10:-10]: image_index += 1 im = Image.fromarray(np.uint8(X.squeeze() * 255), 'L') im.save( str(target_dir / "testA" / f"{image_index}_{f.parent.stem}.jpg")) if temp_save: im.save(f"TEMP/TEST_A_{image_index}_{f.parent.stem}.jpg")
def process_CBCT(source_dir, target_dir): files = list(source_dir.iterdir()) files = [f for f in files if len(list(f.glob("X*.nrrd"))) > 0] train = files[:int(0.8 * len(files))] val = files[int(0.8 * len(files)):int(0.9 * len(files))] test = files[int(0.9 * len(files)):] transform_CT = transforms.Compose([ClipAndNormalize(250, 1750)]) scan_id = 0 image_index = 0 for p in tqdm(train): cbct_count = 1 for f in p.glob("X*.nrrd"): if cbct_count > 2: break image, meta = read_image(str(f)) cbct_count += 1 image = transform_CT({"image": image})["image"] scan_id += 1 for X in image[10:-10]: image_index += 1 im = Image.fromarray(np.uint8(X.squeeze() * 255), 'L') im.save( str(target_dir / "trainB" / f"{image_index}_{f.parent.stem}_{f.stem}.jpg")) if temp_save: im.save( f"TEMP/TRAIN_B_{image_index}_{f.parent.stem}_{f.stem}_{scan_id}.jpg" ) image_index = 0 for p in tqdm(val): cbct_count = 1 for f in p.glob("X*.nrrd"): if cbct_count > 2: break if f.stem == "X01": continue image, meta = read_image(str(f)) cbct_count += 1 image = transform_CT({"image": image})["image"] for X in image[10:-10]: image_index += 1 im = Image.fromarray(np.uint8(X.squeeze() * 255), 'L') im.save( str(target_dir / "valB" / f"{image_index}_{f.parent.stem}_{f.stem}.jpg")) if temp_save: im.save( f"TEMP/VAL_B_{image_index}_{f.parent.stem}_{f.stem}.jpg") image_index = 0 scan_id = 0 # A: CT, B: CBCT for p in tqdm(test): cbct_count = 1 for f in p.glob("X*.nrrd"): if cbct_count > 2: break cbct_count += 1 scan_id += 1 image, meta = read_image(str(f)) image = transform_CT({"image": image})["image"] for X in image[10:-10]: image_index += 1 im = Image.fromarray(np.uint8(X.squeeze() * 255), 'L') im.save( str(target_dir / "testB" / f"{image_index}_{f.parent.stem}_{f.stem}.jpg")) if temp_save: im.save( f"TEMP/TEST_B_{image_index}_{f.parent.stem}_{f.stem}_{scan_id}.jpg" )
def test(args, dl, writer, model, image_shapes): device = "cuda" # Run on GPU criterion = get_loss_func("NLL") softmax = nn.LogSoftmax(1) logger.info("Start Testing...") tmp_losses = [] metrics = {"bladder": {}, "cervix": {}} segmentations = {0: [], 1: [], "y_bladder": [], "y_cervix": []} image_shapes.append(None) img_i = 0 temp = image_shapes.pop(0) img_shape = temp[1][0] patient = temp[0].replace("_", "/") _, metadata = read_image(str(temp[2])) logger.debug(patient.replace("/", "_")) all_zeros = 0 seg_slices = 0 model.eval() for i, (X, Y) in enumerate(dl): X, Y = X.to(device).float(), Y.to(device).float() torch.cuda.empty_cache() Y_hat = model(X) assert Y_hat.shape == Y.shape, "output and classification must be same shape, {}, {}".format( Y_hat.shape, Y.shape) if args.save_3d: segmentations["y_bladder"].append( Y[:, 0, :, :, :].squeeze().detach().cpu()) segmentations["y_cervix"].append( Y[:, 1, :, :, :].squeeze().detach().cpu()) Y_hat = softmax(Y_hat) loss = criterion(Y_hat, Y.argmax(1)) tmp_losses.append(loss.detach().cpu().item()) segmentations[0].append( Y_hat.exp()[:, 0, :, :, :].squeeze().detach().cpu() > args.threshold) segmentations[1].append( Y_hat.exp()[:, 1, :, :, :].squeeze().detach().cpu() > args.threshold) img_i += 1 if img_i >= img_shape and args.save_3d: img_i = 0 logger.info(f"Saving image {patient}") y_bladder = torch.stack( segmentations["y_bladder"]).detach().cpu().numpy() y_cervix = torch.stack( segmentations["y_cervix"]).detach().cpu().numpy() seg_bladder = torch.stack( segmentations[0]).int().detach().cpu().numpy() seg_cervix = torch.stack( segmentations[1]).int().detach().cpu().numpy() if args.post_process: labels_mask = measure.label(seg_bladder) regions = measure.regionprops(labels_mask) regions.sort(key=lambda x: x.area, reverse=True) if len(regions) > 1: for rg in regions[1:]: labels_mask[rg.coords[:, 0], rg.coords[:, 1]] = 0 labels_mask[labels_mask != 0] = 1 seg_bladder = labels_mask labels_mask = measure.label(seg_cervix) regions = measure.regionprops(labels_mask) regions.sort(key=lambda x: x.area, reverse=True) if len(regions) > 1: for rg in regions[1:]: labels_mask[rg.coords[:, 0], rg.coords[:, 1]] = 0 labels_mask[labels_mask != 0] = 1 seg_cervix = labels_mask write_image(seg_bladder.astype(np.uint8), "{}/{}_seg_bladder.nrrd".format( args.test_folder, patient.replace("/", "_")), metadata=metadata) write_image(seg_cervix.astype(np.uint8), "{}/{}_seg_cervix_uterus.nrrd".format( args.test_folder, patient.replace("/", "_")), metadata=metadata) metrics_bladder = calculate_metrics(y_bladder.astype(bool), seg_bladder.astype(bool), metadata["spacing"], 25.0, [0.5, 1.0, 1.5, 3.0]) metrics_cervix = calculate_metrics(y_cervix.astype(bool), seg_cervix.astype(bool), metadata["spacing"], 25.0, [0.5, 1.0, 1.5, 3.0]) metrics["bladder"][patient] = metrics_bladder metrics["cervix"][patient] = metrics_cervix for m, v in metrics_bladder.items(): logger.info(f"{m} bladder: {v}") for m, v in metrics_cervix.items(): logger.info(f"{m} cervix: {v}") segmentations = {0: [], 1: [], "y_bladder": [], "y_cervix": []} temp = image_shapes.pop(0) if not temp is None: img_shape = temp[1][0] patient = temp[0].replace("_", "/") CT_path = temp[2] if CT_path.exists(): _, metadata = read_image(str(CT_path)) torch.cuda.empty_cache() _log_images(X, Y, Y_hat, i, writer, tag="test") if img_i % args.print_every == 0: logger.info("Iteration: {}/{} Loss: {}".format( i, len(dl), sum(tmp_losses) / len(tmp_losses))) tmp_losses = [] pickle.dump(metrics, open("{}/metrics.p".format(args.test_folder), 'wb')) writer.flush() logger.info("End testing")
''' Author: Tessa Wagenaar Calculate the data statistics ''' import glob import os from pathlib import Path from utils.image_readers import read_image import pickle import numpy as np files_CBCT = pickle.load(open("files_CBCT.p", 'rb')) statistics = {} print("calculate image data") for (patient, shape, image_fn, segmentations) in files_CBCT: image, metadata = read_image(str(image_fn)) print(f"Min: {image.min()} Max: {image.max()}") pickle.dump(statistics, open("CBCT_statistics.p", 'wb'))