def label_dist(labels: List[Label], completion_min: int = 6, dataset: Optional[str] = None, gt_version: str = "v0003", save: Optional[str] = None) -> Dict[str, Dict[int, int]]: """ Compute label distribution. Args: labels: List of labels to compute distribution for. completion_min: Minimal completion status for a crop from the database to be included in the distribution. dataset: Dataset for which to calculate label distribution. If None calculate across all datasets. gt_version: Version of groundtruth for which to accumulate distribution. save: File to which to save distributions as json. If None, results won't be saved. Returns: Dictionary with distributions per label with counts for "positives", "negatives" and the sum of both ("sums"). """ db = MongoCosemDB(gt_version=gt_version) collection = db.access("crops", db.gt_version) db_filter = {"completion": {"$gte": completion_min}} if dataset is not None: db_filter["dataset_id"] = dataset skip = { "_id": 0, "number": 1, "labels": 1, "parent": 1, "dimensions": 1, "dataset_id": 1 } positives = dict() negatives = dict() for ll in labels: positives[int(ll.labelid[0])] = 0 negatives[int(ll.labelid[0])] = 0 for crop in collection.find(db_filter, skip): pos, neg = one_crop(crop, labels, db.gt_version) for ll, c in pos.items(): positives[ll] += int(c) for ll, c in neg.items(): negatives[ll] += int(c) sums = dict() for ll in pos.keys(): sums[ll] = negatives[ll] + positives[ll] stats = dict() stats["positives"] = positives stats["negatives"] = negatives stats["sums"] = sums if save is not None: if not save.endswith(".json"): save += ".json" with open(save, "w") as f: json.dump(stats, f) return stats
def transfer(training_version="v0003.2"): db = MongoCosemDB(training_version=training_version) eval_col = db.access("evaluation", db.training_version) eval_results_csv_folder = os.path.join( config_loader.get_config()["organelles"]["evaluation_path"], training_version, "evaluation_results") csv_d = CosemCSV(eval_results_csv_folder) for l in hierarchy.keys(): csv_d.erase(l) for db_entry in eval_col.find(): csv_d.write_evaluation_result(db_entry)
def get_refined_comparisons( db: cosem_db.MongoCosemDB, cropno: Union[None, str, int, Sequence[Union[str, int]]] = None ) -> List[Dict[str, Any]]: """ Get list of queries for predictions that have been refined (as read from csv file) Args: db: Database with crop information. cropno: Specific crop number or list of crop numbers that should be included in queries. Returns: List of queries for which refined predictions exist. """ csv_folder_refined = os.path.join( config_loader.get_config()["organelles"]["evaluation_path"], db.training_version, "refined") # get list of csv files relevant for crops if cropno is None: csv_result_files = os.listdir(csv_folder_refined) else: if isinstance(cropno, str) or isinstance(cropno, int): cropno = [cropno] csv_result_files = [] for cno in cropno: crop = db.get_crop_by_number(cno) csv_result_files.append( os.path.join(csv_folder_refined, crop["dataset_id"] + "_setup.csv")) # collect entries from those csv files queries = [] for csv_f in csv_result_files: f = open(os.path.join(csv_folder_refined, csv_f), "r") fieldnames = ["setup", "labelname", "iteration", "raw_dataset"] cell_id = re.split("_setup.csv", csv_f)[0] crop = db.get_validation_crop_by_cell_id(cell_id) reader = csv.DictReader(f, fieldnames) for row in reader: # only consider results that we can evaluate automatically (actually contained in the crop) if any(lbl in get_label_ids_by_category(crop, "present_annotated") for lbl in hierarchy[row["labelname"]].labelid): query = { "label": row["labelname"], "raw_dataset": row["raw_dataset"], "setup": row["setup"], "crop": crop["number"], "iteration": int(row["iteration"]) } queries.append(query) return queries
def compare_generalization( db: cosem_db.MongoCosemDB, metric: str, crops: Optional[Sequence[Union[str, int]]] = None, tol_distance: int = 40, clip_distance: int = 200, threshold: int = 200, raw_ds: Union[None, str, Sequence[str]] = "volumes/raw/s0", ) -> List[List[Optional[Dict[str, Any]]]]: """ Evaluate generalization experiments for er and mito. Args: db: Database with crop information and evaluation result. metric: Metric to use for comparison. crops: List of crops to run comparison on. If None will use all validation crops. tol_distance: tolerance distance when using a metric with tolerance distance, otherwise not used. clip_distance: clip distance when using a metric with clip distance threshold: Threshold to have been applied on top of raw predictions. raw_ds: raw dataset to run prediction on Returns: List of best results for setups involved in generalization experiments. Each result is a list with just one dictionary. """ setups = ["setup03", "setup61", "setup62", "setup63", "setup64"] labels = ["er", "mito", "nucleus", "plasma_membrane"] if crops is None: crops = [c["number"] for c in db.get_all_validation_crops()] results = [] for lbl in labels: for setup in setups: for cropno in crops: if crop_utils.check_label_in_crop( hierarchy.hierarchy[lbl], db.get_crop_by_number(cropno)): results.append([ analyze_evals.best_result(db, lbl, [setup], cropno, metric, raw_ds=raw_ds, tol_distance=tol_distance, clip_distance=clip_distance, threshold=threshold, test=False) ]) return results
def _get_iteration_queries(cropno: Sequence[Union[int, str]], db: cosem_db.MongoCosemDB) -> List[Dict[str, str]]: csv_folder_manual = os.path.join( config_loader.get_config()["organelles"]["evaluation_path"], db.training_version, "manual") csv_result_files = _get_csv_files(csv_folder_manual, "iteration", cropno, db) iteration_queries = [] for csv_f in csv_result_files: f = open(os.path.join(csv_folder_manual, csv_f), "r") fieldnames = ["setup", "labelname", "iteration", "raw_dataset"] cell_id = re.split("_(setup|iteration).csv", csv_f)[0] crop = db.get_validation_crop_by_cell_id(cell_id) reader = csv.DictReader(f, fieldnames) for row in reader: if any(lbl in get_label_ids_by_category(crop, "present_annotated") for lbl in hierarchy[row["labelname"]].labelid): query = { "label": row["labelname"], "raw_dataset": row["raw_dataset"], "setups": [row["setup"]], "crop": crop["number"] } iteration_queries.append(query) return iteration_queries
def plot_val_all_labels(db: MongoCosemDB, setup: str, path: str, threshold: int = 127, filetype: str = "pdf"): """ Plot validation graphs for all labels corresponding to a specific setup. Will be saved to files in `path`. Args: db: Database with crop information and evaluation results. setup: Setup to plot validation results for. path: Path in which to save all the plots. threshold: Threshold to be applied on top of raw predictions to generate binary segmentations for evaluation. filetype: Filetype for saving plots. """ valcrops = db.get_all_validation_crops() labels = get_unet_setup(setup).labels for lbl in labels: in_crop = [check_label_in_crop(lbl, crop) for crop in valcrops] if any(in_crop): file = os.path.join( path, "{label:}_{setup:}.{filetype:}".format(label=lbl.labelname, setup=setup, filetype=filetype)) plot_val(db, setup, lbl.labelname, file, threshold=threshold)
def max_evaluated_iteration(query: Dict[str, Any], db: cosem_db.MongoCosemDB) -> int: """ Find maximum iteration that is found in the database for the given query. Args: query: Dictionary defining query for which to find max iteration. db: Database with evaluation results. Returns: Maximum iteration for `query`. """ col = db.access("evaluation", db.training_version) max_it = col.aggregate([{ "$match": query }, { "$sort": { "iteration": -1 } }, { "$limit": 1 }, { "$project": { "iteration": 1, "_id": 0 } }]) max_it = [m for m in max_it][0] return max_it["iteration"]
def _best_manual( db: cosem_db.MongoCosemDB, label: str, setups: Sequence[str], cropno: Union[int, str], raw_ds: Optional[Sequence[str]] = None ) -> Optional[Dict[str, Union[str, int, bool]]]: # read csv file containing results of manual evaluation, first for best iteration c = db.get_crop_by_number(str(cropno)) csv_folder_manual = os.path.join( config_loader.get_config()["organelles"]["evaluation_path"], db.training_version, "manual") csv_file_iterations = open( os.path.join(csv_folder_manual, c["dataset_id"] + "_iteration.csv"), "r") fieldnames = ["setup", "labelname", "iteration", "raw_dataset"] reader = csv.DictReader(csv_file_iterations, fieldnames) # look for all possible matches with the given query best_manuals = [] for row in reader: if row["labelname"] == label and row["setup"] in setups: if raw_ds is None or row["raw_dataset"] in raw_ds: manual_result = { "setup": row["setup"], "label": row["labelname"], "iteration": int(row["iteration"]), "raw_dataset": row["raw_dataset"], "crop": str(cropno), "metric": "manual" } best_manuals.append(manual_result) if len( best_manuals ) == 0: # no manual evaluations with the given constraints were done return None elif len(best_manuals ) == 1: # if there's only one match it has to be the best one return best_manuals[0] else: # if there's several matches check the setup results for overall best # read csv file containing results of manual evaluations, now for best setup per label/crop csv_file_setups = open( os.path.join(csv_folder_manual, c["dataset_id"] + "_setup.csv"), "r") reader = csv.DictReader(csv_file_setups, fieldnames) for row in reader: if row["labelname"] == label and row["setup"] in setups: if raw_ds is None or row["raw_dataset"] in raw_ds: manual_result_best = { "setup": row["setup"], "label": row["labelname"], "iteration": int(row["iteration"]), "raw_dataset": row["raw_dataset"], "crop": str(cropno), "metric": "manual", "refined": False } return manual_result_best return None
def get_diff(db: cosem_db.MongoCosemDB, label: str, setups: Union[str, Sequence[str]], cropno: str, metric_best: str, metric_compare: str, raw_ds: Optional[str] = None, tol_distance: int = 40, clip_distance: int = 200, threshold: int = 127, test: bool = False) -> Dict[str, Any]: """ Compare two metrics by measuring performance using `metric_compare` but picking the best configuration using metric `metric_best`. Args: db: Database with evaluation results and crop information label: label for which to complete this comparison. setups: training setup for which or training setups across which to determine best configuration. cropno: crops to analyze to determine best configuration and measure performance metric_best: Metric to use for finding best configuration (iteration/iteration+setup) metric_compare: Metric to use for reporting performance using the best configuration determined by `metric_best`. raw_ds: raw datasets from which predictions could be pulled for the evaluations considered here tol_distance: tolerance distance when using a metric with tolerance distance, otherwise not used clip_distance: clip distance when using a metric with clip distance, otherwise not used threshold: threshold applied on top of distance predictions to generate binary segmentation test: whether to run in test mode Returns: Dictionary with evaluation result measured by `metric_compare` but optimized using `metric_best`. """ best_config = best_result(db, label, setups, cropno, metric_best, raw_ds=raw_ds, tol_distance=tol_distance, clip_distance=clip_distance, threshold=threshold, test=test) query_metric2 = best_config.copy() query_metric2["metric"] = metric_compare query_metric2["metric_params"] = filter_params( { "clip_distance": clip_distance, "tol_distance": tol_distance }, metric_compare) if best_config["metric"] != "manual": try: query_metric2.pop("value") query_metric2.pop("_id") except KeyError: query_metric2["value"] = None return query_metric2 compare_setup = db.find(query_metric2)[0] return compare_setup
def compare_refined( db: cosem_db.MongoCosemDB, metric: str, queries: Sequence[Dict[str, Any]], tol_distance: int = 40, clip_distance: int = 200, threshold: int = 127) -> List[Tuple[Dict[str, Any], Dict[str, Any]]]: """ For given queries read corresponding refined and unrefined results from the database for the given metric. Args: db: Database with crop information and evaluation results. metric: Metrics to use for comparison. queries: List of queries for which to compare results for refinements. tol_distance: tolerance distance when using a metric with tolerance distance, otherwise not used clip_distance: clip distance when using a metric with clip distance, otherwise not used threshold: threshold applied on top of distance predictions to generate binary segmentation Returns: List of tuples with evaluation results. The first entry will be the result before refinements, the second after refinements. """ comparisons = [] for qu in queries: qu["metric"] = metric qu["metric_params"] = filter_params( { "clip_distance": clip_distance, "tol_distance": tol_distance }, metric) qu["refined"] = True refined = db.find(qu) assert len(refined) == 1 refined = refined[0] qu["refined"] = False qu["threshold"] = threshold not_refined = db.find(qu) if len(not_refined) != 1: print([x for x in not_refined]) assert len(not_refined) == 1 not_refined = not_refined[0] comparisons.append((not_refined, refined)) return comparisons
def _get_csv_files(csv_folder_manual: str, domain: str, cropno: Sequence[Union[int, str]], db: cosem_db.MongoCosemDB) -> List[str]: if cropno is None: csv_result_files = os.listdir(csv_folder_manual) csv_result_files = [ fn for fn in csv_result_files if fn.endswith("_{0:}.csv".format(domain)) ] else: csv_result_files = [] for cno in cropno: crop = db.get_crop_by_number(cno) csv_result_files.append( os.path.join(csv_folder_manual, crop["dataset_id"] + "_{0:}.csv".format(domain))) return csv_result_files
def _get_setup_queries( cropno: Sequence[Union[int, str]], db: cosem_db.MongoCosemDB ) -> List[Dict[str, Union[str, Sequence[str]]]]: csv_folder_manual = os.path.join( config_loader.get_config()["organelles"]["evaluation_path"], db.training_version, "manual") csv_result_files = _get_csv_files(csv_folder_manual, "setup", cropno, db) setup_queries = [] for csv_f in csv_result_files: f = open(os.path.join(csv_folder_manual, csv_f), "r") fieldnames = ["setup", "labelname", "iteration", "raw_dataset"] cell_id = re.split("_(setup|iteration).csv", csv_f)[0] crop = db.get_validation_crop_by_cell_id(cell_id) reader = csv.DictReader(f, fieldnames) for row in reader: if any(lbl in get_label_ids_by_category(crop, "present_annotated") for lbl in hierarchy[row["labelname"]].labelid): # find the csv files with the list of setups compared for each label (4nm or 8nm) if row["raw_dataset"] == "volumes/raw/s0": ff = open( os.path.join(csv_folder_manual, "compared_4nm_setups.csv"), "r") elif row["raw_dataset"] == "volumes/subsampled/raw/0" or row[ "raw_dataset"] == "volumes/raw/s1": ff = open( os.path.join(csv_folder_manual, "compared_8nm_setups.csv"), "r") else: raise ValueError("The raw_dataset {0:} ".format( row["raw_dataset"])) # get that list of compared setups from the csv file compare_reader = csv.reader(ff) for compare_row in compare_reader: if compare_row[0] == row["labelname"]: setups = compare_row[1:] break # collect result query = { "label": row["labelname"], "raw_dataset": row["raw_dataset"], "setups": setups, "crop": crop["number"] } setup_queries.append(query) return setup_queries
def above_threshold(db: cosem_db.MongoCosemDB, query: Dict[str, Any], by: int = 500000) -> bool: """ Check whether predictions for the given `query` are ever above threshold in the validation crop by iteration `by`. Args: db: Database with evaluation results. query: Dictionary defining query for which to check for whether they're above threshold. by: Only check evaluation results up to this iteration (inclusive). Returns: True if there are any results above threshold by iteration `by`. False otherwise. """ qy = query.copy() qy["metric"] = "mean_false_distance" qy["value"] = {"$gt": 0} qy["iteration"] = {"$mod": [25000, 0], "$lte": by} eval_col = db.access("evaluation", db.training_version) return not (eval_col.find_one(qy) is None)
def _best_automatic(db: cosem_db.MongoCosemDB, label: str, setups: Sequence[str], cropno: Union[Sequence[str], Sequence[int]], metric: str, raw_ds: Optional[Sequence[str]] = None, tol_distance: int = 40, clip_distance: int = 200, threshold: int = 127, test: bool = False) -> Dict[str, Any]: metric_params = dict() metric_params["clip_distance"] = clip_distance metric_params["tol_distance"] = tol_distance filtered_params = filter_params(metric_params, metric) setups = [ setup for setup in setups if label in [lbl.labelname for lbl in autodiscover_labels(setup)] ] # in test mode the remaining validation crops are used for determining best configuration if test: cropnos_query = [ crop["number"] for crop in db.get_all_validation_crops() ] for cno in cropno: cropnos_query.pop(cropnos_query.index(str(cno))) cropnos_query = [ cno for cno in cropnos_query if check_label_in_crop( hierarchy[label], db.get_crop_by_number(cno)) ] else: cropnos_query = cropno if len(cropnos_query) == 0: # if no crops remain return without result final = { "value": None, "iteration": None, "label": label, "metric": metric, "metric_params": filtered_params, "refined": False, "threshold": threshold, "setup": setups[0] if len(setups) == 1 else None, "crop": cropno[0] if len(cropno) == 1 else { "$in": cropno } } if raw_ds is not None: final["raw_dataset"] = raw_ds[0] if len(raw_ds) == 1 else { "$in": raw_ds } return final # find max iterations and put corresponding conditions in query conditions = [] for setup in setups: # several setups if both iteration and setup are being optimized ("across_setups") max_its = [] for cno in cropnos_query: maxit_query = { "label": label, "crop": str(cno), "threshold": threshold, "refined": False, "setup": setup } if raw_ds is not None: maxit_query["raw_dataset"] = {"$in": raw_ds} maxit, valid = max_iteration_for_analysis(maxit_query, db) max_its.append(maxit) conditions.append({ "setup": setup, "iteration": { "$lte": max(max_its) } }) if len(conditions) > 1: match_query = {"$or": conditions} else: match_query = conditions[0] # prepare aggregation of best configuration on the database aggregator = [] # match match_query.update({ "crop": { "$in": cropnos_query }, "label": label, "metric": metric, "metric_params": filtered_params, "threshold": threshold, "value": { "$ne": np.nan }, "refined": False }) if raw_ds is not None: match_query["raw_dataset"] = {"$in": raw_ds} aggregator.append({"$match": match_query}) # for each combination of setup and iteration, and raw_dataset if relevant, average across the matched results crossval_group = { "_id": { "setup": "$setup", "iteration": "$iteration" }, "score": { "$avg": "$value" } } if raw_ds is not None: crossval_group["_id"]["raw_dataset"] = "$raw_dataset" aggregator.append({"$group": crossval_group}) # sort (descending/ascending determined by metric) by averaged score aggregator.append( {"$sort": { "score": sorting(metric), "_id.iteration": 1 }}) # only need max so limit results to one (mongodb can take advantage of this for sort) aggregator.append({"$limit": 1}) # extract setup and iteration, and raw_dataset if relevant, in the end projection = { "setup": "$_id.setup", "iteration": "$_id.iteration", "_id": 0 } if raw_ds is not None: projection["raw_dataset"] = "$_id.raw_dataset" aggregator.append({"$project": projection}) # run the aggregation on the evaluation database col = db.access("evaluation", db.training_version) best_config = list(col.aggregate(aggregator)) if len(best_config) == 0: # if no results are found, return at this point final = match_query.copy() # final result should have actual cropno if len(cropno) == 1: final["crop"] = cropno[0] else: final["crop"] = {"$in": cropno} final.update({"setup": None, "value": None, "iteration": None}) return final else: best_config = best_config[0] all_best = [] for cno in cropno: query_best = { "label": label, "crop": str(cno), "metric": metric, "setup": best_config["setup"], "metric_params": filtered_params, "threshold": threshold, "iteration": best_config["iteration"], "refined": False } if raw_ds is not None: query_best["raw_dataset"] = best_config["raw_dataset"] best_this = db.find(query_best) if len(best_this) != 1: print("query:", query_best) print("results:", list(best_this)) assert len(best_this) == 1, "Got more than one result for best" all_best.append(best_this[0]) # average results for the case of several crops final = dict() final["value"] = np.mean([ab["value"] for ab in all_best]) # assemble all entries that are shared by the best result for each crop all_keys = set( all_best[0].keys()).intersection(*(d.keys() for d in all_best)) - {"value"} for k in all_keys: if all([ab[k] == all_best[0][k] for ab in all_best]): final[k] = all_best[0][k] return final
def convergence_iteration( query: Dict[str, Any], db: cosem_db.MongoCosemDB, check_evals_complete: bool = True) -> Tuple[int, int]: """ Find the first iteration that meets the convergence criterion in 25k intervals. Convergence criterion is that both mean_false_distance and dice score indicate a decreasing performance for two consecutive evaluation points. If predictions don't produce above threshold segmentations by 500k iterations no higher iterations are considered. Args: query: Dictionary specifying which set of configuration to consider for the maximum iteration. This will typically contain keys for setups, label and crop. db: Database containing the evaluation results. check_evals_complete: Whether to first check whether the considered evaluations are consistent across the queries (i.e. same for all crops/labels/raw_datasets within one setup, at least to 500k, if above threshold by 500k at least to 700k). Should generally be set to True unless this has already been checked. Returns: The converged or maximum evaluated iteration and a flag indicating whether this represents a converged training. 0 for not converged training, 1 for converged trainings, 2 for trainings that have not reached above threshold predictions by 500k iterations, 3 for trainings that have not reached the convergence criterion but 2,000,000 iterations. Raises: ValueError if no evaluations are found for given query. """ query["iteration"] = {"$mod": [25000, 0]} metrics = ("dice", "mean_false_distance") col = db.access("evaluation", db.training_version) # check whether anything has been evaluated for this query type query_any = query.copy() query_any["metric"] = {"$in": metrics} if col.find_one(query_any) is None: raise ValueError("No evaluations found for query {0:}".format(query)) if check_evals_complete: if not check_completeness(db, spec_query=query.copy()): return max_evaluated_iteration(query, db), 0 if not above_threshold(db, query): return 500000, 2 # get results and sort by iteration results = [] for met in metrics: qy = query.copy() qy["metric"] = met qy["iteration"]["$lte"] = 2000000 results.append( list(col.aggregate([{ "$match": qy }, { "$sort": { "iteration": 1 } }]))) # check for convergence criterion for k in range(2, len(results[0])): this_one = [ False, ] * len(metrics) for m_no, met in enumerate(metrics): if np.isnan(results[m_no][k]["value"]) and np.isnan( results[m_no][k - 1]["value"]) and not np.isnan( results[m_no][k - 2]["value"]): this_one[m_no] = True else: if sorting(EvaluationMetrics[met]) == -1: if results[m_no][k]["value"] <= results[m_no][ k - 1]["value"] < results[m_no][k - 2]["value"]: this_one[m_no] = True elif (np.isnan(results[m_no][k]["value"]) and not np.isnan(results[m_no][k-1]["value"]) and not \ np.isnan(results[m_no][k-2]["value"]) and results[m_no][k-1]["value"] < results[m_no][ k-2]["value"]): this_one[m_no] = True else: if results[m_no][k]["value"] >= results[m_no][ k - 1]["value"] > results[m_no][k - 2]["value"]: this_one[m_no] = True elif (np.isnan(results[m_no][k]["value"]) and not np.isnan(results[m_no][k-1]["value"]) and not \ np.isnan(results[m_no][k-2]["value"]) and results[m_no][k-1]["value"] > results[m_no][ k-2]["value"]): this_one[m_no] = True if all(this_one): return results[0][k]["iteration"], 1 if max_evaluated_iteration(query, db) >= 2000000: return 2000000, 3 return results[0][-1]["iteration"], 0
def check_completeness(db: cosem_db.MongoCosemDB, setup: Optional[str] = None, metric_params: Optional[Dict[str, int]] = None, threshold: int = None, spec_query: Optional[Dict[str, Any]] = None) -> bool: """ Check whether for the given configuration each relevant label/raw_dataset is evaluated for all metrics for the same iterations (in 25k intervals, starting from 25k). Args: db: Database with crop information and evaluation results. setup: Network setup to check. metric_params: Dictionary with metric parameters. threshold: Value at which predictions were thresholded for evaluation. spec_query: Alternatively to specifying these arguments they can be fed in as a dictionary. Returns: False if any evaluations are missing/inconsistent. Otherwise, True. """ if spec_query is None: spec_query = dict() if "refined" in spec_query: if spec_query["refined"]: logging.info("Check for refined predictions not necessary") return True if setup is None: try: setup = spec_query["setup"] except KeyError: raise ValueError( "setup needs to be specified as kwarg or in spec_query") if threshold is None: try: threshold = spec_query["threshold"] except KeyError: raise ValueError( "threshold needs to be specified as kwarg or in spec_query") if metric_params is None: try: metric_params = spec_query["metric_params"] except KeyError: logging.warning( "metric_params not specified as kwarg or spec_query, defaulting to tol_distance=40," "clip_distance=200") metric_params = {"clip_distance": 200, "tol_distance": 40} if "raw_dataset" in spec_query: if isinstance(spec_query["raw_dataset"], dict): try: raw_datasets = spec_query["raw_dataset"]["$in"] except KeyError: raise NotImplementedError( "don't know how to do check with query {0:} for raw_dataset" .format(spec_query["raw_dataset"])) else: raw_datasets = [ spec_query["raw_dataset"], ] else: raw_datasets = autodiscover_raw_datasets(setup) if "crop" in spec_query: if not (isinstance(spec_query["crop"], int) or isinstance(spec_query["crop"], str)): raise NotImplementedError( "can't check query with complicated query for crop") else: spec_query["crop"] = str(spec_query["crop"]) if "label" in spec_query and "crop" in spec_query: label_to_cropnos = { spec_query["label"]: [ spec_query["crop"], ] } else: label_to_cropnos = autodiscover_label_to_crops(setup, db) if "label" in spec_query: label_to_cropnos = dict((k, v) for k, v in label_to_cropnos.items() if k in spec_query["label"]) if "crop" in spec_query: for k, v in label_to_cropnos: new_v = [vv for vv in v if vv in [spec_query["crop"]]] if len(new_v) > 0: label_to_cropnos[k] = new_v else: del label_to_cropnos[k] if len(label_to_cropnos) == 0: return True eval_col = db.access("evaluation", db.training_version) will_return = True iterations_col = [] for met in segmentation_metrics.EvaluationMetrics: met_specific_params_nested = segmentation_metrics.filter_params( metric_params, met) if met_specific_params_nested: for k, v in met_specific_params_nested.items(): if not isinstance(v, collections.Iterable): met_specific_params_nested[k] = [v] met_specific_params_it = list( itertools.chain( *[[{ k: vv } for vv in v] for k, v in met_specific_params_nested.items()])) else: met_specific_params_it = [ dict(), ] for met_specific_params in met_specific_params_it: for lblname, cropnos in label_to_cropnos.items(): for cropno in cropnos: for raw_ds in raw_datasets: query = { "setup": setup, "raw_dataset": raw_ds, "crop": cropno, "label": lblname, "threshold": threshold, "refined": False, "iteration": { "$mod": [25000, 0] }, "metric": met, "metric_params": met_specific_params } iterations = list( eval_col.aggregate([{ "$match": query }, { "$sort": { "iteration": 1 } }, { "$project": { "iteration": True, "_id": False } }])) iterations_col.append( [it["iteration"] for it in iterations]) if len(iterations_col) > 1: if iterations_col[-1] != iterations_col[-2]: print( "Results for query {0:} not matching: {1:}" .format(query, iterations_col[-1])) will_return = False if not iterations_col[-1] == list( range(25000, iterations_col[-1][-1] + 1, 25000)): print("Missing checkpoints, found: {0:}".format(iterations_col[-1])) will_return = False if not iterations_col[-1][-1] >= 500000: print("Not evaluated to 500000 iterations.") will_return = False # check till 700k if pos results until 500k if will_return: for lblname, cropnos in label_to_cropnos.items(): for cropno in cropnos: for raw_ds in raw_datasets: query = { "setup": setup, "raw_dataset": raw_ds, "crop": cropno, "label": lblname, "threshold": threshold, "refined": False } # if above threshold results are found by 500k iterations, network should be evaluated until at # least 700k iterations if above_threshold(db, query): query[ "iteration"] = 700000 # if 700k exists the ones in between exist (checked abvoe) if eval_col.find_one(query) is None: print( "For query {0:}, above threshold results are found by 500k iterations but network " "isn't evaluated to at least 700k".format( query)) will_return = False return will_return
def main() -> None: main_parser = argparse.ArgumentParser("Plot validation graphs") parser = main_parser.add_subparsers(dest="script", help="") all_parser = parser.add_parser( "all_setups", help="Validation graphs for all default setups.") all_parser.add_argument( "--threshold", type=int, default=127, help= ("Threshold to be applied on top of raw predictions to generate binary " "segmentations for evaluation.")) all_parser.add_argument("--path", type=str, default='.', help="Path to save validation graphs to.") all_parser.add_argument("--filetype", type=str, default='pdf', help="Filetype for validation plots.") all_parser.add_argument( "--training_version", type=str, default="v0003.2", help="Version of training for which to plot validation graphs.") all_parser.add_argument("--gt_version", type=str, default="v0003", help="Version of groundtruth to consider.") setup_parser = parser.add_parser( "all_labels", help="Validation graphs for all labels of a specific setup.") setup_parser.add_argument("setup", type=str, help="Setup to make validation graphs for.") setup_parser.add_argument("--path", type=str, default="Path to save validation graphs to.") setup_parser.add_argument("--filetype", type=str, default="pdf", help="Filetype for validation plots.") setup_parser.add_argument( "--threshold", type=int, default=127, help= ("Threshold to be applied on top of raw predictions to generate binary " "segmentations for evaluation.")) setup_parser.add_argument( "--training_version", type=str, default="v0003.2", help="Version of training for which to plot validation graphs.") setup_parser.add_argument("--gt_version", type=str, default="v0003", help="Version of groundtruth to consider.") single_parser = parser.add_parser( "single", help="Validation graph for a specific setup and label.") single_parser.add_argument("setup", type=str, help="Setup to make validation graph for.") single_parser.add_argument("label", type=str, help="Label to make validation graph for.") single_parser.add_argument( "--file", type=str, help="File to save validation graph to. (Full path)") single_parser.add_argument( "--threshold", type=int, default=127, help= ("Threshold to be applied on top of raw predictions to generate binary " "segmentations for evaluation.")) single_parser.add_argument( "--training_version", type=str, default="v0003.2", help="Version of training for which to plot validation graphs.") single_parser.add_argument("--gt_version", type=str, default="v0003", help="Version of groundtruth to consider.") args = main_parser.parse_args() db = MongoCosemDB(training_version=args.training_version, gt_version=args.gt_version) if args.script == "all_setups": plot_val_all_setups(db, args.path, threshold=args.threshold, filetype=args.filetype) elif args.script == "all_labels": plot_val_all_labels(db, args.setup, args.path, threshold=args.threshold, filetype=args.filetype) else: plot_val(db, args.setup, args.label, args.file, threshold=args.threshold)
def plot_val(db: MongoCosemDB, setup: str, labelname: str, file: Optional[str], threshold: int = 127) -> None: """ Plot validation graph for a specific setup and label. Can be saved by specifying a `file`. Args: db: Database with crop information and evaluation results. setup: Setup to plot validation results for. labelname: Label to plot validation results for. file: File to save validation plot to, if None show plot instead. threshold: Threshold to be applied on top of raw predictions to generate binary segmentations for evaluation. Returns: None. """ print(setup, labelname) valcrops = db.get_all_validation_crops() if len(valcrops) != 4: raise NotImplementedError( "Number of validation crops has changed so plotting layout has to be updated" ) if detect_8nm(setup): raw_datasets = ["volumes/subsampled/raw/0", "volumes/raw/s1"] else: raw_datasets = ["volumes/raw/s0"] col = db.access("evaluation", db.training_version) # query all relevant results query = { "setup": setup, "label": labelname, "refined": False, "iteration": { "$mod": [25000, 0] }, "threshold": threshold } results = dict() max_its = dict() for crop in valcrops: query["crop"] = crop["number"] if col.find_one(query) is None: continue results[crop["number"]] = dict() max_its[crop["number"]] = dict() for raw_ds in raw_datasets: query["raw_dataset"] = raw_ds results[crop["number"]][raw_ds] = dict() max_its[crop["number"]][raw_ds] = dict() max_it_actual = convergence_iteration(query, db) max_it_min700 = max_iteration_for_analysis(query, db, conv_it=max_it_actual) max_its[crop["number"]][raw_ds]["actual"] = max_it_actual max_its[crop["number"]][raw_ds]["min700"] = max_it_min700 for metric in ["dice", "mean_false_distance"]: query["metric"] = metric col = db.access("evaluation", db.training_version) scores = list( col.aggregate([{ "$match": query }, { "$sort": { "iteration": 1 } }, { "$project": { "iteration": 1, "_id": 0, "value": 1 } }])) results[crop["number"]][raw_ds][metric] = scores colors = {"dice": "tab:green", "mean_false_distance": "tab:blue"} fig, axs = plt.subplots(2, 2, sharex=True, sharey=True, figsize=(30, 15)) if len(raw_datasets) > 1: plt.plot([], [], marker='.', ms=1.2, linestyle="-", color="k", label="subsampled") plt.plot([], [], marker='.', ms=1.2, linestyle="--", color="k", label="averaged") plt.plot([], [], linestyle="-", color="tab:red", label="max iteration (min 700k)") plt.plot([], [], linestyle="-", color="tab:pink", label="max iteration (no min)") fig.legend(loc='upper right', frameon=False, prop={"size": 18}) plt.suptitle("{setup:} - {label:}".format(setup=setup, label=labelname), fontsize=22) for crop, ax in zip(valcrops, axs.flatten()): try: crop_res = results[crop["number"]] except KeyError: continue ax2 = ax.twinx() for raw_ds, ls in zip(raw_datasets, ["-", "--"]): x_vs_dice = [r["iteration"] for r in crop_res[raw_ds]["dice"]] y_vs_dice = [1 - r["value"] for r in crop_res[raw_ds]["dice"]] x_vs_mfd = [ r["iteration"] for r in crop_res[raw_ds]["mean_false_distance"] ] y_vs_mfd = [ r["value"] for r in crop_res[raw_ds]["mean_false_distance"] ] ax.plot(x_vs_mfd, y_vs_mfd, linestyle=ls, color=colors["mean_false_distance"], marker='o', ms=3) ax2.plot(x_vs_dice, y_vs_dice, linestyle=ls, color=colors["dice"], marker='o', ms=3) if max_its[crop["number"]][raw_ds]["min700"][1]: ax.axvline(max_its[crop["number"]][raw_ds]["min700"][0], linestyle=ls, color="tab:red") if max_its[crop["number"]][raw_ds]["min700"][0] != max_its[ crop["number"]][raw_ds]["actual"][0]: ax.axvline(max_its[crop["number"]][raw_ds]["actual"][0], linestyle=ls, color="tab:pink") ax.set_xlabel("iteration", fontsize=18) ax.set_title(crop["number"], fontsize=18) ax.xaxis.set_major_formatter(ticker.EngFormatter()) ax.set_ylabel("MFD", color=colors["mean_false_distance"], fontsize=18) ax.tick_params(axis="y", labelcolor=colors["mean_false_distance"]) ax.set_ylim(bottom=0) ax2.set_ylabel("1 - dice", color=colors["dice"], fontsize=18) ax2.tick_params(axis="y", labelcolor=colors["dice"]) ax2.set_ylim([0, 1]) ax.tick_params(axis="both", which="major", labelsize=18) ax2.tick_params(axis="both", which="major", labelsize=18) if file is None: plt.show() else: plt.savefig(file) plt.close()
def compare_evaluation_methods( db: cosem_db.MongoCosemDB, metric_compare: str, # may not be manual metric_bestby: str, # may be manual queries: List[Union[Dict[str, str], Dict[str, Union[str, Sequence[str]]]]], tol_distance: int = 40, clip_distance: int = 200, threshold: int = 127, test: bool = False) -> List[Tuple[Dict[str, Any], Dict[str, Any]]]: """ Compare different metrics for evaluation by picking one metric (`metric_compare`) to report results and optimizing the configuration (iteration/iteration+setup) with that metric on the one hand and the metric `metric_bestby` on the other hand. Args: db: Database with crop information and evaluation results. metric_compare: Metric to use for reporting performance - using the best configuration determined by `metric_bestby` compared to this metric. metric_bestby: Metric to use for finding best configuration (iteration/iteration+setup) queries: List of queries for which to compare metrics. tol_distance: tolerance distance when using a metric with tolerance distance, otherwise not used clip_distance: clip distance when using a metric with clip distance, otherwise not used threshold: threshold applied on top of distance predictions to generate binary segmentation test: whether to run in test mode Returns: List of Tuples with evaluation result (reported via `metric_compare`). The first entry will be optimized directly for `metric_compare`, the second entry will be optimized for `metric_bestby`. """ comparisons = [] for qu in queries: for setup in qu["setups"]: test_query = { "setup": setup, "crop": qu["crop"], "label": qu["label"], "raw_dataset": qu["raw_dataset"], "metric": { "$in": [metric_compare, metric_bestby] } } if len(db.find(test_query)) == 0: raise RuntimeError( "No results found in database for {0:}".format(test_query)) best_setup = best_result(db, qu["label"], qu["setups"], qu["crop"], metric_compare, raw_ds=qu["raw_dataset"], tol_distance=tol_distance, clip_distance=clip_distance, threshold=threshold, test=test) compare_setup = get_diff(db, qu["label"], qu["setups"], qu["crop"], metric_bestby, metric_compare, raw_ds=qu["raw_dataset"], tol_distance=tol_distance, clip_distance=clip_distance, threshold=threshold, test=test) comparisons.append((best_setup, compare_setup)) return comparisons
def best_8nm(db: cosem_db.MongoCosemDB, metric: str, crops: Optional[Sequence[Union[str, int]]], tol_distance: int = 40, clip_distance: int = 200, threshold: int = 200, mode: str = "across-setups", raw_ds: Union[None, str, Sequence[str]] = "volumes/subsampled/raw/0", test: bool = False) -> List[List[Dict[str, Any]]]: """ Get the best results for the 8nm setups. Args: db: Database with crop information and evaluation result. metric: Metric to report and use for optimiation of iteration/setup. crops: List of crops to run comparison on. If None will use all validation crops. tol_distance: tolerance distance when using a metric with tolerance distance, otherwise not used. clip_distance: clip distance when using a metric with clip distance, otherwise not used. threshold: Threshold to have been applied on top of raw predictions. mode: "across-setups" to optimize both setup+iteration or "per-setup" to optimize iteration for a fixed setup. raw_ds: raw dataset to run prediction on. test: whether to run in test mode. Returns: List of best results. Each result is a list with just one dictionary. """ if mode == "across-setups": setups = [ "setup04", "setup26.1", "setup28", "setup32", "setup36", "setup46", "setup48" ] labels = [ "ecs", "plasma_membrane", "mito", "mito_membrane", "mito_DNA", "vesicle", "vesicle_membrane", "MVB", "MVB_membrane", "lysosome", "lysosome_membrane", "er", "er_membrane", "ERES", "nucleus", "microtubules", "microtubules_out" ] elif mode == "per-setup": setups = [ "setup04", "setup04", "setup04", "setup04", "setup04", "setup04", "setup04", "setup04", "setup04", "setup04", "setup04", "setup04", "setup04", "setup04", "setup26.1", "setup26.1", "setup26.1", "setup28", "setup28", "setup32", "setup32", "setup36", "setup46", "setup46", "setup48", "setup48", "setup48", "setup48" ] labels = [ "ecs", "plasma_membrane", "mito", "mito_membrane", "vesicle", "vesicle_membrane", "MVB", "MVB_membrane", "er", "er_membrane", "ERES", "nucleus", "microtubules", "microtubules_out", "mito", "mito_membrane", "mito_DNA", "er", "er_membrane", "microtubules", "microtubules_out", "nucleus", "ecs", "plasma_membrane", "MVB", "MVB_membrane", "lysosome", "lysosome_membrane" ] else: raise ValueError("unknown mode {0:}".format(mode)) results = [] if crops is None: crops = [c["number"] for c in db.get_all_validation_crops()] for cropno in crops: if mode == "across-setups": for lbl in labels: if crop_utils.check_label_in_crop( hierarchy.hierarchy[lbl], db.get_crop_by_number(cropno)): results.append([ analyze_evals.best_result(db, lbl, setups, cropno, metric, raw_ds=raw_ds, tol_distance=tol_distance, clip_distance=clip_distance, threshold=threshold, test=test) ]) elif mode == "per-setup": for setup, lbl in zip(setups, labels): if crop_utils.check_label_in_crop( hierarchy.hierarchy[lbl], db.get_crop_by_number(cropno)): results.append([ analyze_evals.best_result(db, lbl, [setup], cropno, metric, raw_ds=raw_ds, tol_distance=tol_distance, clip_distance=clip_distance, threshold=threshold, test=test) ]) return results
def compare_setups(db: cosem_db.MongoCosemDB, setups_compare: Sequence[Sequence[str]], labels: Sequence[str], metric: str, raw_ds: Optional[Sequence[str]] = None, crops: Optional[Sequence[Union[str, int]]] = None, tol_distance: int = 40, clip_distance: int = 200, threshold: int = 127, mode: str = "across_setups", test: bool = False) -> List[List[Optional[Dict[str, Any]]]]: """ Flexibly query comparisons from the evaluation database. `setups_compare` and optionally `raw_ds` define sets of settings that should be compared. Args: db: Database with crop information and evaluation results. setups_compare: List of list of setups to compare. labels: List of labels. In a `mode` = "per_setup" evaluation, these are paired with the entries of the entries in each list of `setups_compare`. metric: Metric to evaluate for comparing the setups. raw_ds: List of raw datasets to consider for querying pulled predictions, can be None if it doesn't matter. crops: List of crop numbers to evaluate for. If None it'll be all validation crops tol_distance: Tolerance distance when using a metric with tolerance distance, otherwise not used clip_distance: Clip distance when using a metric with clip distance, otherwise not used. threshold: Threshold applied on top of distance predictions to generate binary segmentation. mode: "across_setups" or "per"setup" depending on whether the configuration that should be optimized is both the setup and the iteration ("across_setups") or just the iteration for a given setup ("per_setup"_ test: whether to run in test mode Returns: List of comparisons. Each entry corresponds to a cropno and label and each entry is itself a list with entries corresponding to the each list in `setups_compare` and optionally `raw_ds`. """ comparisons = [] if crops is None: crops = [c["number"] for c in db.get_all_validation_crops()] if mode == "across_setups": # for one label find best result across setups for cropno in crops: for lbl in labels: comp = [] for k, setups in enumerate(setups_compare): if raw_ds is None: rd = None else: rd = raw_ds[k] comp.append( best_result(db, lbl, setups, cropno, metric, raw_ds=rd, tol_distance=tol_distance, clip_distance=clip_distance, threshold=threshold, test=test)) comparisons.append(comp) elif mode == "per_setup": # find best result for each combination of setup and label for cropno in crops: comps = [[] for _ in labels] for k, setups in enumerate(setups_compare): if raw_ds is None: rd = None else: rd = raw_ds[k] for kk, (lbl, setup) in enumerate(zip(labels, setups)): comps[kk].append( best_result(db, lbl, setup, cropno, metric, raw_ds=rd, tol_distance=tol_distance, clip_distance=clip_distance, threshold=threshold, test=test)) comparisons.extend(comps) return comparisons