def run_experiment(self): """Main experiment runner.""" env, agent = self.build_scenario() social_burden = value_tracking_metrics.AggregatorMetric( env=env, selection_fn=self.selection_fn_social_burden_eligible_auditor, modifier_fn=None, stratify_fn=self.stratify_by_group, realign_fn=self.realign_history, calc_mean=True) accuracy = error_metrics.AccuracyMetric( env=env, numerator_fn=self.accuracy_nr_fn, denominator_fn=None, stratify_fn=self.stratify_by_group, realign_fn=self.realign_history) overall_accuracy = error_metrics.AccuracyMetric( env=env, numerator_fn=self.accuracy_nr_fn, denominator_fn=None, # pylint: disable=g-long-lambda stratify_fn=lambda x: [1 for _ in range(env.initial_params.num_applicants)], realign_fn=self.realign_history) overall_social_burden = value_tracking_metrics.AggregatorMetric( env=env, selection_fn=self.selection_fn_social_burden_eligible_auditor, modifier_fn=None, # pylint: disable=g-long-lambda stratify_fn=lambda x: [1 for _ in range(env.initial_params.num_applicants)], realign_fn=self.realign_history, calc_mean=True) final_threshold = value_tracking_metrics.FinalValueMetric( env=env, state_var='decision_threshold', realign_fn=self.realign_history) metrics = [ social_burden, accuracy, overall_accuracy, overall_social_burden, final_threshold ] metric_names = [ 'social_burden', 'accuracy', 'overall_accuracy', 'overall_social_burden', 'final_threshold' ] metric_results = run_util.run_stackelberg_simulation( env, agent, metrics, self.num_steps, self.seed) return core.to_json({ 'metric_results': dict(zip(metric_names, metric_results)), })
def test_accuracy_metric_can_interact_with_dummy(self): def _is_zero(history_item): _, action = history_item return int(action == 0) env = test_util.DummyEnv() env.set_scalar_reward(rewards.NullReward()) metric = error_metrics.AccuracyMetric(env=env, numerator_fn=_is_zero) test_util.run_test_simulation(env=env, metric=metric)
def test_stratified_accuracy_metric_correct_sequence_prediction(self): """Check correctness when stratifying into (wrong, right) bins.""" def _x_select(history_item): return [i == 1 for i in history_item.state.x] def _x_stratify(history_item): return history_item.state.x env = test_util.DeterministicDummyEnv(test_util.DummyParams(dim=10)) env.set_scalar_reward(rewards.NullReward()) metric = error_metrics.AccuracyMetric( env=env, numerator_fn=_x_select, stratify_fn=_x_stratify) measurement = test_util.run_test_simulation( env=env, agent=None, metric=metric) logging.info('Measurement: %s.', measurement) self.assertEqual(measurement[0], 0) self.assertEqual(measurement[1], 1)