def testFairnessIndicatorsMultiHead(self): temp_eval_export_dir = self._getEvalExportDir() _, eval_export_dir = (multi_head.simple_multi_head( None, temp_eval_export_dir)) examples = [ self._makeExample(age=3.0, language='english', english_label=1.0, chinese_label=0.0, other_label=0.0), self._makeExample(age=3.0, language='chinese', english_label=0.0, chinese_label=1.0, other_label=0.0), self._makeExample(age=4.0, language='english', english_label=1.0, chinese_label=0.0, other_label=0.0), self._makeExample(age=5.0, language='chinese', english_label=0.0, chinese_label=1.0, other_label=0.0), self._makeExample(age=6.0, language='chinese', english_label=0.0, chinese_label=1.0, other_label=0.0), ] fairness_english = post_export_metrics.fairness_indicators( target_prediction_keys=['english_head/logistic'], labels_key='english_head') fairness_chinese = post_export_metrics.fairness_indicators( target_prediction_keys=['chinese_head/logistic'], labels_key='chinese_head') def check_metric_result(got): try: self.assertEqual(1, len(got), 'got: %s' % got) (slice_key, value) = got[0] self.assertEqual((), slice_key) expected_values_dict = { metric_keys.base_key('english_head/logistic/[email protected]'): 1.0, metric_keys.base_key('chinese_head/logistic/[email protected]'): 1.0, } self.assertDictElementsAlmostEqual(value, expected_values_dict) except AssertionError as err: raise util.BeamAssertException(err) self._runTestWithCustomCheck(examples, eval_export_dir, [ fairness_english, fairness_chinese, ], custom_metrics_check=check_metric_result)
def testFairnessIndicatorsCounters(self): temp_eval_export_dir = self._getEvalExportDir() _, eval_export_dir = ( multi_head.simple_multi_head(None, temp_eval_export_dir)) examples = [ self._makeExample( age=3.0, language='english', english_label=1.0, chinese_label=0.0, other_label=0.0), self._makeExample( age=3.0, language='chinese', english_label=0.0, chinese_label=1.0, other_label=0.0), self._makeExample( age=4.0, language='english', english_label=1.0, chinese_label=0.0, other_label=0.0), self._makeExample( age=5.0, language='chinese', english_label=0.0, chinese_label=1.0, other_label=0.0), self._makeExample( age=6.0, language='chinese', english_label=0.0, chinese_label=1.0, other_label=0.0), ] fairness_english = post_export_metrics.fairness_indicators( target_prediction_keys=['english_head/logistic'], labels_key='english_head') fairness_chinese = post_export_metrics.fairness_indicators( target_prediction_keys=['chinese_head/logistic'], labels_key='chinese_head') def check_metric_counter(result): metric_filter = beam.metrics.metric.MetricsFilter().with_name( 'metric_computed_fairness_indicators') actual_metrics_count = result.metrics().query( filter=metric_filter)['counters'][0].committed self.assertEqual(actual_metrics_count, 2) self._runTestWithCustomCheck( examples, eval_export_dir, [ fairness_english, fairness_chinese, ], custom_result_check=check_metric_counter)
def testFairnessIndicatorsZeroes(self): temp_eval_export_dir = self._getEvalExportDir() _, eval_export_dir = ( fixed_prediction_estimator_extra_fields .simple_fixed_prediction_estimator_extra_fields(None, temp_eval_export_dir)) examples = self.makeConfusionMatrixExamples()[0:1] fairness_metrics = post_export_metrics.fairness_indicators() def check_result(got): # pylint: disable=invalid-name try: self.assertEqual(1, len(got), 'got: %s' % got) (slice_key, value) = got[0] self.assertEqual((), slice_key) expected_values_dict = { metric_keys.base_key('[email protected]'): 0.0, } self.assertDictElementsAlmostEqual(value, expected_values_dict) except AssertionError as err: raise util.BeamAssertException(err) self._runTestWithCustomCheck( examples, eval_export_dir, [fairness_metrics], custom_metrics_check=check_result)
def testFairnessIndicatorsDigitsKey(self): temp_eval_export_dir = self._getEvalExportDir() _, eval_export_dir = ( fixed_prediction_estimator_extra_fields .simple_fixed_prediction_estimator_extra_fields(None, temp_eval_export_dir)) examples = self.makeConfusionMatrixExamples() fairness_metrics = post_export_metrics.fairness_indicators( example_weight_key='fixed_float', thresholds=[0.5, 0.59, 0.599, 0.5999]) def check_result(got): # pylint: disable=invalid-name try: self.assertEqual(1, len(got), 'got: %s' % got) (slice_key, value) = got[0] self.assertEqual((), slice_key) self.assertIn(metric_keys.base_key('[email protected]'), value) self.assertIn(metric_keys.base_key('[email protected]'), value) self.assertIn(metric_keys.base_key('[email protected]'), value) self.assertIn(metric_keys.base_key('[email protected]'), value) self.assertIn(metric_keys.base_key('[email protected]'), value) self.assertIn(metric_keys.base_key('[email protected]'), value) self.assertIn(metric_keys.base_key('[email protected]'), value) self.assertIn(metric_keys.base_key('[email protected]'), value) except AssertionError as err: raise util.BeamAssertException(err) self._runTestWithCustomCheck( examples, eval_export_dir, [fairness_metrics], custom_metrics_check=check_result)
def testFairnessIndicatorsAtThresholdsWeightedWithUncertainty(self): self.compute_confidence_intervals = True temp_eval_export_dir = self._getEvalExportDir() _, eval_export_dir = ( fixed_prediction_estimator_extra_fields .simple_fixed_prediction_estimator_extra_fields(None, temp_eval_export_dir)) examples = self.makeConfusionMatrixExamples() fairness_metrics = post_export_metrics.fairness_indicators( example_weight_key='fixed_float', thresholds=[0.0, 0.7, 0.8, 1.0]) def check_result(got): # pylint: disable=invalid-name try: self.assertEqual(1, len(got), 'got: %s' % got) (slice_key, value) = got[0] self.assertEqual((), slice_key) expected_values_dict = { metric_keys.base_key('[email protected]'): 6.0 / 7.0, metric_keys.base_key('[email protected]'): 2.0 / 3.0, metric_keys.base_key('[email protected]'): 0.8, metric_keys.base_key('[email protected]'): 1.0 / 3.0, metric_keys.base_key('[email protected]'): 1.0 / 7.0, metric_keys.base_key('[email protected]'): 2.0 / 10.0, metric_keys.base_key('[email protected]'): 2.0 / 8.0, metric_keys.base_key('[email protected]'): 1.0 / 2.0, metric_keys.base_key('[email protected]'): 3.0 / 7.0, metric_keys.base_key('[email protected]'): 2.0 / 3.0, metric_keys.base_key('[email protected]'): 0.5, metric_keys.base_key('[email protected]'): 1.0 / 3.0, metric_keys.base_key('[email protected]'): 4.0 / 7.0, metric_keys.base_key('[email protected]'): 5.0 / 10.0, metric_keys.base_key('[email protected]'): 2.0 / 5.0, metric_keys.base_key('[email protected]'): 4.0 / 5.0, metric_keys.base_key('[email protected]'): 3.0 / 7.0, metric_keys.base_key('[email protected]'): 0, metric_keys.base_key('[email protected]'): 0.3, metric_keys.base_key('[email protected]'): 3.0 / 3.0, metric_keys.base_key('[email protected]'): 4.0 / 7.0, metric_keys.base_key('[email protected]'): 7.0 / 10.0, metric_keys.base_key('[email protected]'): 0, metric_keys.base_key('[email protected]'): 4.0 / 7.0, metric_keys.base_key('[email protected]'): 0, metric_keys.base_key('[email protected]'): 0, metric_keys.base_key('[email protected]'): 0, metric_keys.base_key('[email protected]'): 1, metric_keys.base_key('[email protected]'): 1, metric_keys.base_key('[email protected]'): 1, metric_keys.base_key('[email protected]'): 0, metric_keys.base_key('[email protected]'): 7.0 / 10.0, } self.assertDictElementsWithTDistributionAlmostEqual( value, expected_values_dict) except AssertionError as err: raise util.BeamAssertException(err) self._runTestWithCustomCheck( examples, eval_export_dir, [fairness_metrics], custom_metrics_check=check_result)