def inference_epoch_end(self, outputs: EPOCH_OUTPUT,
                            stage: Stage) -> Dict[str, float]:
        targets_all = aggregate_over_epoch(outputs=outputs, metric="targets")
        subgroup_inf_all = aggregate_over_epoch(outputs=outputs,
                                                metric="subgroup_inf")
        logits_y_all = aggregate_over_epoch(outputs=outputs, metric="logits_y")

        preds_y_all = hard_prediction(logits_y_all)

        dt = em.DataTuple(
            x=pd.DataFrame(
                torch.rand_like(subgroup_inf_all).detach().cpu().numpy(),
                columns=["x0"],
            ),
            s=pd.DataFrame(subgroup_inf_all.detach().cpu().numpy(),
                           columns=["s"]),
            y=pd.DataFrame(targets_all.detach().cpu().numpy(), columns=["y"]),
        )

        return em.run_metrics(
            predictions=em.Prediction(
                hard=pd.Series(preds_y_all.detach().cpu().numpy())),
            actual=dt,
            metrics=[em.Accuracy(),
                     em.RenyiCorrelation(),
                     em.Yanovich()],
            per_sens_metrics=[em.Accuracy(),
                              em.ProbPos(),
                              em.TPR()],
        )
    def _inference_epoch_end(self, output_results: List[Dict[str, Tensor]],
                             stage: str) -> None:
        all_y = torch.cat([_r["y"] for _r in output_results], 0)
        all_s = torch.cat([_r["s"] for _r in output_results], 0)
        all_preds = torch.cat([_r["preds"] for _r in output_results], 0)

        dt = em.DataTuple(
            x=pd.DataFrame(torch.rand_like(all_s,
                                           dtype=float).detach().cpu().numpy(),
                           columns=["x0"]),
            s=pd.DataFrame(all_s.detach().cpu().numpy(), columns=["s"]),
            y=pd.DataFrame(all_y.detach().cpu().numpy(), columns=["y"]),
        )

        results = em.run_metrics(
            predictions=em.Prediction(
                hard=pd.Series(all_preds.detach().cpu().numpy())),
            actual=dt,
            metrics=[em.Accuracy(),
                     em.RenyiCorrelation(),
                     em.Yanovich()],
            per_sens_metrics=[em.Accuracy(),
                              em.ProbPos(),
                              em.TPR()],
        )

        tm_acc = self.val_acc if stage == "val" else self.test_acc
        acc = tm_acc.compute().item()
        results_dict = {f"{stage}/acc": acc}
        results_dict.update(
            {f"{stage}/{self.target}_{k}": v
             for k, v in results.items()})

        self.log_dict(results_dict)
    def inference_epoch_end(self, outputs: EPOCH_OUTPUT, stage: Stage) -> Dict[str, float]:
        targets_all = aggregate_over_epoch(outputs=outputs, metric="targets")
        subgroup_inf_all = aggregate_over_epoch(outputs=outputs, metric="subgroup_inf")
        preds_all = aggregate_over_epoch(outputs=outputs, metric="preds")

        mean_preds = preds_all.mean(-1)
        mean_preds_s0 = preds_all[subgroup_inf_all == 0].mean(-1)
        mean_preds_s1 = preds_all[subgroup_inf_all == 1].mean(-1)

        dt = em.DataTuple(
            x=pd.DataFrame(
                torch.rand_like(subgroup_inf_all, dtype=torch.float).detach().cpu().numpy(),
                columns=["x0"],
            ),
            s=pd.DataFrame(subgroup_inf_all.detach().cpu().numpy(), columns=["s"]),
            y=pd.DataFrame(targets_all.detach().cpu().numpy(), columns=["y"]),
        )

        results_dict = em.run_metrics(
            predictions=em.Prediction(hard=pd.Series((preds_all > 0).detach().cpu().numpy())),
            actual=dt,
            metrics=[em.Accuracy(), em.RenyiCorrelation(), em.Yanovich()],
            per_sens_metrics=[em.Accuracy(), em.ProbPos(), em.TPR()],
        )

        results_dict.update(
            {
                "DP_Gap": float((mean_preds_s0 - mean_preds_s1).abs().item()),
                "mean_pred": float(mean_preds.item()),
            }
        )
        return results_dict
Exemple #4
0
def test_get_info(toy_train_val: TrainValPair):
    """Test get info."""
    train, test = toy_train_val
    model: LRCV = LRCV()
    predictions: Prediction = model.run(train, test)
    results = em.run_metrics(predictions, test, [], [])
    assert results["C"] == approx(166.810, abs=0.001)
Exemple #5
0
def test_run_metrics(toy_train_val: TrainValPair):
    """Test run metrics."""
    train, test = toy_train_val
    model: InAlgorithm = SVM()
    predictions: Prediction = model.run(train, test)
    results = em.run_metrics(predictions, test, [CV()], [TPR()])
    assert len(results) == 5
    assert results["TPR_sensitive-attr_0"] == approx(0.923, abs=0.001)
    assert results["TPR_sensitive-attr_1"] == approx(1.0, abs=0.001)
    assert results["TPR_sensitive-attr_0-sensitive-attr_1"] == approx(
        0.077, abs=0.001)
    assert results["TPR_sensitive-attr_0/sensitive-attr_1"] == approx(
        0.923, abs=0.001)
    assert results["CV"] == approx(0.630, abs=0.001)
Exemple #6
0
def compute_metrics(
    predictions: em.Prediction,
    actual: em.DataTuple,
    s_dim: int,
) -> dict[str, float]:
    """Compute accuracy and fairness metrics and log them.

    Args:
        args: args object
        predictions: predictions in a format that is compatible with EthicML
        actual: labels for the predictions
        model_name: name of the model used
        step: step of training (needed for logging to W&B)
        s_dim: dimension of s
        exp_name: name of the experiment
        save_summary: if True, a summary will be saved to wandb
        use_wandb: whether to use wandb at all
        additional_entries: entries that should go with in the summary
    Returns:
        dictionary with the computed metrics
    """

    predictions._info = {}
    metrics = em.run_metrics(
        predictions,
        actual,
        metrics=[em.Accuracy(),
                 em.TPR(),
                 em.TNR(),
                 em.RenyiCorrelation()],
        per_sens_metrics=[em.Accuracy(),
                          em.ProbPos(),
                          em.TPR(),
                          em.TNR()],
        diffs_and_ratios=s_dim <
        4,  # this just gets too much with higher s dim
    )
    # replace the slash; it's causing problems
    metrics = {k.replace("/", "÷"): v for k, v in metrics.items()}
    print_metrics(metrics)
    return metrics
Exemple #7
0
def compute_metrics(
    cfg: BaseArgs,
    predictions: em.Prediction,
    actual: em.DataTuple,
    exp_name: str,
    model_name: str,
    step: int,
    save_to_csv: Optional[Path] = None,
    results_csv: str = "",
    use_wandb: bool = False,
    additional_entries: Optional[Mapping[str, float]] = None,
) -> Dict[str, float]:
    """Compute accuracy and fairness metrics and log them.

    Args:
        args: args object
        predictions: predictions in a format that is compatible with EthicML
        actual: labels for the predictions
        exp_name: name of the experiment
        model_name: name of the model used
        step: step of training (needed for logging to W&B)
        save_to_csv: if a path is given, the results are saved to a CSV file
        results_csv: name of the CSV file
    Returns:
        dictionary with the computed metrics
    """

    predictions._info = {}
    metrics = em.run_metrics(
        predictions,
        actual,
        metrics=[em.Accuracy(),
                 em.TPR(),
                 em.TNR(),
                 em.RenyiCorrelation()],
        per_sens_metrics=[em.Accuracy(),
                          em.ProbPos(),
                          em.TPR(),
                          em.TNR()],
        diffs_and_ratios=cfg.misc._s_dim <
        4,  # this just gets too much with higher s dim
    )
    # replace the slash; it's causing problems
    metrics = {k.replace("/", "÷"): v for k, v in metrics.items()}

    if use_wandb:
        wandb_log(cfg.misc,
                  {f"{k} ({model_name})": v
                   for k, v in metrics.items()},
                  step=step)

    if save_to_csv is not None:
        # full_name = f"{args.dataset}_{exp_name}"
        # exp_name += "_s" if pred_s else "_y"
        # if hasattr(args, "eval_on_recon"):
        #     exp_name += "_on_recons" if args.eval_on_recon else "_on_encodings"

        manual_entries = {
            "seed":
            str(getattr(cfg.misc, "seed", cfg.misc.data_split_seed)),
            "data":
            exp_name,
            "method":
            f'"{model_name}"',
            "wandb_url":
            str(wandb.run.get_url())
            if use_wandb and cfg.misc.use_wandb else "(None)",
        }

        external = additional_entries or {}

        if results_csv:
            assert isinstance(save_to_csv, Path)
            save_to_csv.mkdir(exist_ok=True, parents=True)
            results = {**metrics, **external}

            results_path = save_to_csv / f"{cfg.data.dataset.name}_{model_name}_{results_csv}"
            values = ",".join(
                list(manual_entries.values()) +
                [str(v) for v in results.values()])
            if not results_path.is_file():
                with results_path.open("w") as f:
                    # ========= header =========
                    f.write(",".join(
                        list(manual_entries) + [str(k)
                                                for k in results]) + "\n")
                    f.write(values + "\n")
            else:
                with results_path.open("a") as f:  # append to existing file
                    f.write(values + "\n")
            log.info(f"Results have been written to {results_path.resolve()}")
        if use_wandb:
            for metric_name, value in metrics.items():
                wandb.run.summary[f"{model_name}_{metric_name}"] = value
            # external metrics are without prefix
            for metric_name, value in external.items():
                wandb.run.summary[metric_name] = value

    log.info(f"Results for {exp_name} ({model_name}):")
    print_metrics({f"{k} ({model_name})": v for k, v in metrics.items()})
    log.info("")  # empty line
    return metrics