Exemple #1
0
    def testEvaluateQueryBasedMetrics(self):
        temp_eval_export_dir = self._getEvalExportDir()
        _, eval_export_dir = (fixed_prediction_estimator_extra_fields.
                              simple_fixed_prediction_estimator_extra_fields(
                                  None, temp_eval_export_dir))
        eval_shared_model = self.createTestEvalSharedModel(
            eval_saved_model_path=eval_export_dir)
        extractors = [
            legacy_predict_extractor.PredictExtractor(eval_shared_model),
            slice_key_extractor.SliceKeyExtractor()
        ]

        with beam.Pipeline() as pipeline:
            metrics = (
                pipeline
                | 'Create' >> beam.Create(self._get_examples())
                | 'InputsToExtracts' >> model_eval_lib.InputsToExtracts()
                | 'Extract' >> tfma_unit.Extract(extractors=extractors)  # pylint: disable=no-value-for-parameter
                | 'EvaluateQueryBasedMetrics' >>
                query_based_metrics_evaluator.EvaluateQueryBasedMetrics(
                    prediction_key='',
                    query_id='fixed_string',
                    combine_fns=[
                        query_statistics.QueryStatisticsCombineFn(),
                        ndcg.NdcgMetricCombineFn(at_vals=[1, 2],
                                                 gain_key='fixed_float',
                                                 weight_key='fixed_int'),
                        min_label_position.MinLabelPositionCombineFn(
                            label_key='', weight_key='fixed_int'),
                    ]))

            def check_metrics(got):
                try:
                    self.assertEqual(1, len(got), 'got: %s' % got)
                    got_slice_key, got_metrics = got[0]
                    self.assertEqual(got_slice_key, ())
                    self.assertDictElementsAlmostEqual(
                        got_metrics, {
                            'post_export_metrics/total_queries':
                            3.0,
                            'post_export_metrics/total_documents':
                            6.0,
                            'post_export_metrics/min_documents':
                            1.0,
                            'post_export_metrics/max_documents':
                            3.0,
                            'post_export_metrics/ndcg@1':
                            0.9166667,
                            'post_export_metrics/ndcg@2':
                            0.9766198,
                            'post_export_metrics/average_min_label_position/__labels':
                            0.6666667,
                        })

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(metrics[constants.METRICS_KEY],
                             check_metrics,
                             label='metrics')
 def testRunModelAnalysisWithQueryExtractor(self):
   model_location = self._exportEvalSavedModel(
       linear_classifier.simple_linear_classifier)
   examples = [
       self._makeExample(age=3.0, language='english', label=1.0),
       self._makeExample(age=3.0, language='chinese', label=0.0),
       self._makeExample(age=4.0, language='english', label=0.0),
       self._makeExample(age=5.0, language='chinese', label=1.0)
   ]
   data_location = self._writeTFExamplesToTFRecords(examples)
   slice_spec = [slicer.SingleSliceSpec()]
   eval_shared_model = model_eval_lib.default_eval_shared_model(
       eval_saved_model_path=model_location, example_weight_key='age')
   eval_result = model_eval_lib.run_model_analysis(
       eval_shared_model=eval_shared_model,
       data_location=data_location,
       slice_spec=slice_spec,
       evaluators=[
           metrics_and_plots_evaluator.MetricsAndPlotsEvaluator(
               eval_shared_model),
           query_based_metrics_evaluator.QueryBasedMetricsEvaluator(
               query_id='language',
               prediction_key='logistic',
               combine_fns=[
                   query_statistics.QueryStatisticsCombineFn(),
                   ndcg.NdcgMetricCombineFn(
                       at_vals=[1], gain_key='label', weight_key='')
               ]),
       ])
   # We only check some of the metrics to ensure that the end-to-end
   # pipeline works.
   expected = {
       (): {
           'post_export_metrics/total_queries': {
               'doubleValue': 2.0
           },
           'post_export_metrics/min_documents': {
               'doubleValue': 2.0
           },
           'post_export_metrics/max_documents': {
               'doubleValue': 2.0
           },
           'post_export_metrics/total_documents': {
               'doubleValue': 4.0
           },
           'post_export_metrics/ndcg@1': {
               'doubleValue': 0.5
           },
           'post_export_metrics/example_weight': {
               'doubleValue': 15.0
           },
           'post_export_metrics/example_count': {
               'doubleValue': 4.0
           },
       }
   }
   self.assertEqual(eval_result.config.model_location, model_location)
   self.assertEqual(eval_result.config.data_location, data_location)
   self.assertEqual(eval_result.config.slice_spec, slice_spec)
   self.assertMetricsAlmostEqual(eval_result.slicing_metrics, expected)
   self.assertFalse(eval_result.plots)