def test_multiple_sources_one_tile(): d = { "n_sources": torch.tensor([2]), "plocs": torch.tensor([[0.5, 0.5], [0.6, 0.6]]).reshape(1, 2, 2), "galaxy_bools": torch.tensor([1, 1]).reshape(1, 2, 1), } full_cat = FullCatalog(2, 2, d) with pytest.raises(ValueError) as error_info: full_cat.to_tile_params(1, 1, ignore_extra_sources=False) assert error_info.value.args[ 0] == "Number of sources per tile exceeds `max_sources_per_tile`." # should return only first source in first tile. tile_cat = full_cat.to_tile_params(1, 1, ignore_extra_sources=True) assert torch.equal(tile_cat.n_sources, torch.tensor([[1, 0], [0, 0]]).reshape(1, 2, 2)) assert torch.equal( tile_cat.locs, torch.tensor([[[0.5, 0.5], [0, 0]], [[0, 0], [0, 0]]]).reshape(1, 2, 2, 1, 2)) assert torch.equal( tile_cat["galaxy_bools"], torch.tensor([[[1.0], [0.0]], [[0.0], [0.0]]]).reshape(1, 2, 2, 1, 1), )
def test_catalog_conversion(): true_params = FullCatalog( 100, 100, { "plocs": torch.tensor([[[51.8, 49.6]]]).float(), "galaxy_bools": torch.tensor([[[1]]]).float(), "n_sources": torch.tensor([1]).long(), "mags": torch.tensor([[[23.0]]]).float(), }, ) true_tile_params = true_params.to_tile_params(tile_slen=4, max_sources_per_tile=1) tile_params_tilde = true_tile_params.to_full_params() assert true_params.equals(tile_params_tilde)
def test_metrics(): slen = 50 slack = 1.0 detect = DetectionMetrics(slack) classify = ClassificationMetrics(slack) true_locs = torch.tensor([[[0.5, 0.5], [0.0, 0.0]], [[0.2, 0.2], [0.1, 0.1]]]).reshape(2, 2, 2) est_locs = torch.tensor([[[0.49, 0.49], [0.1, 0.1]], [[0.19, 0.19], [0.01, 0.01]]]).reshape( 2, 2, 2 ) true_galaxy_bools = torch.tensor([[1, 0], [1, 1]]).reshape(2, 2, 1) est_galaxy_bools = torch.tensor([[0, 1], [1, 0]]).reshape(2, 2, 1) true_params = FullCatalog( slen, slen, { "n_sources": torch.tensor([1, 2]), "plocs": true_locs * slen, "galaxy_bools": true_galaxy_bools, }, ) est_params = FullCatalog( slen, slen, { "n_sources": torch.tensor([2, 2]), "plocs": est_locs * slen, "galaxy_bools": est_galaxy_bools, }, ) results_detection = detect(true_params, est_params) precision = results_detection["precision"] recall = results_detection["recall"] avg_distance = results_detection["avg_distance"] results_classify = classify(true_params, est_params) class_acc = results_classify["class_acc"] conf_matrix = results_classify["conf_matrix"] assert precision == 2 / (2 + 2) assert recall == 2 / 3 assert class_acc == 1 / 2 assert conf_matrix.eq(torch.tensor([[1, 1], [0, 0]])).all() assert avg_distance.item() == 50 * (0.01 + (0.01 + 0.09) / 2) / 2
def test_scene_metrics(): true_params = FullCatalog( 100, 100, { "plocs": torch.tensor([[[50.0, 50.0]]]).float(), "galaxy_bools": torch.tensor([[[1]]]).float(), "n_sources": torch.tensor([1]).long(), "mags": torch.tensor([[[23.0]]]).float(), }, ) reporting.scene_metrics(true_params, true_params, mag_max=25, slack=1.0)
def get_positive_negative_stats( true_cat: FullCatalog, est_tile_cat: TileCatalog, mag_max: float = np.inf, ): true_cat = true_cat.apply_mag_bin(-np.inf, mag_max) thresholds = np.linspace(0.01, 0.99, 99) est_tile_cat = est_tile_cat.copy() res = Parallel(n_jobs=10)( delayed(stats_for_threshold)(true_cat.plocs, est_tile_cat, t) for t in tqdm(thresholds)) out: Dict[str, Union[int, Tensor]] = {} for k in res[0]: out[k] = torch.stack([r[k] for r in res]) out["n_obj"] = true_cat.plocs.shape[1] return out
def compute_metrics(truth: FullCatalog, pred: FullCatalog): # prepare magnitude bins mag_cuts2 = torch.arange(18, 24.5, 0.25) mag_cuts1 = torch.full_like(mag_cuts2, fill_value=-np.inf) mag_cuts = torch.column_stack((mag_cuts1, mag_cuts2)) mag_bins2 = torch.arange(18, 25, 1.0) mag_bins1 = mag_bins2 - 1 mag_bins = torch.column_stack((mag_bins1, mag_bins2)) # compute metrics cuts_data = compute_mag_bin_metrics(mag_cuts, truth, pred) bins_data = compute_mag_bin_metrics(mag_bins, truth, pred) # data for scatter plot of misclassifications (over all magnitudes). tplocs = truth.plocs.reshape(-1, 2) eplocs = pred.plocs.reshape(-1, 2) tindx, eindx, dkeep, _ = reporting.match_by_locs(tplocs, eplocs, slack=1.0) # compute egprob separately for PHOTO egbool = pred["galaxy_bools"].reshape(-1)[eindx][dkeep] egprob = pred.get("galaxy_probs", None) egprob = egbool if egprob is None else egprob.reshape(-1)[eindx][dkeep] full_metrics = { "tgbool": truth["galaxy_bools"].reshape(-1)[tindx][dkeep], "egbool": egbool, "egprob": egprob, "tmag": truth["mags"].reshape(-1)[tindx][dkeep], "emag": pred["mags"].reshape(-1)[eindx][dkeep], } return { "mag_cuts": mag_cuts2, "mag_bins": mag_bins2, "cuts_data": cuts_data, "bins_data": bins_data, "full_metrics": full_metrics, }
def _sample_full_catalog(self): params_dict = self.prior.sample() params_dict["plocs"] = params_dict["locs"] * self.slen params_dict.pop("locs") params_dict = {k: v.unsqueeze(0) for k, v in params_dict.items()} return FullCatalog(self.slen, self.slen, params_dict)
def scene_metrics( true_params: FullCatalog, est_params: FullCatalog, mag_min: float = -np.inf, mag_max: float = np.inf, slack: float = 1.0, ) -> dict: """Return detection and classification metrics based on a given ground truth. These metrics are computed as a function of magnitude based on the specified bin `(mag_min, mag_max)` but are designed to be independent of the estimated magnitude. Hence, precision is computed by taking a cut in the estimated parameters based on the magnitude bin and matching them with *any* true objects. Similarly, recall is computed by taking a cut on the true parameters and matching them with *any* predicted objects. Args: true_params: True parameters of each source in the scene (e.g. from coadd catalog) est_params: Predictions on scene obtained from predict_on_scene function. mag_min: Discard all objects with magnitude lower than this. mag_max: Discard all objects with magnitude higher than this. slack: Pixel L-infinity distance slack when doing matching for metrics. Returns: Dictionary with output from DetectionMetrics, ClassificationMetrics. """ detection_metrics = DetectionMetrics(slack) classification_metrics = ClassificationMetrics(slack) # precision eparams = est_params.apply_mag_bin(mag_min, mag_max) detection_metrics.update(true_params, eparams) precision = detection_metrics.compute()["precision"] detection_metrics.reset() # reset global state since recall and precision use different cuts. # recall tparams = true_params.apply_mag_bin(mag_min, mag_max) detection_metrics.update(tparams, est_params) recall = detection_metrics.compute()["recall"] n_galaxies_detected = detection_metrics.compute()["n_galaxies_detected"] detection_metrics.reset() # f1-score f1 = 2 * precision * recall / (precision + recall) detection_result = { "precision": precision.item(), "recall": recall.item(), "f1": f1.item(), "n_galaxies_detected": n_galaxies_detected.item(), } # classification tparams = true_params.apply_mag_bin(mag_min, mag_max) classification_metrics.update(tparams, est_params) classification_result = classification_metrics.compute() # report counts on each bin tparams = true_params.apply_mag_bin(mag_min, mag_max) eparams = est_params.apply_mag_bin(mag_min, mag_max) tcount = tparams.n_sources.int().item() tgcount = tparams["galaxy_bools"].sum().int().item() tscount = tcount - tgcount ecount = eparams.n_sources.int().item() egcount = eparams["galaxy_bools"].sum().int().item() escount = ecount - egcount n_matches = classification_result["n_matches"] n_matches_gal_coadd = classification_result["n_matches_gal_coadd"] counts = { "tgcount": tgcount, "tscount": tscount, "egcount": egcount, "escount": escount, "n_matches_coadd_gal": n_matches_gal_coadd.item(), "n_matches_coadd_star": n_matches.item() - n_matches_gal_coadd.item(), } # compute and return results return {**detection_result, **classification_result, "counts": counts}
def calc_scene_metrics_by_mag(est_cat: FullCatalog, true_cat: FullCatalog, mag_start: int, mag_end: int, loc_slack: float): scene_metrics_by_mag: Dict[str, Dict[str, Number]] = {} mag_mins = [float(m - 1) for m in range(mag_start, mag_end + 1)] + [-np.inf] mag_maxes = [float(m) for m in range(mag_start, mag_end + 1)] + [mag_end] for mag_min, mag_max in zip(mag_mins, mag_maxes): # report counts on each bin true_cat_binned = true_cat.apply_mag_bin(mag_min, mag_max) est_cat_binned = est_cat.apply_mag_bin(mag_min, mag_max) tcount = true_cat_binned.n_sources.int().item() tgcount = true_cat_binned["galaxy_bools"].sum().int().item() tscount = tcount - tgcount detection_metrics = reporting.DetectionMetrics(loc_slack) classification_metrics = reporting.ClassificationMetrics(loc_slack) # precision est_cat_binned = est_cat.apply_mag_bin(mag_min, mag_max) detection_metrics.update(true_cat, est_cat_binned) precision_metrics = detection_metrics.compute() fp = precision_metrics["fp"].item() # reset global state since recall and precision use different cuts. detection_metrics.reset() # recall true_cat_binned = true_cat.apply_mag_bin(mag_min, mag_max) detection_metrics.update(true_cat_binned, est_cat) detection_dict = detection_metrics.compute() tp = detection_dict["tp"].item() tp_gal = detection_dict["n_galaxies_detected"].item() tp_star = tp - tp_gal detection_metrics.reset() # classification classification_metrics.update(true_cat_binned, est_cat) classification_result = classification_metrics.compute() n_matches = classification_result["n_matches"].item() conf_matrix = classification_result["conf_matrix"] galaxy_acc = conf_matrix[0, 0] / (conf_matrix[0, 0] + conf_matrix[0, 1]) star_acc = conf_matrix[1, 1] / (conf_matrix[1, 1] + conf_matrix[1, 0]) if np.isinf(mag_min): mag = "overall" else: mag = str(int(mag_max)) scene_metrics_by_mag[mag] = { "tcount": tcount, "tgcount": tgcount, "tp": tp, "fp": fp, "recall": tp / tcount if tcount > 0 else 0.0, "precision": tp / (tp + fp) if (tp + fp) > 0 else 1.0, "tp_gal": tp_gal, "recall_gal": tp_gal / tgcount if tgcount > 0 else 0.0, "tp_star": tp_star, "recall_star": tp_star / tscount if tscount > 0 else 0.0, "classif_n_matches": n_matches, "classif_acc": classification_result["class_acc"].item(), "classif_galaxy_acc": galaxy_acc.item(), "classif_star_acc": star_acc.item(), } d: DefaultDict[str, Dict[str, Number]] = defaultdict(dict) for mag, scene_metrics_mag in scene_metrics_by_mag.items(): for measure, value in scene_metrics_mag.items(): d[measure][mag] = value return pd.DataFrame(d)