예제 #1
0
def create_3dshallow_dataset(depth=5):
    """
    Create 3D shallow images and save with the given depth to allow fast shuffled loading during training.
    :param depth: int: slice thickness to create for images
    """
    assert (
        depth > 1 and depth % 2 == 1
    ), f"depth must be an odd number greater than 1 but got {depth}"
    reader = NIIReader()

    base_folder = Path("/media/y4tsu/4B172BDA26AB3054/cmr_folds")
    folds = [x for x in base_folder.iterdir() if x.is_dir()]

    for fold in folds:
        data_3d = fold / "3D"
        data_3d_shallow = fold / "3DShallow"
        data_3d_shallow.mkdir(parents=False, exist_ok=True)

        train_data_3d_shallow = data_3d_shallow / "train"
        val_data_3d_shallow = data_3d_shallow / "val"
        train_data_3d_shallow.mkdir(parents=False, exist_ok=True)
        val_data_3d_shallow.mkdir(parents=False, exist_ok=True)

        transverse_train_2d = train_data_3d_shallow / "transverse"
        transverse_val_2d = val_data_3d_shallow / "transverse"
        transverse_train_2d.mkdir(parents=False, exist_ok=True)
        transverse_val_2d.mkdir(parents=False, exist_ok=True)

        for g in ["train", "val"]:
            for x in tqdm((data_3d / g).iterdir()):
                image_fname = x / f"{x.stem}_SAX.nii.gz"
                label_fname = x / f"{x.stem}_SAX_mask2.nii.gz"

                image = reader.read(image_fname)
                label = reader.read(label_fname)

                curr_dir = data_3d_shallow / g / "transverse" / x.stem
                curr_dir.mkdir(parents=False, exist_ok=True)

                for j in range(image.shape[2] - depth + 1):
                    np.save(
                        str(
                            data_3d_shallow
                            / g
                            / "transverse"
                            / x.stem
                            / f"{x.stem}_{j:03}_image"
                        ),
                        image[:, :, j : j + depth],
                    )
                    np.save(
                        str(
                            data_3d_shallow
                            / g
                            / "transverse"
                            / x.stem
                            / f"{x.stem}_{j:03}_label"
                        ),
                        label[:, :, j : j + depth],
                    )
예제 #2
0
def find_widths_distributions(root):
    """Get the widths of all images (square) in the dataset and display as a distribution."""
    depths = []
    reader = NIIReader()

    for g in ["train", "val"]:
        for x in tqdm(sorted(os.listdir(os.path.join(root, "3D", g)))):
            image_fname = os.path.join(root, "3D", g, x, f"{x}_SAX.nii.gz")
            image = reader.read(image_fname)
            depths += [image.shape[0]]

    plt.hist(
        depths,
        bins=np.arange(min(depths), max(depths) + 2) - 0.5,
        rwidth=1,
        color="cadetblue",
    )
    plt.xlabel("Width of image")
    plt.ylabel("Frequency")
    plt.title("Frequency of image widths")

    # Create legend & Show graphic
    plt.savefig(
        "/home/y4tsu/Desktop/diss img/width_frequencies.png",
        bbox_inches="tight",
        dpi=300,
    )
    plt.savefig(
        "/home/y4tsu/Desktop/diss img/width_frequencies.pdf", bbox_inches="tight"
    )
    plt.show()
예제 #3
0
    def _get_manual_bboxes(self):
        """Load the manual segmentation masks from 3D images to get bounding boxes on the fly."""
        print(
            f"Loading bboxes from manual segmentations for {self.dataset} data ... "
        )
        reader = NIIReader()
        base_folder = os.path.join(self.data_path, "3D", self.dataset)
        roots = sorted(os.listdir(base_folder))
        out = {}

        for i, root in enumerate(tqdm(roots)):
            label = np.squeeze(
                reader.read(
                    os.path.join(base_folder, root,
                                 f"{root}_SAX_mask2.nii.gz")))
            binary_label = np.where(np.equal(label, 0), 0, 1)
            bbox = self.__find_bbox(binary_label)

            out[root] = {
                "top": bbox[0],
                "left": bbox[1],
                "bottom": bbox[2],
                "right": bbox[3],
            }

        return out
예제 #4
0
def create_2d_dataset():
    """
    Create the 2D images and labels slices so that they can be loaded quickly in a shuffled order during training.
    """
    reader = NIIReader()
    base_folder = Path("/media/y4tsu/4B172BDA26AB3054/cmr_folds")
    folds = [x for x in base_folder.iterdir() if x.is_dir()]

    for fold in folds:
        data_3d = fold / "3D"
        data_2d = fold / "2D"
        data_2d.mkdir(parents=False, exist_ok=True)

        train_data_2d = data_2d / "train"
        val_data_2d = data_2d / "val"
        train_data_2d.mkdir(parents=False, exist_ok=True)
        val_data_2d.mkdir(parents=False, exist_ok=True)

        transverse_train_2d = train_data_2d / "transverse"
        transverse_val_2d = val_data_2d / "transverse"
        transverse_train_2d.mkdir(parents=False, exist_ok=True)
        transverse_val_2d.mkdir(parents=False, exist_ok=True)

        for g in ["train", "val"]:
            for x in tqdm((data_3d / g).iterdir()):
                image_fname = x / f"{x.stem}_SAX.nii.gz"
                label_fname = x / f"{x.stem}_SAX_mask2.nii.gz"

                image = reader.read(image_fname)
                label = reader.read(label_fname)

                curr_dir = data_2d / g / "transverse" / x.stem
                curr_dir.mkdir(parents=False, exist_ok=True)

                for j in range(image.shape[2]):
                    np.save(
                        str(
                            data_2d
                            / g
                            / "transverse"
                            / x.stem
                            / f"{x.stem}_{j:03}_image"
                        ),
                        image[:, :, j],
                    )
                    np.save(
                        str(
                            data_2d
                            / g
                            / "transverse"
                            / x.stem
                            / f"{x.stem}_{j:03}_label"
                        ),
                        label[:, :, j],
                    )
예제 #5
0
def search_incorrect_orientations():
    """
    Look through all the available images to find those which are malrotated with respect to the majority so that they
    can be fixed.
    """
    reader = NIIReader()
    base_folder = Path("/media/y4tsu/4B172BDA26AB3054/cmr_clean")
    roots = sorted([x for x in base_folder.iterdir() if x.is_dir()])
    print(f"{len(roots)=}")

    for root in roots:
        print(f"Reading root {root}")
        file_num = root.stem[-3:]
        image = np.squeeze(reader.read(root / f"20CA015_N{file_num}_SAX.nii.gz"))
        print(f"{image[:20, ...].sum()=}")
        reader.scroll_view(image)
예제 #6
0
    def __init__(self, data_path, dataset, model_path, train_config,
                 post_process):
        super().__init__(data_path, dataset, train_config, post_process)
        self.full_data_path = os.path.join(data_path, "3D", self.dataset)

        self.reader = NIIReader() if not self.cascade else NPYReader()

        # TODO: test this all works ok
        self.image_fnames = [
            os.path.join(self.full_data_path, x)
            for x in sorted(os.listdir(self.full_data_path))
        ]
        self.label_fnames = [
            os.path.join(self.full_data_path, x)
            for x in sorted(os.listdir(self.full_data_path))
        ]

        self.model = self.load_model(model_path)
        self.dimensionality = "3D"
예제 #7
0
 def calculate_label_weights(self):
     """Calculate beta pixel weighting and its inverse as the label weights for weighted loss functions."""
     print(
         f"Calculating label weightings across {len(self.train_gen.image_fnames)} label images for use in loss"
         f" function, may take a while ... ")
     if self.combine_labels:
         sums = np.zeros(len(self.combine_labels), dtype=np.float32)
     else:
         sums = np.zeros(len(self.labels), dtype=np.float32)
     for i in tqdm(
             range(len(self.train_gen.image_fnames) // self.batch_size)):
         _, label_img = self.train_gen.__getitem__(i, weight_mode=True)
         # Get the number of labelled voxels of each class for each label image
         if self.quality_weighted_mode:
             sums += [
                 label_img["m"][..., j].sum()
                 for j in range(label_img["m"].shape[-1])
             ]
         else:
             sums += [
                 label_img[..., j].sum() for j in range(label_img.shape[-1])
             ]
     # Get the total number of voxels in the dataset to normalize the beta
     if self.model != "UNet3DFrozenDepth":
         total_voxels = np.prod(
             np.array([*self.image_size,
                       len(self.train_gen.image_fnames)]))
     else:
         total_voxels = 0
         reader = NIIReader()
         for fname in self.train_gen.image_fnames:
             img = reader.read(fname)
             total_voxels += self.image_size[0] * self.image_size[
                 1] * img.shape[-1]
     beta = sums / total_voxels
     print(1.0 / beta)
     # Return weightings: 1 / beta
     return 1.0 / beta
예제 #8
0
 def __init__(
     self,
     model_save_path,
     generic_data_path,
     data_path,
     plane,
     batch_size,
     image_size,
     labels,
     dataset,
     shuffle=True,
     augmenter=None,
     use_cropper=False,
     combine_labels=None,
     cascade=None,
     quality_weighting_scores=None,
 ):
     super().__init__(
         model_save_path,
         generic_data_path,
         data_path,
         plane,
         batch_size,
         image_size,
         labels,
         dataset,
         shuffle,
         augmenter,
         use_cropper,
         combine_labels,
         cascade,
         quality_weighting_scores,
     )
     if cascade:
         self.reader = NPYReader()
         self.image_fnames = [
             os.path.join(model_save_path, "mask", dataset, f"{x}_SAX.nii.gz")
             for x in sorted(
                 os.listdir(os.path.join(model_save_path, "mask", dataset))
             )
         ]
     else:
         self.reader = NIIReader()
     if self.quality_weighting_scores:
         self.resizer = NPYReader()
예제 #9
0
def rotate_incorrect_orientations():
    """
    Rotate a list of manually specified images by 90 degrees anti-clockwise so that all images are in the same
    orientation to simplify the learning task.
    """
    reader = NIIReader()
    base_folder = Path("/media/y4tsu/4B172BDA26AB3054/cmr_clean")
    roots = sorted([x for x in base_folder.iterdir() if x.is_dir()])
    print(f"{len(roots)=}")

    non_squares = []

    for i, root in enumerate(roots):
        file_num = root.stem[-3:]
        image = np.squeeze(reader.read(root / f"20CA015_N{file_num}_SAX.nii.gz"))
        shape = image.shape
        label = np.squeeze(reader.read(root / f"20CA015_N{file_num}_SAX_mask2.nii.gz"))

        if shape[0] != shape[1]:
            print(f"{root} is non-square")
            non_squares += [root]

        if file_num in [
            "008",
            "014",
            "024",
            "030",
            "062",
            "064",
            "083",
            "089",
            "135",
            "138",
            "141",
            "144",
            "156",
            "159",
            "168",
            "174",
            "181",
            "192",
            "213",
            "215",
            "227",
            "262",
            "278",
            "294",
            "304",
            "307",
            "319",
            "330",
            "347",
            "348",
            "353",
            "355",
            "375",
        ]:
            # Show the original image
            # plt.imshow(label[:, :, 5], cmap="gray")
            # plt.show()

            # Rotate the bad images and labels
            rot_image = rotate(image, axes=(0, 1), angle=-90.0, reshape=False, order=3)
            rot_label = rotate(label, axes=(0, 1), angle=-90.0, reshape=False, order=0)

            # Set them as nifti type images
            new_img = nib.Nifti1Image(rot_image, np.eye(4))
            new_label = nib.Nifti1Image(rot_label, np.eye(4))

            # Save them, overwriting the original files
            nib.save(new_img, root / f"20CA015_N{file_num}_SAX.nii.gz")
            nib.save(new_label, root / f"20CA015_N{file_num}_SAX_mask2.nii.gz")

            # Show the newly rotated image
            # plt.imshow(rot_label[:, :, 5], cmap="gray")
            # plt.show()

    print(f"{len(non_squares)=}")
예제 #10
0
class Predictor3D(__Predictor):
    def __init__(self, data_path, dataset, model_path, train_config,
                 post_process):
        super().__init__(data_path, dataset, train_config, post_process)
        self.full_data_path = os.path.join(data_path, "3D", self.dataset)

        self.reader = NIIReader() if not self.cascade else NPYReader()

        # TODO: test this all works ok
        self.image_fnames = [
            os.path.join(self.full_data_path, x)
            for x in sorted(os.listdir(self.full_data_path))
        ]
        self.label_fnames = [
            os.path.join(self.full_data_path, x)
            for x in sorted(os.listdir(self.full_data_path))
        ]

        self.model = self.load_model(model_path)
        self.dimensionality = "3D"

    def load_image_label(self, fname):
        """Loads the image and label files."""
        image_folder, label_folder, suffix, fname = self._get_folder_paths(
            fname)

        # Load image and label
        image = self.reader.read(
            os.path.join(image_folder, f"{suffix}_SAX.nii.gz"))
        label = self.reader.read(
            os.path.join(label_folder, f"{suffix}_SAX_mask2.nii.gz"))

        image, label = self._prepare_image_label(image, label, suffix)

        # Set to the correct rank
        image = image[np.newaxis, ..., np.newaxis]

        return image, label, fname

    def predict(self,
                fname=None,
                display=False,
                apply_combine=True,
                return_fname=False):
        image, label, fname = self.load_image_label(fname)
        if self.quality_weighted_mode:
            pred_label = self.model.predict(
                (image, np.array([1.0], dtype=np.float32)))[1]
        else:
            pred_label = self.model.predict(image)

        pred_label = np.squeeze(np.argmax(pred_label, axis=-1))

        if self.combine_labels and apply_combine:
            label = self.apply_label_combine(label)

        if self.post_process:
            pred_label = self.post_process_label(pred_label)

        if display:
            print(self.calculate_dice(label, pred_label))
            self.display(image, label, pred_label)

        if return_fname:
            return image, label, pred_label, fname
        else:
            return image, label, pred_label
예제 #11
0
    with open(os.path.join(path_tr, "train_config.json"), "r") as f:
        p1 = Predictor2D(data_path, dataset, path_tr, json.load(f))

    # Sagittal predictor
    with open(os.path.join(path_sag, "train_config.json"), "r") as f:
        p2 = Predictor2D(data_path, dataset, path_sag, json.load(f))

    # Coronal predictor
    with open(os.path.join(path_cor, "train_config.json"), "r") as f:
        p3 = Predictor2D(data_path, dataset, path_cor, json.load(f))

    return p1, p2, p3


if __name__ == "__main__":
    nii_reader = NIIReader()
    data_path = "/media/y4tsu/ml_data/cmr"
    dataset = "val"

    p_tr, p_sag, p_cor = setup_predictors(
        data_path,
        dataset,
        "/home/y4tsu/PycharmProjects/3d_unet/checkpoint/2D_tr",
        "/home/y4tsu/PycharmProjects/3d_unet/checkpoint/2D_sag",
        "/home/y4tsu/PycharmProjects/3d_unet/checkpoint/2D_cor",
    )

    roots = sorted(os.listdir(os.path.join(data_path, "3D", dataset)))
    dices = 0
    class_wise_dices = np.zeros([len(p_tr.labels_dict)])
    start = time.time()
예제 #12
0
    def _get_auto_bboxes(self, model_path):
        """Run the segmentation model to get bounding boxes for all the images in the dataset."""
        from predict import load_predictor

        print(
            f"Predicting bboxes using automatic cropper model at {model_path} for {self.dataset} data ... "
        )
        reader = NIIReader()
        base_folder = os.path.join(self.data_path, "3D", self.dataset)
        roots = sorted(os.listdir(base_folder))
        out = {}

        # Set up the correct config for the predictor model
        predict_config = {
            "model_path": model_path,
            "data_path": self.data_path,
            "dataset": self.dataset,
            "post_process": False,
        }

        # Load the cropper model as a Predictor object
        p = load_predictor(predict_config)

        # Iterate over all the images, getting predicted labels
        for i, root in enumerate(tqdm(roots)):
            # Search for the correct full filename
            fname = None
            for x in p.image_fnames:
                if root in x:
                    fname = x

            if fname is None:
                raise ValueError(
                    f"Unable to find the correct path for image {root}")

            image_size = reader.read(
                os.path.join(self.data_path, "3D", self.dataset,
                             f"{root}/{root}_SAX.nii.gz")).shape

            _, label, pred_label = p.predict(fname, display=False)

            # Remove noise and small islands from the prediction
            pred_label = self.__clean_prediction(pred_label)

            # In the case where an image has different dimensions to model input, needs to be re-scaled
            new_pred_label = np.empty(image_size, dtype=np.int8)
            if image_size[0] != p.image_size[0]:
                for j in range(image_size[-1]):
                    curr = pred_label[..., j]
                    new_pred_label[..., j] = cv2.resize(
                        curr,
                        tuple(reversed(image_size[:2])),
                        interpolation=cv2.INTER_NEAREST,
                    )
                pred_label = new_pred_label

            # Now find the bounding box around the segmentation mask
            bbox = self.__find_bbox(pred_label)

            out[root] = {
                "top": bbox[0],
                "left": bbox[1],
                "bottom": bbox[2],
                "right": bbox[3],
            }

        return out
예제 #13
0
def main():
    start = time.time()

    with open("predict_config.json", "r") as f:
        predict_config = json.load(f)

    # Load the correct Predictor class for the given model type
    p = load_predictor(predict_config)
    plane = p.plane
    cropper = p.cropper

    if p.model_name in ["UNet3D", "VNet", "UNet3DFrozenDepth"]:
        dims = "3D"
        plane = ""
    elif p.model_name in ["UNet3DShallow", "VNetShallow"]:
        dims = "3DShallow"
    else:
        dims = "2D"

    # Get the names of all the scans we are interested in
    roots = sorted(
        os.listdir(
            os.path.join(
                predict_config["data_path"], dims, predict_config["dataset"], plane
            )
        )
    )
    roots = [
        os.path.join(
            predict_config["data_path"], dims, predict_config["dataset"], plane, root
        )
        for root in roots
    ]
    print(f"{roots=}")

    # Get the correct metadata
    headers = {}
    shapes = {}
    affines = {}
    reader = NIIReader()

    # Get roots from the 3D folder
    real_roots = sorted(
        os.listdir(
            os.path.join(predict_config["data_path"], "3D", predict_config["dataset"])
        )
    )
    real_roots = [
        os.path.join(predict_config["data_path"], "3D", predict_config["dataset"], x)
        for x in real_roots
    ]

    # Load important information from the original .nii.gz files for these images
    for real_root in tqdm(real_roots):
        name_end = real_root.split("/")[-1]
        img = nib.load(
            os.path.join(real_root, f"{name_end}_SAX_mask2.nii.gz"), mmap=False
        )
        # Headers = headers from the original NiFTi file
        headers[name_end] = img.header
        # Affine = transformation which maps points to 3D space of the MRI image
        affines[name_end] = img.affine
        # Shape of each image
        shapes[name_end] = np.squeeze(img.get_fdata()).shape

    # Get a prediction for each image and save it in NiFTi format for loading in ITK-SNAP
    for root in tqdm(roots):
        # Get the prediction
        image, label, pred_label = p.predict(fname=root, display=False)
        # Get the key to find the correct header and affine for the output
        name_end = root.split("/")[-1]
        # print(f'{name_end=}')
        save_img = np.zeros(shapes[name_end])
        # print(f'{save_img.shape=}')
        # bbox = cropper.bboxes[name_end]
        # print(f"{bbox=}")
        reverse_size, cut_dims = p.cropper.reverse_crop(shapes[name_end], name_end)
        # print(f"{reverse_size=}")
        # print(f"{cut_dims=}")
        # Resize the prediction to undo the pre-processing
        pred_label = reader.resize(pred_label, reverse_size, interpolation_order=0)
        # Get the correct portion of the predicted label to output
        try:
            save_img[
                cut_dims["top"] : cut_dims["bottom"],
                cut_dims["left"] : cut_dims["right"],
                ...,
            ] = pred_label
        except ValueError:
            print(
                f"Unable to undo pre-processing for image {name_end}! {save_img.shape=}, {pred_label.shape=}"
            )
            continue
        # Save the image in NiFTi format
        # exit()
        save_img = save_img.astype(np.uint16)

        save_img = nib.Nifti1Image(
            save_img,
            affine=affines[name_end],
            header=headers[name_end],
        )

        nib.save(
            save_img,
            os.path.join(
                "/home/y4tsu/PycharmProjects/3d_unet/saved_preds",
                f"{name_end}_prediction.nii.gz",
            ),
        )
        # print('-----')

    print(f"Finished! Process took {time.time() - start:.2f} seconds.")