def _serialize_plots( plots: Tuple[slicer.SliceKeyType, Dict[Text, Any]], post_export_metrics: List[types.AddMetricsCallbackType]) -> bytes: """Converts the given slice plots into serialized proto PlotsForSlice.. Args: plots: The slice plots. post_export_metrics: A list of metric callbacks. This should be the same list as the one passed to tfma.Evaluate(). Returns: The serialized proto PlotsForSlice. """ result = metrics_for_slice_pb2.PlotsForSlice() slice_key, slice_plots = plots if metric_keys.ERROR_METRIC in slice_plots: tf.compat.v1.logging.warning( 'Error for slice: %s with error message: %s ', slice_key, slice_plots[metric_keys.ERROR_METRIC]) metrics = metrics_for_slice_pb2.PlotsForSlice() metrics.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key)) metrics.plots[metric_keys.ERROR_METRIC].debug_message = slice_plots[ metric_keys.ERROR_METRIC] return metrics.SerializeToString() # Convert the slice key. result.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key)) # Convert the slice plots. _convert_slice_plots(slice_plots, post_export_metrics, result.plots) # pytype: disable=wrong-arg-types return result.SerializeToString()
def testConvertSlicePlotsToProtoEmptyPlot(self): slice_key = _make_slice_key('fruit', 'apple') tfma_plots = {metric_keys.ERROR_METRIC: 'error_message'} actual_plot = metrics_plots_and_validations_writer.convert_slice_plots_to_proto( (slice_key, tfma_plots), []) expected_plot = metrics_for_slice_pb2.PlotsForSlice() expected_plot.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key)) expected_plot.plots[ metric_keys.ERROR_METRIC].debug_message = 'error_message' self.assertProtoEquals(expected_plot, actual_plot)
def testSerializePlots_emptyPlot(self): slice_key = _make_slice_key('fruit', 'apple') tfma_plots = {metric_keys.ERROR_METRIC: 'error_message'} actual_plot = metrics_and_plots_serialization._serialize_plots( (slice_key, tfma_plots), []) expected_plot = metrics_for_slice_pb2.PlotsForSlice() expected_plot.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key)) expected_plot.plots[ metric_keys.ERROR_METRIC].debug_message = 'error_message' self.assertProtoEquals( expected_plot, metrics_for_slice_pb2.PlotsForSlice.FromString(actual_plot))
def _serialize_plots(plots, post_export_metrics): """Converts the given slice plots into serialized proto PlotsForSlice.. Args: plots: The slice plots. post_export_metrics: A list of metric callbacks. This should be the same list as the one passed to tfma.Evaluate(). Returns: The serialized proto PlotsForSlice. """ result = metrics_for_slice_pb2.PlotsForSlice() slice_key, slice_plots = plots # Convert the slice key. result.slice_key.CopyFrom(_convert_slice_key(slice_key)) # Convert the slice plots. _convert_slice_plots(slice_plots, post_export_metrics, result.plot_data) # pytype: disable=wrong-arg-types return result.SerializeToString()
def _serialize_plots( plots: Tuple[slicer.SliceKeyType, Dict[Text, Any]], post_export_metrics: List[types.AddMetricsCallbackType]) -> bytes: """Converts the given slice plots into serialized proto PlotsForSlice.. Args: plots: The slice plots. post_export_metrics: A list of metric callbacks. This should be the same list as the one passed to tfma.Evaluate(). Returns: The serialized proto PlotsForSlice. """ result = metrics_for_slice_pb2.PlotsForSlice() slice_key, slice_plots = plots # Convert the slice key. result.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key)) # Convert the slice plots. _convert_slice_plots(slice_plots, post_export_metrics, result.plots) # pytype: disable=wrong-arg-types return result.SerializeToString()
def convert_slice_plots_to_proto( plots: Tuple[slicer.SliceKeyType, Dict[Any, Any]], add_metrics_callbacks: List[types.AddMetricsCallbackType] ) -> metrics_for_slice_pb2.PlotsForSlice: """Converts the given slice plots into PlotsForSlice proto. Args: plots: The slice plots. add_metrics_callbacks: A list of metric callbacks. This should be the same list as the one passed to tfma.Evaluate(). Returns: The PlotsForSlice proto. """ result = metrics_for_slice_pb2.PlotsForSlice() slice_key, slice_plots = plots result.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key)) slice_plots = slice_plots.copy() if metric_keys.ERROR_METRIC in slice_plots: logging.warning('Error for slice: %s with error message: %s ', slice_key, slice_plots[metric_keys.ERROR_METRIC]) error_metric = slice_plots.pop(metric_keys.ERROR_METRIC) result.plots[metric_keys.ERROR_METRIC].debug_message = error_metric return result if add_metrics_callbacks and (not any( isinstance(k, metric_types.MetricKey) for k in slice_plots.keys())): for add_metrics_callback in add_metrics_callbacks: if hasattr(add_metrics_callback, 'populate_plots_and_pop'): add_metrics_callback.populate_plots_and_pop( slice_plots, result.plots) plots_by_key = {} for key in sorted(slice_plots.keys()): value = slice_plots[key] # Remove plot name from key (multiple plots are combined into a single # proto). if isinstance(key, metric_types.MetricKey): parent_key = key._replace(name=None) else: continue if parent_key not in plots_by_key: key_and_value = result.plot_keys_and_values.add() key_and_value.key.CopyFrom(parent_key.to_proto()) plots_by_key[parent_key] = key_and_value.value if isinstance(value, metrics_for_slice_pb2.CalibrationHistogramBuckets): plots_by_key[parent_key].calibration_histogram_buckets.CopyFrom( value) slice_plots.pop(key) elif isinstance(value, metrics_for_slice_pb2.ConfusionMatrixAtThresholds): plots_by_key[parent_key].confusion_matrix_at_thresholds.CopyFrom( value) slice_plots.pop(key) elif isinstance( value, metrics_for_slice_pb2.MultiClassConfusionMatrixAtThresholds): plots_by_key[ parent_key].multi_class_confusion_matrix_at_thresholds.CopyFrom( value) slice_plots.pop(key) elif isinstance( value, metrics_for_slice_pb2.MultiLabelConfusionMatrixAtThresholds): plots_by_key[ parent_key].multi_label_confusion_matrix_at_thresholds.CopyFrom( value) slice_plots.pop(key) if slice_plots: if add_metrics_callbacks is None: add_metrics_callbacks = [] raise NotImplementedError( 'some plots were not converted or popped. keys: %s. ' 'add_metrics_callbacks were: %s' % ( slice_plots.keys(), [ x.name for x in add_metrics_callbacks # pytype: disable=attribute-error ])) return result
def testWriteMetricsAndPlots(self): metrics_file = os.path.join(self._getTempDir(), 'metrics') plots_file = os.path.join(self._getTempDir(), 'plots') temp_eval_export_dir = os.path.join(self._getTempDir(), 'eval_export_dir') _, eval_export_dir = ( fixed_prediction_estimator.simple_fixed_prediction_estimator( None, temp_eval_export_dir)) eval_config = config.EvalConfig( model_specs=[config.ModelSpec()], options=config.Options( disabled_outputs={'values': ['eval_config.json']})) eval_shared_model = self.createTestEvalSharedModel( eval_saved_model_path=eval_export_dir, add_metrics_callbacks=[ post_export_metrics.example_count(), post_export_metrics.calibration_plot_and_prediction_histogram( num_buckets=2) ]) extractors = [ predict_extractor.PredictExtractor(eval_shared_model), slice_key_extractor.SliceKeyExtractor() ] evaluators = [ metrics_and_plots_evaluator.MetricsAndPlotsEvaluator(eval_shared_model) ] output_paths = { constants.METRICS_KEY: metrics_file, constants.PLOTS_KEY: plots_file } writers = [ metrics_plots_and_validations_writer.MetricsPlotsAndValidationsWriter( output_paths, eval_shared_model.add_metrics_callbacks) ] with beam.Pipeline() as pipeline: example1 = self._makeExample(prediction=0.0, label=1.0) example2 = self._makeExample(prediction=1.0, label=1.0) # pylint: disable=no-value-for-parameter _ = ( pipeline | 'Create' >> beam.Create([ example1.SerializeToString(), example2.SerializeToString(), ]) | 'ExtractEvaluateAndWriteResults' >> model_eval_lib.ExtractEvaluateAndWriteResults( eval_config=eval_config, eval_shared_model=eval_shared_model, extractors=extractors, evaluators=evaluators, writers=writers)) # pylint: enable=no-value-for-parameter expected_metrics_for_slice = text_format.Parse( """ slice_key {} metrics { key: "average_loss" value { double_value { value: 0.5 } } } metrics { key: "post_export_metrics/example_count" value { double_value { value: 2.0 } } } """, metrics_for_slice_pb2.MetricsForSlice()) metric_records = [] for record in tf.compat.v1.python_io.tf_record_iterator(metrics_file): metric_records.append( metrics_for_slice_pb2.MetricsForSlice.FromString(record)) self.assertEqual(1, len(metric_records), 'metrics: %s' % metric_records) self.assertProtoEquals(expected_metrics_for_slice, metric_records[0]) expected_plots_for_slice = text_format.Parse( """ slice_key {} plots { key: "post_export_metrics" value { calibration_histogram_buckets { buckets { lower_threshold_inclusive: -inf num_weighted_examples {} total_weighted_label {} total_weighted_refined_prediction {} } buckets { upper_threshold_exclusive: 0.5 num_weighted_examples { value: 1.0 } total_weighted_label { value: 1.0 } total_weighted_refined_prediction {} } buckets { lower_threshold_inclusive: 0.5 upper_threshold_exclusive: 1.0 num_weighted_examples { } total_weighted_label {} total_weighted_refined_prediction {} } buckets { lower_threshold_inclusive: 1.0 upper_threshold_exclusive: inf num_weighted_examples { value: 1.0 } total_weighted_label { value: 1.0 } total_weighted_refined_prediction { value: 1.0 } } } } } """, metrics_for_slice_pb2.PlotsForSlice()) plot_records = [] for record in tf.compat.v1.python_io.tf_record_iterator(plots_file): plot_records.append( metrics_for_slice_pb2.PlotsForSlice.FromString(record)) self.assertEqual(1, len(plot_records), 'plots: %s' % plot_records) self.assertProtoEquals(expected_plots_for_slice, plot_records[0])
def testSerializeDeserializeToFile(self): metrics_slice_key = _make_slice_key(b'fruit', b'pear', b'animal', b'duck') metrics_for_slice = text_format.Parse( """ slice_key { single_slice_keys { column: "fruit" bytes_value: "pear" } single_slice_keys { column: "animal" bytes_value: "duck" } } metrics { key: "accuracy" value { double_value { value: 0.8 } } } metrics { key: "example_weight" value { double_value { value: 10.0 } } } metrics { key: "auc" value { bounded_value { lower_bound { value: 0.1 } upper_bound { value: 0.3 } value { value: 0.2 } } } } metrics { key: "auprc" value { bounded_value { lower_bound { value: 0.05 } upper_bound { value: 0.17 } value { value: 0.1 } } } }""", metrics_for_slice_pb2.MetricsForSlice()) plots_for_slice = text_format.Parse( """ slice_key { single_slice_keys { column: "fruit" bytes_value: "peach" } single_slice_keys { column: "animal" bytes_value: "cow" } } plots { key: '' value { calibration_histogram_buckets { buckets { lower_threshold_inclusive: -inf upper_threshold_exclusive: 0.0 num_weighted_examples { value: 0.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.0 } } buckets { lower_threshold_inclusive: 0.0 upper_threshold_exclusive: 0.5 num_weighted_examples { value: 1.0 } total_weighted_label { value: 1.0 } total_weighted_refined_prediction { value: 0.3 } } buckets { lower_threshold_inclusive: 0.5 upper_threshold_exclusive: 1.0 num_weighted_examples { value: 1.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.7 } } buckets { lower_threshold_inclusive: 1.0 upper_threshold_exclusive: inf num_weighted_examples { value: 0.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.0 } } } } }""", metrics_for_slice_pb2.PlotsForSlice()) plots_slice_key = _make_slice_key(b'fruit', b'peach', b'animal', b'cow') eval_config = model_eval_lib.EvalConfig( model_location='/path/to/model', data_location='/path/to/data', slice_spec=[ slicer.SingleSliceSpec(features=[('age', 5), ('gender', 'f')], columns=['country']), slicer.SingleSliceSpec(features=[('age', 6), ('gender', 'm')], columns=['interest']) ], example_weight_metric_key='key') output_path = self._getTempDir() with beam.Pipeline() as pipeline: metrics = (pipeline | 'CreateMetrics' >> beam.Create( [metrics_for_slice.SerializeToString()])) plots = (pipeline | 'CreatePlots' >> beam.Create( [plots_for_slice.SerializeToString()])) evaluation = { constants.METRICS_KEY: metrics, constants.PLOTS_KEY: plots } _ = (evaluation | 'WriteResults' >> model_eval_lib.WriteResults( writers=model_eval_lib.default_writers( output_path=output_path))) _ = pipeline | model_eval_lib.WriteEvalConfig( eval_config, output_path) metrics = metrics_and_plots_evaluator.load_and_deserialize_metrics( path=os.path.join(output_path, model_eval_lib._METRICS_OUTPUT_FILE)) plots = metrics_and_plots_evaluator.load_and_deserialize_plots( path=os.path.join(output_path, model_eval_lib._PLOTS_OUTPUT_FILE)) self.assertSliceMetricsListEqual( [(metrics_slice_key, metrics_for_slice.metrics)], metrics) self.assertSlicePlotsListEqual( [(plots_slice_key, plots_for_slice.plots)], plots) got_eval_config = model_eval_lib.load_eval_config(output_path) self.assertEqual(eval_config, got_eval_config)
def testSerializePlots(self): slice_key = _make_slice_key('fruit', 'apple') plot_key = metric_types.PlotKey(name='calibration_plot', output_name='output_name') calibration_plot = text_format.Parse( """ buckets { lower_threshold_inclusive: -inf upper_threshold_exclusive: 0.0 num_weighted_examples { value: 0.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.0 } } buckets { lower_threshold_inclusive: 0.0 upper_threshold_exclusive: 0.5 num_weighted_examples { value: 1.0 } total_weighted_label { value: 1.0 } total_weighted_refined_prediction { value: 0.3 } } buckets { lower_threshold_inclusive: 0.5 upper_threshold_exclusive: 1.0 num_weighted_examples { value: 1.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.7 } } buckets { lower_threshold_inclusive: 1.0 upper_threshold_exclusive: inf num_weighted_examples { value: 0.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.0 } } """, metrics_for_slice_pb2.CalibrationHistogramBuckets()) expected_plots_for_slice = text_format.Parse( """ slice_key { single_slice_keys { column: 'fruit' bytes_value: 'apple' } } plot_keys_and_values { key { output_name: "output_name" } value { calibration_histogram_buckets { buckets { lower_threshold_inclusive: -inf upper_threshold_exclusive: 0.0 num_weighted_examples { value: 0.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.0 } } buckets { lower_threshold_inclusive: 0.0 upper_threshold_exclusive: 0.5 num_weighted_examples { value: 1.0 } total_weighted_label { value: 1.0 } total_weighted_refined_prediction { value: 0.3 } } buckets { lower_threshold_inclusive: 0.5 upper_threshold_exclusive: 1.0 num_weighted_examples { value: 1.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.7 } } buckets { lower_threshold_inclusive: 1.0 upper_threshold_exclusive: inf num_weighted_examples { value: 0.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.0 } } } } } """, metrics_for_slice_pb2.PlotsForSlice()) got = metrics_and_plots_serialization._serialize_plots( (slice_key, { plot_key: calibration_plot }), None) self.assertProtoEquals( expected_plots_for_slice, metrics_for_slice_pb2.PlotsForSlice.FromString(got))
def testSerializeDeserializeToFile(self): metrics_slice_key = _make_slice_key('fruit', 'pear', 'animal', 'duck') metrics_for_slice = text_format.Parse( """ slice_key { single_slice_keys { column: "fruit" bytes_value: "pear" } single_slice_keys { column: "animal" bytes_value: "duck" } } metrics { key: "accuracy" value { double_value { value: 0.8 } } } metrics { key: "example_weight" value { double_value { value: 10.0 } } } metrics { key: "auc" value { bounded_value { lower_bound { value: 0.1 } upper_bound { value: 0.3 } value { value: 0.2 } } } } metrics { key: "auprc" value { bounded_value { lower_bound { value: 0.05 } upper_bound { value: 0.17 } value { value: 0.1 } } } }""", metrics_for_slice_pb2.MetricsForSlice()) plots_for_slice = text_format.Parse( """ slice_key { single_slice_keys { column: "fruit" bytes_value: "peach" } single_slice_keys { column: "animal" bytes_value: "cow" } } plot_data { calibration_histogram_buckets { buckets { lower_threshold_inclusive: -inf upper_threshold_exclusive: 0.0 num_weighted_examples { value: 0.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.0 } } buckets { lower_threshold_inclusive: 0.0 upper_threshold_exclusive: 0.5 num_weighted_examples { value: 1.0 } total_weighted_label { value: 1.0 } total_weighted_refined_prediction { value: 0.3 } } buckets { lower_threshold_inclusive: 0.5 upper_threshold_exclusive: 1.0 num_weighted_examples { value: 1.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.7 } } buckets { lower_threshold_inclusive: 1.0 upper_threshold_exclusive: inf num_weighted_examples { value: 0.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.0 } } } }""", metrics_for_slice_pb2.PlotsForSlice()) plots_slice_key = _make_slice_key('fruit', 'peach', 'animal', 'cow') eval_config = api_types.EvalConfig( model_location='/path/to/model', data_location='/path/to/data', slice_spec=[ slicer.SingleSliceSpec(features=[('age', 5), ('gender', 'f')], columns=['country']), slicer.SingleSliceSpec(features=[('age', 6), ('gender', 'm')], columns=['interest']) ], example_weight_metric_key='key') output_path = self._getTempDir() with beam.Pipeline() as pipeline: metrics = (pipeline | 'CreateMetrics' >> beam.Create( [metrics_for_slice.SerializeToString()])) plots = (pipeline | 'CreatePlots' >> beam.Create( [plots_for_slice.SerializeToString()])) _ = ((metrics, plots) | 'WriteMetricsPlotsAndConfig' >> serialization.WriteMetricsPlotsAndConfig( output_path=output_path, eval_config=eval_config)) metrics, plots = serialization.load_plots_and_metrics(output_path) self.assertSliceMetricsListEqual( [(metrics_slice_key, metrics_for_slice.metrics)], metrics) self.assertSlicePlotsListEqual( [(plots_slice_key, plots_for_slice.plot_data)], plots) got_eval_config = serialization.load_eval_config(output_path) self.assertEqual(eval_config, got_eval_config)