def _inference_epoch_end(self, output_results: List[Dict[str, Tensor]], stage: str) -> None: all_y = torch.cat([_r["y"] for _r in output_results], 0) all_s = torch.cat([_r["s"] for _r in output_results], 0) all_preds = torch.cat([_r["preds"] for _r in output_results], 0) dt = em.DataTuple( x=pd.DataFrame(torch.rand_like(all_s, dtype=float).detach().cpu().numpy(), columns=["x0"]), s=pd.DataFrame(all_s.detach().cpu().numpy(), columns=["s"]), y=pd.DataFrame(all_y.detach().cpu().numpy(), columns=["y"]), ) results = em.run_metrics( predictions=em.Prediction( hard=pd.Series(all_preds.detach().cpu().numpy())), actual=dt, metrics=[em.Accuracy(), em.RenyiCorrelation(), em.Yanovich()], per_sens_metrics=[em.Accuracy(), em.ProbPos(), em.TPR()], ) tm_acc = self.val_acc if stage == "val" else self.test_acc acc = tm_acc.compute().item() results_dict = {f"{stage}/acc": acc} results_dict.update( {f"{stage}/{self.target}_{k}": v for k, v in results.items()}) self.log_dict(results_dict)
def inference_epoch_end(self, outputs: EPOCH_OUTPUT, stage: Stage) -> Dict[str, float]: targets_all = aggregate_over_epoch(outputs=outputs, metric="targets") subgroup_inf_all = aggregate_over_epoch(outputs=outputs, metric="subgroup_inf") logits_y_all = aggregate_over_epoch(outputs=outputs, metric="logits_y") preds_y_all = hard_prediction(logits_y_all) dt = em.DataTuple( x=pd.DataFrame( torch.rand_like(subgroup_inf_all).detach().cpu().numpy(), columns=["x0"], ), s=pd.DataFrame(subgroup_inf_all.detach().cpu().numpy(), columns=["s"]), y=pd.DataFrame(targets_all.detach().cpu().numpy(), columns=["y"]), ) return em.run_metrics( predictions=em.Prediction( hard=pd.Series(preds_y_all.detach().cpu().numpy())), actual=dt, metrics=[em.Accuracy(), em.RenyiCorrelation(), em.Yanovich()], per_sens_metrics=[em.Accuracy(), em.ProbPos(), em.TPR()], )
def inference_epoch_end(self, outputs: EPOCH_OUTPUT, stage: Stage) -> Dict[str, float]: targets_all = aggregate_over_epoch(outputs=outputs, metric="targets") subgroup_inf_all = aggregate_over_epoch(outputs=outputs, metric="subgroup_inf") preds_all = aggregate_over_epoch(outputs=outputs, metric="preds") mean_preds = preds_all.mean(-1) mean_preds_s0 = preds_all[subgroup_inf_all == 0].mean(-1) mean_preds_s1 = preds_all[subgroup_inf_all == 1].mean(-1) dt = em.DataTuple( x=pd.DataFrame( torch.rand_like(subgroup_inf_all, dtype=torch.float).detach().cpu().numpy(), columns=["x0"], ), s=pd.DataFrame(subgroup_inf_all.detach().cpu().numpy(), columns=["s"]), y=pd.DataFrame(targets_all.detach().cpu().numpy(), columns=["y"]), ) results_dict = em.run_metrics( predictions=em.Prediction(hard=pd.Series((preds_all > 0).detach().cpu().numpy())), actual=dt, metrics=[em.Accuracy(), em.RenyiCorrelation(), em.Yanovich()], per_sens_metrics=[em.Accuracy(), em.ProbPos(), em.TPR()], ) results_dict.update( { "DP_Gap": float((mean_preds_s0 - mean_preds_s1).abs().item()), "mean_pred": float(mean_preds.item()), } ) return results_dict