コード例 #1
0
    def testFairnessIndicatorsMultiHead(self):
        temp_eval_export_dir = self._getEvalExportDir()
        _, eval_export_dir = (multi_head.simple_multi_head(
            None, temp_eval_export_dir))

        examples = [
            self._makeExample(age=3.0,
                              language='english',
                              english_label=1.0,
                              chinese_label=0.0,
                              other_label=0.0),
            self._makeExample(age=3.0,
                              language='chinese',
                              english_label=0.0,
                              chinese_label=1.0,
                              other_label=0.0),
            self._makeExample(age=4.0,
                              language='english',
                              english_label=1.0,
                              chinese_label=0.0,
                              other_label=0.0),
            self._makeExample(age=5.0,
                              language='chinese',
                              english_label=0.0,
                              chinese_label=1.0,
                              other_label=0.0),
            self._makeExample(age=6.0,
                              language='chinese',
                              english_label=0.0,
                              chinese_label=1.0,
                              other_label=0.0),
        ]
        fairness_english = post_export_metrics.fairness_indicators(
            target_prediction_keys=['english_head/logistic'],
            labels_key='english_head')
        fairness_chinese = post_export_metrics.fairness_indicators(
            target_prediction_keys=['chinese_head/logistic'],
            labels_key='chinese_head')

        def check_metric_result(got):
            try:
                self.assertEqual(1, len(got), 'got: %s' % got)
                (slice_key, value) = got[0]
                self.assertEqual((), slice_key)
                expected_values_dict = {
                    metric_keys.base_key('english_head/logistic/[email protected]'):
                    1.0,
                    metric_keys.base_key('chinese_head/logistic/[email protected]'):
                    1.0,
                }
                self.assertDictElementsAlmostEqual(value, expected_values_dict)
            except AssertionError as err:
                raise util.BeamAssertException(err)

        self._runTestWithCustomCheck(examples,
                                     eval_export_dir, [
                                         fairness_english,
                                         fairness_chinese,
                                     ],
                                     custom_metrics_check=check_metric_result)
コード例 #2
0
  def testFairnessIndicatorsCounters(self):
    temp_eval_export_dir = self._getEvalExportDir()
    _, eval_export_dir = (
        multi_head.simple_multi_head(None, temp_eval_export_dir))

    examples = [
        self._makeExample(
            age=3.0,
            language='english',
            english_label=1.0,
            chinese_label=0.0,
            other_label=0.0),
        self._makeExample(
            age=3.0,
            language='chinese',
            english_label=0.0,
            chinese_label=1.0,
            other_label=0.0),
        self._makeExample(
            age=4.0,
            language='english',
            english_label=1.0,
            chinese_label=0.0,
            other_label=0.0),
        self._makeExample(
            age=5.0,
            language='chinese',
            english_label=0.0,
            chinese_label=1.0,
            other_label=0.0),
        self._makeExample(
            age=6.0,
            language='chinese',
            english_label=0.0,
            chinese_label=1.0,
            other_label=0.0),
    ]
    fairness_english = post_export_metrics.fairness_indicators(
        target_prediction_keys=['english_head/logistic'],
        labels_key='english_head')
    fairness_chinese = post_export_metrics.fairness_indicators(
        target_prediction_keys=['chinese_head/logistic'],
        labels_key='chinese_head')

    def check_metric_counter(result):
      metric_filter = beam.metrics.metric.MetricsFilter().with_name(
          'metric_computed_fairness_indicators')
      actual_metrics_count = result.metrics().query(
          filter=metric_filter)['counters'][0].committed
      self.assertEqual(actual_metrics_count, 2)

    self._runTestWithCustomCheck(
        examples,
        eval_export_dir, [
            fairness_english,
            fairness_chinese,
        ],
        custom_result_check=check_metric_counter)
コード例 #3
0
  def testFairnessIndicatorsZeroes(self):

    temp_eval_export_dir = self._getEvalExportDir()
    _, eval_export_dir = (
        fixed_prediction_estimator_extra_fields
        .simple_fixed_prediction_estimator_extra_fields(None,
                                                        temp_eval_export_dir))
    examples = self.makeConfusionMatrixExamples()[0:1]
    fairness_metrics = post_export_metrics.fairness_indicators()

    def check_result(got):  # pylint: disable=invalid-name
      try:
        self.assertEqual(1, len(got), 'got: %s' % got)
        (slice_key, value) = got[0]
        self.assertEqual((), slice_key)
        expected_values_dict = {
            metric_keys.base_key('[email protected]'): 0.0,
        }
        self.assertDictElementsAlmostEqual(value, expected_values_dict)
      except AssertionError as err:
        raise util.BeamAssertException(err)

    self._runTestWithCustomCheck(
        examples,
        eval_export_dir, [fairness_metrics],
        custom_metrics_check=check_result)
コード例 #4
0
  def testFairnessIndicatorsDigitsKey(self):
    temp_eval_export_dir = self._getEvalExportDir()
    _, eval_export_dir = (
        fixed_prediction_estimator_extra_fields
        .simple_fixed_prediction_estimator_extra_fields(None,
                                                        temp_eval_export_dir))
    examples = self.makeConfusionMatrixExamples()
    fairness_metrics = post_export_metrics.fairness_indicators(
        example_weight_key='fixed_float', thresholds=[0.5, 0.59, 0.599, 0.5999])

    def check_result(got):  # pylint: disable=invalid-name
      try:
        self.assertEqual(1, len(got), 'got: %s' % got)
        (slice_key, value) = got[0]
        self.assertEqual((), slice_key)
        self.assertIn(metric_keys.base_key('[email protected]'), value)
        self.assertIn(metric_keys.base_key('[email protected]'), value)
        self.assertIn(metric_keys.base_key('[email protected]'), value)
        self.assertIn(metric_keys.base_key('[email protected]'), value)
        self.assertIn(metric_keys.base_key('[email protected]'), value)
        self.assertIn(metric_keys.base_key('[email protected]'), value)
        self.assertIn(metric_keys.base_key('[email protected]'), value)
        self.assertIn(metric_keys.base_key('[email protected]'), value)
      except AssertionError as err:
        raise util.BeamAssertException(err)

    self._runTestWithCustomCheck(
        examples,
        eval_export_dir, [fairness_metrics],
        custom_metrics_check=check_result)
コード例 #5
0
  def testFairnessIndicatorsAtThresholdsWeightedWithUncertainty(self):
    self.compute_confidence_intervals = True
    temp_eval_export_dir = self._getEvalExportDir()
    _, eval_export_dir = (
        fixed_prediction_estimator_extra_fields
        .simple_fixed_prediction_estimator_extra_fields(None,
                                                        temp_eval_export_dir))
    examples = self.makeConfusionMatrixExamples()
    fairness_metrics = post_export_metrics.fairness_indicators(
        example_weight_key='fixed_float', thresholds=[0.0, 0.7, 0.8, 1.0])

    def check_result(got):  # pylint: disable=invalid-name
      try:
        self.assertEqual(1, len(got), 'got: %s' % got)
        (slice_key, value) = got[0]
        self.assertEqual((), slice_key)
        expected_values_dict = {
            metric_keys.base_key('[email protected]'): 6.0 / 7.0,
            metric_keys.base_key('[email protected]'): 2.0 / 3.0,
            metric_keys.base_key('[email protected]'): 0.8,
            metric_keys.base_key('[email protected]'): 1.0 / 3.0,
            metric_keys.base_key('[email protected]'): 1.0 / 7.0,
            metric_keys.base_key('[email protected]'): 2.0 / 10.0,
            metric_keys.base_key('[email protected]'): 2.0 / 8.0,
            metric_keys.base_key('[email protected]'): 1.0 / 2.0,
            metric_keys.base_key('[email protected]'): 3.0 / 7.0,
            metric_keys.base_key('[email protected]'): 2.0 / 3.0,
            metric_keys.base_key('[email protected]'): 0.5,
            metric_keys.base_key('[email protected]'): 1.0 / 3.0,
            metric_keys.base_key('[email protected]'): 4.0 / 7.0,
            metric_keys.base_key('[email protected]'): 5.0 / 10.0,
            metric_keys.base_key('[email protected]'): 2.0 / 5.0,
            metric_keys.base_key('[email protected]'): 4.0 / 5.0,
            metric_keys.base_key('[email protected]'): 3.0 / 7.0,
            metric_keys.base_key('[email protected]'): 0,
            metric_keys.base_key('[email protected]'): 0.3,
            metric_keys.base_key('[email protected]'): 3.0 / 3.0,
            metric_keys.base_key('[email protected]'): 4.0 / 7.0,
            metric_keys.base_key('[email protected]'): 7.0 / 10.0,
            metric_keys.base_key('[email protected]'): 0,
            metric_keys.base_key('[email protected]'): 4.0 / 7.0,
            metric_keys.base_key('[email protected]'): 0,
            metric_keys.base_key('[email protected]'): 0,
            metric_keys.base_key('[email protected]'): 0,
            metric_keys.base_key('[email protected]'): 1,
            metric_keys.base_key('[email protected]'): 1,
            metric_keys.base_key('[email protected]'): 1,
            metric_keys.base_key('[email protected]'): 0,
            metric_keys.base_key('[email protected]'): 7.0 / 10.0,
        }
        self.assertDictElementsWithTDistributionAlmostEqual(
            value, expected_values_dict)
      except AssertionError as err:
        raise util.BeamAssertException(err)

    self._runTestWithCustomCheck(
        examples,
        eval_export_dir, [fairness_metrics],
        custom_metrics_check=check_result)