Exemple #1
0
    def __init__(
        self,
        preprocessed_output_dir: os.PathLike,
        save_dir: os.PathLike,
        network_cls: Type[AbstractModel] = None,
        estimator: MemoryEstimator = None,
        **kwargs,
    ):
        """
        Plan the architecture for training

        Args:
            min_feature_map_length (int): minimal size of feature map in bottleneck
        """
        super().__init__(**kwargs)

        self.preprocessed_output_dir = Path(preprocessed_output_dir)
        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(parents=True, exist_ok=True)

        self.network_cls = network_cls
        self.estimator = estimator

        self.dataset_properties = load_pickle(self.preprocessed_output_dir /
                                              "properties" /
                                              'dataset_properties.pkl')

        # parameters initialized from process properties
        self.all_boxes: np.ndarray = None
        self.all_ious: np.ndarray = None
        self.class_ious: Dict[str, np.ndarray] = None
        self.num_instances: Dict[int, int] = None
        self.dim: int = None
        self.architecture_kwargs: dict = {}
        self.transpose_forward = None
Exemple #2
0
def collect_boxes(prediction_dir: Path, gt_dir: Path, iou: float,
                  score: float):
    all_pred = []
    all_target = []
    all_boxes = []

    i = 0
    for f in prediction_dir.glob("*_boxes.pkl"):
        case_id = f.stem.rsplit('_', 1)[0]

        gt_data = np.load(str(gt_dir / f"{case_id}_boxes_gt.npz"),
                          allow_pickle=True)
        gt_boxes = gt_data["boxes"]
        gt_classes = gt_data["classes"]
        gt_ignore = [
            np.zeros(gt_boxes_img.shape[0]).reshape(-1, 1)
            for gt_boxes_img in [gt_boxes]
        ]

        case_result = load_pickle(f)
        pred_boxes = case_result["pred_boxes"]
        pred_scores = case_result["pred_scores"]
        pred_labels = case_result["pred_labels"]

        keep = pred_scores > score
        pred_boxes, pred_scores, pred_labels = pred_boxes[keep], pred_scores[
            keep], pred_labels[keep]

        # computation starts here
        if gt_boxes.size == 0:
            all_pred.append(pred_labels)
            all_target.append(np.ones(len(pred_labels)) * -1)
            all_boxes.append(pred_boxes)
        elif pred_boxes.size == 0:
            all_pred.append(np.ones(len(gt_classes)) * -1)
            all_target.append(gt_classes)
            all_boxes.append(gt_boxes)
        else:
            match_quality_matrix = box_iou_np(gt_boxes, pred_boxes)

            matched_idxs = np.argmax(match_quality_matrix, axis=0)
            matched_vals = np.max(match_quality_matrix, axis=0)
            matched_idxs[matched_vals < iou] = -1

            matched_gt_boxes_per_image = gt_boxes[matched_idxs.clip(min=0)]
            target_labels = gt_classes[matched_idxs.clip(min=0)]
            target_labels[matched_idxs == -1] = -1

            all_pred.append(pred_labels)
            all_target.append(target_labels)
            all_boxes.append(pred_boxes)

            unmatched_gt = (match_quality_matrix.max(axis=1) < iou)
            false_negatives = unmatched_gt.sum()
            if false_negatives > 0:  # false negatives
                all_pred.append(np.ones(false_negatives) * -1)
                all_target.append(np.zeros(false_negatives))
                all_boxes.append(gt_boxes[np.nonzero(unmatched_gt)[0]])
    return all_pred, all_target, all_boxes
Exemple #3
0
def evaluate_box_dir(
    pred_dir: PathLike,
    gt_dir: PathLike,
    classes: Sequence[str],
    save_dir: Optional[Path] = None,
) -> Tuple[Dict, Dict]:
    """
    Run box evaluation inside a directory

    Args:
        pred_dir: path to dir with predictions
        gt_dir: path to dir with groud truth data
        classes: classes present in dataset
        save_dir: optional path to save plots

    Returns:
        Dict[str, float]: dictionary with scalar values for evaluation
        Dict[str, np.ndarray]: dictionary with arrays, e.g. for visualization of graphs
    
    See Also:
        :class:`nndet.evaluator.registry.BoxEvaluator`
    """
    pred_dir = Path(pred_dir)
    gt_dir = Path(gt_dir)
    if save_dir is not None:
        save_dir.mkdir(parents=True, exist_ok=True)
    case_ids = [
        p.stem.rsplit('_boxes', 1)[0] for p in pred_dir.iterdir()
        if p.is_file() and p.stem.endswith("_boxes")
    ]
    logger.info(f"Found {len(case_ids)} for box evaluation in {pred_dir}")

    evaluator = BoxEvaluator.create(
        classes=classes,
        fast=False,
        verbose=False,
        save_dir=save_dir,
    )

    for case_id in case_ids:
        gt = np.load(str(gt_dir / f"{case_id}_boxes_gt.npz"),
                     allow_pickle=True)
        pred = load_pickle(pred_dir / f"{case_id}_boxes.pkl")
        evaluator.run_online_evaluation(
            pred_boxes=[pred["pred_boxes"]],
            pred_classes=[pred["pred_labels"]],
            pred_scores=[pred["pred_scores"]],
            gt_boxes=[gt["boxes"]],
            gt_classes=[gt["classes"]],
            gt_ignore=None,
        )
    return evaluator.finish_online_evaluation()
Exemple #4
0
def check_case(case_npz: Path,
               case_pkl: Path = None,
               remove: bool = False,
               keys: Sequence[str] = ("data", "seg"),
               ) -> Tuple[str, bool]:
    """
    Check if a single cases loadable

    Args:
        case_npz (Path): path to npz file
        case_pkl (Path, optional): path to pkl file. Defaults to None.
        remove (bool, optional): if loading fails the file is the npz and pkl
            file are removed automatically. Defaults to False.

    Returns:
        str: case id
        bool: true if case was loaded correctly, false otherwise
    """
    logger.info(f"Checking {case_npz}")
    case_id = get_case_id_from_path(case_npz, remove_modality=False)
    try:
        case_dict = load_npz_looped(str(case_npz), keys=keys, num_tries=3)
        if "seg" in keys and case_pkl is not None:
            properties = load_pickle(case_pkl)
            seg = case_dict["seg"]
            seg_instances = np.unique(seg)  # automatically sorted
            seg_instances = seg_instances[seg_instances > 0]
            
            instances_properties = properties["instances"].keys()
            props_instances = np.sort(np.array(list(map(int, instances_properties))))
            
            if (len(seg_instances) != len(props_instances)) or any(seg_instances != props_instances):
                logger.warning(f"Inconsistent instances {case_npz} from "
                                f"properties {props_instances} from seg {seg_instances}. "
                                f"Very small instances can get lost in resampling "
                                f"but larger instances should not disappear!")       
            for i in seg_instances:
                if str(i) not in instances_properties:
                    raise RuntimeError(f"Found instance {seg_instances} in segmentation "
                                       f"which is not in properties {instances_properties}."
                                       f"Delete labels manually and rerun prepare label!")
    except Exception as e:
        logger.error(f"Failed to load {case_npz} with {e}")
        logger.error(f"{traceback.format_exc()}")
        if remove:
            os.remove(case_npz)
            if case_pkl is not None:
                os.remove(case_pkl)
        return case_id, False
    return case_id, True
Exemple #5
0
    def load_candidates(self, case_id: str, fg_crop: bool) -> Union[Dict, None]:
        """
        Load candidates for sampling

        Args:
            case_id: case id to load candidates from
            fg_crop: True if foreground crop will be sampled, False if
                background will be sampled

        Returns:
            Union[Dict, None]: dict if fg, None if bg
        """
        if fg_crop:
            return load_pickle(self._data[case_id]['boxes_file'])
        else:
            return None
Exemple #6
0
    def build_cache(self) -> Dict[str, List]:
        """
        Build up cache for sampling

        Returns:
            Dict[str, List]: cache for sampling
                `case`: list with all case identifiers
                `instances`: list with tuple of (case_id, instance_id)
        """
        instance_cache = []

        logger.info("Building Sampling Cache for Dataloder")
        for case_id, item in maybe_verbose_iterable(self._data.items(), desc="Sampling Cache"):
            instances = load_pickle(item['boxes_file'])["instances"]
            if instances:
                for instance_id in instances:
                    instance_cache.append((case_id, instance_id))
        return {"case": list(self._data.keys()), "instances": instance_cache}
Exemple #7
0
def evaluate_case_dir(
    pred_dir: PathLike,
    gt_dir: PathLike,
    classes: Sequence[str],
    target_class: Optional[int] = None,
) -> Tuple[Dict, Dict]:
    """
    Run evaluation of case results inside a directory

    Args:
        pred_dir: path to dir with predictions
        gt_dir: path to dir with groud truth data
        classes: classes present in dataset
        target_class in case of multiple classes, specify a target class
            to evaluate in a target class vs rest setting

    Returns:
        Dict[str, float]: dictionary with scalar values for evaluation
        Dict[str, np.ndarray]: dictionary with arrays, e.g. for visualization of graph)
    
    See Also:
        :class:`nndet.evaluator.registry.CaseEvaluator`
    """
    pred_dir = Path(pred_dir)
    gt_dir = Path(gt_dir)
    case_ids = [
        p.stem.rsplit('_boxes', 1)[0] for p in pred_dir.iterdir()
        if p.is_file() and p.stem.endswith("_boxes")
    ]
    logger.info(f"Found {len(case_ids)} for case evaluation in {pred_dir}")

    evaluator = CaseEvaluator.create(
        classes=classes,
        target_class=target_class,
    )

    for case_id in case_ids:
        gt = np.load(str(gt_dir / f"{case_id}_boxes_gt.npz"),
                     allow_pickle=True)
        pred = load_pickle(pred_dir / f"{case_id}_boxes.pkl")
        evaluator.run_online_evaluation(pred_classes=[pred["pred_labels"]],
                                        pred_scores=[pred["pred_scores"]],
                                        gt_classes=[gt["classes"]])
    return evaluator.finish_online_evaluation()
Exemple #8
0
    def build_cache(self) -> Tuple[Dict[int, List[Tuple[str, int]]], List]:
        """
        Build up cache for sampling

        Returns:
            Dict[int, List[Tuple[str, int]]]: foreground cache which contains
                of list of tuple of case ids and instance ids for each class
            List: background cache (all samples which do not have any
                foreground)
        """
        fg_cache = defaultdict(list)

        logger.info("Building Sampling Cache for Dataloder")
        for case_id, item in maybe_verbose_iterable(self._data.items(), desc="Sampling Cache"):
            candidates = load_pickle(item['boxes_file'])
            if candidates["instances"]:
                for instance_id, instance_class in zip(candidates["instances"], candidates["labels"]):
                    fg_cache[int(instance_class)].append((case_id, instance_id))
        return {"fg": fg_cache, "case": list(self._data.keys())}
Exemple #9
0
def evaluate_seg_dir(
    pred_dir: PathLike,
    gt_dir: PathLike,
    classes: Sequence[str],
) -> Tuple[Dict, None]:
    """
    Compute dice metric across a directory

    Args:
        pred_dir: path to dir with predictions
        gt_dir: path to dir with groud truth data
        classes: classes present in dataset

    Returns:
        Dict[str, float]: dictionary with scalar values for evaluation
        None

    See Also:
        :class:`nndet.evaluator.registry.PerCaseSegmentationEvaluator`
    """
    pred_dir = Path(pred_dir)
    gt_dir = Path(gt_dir)
    case_ids = [
        p.stem.rsplit('_seg', 1)[0] for p in pred_dir.iterdir()
        if p.is_file() and p.stem.endswith("_seg")
    ]
    logger.info(f"Found {len(case_ids)} for seg evaluation in {pred_dir}")

    evaluator = PerCaseSegmentationEvaluator.create(classes=classes)

    for case_id in case_ids:
        gt = np.load(str(gt_dir / f"{case_id}_seg_gt.npz"),
                     allow_pickle=True)["seg"]  # 1, dims
        pred = load_pickle(pred_dir / f"{case_id}_seg.pkl")
        evaluator.run_online_evaluation(
            seg=pred[None],
            target=gt,
        )
    return evaluator.finish_online_evaluation()
Exemple #10
0
def collect_overview(
    prediction_dir: Path,
    gt_dir: Path,
    iou: float,
    score: float,
    max_num_fp_per_image: int = 5,
    top_n: int = 10,
):
    results = defaultdict(dict)

    for f in prediction_dir.glob("*_boxes.pkl"):
        case_id = f.stem.rsplit('_', 1)[0]

        gt_data = np.load(str(gt_dir / f"{case_id}_boxes_gt.npz"),
                          allow_pickle=True)
        gt_boxes = gt_data["boxes"]
        gt_classes = gt_data["classes"]
        gt_ignore = [
            np.zeros(gt_boxes_img.shape[0]).reshape(-1, 1)
            for gt_boxes_img in [gt_boxes]
        ]

        case_result = load_pickle(f)
        pred_boxes = case_result["pred_boxes"]
        pred_scores = case_result["pred_scores"]
        pred_labels = case_result["pred_labels"]
        keep = pred_scores > score
        pred_boxes, pred_scores, pred_labels = pred_boxes[keep], pred_scores[
            keep], pred_labels[keep]

        # if "properties" in case_data:
        #     results[case_id]["orig_spacing"] = case_data["properties"]["original_spacing"]
        #     results[case_id]["crop_shape"] = [c[1] for c in case_data["properties"]["crop_bbox"]]
        # else:
        #     results[case_id]["orig_spacing"] = None
        #     results[case_id]["crop_shape"] = None
        results[case_id]["num_gt"] = len(gt_classes)

        # computation stats here
        if gt_boxes.size == 0:
            idx = np.argsort(pred_scores)[::-1][:5]
            results[case_id]["fp_score"] = pred_scores[idx]
            results[case_id]["fp_label"] = pred_labels[idx]
            results[case_id]["fp_true_label"] = (np.ones(len(pred_labels)) *
                                                 -1)
            results[case_id]["fp_type"] = ["fp_iou"] * len(pred_labels)
            results[case_id]["num_fn"] = 0
        elif pred_boxes.size == 0:
            results[case_id]["num_fn"] = len(gt_classes)
            results[case_id]["fn_boxes"] = gt_boxes
        else:
            match_quality_matrix = box_iou_np(gt_boxes, pred_boxes)
            matched_idxs = np.argmax(match_quality_matrix, axis=0)
            matched_vals = np.max(match_quality_matrix, axis=0)
            matched_idxs[matched_vals < iou] = -1

            matched_gt_boxes_per_image = gt_boxes[matched_idxs.clip(min=0)]
            target_labels = gt_classes[matched_idxs.clip(min=0)]
            target_labels[matched_idxs == -1] = -1

            # True positive analysis
            tp_keep = target_labels == pred_labels
            tp_boxes, tp_scores, tp_labels = pred_boxes[tp_keep], pred_scores[
                tp_keep], pred_labels[tp_keep]

            keep_high = tp_scores > 0.5
            tp_high_boxes, tp_high_scores, tp_high_labels = tp_boxes[
                keep_high], tp_scores[keep_high], tp_labels[keep_high]
            keep_low = tp_scores < 0.5
            tp_low_boxes, tp_low_scores, tp_low_labels = tp_boxes[
                keep_low], tp_scores[keep_low], tp_labels[keep_low]

            high_idx = np.argsort(tp_high_scores)[::-1][:3]
            low_idx = np.argsort(tp_low_scores)[:3]
            results[case_id]["iou_tp"] = int(tp_keep.sum())
            results[case_id]["tp_high_boxes"] = tp_high_boxes[high_idx]
            results[case_id]["tp_high_score"] = tp_high_scores[high_idx]
            results[case_id]["tp_high_label"] = tp_high_labels[high_idx]
            results[case_id]["tp_iou"] = matched_vals[tp_keep]

            if tp_low_boxes.size > 0:
                results[case_id]["tp_low_boxes"] = tp_low_boxes[low_idx]
                results[case_id]["tp_low_score"] = tp_low_scores[low_idx]
                results[case_id]["tp_low_label"] = tp_low_labels[low_idx]

            # False Positive Analysis
            fp_keep = (pred_labels != target_labels) * (pred_labels != -1)
            fp_boxes, fp_scores, fp_labels, fp_target_labels = pred_boxes[
                fp_keep], pred_scores[fp_keep], pred_labels[
                    fp_keep], target_labels[fp_keep]
            idx = np.argsort(fp_scores)[::-1][:max_num_fp_per_image]
            # results[case_id]["fp_box"] = fp_boxes[idx]
            results[case_id]["fp_score"] = fp_scores[idx]
            results[case_id]["fp_label"] = fp_labels[idx]
            results[case_id]["fp_true_label"] = fp_target_labels[idx]
            results[case_id]["fp_type"] = [
                "fp_iou" if tl == -1 else "fp_cls" for tl in fp_target_labels
            ]

            # Misc
            unmatched_gt = (match_quality_matrix.max(axis=1) < iou)
            false_negatives = unmatched_gt.sum()
            results[case_id]["fn_boxes"] = gt_boxes[unmatched_gt]
            results[case_id]["num_fn"] = false_negatives

    df = pd.DataFrame.from_dict(results, orient='index')
    df = df.sort_index()

    analysis_ids = {}
    if "fp_score" in list(df.columns):
        tmp = df["fp_score"].apply(lambda x: np.max(x)
                                   if np.any(x) else 0).nlargest(top_n)
        analysis_ids["top_scoring_fp"] = tmp.index.values.tolist()
        tmp = df["fp_score"].apply(lambda x: len(x) if isinstance(
            x, Sequence) or isinstance(x, np.ndarray) else 0).nlargest(top_n)
        analysis_ids["top_num_fp"] = tmp.index.values.tolist()
    if "fp_score" in list(df.columns):
        tmp = df["num_fn"].nlargest(top_n)
        analysis_ids["top_num_fn"] = tmp.index.values.tolist()
    return df, analysis_ids
Exemple #11
0
def _evaluate(
    task: str,
    model: str,
    fold: int,
    test: bool = False,
    do_case_eval: bool = False,
    do_boxes_eval: bool = False,
    do_seg_eval: bool = False,
    do_instances_eval: bool = False,
    do_analyze_boxes: bool = False,
):
    """
    This entrypoint runs the evaluation
    
    Args:
        task: current task
        model: full name of the model run determine empricial parameters for
            e.g. RetinaUNetV001_D3V001_3d
        fold: current fold
        test: use test split
        do_case_eval: evaluate patient metrics
        do_boxes_eval: perform box evaluation
        do_seg_eval: perform semantic segmentation evaluation
        do_instances_eval: perform instance segmentation evaluation
        do_analyze_boxes: run analysis of box results
    """
    # prepare paths
    task = get_task(task, name=True)
    model_dir = Path(os.getenv("det_models")) / task / model
    training_dir = get_training_dir(model_dir, fold)

    data_dir_task = Path(os.getenv("det_data")) / task
    data_cfg = load_dataset_info(data_dir_task)

    prefix = "test" if test else "val"

    modes = [True] if test else [True, False]
    for restore in modes:
        if restore:
            pred_dir_name = f"{prefix}_predictions"
            gt_dir_name = "labelsTs" if test else "labelsTr"
            gt_dir = data_dir_task / "preprocessed" / gt_dir_name
        else:
            plan = load_pickle(training_dir / "plan.pkl")
            pred_dir_name = f"{prefix}_predictions_preprocessed"
            gt_dir = data_dir_task / "preprocessed" / plan[
                "data_identifier"] / "labelsTr"

        pred_dir = training_dir / pred_dir_name
        save_dir = training_dir / f"{prefix}_results" if restore else \
            training_dir / f"{prefix}_results_preprocessed"

        # compute metrics
        if do_boxes_eval:
            logger.info(f"Computing box metrics: restore {restore}")
            scores, curves = evaluate_box_dir(
                pred_dir=pred_dir,
                gt_dir=gt_dir,
                classes=list(data_cfg["labels"].keys()),
                save_dir=save_dir / "boxes",
            )
            save_metric_output(scores, curves, save_dir, "results_boxes")
        if do_case_eval:
            logger.info(f"Computing case metrics: restore {restore}")
            scores, curves = evaluate_case_dir(
                pred_dir=pred_dir,
                gt_dir=gt_dir,
                classes=list(data_cfg["labels"].keys()),
                target_class=data_cfg["target_class"],
            )
            save_metric_output(scores, curves, save_dir, "results_case")
        if do_seg_eval:
            logger.info(f"Computing seg metrics: restore {restore}")
            scores, curves = evaluate_seg_dir(
                pred_dir=pred_dir,
                gt_dir=gt_dir,
            )
            save_metric_output(scores, curves, save_dir, "results_seg")
        if do_instances_eval:
            raise NotImplementedError

        # run analysis
        save_dir = training_dir / f"{prefix}_analysis" if restore else \
            training_dir / f"{prefix}_analysis_preprocessed"
        if do_analyze_boxes:
            logger.info(f"Analyze box predictions: restore {restore}")
            run_analysis_suite(
                prediction_dir=pred_dir,
                gt_dir=gt_dir,
                save_dir=save_dir / "boxes",
            )
Exemple #12
0
def _sweep(
    task: str,
    model: str,
    fold: int,
):
    """
    Determine best postprocessing parameters for a trained model

    Args:
        task: current task
        model: full name of the model run determine empricial parameters for
            e.g. RetinaUNetV001_D3V001_3d
        fold: current fold
    """
    nndet_data_dir = Path(os.getenv("det_models"))
    task = get_task(task, name=True, models=True)
    train_dir = nndet_data_dir / task / model / f"fold{fold}"

    cfg = OmegaConf.load(str(train_dir / "config.yaml"))
    os.chdir(str(train_dir))

    logger.remove()
    logger.add(sys.stdout, format="{level} {message}", level="INFO")
    log_file = Path(os.getcwd()) / "sweep.log"
    logger.add(log_file, level="INFO")
    logger.info(f"Log file at {log_file}")

    plan = load_pickle(train_dir / "plan.pkl")
    data_dir = Path(cfg.host["preprocessed_output_dir"]
                    ) / plan["data_identifier"] / "imagesTr"

    module = MODULE_REGISTRY[cfg["module"]](
        model_cfg=OmegaConf.to_container(cfg["model_cfg"], resolve=True),
        trainer_cfg=OmegaConf.to_container(cfg["trainer_cfg"], resolve=True),
        plan=plan,
    )

    splits = load_pickle(train_dir / "splits.pkl")
    case_ids = splits[cfg["exp"]["fold"]]["val"]
    inference_plan = module.sweep(
        cfg=OmegaConf.to_container(cfg, resolve=True),
        save_dir=train_dir,
        train_data_dir=data_dir,
        case_ids=case_ids,
        run_prediction=True,  # TODO: add commmand line arg
    )

    plan["inference_plan"] = inference_plan
    save_pickle(plan, train_dir / "plan_inference.pkl")

    ensembler_cls = module.get_ensembler_cls(
        key="boxes", dim=plan["network_dim"])  # TODO: make this configurable
    for restore in [True, False]:
        target_dir = train_dir / "val_predictions" if restore else \
            train_dir / "val_predictions_preprocessed"
        extract_results(
            source_dir=train_dir / "sweep_predictions",
            target_dir=target_dir,
            ensembler_cls=ensembler_cls,
            restore=restore,
            **inference_plan,
        )

    _evaluate(
        task=cfg["task"],
        model=cfg["exp"]["id"],
        fold=cfg["exp"]["fold"],
        test=False,
        do_boxes_eval=True,  # TODO: make this configurable
        do_analyze_boxes=True,  # TODO: make this configurable
    )
Exemple #13
0
def _train(
    task: str,
    ov: List[str],
    do_sweep: bool,
):
    """
    Run training

    Args:
        task: task to run training for
        ov: overwrites for config manager
        do_sweep: determine best emprical parameters for run
    """
    print(f"Overwrites: {ov}")
    initialize_config_module(config_module="nndet.conf")
    cfg = compose(task, "config.yaml", overrides=ov if ov is not None else [])

    assert cfg.host.parent_data is not None, 'Parent data can not be None'
    assert cfg.host.parent_results is not None, 'Output dir can not be None'

    train_dir = init_train_dir(cfg)

    pl_logger = MLFlowLogger(
        experiment_name=cfg["task"],
        tags={
            "host": socket.gethostname(),
            "fold": cfg["exp"]["fold"],
            "task": cfg["task"],
            "job_id": os.getenv('LSB_JOBID', 'no_id'),
            "mlflow.runName": cfg["exp"]["id"],
        },
        save_dir=os.getenv("MLFLOW_TRACKING_URI", "./mlruns"),
    )
    pl_logger.log_hyperparams(
        flatten_mapping(
            {"model": OmegaConf.to_container(cfg["model_cfg"], resolve=True)}))
    pl_logger.log_hyperparams(
        flatten_mapping({
            "trainer":
            OmegaConf.to_container(cfg["trainer_cfg"], resolve=True)
        }))

    logger.remove()
    logger.add(sys.stdout, format="{level} {message}", level="INFO")
    log_file = Path(os.getcwd()) / "train.log"
    logger.add(log_file, level="INFO")
    logger.info(f"Log file at {log_file}")

    meta_data = {}
    meta_data["torch_version"] = str(torch.__version__)
    meta_data["date"] = str(datetime.now())
    meta_data["git"] = log_git(nndet.__path__[0], repo_name="nndet")
    save_json(meta_data, "./meta.json")
    try:
        write_requirements_to_file("requirements.txt")
    except Exception as e:
        logger.error(f"Could not log req: {e}")

    plan_path = Path(str(cfg.host["plan_path"]))
    plan = load_pickle(plan_path)
    save_json(create_debug_plan(plan), "./plan_debug.json")

    data_dir = Path(cfg.host["preprocessed_output_dir"]
                    ) / plan["data_identifier"] / "imagesTr"

    datamodule = Datamodule(
        augment_cfg=OmegaConf.to_container(cfg["augment_cfg"], resolve=True),
        plan=plan,
        data_dir=data_dir,
        fold=cfg["exp"]["fold"],
    )
    module = MODULE_REGISTRY[cfg["module"]](
        model_cfg=OmegaConf.to_container(cfg["model_cfg"], resolve=True),
        trainer_cfg=OmegaConf.to_container(cfg["trainer_cfg"], resolve=True),
        plan=plan,
    )
    callbacks = []
    checkpoint_cb = ModelCheckpoint(
        dirpath=train_dir,
        filename='model_best',
        save_last=True,
        save_top_k=1,
        monitor=cfg["trainer_cfg"]["monitor_key"],
        mode=cfg["trainer_cfg"]["monitor_mode"],
    )
    checkpoint_cb.CHECKPOINT_NAME_LAST = 'model_last'
    callbacks.append(checkpoint_cb)
    callbacks.append(LearningRateMonitor(logging_interval="epoch"))

    OmegaConf.save(cfg, str(Path(os.getcwd()) / "config.yaml"))
    OmegaConf.save(cfg,
                   str(Path(os.getcwd()) / "config_resolved.yaml"),
                   resolve=True)
    save_pickle(plan, train_dir / "plan.pkl")  # backup plan
    splits = load_pickle(
        Path(cfg.host.preprocessed_output_dir) / datamodule.splits_file)
    save_pickle(splits, train_dir / "splits.pkl")

    trainer_kwargs = {}
    if cfg["train"]["mode"].lower() == "resume":
        trainer_kwargs[
            "resume_from_checkpoint"] = train_dir / "model_last.ckpt"

    num_gpus = cfg["trainer_cfg"]["gpus"]
    logger.info(f"Using {num_gpus} GPUs for training")
    plugins = cfg["trainer_cfg"].get("plugins", None)
    logger.info(f"Using {plugins} plugins for training")

    trainer = pl.Trainer(
        gpus=list(range(num_gpus)) if num_gpus > 1 else num_gpus,
        accelerator=cfg["trainer_cfg"]["accelerator"],
        precision=cfg["trainer_cfg"]["precision"],
        amp_backend=cfg["trainer_cfg"]["amp_backend"],
        amp_level=cfg["trainer_cfg"]["amp_level"],
        benchmark=cfg["trainer_cfg"]["benchmark"],
        deterministic=cfg["trainer_cfg"]["deterministic"],
        callbacks=callbacks,
        logger=pl_logger,
        max_epochs=module.max_epochs,
        progress_bar_refresh_rate=None
        if bool(int(os.getenv("det_verbose", 1))) else 0,
        reload_dataloaders_every_epoch=False,
        num_sanity_val_steps=10,
        weights_summary='full',
        plugins=plugins,
        terminate_on_nan=True,  # TODO: make modular
        move_metrics_to_cpu=True,
        **trainer_kwargs)
    trainer.fit(module, datamodule=datamodule)

    if do_sweep:
        case_ids = splits[cfg["exp"]["fold"]]["val"]
        if "debug" in cfg and "num_cases_val" in cfg["debug"]:
            case_ids = case_ids[:cfg["debug"]["num_cases_val"]]

        inference_plan = module.sweep(
            cfg=OmegaConf.to_container(cfg, resolve=True),
            save_dir=train_dir,
            train_data_dir=data_dir,
            case_ids=case_ids,
            run_prediction=True,
        )

        plan["inference_plan"] = inference_plan
        save_pickle(plan, train_dir / "plan_inference.pkl")

        ensembler_cls = module.get_ensembler_cls(
            key="boxes",
            dim=plan["network_dim"])  # TODO: make this configurable
        for restore in [True, False]:
            target_dir = train_dir / "val_predictions" if restore else \
                train_dir / "val_predictions_preprocessed"
            extract_results(
                source_dir=train_dir / "sweep_predictions",
                target_dir=target_dir,
                ensembler_cls=ensembler_cls,
                restore=restore,
                **inference_plan,
            )

        _evaluate(
            task=cfg["task"],
            model=cfg["exp"]["id"],
            fold=cfg["exp"]["fold"],
            test=False,
            do_boxes_eval=True,  # TODO: make this configurable
            do_analyze_boxes=True,  # TODO: make this configurable
        )
Exemple #14
0
    def generate_train_batch(self) -> Dict[str, Any]:
        """
        Generate a single batch

        Returns:
            Dict: batch dict
                `data` (np.ndarray): data
                `seg` (np.ndarray): unordered(!) instance segmentation
                    Reordering needs to happen after final crop
                `instances` (List[Sequence[int]]): class for each instance in
                    the case (<- we can not extract them because we do not
                    know the present instances yet)
                `properties`(List[Dict]): properties of each case
                `keys` (List[str]): case ids
        """
        data_batch = np.zeros(self.data_shape_batch, dtype=float)
        seg_batch = np.zeros(self.seg_shape_batch, dtype=float)
        instances_batch, properties_batch, case_ids_batch = [], [], []

        selected_cases, selected_instances = self.select()
        for batch_idx, (case_id, instance_id) in enumerate(zip(selected_cases, selected_instances)):
            # print(case_id, instance_id)
            case_data = np.load(self._data[case_id]['data_file'], self.memmap_mode, allow_pickle=True)
            case_seg = np.load(self._data[case_id]['seg_file'], self.memmap_mode, allow_pickle=True)
            properties = load_pickle(self._data[case_id]['properties_file'])

            if instance_id < 0:
                candidates = self.load_candidates(case_id=case_id, fg_crop=False)
                crop = self.get_bg_crop(
                    case_data=case_data,
                    case_seg=case_seg,
                    properties=properties,
                    case_id=case_id,
                    candidates=candidates,
                )
            else:
                candidates = self.load_candidates(case_id=case_id, fg_crop=True)
                crop = self.get_fg_crop(
                    case_data=case_data,
                    case_seg=case_seg,
                    properties=properties,
                    case_id=case_id,
                    instance_id=instance_id,
                    candidates=candidates,
                )

            data_batch[batch_idx] = save_get_crop(case_data,
                                                  crop=crop,
                                                  mode=self.pad_mode,
                                                  **self.pad_kwargs_data,
                                                  )[0]
            seg_batch[batch_idx] = save_get_crop(case_seg,
                                                 crop=crop,
                                                 mode='constant',
                                                 constant_values=-1,
                                                 )[0]
            case_ids_batch.append(case_id)
            instances_batch.append(properties.pop("instances"))
            properties_batch.append(properties)

        return {'data': data_batch,
                'seg': seg_batch,
                'properties': properties_batch,
                'instance_mapping': instances_batch,
                'keys': case_ids_batch,
                }
Exemple #15
0
def run_planning_and_process(
    splitted_4d_output_dir: Path,
    cropped_output_dir: Path,
    preprocessed_output_dir: Path,
    planner_name: str,
    dim: int,
    model_name: str,
    model_cfg: Dict,
    num_processes: int,
    run_preprocessing: bool = True,
    ):
    """
    Run planning and preprocessing

    Args:
        splitted_4d_output_dir: base dir of splitted data
        cropped_output_dir: base dir of cropped data
        preprocessed_output_dir: base dir of preprocessed data
        planner_name: planner name
        dim: number of spatial dimensions
        model_name: name of model to run planning for
        model_cfg: hyperparameters of model (used during planning to
            instantiate model)
        num_processes: number of processes to use for preprocessing
        run_preprocessing: Preprocess and check data. Defaults to True.
    """
    planner_cls = PLANNER_REGISTRY.get(planner_name)
    planner = planner_cls(
        preprocessed_output_dir=preprocessed_output_dir
    )
    plan_identifiers = planner.plan_experiment(
        model_name=model_name,
        model_cfg=model_cfg,
    )
    if run_preprocessing:
        for plan_id in plan_identifiers:
            plan = load_pickle(preprocessed_output_dir / plan_id)
            planner.run_preprocessing(
                cropped_data_dir=cropped_output_dir / "imagesTr",
                plan=plan,
                num_processes=num_processes,
                )
            case_ids_failed, result_check = run_check(
                data_dir=preprocessed_output_dir / plan["data_identifier"] / "imagesTr",
                remove=True,
                processes=num_processes
            )

            # delete and rerun corrupted cases
            if not result_check:
                logger.warning(f"{plan_id} check failed: There are corrupted files {case_ids_failed}!!!!"
                                f"Running preprocessing of those cases without multiprocessing.")
                planner.run_preprocessing(
                    cropped_data_dir=cropped_output_dir / "imagesTr",
                    plan=plan,
                    num_processes=0,
                )
                case_ids_failed, result_check = run_check(
                    data_dir=preprocessed_output_dir / plan["data_identifier"] / "imagesTr",
                    remove=False,
                    processes=0
                )
                if not result_check:
                    logger.error(f"Could not fix corrupted files {case_ids_failed}!")
                    raise RuntimeError("Found corrupted files, check logs!")
                else:
                    logger.info("Fixed corrupted files.")
            else:
                logger.info(f"{plan_id} check successful: Loading check completed")

    if run_preprocessing:
        create_labels(
            preprocessed_output_dir=preprocessed_output_dir,
            source_dir=splitted_4d_output_dir,
            num_processes=num_processes,
        )
    raw_splitted_images = Path(os.getenv("det_data")) / "Task016_Luna" / "raw_splitted" / "imagesTr"

    prediction_dir = model_dir / "consolidated" / "val_predictions"
    assert prediction_dir.is_dir()

    logger.remove()
    logger.add(sys.stdout, level="INFO")
    log_file = model_dir / "prepare_eval_cpm.log"

    prediction_cache = defaultdict(list)
    prediction_paths = sorted([p for p in prediction_dir.iterdir() if p.is_file() and p.name.endswith("_boxes.pkl")])
    logger.info(f"Found {len(prediction_paths)} predictions for evaluation")
    for prediction_path in tqdm(prediction_paths):
        seriusuid = prediction_path.stem.rsplit("_", 1)[0].replace('_', ".")
        predictions = load_pickle(prediction_path)

        data_path = raw_splitted_images / f"{prediction_path.stem.rsplit('_', 1)[0]}_0000.nii.gz"
        image_itk = load_sitk(data_path)

        boxes = predictions["pred_boxes"]
        probs = predictions["pred_scores"]
        centers = box_center_np(boxes)
        assert predictions["restore"]

        for center, prob in zip(centers, probs):
            position_image = (float(center[2]), float(center[1]), float(center[0]))
            position_world = image_itk.TransformContinuousIndexToPhysicalPoint(position_image)

            prediction_cache["seriesuid"].append(seriusuid)
            prediction_cache["coordX"].append(float(position_world[0]))