def _serialize_metrics(metrics: Tuple[slicer.SliceKeyType, Dict[Text, Any]], post_export_metrics: List[types.AddMetricsCallbackType] ) -> bytes: """Converts the given slice metrics into serialized proto MetricsForSlice. Args: metrics: The slice metrics. post_export_metrics: A list of metric callbacks. This should be the same list as the one passed to tfma.Evaluate(). Returns: The serialized proto MetricsForSlice. Raises: TypeError: If the type of the feature value in slice key cannot be recognized. """ result = metrics_for_slice_pb2.MetricsForSlice() slice_key, slice_metrics = metrics if metric_keys.ERROR_METRIC in slice_metrics: tf.logging.warning('Error for slice: %s with error message: %s ', slice_key, slice_metrics[metric_keys.ERROR_METRIC]) metrics = metrics_for_slice_pb2.MetricsForSlice() metrics.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key)) metrics.metrics[metric_keys.ERROR_METRIC].debug_message = slice_metrics[ metric_keys.ERROR_METRIC] return metrics.SerializeToString() # Convert the slice key. result.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key)) # Convert the slice metrics. convert_slice_metrics(slice_metrics, post_export_metrics, result) return result.SerializeToString()
def check_result(got): # pylint: disable=invalid-name try: self.assertEqual(1, len(got), 'got: %s' % got) (slice_key, value) = got[0] self.assertEqual((), slice_key) self.assertDictElementsAlmostEqual(value, expected_values_dict) # Check serialization too. # Note that we can't just make this a dict, since proto maps # allow uninitialized key access, i.e. they act like defaultdicts. output_metrics = metrics_for_slice_pb2.MetricsForSlice( ).metrics auc_metric.populate_stats_and_pop(value, output_metrics) self.assertProtoEquals( """ bounded_value { lower_bound { value: 0.6999999 } upper_bound { value: 0.7777776 } value { value: 0.7407472 } } """, output_metrics[metric_keys.AUPRC]) except AssertionError as err: raise util.BeamAssertException(err)
def revert_slice_keys_for_transformed_features( metrics: List[metrics_for_slice_pb2.MetricsForSlice], statistics: statistics_pb2.DatasetFeatureStatisticsList): """Revert the slice keys for the transformed features. Args: metrics: List of slice metrics protos. statistics: Data statistics used to configure AutoSliceKeyExtractor. Returns: List of slice metrics protos where transformed features are mapped back to raw features in the slice keys. """ result = [] boundaries = auto_slice_key_extractor._get_quantile_boundaries(statistics) # pylint: disable=protected-access for slice_metrics in metrics: transformed_metrics = metrics_for_slice_pb2.MetricsForSlice() transformed_metrics.CopyFrom(slice_metrics) for single_slice_key in transformed_metrics.slice_key.single_slice_keys: if single_slice_key.column.startswith( auto_slice_key_extractor.TRANSFORMED_FEATURE_PREFIX): raw_feature = single_slice_key.column[ len(auto_slice_key_extractor.TRANSFORMED_FEATURE_PREFIX):] single_slice_key.column = raw_feature (start, end) = auto_slice_key_extractor._get_bucket_boundary( # pylint: disable=protected-access getattr(single_slice_key, single_slice_key.WhichOneof('kind')), boundaries[raw_feature]) single_slice_key.bytes_value = _format_boundary(start, end) result.append(transformed_metrics) return result
def revert_slice_keys_for_transformed_features( metrics: List[metrics_for_slice_pb2.MetricsForSlice], statistics: statistics_pb2.DatasetFeatureStatisticsList): """Revert the slice keys for the transformed features. Args: metrics: List of slice metrics protos. statistics: Data statistics used to configure AutoSliceKeyExtractor. Returns: List of slice metrics protos where transformed features are mapped back to raw features in the slice keys. """ result = [] boundaries = auto_slice_key_extractor.get_quantile_boundaries(statistics) for slice_metrics in metrics: transformed_metrics = metrics_for_slice_pb2.MetricsForSlice() transformed_metrics.CopyFrom(slice_metrics) for single_slice_key in transformed_metrics.slice_key.single_slice_keys: raw_feature_name, raw_feature_value = get_raw_feature( single_slice_key.column, getattr(single_slice_key, single_slice_key.WhichOneof('kind')), boundaries) single_slice_key.column = raw_feature_name single_slice_key.bytes_value = raw_feature_value result.append(transformed_metrics) return result
def _serialize_metrics( metrics, post_export_metrics): """Converts the given slice metrics into serialized proto MetricsForSlice. Args: metrics: The slice metrics. post_export_metrics: A list of metric callbacks. This should be the same list as the one passed to tfma.Evaluate(). Returns: The serialized proto MetricsForSlice. Raises: TypeError: If the type of the feature value in slice key cannot be recognized. """ result = metrics_for_slice_pb2.MetricsForSlice() slice_key, slice_metrics = metrics # Convert the slice key. result.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key)) # Convert the slice metrics. _convert_slice_metrics(slice_metrics, post_export_metrics, result) return result.SerializeToString()
def testConvertMetricsProto(self): metrics_for_slice = text_format.Parse( """ slice_key {} metric_keys_and_values { key { name: "metric_name" } value: { double_value { value: 1.0 } } confidence_interval { lower_bound: { double_value: { value: 0.5 } } upper_bound: { double_value: { value: 1.5 } } } }""", metrics_for_slice_pb2.MetricsForSlice()) got = util.convert_metrics_proto_to_dict(metrics_for_slice) expected = ((), { '': { '': { 'metric_name': { 'boundedValue': { 'lowerBound': 0.5, 'upperBound': 1.5, 'value': 1.0 } } } } }) self.assertEqual(got, expected)
def check_result(got): try: self.assertEqual(3, len(got), 'got: %s' % got) for _, value in got: expected_value = { # Subgroup 'post_export_metrics/fairness/auc/subgroup_auc/fixed_int': 0.5, 'post_export_metrics/fairness/auc/subgroup_auc/fixed_int/lower_bound': 0.25, 'post_export_metrics/fairness/auc/subgroup_auc/fixed_int/upper_bound': 0.75, # BNSP 'post_export_metrics/fairness/auc/bnsp_auc/fixed_int': 0.5, 'post_export_metrics/fairness/auc/bnsp_auc/fixed_int/lower_bound': 0.25, 'post_export_metrics/fairness/auc/bnsp_auc/fixed_int/upper_bound': 0.75, # BPSN 'post_export_metrics/fairness/auc/bpsn_auc/fixed_int': 0.5, 'post_export_metrics/fairness/auc/bpsn_auc/fixed_int/lower_bound': 0.25, 'post_export_metrics/fairness/auc/bpsn_auc/fixed_int/upper_bound': 0.75, 'average_loss': 0.5, } self.assertDictElementsAlmostEqual(value, expected_value) # Check serialization too. output_metrics = metrics_for_slice_pb2.MetricsForSlice( ).metrics for slice_key, value in got: fairness_auc.populate_stats_and_pop( slice_key, value, output_metrics) for key in ( metric_keys.FAIRNESS_AUC + '/subgroup_auc/fixed_int', metric_keys.FAIRNESS_AUC + '/bpsn_auc/fixed_int', metric_keys.FAIRNESS_AUC + '/bnsp_auc/fixed_int', ): self.assertProtoEquals( """ bounded_value { lower_bound { value: 0.2500001 } upper_bound { value: 0.7499999 } value { value: 0.5 } methodology: RIEMANN_SUM } """, output_metrics[key]) except AssertionError as err: raise util.BeamAssertException(err)
def testUncertaintyValuedMetrics(self): slice_key = _make_slice_key() slice_metrics = { 'one_dim': types.ValueWithConfidenceInterval(2.0, 1.0, 3.0), 'nans': types.ValueWithConfidenceInterval(float('nan'), float('nan'), float('nan')), } expected_metrics_for_slice = text_format.Parse( """ slice_key {} metrics { key: "one_dim" value { bounded_value { value { value: 2.0 } lower_bound { value: 1.0 } upper_bound { value: 3.0 } methodology: POISSON_BOOTSTRAP } } } metrics { key: "nans" value { bounded_value { value { value: nan } lower_bound { value: nan } upper_bound { value: nan } methodology: POISSON_BOOTSTRAP } } } """, metrics_for_slice_pb2.MetricsForSlice()) got = metrics_and_plots_evaluator._serialize_metrics( (slice_key, slice_metrics), []) self.assertProtoEquals( expected_metrics_for_slice, metrics_for_slice_pb2.MetricsForSlice.FromString(got))
def testTensorValuedMetrics(self): slice_key = _make_slice_key() slice_metrics = { 'one_dim': np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32), 'two_dims': np.array([['two', 'dims', 'test'], ['TWO', 'DIMS', 'TEST']]), 'three_dims': np.array([[[100, 200, 300]], [[500, 600, 700]]], dtype=np.int64), } expected_metrics_for_slice = text_format.Parse( """ slice_key {} metrics { key: "one_dim" value { array_value { data_type: FLOAT32 shape: 4 float32_values: [1.0, 2.0, 3.0, 4.0] } } } metrics { key: "two_dims" value { array_value { data_type: BYTES shape: [2, 3] bytes_values: ["two", "dims", "test", "TWO", "DIMS", "TEST"] } } } metrics { key: "three_dims" value { array_value { data_type: INT64 shape: [2, 1, 3] int64_values: [100, 200, 300, 500, 600, 700] } } } """, metrics_for_slice_pb2.MetricsForSlice()) got = metrics_and_plots_evaluator._serialize_metrics( (slice_key, slice_metrics), []) self.assertProtoEquals( expected_metrics_for_slice, metrics_for_slice_pb2.MetricsForSlice.FromString(got))
def testConvertSliceMetricsToProtoEmptyMetrics(self): slice_key = _make_slice_key('age', 5, 'language', 'english', 'price', 0.3) slice_metrics = {metric_keys.ERROR_METRIC: 'error_message'} actual_metrics = ( metrics_plots_and_validations_writer.convert_slice_metrics_to_proto( (slice_key, slice_metrics), [post_export_metrics.auc(), post_export_metrics.auc(curve='PR')])) expected_metrics = metrics_for_slice_pb2.MetricsForSlice() expected_metrics.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key)) expected_metrics.metrics[ metric_keys.ERROR_METRIC].debug_message = 'error_message' self.assertProtoEquals(expected_metrics, actual_metrics)
def testSerializeMetrics_emptyMetrics(self): slice_key = _make_slice_key('age', 5, 'language', 'english', 'price', 0.3) slice_metrics = {metric_keys.ERROR_METRIC: 'error_message'} actual_metrics = metrics_and_plots_serialization._serialize_metrics( (slice_key, slice_metrics), [post_export_metrics.auc(), post_export_metrics.auc(curve='PR')]) expected_metrics = metrics_for_slice_pb2.MetricsForSlice() expected_metrics.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key)) expected_metrics.metrics[ metric_keys.ERROR_METRIC].debug_message = 'error_message' self.assertProtoEquals( expected_metrics, metrics_for_slice_pb2.MetricsForSlice.FromString(actual_metrics))
def testConvertSliceMetricsToProtoStringMetrics(self): slice_key = _make_slice_key() slice_metrics = { 'valid_ascii': b'test string', 'valid_unicode': b'\xF0\x9F\x90\x84', # U+1F404, Cow 'invalid_unicode': b'\xE2\x28\xA1', } expected_metrics_for_slice = metrics_for_slice_pb2.MetricsForSlice() expected_metrics_for_slice.slice_key.SetInParent() expected_metrics_for_slice.metrics[ 'valid_ascii'].bytes_value = slice_metrics['valid_ascii'] expected_metrics_for_slice.metrics[ 'valid_unicode'].bytes_value = slice_metrics['valid_unicode'] expected_metrics_for_slice.metrics[ 'invalid_unicode'].bytes_value = slice_metrics['invalid_unicode'] got = metrics_plots_and_validations_writer.convert_slice_metrics_to_proto( (slice_key, slice_metrics), []) self.assertProtoEquals(expected_metrics_for_slice, got)
def testStringMetrics(self): slice_key = _make_slice_key() slice_metrics = { 'valid_ascii': b'test string', 'valid_unicode': b'\xF0\x9F\x90\x84', # U+1F404, Cow 'invalid_unicode': b'\xE2\x28\xA1', } expected_metrics_for_slice = metrics_for_slice_pb2.MetricsForSlice() expected_metrics_for_slice.slice_key.SetInParent() expected_metrics_for_slice.metrics[ 'valid_ascii'].bytes_value = slice_metrics['valid_ascii'] expected_metrics_for_slice.metrics[ 'valid_unicode'].bytes_value = slice_metrics['valid_unicode'] expected_metrics_for_slice.metrics[ 'invalid_unicode'].bytes_value = slice_metrics['invalid_unicode'] got = serialization._serialize_metrics((slice_key, slice_metrics), []) self.assertProtoEquals( expected_metrics_for_slice, metrics_for_slice_pb2.MetricsForSlice.FromString(got))
def testSerializeMetrics(self): slice_key = _make_slice_key('age', 5, 'language', 'english', 'price', 0.3) slice_metrics = { metric_types.MetricKey(name='accuracy', output_name='output_name'): 0.8 } expected_metrics_for_slice = text_format.Parse( """ slice_key { single_slice_keys { column: 'age' int64_value: 5 } single_slice_keys { column: 'language' bytes_value: 'english' } single_slice_keys { column: 'price' float_value: 0.3 } } metric_keys_and_values { key { name: "accuracy" output_name: "output_name" } value { double_value { value: 0.8 } } }""", metrics_for_slice_pb2.MetricsForSlice()) got = metrics_and_plots_serialization._serialize_metrics( (slice_key, slice_metrics), None) self.assertProtoEquals( expected_metrics_for_slice, metrics_for_slice_pb2.MetricsForSlice.FromString(got))
def check_result(got): # pylint: disable=invalid-name try: self.assertEqual(1, len(got), 'got: %s' % got) (slice_key, value) = got[0] self.assertEqual((), slice_key) self.assertIn(metric_keys.PRECISION_RECALL_AT_K, value) table = value[metric_keys.PRECISION_RECALL_AT_K] cutoffs = table[:, 0].tolist() precision = table[:, 1].tolist() recall = table[:, 2].tolist() self.assertEqual(cutoffs, [0, 1, 2, 3, 5]) self.assertSequenceAlmostEqual( precision, [4.0 / 9.0, 2.0 / 3.0, 2.0 / 6.0, 4.0 / 9.0, 4.0 / 9.0]) self.assertSequenceAlmostEqual( recall, [4.0 / 4.0, 2.0 / 4.0, 2.0 / 4.0, 4.0 / 4.0, 4.0 / 4.0]) # Check serialization too. # Note that we can't just make this a dict, since proto maps # allow uninitialized key access, i.e. they act like defaultdicts. output_metrics = metrics_for_slice_pb2.MetricsForSlice( ).metrics precision_recall_metric.populate_stats_and_pop( value, output_metrics) self.assertProtoEquals( """ value_at_cutoffs { values { cutoff: 0 value: 0.44444444 } } value_at_cutoffs { values { cutoff: 1 value: 0.66666666 } } value_at_cutoffs { values { cutoff: 2 value: 0.33333333 } } value_at_cutoffs { values { cutoff: 3 value: 0.44444444 } } value_at_cutoffs { values { cutoff: 5 value: 0.44444444 } } """, output_metrics[metric_keys.PRECISION_AT_K]) self.assertProtoEquals( """ value_at_cutoffs { values { cutoff: 0 value: 1.0 } } value_at_cutoffs { values { cutoff: 1 value: 0.5 } } value_at_cutoffs { values { cutoff: 2 value: 0.5 } } value_at_cutoffs { values { cutoff: 3 value: 1.0 } } value_at_cutoffs { values { cutoff: 5 value: 1.0 } } """, output_metrics[metric_keys.RECALL_AT_K]) except AssertionError as err: raise util.BeamAssertException(err)
def testSerializeMetrics(self): slice_key = _make_slice_key('age', 5, 'language', 'english', 'price', 0.3) slice_metrics = { 'accuracy': 0.8, _full_key(metric_keys.AUPRC): 0.1, _full_key(metric_keys.lower_bound(metric_keys.AUPRC)): 0.05, _full_key(metric_keys.upper_bound(metric_keys.AUPRC)): 0.17, _full_key(metric_keys.AUC): 0.2, _full_key(metric_keys.lower_bound(metric_keys.AUC)): 0.1, _full_key(metric_keys.upper_bound(metric_keys.AUC)): 0.3 } expected_metrics_for_slice = text_format.Parse( string.Template(""" slice_key { single_slice_keys { column: 'age' int64_value: 5 } single_slice_keys { column: 'language' bytes_value: 'english' } single_slice_keys { column: 'price' float_value: 0.3 } } metrics { key: "accuracy" value { double_value { value: 0.8 } } } metrics { key: "$auc" value { bounded_value { lower_bound { value: 0.1 } upper_bound { value: 0.3 } value { value: 0.2 } methodology: RIEMANN_SUM } } } metrics { key: "$auprc" value { bounded_value { lower_bound { value: 0.05 } upper_bound { value: 0.17 } value { value: 0.1 } methodology: RIEMANN_SUM } } }""").substitute(auc=_full_key(metric_keys.AUC), auprc=_full_key(metric_keys.AUPRC)), metrics_for_slice_pb2.MetricsForSlice()) got = metrics_and_plots_evaluator._serialize_metrics( (slice_key, slice_metrics), [post_export_metrics.auc(), post_export_metrics.auc(curve='PR')]) self.assertProtoEquals( expected_metrics_for_slice, metrics_for_slice_pb2.MetricsForSlice.FromString(got))
def testSerializeConfusionMatrices(self): slice_key = _make_slice_key() thresholds = [0.25, 0.75, 1.00] matrices = [[0.0, 1.0, 0.0, 2.0, 1.0, 1.0], [1.0, 1.0, 0.0, 1.0, 1.0, 0.5], [2.0, 1.0, 0.0, 0.0, float('nan'), 0.0]] slice_metrics = { _full_key(metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_MATRICES): matrices, _full_key(metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_THRESHOLDS): thresholds, } expected_metrics_for_slice = text_format.Parse( """ slice_key {} metrics { key: "post_export_metrics/confusion_matrix_at_thresholds" value { confusion_matrix_at_thresholds { matrices { threshold: 0.25 false_negatives: 0.0 true_negatives: 1.0 false_positives: 0.0 true_positives: 2.0 precision: 1.0 recall: 1.0 bounded_false_negatives { value { value: 0.0 } } bounded_true_negatives { value { value: 1.0 } } bounded_true_positives { value { value: 2.0 } } bounded_false_positives { value { value: 0.0 } } bounded_precision { value { value: 1.0 } } bounded_recall { value { value: 1.0 } } } matrices { threshold: 0.75 false_negatives: 1.0 true_negatives: 1.0 false_positives: 0.0 true_positives: 1.0 precision: 1.0 recall: 0.5 bounded_false_negatives { value { value: 1.0 } } bounded_true_negatives { value { value: 1.0 } } bounded_true_positives { value { value: 1.0 } } bounded_false_positives { value { value: 0.0 } } bounded_precision { value { value: 1.0 } } bounded_recall { value { value: 0.5 } } } matrices { threshold: 1.00 false_negatives: 2.0 true_negatives: 1.0 false_positives: 0.0 true_positives: 0.0 precision: nan recall: 0.0 bounded_false_negatives { value { value: 2.0 } } bounded_true_negatives { value { value: 1.0 } } bounded_true_positives { value { value: 0.0 } } bounded_false_positives { value { value: 0.0 } } bounded_precision { value { value: nan } } bounded_recall { value { value: 0.0 } } } } } } """, metrics_for_slice_pb2.MetricsForSlice()) got = metrics_and_plots_evaluator._serialize_metrics( (slice_key, slice_metrics), [post_export_metrics.confusion_matrix_at_thresholds(thresholds)]) self.assertProtoEquals( expected_metrics_for_slice, metrics_for_slice_pb2.MetricsForSlice.FromString(got))
def testLoadMetricsAsDataframe_DoubleValueOnly(self): metrics_for_slice = text_format.Parse( """ slice_key { single_slice_keys { column: "age" float_value: 38.0 } single_slice_keys { column: "sex" bytes_value: "Female" } } metric_keys_and_values { key { name: "mean_absolute_error" example_weighted { } } value { double_value { value: 0.1 } } } metric_keys_and_values { key { name: "mean_squared_logarithmic_error" example_weighted { } } value { double_value { value: 0.02 } } } """, metrics_for_slice_pb2.MetricsForSlice()) path = os.path.join(absltest.get_default_test_tmpdir(), 'metrics.tfrecord') with tf.io.TFRecordWriter(path) as writer: writer.write(metrics_for_slice.SerializeToString()) df = experimental.load_metrics_as_dataframe(path) expected = pd.DataFrame({ 'slice': ['age = 38.0; sex = b\'Female\'', 'age = 38.0; sex = b\'Female\''], 'name': ['mean_absolute_error', 'mean_squared_logarithmic_error'], 'model_name': ['', ''], 'output_name': ['', ''], 'example_weighted': [False, False], 'is_diff': [False, False], 'display_value': [str(0.1), str(0.02)], 'metric_value': [ metrics_for_slice_pb2.MetricValue(double_value={'value': 0.1}), metrics_for_slice_pb2.MetricValue(double_value={'value': 0.02}) ], }) pd.testing.assert_frame_equal(expected, df) # Include empty column. df = experimental.load_metrics_as_dataframe(path, include_empty_columns=True) expected = pd.DataFrame({ 'slice': ['age = 38.0; sex = b\'Female\'', 'age = 38.0; sex = b\'Female\''], 'name': ['mean_absolute_error', 'mean_squared_logarithmic_error'], 'model_name': ['', ''], 'output_name': ['', ''], 'sub_key': [None, None], 'aggregation_type': [None, None], 'example_weighted': [False, False], 'is_diff': [False, False], 'display_value': [str(0.1), str(0.02)], 'metric_value': [ metrics_for_slice_pb2.MetricValue(double_value={'value': 0.1}), metrics_for_slice_pb2.MetricValue(double_value={'value': 0.02}) ], 'confidence_interval': [None, None], }) pd.testing.assert_frame_equal(expected, df)
def testWriteMetricsAndPlots(self): metrics_file = os.path.join(self._getTempDir(), 'metrics') plots_file = os.path.join(self._getTempDir(), 'plots') temp_eval_export_dir = os.path.join(self._getTempDir(), 'eval_export_dir') _, eval_export_dir = ( fixed_prediction_estimator.simple_fixed_prediction_estimator( None, temp_eval_export_dir)) eval_config = config.EvalConfig( model_specs=[config.ModelSpec()], options=config.Options( disabled_outputs={'values': ['eval_config.json']})) eval_shared_model = self.createTestEvalSharedModel( eval_saved_model_path=eval_export_dir, add_metrics_callbacks=[ post_export_metrics.example_count(), post_export_metrics.calibration_plot_and_prediction_histogram( num_buckets=2) ]) extractors = [ predict_extractor.PredictExtractor(eval_shared_model), slice_key_extractor.SliceKeyExtractor() ] evaluators = [ metrics_and_plots_evaluator.MetricsAndPlotsEvaluator(eval_shared_model) ] output_paths = { constants.METRICS_KEY: metrics_file, constants.PLOTS_KEY: plots_file } writers = [ metrics_plots_and_validations_writer.MetricsPlotsAndValidationsWriter( output_paths, eval_shared_model.add_metrics_callbacks) ] with beam.Pipeline() as pipeline: example1 = self._makeExample(prediction=0.0, label=1.0) example2 = self._makeExample(prediction=1.0, label=1.0) # pylint: disable=no-value-for-parameter _ = ( pipeline | 'Create' >> beam.Create([ example1.SerializeToString(), example2.SerializeToString(), ]) | 'ExtractEvaluateAndWriteResults' >> model_eval_lib.ExtractEvaluateAndWriteResults( eval_config=eval_config, eval_shared_model=eval_shared_model, extractors=extractors, evaluators=evaluators, writers=writers)) # pylint: enable=no-value-for-parameter expected_metrics_for_slice = text_format.Parse( """ slice_key {} metrics { key: "average_loss" value { double_value { value: 0.5 } } } metrics { key: "post_export_metrics/example_count" value { double_value { value: 2.0 } } } """, metrics_for_slice_pb2.MetricsForSlice()) metric_records = [] for record in tf.compat.v1.python_io.tf_record_iterator(metrics_file): metric_records.append( metrics_for_slice_pb2.MetricsForSlice.FromString(record)) self.assertEqual(1, len(metric_records), 'metrics: %s' % metric_records) self.assertProtoEquals(expected_metrics_for_slice, metric_records[0]) expected_plots_for_slice = text_format.Parse( """ slice_key {} plots { key: "post_export_metrics" value { calibration_histogram_buckets { buckets { lower_threshold_inclusive: -inf num_weighted_examples {} total_weighted_label {} total_weighted_refined_prediction {} } buckets { upper_threshold_exclusive: 0.5 num_weighted_examples { value: 1.0 } total_weighted_label { value: 1.0 } total_weighted_refined_prediction {} } buckets { lower_threshold_inclusive: 0.5 upper_threshold_exclusive: 1.0 num_weighted_examples { } total_weighted_label {} total_weighted_refined_prediction {} } buckets { lower_threshold_inclusive: 1.0 upper_threshold_exclusive: inf num_weighted_examples { value: 1.0 } total_weighted_label { value: 1.0 } total_weighted_refined_prediction { value: 1.0 } } } } } """, metrics_for_slice_pb2.PlotsForSlice()) plot_records = [] for record in tf.compat.v1.python_io.tf_record_iterator(plots_file): plot_records.append( metrics_for_slice_pb2.PlotsForSlice.FromString(record)) self.assertEqual(1, len(plot_records), 'plots: %s' % plot_records) self.assertProtoEquals(expected_plots_for_slice, plot_records[0])
def test_find_top_slices(self): statistics = text_format.Parse( """ datasets{ num_examples: 1500 features { path { step: 'country' } type: STRING string_stats { unique: 10 } } features { path { step: 'age' } type: INT num_stats { common_stats { num_non_missing: 1500 min_num_values: 1 max_num_values: 1 } min: 1 max: 18 histograms { buckets { low_value: 1 high_value: 6.0 sample_count: 500 } buckets { low_value: 6.0 high_value: 12.0 sample_count: 500 } buckets { low_value: 12.0 high_value: 18.0 sample_count: 500 } type: QUANTILES } } } } """, statistics_pb2.DatasetFeatureStatisticsList()) metrics = [ text_format.Parse( """ slice_key { } metric_keys_and_values { key { name: "accuracy" } value { bounded_value { value { value: 0.8 } lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } t_distribution_value { sample_mean { value: 0.8 } sample_standard_deviation { value: 0.1 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 0.8 } } } } } metric_keys_and_values { key { name: "example_count" } value { bounded_value { value { value: 1500 } lower_bound { value: 1500 } upper_bound { value: 1500 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 1500 } upper_bound { value: 1500 } t_distribution_value { sample_mean { value: 1500 } sample_standard_deviation { value: 0 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 1500 } } } } } """, metrics_for_slice_pb2.MetricsForSlice()), text_format.Parse( """ slice_key { single_slice_keys { column: 'transformed_age' int64_value: 1 } } metric_keys_and_values { key { name: "accuracy" } value { bounded_value { value { value: 0.4 } lower_bound { value: 0.3737843 } upper_bound { value: 0.6262157 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 0.3737843 } upper_bound { value: 0.6262157 } t_distribution_value { sample_mean { value: 0.4 } sample_standard_deviation { value: 0.1 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 0.4 } } } } } metric_keys_and_values { key { name: "example_count" } value { bounded_value { value { value: 500 } lower_bound { value: 500 } upper_bound { value: 500 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 500 } upper_bound { value: 500 } t_distribution_value { sample_mean { value: 500 } sample_standard_deviation { value: 0 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 500 } } } } } """, metrics_for_slice_pb2.MetricsForSlice()), text_format.Parse( """ slice_key { single_slice_keys { column: 'transformed_age' int64_value: 2 } } metric_keys_and_values { key { name: "accuracy" } value { bounded_value { value { value: 0.79 } lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } t_distribution_value { sample_mean { value: 0.79 } sample_standard_deviation { value: 0.1 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 0.79 } } } } } metric_keys_and_values { key { name: "example_count" } value { bounded_value { value { value: 500 } lower_bound { value: 500 } upper_bound { value: 500 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 500 } upper_bound { value: 500 } t_distribution_value { sample_mean { value: 500 } sample_standard_deviation { value: 0 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 500} } } } } """, metrics_for_slice_pb2.MetricsForSlice()), text_format.Parse( """ slice_key { single_slice_keys { column: 'transformed_age' int64_value: 3 } } metric_keys_and_values { key { name: "accuracy" } value { bounded_value { value { value: 0.9 } lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } t_distribution_value { sample_mean { value: 0.9 } sample_standard_deviation { value: 0.1 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 0.9 } } } } } metric_keys_and_values { key { name: "example_count" } value { bounded_value { value { value: 500 } lower_bound { value: 500 } upper_bound { value: 500 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 500 } upper_bound { value: 500 } t_distribution_value { sample_mean { value: 500 } sample_standard_deviation { value: 0 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 500} } } } } """, metrics_for_slice_pb2.MetricsForSlice()), text_format.Parse( """ slice_key { single_slice_keys { column: 'country' bytes_value: 'USA' } } metric_keys_and_values { key { name: "accuracy" } value { bounded_value { value { value: 0.9 } lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } t_distribution_value { sample_mean { value: 0.9 } sample_standard_deviation { value: 0.1 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 0.9 } } } } } metric_keys_and_values { key { name: "example_count" } value { bounded_value { value { value: 500 } lower_bound { value: 500 } upper_bound { value: 500 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 500 } upper_bound { value: 500 } t_distribution_value { sample_mean { value: 500 } sample_standard_deviation { value: 0 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 500} } } } } """, metrics_for_slice_pb2.MetricsForSlice()) ] self.assertCountEqual( auto_slicing_util.find_top_slices(metrics, metric_key='accuracy', statistics=statistics, comparison_type='LOWER'), [ auto_slicing_util.SliceComparisonResult( slice_key=u'age:[1.0, 6.0]', num_examples=500.0, slice_metric=0.4, base_metric=0.8, pvalue=0.0, effect_size=4.0) ]) self.assertCountEqual( auto_slicing_util.find_top_slices(metrics, metric_key='accuracy', statistics=statistics, comparison_type='HIGHER'), [ auto_slicing_util.SliceComparisonResult( slice_key=u'age:[12.0, 18.0]', num_examples=500.0, slice_metric=0.9, base_metric=0.8, pvalue=7.356017854191938e-70, effect_size=0.9999999999999996), auto_slicing_util.SliceComparisonResult( slice_key=u'country:USA', num_examples=500.0, slice_metric=0.9, base_metric=0.8, pvalue=7.356017854191938e-70, effect_size=0.9999999999999996) ])
def testConvertSliceMetricsToProtoFromLegacyStrings(self): slice_key = _make_slice_key('age', 5, 'language', 'english', 'price', 0.3) slice_metrics = { 'accuracy': 0.8, metric_keys.AUPRC: 0.1, metric_keys.lower_bound_key(metric_keys.AUPRC): 0.05, metric_keys.upper_bound_key(metric_keys.AUPRC): 0.17, metric_keys.AUC: 0.2, metric_keys.lower_bound_key(metric_keys.AUC): 0.1, metric_keys.upper_bound_key(metric_keys.AUC): 0.3 } expected_metrics_for_slice = text_format.Parse( string.Template(""" slice_key { single_slice_keys { column: 'age' int64_value: 5 } single_slice_keys { column: 'language' bytes_value: 'english' } single_slice_keys { column: 'price' float_value: 0.3 } } metrics { key: "accuracy" value { double_value { value: 0.8 } } } metrics { key: "$auc" value { bounded_value { lower_bound { value: 0.1 } upper_bound { value: 0.3 } value { value: 0.2 } methodology: RIEMANN_SUM } } } metrics { key: "$auprc" value { bounded_value { lower_bound { value: 0.05 } upper_bound { value: 0.17 } value { value: 0.1 } methodology: RIEMANN_SUM } } }""").substitute(auc=metric_keys.AUC, auprc=metric_keys.AUPRC), metrics_for_slice_pb2.MetricsForSlice()) got = metrics_plots_and_validations_writer.convert_slice_metrics_to_proto( (slice_key, slice_metrics), [post_export_metrics.auc(), post_export_metrics.auc(curve='PR')]) self.assertProtoEquals(expected_metrics_for_slice, got)
def testSerializeDeserializeToFile(self): metrics_slice_key = _make_slice_key(b'fruit', b'pear', b'animal', b'duck') metrics_for_slice = text_format.Parse( """ slice_key { single_slice_keys { column: "fruit" bytes_value: "pear" } single_slice_keys { column: "animal" bytes_value: "duck" } } metrics { key: "accuracy" value { double_value { value: 0.8 } } } metrics { key: "example_weight" value { double_value { value: 10.0 } } } metrics { key: "auc" value { bounded_value { lower_bound { value: 0.1 } upper_bound { value: 0.3 } value { value: 0.2 } } } } metrics { key: "auprc" value { bounded_value { lower_bound { value: 0.05 } upper_bound { value: 0.17 } value { value: 0.1 } } } }""", metrics_for_slice_pb2.MetricsForSlice()) plots_for_slice = text_format.Parse( """ slice_key { single_slice_keys { column: "fruit" bytes_value: "peach" } single_slice_keys { column: "animal" bytes_value: "cow" } } plots { key: '' value { calibration_histogram_buckets { buckets { lower_threshold_inclusive: -inf upper_threshold_exclusive: 0.0 num_weighted_examples { value: 0.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.0 } } buckets { lower_threshold_inclusive: 0.0 upper_threshold_exclusive: 0.5 num_weighted_examples { value: 1.0 } total_weighted_label { value: 1.0 } total_weighted_refined_prediction { value: 0.3 } } buckets { lower_threshold_inclusive: 0.5 upper_threshold_exclusive: 1.0 num_weighted_examples { value: 1.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.7 } } buckets { lower_threshold_inclusive: 1.0 upper_threshold_exclusive: inf num_weighted_examples { value: 0.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.0 } } } } }""", metrics_for_slice_pb2.PlotsForSlice()) plots_slice_key = _make_slice_key(b'fruit', b'peach', b'animal', b'cow') eval_config = model_eval_lib.EvalConfig( model_location='/path/to/model', data_location='/path/to/data', slice_spec=[ slicer.SingleSliceSpec(features=[('age', 5), ('gender', 'f')], columns=['country']), slicer.SingleSliceSpec(features=[('age', 6), ('gender', 'm')], columns=['interest']) ], example_weight_metric_key='key') output_path = self._getTempDir() with beam.Pipeline() as pipeline: metrics = (pipeline | 'CreateMetrics' >> beam.Create( [metrics_for_slice.SerializeToString()])) plots = (pipeline | 'CreatePlots' >> beam.Create( [plots_for_slice.SerializeToString()])) evaluation = { constants.METRICS_KEY: metrics, constants.PLOTS_KEY: plots } _ = (evaluation | 'WriteResults' >> model_eval_lib.WriteResults( writers=model_eval_lib.default_writers( output_path=output_path))) _ = pipeline | model_eval_lib.WriteEvalConfig( eval_config, output_path) metrics = metrics_and_plots_evaluator.load_and_deserialize_metrics( path=os.path.join(output_path, model_eval_lib._METRICS_OUTPUT_FILE)) plots = metrics_and_plots_evaluator.load_and_deserialize_plots( path=os.path.join(output_path, model_eval_lib._PLOTS_OUTPUT_FILE)) self.assertSliceMetricsListEqual( [(metrics_slice_key, metrics_for_slice.metrics)], metrics) self.assertSlicePlotsListEqual( [(plots_slice_key, plots_for_slice.plots)], plots) got_eval_config = model_eval_lib.load_eval_config(output_path) self.assertEqual(eval_config, got_eval_config)
def testSerializeMetricsRanges(self): slice_key = _make_slice_key('age', 5, 'language', 'english', 'price', 0.3) slice_metrics = { 'accuracy': types.ValueWithTDistribution(0.8, 0.1, 9, 0.8), metric_keys.AUPRC: 0.1, metric_keys.lower_bound_key(metric_keys.AUPRC): 0.05, metric_keys.upper_bound_key(metric_keys.AUPRC): 0.17, metric_keys.AUC: 0.2, metric_keys.lower_bound_key(metric_keys.AUC): 0.1, metric_keys.upper_bound_key(metric_keys.AUC): 0.3 } expected_metrics_for_slice = text_format.Parse( string.Template(""" slice_key { single_slice_keys { column: 'age' int64_value: 5 } single_slice_keys { column: 'language' bytes_value: 'english' } single_slice_keys { column: 'price' float_value: 0.3 } } metrics { key: "accuracy" value { bounded_value { value { value: 0.8 } lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } methodology: POISSON_BOOTSTRAP } } } metrics { key: "$auc" value { bounded_value { lower_bound { value: 0.1 } upper_bound { value: 0.3 } value { value: 0.2 } methodology: RIEMANN_SUM } } } metrics { key: "$auprc" value { bounded_value { lower_bound { value: 0.05 } upper_bound { value: 0.17 } value { value: 0.1 } methodology: RIEMANN_SUM } } }""").substitute(auc=metric_keys.AUC, auprc=metric_keys.AUPRC), metrics_for_slice_pb2.MetricsForSlice()) got = metrics_and_plots_serialization._serialize_metrics( (slice_key, slice_metrics), [post_export_metrics.auc(), post_export_metrics.auc(curve='PR')]) self.assertProtoEquals( expected_metrics_for_slice, metrics_for_slice_pb2.MetricsForSlice.FromString(got))
def testUncertaintyValuedMetrics(self): slice_key = _make_slice_key() slice_metrics = { 'one_dim': types.ValueWithTDistribution(2.0, 1.0, 3, 2.0), 'nans': types.ValueWithTDistribution( float('nan'), float('nan'), -1, float('nan')), } expected_metrics_for_slice = text_format.Parse( """ slice_key {} metrics { key: "one_dim" value { bounded_value { value { value: 2.0 } lower_bound { value: -1.1824463 } upper_bound { value: 5.1824463 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: -1.1824463 } upper_bound { value: 5.1824463 } t_distribution_value { sample_mean { value: 2.0 } sample_standard_deviation { value: 1.0 } sample_degrees_of_freedom { value: 3 } unsampled_value { value: 2.0 } } } } } metrics { key: "nans" value { bounded_value { value { value: nan } lower_bound { value: nan } upper_bound { value: nan } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: nan } upper_bound { value: nan } t_distribution_value { sample_mean { value: nan } sample_standard_deviation { value: nan } sample_degrees_of_freedom { value: -1 } unsampled_value { value: nan } } } } } """, metrics_for_slice_pb2.MetricsForSlice()) got = metrics_plots_and_validations_writer.convert_slice_metrics_to_proto( (slice_key, slice_metrics), []) self.assertProtoEquals(expected_metrics_for_slice, got)
def test_find_significant_slices(self): metrics = [ text_format.Parse( """ slice_key { } metric_keys_and_values { key { name: "accuracy" } value { bounded_value { value { value: 0.8 } lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } t_distribution_value { sample_mean { value: 0.8 } sample_standard_deviation { value: 0.1 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 0.8 } } } } } metric_keys_and_values { key { name: "example_count" } value { bounded_value { value { value: 1500 } lower_bound { value: 1500 } upper_bound { value: 1500 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 1500 } upper_bound { value: 1500 } t_distribution_value { sample_mean { value: 1500 } sample_standard_deviation { value: 0 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 1500 } } } } } """, metrics_for_slice_pb2.MetricsForSlice()), text_format.Parse( """ slice_key { single_slice_keys { column: 'age' bytes_value: '[1.0, 6.0)' } } metric_keys_and_values { key { name: "accuracy" } value { bounded_value { value { value: 0.4 } lower_bound { value: 0.3737843 } upper_bound { value: 0.6262157 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 0.3737843 } upper_bound { value: 0.6262157 } t_distribution_value { sample_mean { value: 0.4 } sample_standard_deviation { value: 0.1 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 0.4 } } } } } metric_keys_and_values { key { name: "example_count" } value { bounded_value { value { value: 500 } lower_bound { value: 500 } upper_bound { value: 500 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 500 } upper_bound { value: 500 } t_distribution_value { sample_mean { value: 500 } sample_standard_deviation { value: 0 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 500 } } } } } """, metrics_for_slice_pb2.MetricsForSlice()), text_format.Parse( """ slice_key { single_slice_keys { column: 'age' bytes_value: '[6.0, 12.0)' } } metric_keys_and_values { key { name: "accuracy" } value { bounded_value { value { value: 0.79 } lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } t_distribution_value { sample_mean { value: 0.79 } sample_standard_deviation { value: 0.1 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 0.79 } } } } } metric_keys_and_values { key { name: "example_count" } value { bounded_value { value { value: 500 } lower_bound { value: 500 } upper_bound { value: 500 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 500 } upper_bound { value: 500 } t_distribution_value { sample_mean { value: 500 } sample_standard_deviation { value: 0 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 500} } } } } """, metrics_for_slice_pb2.MetricsForSlice()), text_format.Parse( """ slice_key { single_slice_keys { column: 'age' bytes_value: '[12.0, 18.0)' } } metric_keys_and_values { key { name: "accuracy" } value { bounded_value { value { value: 0.9 } lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } t_distribution_value { sample_mean { value: 0.9 } sample_standard_deviation { value: 0.1 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 0.9 } } } } } metric_keys_and_values { key { name: "example_count" } value { bounded_value { value { value: 500 } lower_bound { value: 500 } upper_bound { value: 500 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 500 } upper_bound { value: 500 } t_distribution_value { sample_mean { value: 500 } sample_standard_deviation { value: 0 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 500} } } } } """, metrics_for_slice_pb2.MetricsForSlice()), text_format.Parse( """ slice_key { single_slice_keys { column: 'country' bytes_value: 'USA' } } metric_keys_and_values { key { name: "accuracy" } value { bounded_value { value { value: 0.9 } lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } t_distribution_value { sample_mean { value: 0.9 } sample_standard_deviation { value: 0.1 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 0.9 } } } } } metric_keys_and_values { key { name: "example_count" } value { bounded_value { value { value: 500 } lower_bound { value: 500 } upper_bound { value: 500 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 500 } upper_bound { value: 500 } t_distribution_value { sample_mean { value: 500 } sample_standard_deviation { value: 0 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 500} } } } } """, metrics_for_slice_pb2.MetricsForSlice()), text_format.Parse( """ slice_key { single_slice_keys { column: 'country' bytes_value: 'USA' } single_slice_keys { column: 'age' bytes_value: '[12.0, 18.0)' } } metric_keys_and_values { key { name: "accuracy" } value { bounded_value { value { value: 0.9 } lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } t_distribution_value { sample_mean { value: 0.9 } sample_standard_deviation { value: 0.1 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 0.9 } } } } } metric_keys_and_values { key { name: "example_count" } value { bounded_value { value { value: 500 } lower_bound { value: 500 } upper_bound { value: 500 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 500 } upper_bound { value: 500 } t_distribution_value { sample_mean { value: 500 } sample_standard_deviation { value: 0 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 500} } } } } """, metrics_for_slice_pb2.MetricsForSlice()) ] result = auto_slicing_util.partition_slices(metrics, metric_key='accuracy', comparison_type='LOWER') self.assertCountEqual([s.slice_key for s in result[0]], [(('age', '[1.0, 6.0)'), )]) self.assertCountEqual([s.slice_key for s in result[1]], [(('age', '[6.0, 12.0)'), ), (('age', '[12.0, 18.0)'), ), (('country', 'USA'), ), (('country', 'USA'), ('age', '[12.0, 18.0)'))]) result = auto_slicing_util.partition_slices(metrics, metric_key='accuracy', comparison_type='HIGHER') self.assertCountEqual([s.slice_key for s in result[0]], [(('age', '[12.0, 18.0)'), ), (('country', 'USA'), ), (('country', 'USA'), ('age', '[12.0, 18.0)'))]) self.assertCountEqual([s.slice_key for s in result[1]], [(('age', '[1.0, 6.0)'), ), (('age', '[6.0, 12.0)'), )])
def testConvertSliceMetricsToProtoConfusionMatrices(self): slice_key = _make_slice_key() thresholds = [0.25, 0.75, 1.00] matrices = [[0.0, 1.0, 0.0, 2.0, 1.0, 1.0], [1.0, 1.0, 0.0, 1.0, 1.0, 0.5], [2.0, 1.0, 0.0, 0.0, float('nan'), 0.0]] slice_metrics = { metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_MATRICES: matrices, metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_THRESHOLDS: thresholds, } expected_metrics_for_slice = text_format.Parse( """ slice_key {} metrics { key: "post_export_metrics/confusion_matrix_at_thresholds" value { confusion_matrix_at_thresholds { matrices { threshold: 0.25 false_negatives: 0.0 true_negatives: 1.0 false_positives: 0.0 true_positives: 2.0 precision: 1.0 recall: 1.0 bounded_false_negatives { value { value: 0.0 } } bounded_true_negatives { value { value: 1.0 } } bounded_true_positives { value { value: 2.0 } } bounded_false_positives { value { value: 0.0 } } bounded_precision { value { value: 1.0 } } bounded_recall { value { value: 1.0 } } t_distribution_false_negatives { unsampled_value { value: 0.0 } } t_distribution_true_negatives { unsampled_value { value: 1.0 } } t_distribution_true_positives { unsampled_value { value: 2.0 } } t_distribution_false_positives { unsampled_value { value: 0.0 } } t_distribution_precision { unsampled_value { value: 1.0 } } t_distribution_recall { unsampled_value { value: 1.0 } } } matrices { threshold: 0.75 false_negatives: 1.0 true_negatives: 1.0 false_positives: 0.0 true_positives: 1.0 precision: 1.0 recall: 0.5 bounded_false_negatives { value { value: 1.0 } } bounded_true_negatives { value { value: 1.0 } } bounded_true_positives { value { value: 1.0 } } bounded_false_positives { value { value: 0.0 } } bounded_precision { value { value: 1.0 } } bounded_recall { value { value: 0.5 } } t_distribution_false_negatives { unsampled_value { value: 1.0 } } t_distribution_true_negatives { unsampled_value { value: 1.0 } } t_distribution_true_positives { unsampled_value { value: 1.0 } } t_distribution_false_positives { unsampled_value { value: 0.0 } } t_distribution_precision { unsampled_value { value: 1.0 } } t_distribution_recall { unsampled_value { value: 0.5 } } } matrices { threshold: 1.00 false_negatives: 2.0 true_negatives: 1.0 false_positives: 0.0 true_positives: 0.0 precision: nan recall: 0.0 bounded_false_negatives { value { value: 2.0 } } bounded_true_negatives { value { value: 1.0 } } bounded_true_positives { value { value: 0.0 } } bounded_false_positives { value { value: 0.0 } } bounded_precision { value { value: nan } } bounded_recall { value { value: 0.0 } } t_distribution_false_negatives { unsampled_value { value: 2.0 } } t_distribution_true_negatives { unsampled_value { value: 1.0 } } t_distribution_true_positives { unsampled_value { value: 0.0 } } t_distribution_false_positives { unsampled_value { value: 0.0 } } t_distribution_precision { unsampled_value { value: nan } } t_distribution_recall { unsampled_value { value: 0.0 } } } } } } """, metrics_for_slice_pb2.MetricsForSlice()) got = metrics_plots_and_validations_writer.convert_slice_metrics_to_proto( (slice_key, slice_metrics), [post_export_metrics.confusion_matrix_at_thresholds(thresholds)]) self.assertProtoEquals(expected_metrics_for_slice, got)
def convert_slice_metrics_to_proto( metrics: Tuple[slicer.SliceKeyOrCrossSliceKeyType, Dict[Any, Any]], add_metrics_callbacks: List[types.AddMetricsCallbackType] ) -> metrics_for_slice_pb2.MetricsForSlice: """Converts the given slice metrics into serialized proto MetricsForSlice. Args: metrics: The slice metrics. add_metrics_callbacks: A list of metric callbacks. This should be the same list as the one passed to tfma.Evaluate(). Returns: The MetricsForSlice proto. Raises: TypeError: If the type of the feature value in slice key cannot be recognized. """ result = metrics_for_slice_pb2.MetricsForSlice() slice_key, slice_metrics = metrics if slicer.is_cross_slice_key(slice_key): result.cross_slice_key.CopyFrom( slicer.serialize_cross_slice_key(slice_key)) else: result.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key)) slice_metrics = slice_metrics.copy() if metric_keys.ERROR_METRIC in slice_metrics: logging.warning('Error for slice: %s with error message: %s ', slice_key, slice_metrics[metric_keys.ERROR_METRIC]) result.metrics[metric_keys.ERROR_METRIC].debug_message = slice_metrics[ metric_keys.ERROR_METRIC] return result # Convert the metrics from add_metrics_callbacks to the structured output if # defined. if add_metrics_callbacks and (not any( isinstance(k, metric_types.MetricKey) for k in slice_metrics.keys())): for add_metrics_callback in add_metrics_callbacks: if hasattr(add_metrics_callback, 'populate_stats_and_pop'): add_metrics_callback.populate_stats_and_pop( slice_key, slice_metrics, result.metrics) for key in sorted(slice_metrics.keys()): value = slice_metrics[key] if isinstance(value, types.ValueWithTDistribution): unsampled_value = value.unsampled_value _, lower_bound, upper_bound = ( math_util.calculate_confidence_interval(value)) confidence_interval = metrics_for_slice_pb2.ConfidenceInterval( lower_bound=convert_metric_value_to_proto(lower_bound), upper_bound=convert_metric_value_to_proto(upper_bound), standard_error=convert_metric_value_to_proto( value.sample_standard_deviation), degrees_of_freedom={'value': value.sample_degrees_of_freedom}) metric_value = convert_metric_value_to_proto(unsampled_value) # If metric can be stored to double_value metrics, replace it with a # bounded_value for backwards compatibility. # TODO(b/188575688): remove this logic to stop populating bounded_value if metric_value.WhichOneof('type') == 'double_value': # setting bounded_value clears double_value in the same oneof scope. metric_value.bounded_value.value.value = unsampled_value metric_value.bounded_value.lower_bound.value = lower_bound metric_value.bounded_value.upper_bound.value = upper_bound metric_value.bounded_value.methodology = ( metrics_for_slice_pb2.BoundedValue.POISSON_BOOTSTRAP) else: metric_value = convert_metric_value_to_proto(value) confidence_interval = None if isinstance(key, metric_types.MetricKey): result.metric_keys_and_values.add( key=key.to_proto(), value=metric_value, confidence_interval=confidence_interval) else: result.metrics[key].CopyFrom(metric_value) return result
def check_result(got): # pylint: disable=invalid-name try: self.assertEqual(1, len(got), 'got: %s' % got) (slice_key, value) = got[0] self.assertEqual((), slice_key) self.assertIn( metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_MATRICES, value) matrices = value[ metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_MATRICES] # | | ---- Threshold ---- # true label | pred | 0.25 | 0.75 | 1.00 # - | 0.0 | TN | TN | TN # + | 0.5 | TP | FN | FN # + | 1.0 | TP | TP | FN self.assertSequenceAlmostEqual(matrices[0], [0.0, 1.0, 0.0, 2.0, 1.0, 1.0]) self.assertSequenceAlmostEqual(matrices[1], [1.0, 1.0, 0.0, 1.0, 1.0, 0.5]) self.assertSequenceAlmostEqual( matrices[2], [2.0, 1.0, 0.0, 0.0, float('nan'), 0.0]) self.assertIn( metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_THRESHOLDS, value) thresholds = value[ metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_THRESHOLDS] self.assertAlmostEqual(0.25, thresholds[0]) self.assertAlmostEqual(0.75, thresholds[1]) self.assertAlmostEqual(1.00, thresholds[2]) # Check serialization too. # Note that we can't just make this a dict, since proto maps # allow uninitialized key access, i.e. they act like defaultdicts. output_metrics = metrics_for_slice_pb2.MetricsForSlice( ).metrics confusion_matrix_at_thresholds_metric.populate_stats_and_pop( value, output_metrics) self.assertProtoEquals( """ confusion_matrix_at_thresholds { matrices { threshold: 0.25 false_negatives: 0.0 true_negatives: 1.0 false_positives: 0.0 true_positives: 2.0 precision: 1.0 recall: 1.0 } matrices { threshold: 0.75 false_negatives: 1.0 true_negatives: 1.0 false_positives: 0.0 true_positives: 1.0 precision: 1.0 recall: 0.5 } matrices { threshold: 1.00 false_negatives: 2.0 true_negatives: 1.0 false_positives: 0.0 true_positives: 0.0 precision: nan recall: 0.0 } } """, output_metrics[metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS]) except AssertionError as err: raise util.BeamAssertException(err)
def convert_slice_metrics_to_proto( metrics: Tuple[slicer.SliceKeyType, Dict[Any, Any]], add_metrics_callbacks: List[types.AddMetricsCallbackType] ) -> metrics_for_slice_pb2.MetricsForSlice: """Converts the given slice metrics into serialized proto MetricsForSlice. Args: metrics: The slice metrics. add_metrics_callbacks: A list of metric callbacks. This should be the same list as the one passed to tfma.Evaluate(). Returns: The MetricsForSlice proto. Raises: TypeError: If the type of the feature value in slice key cannot be recognized. """ result = metrics_for_slice_pb2.MetricsForSlice() slice_key, slice_metrics = metrics result.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key)) slice_metrics = slice_metrics.copy() if metric_keys.ERROR_METRIC in slice_metrics: logging.warning('Error for slice: %s with error message: %s ', slice_key, slice_metrics[metric_keys.ERROR_METRIC]) result.metrics[metric_keys.ERROR_METRIC].debug_message = slice_metrics[ metric_keys.ERROR_METRIC] return result # Convert the metrics from add_metrics_callbacks to the structured output if # defined. if add_metrics_callbacks and (not any( isinstance(k, metric_types.MetricKey) for k in slice_metrics.keys())): for add_metrics_callback in add_metrics_callbacks: if hasattr(add_metrics_callback, 'populate_stats_and_pop'): add_metrics_callback.populate_stats_and_pop( slice_key, slice_metrics, result.metrics) for key in sorted(slice_metrics.keys()): value = slice_metrics[key] metric_value = metrics_for_slice_pb2.MetricValue() if isinstance(value, metrics_for_slice_pb2.ConfusionMatrixAtThresholds): metric_value.confusion_matrix_at_thresholds.CopyFrom(value) elif isinstance( value, metrics_for_slice_pb2.MultiClassConfusionMatrixAtThresholds): metric_value.multi_class_confusion_matrix_at_thresholds.CopyFrom( value) elif isinstance(value, types.ValueWithTDistribution): # Currently we populate both bounded_value and confidence_interval. # Avoid populating bounded_value once the UI handles confidence_interval. # Convert to a bounded value. 95% confidence level is computed here. _, lower_bound, upper_bound = ( math_util.calculate_confidence_interval(value)) metric_value.bounded_value.value.value = value.unsampled_value metric_value.bounded_value.lower_bound.value = lower_bound metric_value.bounded_value.upper_bound.value = upper_bound metric_value.bounded_value.methodology = ( metrics_for_slice_pb2.BoundedValue.POISSON_BOOTSTRAP) # Populate confidence_interval metric_value.confidence_interval.lower_bound.value = lower_bound metric_value.confidence_interval.upper_bound.value = upper_bound t_dist_value = metrics_for_slice_pb2.TDistributionValue() t_dist_value.sample_mean.value = value.sample_mean t_dist_value.sample_standard_deviation.value = ( value.sample_standard_deviation) t_dist_value.sample_degrees_of_freedom.value = ( value.sample_degrees_of_freedom) # Once the UI handles confidence interval, we will avoid setting this and # instead use the double_value. t_dist_value.unsampled_value.value = value.unsampled_value metric_value.confidence_interval.t_distribution_value.CopyFrom( t_dist_value) elif isinstance(value, six.binary_type): # Convert textual types to string metrics. metric_value.bytes_value = value elif isinstance(value, six.text_type): # Convert textual types to string metrics. metric_value.bytes_value = value.encode('utf8') elif isinstance(value, np.ndarray): # Convert NumPy arrays to ArrayValue. metric_value.array_value.CopyFrom(_convert_to_array_value(value)) else: # We try to convert to float values. try: metric_value.double_value.value = float(value) except (TypeError, ValueError) as e: metric_value.unknown_type.value = str(value) metric_value.unknown_type.error = e.message # pytype: disable=attribute-error if isinstance(key, metric_types.MetricKey): key_and_value = result.metric_keys_and_values.add() key_and_value.key.CopyFrom(key.to_proto()) key_and_value.value.CopyFrom(metric_value) else: result.metrics[key].CopyFrom(metric_value) return result
def testConvertSliceMetricsToProtoMetricsRanges(self): slice_key = _make_slice_key('age', 5, 'language', 'english', 'price', 0.3) slice_metrics = { 'accuracy': types.ValueWithTDistribution(0.8, 0.1, 9, 0.8), metric_keys.AUPRC: 0.1, metric_keys.lower_bound_key(metric_keys.AUPRC): 0.05, metric_keys.upper_bound_key(metric_keys.AUPRC): 0.17, metric_keys.AUC: 0.2, metric_keys.lower_bound_key(metric_keys.AUC): 0.1, metric_keys.upper_bound_key(metric_keys.AUC): 0.3 } expected_metrics_for_slice = text_format.Parse( string.Template(""" slice_key { single_slice_keys { column: 'age' int64_value: 5 } single_slice_keys { column: 'language' bytes_value: 'english' } single_slice_keys { column: 'price' float_value: 0.3 } } metrics { key: "accuracy" value { bounded_value { value { value: 0.8 } lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } t_distribution_value { sample_mean { value: 0.8 } sample_standard_deviation { value: 0.1 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 0.8 } } } } } metrics { key: "$auc" value { bounded_value { lower_bound { value: 0.1 } upper_bound { value: 0.3 } value { value: 0.2 } methodology: RIEMANN_SUM } } } metrics { key: "$auprc" value { bounded_value { lower_bound { value: 0.05 } upper_bound { value: 0.17 } value { value: 0.1 } methodology: RIEMANN_SUM } } }""").substitute(auc=metric_keys.AUC, auprc=metric_keys.AUPRC), metrics_for_slice_pb2.MetricsForSlice()) got = metrics_plots_and_validations_writer.convert_slice_metrics_to_proto( (slice_key, slice_metrics), [post_export_metrics.auc(), post_export_metrics.auc(curve='PR')]) self.assertProtoEquals(expected_metrics_for_slice, got)