Exemplo n.º 1
0
    def test_confusion_metric_correct_for_atomic_prediction_rule(self):
        def _ground_truth_fn(history_item):
            state, _ = history_item
            return state.x[0]

        env = test_util.DeterministicDummyEnv(test_util.DummyParams(dim=1))
        env.set_scalar_reward(rewards.NullReward())
        # Always predict 1.
        metric = error_metrics.ConfusionMetric(
            env=env,
            prediction_fn=lambda x: 1,
            ground_truth_fn=_ground_truth_fn,
            stratify_fn=lambda x: 1,
        )

        measurement = test_util.run_test_simulation(env=env,
                                                    agent=None,
                                                    metric=metric)

        logging.info("Measurement: %s.", measurement)

        # The keys in measurement are given by group membership, which in this case
        # is defined to always be 1.
        self.assertEqual(measurement[1].fp, 5)
        self.assertEqual(measurement[1].tp, 5)
        self.assertNotIn(0, measurement)
Exemplo n.º 2
0
    def test_confusion_metric_correct_for_sequence_prediction_rule(self):
        dim = 10

        def _ground_truth_fn(history_item):
            state, _ = history_item
            return state.x

        env = test_util.DeterministicDummyEnv(test_util.DummyParams(dim=dim))
        env.set_scalar_reward(rewards.NullReward())
        # Always predict a sequence of 1s.
        metric = error_metrics.ConfusionMetric(
            env=env,
            prediction_fn=lambda x: [1 for _ in range(dim)],
            ground_truth_fn=_ground_truth_fn,
            stratify_fn=lambda x: [1 for _ in range(dim)],
        )

        measurement = test_util.run_test_simulation(env=env,
                                                    agent=None,
                                                    metric=metric)

        logging.info("Measurement: %s.", measurement)

        self.assertEqual(measurement[1].fp, 50)
        self.assertEqual(measurement[1].tp, 50)
        self.assertNotIn(0, measurement)
Exemplo n.º 3
0
def _setup_test_simulation(dim=1, calc_mean=False, modifier_fn=_modifier_fn):
    env = test_util.DeterministicDummyEnv(test_util.DummyParams(dim=dim))
    env.set_scalar_reward(rewards.NullReward())
    metric = value_tracking_metrics.AggregatorMetric(
        env=env,
        modifier_fn=modifier_fn,
        selection_fn=_selection_fn,
        stratify_fn=_stratify_fn,
        calc_mean=calc_mean)
    return env, metric
Exemplo n.º 4
0
    def test_summing_metric_give_correct_sum_dummy_env(self):
        env = test_util.DeterministicDummyEnv(test_util.DummyParams(dim=1))
        env.set_scalar_reward(rewards.NullReward())

        metric = value_tracking_metrics.SummingMetric(
            env=env, selection_fn=_selection_fn)
        measurement = test_util.run_test_simulation(env,
                                                    agent=None,
                                                    metric=metric,
                                                    seed=0)

        self.assertTrue(np.all(np.equal(measurement, [5])))
  def test_recall_with_zero_denominator(self):
    env = test_util.DeterministicDummyEnv(test_util.DummyParams(dim=1))
    env.set_scalar_reward(rewards.NullReward())
    # Ground truth is always 0, recall will have a zero denominator.
    metric = error_metrics.RecallMetric(
        env=env,
        prediction_fn=lambda x: 0,
        ground_truth_fn=lambda x: 0,
        stratify_fn=lambda x: 1)

    measurement = test_util.run_test_simulation(
        env=env, agent=None, metric=metric, num_steps=50)
    self.assertEqual({1: 0}, measurement)
  def test_precision_with_zero_denominator(self):
    def _ground_truth_fn(history_item):
      state, _ = history_item
      return state.x[0]

    env = test_util.DeterministicDummyEnv(test_util.DummyParams(dim=1))
    env.set_scalar_reward(rewards.NullReward())
    # Always predict 0, precision will have a zero denominator.
    metric = error_metrics.PrecisionMetric(
        env=env,
        prediction_fn=lambda x: 0,
        ground_truth_fn=_ground_truth_fn,
        stratify_fn=lambda x: 1)

    measurement = test_util.run_test_simulation(
        env=env, agent=None, metric=metric, num_steps=50)

    self.assertEqual({1: 0}, measurement)
  def test_recall_metric_correct_for_atomic_prediction_rule(self):
    def _ground_truth_fn(history_item):
      state, _ = history_item
      return state.x[0]

    env = test_util.DeterministicDummyEnv(test_util.DummyParams(dim=1))
    env.set_scalar_reward(rewards.NullReward())
    # Always predict 1.
    metric = error_metrics.RecallMetric(
        env=env,
        prediction_fn=lambda x: 1,
        ground_truth_fn=_ground_truth_fn,
        stratify_fn=lambda x: 1)

    measurement = test_util.run_test_simulation(
        env=env, agent=None, metric=metric, num_steps=50)

    logging.info('Measurement: %s.', measurement)
    self.assertEqual({1: 1}, measurement)
  def test_stratified_accuracy_metric_correct_sequence_prediction(self):
    """Check correctness when stratifying into (wrong, right) bins."""

    def _x_select(history_item):
      return [i == 1 for i in history_item.state.x]

    def _x_stratify(history_item):
      return history_item.state.x

    env = test_util.DeterministicDummyEnv(test_util.DummyParams(dim=10))
    env.set_scalar_reward(rewards.NullReward())
    metric = error_metrics.AccuracyMetric(
        env=env, numerator_fn=_x_select, stratify_fn=_x_stratify)

    measurement = test_util.run_test_simulation(
        env=env, agent=None, metric=metric)

    logging.info('Measurement: %s.', measurement)

    self.assertEqual(measurement[0], 0)
    self.assertEqual(measurement[1], 1)
  def test_cost_metric_correct_for_atomic_prediction_rule(self):

    def _ground_truth_fn(history_item):
      state, _ = history_item
      return state.x[0]

    env = test_util.DeterministicDummyEnv(test_util.DummyParams(dim=1))
    env.set_scalar_reward(rewards.NullReward())
    cost_metric = error_metrics.CostedConfusionMetric(
        env=env,
        prediction_fn=lambda x: 1,
        ground_truth_fn=_ground_truth_fn,
        stratify_fn=lambda x: 1,
        cost_matrix=params.CostMatrix(tp=1, fp=-2, tn=-1, fn=-1))
    measurement = test_util.run_test_simulation(
        env=env, agent=None, metric=cost_metric)

    logging.info('Cost measurement: %s.', measurement)

    self.assertEqual(measurement[1], -5)
    self.assertNotIn(0, measurement)
Exemplo n.º 10
0
    def test_stratified_accuracy_metric_correct_atomic_prediction(self):
        """Check correctness when stratifying into (wrong, right) bins."""
        def _x_select(history_item):
            state, _ = history_item
            return int(state.x[0] == 1)

        def _x_stratify(history_item):
            state, _ = history_item
            return state.x[0]

        env = test_util.DeterministicDummyEnv()
        env.set_scalar_reward(rewards.NullReward())
        metric = error_metrics.AccuracyMetric(env=env,
                                              numerator_fn=_x_select,
                                              stratify_fn=_x_stratify)

        measurement = test_util.run_test_simulation(env=env,
                                                    agent=None,
                                                    metric=metric)

        logging.info("Measurement: %s.", measurement)

        self.assertEqual(measurement[0], 0)
        self.assertEqual(measurement[1], 1)