Exemple #1
0
    def _load_images(self, patient):
        image_path, segmentation_paths = self.data_ct[patient]
        ct, meta_ct = read_image(image_path)
        segmentation_ct = self._get_segmentation_ct(segmentation_paths)
        logger.debug(
            f"Image and segmentation shapes:\nCT: {ct.shape}\nCT seg: {segmentation_ct.shape}"
        )

        image_path, affine_fn, segmentation_paths = self.data_cbct[patient]
        ct_fn, ct_seg_fn = self.data_ct[patient]
        cbct, meta = read_image(image_path,
                                affine_matrix=True,
                                ref_fn=ct_fn,
                                interpolator="linear")  # no_meta=True,
        self.meta = meta
        segmentation_cbct = self._get_segmentation_cbct(
            segmentation_paths, ct_seg_fn)
        logger.debug(f"CBCT Shape: {cbct.shape}")
        logger.debug(
            f"Image and segmentation shapes:\nCBCT: {cbct.shape}\nCBCT seg: {segmentation_cbct.shape}"
        )

        if len(ct.shape) == 3:
            # add "channels" dimension if it is not present
            ct = np.expand_dims(ct, axis=0)
        if len(cbct.shape) == 3:
            # add "channels" dimension if it is not present
            cbct = np.expand_dims(cbct, axis=0)
        return cbct, ct, segmentation_cbct, segmentation_ct
Exemple #2
0
 def _get_segmentation_ct(self, segmentations):
     if len(segmentations) == 2:
         seg_bladder = read_image(segmentations[0], no_meta=True)
         seg_cervix_uterus = read_image(segmentations[1], no_meta=True)
     start = int((seg_bladder.shape[2] - 512) / 2)
     seg_bladder = crop_to_bbox(
         seg_bladder, (0, start, start, seg_bladder.shape[0], 512, 512))
     seg_cervix_uterus = crop_to_bbox(
         seg_cervix_uterus,
         (0, start, start, seg_cervix_uterus.shape[0], 512, 512))
     all_segs = seg_bladder + seg_cervix_uterus
     other = all_segs < 1
     segs = [seg_bladder, seg_cervix_uterus, other]
     segmentation = np.stack(segs).astype(int)
     return segmentation
def main_CBCT():
    target_dir = Path("/data/cyclegan/cbct")
    files = list(Path("/data/cervix/patients").iterdir())
    files = [f for f in files if len(list(f.iterdir())) > 0]

    transform_CT = transforms.Compose([ClipAndNormalize(250, 1750)])

    j = 0
    image_index = 0

    for p in tqdm(files):
        for f in p.glob("X*.nii"):
            image, meta = read_image(str(f),
                                     ref_fn=(p / "CT1.nii"),
                                     affine_matrix=True)
            image = transform_CT({"image": image})["image"]

            for X in image:
                image_index += 1
                im = Image.fromarray(np.uint8(X.squeeze() * 255), 'L')
                im.save(
                    str(target_dir / "test_CBCT" /
                        f"{image_index}_{f.parent.stem}_{f.stem}.jpg"))
            if temp_save:
                im.save(f"TEMP/{image_index}_{f.parent.stem}_{f.stem}.jpg")
Exemple #4
0
def get_shapes_extra(root_dir):
    shapes = {}
    for patient in os.listdir(root_dir):
        try:
            img = read_image(os.path.join(root_dir, patient, "full",
                                          "CT.nrrd"),
                             no_meta=True)
            print(patient, img.shape)
            shapes[patient + "/full"] = img.shape

            img2 = read_image(os.path.join(root_dir, patient, "empty",
                                           "CT.nrrd"),
                              no_meta=True)
            print(patient, img2.shape)
            shapes[patient + "/empty"] = img2.shape
        except:
            print(patient, "failed")
    return shapes
def save_scans(args, slices):
    source_dir = Path("/data/cervix/patients")
    for patient, scans in tqdm(slices["fake"].items()):
        for scan, s in scans.items():
            s = sorted(s, key=lambda tup: tup[0])
            image = []
            for tup in s:

                im_frame = Image.open(tup[1]).convert('L')
                np_array = np.array(im_frame.getdata()).reshape(
                    (512, 512)) / 255.0
                image.append(np_array)

            img = np.stack(image)
            img = img * 1500 + 250
            if not (source_dir / patient /
                    (f"bladder_{scan.lower()}.nii")).exists():
                continue
            _, meta = read_image(str(source_dir / patient / ("CT1.nii")))
            bladder = read_image(
                str(source_dir / patient / (f"bladder_{scan.lower()}.nii")),
                no_meta=True,
                ref_fn=str(source_dir / patient / ("CT1.nii")),
                affine_matrix=True)
            cervix = read_image(str(source_dir / patient /
                                    (f"cervix_uterus_{scan.lower()}.nii")),
                                no_meta=True,
                                ref_fn=str(source_dir / patient / ("CT1.nii")),
                                affine_matrix=True)

            if not (args.output_dir / patient).exists():
                (args.output_dir / patient).mkdir()
            write_image(img,
                        str(args.output_dir / patient / (scan + ".nrrd")),
                        metadata=meta)
            write_image(bladder,
                        str(args.output_dir / patient /
                            (f"bladder_{scan.lower()}.nrrd")),
                        metadata=meta)
            write_image(cervix,
                        str(args.output_dir / patient /
                            (f"cervix_uterus_{scan.lower()}.nrrd")),
                        metadata=meta)
Exemple #6
0
def get_shapes(root_dir):
    shapes = {}
    for patient in os.listdir(root_dir):
        try:
            img = read_image(os.path.join(root_dir, patient, "CT1.nii"),
                             no_meta=True)
            print(patient, img.shape)
            shapes[patient] = img.shape
        except:
            print(patient, "failed")
    return shapes
Exemple #7
0
 def _get_segmentation(self, segmentations):
     if len(segmentations) == 2:
         seg_bladder = read_image(segmentations[0], no_meta=True)
         seg_cervix_uterus = read_image(segmentations[1], no_meta=True)
         all_segs = seg_bladder + seg_cervix_uterus
     # Combine cervix and uterus segmentation
     elif len(segmentations) == 3:
         seg_bladder = read_image(segmentations[0], no_meta=True)
         seg_cervix = read_image(segmentations[1], no_meta=True)
         seg_uterus = read_image(segmentations[2], no_meta=True)
         seg_cervix_uterus = (seg_cervix | seg_uterus)
         all_segs = seg_bladder + seg_cervix + seg_uterus
     start = int((all_segs.shape[1] - 512) / 2)
     seg_bladder = crop_to_bbox(
         seg_bladder, (0, start, start, seg_bladder.shape[0], 512, 512))
     seg_cervix_uterus = crop_to_bbox(
         seg_cervix_uterus,
         (0, start, start, seg_cervix_uterus.shape[0], 512, 512))
     all_segs = crop_to_bbox(all_segs,
                             (0, start, start, all_segs.shape[0], 512, 512))
     other = all_segs < 1
     segs = [seg_bladder, seg_cervix_uterus, other]
     segmentation = np.stack(segs).astype(int)
     return segmentation
Exemple #8
0
    def _load_image(self, patient):
        cache_fn = self.cachedir / f"{patient}_CT1"
        cache_fn_seg = self.cachedir / f"{patient}_CT1_seg"
        if cache_fn.exists() and cache_fn_seg.exists():
            image = read_object(cache_fn)
            segmentation = read_object(cache_fn_seg)
        else:
            image_path, segmentation_paths = self.data[patient]
            image = read_image(image_path, no_meta=True)
            segmentation = self._get_segmentation(segmentation_paths)

            if len(image.shape) == 3:
                # add "channels" dimension if it is not present
                image = np.expand_dims(image, axis=0)

        return image, segmentation
Exemple #9
0
def get_shapes_cbct(root_dir):
    shapes = {}

    for patient in os.listdir(root_dir):
        images = glob.glob(os.path.join(root_dir, patient, "X*.nii"))
        for cbct in images:
            try:
                m = re.search("X[0-9]+", cbct)
                n = m.group(0)
                img = read_image(cbct, no_meta=True)
                segmentations = glob.glob(
                    os.path.join(root_dir, patient,
                                 "*_{}.nii".format(n.lower())))
                print(patient + "\\" + n, img.shape, len(segmentations))
                if len(segmentations) > 0:
                    shapes[patient + "\\" + n] = img.shape
            except:
                print(patient, "failed")
    return shapes
def process_CT(source_dir, target_dir):
    files = list(source_dir.rglob("CT.nrrd"))
    train = files[:int(0.8 * len(files))]
    val = files[int(0.8 * len(files)):int(0.9 * len(files))]
    test = files[int(0.9 * len(files)):]
    transform_CT = transforms.Compose([ClipAndNormalize(250, 1750)])

    image_index = 0
    # A: CT, B: CBCT
    for i, f in enumerate(tqdm(train)):
        image, meta = read_image(str(f))

        bboxes = np.array(
            get_table(image, clip_val=300, margin=2500, show_imgs=False))
        if len(bboxes) == 0:
            print("Table not found in image!")
            image_index = 0
            image = transform_CT({"image": image})["image"]
            for X in image[10:-10]:
                image_index += 1
                im = Image.fromarray(np.uint8(X.squeeze() * 255), 'L')
                im.save(f"TEMP/NO_TABLE_{image_index}_{f.parent.stem}.jpg")

        bbox = (stats.mode(bboxes[:, 0])[0][0], stats.mode(bboxes[:, 1])[0][0],
                stats.mode(bboxes[:, 2])[0][0], stats.mode(bboxes[:, 3])[0][0])
        image = np.clip(image, 0, image.max())
        image[:, bbox[0]:512, bbox[1]:bbox[3]] = 0

        image = transform_CT({"image": image})["image"]
        for X in image[10:-10]:
            image_index += 1
            im = Image.fromarray(np.uint8(X.squeeze() * 255), 'L')
            im.save(
                str(target_dir / "trainA" /
                    f"{image_index}_{f.parent.stem}.jpg"))
        if temp_save:
            im.save(f"TEMP/TRAIN_A_{image_index}_{f.parent.stem}.jpg")

    image_index = 0
    for f in tqdm(val):
        image, meta = read_image(str(f))

        bboxes = np.array(
            get_table(image, clip_val=200, margin=2500, show_imgs=False))
        if len(bboxes) == 0:
            print("Table not found in image!")

        bbox = (stats.mode(bboxes[:, 0])[0][0], stats.mode(bboxes[:, 1])[0][0],
                stats.mode(bboxes[:, 2])[0][0], stats.mode(bboxes[:, 3])[0][0])
        image = np.clip(image, 0, image.max())
        image[:, bbox[0]:512, bbox[1]:bbox[3]] = 0

        image = transform_CT({"image": image})["image"]
        for X in image[10:-10]:
            image_index += 1
            im = Image.fromarray(np.uint8(X.squeeze() * 255), 'L')
            im.save(
                str(target_dir / "valA" /
                    f"{image_index}_{f.parent.stem}.jpg"))
        if temp_save:
            im.save(f"TEMP/VAL_A_{image_index}_{f.parent.stem}.jpg")

    image_index = 0
    for f in tqdm(test):
        image, meta = read_image(str(f))

        bboxes = np.array(
            get_table(image, clip_val=200, margin=2500, show_imgs=False))
        if len(bboxes) == 0:
            print("Table not found in image!")
            image_index = 0
            image = transform_CT({"image": image})["image"]
            for X in image[10:-10]:
                image_index += 1
                im = Image.fromarray(np.uint8(X.squeeze() * 255), 'L')
                im.save(f"TEMP/NO_TABLE_{image_index}_{f.parent.stem}.jpg")

        bbox = (stats.mode(bboxes[:, 0])[0][0], stats.mode(bboxes[:, 1])[0][0],
                stats.mode(bboxes[:, 2])[0][0], stats.mode(bboxes[:, 3])[0][0])
        image = np.clip(image, 0, image.max())
        image[:, bbox[0]:512, bbox[1]:bbox[3]] = 0

        image = transform_CT({"image": image})["image"]
        for X in image[10:-10]:
            image_index += 1
            im = Image.fromarray(np.uint8(X.squeeze() * 255), 'L')
            im.save(
                str(target_dir / "testA" /
                    f"{image_index}_{f.parent.stem}.jpg"))
        if temp_save:
            im.save(f"TEMP/TEST_A_{image_index}_{f.parent.stem}.jpg")
def process_CBCT(source_dir, target_dir):
    files = list(source_dir.iterdir())
    files = [f for f in files if len(list(f.glob("X*.nrrd"))) > 0]

    train = files[:int(0.8 * len(files))]
    val = files[int(0.8 * len(files)):int(0.9 * len(files))]
    test = files[int(0.9 * len(files)):]
    transform_CT = transforms.Compose([ClipAndNormalize(250, 1750)])

    scan_id = 0
    image_index = 0
    for p in tqdm(train):
        cbct_count = 1
        for f in p.glob("X*.nrrd"):
            if cbct_count > 2: break
            image, meta = read_image(str(f))
            cbct_count += 1
            image = transform_CT({"image": image})["image"]
            scan_id += 1
            for X in image[10:-10]:
                image_index += 1
                im = Image.fromarray(np.uint8(X.squeeze() * 255), 'L')
                im.save(
                    str(target_dir / "trainB" /
                        f"{image_index}_{f.parent.stem}_{f.stem}.jpg"))
            if temp_save:
                im.save(
                    f"TEMP/TRAIN_B_{image_index}_{f.parent.stem}_{f.stem}_{scan_id}.jpg"
                )

    image_index = 0
    for p in tqdm(val):
        cbct_count = 1
        for f in p.glob("X*.nrrd"):
            if cbct_count > 2: break
            if f.stem == "X01":
                continue
            image, meta = read_image(str(f))
            cbct_count += 1
            image = transform_CT({"image": image})["image"]
            for X in image[10:-10]:
                image_index += 1
                im = Image.fromarray(np.uint8(X.squeeze() * 255), 'L')
                im.save(
                    str(target_dir / "valB" /
                        f"{image_index}_{f.parent.stem}_{f.stem}.jpg"))
            if temp_save:
                im.save(
                    f"TEMP/VAL_B_{image_index}_{f.parent.stem}_{f.stem}.jpg")

    image_index = 0
    scan_id = 0
    # A: CT, B: CBCT
    for p in tqdm(test):
        cbct_count = 1
        for f in p.glob("X*.nrrd"):
            if cbct_count > 2: break
            cbct_count += 1
            scan_id += 1
            image, meta = read_image(str(f))
            image = transform_CT({"image": image})["image"]
            for X in image[10:-10]:
                image_index += 1
                im = Image.fromarray(np.uint8(X.squeeze() * 255), 'L')
                im.save(
                    str(target_dir / "testB" /
                        f"{image_index}_{f.parent.stem}_{f.stem}.jpg"))
            if temp_save:
                im.save(
                    f"TEMP/TEST_B_{image_index}_{f.parent.stem}_{f.stem}_{scan_id}.jpg"
                )
Exemple #12
0
def test(args, dl, writer, model, image_shapes):
    device = "cuda"  # Run on GPU

    criterion = get_loss_func("NLL")
    softmax = nn.LogSoftmax(1)

    logger.info("Start Testing...")
    tmp_losses = []
    metrics = {"bladder": {}, "cervix": {}}

    segmentations = {0: [], 1: [], "y_bladder": [], "y_cervix": []}

    image_shapes.append(None)

    img_i = 0
    temp = image_shapes.pop(0)
    img_shape = temp[1][0]
    patient = temp[0].replace("_", "/")

    _, metadata = read_image(str(temp[2]))

    logger.debug(patient.replace("/", "_"))

    all_zeros = 0
    seg_slices = 0

    model.eval()
    for i, (X, Y) in enumerate(dl):
        X, Y = X.to(device).float(), Y.to(device).float()

        torch.cuda.empty_cache()
        Y_hat = model(X)
        assert Y_hat.shape == Y.shape, "output and classification must be same shape, {}, {}".format(
            Y_hat.shape, Y.shape)

        if args.save_3d:
            segmentations["y_bladder"].append(
                Y[:, 0, :, :, :].squeeze().detach().cpu())
            segmentations["y_cervix"].append(
                Y[:, 1, :, :, :].squeeze().detach().cpu())

        Y_hat = softmax(Y_hat)
        loss = criterion(Y_hat, Y.argmax(1))
        tmp_losses.append(loss.detach().cpu().item())

        segmentations[0].append(
            Y_hat.exp()[:,
                        0, :, :, :].squeeze().detach().cpu() > args.threshold)
        segmentations[1].append(
            Y_hat.exp()[:,
                        1, :, :, :].squeeze().detach().cpu() > args.threshold)

        img_i += 1

        if img_i >= img_shape and args.save_3d:
            img_i = 0
            logger.info(f"Saving image {patient}")

            y_bladder = torch.stack(
                segmentations["y_bladder"]).detach().cpu().numpy()
            y_cervix = torch.stack(
                segmentations["y_cervix"]).detach().cpu().numpy()
            seg_bladder = torch.stack(
                segmentations[0]).int().detach().cpu().numpy()
            seg_cervix = torch.stack(
                segmentations[1]).int().detach().cpu().numpy()

            if args.post_process:
                labels_mask = measure.label(seg_bladder)
                regions = measure.regionprops(labels_mask)
                regions.sort(key=lambda x: x.area, reverse=True)
                if len(regions) > 1:
                    for rg in regions[1:]:
                        labels_mask[rg.coords[:, 0], rg.coords[:, 1]] = 0
                labels_mask[labels_mask != 0] = 1
                seg_bladder = labels_mask

                labels_mask = measure.label(seg_cervix)
                regions = measure.regionprops(labels_mask)
                regions.sort(key=lambda x: x.area, reverse=True)
                if len(regions) > 1:
                    for rg in regions[1:]:
                        labels_mask[rg.coords[:, 0], rg.coords[:, 1]] = 0
                labels_mask[labels_mask != 0] = 1
                seg_cervix = labels_mask

            write_image(seg_bladder.astype(np.uint8),
                        "{}/{}_seg_bladder.nrrd".format(
                            args.test_folder, patient.replace("/", "_")),
                        metadata=metadata)
            write_image(seg_cervix.astype(np.uint8),
                        "{}/{}_seg_cervix_uterus.nrrd".format(
                            args.test_folder, patient.replace("/", "_")),
                        metadata=metadata)

            metrics_bladder = calculate_metrics(y_bladder.astype(bool),
                                                seg_bladder.astype(bool),
                                                metadata["spacing"], 25.0,
                                                [0.5, 1.0, 1.5, 3.0])
            metrics_cervix = calculate_metrics(y_cervix.astype(bool),
                                               seg_cervix.astype(bool),
                                               metadata["spacing"], 25.0,
                                               [0.5, 1.0, 1.5, 3.0])

            metrics["bladder"][patient] = metrics_bladder
            metrics["cervix"][patient] = metrics_cervix

            for m, v in metrics_bladder.items():
                logger.info(f"{m} bladder: {v}")
            for m, v in metrics_cervix.items():
                logger.info(f"{m} cervix: {v}")

            segmentations = {0: [], 1: [], "y_bladder": [], "y_cervix": []}
            temp = image_shapes.pop(0)
            if not temp is None:
                img_shape = temp[1][0]
                patient = temp[0].replace("_", "/")
                CT_path = temp[2]
                if CT_path.exists():
                    _, metadata = read_image(str(CT_path))

        torch.cuda.empty_cache()

        _log_images(X, Y, Y_hat, i, writer, tag="test")

        if img_i % args.print_every == 0:
            logger.info("Iteration: {}/{} Loss: {}".format(
                i, len(dl),
                sum(tmp_losses) / len(tmp_losses)))
            tmp_losses = []

    pickle.dump(metrics, open("{}/metrics.p".format(args.test_folder), 'wb'))

    writer.flush()
    logger.info("End testing")
'''
Author: Tessa Wagenaar

Calculate the data statistics
'''
import glob
import os
from pathlib import Path
from utils.image_readers import read_image
import pickle
import numpy as np


files_CBCT = pickle.load(open("files_CBCT.p", 'rb'))
statistics = {}

print("calculate image data")
for (patient, shape, image_fn, segmentations) in files_CBCT:
    image, metadata = read_image(str(image_fn))
    print(f"Min: {image.min()} Max: {image.max()}")


pickle.dump(statistics, open("CBCT_statistics.p", 'wb'))