def testEvaluateQueryBasedMetrics(self): temp_eval_export_dir = self._getEvalExportDir() _, eval_export_dir = (fixed_prediction_estimator_extra_fields. simple_fixed_prediction_estimator_extra_fields( None, temp_eval_export_dir)) eval_shared_model = self.createTestEvalSharedModel( eval_saved_model_path=eval_export_dir) extractors = [ legacy_predict_extractor.PredictExtractor(eval_shared_model), slice_key_extractor.SliceKeyExtractor() ] with beam.Pipeline() as pipeline: metrics = ( pipeline | 'Create' >> beam.Create(self._get_examples()) | 'InputsToExtracts' >> model_eval_lib.InputsToExtracts() | 'Extract' >> tfma_unit.Extract(extractors=extractors) # pylint: disable=no-value-for-parameter | 'EvaluateQueryBasedMetrics' >> query_based_metrics_evaluator.EvaluateQueryBasedMetrics( prediction_key='', query_id='fixed_string', combine_fns=[ query_statistics.QueryStatisticsCombineFn(), ndcg.NdcgMetricCombineFn(at_vals=[1, 2], gain_key='fixed_float', weight_key='fixed_int'), min_label_position.MinLabelPositionCombineFn( label_key='', weight_key='fixed_int'), ])) def check_metrics(got): try: self.assertEqual(1, len(got), 'got: %s' % got) got_slice_key, got_metrics = got[0] self.assertEqual(got_slice_key, ()) self.assertDictElementsAlmostEqual( got_metrics, { 'post_export_metrics/total_queries': 3.0, 'post_export_metrics/total_documents': 6.0, 'post_export_metrics/min_documents': 1.0, 'post_export_metrics/max_documents': 3.0, 'post_export_metrics/ndcg@1': 0.9166667, 'post_export_metrics/ndcg@2': 0.9766198, 'post_export_metrics/average_min_label_position/__labels': 0.6666667, }) except AssertionError as err: raise util.BeamAssertException(err) util.assert_that(metrics[constants.METRICS_KEY], check_metrics, label='metrics')
def testRunModelAnalysisWithQueryExtractor(self): model_location = self._exportEvalSavedModel( linear_classifier.simple_linear_classifier) examples = [ self._makeExample(age=3.0, language='english', label=1.0), self._makeExample(age=3.0, language='chinese', label=0.0), self._makeExample(age=4.0, language='english', label=0.0), self._makeExample(age=5.0, language='chinese', label=1.0) ] data_location = self._writeTFExamplesToTFRecords(examples) slice_spec = [slicer.SingleSliceSpec()] eval_shared_model = model_eval_lib.default_eval_shared_model( eval_saved_model_path=model_location, example_weight_key='age') eval_result = model_eval_lib.run_model_analysis( eval_shared_model=eval_shared_model, data_location=data_location, slice_spec=slice_spec, evaluators=[ metrics_and_plots_evaluator.MetricsAndPlotsEvaluator( eval_shared_model), query_based_metrics_evaluator.QueryBasedMetricsEvaluator( query_id='language', prediction_key='logistic', combine_fns=[ query_statistics.QueryStatisticsCombineFn(), ndcg.NdcgMetricCombineFn( at_vals=[1], gain_key='label', weight_key='') ]), ]) # We only check some of the metrics to ensure that the end-to-end # pipeline works. expected = { (): { 'post_export_metrics/total_queries': { 'doubleValue': 2.0 }, 'post_export_metrics/min_documents': { 'doubleValue': 2.0 }, 'post_export_metrics/max_documents': { 'doubleValue': 2.0 }, 'post_export_metrics/total_documents': { 'doubleValue': 4.0 }, 'post_export_metrics/ndcg@1': { 'doubleValue': 0.5 }, 'post_export_metrics/example_weight': { 'doubleValue': 15.0 }, 'post_export_metrics/example_count': { 'doubleValue': 4.0 }, } } self.assertEqual(eval_result.config.model_location, model_location) self.assertEqual(eval_result.config.data_location, data_location) self.assertEqual(eval_result.config.slice_spec, slice_spec) self.assertMetricsAlmostEqual(eval_result.slicing_metrics, expected) self.assertFalse(eval_result.plots)