Пример #1
0
def test_multiple_sources_one_tile():
    d = {
        "n_sources": torch.tensor([2]),
        "plocs": torch.tensor([[0.5, 0.5], [0.6, 0.6]]).reshape(1, 2, 2),
        "galaxy_bools": torch.tensor([1, 1]).reshape(1, 2, 1),
    }
    full_cat = FullCatalog(2, 2, d)

    with pytest.raises(ValueError) as error_info:
        full_cat.to_tile_params(1, 1, ignore_extra_sources=False)
    assert error_info.value.args[
        0] == "Number of sources per tile exceeds `max_sources_per_tile`."

    # should return only first source in first tile.
    tile_cat = full_cat.to_tile_params(1, 1, ignore_extra_sources=True)
    assert torch.equal(tile_cat.n_sources,
                       torch.tensor([[1, 0], [0, 0]]).reshape(1, 2, 2))
    assert torch.equal(
        tile_cat.locs,
        torch.tensor([[[0.5, 0.5], [0, 0]], [[0, 0],
                                             [0, 0]]]).reshape(1, 2, 2, 1, 2))
    assert torch.equal(
        tile_cat["galaxy_bools"],
        torch.tensor([[[1.0], [0.0]], [[0.0], [0.0]]]).reshape(1, 2, 2, 1, 1),
    )
Пример #2
0
def test_catalog_conversion():
    true_params = FullCatalog(
        100,
        100,
        {
            "plocs": torch.tensor([[[51.8, 49.6]]]).float(),
            "galaxy_bools": torch.tensor([[[1]]]).float(),
            "n_sources": torch.tensor([1]).long(),
            "mags": torch.tensor([[[23.0]]]).float(),
        },
    )
    true_tile_params = true_params.to_tile_params(tile_slen=4, max_sources_per_tile=1)
    tile_params_tilde = true_tile_params.to_full_params()
    assert true_params.equals(tile_params_tilde)
Пример #3
0
def test_metrics():
    slen = 50
    slack = 1.0
    detect = DetectionMetrics(slack)
    classify = ClassificationMetrics(slack)

    true_locs = torch.tensor([[[0.5, 0.5], [0.0, 0.0]], [[0.2, 0.2], [0.1, 0.1]]]).reshape(2, 2, 2)
    est_locs = torch.tensor([[[0.49, 0.49], [0.1, 0.1]], [[0.19, 0.19], [0.01, 0.01]]]).reshape(
        2, 2, 2
    )
    true_galaxy_bools = torch.tensor([[1, 0], [1, 1]]).reshape(2, 2, 1)
    est_galaxy_bools = torch.tensor([[0, 1], [1, 0]]).reshape(2, 2, 1)

    true_params = FullCatalog(
        slen,
        slen,
        {
            "n_sources": torch.tensor([1, 2]),
            "plocs": true_locs * slen,
            "galaxy_bools": true_galaxy_bools,
        },
    )
    est_params = FullCatalog(
        slen,
        slen,
        {
            "n_sources": torch.tensor([2, 2]),
            "plocs": est_locs * slen,
            "galaxy_bools": est_galaxy_bools,
        },
    )

    results_detection = detect(true_params, est_params)
    precision = results_detection["precision"]
    recall = results_detection["recall"]
    avg_distance = results_detection["avg_distance"]

    results_classify = classify(true_params, est_params)
    class_acc = results_classify["class_acc"]
    conf_matrix = results_classify["conf_matrix"]

    assert precision == 2 / (2 + 2)
    assert recall == 2 / 3
    assert class_acc == 1 / 2
    assert conf_matrix.eq(torch.tensor([[1, 1], [0, 0]])).all()
    assert avg_distance.item() == 50 * (0.01 + (0.01 + 0.09) / 2) / 2
Пример #4
0
def test_scene_metrics():
    true_params = FullCatalog(
        100,
        100,
        {
            "plocs": torch.tensor([[[50.0, 50.0]]]).float(),
            "galaxy_bools": torch.tensor([[[1]]]).float(),
            "n_sources": torch.tensor([1]).long(),
            "mags": torch.tensor([[[23.0]]]).float(),
        },
    )
    reporting.scene_metrics(true_params, true_params, mag_max=25, slack=1.0)
Пример #5
0
def get_positive_negative_stats(
    true_cat: FullCatalog,
    est_tile_cat: TileCatalog,
    mag_max: float = np.inf,
):
    true_cat = true_cat.apply_mag_bin(-np.inf, mag_max)
    thresholds = np.linspace(0.01, 0.99, 99)
    est_tile_cat = est_tile_cat.copy()

    res = Parallel(n_jobs=10)(
        delayed(stats_for_threshold)(true_cat.plocs, est_tile_cat, t)
        for t in tqdm(thresholds))
    out: Dict[str, Union[int, Tensor]] = {}
    for k in res[0]:
        out[k] = torch.stack([r[k] for r in res])
    out["n_obj"] = true_cat.plocs.shape[1]
    return out
Пример #6
0
    def compute_metrics(truth: FullCatalog, pred: FullCatalog):

        # prepare magnitude bins
        mag_cuts2 = torch.arange(18, 24.5, 0.25)
        mag_cuts1 = torch.full_like(mag_cuts2, fill_value=-np.inf)
        mag_cuts = torch.column_stack((mag_cuts1, mag_cuts2))

        mag_bins2 = torch.arange(18, 25, 1.0)
        mag_bins1 = mag_bins2 - 1
        mag_bins = torch.column_stack((mag_bins1, mag_bins2))

        # compute metrics
        cuts_data = compute_mag_bin_metrics(mag_cuts, truth, pred)
        bins_data = compute_mag_bin_metrics(mag_bins, truth, pred)

        # data for scatter plot of misclassifications (over all magnitudes).
        tplocs = truth.plocs.reshape(-1, 2)
        eplocs = pred.plocs.reshape(-1, 2)
        tindx, eindx, dkeep, _ = reporting.match_by_locs(tplocs,
                                                         eplocs,
                                                         slack=1.0)

        # compute egprob separately for PHOTO
        egbool = pred["galaxy_bools"].reshape(-1)[eindx][dkeep]
        egprob = pred.get("galaxy_probs", None)
        egprob = egbool if egprob is None else egprob.reshape(-1)[eindx][dkeep]
        full_metrics = {
            "tgbool": truth["galaxy_bools"].reshape(-1)[tindx][dkeep],
            "egbool": egbool,
            "egprob": egprob,
            "tmag": truth["mags"].reshape(-1)[tindx][dkeep],
            "emag": pred["mags"].reshape(-1)[eindx][dkeep],
        }

        return {
            "mag_cuts": mag_cuts2,
            "mag_bins": mag_bins2,
            "cuts_data": cuts_data,
            "bins_data": bins_data,
            "full_metrics": full_metrics,
        }
Пример #7
0
 def _sample_full_catalog(self):
     params_dict = self.prior.sample()
     params_dict["plocs"] = params_dict["locs"] * self.slen
     params_dict.pop("locs")
     params_dict = {k: v.unsqueeze(0) for k, v in params_dict.items()}
     return FullCatalog(self.slen, self.slen, params_dict)
Пример #8
0
def scene_metrics(
    true_params: FullCatalog,
    est_params: FullCatalog,
    mag_min: float = -np.inf,
    mag_max: float = np.inf,
    slack: float = 1.0,
) -> dict:
    """Return detection and classification metrics based on a given ground truth.

    These metrics are computed as a function of magnitude based on the specified
    bin `(mag_min, mag_max)` but are designed to be independent of the estimated magnitude.
    Hence, precision is computed by taking a cut in the estimated parameters based on the magnitude
    bin and matching them with *any* true objects. Similarly, recall is computed by taking a cut
    on the true parameters and matching them with *any* predicted objects.

    Args:
        true_params: True parameters of each source in the scene (e.g. from coadd catalog)
        est_params: Predictions on scene obtained from predict_on_scene function.
        mag_min: Discard all objects with magnitude lower than this.
        mag_max: Discard all objects with magnitude higher than this.
        slack: Pixel L-infinity distance slack when doing matching for metrics.

    Returns:
        Dictionary with output from DetectionMetrics, ClassificationMetrics.
    """
    detection_metrics = DetectionMetrics(slack)
    classification_metrics = ClassificationMetrics(slack)

    # precision
    eparams = est_params.apply_mag_bin(mag_min, mag_max)
    detection_metrics.update(true_params, eparams)
    precision = detection_metrics.compute()["precision"]
    detection_metrics.reset()  # reset global state since recall and precision use different cuts.

    # recall
    tparams = true_params.apply_mag_bin(mag_min, mag_max)
    detection_metrics.update(tparams, est_params)
    recall = detection_metrics.compute()["recall"]
    n_galaxies_detected = detection_metrics.compute()["n_galaxies_detected"]
    detection_metrics.reset()

    # f1-score
    f1 = 2 * precision * recall / (precision + recall)
    detection_result = {
        "precision": precision.item(),
        "recall": recall.item(),
        "f1": f1.item(),
        "n_galaxies_detected": n_galaxies_detected.item(),
    }

    # classification
    tparams = true_params.apply_mag_bin(mag_min, mag_max)
    classification_metrics.update(tparams, est_params)
    classification_result = classification_metrics.compute()

    # report counts on each bin
    tparams = true_params.apply_mag_bin(mag_min, mag_max)
    eparams = est_params.apply_mag_bin(mag_min, mag_max)
    tcount = tparams.n_sources.int().item()
    tgcount = tparams["galaxy_bools"].sum().int().item()
    tscount = tcount - tgcount

    ecount = eparams.n_sources.int().item()
    egcount = eparams["galaxy_bools"].sum().int().item()
    escount = ecount - egcount

    n_matches = classification_result["n_matches"]
    n_matches_gal_coadd = classification_result["n_matches_gal_coadd"]

    counts = {
        "tgcount": tgcount,
        "tscount": tscount,
        "egcount": egcount,
        "escount": escount,
        "n_matches_coadd_gal": n_matches_gal_coadd.item(),
        "n_matches_coadd_star": n_matches.item() - n_matches_gal_coadd.item(),
    }

    # compute and return results
    return {**detection_result, **classification_result, "counts": counts}
Пример #9
0
def calc_scene_metrics_by_mag(est_cat: FullCatalog, true_cat: FullCatalog,
                              mag_start: int, mag_end: int, loc_slack: float):
    scene_metrics_by_mag: Dict[str, Dict[str, Number]] = {}
    mag_mins = [float(m - 1)
                for m in range(mag_start, mag_end + 1)] + [-np.inf]
    mag_maxes = [float(m) for m in range(mag_start, mag_end + 1)] + [mag_end]
    for mag_min, mag_max in zip(mag_mins, mag_maxes):
        # report counts on each bin
        true_cat_binned = true_cat.apply_mag_bin(mag_min, mag_max)
        est_cat_binned = est_cat.apply_mag_bin(mag_min, mag_max)
        tcount = true_cat_binned.n_sources.int().item()
        tgcount = true_cat_binned["galaxy_bools"].sum().int().item()
        tscount = tcount - tgcount

        detection_metrics = reporting.DetectionMetrics(loc_slack)
        classification_metrics = reporting.ClassificationMetrics(loc_slack)

        # precision
        est_cat_binned = est_cat.apply_mag_bin(mag_min, mag_max)
        detection_metrics.update(true_cat, est_cat_binned)
        precision_metrics = detection_metrics.compute()
        fp = precision_metrics["fp"].item()
        # reset global state since recall and precision use different cuts.
        detection_metrics.reset()

        # recall
        true_cat_binned = true_cat.apply_mag_bin(mag_min, mag_max)
        detection_metrics.update(true_cat_binned, est_cat)
        detection_dict = detection_metrics.compute()
        tp = detection_dict["tp"].item()
        tp_gal = detection_dict["n_galaxies_detected"].item()
        tp_star = tp - tp_gal
        detection_metrics.reset()

        # classification
        classification_metrics.update(true_cat_binned, est_cat)
        classification_result = classification_metrics.compute()
        n_matches = classification_result["n_matches"].item()

        conf_matrix = classification_result["conf_matrix"]
        galaxy_acc = conf_matrix[0,
                                 0] / (conf_matrix[0, 0] + conf_matrix[0, 1])
        star_acc = conf_matrix[1, 1] / (conf_matrix[1, 1] + conf_matrix[1, 0])

        if np.isinf(mag_min):
            mag = "overall"
        else:
            mag = str(int(mag_max))

        scene_metrics_by_mag[mag] = {
            "tcount": tcount,
            "tgcount": tgcount,
            "tp": tp,
            "fp": fp,
            "recall": tp / tcount if tcount > 0 else 0.0,
            "precision": tp / (tp + fp) if (tp + fp) > 0 else 1.0,
            "tp_gal": tp_gal,
            "recall_gal": tp_gal / tgcount if tgcount > 0 else 0.0,
            "tp_star": tp_star,
            "recall_star": tp_star / tscount if tscount > 0 else 0.0,
            "classif_n_matches": n_matches,
            "classif_acc": classification_result["class_acc"].item(),
            "classif_galaxy_acc": galaxy_acc.item(),
            "classif_star_acc": star_acc.item(),
        }

    d: DefaultDict[str, Dict[str, Number]] = defaultdict(dict)
    for mag, scene_metrics_mag in scene_metrics_by_mag.items():
        for measure, value in scene_metrics_mag.items():
            d[measure][mag] = value
    return pd.DataFrame(d)