Esempio n. 1
0
def main() -> None:
    parser = argparse.ArgumentParser(
        "Check whether trainings are complete and converged.")
    parser.add_argument(
        "type",
        type=str,
        choices=["completeness", "convergence"],
        help="Pick whether to check convergence or just completeness.")
    parser.add_argument("setup", type=str, help="Network setup to check.")
    parser.add_argument(
        "--tol_distance",
        type=int,
        default=40,
        nargs="+",
        help=
        "Tolerance distance to check for with metrics using tolerance distance."
    )
    parser.add_argument(
        "--clip_distance",
        type=int,
        default=200,
        nargs="+",
        help="Clip distance to check for which metrics using clip distance.")
    parser.add_argument(
        "--threshold",
        type=int,
        default=127,
        help="Threshold to have been applied on top of raw predictions.")
    parser.add_argument("--training_version",
                        type=str,
                        default="v0003.2",
                        help="Version of training")
    parser.add_argument("--gt_version",
                        type=str,
                        default="v0003",
                        help="Version of groundtruth")
    parser.add_argument("--check_private_db", action="store_true")
    args = parser.parse_args()
    db = cosem_db.MongoCosemDB(training_version=args.training_version,
                               gt_version=args.gt_version,
                               write_access=args.check_private_db)
    metric_params = {
        "tol_distance": args.tol_distance,
        "clip_distance": args.clip_distance
    }
    if args.type == "completeness":
        print(
            check_completeness(db,
                               args.setup,
                               metric_params,
                               threshold=args.threshold))
    else:
        print(check_convergence(args.setup, args.threshold, db))
Esempio n. 2
0
def query_score(cropno,
                labelname,
                threshold=127,
                setup=None,
                s1=False,
                clip_distance=200,
                tol_distance=40,
                training_version="v0003.2",
                gt_version="v0003"):
    db = cosem_db.MongoCosemDB(training_version=training_version,
                               gt_version=gt_version)
    c = db.get_crop_by_number(cropno)
    labelname, setup, iteration, s1 = get_best_manual(
        c["dataset_id"],
        labelname,
        setup=setup,
        s1=s1,
        training_version=db.training_version)
    path = construct_pred_path(setup,
                               iteration,
                               c,
                               s1,
                               training_version=db.training_version)
    threshold = threshold
    metric_params = dict()
    metric_params["clip_distance"] = clip_distance
    metric_params["tol_distance"] = tol_distance
    scores = dict()
    for metric in EvaluationMetrics:
        specific_params = filter_params(metric_params, metric)
        query = {
            "path": path,
            "dataset": labelname,
            "setup": setup,
            "iteration": iteration,
            "crop": str(cropno),
            "threshold": threshold,
            "metric": metric,
            "metric_params": specific_params
        }
        doc = db.read_evaluation_result(query)
        scores[metric.name] = doc["value"]

    return scores
Esempio n. 3
0
def main(alt_args=None):
    parser = argparse.ArgumentParser("Evaluate predictions")
    parser.add_argument("--setup", type=str, nargs='+', default=None,
                        help="network setup from which to evaluate a prediction, e.g. setup01")
    parser.add_argument("--iteration", type=int, nargs='+', default=None,
                        help="network iteration from which to evaluate prediction, e.g. 725000")
    parser.add_argument("--label", type=str, nargs='+', default=None,
                        help="label for which to evaluate prediction, choices: " + ", ".join(list(hierarchy.keys())))
    parser.add_argument("--crop", type=int, nargs='+', default=None,
                        help="number of crop with annotated groundtruth, e.g. 110")
    parser.add_argument("--threshold", type=int, default=127, nargs='+',
                        help="threshold to apply on distances")
    parser.add_argument("--pred_path", type=str, default=None, nargs='+',
                        help="path of n5 file containing predictions")
    parser.add_argument("--pred_ds", type=str, default=None, nargs='+',
                        help="dataset of the n5 file containing predictions")
    parser.add_argument("--metric", type=str, default=None, help="metric to evaluate",
                        choices=list(em.value for em in EvaluationMetrics), nargs="+")
    parser.add_argument("--clip_distance", type=int, default=200,
                        help="Parameter used for clipped false distances. False distances larger than the value of "
                             "this parameter are reduced to this value.")
    parser.add_argument("--tol_distance", type=int, default=40,
                        help="Parameter used for counting false negatives/positives with a tolerance. Only false "
                             "predictions that are farther than this value from the closest pixel where they would be "
                             "correct are counted.")
    parser.add_argument("--training_version", type=str, default="v0003.2", help="Version of training from which to "
                                                                                "evaluate setup.")
    parser.add_argument("--gt_version", type=str, default="v0003", help="Version of groundtruth to use for evaluation.")
    parser.add_argument("--save", action='store_true',
                        help="save to database and csv file")
    parser.add_argument("--overwrite", action='store_true',
                        help="overwrite existing entries in database and csv")
    parser.add_argument("--s1", action='store_true', help="use s1 standard directory")
    parser.add_argument("--refined", action='store_true', help="use refined predictions")
    parser.add_argument("--dry-run", action='store_true',
                        help="show list of evaluations that would be run with given arguments without compute anything")

    args = parser.parse_args(alt_args)
    db = cosem_db.MongoCosemDB(write_access=True, training_version=args.training_version, gt_version=args.gt_version)
    eval_results_csv_folder = os.path.join(config_loader.get_config()["organelles"]["evaluation_path"],
                                           db.training_version, "evaluation_results")
    csvhandler = cosem_db.CosemCSV(eval_results_csv_folder)
    if args.overwrite and not args.save:
        raise ValueError("Overwriting should only be set if save is also set")
    if args.crop is None:
        crops = db.get_all_validation_crops()
    else:
        crops = []
        for cno in args.crop:
            c = db.get_crop_by_number(cno)
            if c is None:
                raise ValueError("Did not find crop {0:} in database".format(cno))
            crops.append(c)
    if args.metric is None:
        metric = list(em.value for em in EvaluationMetrics)
    else:
        metric = list(always_iterable(args.metric))

    metric_params = dict()
    metric_params['clip_distance'] = args.clip_distance
    metric_params['tol_distance'] = args.tol_distance
    if args.refined:
        assert args.setup is None
        assert args.iteration is None
    else:
        assert args.setup is not None

    num_validations = max(len(list(always_iterable(args.setup))), len(list(always_iterable(args.iteration))),
                          len(list(always_iterable(args.label))), len(list(always_iterable(args.pred_path))),
                          len(list(always_iterable(args.pred_ds))), len(list(always_iterable(args.threshold))), 1)
    iterator = itertools.product(zip(range(num_validations), repeat_last(always_iterable(args.setup)),
                                     repeat_last(always_iterable(args.iteration)),
                                     repeat_last(always_iterable(args.label)),
                                     repeat_last(always_iterable(args.pred_path)),
                                     repeat_last(always_iterable(args.pred_ds)),
                                     repeat_last(always_iterable(args.threshold))), always_iterable(crops))

    print("\nWill run the following validations:\n")
    validations = []
    for (valno, setup, iteration, label, pred_path, pred_ds, thr), crop in iterator:
        if pred_ds is not None and label is None:
                raise ValueError("If pred_ds is specified, label can't be autodetected")

        if pred_path is None:
            if iteration is None and not args.refined:
                raise ValueError("Either pred_path or iteration must be specified")
            if args.refined:
                pred_path = construct_refined_path(crop)
            else:
                pred_path = construct_pred_path(setup, iteration, crop, args.s1, training_version=args.training_version)
        if not os.path.exists(pred_path):
            raise ValueError("{0:} not found".format(pred_path))
        if not os.path.exists(os.path.join(pred_path, 'attributes.json')):
            raise RuntimeError("N5 is incompatible with zarr due to missing attributes files. Consider running"
                               " `add_missing_n5_attributes {0:}`".format(pred_path))
        if label is None:
            labels = autodetect_labelnames(pred_path, crop)
        else:
            labels = list(always_iterable(label))
            for ll in labels:
                if ll not in crop_utils.get_all_annotated_labelnames(crop):
                    raise ValueError("Label {0:} not annotated in crop {1:}".format(ll, crop['number']))
        for ll in labels:
            if pred_ds is None:
                ds = ll
            else:
                ds = pred_ds

            if not os.path.exists(os.path.join(pred_path, ds)):
                raise ValueError('{0:} not found'.format(os.path.join(pred_path, ds)))
            if iteration is not None:
                iter = autodetect_iteration(pred_path, ds)
                if iter is not None:
                    if iteration != iter:
                        raise ValueError(
                            "You specified pred_path as well as iteration. The iteration does not match the "
                            "iteration in the attributes of the prediction."
                        )
                else:
                    iter = iteration
            else:
                iter = autodetect_iteration(pred_path, ds)
                if iter is None:
                    raise ValueError(
                        "Please sepcify iteration, it is not contained in the prediction metadata."
                    )
            if setup is None:
                this_setup = autodetect_setup(pred_path, ds)
                if this_setup is None:
                    raise ValueError(
                        "Please specify setup, it is not contained in the prediction metadata."
                    )
            else:
                this_setup = autodetect_setup(pred_path, ds)
                if this_setup is not None:
                    if this_setup != setup:
                        raise ValueError(
                            "The specified setup does not match the setup in the attributes of the prediction."
                        )
                else:
                    this_setup = setup
            if not args.refined and pred_path != construct_pred_path(this_setup, iter, crop, args.s1,
                                                                     training_version=args.training_version):
                warnings.warn(
                    "You specified pred_path as well as setup and the pred_path does not match the standard "
                    "location."
                )
            if args.refined and pred_path != construct_refined_path(crop):
                warnings.warn(
                    "You specified pred_path does not match the standard location."
                )
            if not os.path.exists(os.path.join(pred_path, ds)):
                raise ValueError('{0:} not found'.format(os.path.join(pred_path, ds)))
            n5 = zarr.open(pred_path, mode="r")
            raw_ds = n5[ds].attrs["raw_ds"]
            parent_path = n5[ds].attrs["raw_data_path"]
            parent_dataset_id = n5[ds].attrs["parent_dataset_id"]
            validations.append([pred_path, ds, this_setup, iter, hierarchy[ll], crop, raw_ds, parent_path,
                                parent_dataset_id, thr])

    tabs = [(pp, d, s, i, ll.labelname, c['number'], r_ds, parent, p_id, t, m, filter_params(metric_params, m)) for
            (pp, d, s, i, ll, c, r_ds, parent, p_id, t), m in itertools.product(validations, metric)]
    print(tabulate.tabulate(tabs, ["Path", "Dataset", "Setup", "Iteration", "Label", "Crop", "Raw Dataset",
                                   "Parent Path", "Parent Id", "Threshold", "Metric", "Metric Params"]))

    if not args.dry_run:
        print("\nRunning Evaluations:")
        for val_params in validations:
            pp, d, s, i, ll, c , r_ds, parent, p_id, t = val_params
            results = run_validation(pp, d, s, i, ll, c, t, metric, metric_params, db, csvhandler, args.save,
                                     args.overwrite, args.refined, gt_version=args.gt_version)
            val_params.append(results)
        print("\nResults Summary:")
        tabs = [(pp, d, s, i, ll.labelname, c['number'], r_ds, parent, p_id, t, m, filter_params(metric_params, m), v) for
                (pp, d, s, i, ll, c, r_ds, parent, p_id, t, r) in validations for m, v in r.items()]
        print(tabulate.tabulate(tabs, ["Path", "Dataset", "Setup", "Iteration", "Label", "Crop",
                                       "Raw Dataset", "Parent Path", "Parent Id", "Threshold", "Metric",
                                       "Metric Params", "Value"]))
Esempio n. 4
0
def get_differences(cropno,
                    metrics,
                    domain="setup",
                    threshold=127,
                    clip_distance=200,
                    tol_distance=40,
                    training_version="v0003.2",
                    gt_version="v0003"):
    if not (isinstance(metrics, tuple) or isinstance(metrics, list)):
        metrics = [
            metrics,
        ]
    db = cosem_db.MongoCosemDB(training_version=training_version,
                               gt_version=gt_version)

    c = db.get_crop_by_number(str(cropno))
    metric_params = dict()
    metric_params["clip_distance"] = clip_distance
    metric_params["tol_distance"] = tol_distance

    csv_folder = os.path.join(
        config_loader.get_config()["organelles"]["evaluation_path"],
        db.training_version, "manual")

    if domain == "setup":
        csv_file = os.path.join(csv_folder, c["dataset_id"] + "_setup.csv")
        f = open(csv_file, "r")
        fieldnames = ["labelname", "setup", "iteration", "s1"]
    elif domain == "iteration":
        csv_file = os.path.join(csv_folder, c["dataset_id"] + "_iteration.csv")
        f = open(csv_file, "r")
        fieldnames = ["setup", "labelname", "iteration", "s1"]
    else:
        raise ValueError("unknown domain")

    reader = csv.DictReader(f, fieldnames)
    all_manual = []
    for row in reader:
        manual_result = dict()
        manual_result["setup"] = row["setup"]
        manual_result["labelname"] = row["labelname"]
        manual_result["iteration"] = int(row["iteration"])
        manual_result["s1"] = bool(int(row["s1"]))
        all_manual.append(manual_result)
    f.close()
    for manual_result in all_manual:
        if domain == "setup":
            query_setup = None
        else:
            query_setup = manual_result["setup"]

        for best_by_metric in metrics + ["manual"]:

            for eval_metric in metrics:

                entry = "{eval_metric:}_by_{best_by_metric:}".format(
                    eval_metric=eval_metric, best_by_metric=best_by_metric)
                if best_by_metric == "manual":
                    manual_result[entry] = query_score(
                        cropno,
                        manual_result["labelname"],
                        setup=query_setup,
                        s1=manual_result["s1"],
                        threshold=threshold,
                        clip_distance=clip_distance,
                        tol_distance=tol_distance,
                        training_version=training_version,
                        gt_version=gt_version)[eval_metric]
                else:
                    specific_params = filter_params(metric_params,
                                                    best_by_metric)
                    best_result = get_best_automatic(
                        db,
                        cropno,
                        manual_result["labelname"],
                        best_by_metric,
                        specific_params,
                        query_setup,
                        threshold=threshold,
                        s1=manual_result["s1"])

                    specific_params = filter_params(metric_params, eval_metric)
                    query = {
                        "path": best_result["path"],
                        "dataset": best_result["dataset"],
                        "setup": best_result["setup"],
                        "iteration": best_result["iteration"],
                        "label": best_result["label"],
                        "crop": best_result["crop"],
                        "threshold": best_result["threshold"],
                        "metric": eval_metric,
                        "metric_params": specific_params
                    }
                    doc = db.read_evaluation_result(query)
                    manual_result[entry] = doc["value"]

    return all_manual
Esempio n. 5
0
                        "label": best_result["label"],
                        "crop": best_result["crop"],
                        "threshold": best_result["threshold"],
                        "metric": eval_metric,
                        "metric_params": specific_params
                    }
                    doc = db.read_evaluation_result(query)
                    manual_result[entry] = doc["value"]

    return all_manual


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("label", type=str)
    parser.add_argument("crop", type=int)
    parser.add_argument("metric", type=str)
    parser.add_argument("--setup", type=str, default=None)
    parser.add_argument("--training_version", type=str, default="v0003.2")
    parser.add_argument("--gt_version", type=str, default="v0003")
    args = parser.parse_args()

    db = cosem_db.MongoCosemDB(training_version=args.training_version,
                               gt_version=args.gt_version)
    print(
        get_best_automatic(db,
                           args.crop,
                           args.label,
                           args.metric, {},
                           setup=args.setup))
Esempio n. 6
0
def main() -> None:
    parser = argparse.ArgumentParser(
        "Run a pre-defined comparison for evaluation results in the database.")
    parser.add_argument("comparison",
                        type=str,
                        help="Type of comparison to run",
                        choices=[
                            "4nm-vs-8nm", "s1-vs-sub", "raw-vs-refined",
                            "all-vs-common-vs-single", "metrics",
                            "generalization", "best-4nm", "best-8nm"
                        ])
    parser.add_argument(
        "--metric",
        nargs="+",
        type=str,
        choices=list(em.value
                     for em in segmentation_metrics.EvaluationMetrics) +
        ["manual"],
        help=
        "Metric to use for evaluation. For metrics evaluation the first one is used for "
        "comparison, the second one is the alternative metric by which to pick the best result"
    )
    parser.add_argument("--crops",
                        type=int,
                        nargs="*",
                        default=None,
                        help="Crops on which .")
    parser.add_argument("--threshold",
                        type=int,
                        default=127,
                        help="threshold applied on distances for evaluation")
    parser.add_argument("--clip_distance",
                        type=int,
                        default=200,
                        help="Parameter used for clipped false distances "
                        "for relevant metrics.")
    parser.add_argument("--tol_distance",
                        type=int,
                        default=40,
                        help="Parameter used for tolerated false distances "
                        "for relevant metrics.")
    parser.add_argument(
        "--mode",
        type=str,
        choices=["across-setups", "per-setup", "all"],
        help=
        "Mode for some of the comparisons on whether to compare across setups ("
        "`across-setups`) or only between equivalent setups (`per-setup`)",
        default="across-setups")
    parser.add_argument("--test",
                        action="store_true",
                        help="use cross validation for automatic evaluations")
    parser.add_argument("--raw_ds",
                        type=str,
                        help="filter for raw dataset",
                        default="volumes/raw/s0")
    parser.add_argument("--training_version",
                        type=str,
                        default="v0003.2",
                        help="Version of training")
    parser.add_argument("--gt_version",
                        type=str,
                        default="v0003",
                        help="Version of groundtruth")
    parser.add_argument(
        "--save",
        type=functools.partial(type_or_bool, type=str),
        default="False",
        help=
        ("True to save to a standardized csv file, or path to a custom csv file. False for not "
         "saving."))
    parser.add_argument("--check_private_db", action="store_true")
    args = parser.parse_args()
    if args.mode == "all" and args.comparison != "metrics":
        raise ValueError("Mode all can only be used for metrics comparison.")
    db = cosem_db.MongoCosemDB(training_version=args.training_version,
                               gt_version=args.gt_version,
                               write_access=args.check_private_db)
    compare(args.comparison,
            db,
            args.metric,
            crops=args.crops,
            save=args.save,
            tol_distance=args.tol_distance,
            clip_distance=args.clip_distance,
            threshold=args.threshold,
            mode=args.mode,
            raw_ds=args.raw_ds,
            test=args.test)