コード例 #1
0
ファイル: base.py プロジェクト: MIC-DKFZ/nnDetection
def extract_results(
    source_dir: PathLike,
    target_dir: PathLike,
    ensembler_cls: Callable,
    restore: bool,
    **params,
) -> None:
    """
    Compute case result from ensembler and save it

    Args:
        source_dir: directory which contains the saved predictions/state from
            the ensembler class
        target_dir: directory to save results
        ensembler_cls: ensembler class for prediction
        restore: if true, the results are converted into the opriginal image
            space
    """
    Path(target_dir).mkdir(parents=True, exist_ok=True)
    for case_id in maybe_verbose_iterable(
            ensembler_cls.get_case_ids(source_dir)):
        ensembler = ensembler_cls.from_checkpoint(base_dir=source_dir,
                                                  case_id=case_id)
        ensembler.update_parameters(**params)

        pred = to_numpy(ensembler.get_case_result(restore=restore))

        save_pickle(pred,
                    Path(target_dir) / f"{case_id}_{ensembler_cls.ID}.pkl")
コード例 #2
0
ファイル: prepare.py プロジェクト: MIC-DKFZ/nnDetection
def main():
    det_data_dir = Path(os.getenv('det_data'))
    task_data_dir = det_data_dir / "Task017_CADA"

    # setup raw paths
    source_data_dir = task_data_dir / "raw" / "train_dataset"
    if not source_data_dir.is_dir():
        raise RuntimeError(
            f"{source_data_dir} should contain the raw data but does not exist."
        )
    source_label_dir = task_data_dir / "raw" / "train_mask_images"
    if not source_label_dir.is_dir():
        raise RuntimeError(
            f"{source_label_dir} should contain the raw labels but does not exist."
        )

    # setup raw splitted dirs
    target_data_dir = task_data_dir / "raw_splitted" / "imagesTr"
    target_data_dir.mkdir(exist_ok=True, parents=True)
    target_label_dir = task_data_dir / "raw_splitted" / "labelsTr"
    target_label_dir.mkdir(exist_ok=True, parents=True)

    # prepare dataset info
    meta = {
        "name": "CADA",
        "task": "Task017_CADA",
        "target_class": None,
        "test_labels": False,
        "labels": {
            "0": "aneurysm"
        },
        "modalities": {
            "0": "CT"
        },
        "dim": 3,
    }
    save_json(meta, task_data_dir / "dataset.json")

    # prepare data & label
    case_ids = [(p.stem).rsplit('_', 1)[0]
                for p in source_data_dir.glob("*.nii.gz")]
    print(f"Found {len(case_ids)} case ids")
    for cid in maybe_verbose_iterable(case_ids):
        run_prep(
            source_data=source_data_dir / f"{cid}_orig.nii.gz",
            source_label=source_label_dir / f"{cid}_labeledMasks.nii.gz",
            target_data_dir=target_data_dir,
            target_label_dir=target_label_dir,
        )
コード例 #3
0
ファイル: bg_loader.py プロジェクト: MIC-DKFZ/nnDetection
    def build_cache(self) -> Dict[str, List]:
        """
        Build up cache for sampling

        Returns:
            Dict[str, List]: cache for sampling
                `case`: list with all case identifiers
                `instances`: list with tuple of (case_id, instance_id)
        """
        instance_cache = []

        logger.info("Building Sampling Cache for Dataloder")
        for case_id, item in maybe_verbose_iterable(self._data.items(), desc="Sampling Cache"):
            instances = load_pickle(item['boxes_file'])["instances"]
            if instances:
                for instance_id in instances:
                    instance_cache.append((case_id, instance_id))
        return {"case": list(self._data.keys()), "instances": instance_cache}
コード例 #4
0
ファイル: bg_loader.py プロジェクト: MIC-DKFZ/nnDetection
    def build_cache(self) -> Tuple[Dict[int, List[Tuple[str, int]]], List]:
        """
        Build up cache for sampling

        Returns:
            Dict[int, List[Tuple[str, int]]]: foreground cache which contains
                of list of tuple of case ids and instance ids for each class
            List: background cache (all samples which do not have any
                foreground)
        """
        fg_cache = defaultdict(list)

        logger.info("Building Sampling Cache for Dataloder")
        for case_id, item in maybe_verbose_iterable(self._data.items(), desc="Sampling Cache"):
            candidates = load_pickle(item['boxes_file'])
            if candidates["instances"]:
                for instance_id, instance_class in zip(candidates["instances"], candidates["labels"]):
                    fg_cache[int(instance_class)].append((case_id, instance_id))
        return {"fg": fg_cache, "case": list(self._data.keys())}
コード例 #5
0
ファイル: prepare.py プロジェクト: MIC-DKFZ/nnDetection
def main():
    det_data_dir = Path(os.getenv('det_data'))
    task_data_dir = det_data_dir / "Task019FG_ADAM"
    
    # setup raw paths
    source_data_dir = task_data_dir / "raw" / "ADAM_release_subjs"
    if not source_data_dir.is_dir():
        raise RuntimeError(f"{source_data_dir} should contain the raw data but does not exist.")

    # setup raw splitted dirs
    target_data_dir = task_data_dir / "raw_splitted" / "imagesTr"
    target_data_dir.mkdir(exist_ok=True, parents=True)
    target_label_dir = task_data_dir / "raw_splitted" / "labelsTr"
    target_label_dir.mkdir(exist_ok=True, parents=True)

    # prepare dataset info
    meta = {
        "name": "ADAM",
        "task": "Task019FG_ADAM",
        "target_class": None,
        "test_labels": False,
        "labels": {"0": "Aneurysm"}, # since we are running FG vs BG this is not completely correct
        "modalities": {"0": "Structured", "1": "TOF"},
        "dim": 3,
    }
    save_json(meta, task_data_dir / "dataset.json")

    # prepare data
    case_ids = [p.stem for p in source_data_dir.iterdir() if p.is_dir()]
    print(f"Found {len(case_ids)} case ids")
    for cid in maybe_verbose_iterable(case_ids):
        run_prep_fg_v_bg(
            case_id=cid,
            source_data=source_data_dir,
            target_data_dir=target_data_dir,
            target_label_dir=target_label_dir,
            )
コード例 #6
0
    assert len(image_paths) == len(label_paths)

    meta = {
        "name": "RibFracFG",
        "task": "Task020FG_RibFrac",
        "target_class": None,
        "test_labels": False,
        "labels": {
            "0": "fracture"
        },  # since we are running FG vs BG this is not completely correct
        "modalities": {
            "0": "CT"
        },
        "dim": 3,
    }
    save_json(meta, task_data_dir / "dataset.json")

    for ip, lp in maybe_verbose_iterable(list(zip(image_paths, label_paths))):
        create(
            image_source=ip,
            label_source=lp,
            image_target_dir=target_data_dir,
            label_target_dir=target_label_dir,
            df=df,
            fg_only=True,
        )


if __name__ == '__main__':
    main()
コード例 #7
0
ファイル: prepare.py プロジェクト: MIC-DKFZ/nnDetection
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        'tasks',
        type=str,
        nargs='+',
        help="One or multiple of: Task003_Liver, Task007_Pancreas, "
        "Task008_HepaticVessel, Task010_Colon",
    )
    args = parser.parse_args()
    tasks = args.tasks

    decathlon_props = {
        "Task003_Liver": {
            "seg2det_stuff": [
                1,
            ],  # liver
            "seg2det_things": [
                2,
            ],  # cancer
            "min_size": 3.,
            "labels": {
                "0": "cancer"
            },
            "labels_stuff": {
                "1": "liver"
            },
        },
        "Task007_Pancreas": {
            "seg2det_stuff": [
                1,
            ],  # pancreas
            "seg2det_things": [
                2,
            ],
            "min_size": 3.,
            "labels": {
                "0": "cancer"
            },
            "labels_stuff": {
                "1": "pancreas"
            },
        },
        "Task008_HepaticVessel": {
            "seg2det_stuff": [
                1,
            ],  # vessel
            "seg2det_things": [
                2,
            ],
            "min_size": 3.,
            "labels": {
                "0": "tumour"
            },
            "labels_stuff": {
                "1": "vessel"
            },
        },
        "Task010_Colon": {
            "seg2det_stuff": [],
            "seg2det_things": [
                1,
            ],
            "min_size": 3.,
            "labels": {
                "0": "cancer"
            },
            "labels_stuff": {},
        },
    }

    basedir = Path(os.getenv('det_data'))
    for task in tasks:
        task_data_dir = basedir / task

        logger.remove()
        logger.add(sys.stdout, level="INFO")
        logger.add(task_data_dir / "prepare.log", level="DEBUG")
        logger.info(f"Preparing task: {task}")

        source_raw_dir = task_data_dir / "raw"
        source_data_dir = source_raw_dir / "imagesTr"
        source_labels_dir = source_raw_dir / "labelsTr"
        splitted_dir = task_data_dir / "raw_splitted"

        if not source_data_dir.is_dir():
            raise ValueError(f"Exptected training images at {source_data_dir}")
        if not source_labels_dir.is_dir():
            raise ValueError(
                f"Exptected training labels at {source_labels_dir}")
        if not (p := source_raw_dir / "dataset.json").is_file():
            raise ValueError(f"Expected dataset json to be located at {p}")

        target_data_dir = splitted_dir / "imagesTr"
        target_label_dir = splitted_dir / "labelsTr"
        target_data_dir.mkdir(parents=True, exist_ok=True)
        target_label_dir.mkdir(parents=True, exist_ok=True)

        # preapre meta
        original_meta = load_json(source_raw_dir / "dataset.json")

        dataset_info = {
            "task": task,
            "name": original_meta["name"],
            "target_class": None,
            "test_labels": True,
            "modalities": original_meta["modality"],
            "dim": 3,
            "info": {
                "original_labels": original_meta["labels"],
                "original_numTraining": original_meta["numTraining"],
            },
        }
        dataset_info.update(decathlon_props[task])
        save_json(dataset_info, task_data_dir / "dataset.json")

        # prepare data and labels
        case_ids = get_case_ids_from_dir(source_data_dir,
                                         remove_modality=False)
        case_ids = sorted([c for c in case_ids if c])
        logger.info(f"Found {len(case_ids)} for preparation.")

        for cid in maybe_verbose_iterable(case_ids):
            process_case(
                cid,
                source_data_dir,
                source_labels_dir,
                target_data_dir,
                target_label_dir,
            )

        # with Pool(processes=6) as p:
        #     p.starmap(process_case, zip(case_ids,
        #                                 repeat(source_images),
        #                                 repeat(source_labels),
        #                                 repeat(target_images),
        #                                 repeat(target_labels),
        #                                 ))

        # create an artificial test split
        create_test_split(
            splitted_dir=splitted_dir,
            num_modalities=1,
            test_size=0.3,
            random_state=0,
            shuffle=True,
        )
コード例 #8
0
ファイル: analysis.py プロジェクト: MIC-DKFZ/nnDetection
def run_analysis_suite(prediction_dir: Path, gt_dir: Path, save_dir: Path):
    for iou, score in maybe_verbose_iterable(
            list(product([0.1, 0.5], [0.1, 0.5]))):
        _save_dir = save_dir / f"iou_{iou}_score_{score}"
        _save_dir.mkdir(parents=True, exist_ok=True)

        found_predictions = list(prediction_dir.glob("*_boxes.pkl"))
        logger.info(f"Found {len(found_predictions)} predictions for analysis")

        df, analysis_ids = collect_overview(
            prediction_dir,
            gt_dir,
            iou=iou,
            score=score,
            max_num_fp_per_image=5,
            top_n=10,
        )
        df.to_json(_save_dir / "analysis.json", indent=4, orient='index')
        df.to_csv(_save_dir / "analysis.csv")
        save_json(analysis_ids, _save_dir / "analysis_ids.json")

        all_pred, all_target, all_pred_ious, all_pred_scores = collect_score_iou(
            prediction_dir, gt_dir, iou=iou, score=score)
        confusion_ax = plot_confusion_matrix(all_pred,
                                             all_target,
                                             iou=iou,
                                             score=score)
        plt.savefig(_save_dir / "confusion_matrix.png")
        plt.close()

        iou_score_ax = plot_joint_iou_score(all_pred_ious, all_pred_scores)
        plt.savefig(_save_dir / "joint_iou_score.png")
        plt.close()

        all_pred, all_target, all_boxes = collect_boxes(prediction_dir,
                                                        gt_dir,
                                                        iou=iou,
                                                        score=score)
        sizes_fig, sizes_ax = plot_sizes(all_pred,
                                         all_target,
                                         all_boxes,
                                         iou=iou,
                                         score=score)
        plt.savefig(_save_dir / "sizes.png")
        with open(str(_save_dir / 'sizes.pkl'), "wb") as fp:
            pickle.dump(sizes_fig, fp, protocol=4)
        plt.close()

        sizes_fig, sizes_ax = plot_sizes_bar(all_pred,
                                             all_target,
                                             all_boxes,
                                             iou=iou,
                                             score=score)
        plt.savefig(_save_dir / "sizes_bar.png")
        with open(str(_save_dir / 'sizes_bar.pkl'), "wb") as fp:
            pickle.dump(sizes_fig, fp, protocol=4)
        plt.close()

        sizes_fig, sizes_ax = plot_sizes_bar(all_pred,
                                             all_target,
                                             all_boxes,
                                             iou=iou,
                                             score=score,
                                             max_bin=100)
        plt.savefig(_save_dir / "sizes_bar_100.png")
        with open(str(_save_dir / 'sizes_bar_100.pkl'), "wb") as fp:
            pickle.dump(sizes_fig, fp, protocol=4)
        plt.close()
コード例 #9
0
ファイル: prepare_mic.py プロジェクト: MIC-DKFZ/nnDetection
    task_data_dir = det_data_dir / "Task012_LIDC"
    source_data_dir = task_data_dir / "raw"

    if not (p := source_data_dir / "data_nrrd").is_dir():
        raise ValueError(f"Expted {p} to contain LIDC data")
    if not (p := source_data_dir / 'characteristics.csv').is_file():
        raise ValueError(f"Expted {p} to contain exist")

    target_dir = task_data_dir / "raw_splitted"
    target_data_dir = task_data_dir / "raw_splitted" / "imagesTr"
    target_data_dir.mkdir(exist_ok=True, parents=True)
    target_label_dir = task_data_dir / "raw_splitted" / "labelsTr"
    target_label_dir.mkdir(exist_ok=True, parents=True)

    logger.remove()
    logger.add(sys.stdout, level="INFO")
    logger.add(task_data_dir / "prepare.log", level="DEBUG")

    data_dir = source_data_dir / "data_nrrd"
    case_dirs = [x for x in data_dir.iterdir() if x.is_dir()]
    df = pd.read_csv(source_data_dir / 'characteristics.csv', sep=';')

    for cd in maybe_verbose_iterable(case_dirs):
        prepare_case(cd, target_dir, df)

    # TODO download custom split file


if __name__ == '__main__':
    main()
コード例 #10
0
ファイル: prepare.py プロジェクト: MIC-DKFZ/nnDetection
def main():
    """
    Does not use the KTrans Sequence of ProstateX
    This script only uses the provided T2 masks
    """
    det_data_dir = Path(os.getenv('det_data'))
    task_data_dir = det_data_dir / "Task021_ProstateX"

    # setup raw paths
    source_data_dir = task_data_dir / "raw"
    if not source_data_dir.is_dir():
        raise RuntimeError(f"{source_data_dir} should contain the raw data but does not exist.")

    source_data = source_data_dir / "PROSTATEx"
    source_masks = source_data_dir / "rcuocolo-PROSTATEx_masks-e344452"
    source_ktrans = source_data_dir / "ktrains"
    csv_labels = source_data_dir / "ProstateX-TrainingLesionInformationv2" / "ProstateX-Findings-Train.csv"
    csv_masks = source_data_dir / "rcuocolo-PROSTATEx_masks-e344452" / "Files" / "Image_list.csv"

    data_target = task_data_dir / "raw_splitted" / "imagesTr"
    data_target.mkdir(parents=True, exist_ok=True)
    label_target = task_data_dir / "raw_splitted" / "labelsTr"
    label_target.mkdir(parents=True, exist_ok=True)

    logger.remove()
    logger.add(sys.stdout, format="{level} {message}", level="INFO")
    logger.add(data_target.parent.parent / "prepare.log", level="DEBUG")

    base_masks = source_masks / "Files" / "Masks"
    t2_masks = base_masks / "T2"

    df_labels = pd.read_csv(csv_labels)
    df_masks = pd.read_csv(csv_masks)
    case_ids = [f.stem.split("-", 2)[:2] for f in t2_masks.glob("*nii.gz")]
    case_ids = list(set([f"{c[0]}-{c[1]}" for c in case_ids]))
    logger.info(f"Found {len(case_ids)} cases")

    # save meta
    logger.info("Saving dataset info")
    dataset_info = {
        "name": "ProstateX",
        "task": "Task021_ProstateX",

        "target_class": None,
        "test_labels": False,

        "labels": {
            "0": "clinically_significant",
            "1": "clinically_insignificant",
        },
        "modalities": {
            "0": "T2",
            "1": "ADC",
            "2": "PD-W",
            "3": "Ktrans"
        },
        "dim": 3,
        "info": "Ground Truth: T2 Masks; \n"
                "Modalities: T2, ADC, PD-W, Ktrans \n;"
                "Classes: clinically significant = 1, insignificant = 0 \n"
                "Keep: ProstateX-0025 '10-28-2011-MR prostaat kanker detectie WDSmc MCAPRODETW-19047'\n"
                "Masks\n"
                "https://github.com/rcuocolo/PROSTATEx_masks\n"
                "Github hash: e3444521e70cd5e8d405f4e9a6bc08312df8afe7"
    }
    save_json(dataset_info, task_data_dir / "dataset.json")

    # prepare labels and data
    for cid in maybe_verbose_iterable(case_ids):
        prepare_case(cid,
                     data_dirs=source_data,
                     ktrans_dirs=source_ktrans,
                     t2_masks=t2_masks,
                     df_labels=df_labels,
                     df_masks=df_masks,
                     data_target=data_target,
                     label_target=label_target,
                     )

    # with Pool(processes=6) as p:
    #     p.starmap(prepare_case, zip(case_ids,
    #                                 repeat(source_data),
    #                                 repeat(source_ktrans),
    #                                 repeat(t2_masks),
    #                                 repeat(df_labels),
    #                                 repeat(df_masks),
    #                                 repeat(data_target),
    #                                 repeat(label_target),
    #                                 ))

    # create test split
    create_test_split(task_data_dir / "raw_splitted",
                      num_modalities=len(dataset_info["modalities"]),
                      test_size=0.3,
                      random_state=0,
                      shuffle=True,
                      )
コード例 #11
0
def main():
    det_data_dir = Path(os.getenv('det_data'))
    task_data_dir = det_data_dir / "Task011_Kits"
    source_data_dir = task_data_dir / "raw"

    if not source_data_dir.is_dir():
        raise RuntimeError(
            f"{source_data_dir} should contain the raw data but does not exist."
        )

    splitted_dir = task_data_dir / "raw_splitted"
    target_data_dir = task_data_dir / "raw_splitted" / "imagesTr"
    target_data_dir.mkdir(exist_ok=True, parents=True)
    target_label_dir = task_data_dir / "raw_splitted" / "labelsTr"
    target_label_dir.mkdir(exist_ok=True, parents=True)

    logger.remove()
    logger.add(sys.stdout, level="INFO")
    logger.add(task_data_dir / "prepare.log", level="DEBUG")

    # save meta info
    dataset_info = {
        "name": "Kits",
        "task": "Task011_Kits",
        "target_class": None,
        "test_labels": True,
        "seg2det_stuff": [
            1,
        ],  # define stuff classes: kidney
        "seg2det_things": [
            2,
        ],  # define things classes: tumor
        "min_size": 3.,
        "labels": {
            "0": "lesion"
        },
        "labels_stuff": {
            "1": "kidney"
        },
        "modalities": {
            "0": "CT"
        },
        "dim": 3,
    }
    save_json(dataset_info, task_data_dir / "dataset.json")

    # prepare cases
    cases = [str(c.name) for c in source_data_dir.iterdir() if c.is_dir()]
    for c in maybe_verbose_iterable(cases):
        logger.info(f"Copy case {c}")
        case_id = int(c.split("_")[-1])
        if case_id < 210:
            shutil.copy(source_data_dir / c / "imaging.nii.gz",
                        target_data_dir / f"{c}_0000.nii.gz")
            shutil.copy(source_data_dir / c / "segmentation.nii.gz",
                        target_label_dir / f"{c}.nii.gz")

    # create an artificial test split
    create_test_split(
        splitted_dir=splitted_dir,
        num_modalities=1,
        test_size=0.3,
        random_state=0,
        shuffle=True,
    )
コード例 #12
0
def main():
    det_data_dir = Path(os.getenv("det_data"))
    task_data_dir = det_data_dir / "Task025_LymphNodes"
    source_data_base = task_data_dir / "raw"
    if not source_data_base.is_dir():
        raise RuntimeError(
            f"{source_data_base} should contain the raw data but does not exist."
        )

    raw_splitted_dir = task_data_dir / "raw_splitted"
    (raw_splitted_dir / "imagesTr").mkdir(parents=True, exist_ok=True)
    (raw_splitted_dir / "labelsTr").mkdir(parents=True, exist_ok=True)
    (raw_splitted_dir / "imagesTs").mkdir(parents=True, exist_ok=True)
    (raw_splitted_dir / "labelsTs").mkdir(parents=True, exist_ok=True)

    logger.remove()
    logger.add(sys.stdout, format="{level} {message}", level="DEBUG")
    logger.add(raw_splitted_dir.parent / "prepare.log", level="DEBUG")

    meta = {
        "name": "Lymph Node TCIA",
        "task": "Task025_LymphNodes",
        "target_class": None,
        "test_labels": True,
        "labels": {
            "0": "LymphNode",
        },
        "modalities": {
            "0": "CT",
        },
        "dim": 3,
    }

    save_json(meta, raw_splitted_dir.parent / "dataset.json")

    base_dir = source_data_base / "CT Lymph Nodes"
    mask_dir = source_data_base / "MED_ABD_LYMPH_MASKS"

    case_ids = sorted([p.name for p in base_dir.iterdir() if p.is_dir()])
    logger.info(f"Found {len(case_ids)} cases in {base_dir}")

    for cid in maybe_verbose_iterable(case_ids):
        prepare_image(
            case_id=cid,
            base_dir=base_dir,
            mask_dir=mask_dir,
            raw_splitted_dir=raw_splitted_dir,
        )

    # with Pool(processes=6) as p:
    #     p.starmap(
    #         prepare_image,
    #         zip(
    #             case_ids,
    #             repeat(base_dir),
    #             repeat(mask_dir),
    #             repeat(raw_splitted_dir)
    #         )
    #     )

    create_test_split(
        raw_splitted_dir,
        num_modalities=len(meta["modalities"]),
        test_size=0.3,
        random_state=0,
        shuffle=True,
    )
コード例 #13
0
def boxes2nii():
    import os
    import argparse
    from pathlib import Path

    import numpy as np
    import SimpleITK as sitk
    from loguru import logger

    from nndet.io import save_json, load_pickle
    from nndet.io.paths import get_task, get_training_dir
    from nndet.utils.info import maybe_verbose_iterable

    parser = argparse.ArgumentParser()
    parser.add_argument('task',
                        type=str,
                        help="Task id e.g. Task12_LIDC OR 12 OR LIDC")
    parser.add_argument('model',
                        type=str,
                        help="model name, e.g. RetinaUNetV0")
    parser.add_argument('-f',
                        '--fold',
                        type=int,
                        help="fold to sweep.",
                        default=0,
                        required=False)
    parser.add_argument('-o',
                        '--overwrites',
                        type=str,
                        nargs='+',
                        help="overwrites for config file",
                        required=False)
    parser.add_argument(
        '--threshold',
        type=float,
        help="Minimum probability of predictions",
        required=False,
        default=0.5,
    )
    parser.add_argument('--test', action='store_true')

    args = parser.parse_args()
    model = args.model
    fold = args.fold
    task = args.task
    overwrites = args.overwrites
    test = args.test
    threshold = args.threshold

    task_name = get_task(task, name=True, models=True)
    task_dir = Path(os.getenv("det_models")) / task_name

    training_dir = get_training_dir(task_dir / model, fold)

    overwrites = overwrites if overwrites is not None else []
    overwrites.append("host.parent_data=${env:det_data}")
    overwrites.append("host.parent_results=${env:det_models}")

    prediction_dir = training_dir / "test_predictions" \
        if test else training_dir / "val_predictions"
    save_dir = training_dir / "test_predictions_nii" \
        if test else training_dir / "val_predictions_nii"
    save_dir.mkdir(exist_ok=True)

    case_ids = [
        p.stem.rsplit('_', 1)[0] for p in prediction_dir.glob("*_boxes.pkl")
    ]
    for cid in maybe_verbose_iterable(case_ids):
        res = load_pickle(prediction_dir / f"{cid}_boxes.pkl")

        instance_mask = np.zeros(res["original_size_of_raw_data"],
                                 dtype=np.uint8)

        boxes = res["pred_boxes"]
        scores = res["pred_scores"]
        labels = res["pred_labels"]

        _mask = scores >= threshold
        boxes = boxes[_mask]
        labels = labels[_mask]
        scores = scores[_mask]

        idx = np.argsort(scores)
        scores = scores[idx]
        boxes = boxes[idx]
        labels = labels[idx]

        prediction_meta = {}
        for instance_id, (pbox, pscore,
                          plabel) in enumerate(zip(boxes, scores, labels),
                                               start=1):
            mask_slicing = [
                slice(int(pbox[0]), int(pbox[2])),
                slice(int(pbox[1]), int(pbox[3])),
            ]
            if instance_mask.ndim == 3:
                mask_slicing.append(slice(int(pbox[4]), int(pbox[5])))
            instance_mask[tuple(mask_slicing)] = instance_id

            prediction_meta[int(instance_id)] = {
                "score": float(pscore),
                "label": int(plabel),
                "box": list(map(int, pbox))
            }

        logger.info(
            f"Created instance mask with {instance_mask.max()} instances.")

        instance_mask_itk = sitk.GetImageFromArray(instance_mask)
        instance_mask_itk.SetOrigin(res["itk_origin"])
        instance_mask_itk.SetDirection(res["itk_direction"])
        instance_mask_itk.SetSpacing(res["itk_spacing"])

        sitk.WriteImage(instance_mask_itk,
                        str(save_dir / f"{cid}_boxes.nii.gz"))
        save_json(prediction_meta, save_dir / f"{cid}_boxes.json")
コード例 #14
0
def seg2nii():
    import os
    import argparse
    from pathlib import Path

    import SimpleITK as sitk

    from nndet.io import load_pickle
    from nndet.io.paths import get_task, get_training_dir
    from nndet.utils.info import maybe_verbose_iterable

    parser = argparse.ArgumentParser()
    parser.add_argument('task',
                        type=str,
                        help="Task id e.g. Task12_LIDC OR 12 OR LIDC")
    parser.add_argument('model',
                        type=str,
                        help="model name, e.g. RetinaUNetV0")
    parser.add_argument('-f',
                        '--fold',
                        type=int,
                        help="fold to sweep.",
                        default=0,
                        required=False)
    parser.add_argument('-o',
                        '--overwrites',
                        type=str,
                        nargs='+',
                        help="overwrites for config file",
                        required=False)
    parser.add_argument('--test', action='store_true')

    args = parser.parse_args()
    model = args.model
    fold = args.fold
    task = args.task
    overwrites = args.overwrites
    test = args.test

    task_name = get_task(task, name=True, models=True)
    task_dir = Path(os.getenv("det_models")) / task_name

    training_dir = get_training_dir(task_dir / model, fold)

    overwrites = overwrites if overwrites is not None else []
    overwrites.append("host.parent_data=${env:det_data}")
    overwrites.append("host.parent_results=${env:det_models}")

    prediction_dir = training_dir / "test_predictions" \
        if test else training_dir / "val_predictions"
    save_dir = training_dir / "test_predictions_nii" \
        if test else training_dir / "val_predictions_nii"
    save_dir.mkdir(exist_ok=True)

    case_ids = [
        p.stem.rsplit('_', 1)[0] for p in prediction_dir.glob("*_seg.pkl")
    ]
    for cid in maybe_verbose_iterable(case_ids):
        res = load_pickle(prediction_dir / f"{cid}_seg.pkl")

        seg_itk = sitk.GetImageFromArray(res["pred_seg"])
        seg_itk.SetOrigin(res["itk_origin"])
        seg_itk.SetDirection(res["itk_direction"])
        seg_itk.SetSpacing(res["itk_spacing"])

        sitk.WriteImage(seg_itk, str(save_dir / f"{cid}_seg.nii.gz"))
コード例 #15
0
def check_data_and_label_splitted(
    task_name: str,
    test: bool = False,
    labels: bool = True,
    full_check: bool = True,
    ):
    """
    Perform checks of data and label in raw splitted format

    Args:
        task_name: name of task to check
        test: check test data
        labels: check labels
        full_check: Per default a full check will be performed which needs to
            load all files. If this is disabled, a computationall light check
            will be performed 

    Raises:
        ValueError: if not all raw splitted files were found
        ValueError: missing label info file
        ValueError: instances in label info file need to start at 1
        ValueError: instances in label info file need to be consecutive
    """
    print("Start data and label check.")
    cfg = load_dataset_info(get_task(task_name))

    splitted_paths = get_paths_from_splitted_dir(
        num_modalities=len(cfg["modalities"]),
        splitted_4d_output_dir=Path(os.getenv('det_data')) / task_name / "raw_splitted",
        labels=labels,
        test=test,
    )

    for case_paths in maybe_verbose_iterable(splitted_paths):
        # check all files exist
        for cp in case_paths:
            if not Path(cp).is_file():
                raise ValueError(f"Expected {cp} to be a raw splitted "
                                 "data path but it does not exist.")

        if labels:
            # check label info (json files)
            mask_path = case_paths[-1]
            mask_info_path = mask_path.parent / f"{mask_path.stem.split('.')[0]}.json"
            if not Path(mask_info_path).is_file():
                raise ValueError(f"Expected {mask_info_path} to be a raw splitted "
                                "mask info path but it does not exist.")
            mask_info = load_json(mask_info_path)
            if mask_info["instances"]:
                mask_info_instances = list(map(int, mask_info["instances"].keys()))

                if j := not min(mask_info_instances) == 1:
                    raise ValueError(f"Instance IDs need to start at 1, found {j} in {mask_info_path}")

                for i in range(1, len(mask_info_instances) + 1):
                    if i not in mask_info_instances:
                        raise ValueError(f"Exptected {i} to be an Instance ID in "
                                        f"{mask_info_path} but only found {mask_info_instances}")
        else:
            mask_info_path = None

        if full_check:
            _full_check(case_paths, mask_info_path)