def testServingGraphAlsoExportedIfSpecified(self): # Most of the example trainers also pass serving_input_receiver_fn to # export_eval_savedmodel, so the serving graph should be included. temp_eval_export_dir = self._getEvalExportDir() _, eval_export_dir = ( fixed_prediction_estimator.simple_fixed_prediction_estimator( None, temp_eval_export_dir)) # Check the eval graph. eval_saved_model = load.EvalSavedModel(eval_export_dir) example1 = self._makeExample(prediction=0.9, label=0.0).SerializeToString() eval_saved_model.metrics_reset_update_get(example1) metric_values = eval_saved_model.get_metric_values() self.assertDictElementsAlmostEqual(metric_values, {'average_loss': 0.81}) # Check the serving graph. # TODO(b/124466113): Remove tf.compat.v2 once TF 2.0 is the default. if hasattr(tf, 'compat.v2'): imported = tf.compat.v2.saved_model.load( eval_export_dir, tags=tf.saved_model.SERVING) predictions = imported.signatures[ tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]( inputs=tf.constant([example1.SerializeToString()])) self.assertAllClose(predictions['outputs'], np.array([[0.9]]))
def testServingGraphAlsoExportedIfSpecified(self): # Most of the example trainers also pass serving_input_receiver_fn to # export_eval_savedmodel, so the serving graph should be included. temp_eval_export_dir = self._getEvalExportDir() _, eval_export_dir = ( fixed_prediction_estimator.simple_fixed_prediction_estimator( None, temp_eval_export_dir)) # Check the eval graph. eval_saved_model = load.EvalSavedModel(eval_export_dir) example1 = self._makeExample(prediction=0.9, label=0.0) features_predictions_labels = self.predict_injective_single_example( eval_saved_model, example1.SerializeToString()) eval_saved_model.perform_metrics_update(features_predictions_labels) metric_values = eval_saved_model.get_metric_values() self.assertDictElementsAlmostEqual(metric_values, {'average_loss': 0.81}) # Check the serving graph. estimator = tf.contrib.estimator.SavedModelEstimator(eval_export_dir) def predict_input_fn(): return {'inputs': tf.constant([example1.SerializeToString()])} predictions = next(estimator.predict(predict_input_fn)) self.assertAllClose(predictions['outputs'], np.array([0.9]))
def testAucUnweighted(self): temp_eval_export_dir = self._getEvalExportDir() _, eval_export_dir = ( fixed_prediction_estimator.simple_fixed_prediction_estimator( None, temp_eval_export_dir)) examples = [ self._makeExample(prediction=0.0000, label=0.0000), self._makeExample(prediction=0.0000, label=1.0000), self._makeExample(prediction=0.7000, label=1.0000), self._makeExample(prediction=0.8000, label=0.0000), self._makeExample(prediction=1.0000, label=1.0000), ] expected_values_dict = { metric_keys.AUC: 0.58333, metric_keys.lower_bound(metric_keys.AUC): 0.5, metric_keys.upper_bound(metric_keys.AUC): 0.66667, metric_keys.lower_bound(metric_keys.AUPRC): 0.74075, metric_keys.lower_bound(metric_keys.AUPRC): 0.70000, metric_keys.upper_bound(metric_keys.AUPRC): 0.77778, } self._runTest( examples, eval_export_dir, [post_export_metrics.auc(), post_export_metrics.auc(curve='PR')], expected_values_dict)
def testCalibrationPlotAndPredictionHistogramUnweighted(self): temp_eval_export_dir = self._getEvalExportDir() _, eval_export_dir = ( fixed_prediction_estimator.simple_fixed_prediction_estimator( None, temp_eval_export_dir)) examples = [ # For each example, we set label to prediction + 1. # These two go in bucket 0: (-inf, 0) self._makeExample(prediction=-10.0, label=-9.0), self._makeExample(prediction=-9.0, label=-8.0), # This goes in bucket 1: [0, 0.00100) self._makeExample(prediction=0.00000, label=1.00000), # These three go in bucket 1: [0.00100, 0.00110) self._makeExample(prediction=0.00100, label=1.00100), self._makeExample(prediction=0.00101, label=1.00101), self._makeExample(prediction=0.00102, label=1.00102), # These two go in bucket 10000: [0.99990, 1.00000) self._makeExample(prediction=0.99998, label=1.99998), self._makeExample(prediction=0.99999, label=1.99999), # These four go in bucket 10001: [1.0000, +inf) self._makeExample(prediction=1.0, label=2.0), self._makeExample(prediction=8.0, label=9.0), self._makeExample(prediction=9.0, label=10.0), self._makeExample(prediction=10.0, label=11.0), ] def check_result(got): # pylint: disable=invalid-name try: self.assertEqual(1, len(got), 'got: %s' % got) (slice_key, value) = got[0] self.assertEqual((), slice_key) self.assertIn(metric_keys.CALIBRATION_PLOT_MATRICES, value) buckets = value[metric_keys.CALIBRATION_PLOT_MATRICES] self.assertSequenceAlmostEqual(buckets[0], [-19.0, -17.0, 2.0]) self.assertSequenceAlmostEqual(buckets[1], [0.0, 1.0, 1.0]) self.assertSequenceAlmostEqual(buckets[11], [0.00303, 3.00303, 3.0]) self.assertSequenceAlmostEqual(buckets[10000], [1.99997, 3.99997, 2.0]) self.assertSequenceAlmostEqual(buckets[10001], [28.0, 32.0, 4.0]) self.assertIn(metric_keys.CALIBRATION_PLOT_BOUNDARIES, value) boundaries = value[metric_keys.CALIBRATION_PLOT_BOUNDARIES] self.assertAlmostEqual(0.0, boundaries[0]) self.assertAlmostEqual(0.001, boundaries[10]) self.assertAlmostEqual(0.005, boundaries[50]) self.assertAlmostEqual(0.010, boundaries[100]) self.assertAlmostEqual(0.100, boundaries[1000]) self.assertAlmostEqual(0.800, boundaries[8000]) self.assertAlmostEqual(1.000, boundaries[10000]) except AssertionError as err: raise util.BeamAssertException(err) self._runTestWithCustomCheck( examples, eval_export_dir, [post_export_metrics.calibration_plot_and_prediction_histogram()], custom_plots_check=check_result)
def testEvaluateWithPlots(self): temp_eval_export_dir = self._getEvalExportDir() _, eval_export_dir = ( fixed_prediction_estimator.simple_fixed_prediction_estimator( None, temp_eval_export_dir)) with beam.Pipeline() as pipeline: example1 = self._makeExample(prediction=0.0, label=1.0) example2 = self._makeExample(prediction=0.7, label=0.0) example3 = self._makeExample(prediction=0.8, label=1.0) example4 = self._makeExample(prediction=1.0, label=1.0) metrics, plots = ( pipeline | beam.Create([ example1.SerializeToString(), example2.SerializeToString(), example3.SerializeToString(), example4.SerializeToString() ]) | evaluate.Evaluate( eval_saved_model_path=eval_export_dir, add_metrics_callbacks=[ post_export_metrics.example_count(), post_export_metrics.auc_plots() ])) def check_metrics(got): try: self.assertEqual(1, len(got), 'got: %s' % got) (slice_key, value) = got[0] self.assertEqual((), slice_key) self.assertDictElementsAlmostEqual( got_values_dict=value, expected_values_dict={ metric_keys.EXAMPLE_COUNT: 4.0, }) except AssertionError as err: raise util.BeamAssertException(err) util.assert_that(metrics, check_metrics, label='metrics') def check_plots(got): try: self.assertEqual(1, len(got), 'got: %s' % got) (slice_key, value) = got[0] self.assertEqual((), slice_key) self.assertDictMatrixRowsAlmostEqual( got_values_dict=value, expected_values_dict={ metric_keys.AUC_PLOTS_MATRICES: [(8001, [ 2, 1, 0, 1, 1.0 / 1.0, 1.0 / 3.0 ])], }) except AssertionError as err: raise util.BeamAssertException(err) util.assert_that(plots, check_plots, label='plots')
def testAucUnweightedSerialization(self): temp_eval_export_dir = self._getEvalExportDir() _, eval_export_dir = ( fixed_prediction_estimator.simple_fixed_prediction_estimator( None, temp_eval_export_dir)) examples = [ self._makeExample(prediction=0.0000, label=0.0000), self._makeExample(prediction=0.0000, label=1.0000), self._makeExample(prediction=0.7000, label=1.0000), self._makeExample(prediction=0.8000, label=0.0000), self._makeExample(prediction=1.0000, label=1.0000), ] expected_values_dict = { metric_keys.lower_bound(metric_keys.AUPRC): 0.74075, metric_keys.lower_bound(metric_keys.AUPRC): 0.70000, metric_keys.upper_bound(metric_keys.AUPRC): 0.77778, } auc_metric = post_export_metrics.auc(curve='PR') def check_result(got): # pylint: disable=invalid-name try: self.assertEqual(1, len(got), 'got: %s' % got) (slice_key, value) = got[0] self.assertEqual((), slice_key) self.assertDictElementsAlmostEqual(value, expected_values_dict) # Check serialization too. # Note that we can't just make this a dict, since proto maps # allow uninitialized key access, i.e. they act like defaultdicts. output_metrics = metrics_for_slice_pb2.MetricsForSlice( ).metrics auc_metric.populate_stats_and_pop(value, output_metrics) self.assertProtoEquals( """ bounded_value { lower_bound { value: 0.6999999 } upper_bound { value: 0.7777776 } value { value: 0.7407472 } } """, output_metrics[metric_keys.AUPRC]) except AssertionError as err: raise util.BeamAssertException(err) self._runTestWithCustomCheck(examples, eval_export_dir, [auc_metric], custom_metrics_check=check_result)
def testAucPlotsUnweighted(self): temp_eval_export_dir = self._getEvalExportDir() _, eval_export_dir = ( fixed_prediction_estimator.simple_fixed_prediction_estimator( None, temp_eval_export_dir)) examples = [ self._makeExample(prediction=0.0000, label=0.0000), self._makeExample(prediction=0.0000, label=1.0000), self._makeExample(prediction=0.7000, label=1.0000), self._makeExample(prediction=0.8000, label=0.0000), self._makeExample(prediction=1.0000, label=1.0000), ] def check_result(got): # pylint: disable=invalid-name try: self.assertEqual(1, len(got), 'got: %s' % got) (slice_key, value) = got[0] self.assertEqual((), slice_key) self.assertIn(metric_keys.AUC_PLOTS_MATRICES, value) matrices = value[metric_keys.AUC_PLOTS_MATRICES] # | | --------- Threshold ----------- # true label | pred | -1e-6 | 0.0 | 0.7 | 0.8 | 1.0 # - | 0.0 | FP | TN | TN | TN | TN # + | 0.0 | TP | FN | FN | FN | FN # + | 0.7 | TP | TP | FN | FN | FN # - | 0.8 | FP | FP | FP | TN | TN # + | 1.0 | TP | TP | TP | TP | FN self.assertSequenceAlmostEqual(matrices[0], [0, 0, 2, 3, 3.0 / 5.0, 1.0]) self.assertSequenceAlmostEqual( matrices[1], [1, 1, 1, 2, 2.0 / 3.0, 2.0 / 3.0]) self.assertSequenceAlmostEqual( matrices[7001], [2, 1, 1, 1, 1.0 / 2.0, 1.0 / 3.0]) self.assertSequenceAlmostEqual( matrices[8001], [2, 2, 0, 1, 1.0 / 1.0, 1.0 / 3.0]) self.assertSequenceAlmostEqual( matrices[10001], [3, 2, 0, 0, float('nan'), 0.0]) self.assertIn(metric_keys.AUC_PLOTS_THRESHOLDS, value) thresholds = value[metric_keys.AUC_PLOTS_THRESHOLDS] self.assertAlmostEqual(0.0, thresholds[1]) self.assertAlmostEqual(0.001, thresholds[11]) self.assertAlmostEqual(0.005, thresholds[51]) self.assertAlmostEqual(0.010, thresholds[101]) self.assertAlmostEqual(0.100, thresholds[1001]) self.assertAlmostEqual(0.800, thresholds[8001]) self.assertAlmostEqual(1.000, thresholds[10001]) except AssertionError as err: raise util.BeamAssertException(err) self._runTestWithCustomCheck(examples, eval_export_dir, [post_export_metrics.auc_plots()], custom_plots_check=check_result)
def testAssertMetricsComputedWithBeamAre(self): temp_eval_export_dir = self._getEvalExportDir() _, eval_export_dir = ( fixed_prediction_estimator.simple_fixed_prediction_estimator( None, temp_eval_export_dir)) examples = [ self.makeExample(prediction=0.0, label=1.0), self.makeExample(prediction=0.7, label=0.0), self.makeExample(prediction=0.8, label=1.0), self.makeExample(prediction=1.0, label=1.0) ] self.assertMetricsComputedWithBeamAre( eval_saved_model_path=eval_export_dir, serialized_examples=examples, expected_metrics={'average_loss': (1.0 + 0.49 + 0.04 + 0.00) / 4.0})
def testExampleCountNoStandardKeys(self): # Test ExampleCount with a custom Estimator that doesn't have any of the # standard PredictionKeys. temp_eval_export_dir = self._getEvalExportDir() _, eval_export_dir = ( fixed_prediction_estimator.simple_fixed_prediction_estimator( None, temp_eval_export_dir, output_prediction_key='non_standard')) examples = [ self._makeExample(prediction=5.0, label=5.0), self._makeExample(prediction=6.0, label=6.0), self._makeExample(prediction=7.0, label=7.0), ] expected_values_dict = { metric_keys.EXAMPLE_COUNT: 3.0, } self._runTest(examples, eval_export_dir, [ post_export_metrics.example_count(), ], expected_values_dict)
def testExampleCountEmptyPredictionsDict(self): # Test ExampleCount with a custom Estimator that has empty predictions dict. # This is possible if the Estimator doesn't return the predictions dict # in EVAL mode, but computes predictions and feeds them into the metrics # internally. temp_eval_export_dir = self._getEvalExportDir() _, eval_export_dir = ( fixed_prediction_estimator.simple_fixed_prediction_estimator( None, temp_eval_export_dir, output_prediction_key=None)) examples = [ self._makeExample(prediction=5.0, label=5.0), self._makeExample(prediction=6.0, label=6.0), self._makeExample(prediction=7.0, label=7.0), ] expected_values_dict = { metric_keys.EXAMPLE_COUNT: 3.0, } self._runTest(examples, eval_export_dir, [ post_export_metrics.example_count(), ], expected_values_dict)
def testBoundedValueChecks(self): temp_eval_export_dir = self._getEvalExportDir() _, eval_export_dir = ( fixed_prediction_estimator.simple_fixed_prediction_estimator( None, temp_eval_export_dir)) examples = [ self.makeExample(prediction=0.8, label=1.0), ] self.assertMetricsComputedWithBeamAre( eval_saved_model_path=eval_export_dir, serialized_examples=examples, expected_metrics={'average_loss': 0.04}) self.assertMetricsComputedWithoutBeamAre( eval_saved_model_path=eval_export_dir, serialized_examples=examples, expected_metrics={ 'average_loss': tfma_unit.BoundedValue(lower_bound=0.03, upper_bound=0.05) }) with self.assertRaisesRegexp( AssertionError, 'expecting key average_loss to have value between'): self.assertMetricsComputedWithoutBeamAre( eval_saved_model_path=eval_export_dir, serialized_examples=examples, expected_metrics={ 'average_loss': tfma_unit.BoundedValue(upper_bound=0.01) }) with self.assertRaisesRegexp( AssertionError, 'expecting key average_loss to have value between'): self.assertMetricsComputedWithoutBeamAre( eval_saved_model_path=eval_export_dir, serialized_examples=examples, expected_metrics={ 'average_loss': tfma_unit.BoundedValue(lower_bound=0.10) })
def testEvaluateWithPlots(self): temp_eval_export_dir = self._getEvalExportDir() _, eval_export_dir = ( fixed_prediction_estimator.simple_fixed_prediction_estimator( None, temp_eval_export_dir)) eval_shared_model = self.createTestEvalSharedModel( eval_saved_model_path=eval_export_dir, add_metrics_callbacks=[ post_export_metrics.example_count(), post_export_metrics.auc_plots() ]) extractors = [ predict_extractor.PredictExtractor(eval_shared_model), slice_key_extractor.SliceKeyExtractor() ] with beam.Pipeline() as pipeline: example1 = self._makeExample(prediction=0.0, label=1.0) example2 = self._makeExample(prediction=0.7, label=0.0) example3 = self._makeExample(prediction=0.8, label=1.0) example4 = self._makeExample(prediction=1.0, label=1.0) metrics, plots = ( pipeline | 'Create' >> beam.Create([ example1.SerializeToString(), example2.SerializeToString(), example3.SerializeToString(), example4.SerializeToString() ]) | 'InputsToExtracts' >> model_eval_lib.InputsToExtracts() | 'Extract' >> tfma_unit.Extract(extractors=extractors) # pylint: disable=no-value-for-parameter | 'ComputeMetricsAndPlots' >> metrics_and_plots_evaluator. ComputeMetricsAndPlots(eval_shared_model=eval_shared_model)) def check_metrics(got): try: self.assertEqual(1, len(got), 'got: %s' % got) (slice_key, value) = got[0] self.assertEqual((), slice_key) self.assertDictElementsAlmostEqual( got_values_dict=value, expected_values_dict={ metric_keys.EXAMPLE_COUNT: 4.0, }) except AssertionError as err: raise util.BeamAssertException(err) util.assert_that(metrics, check_metrics, label='metrics') def check_plots(got): try: self.assertEqual(1, len(got), 'got: %s' % got) (slice_key, value) = got[0] self.assertEqual((), slice_key) self.assertDictMatrixRowsAlmostEqual( got_values_dict=value, expected_values_dict={ _full_key(metric_keys.AUC_PLOTS_MATRICES): [(8001, [2, 1, 0, 1, 1.0 / 1.0, 1.0 / 3.0])], }) except AssertionError as err: raise util.BeamAssertException(err) util.assert_that(plots, check_plots, label='plots')
def testWriteMetricsAndPlots(self): metrics_file = os.path.join(self._getTempDir(), 'metrics') plots_file = os.path.join(self._getTempDir(), 'plots') temp_eval_export_dir = os.path.join(self._getTempDir(), 'eval_export_dir') _, eval_export_dir = ( fixed_prediction_estimator.simple_fixed_prediction_estimator( None, temp_eval_export_dir)) eval_config = config.EvalConfig( model_specs=[config.ModelSpec()], options=config.Options( disabled_outputs={'values': ['eval_config.json']})) eval_shared_model = self.createTestEvalSharedModel( eval_saved_model_path=eval_export_dir, add_metrics_callbacks=[ post_export_metrics.example_count(), post_export_metrics.calibration_plot_and_prediction_histogram( num_buckets=2) ]) extractors = [ predict_extractor.PredictExtractor(eval_shared_model), slice_key_extractor.SliceKeyExtractor() ] evaluators = [ metrics_and_plots_evaluator.MetricsAndPlotsEvaluator(eval_shared_model) ] output_paths = { constants.METRICS_KEY: metrics_file, constants.PLOTS_KEY: plots_file } writers = [ metrics_plots_and_validations_writer.MetricsPlotsAndValidationsWriter( output_paths, eval_shared_model.add_metrics_callbacks) ] with beam.Pipeline() as pipeline: example1 = self._makeExample(prediction=0.0, label=1.0) example2 = self._makeExample(prediction=1.0, label=1.0) # pylint: disable=no-value-for-parameter _ = ( pipeline | 'Create' >> beam.Create([ example1.SerializeToString(), example2.SerializeToString(), ]) | 'ExtractEvaluateAndWriteResults' >> model_eval_lib.ExtractEvaluateAndWriteResults( eval_config=eval_config, eval_shared_model=eval_shared_model, extractors=extractors, evaluators=evaluators, writers=writers)) # pylint: enable=no-value-for-parameter expected_metrics_for_slice = text_format.Parse( """ slice_key {} metrics { key: "average_loss" value { double_value { value: 0.5 } } } metrics { key: "post_export_metrics/example_count" value { double_value { value: 2.0 } } } """, metrics_for_slice_pb2.MetricsForSlice()) metric_records = [] for record in tf.compat.v1.python_io.tf_record_iterator(metrics_file): metric_records.append( metrics_for_slice_pb2.MetricsForSlice.FromString(record)) self.assertEqual(1, len(metric_records), 'metrics: %s' % metric_records) self.assertProtoEquals(expected_metrics_for_slice, metric_records[0]) expected_plots_for_slice = text_format.Parse( """ slice_key {} plots { key: "post_export_metrics" value { calibration_histogram_buckets { buckets { lower_threshold_inclusive: -inf num_weighted_examples {} total_weighted_label {} total_weighted_refined_prediction {} } buckets { upper_threshold_exclusive: 0.5 num_weighted_examples { value: 1.0 } total_weighted_label { value: 1.0 } total_weighted_refined_prediction {} } buckets { lower_threshold_inclusive: 0.5 upper_threshold_exclusive: 1.0 num_weighted_examples { } total_weighted_label {} total_weighted_refined_prediction {} } buckets { lower_threshold_inclusive: 1.0 upper_threshold_exclusive: inf num_weighted_examples { value: 1.0 } total_weighted_label { value: 1.0 } total_weighted_refined_prediction { value: 1.0 } } } } } """, metrics_for_slice_pb2.PlotsForSlice()) plot_records = [] for record in tf.compat.v1.python_io.tf_record_iterator(plots_file): plot_records.append( metrics_for_slice_pb2.PlotsForSlice.FromString(record)) self.assertEqual(1, len(plot_records), 'plots: %s' % plot_records) self.assertProtoEquals(expected_plots_for_slice, plot_records[0])
def testConfusionMatrixAtThresholdsSerialization(self): temp_eval_export_dir = self._getEvalExportDir() _, eval_export_dir = ( fixed_prediction_estimator.simple_fixed_prediction_estimator( None, temp_eval_export_dir)) examples = [ self._makeExample(prediction=0.0000, label=0.0000), self._makeExample(prediction=0.5000, label=1.0000), self._makeExample(prediction=1.0000, label=1.0000), ] confusion_matrix_at_thresholds_metric = ( post_export_metrics.confusion_matrix_at_thresholds( thresholds=[0.25, 0.75, 1.00])) def check_result(got): # pylint: disable=invalid-name try: self.assertEqual(1, len(got), 'got: %s' % got) (slice_key, value) = got[0] self.assertEqual((), slice_key) self.assertIn( metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_MATRICES, value) matrices = value[ metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_MATRICES] # | | ---- Threshold ---- # true label | pred | 0.25 | 0.75 | 1.00 # - | 0.0 | TN | TN | TN # + | 0.5 | TP | FN | FN # + | 1.0 | TP | TP | FN self.assertSequenceAlmostEqual(matrices[0], [0.0, 1.0, 0.0, 2.0, 1.0, 1.0]) self.assertSequenceAlmostEqual(matrices[1], [1.0, 1.0, 0.0, 1.0, 1.0, 0.5]) self.assertSequenceAlmostEqual( matrices[2], [2.0, 1.0, 0.0, 0.0, float('nan'), 0.0]) self.assertIn( metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_THRESHOLDS, value) thresholds = value[ metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_THRESHOLDS] self.assertAlmostEqual(0.25, thresholds[0]) self.assertAlmostEqual(0.75, thresholds[1]) self.assertAlmostEqual(1.00, thresholds[2]) # Check serialization too. # Note that we can't just make this a dict, since proto maps # allow uninitialized key access, i.e. they act like defaultdicts. output_metrics = metrics_for_slice_pb2.MetricsForSlice( ).metrics confusion_matrix_at_thresholds_metric.populate_stats_and_pop( value, output_metrics) self.assertProtoEquals( """ confusion_matrix_at_thresholds { matrices { threshold: 0.25 false_negatives: 0.0 true_negatives: 1.0 false_positives: 0.0 true_positives: 2.0 precision: 1.0 recall: 1.0 } matrices { threshold: 0.75 false_negatives: 1.0 true_negatives: 1.0 false_positives: 0.0 true_positives: 1.0 precision: 1.0 recall: 0.5 } matrices { threshold: 1.00 false_negatives: 2.0 true_negatives: 1.0 false_positives: 0.0 true_positives: 0.0 precision: nan recall: 0.0 } } """, output_metrics[metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS]) except AssertionError as err: raise util.BeamAssertException(err) self._runTestWithCustomCheck(examples, eval_export_dir, [confusion_matrix_at_thresholds_metric], custom_metrics_check=check_result)
def _write_tfma(self, tfma_path: str, output_file_format: str, store: Optional[mlmd.MetadataStore] = None): _, eval_saved_model_path = ( fixed_prediction_estimator.simple_fixed_prediction_estimator( export_path=None, eval_export_path=os.path.join(self.tmpdir, 'eval_export_dir'))) eval_config = tfma.EvalConfig(model_specs=[tfma.ModelSpec()]) eval_shared_model = self.createTestEvalSharedModel( eval_saved_model_path=eval_saved_model_path, add_metrics_callbacks=[ tfma.post_export_metrics.example_count(), tfma.post_export_metrics. calibration_plot_and_prediction_histogram(num_buckets=2) ]) extractors = [ tfma.extractors.legacy_predict_extractor.PredictExtractor( eval_shared_model, eval_config=eval_config), tfma.extractors.unbatch_extractor.UnbatchExtractor(), tfma.extractors.slice_key_extractor.SliceKeyExtractor() ] evaluators = [ tfma.evaluators.legacy_metrics_and_plots_evaluator. MetricsAndPlotsEvaluator(eval_shared_model) ] writers = [ tfma.writers.MetricsPlotsAndValidationsWriter( output_paths={ 'metrics': os.path.join(tfma_path, 'metrics'), 'plots': os.path.join(tfma_path, 'plots') }, output_file_format=output_file_format, eval_config=eval_config, add_metrics_callbacks=eval_shared_model.add_metrics_callbacks) ] tfx_io = raw_tf_record.RawBeamRecordTFXIO( physical_format='inmemory', raw_record_column_name='__raw_record__', telemetry_descriptors=['TFMATest']) with beam.Pipeline() as pipeline: example1 = self._makeExample(prediction=0.0, label=1.0) example2 = self._makeExample(prediction=1.0, label=1.0) _ = (pipeline | 'Create' >> beam.Create([ example1.SerializeToString(), example2.SerializeToString(), ]) | 'BatchExamples' >> tfx_io.BeamSource() | 'ExtractEvaluateAndWriteResults' >> tfma.ExtractEvaluateAndWriteResults( eval_config=eval_config, eval_shared_model=eval_shared_model, extractors=extractors, evaluators=evaluators, writers=writers)) if store: eval_type = metadata_store_pb2.ArtifactType() eval_type.name = standard_artifacts.ModelEvaluation.TYPE_NAME eval_type_id = store.put_artifact_type(eval_type) artifact = metadata_store_pb2.Artifact() artifact.uri = tfma_path artifact.type_id = eval_type_id store.put_artifacts([artifact])