Exemplo n.º 1
0
def import_single_case(
    logits_source: Path,
    logits_target_dir: Optional[Path],
    aggregation: str,
    min_num_voxel: int,
    min_threshold: Optional[float],
    save_seg: bool = True,
    save_iseg: bool = True,
    stuff: Optional[Sequence[int]] = None,
):
    """
    Process a single case

    Args:
        logits_source: path to nnunet prediction
        logits_target_dir: path to dir where result should be saved
        aggregation: aggregation method for probabilities.
        save_seg: save semantic segmentation
        save_iseg: save instance segmentation
        stuff: stuff classes to remove
    """
    assert logits_source.is_file(
    ), f"Logits source needs to be a file, found {logits_source}"
    assert logits_target_dir.is_dir(
    ), f"Logits target dir needs to be a dir, found {logits_target_dir}"

    case_name = logits_source.stem
    logger.info(f"Processing {case_name}")
    properties_file = logits_source.parent / f"{case_name}.pkl"
    probs = np.load(str(logits_source))["softmax"]

    if properties_file.is_file():
        properties_dict = load_pickle(properties_file)
        bbox = properties_dict.get('crop_bbox')
        shape_original_before_cropping = properties_dict.get(
            'original_size_of_raw_data')

        if bbox is not None:
            tmp = np.zeros((probs.shape[0], *shape_original_before_cropping))
            for c in range(3):
                bbox[c][1] = np.min((bbox[c][0] + probs.shape[c + 1],
                                     shape_original_before_cropping[c]))

            tmp[:, bbox[0][0]:bbox[0][1], bbox[1][0]:bbox[1][1],
                bbox[2][0]:bbox[2][1]] = probs
            probs = tmp

    res = instance_results_from_seg(
        probs,
        aggregation=aggregation,
        min_num_voxel=min_num_voxel,
        min_threshold=min_threshold,
        stuff=stuff,
    )

    detection_target = logits_target_dir / f"{case_name}_boxes.pkl"
    segmentation_target = logits_target_dir / f"{case_name}_segmentation.pkl"
    instances_target = logits_target_dir / f"{case_name}_instances.pkl"

    boxes = {
        key: res[key]
        for key in ["pred_boxes", "pred_labels", "pred_scores"]
    }
    save_pickle(boxes, detection_target)
    if save_iseg:
        instances = {
            key: res[key]
            for key in ["pred_instances", "pred_labels", "pred_scores"]
        }
        save_pickle(instances, instances_target)
    if save_seg:
        segmentation = {"pred_seg": np.argmax(probs, axis=0)}
        save_pickle(segmentation, segmentation_target)
Exemplo n.º 2
0
        nnunet_prediction_dir.mkdir(parents=True, exist_ok=True)

        if num_workers > 0:
            with Pool(processes=max(num_workers // 4, 1)) as p:
                p.starmap(
                    copy_and_ensemble_test,
                    zip(
                        case_ids,
                        repeat(nnunet_dirs),
                        repeat(nnunet_prediction_dir),
                    ))
        else:
            for cid in case_ids:
                copy_and_ensemble_test(cid, nnunet_dirs, nnunet_prediction_dir)

        postprocessing_settings = load_pickle(nndet_unet_dir /
                                              "postprocessing.pkl")
        target_dir = nndet_unet_dir / "test_predictions"

    logger.info(f"Creating final predictions")
    target_dir.mkdir(parents=True, exist_ok=True)
    import_dir(
        nnunet_prediction_dir=nnunet_prediction_dir,
        target_dir=target_dir,
        save_seg=save_seg,
        save_iseg=save_iseg,
        stuff=stuff,
        num_workers=num_workers,
        **postprocessing_settings,
    )
Exemplo n.º 3
0
def boxes2nii():
    import os
    import argparse
    from pathlib import Path

    import numpy as np
    import SimpleITK as sitk
    from loguru import logger

    from nndet.io import save_json, load_pickle
    from nndet.io.paths import get_task, get_training_dir
    from nndet.utils.info import maybe_verbose_iterable

    parser = argparse.ArgumentParser()
    parser.add_argument('task',
                        type=str,
                        help="Task id e.g. Task12_LIDC OR 12 OR LIDC")
    parser.add_argument('model',
                        type=str,
                        help="model name, e.g. RetinaUNetV0")
    parser.add_argument('-f',
                        '--fold',
                        type=int,
                        help="fold to sweep.",
                        default=0,
                        required=False)
    parser.add_argument('-o',
                        '--overwrites',
                        type=str,
                        nargs='+',
                        help="overwrites for config file",
                        required=False)
    parser.add_argument(
        '--threshold',
        type=float,
        help="Minimum probability of predictions",
        required=False,
        default=0.5,
    )
    parser.add_argument('--test', action='store_true')

    args = parser.parse_args()
    model = args.model
    fold = args.fold
    task = args.task
    overwrites = args.overwrites
    test = args.test
    threshold = args.threshold

    task_name = get_task(task, name=True, models=True)
    task_dir = Path(os.getenv("det_models")) / task_name

    training_dir = get_training_dir(task_dir / model, fold)

    overwrites = overwrites if overwrites is not None else []
    overwrites.append("host.parent_data=${env:det_data}")
    overwrites.append("host.parent_results=${env:det_models}")

    prediction_dir = training_dir / "test_predictions" \
        if test else training_dir / "val_predictions"
    save_dir = training_dir / "test_predictions_nii" \
        if test else training_dir / "val_predictions_nii"
    save_dir.mkdir(exist_ok=True)

    case_ids = [
        p.stem.rsplit('_', 1)[0] for p in prediction_dir.glob("*_boxes.pkl")
    ]
    for cid in maybe_verbose_iterable(case_ids):
        res = load_pickle(prediction_dir / f"{cid}_boxes.pkl")

        instance_mask = np.zeros(res["original_size_of_raw_data"],
                                 dtype=np.uint8)

        boxes = res["pred_boxes"]
        scores = res["pred_scores"]
        labels = res["pred_labels"]

        _mask = scores >= threshold
        boxes = boxes[_mask]
        labels = labels[_mask]
        scores = scores[_mask]

        idx = np.argsort(scores)
        scores = scores[idx]
        boxes = boxes[idx]
        labels = labels[idx]

        prediction_meta = {}
        for instance_id, (pbox, pscore,
                          plabel) in enumerate(zip(boxes, scores, labels),
                                               start=1):
            mask_slicing = [
                slice(int(pbox[0]), int(pbox[2])),
                slice(int(pbox[1]), int(pbox[3])),
            ]
            if instance_mask.ndim == 3:
                mask_slicing.append(slice(int(pbox[4]), int(pbox[5])))
            instance_mask[tuple(mask_slicing)] = instance_id

            prediction_meta[int(instance_id)] = {
                "score": float(pscore),
                "label": int(plabel),
                "box": list(map(int, pbox))
            }

        logger.info(
            f"Created instance mask with {instance_mask.max()} instances.")

        instance_mask_itk = sitk.GetImageFromArray(instance_mask)
        instance_mask_itk.SetOrigin(res["itk_origin"])
        instance_mask_itk.SetDirection(res["itk_direction"])
        instance_mask_itk.SetSpacing(res["itk_spacing"])

        sitk.WriteImage(instance_mask_itk,
                        str(save_dir / f"{cid}_boxes.nii.gz"))
        save_json(prediction_meta, save_dir / f"{cid}_boxes.json")
Exemplo n.º 4
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        'task',
        type=str,
        help="Task id e.g. Task12_LIDC OR 12 OR LIDC",
    )
    parser.add_argument(
        'model',
        type=str,
        help="model name, e.g. RetinaUNetV0",
    )
    parser.add_argument(
        '-o',
        '--overwrites',
        type=str,
        nargs='+',
        required=False,
        help="overwrites for config file. Only needed in case of box eval",
    )
    parser.add_argument(
        '-c',
        '--consolidate',
        type=str,
        default="export",
        required=False,
        help=
        ("Determines how to consolidate predictions: 'export' or 'copy'. "
         "'copy' will copy the predictions of each fold into the directory for evaluation. "
         "'export' will use the updated parameters after consolidation to update the "
         "predictions and export them. This is only supported if one of the "
         "sweep settings is active! Default: export"),
    )
    parser.add_argument(
        '--num_folds',
        type=int,
        default=5,
        required=False,
        help="Number of folds. Default: 5",
    )
    parser.add_argument(
        '--no_model',
        action="store_false",
        help="Deactivate if consolidating nnUNet results",
    )
    parser.add_argument(
        '--sweep_boxes',
        action="store_true",
        help="Sweep for best parameters for bounding box based models",
    )
    parser.add_argument(
        '--sweep_instances',
        action="store_true",
        help="Sweep for best parameters for instance segmentation based models",
    )
    parser.add_argument(
        '--ckpt',
        type=str,
        default="last",
        required=False,
        help="Define identifier of checkpoint for consolidation. "
        "Use this with care!")

    args = parser.parse_args()
    model = args.model
    task = args.task
    ov = args.overwrites

    consolidate = args.consolidate
    num_folds = args.num_folds
    do_model_consolidation = args.no_model

    sweep_boxes = args.sweep_boxes
    sweep_instances = args.sweep_instances
    ckpt = args.ckpt

    if consolidate == "export" and not (sweep_boxes or sweep_instances):
        raise ValueError(
            "Export needs new parameter sweep! Actiate one of the sweep "
            "arguments or change to copy mode")

    task_dir = Path(os.getenv("det_models")) / get_task(
        task, name=True, models=True)
    model_dir = task_dir / model
    if not model_dir.is_dir():
        raise ValueError(f"{model_dir} does not exist")
    target_dir = model_dir / "consolidated"

    logger.remove()
    logger.add(sys.stdout, format="{level} {message}", level="INFO")
    logger.add(Path(target_dir) / "consolidate.log", level="DEBUG")

    logger.info(f"looking for models in {model_dir}")
    training_dirs = [
        get_latest_model(model_dir, fold) for fold in range(num_folds)
    ]
    logger.info(f"Found training dirs: {training_dirs}")

    # model consolidation
    if do_model_consolidation:
        logger.info("Consolidate models")
        if ckpt != "last":
            logger.warning(
                f"Found ckpt overwrite {ckpt}, this is not the default, "
                "this can drastically influence the performance!")
        consolidate_models(training_dirs, target_dir, ckpt)

    # consolidate predictions
    logger.info("Consolidate predictions")
    consolidate_predictions(
        source_dirs=training_dirs,
        target_dir=target_dir,
        consolidate=consolidate,
    )

    shutil.copy2(training_dirs[0] / "plan.pkl", target_dir)
    shutil.copy2(training_dirs[0] / "config.yaml", target_dir)

    # invoke new parameter sweeps
    cfg = OmegaConf.load(str(target_dir / "config.yaml"))
    ov = ov if ov is not None else []
    ov.append("host.parent_data=${env:det_data}")
    ov.append("host.parent_results=${env:det_models}")
    if ov is not None:
        cfg.merge_with_dotlist(ov)

    preprocessed_output_dir = Path(cfg["host"]["preprocessed_output_dir"])
    plan = load_pickle(target_dir / "plan.pkl")
    gt_dir = preprocessed_output_dir / plan["data_identifier"] / "labelsTr"

    if sweep_boxes:
        logger.info("Sweeping box predictions")
        module = MODULE_REGISTRY[cfg["module"]]
        ensembler_cls = module.get_ensembler_cls(
            key="boxes",
            dim=plan["network_dim"])  # TODO: make this configurable

        sweeper = BoxSweeper(
            classes=[item for _, item in cfg["data"]["labels"].items()],
            pred_dir=target_dir / "sweep_predictions",
            gt_dir=gt_dir,
            target_metric=cfg["trainer_cfg"].get(
                "eval_score_key", "mAP_IoU_0.10_0.50_0.05_MaxDet_100"),
            ensembler_cls=ensembler_cls,
            save_dir=target_dir / "sweep",
        )
        inference_plan = sweeper.run_postprocessing_sweep()
    elif sweep_instances:
        raise NotImplementedError

    plan = load_pickle(target_dir / "plan.pkl")
    if consolidate != 'copy':
        plan["inference_plan"] = inference_plan
        save_pickle(plan, target_dir / "plan_inference.pkl")

        for restore in [True, False]:
            export_dir = target_dir / "val_predictions" if restore else \
                target_dir / "val_predictions_preprocessed"
            extract_results(
                source_dir=target_dir / "sweep_predictions",
                target_dir=export_dir,
                ensembler_cls=ensembler_cls,
                restore=restore,
                **inference_plan,
            )
    else:
        logger.warning("Plan used from fold 0, not updated with consolidation")
        save_pickle(plan, target_dir / "plan_inference.pkl")
Exemplo n.º 5
0
def seg2nii():
    import os
    import argparse
    from pathlib import Path

    import SimpleITK as sitk

    from nndet.io import load_pickle
    from nndet.io.paths import get_task, get_training_dir
    from nndet.utils.info import maybe_verbose_iterable

    parser = argparse.ArgumentParser()
    parser.add_argument('task',
                        type=str,
                        help="Task id e.g. Task12_LIDC OR 12 OR LIDC")
    parser.add_argument('model',
                        type=str,
                        help="model name, e.g. RetinaUNetV0")
    parser.add_argument('-f',
                        '--fold',
                        type=int,
                        help="fold to sweep.",
                        default=0,
                        required=False)
    parser.add_argument('-o',
                        '--overwrites',
                        type=str,
                        nargs='+',
                        help="overwrites for config file",
                        required=False)
    parser.add_argument('--test', action='store_true')

    args = parser.parse_args()
    model = args.model
    fold = args.fold
    task = args.task
    overwrites = args.overwrites
    test = args.test

    task_name = get_task(task, name=True, models=True)
    task_dir = Path(os.getenv("det_models")) / task_name

    training_dir = get_training_dir(task_dir / model, fold)

    overwrites = overwrites if overwrites is not None else []
    overwrites.append("host.parent_data=${env:det_data}")
    overwrites.append("host.parent_results=${env:det_models}")

    prediction_dir = training_dir / "test_predictions" \
        if test else training_dir / "val_predictions"
    save_dir = training_dir / "test_predictions_nii" \
        if test else training_dir / "val_predictions_nii"
    save_dir.mkdir(exist_ok=True)

    case_ids = [
        p.stem.rsplit('_', 1)[0] for p in prediction_dir.glob("*_seg.pkl")
    ]
    for cid in maybe_verbose_iterable(case_ids):
        res = load_pickle(prediction_dir / f"{cid}_seg.pkl")

        seg_itk = sitk.GetImageFromArray(res["pred_seg"])
        seg_itk.SetOrigin(res["itk_origin"])
        seg_itk.SetDirection(res["itk_direction"])
        seg_itk.SetSpacing(res["itk_spacing"])

        sitk.WriteImage(seg_itk, str(save_dir / f"{cid}_seg.nii.gz"))