Exemplo n.º 1
0
def get_positive_negative_stats(
    true_cat: FullCatalog,
    est_tile_cat: TileCatalog,
    mag_max: float = np.inf,
):
    true_cat = true_cat.apply_mag_bin(-np.inf, mag_max)
    thresholds = np.linspace(0.01, 0.99, 99)
    est_tile_cat = est_tile_cat.copy()

    res = Parallel(n_jobs=10)(
        delayed(stats_for_threshold)(true_cat.plocs, est_tile_cat, t)
        for t in tqdm(thresholds))
    out: Dict[str, Union[int, Tensor]] = {}
    for k in res[0]:
        out[k] = torch.stack([r[k] for r in res])
    out["n_obj"] = true_cat.plocs.shape[1]
    return out
Exemplo n.º 2
0
def scene_metrics(
    true_params: FullCatalog,
    est_params: FullCatalog,
    mag_min: float = -np.inf,
    mag_max: float = np.inf,
    slack: float = 1.0,
) -> dict:
    """Return detection and classification metrics based on a given ground truth.

    These metrics are computed as a function of magnitude based on the specified
    bin `(mag_min, mag_max)` but are designed to be independent of the estimated magnitude.
    Hence, precision is computed by taking a cut in the estimated parameters based on the magnitude
    bin and matching them with *any* true objects. Similarly, recall is computed by taking a cut
    on the true parameters and matching them with *any* predicted objects.

    Args:
        true_params: True parameters of each source in the scene (e.g. from coadd catalog)
        est_params: Predictions on scene obtained from predict_on_scene function.
        mag_min: Discard all objects with magnitude lower than this.
        mag_max: Discard all objects with magnitude higher than this.
        slack: Pixel L-infinity distance slack when doing matching for metrics.

    Returns:
        Dictionary with output from DetectionMetrics, ClassificationMetrics.
    """
    detection_metrics = DetectionMetrics(slack)
    classification_metrics = ClassificationMetrics(slack)

    # precision
    eparams = est_params.apply_mag_bin(mag_min, mag_max)
    detection_metrics.update(true_params, eparams)
    precision = detection_metrics.compute()["precision"]
    detection_metrics.reset()  # reset global state since recall and precision use different cuts.

    # recall
    tparams = true_params.apply_mag_bin(mag_min, mag_max)
    detection_metrics.update(tparams, est_params)
    recall = detection_metrics.compute()["recall"]
    n_galaxies_detected = detection_metrics.compute()["n_galaxies_detected"]
    detection_metrics.reset()

    # f1-score
    f1 = 2 * precision * recall / (precision + recall)
    detection_result = {
        "precision": precision.item(),
        "recall": recall.item(),
        "f1": f1.item(),
        "n_galaxies_detected": n_galaxies_detected.item(),
    }

    # classification
    tparams = true_params.apply_mag_bin(mag_min, mag_max)
    classification_metrics.update(tparams, est_params)
    classification_result = classification_metrics.compute()

    # report counts on each bin
    tparams = true_params.apply_mag_bin(mag_min, mag_max)
    eparams = est_params.apply_mag_bin(mag_min, mag_max)
    tcount = tparams.n_sources.int().item()
    tgcount = tparams["galaxy_bools"].sum().int().item()
    tscount = tcount - tgcount

    ecount = eparams.n_sources.int().item()
    egcount = eparams["galaxy_bools"].sum().int().item()
    escount = ecount - egcount

    n_matches = classification_result["n_matches"]
    n_matches_gal_coadd = classification_result["n_matches_gal_coadd"]

    counts = {
        "tgcount": tgcount,
        "tscount": tscount,
        "egcount": egcount,
        "escount": escount,
        "n_matches_coadd_gal": n_matches_gal_coadd.item(),
        "n_matches_coadd_star": n_matches.item() - n_matches_gal_coadd.item(),
    }

    # compute and return results
    return {**detection_result, **classification_result, "counts": counts}
Exemplo n.º 3
0
def calc_scene_metrics_by_mag(est_cat: FullCatalog, true_cat: FullCatalog,
                              mag_start: int, mag_end: int, loc_slack: float):
    scene_metrics_by_mag: Dict[str, Dict[str, Number]] = {}
    mag_mins = [float(m - 1)
                for m in range(mag_start, mag_end + 1)] + [-np.inf]
    mag_maxes = [float(m) for m in range(mag_start, mag_end + 1)] + [mag_end]
    for mag_min, mag_max in zip(mag_mins, mag_maxes):
        # report counts on each bin
        true_cat_binned = true_cat.apply_mag_bin(mag_min, mag_max)
        est_cat_binned = est_cat.apply_mag_bin(mag_min, mag_max)
        tcount = true_cat_binned.n_sources.int().item()
        tgcount = true_cat_binned["galaxy_bools"].sum().int().item()
        tscount = tcount - tgcount

        detection_metrics = reporting.DetectionMetrics(loc_slack)
        classification_metrics = reporting.ClassificationMetrics(loc_slack)

        # precision
        est_cat_binned = est_cat.apply_mag_bin(mag_min, mag_max)
        detection_metrics.update(true_cat, est_cat_binned)
        precision_metrics = detection_metrics.compute()
        fp = precision_metrics["fp"].item()
        # reset global state since recall and precision use different cuts.
        detection_metrics.reset()

        # recall
        true_cat_binned = true_cat.apply_mag_bin(mag_min, mag_max)
        detection_metrics.update(true_cat_binned, est_cat)
        detection_dict = detection_metrics.compute()
        tp = detection_dict["tp"].item()
        tp_gal = detection_dict["n_galaxies_detected"].item()
        tp_star = tp - tp_gal
        detection_metrics.reset()

        # classification
        classification_metrics.update(true_cat_binned, est_cat)
        classification_result = classification_metrics.compute()
        n_matches = classification_result["n_matches"].item()

        conf_matrix = classification_result["conf_matrix"]
        galaxy_acc = conf_matrix[0,
                                 0] / (conf_matrix[0, 0] + conf_matrix[0, 1])
        star_acc = conf_matrix[1, 1] / (conf_matrix[1, 1] + conf_matrix[1, 0])

        if np.isinf(mag_min):
            mag = "overall"
        else:
            mag = str(int(mag_max))

        scene_metrics_by_mag[mag] = {
            "tcount": tcount,
            "tgcount": tgcount,
            "tp": tp,
            "fp": fp,
            "recall": tp / tcount if tcount > 0 else 0.0,
            "precision": tp / (tp + fp) if (tp + fp) > 0 else 1.0,
            "tp_gal": tp_gal,
            "recall_gal": tp_gal / tgcount if tgcount > 0 else 0.0,
            "tp_star": tp_star,
            "recall_star": tp_star / tscount if tscount > 0 else 0.0,
            "classif_n_matches": n_matches,
            "classif_acc": classification_result["class_acc"].item(),
            "classif_galaxy_acc": galaxy_acc.item(),
            "classif_star_acc": star_acc.item(),
        }

    d: DefaultDict[str, Dict[str, Number]] = defaultdict(dict)
    for mag, scene_metrics_mag in scene_metrics_by_mag.items():
        for measure, value in scene_metrics_mag.items():
            d[measure][mag] = value
    return pd.DataFrame(d)