def run(self): """Run a lending experiment. Returns: A json encoding of the experiment result. """ env, agent = self.scenario_builder() metrics = { 'initial_credit_distribution': lending_metrics.CreditDistribution(env, step=0), 'final_credit_distributions': lending_metrics.CreditDistribution(env, step=-1), 'recall': error_metrics.RecallMetric( env, prediction_fn=lambda x: x.action, ground_truth_fn=lambda x: not x.state.will_default, stratify_fn=lambda x: str(x.state.group_id)), 'precision': error_metrics.PrecisionMetric( env, prediction_fn=lambda x: x.action, ground_truth_fn=lambda x: not x.state.will_default, stratify_fn=lambda x: str(x.state.group_id)), 'profit rate': value_tracking_metrics.ValueChange(env, state_var='bank_cash'), } if self.include_cumulative_loans: metrics['cumulative_loans'] = lending_metrics.CumulativeLoans(env) metrics['cumulative_recall'] = lending_metrics.CumulativeRecall( env) metric_results = run_util.run_simulation(env, agent, metrics, self.num_steps, self.seed) report = { 'environment': { 'name': env.__class__.__name__, 'params': env.initial_params, 'history': env.history, 'env': env }, 'agent': { 'name': agent.__class__.__name__, 'params': agent.params, 'debug_string': agent.debug_string(), 'threshold_history': agent.group_specific_threshold_history, 'tpr_targets': agent.target_recall_history, }, 'experiment_params': self, 'metric_results': metric_results, } if self.return_json: return core.to_json(report, indent=4) return report
def test_precision_with_zero_denominator(self): def _ground_truth_fn(history_item): state, _ = history_item return state.x[0] env = test_util.DeterministicDummyEnv(test_util.DummyParams(dim=1)) env.set_scalar_reward(rewards.NullReward()) # Always predict 0, precision will have a zero denominator. metric = error_metrics.PrecisionMetric( env=env, prediction_fn=lambda x: 0, ground_truth_fn=_ground_truth_fn, stratify_fn=lambda x: 1) measurement = test_util.run_test_simulation( env=env, agent=None, metric=metric, num_steps=50) self.assertEqual({1: 0}, measurement)