def testValidateMetricsMetricValueAndThresholdIgnoreUnmatchedSlice(
         self, slicing_specs, slice_key):
     threshold = config.MetricThreshold(
         value_threshold=config.GenericValueThreshold(
             upper_bound={'value': 1}))
     eval_config = config.EvalConfig(
         model_specs=[
             config.ModelSpec(),
         ],
         slicing_specs=slicing_specs,
         metrics_specs=[
             config.MetricsSpec(
                 metrics=[
                     config.MetricConfig(
                         class_name='WeightedExampleCount',
                         # 1.5 < 1, NOT OK.
                         per_slice_thresholds=[
                             config.PerSliceMetricThreshold(
                                 slicing_specs=slicing_specs,
                                 threshold=threshold)
                         ]),
                 ],
                 model_names=['']),
         ],
     )
     sliced_metrics = (slice_key, {
         metric_types.MetricKey(name='weighted_example_count'):
         1.5,
     })
     result = metrics_validator.validate_metrics(sliced_metrics,
                                                 eval_config)
     self.assertTrue(result.validation_ok)
 def testValidateMetricsCrossSliceThresholdFail(self, cross_slicing_specs,
                                                slice_key):
     threshold = config.MetricThreshold(
         value_threshold=config.GenericValueThreshold(
             upper_bound={'value': 1}))
     eval_config = config.EvalConfig(
         model_specs=[
             config.ModelSpec(),
         ],
         cross_slicing_specs=cross_slicing_specs,
         metrics_specs=[
             config.MetricsSpec(
                 metrics=[
                     config.MetricConfig(
                         class_name='WeightedExampleCount',
                         # 1.5 < 1, NOT OK.
                         threshold=(threshold if cross_slicing_specs is None
                                    else None),
                         cross_slice_thresholds=[
                             config.CrossSliceMetricThreshold(
                                 cross_slicing_specs=cross_slicing_specs,
                                 threshold=threshold)
                         ]),
                 ],
                 model_names=['']),
         ],
     )
     sliced_metrics = (slice_key, {
         metric_types.MetricKey(name='weighted_example_count'):
         1.5,
     })
     result = metrics_validator.validate_metrics(sliced_metrics,
                                                 eval_config)
     self.assertFalse(result.validation_ok)
 def testValidateMetricsValueThresholdLowerBoundPass(
         self, slicing_specs, slice_key):
     threshold = config.MetricThreshold(
         value_threshold=config.GenericValueThreshold(
             lower_bound={'value': 1}))
     eval_config = config.EvalConfig(
         model_specs=[
             config.ModelSpec(),
         ],
         slicing_specs=slicing_specs,
         metrics_specs=[
             config.MetricsSpec(
                 metrics=[
                     config.MetricConfig(
                         class_name='WeightedExampleCount',
                         # 2 > 1, OK.
                         threshold=threshold
                         if slicing_specs is None else None,
                         per_slice_thresholds=[
                             config.PerSliceMetricThreshold(
                                 slicing_specs=slicing_specs,
                                 threshold=threshold)
                         ]),
                 ],
                 model_names=['']),
         ],
     )
     sliced_metrics = (slice_key, {
         metric_types.MetricKey(name='weighted_example_count'):
         2,
     })
     result = metrics_validator.validate_metrics(sliced_metrics,
                                                 eval_config)
     self.assertTrue(result.validation_ok)
 def testValidateMetricsValueThresholdLowerBoundPass(self):
     eval_config = config.EvalConfig(
         model_specs=[
             config.ModelSpec(),
         ],
         slicing_specs=[config.SlicingSpec()],
         metrics_specs=[
             config.MetricsSpec(
                 metrics=[
                     config.MetricConfig(
                         class_name='WeightedExampleCount',
                         # 2 > 1, OK.
                         threshold=config.MetricThreshold(
                             value_threshold=config.GenericValueThreshold(
                                 lower_bound={'value': 1}))),
                 ],
                 model_names=['']),
         ],
     )
     sliced_metrics = ((()), {
         metric_types.MetricKey(name='weighted_example_count'):
         2,
     })
     result = metrics_validator.validate_metrics(sliced_metrics,
                                                 eval_config)
     self.assertTrue(result.validation_ok)
 def testValidateMetricsMetricTDistributionValueAndThreshold(
         self, slicing_specs, slice_key):
     threshold = config.MetricThreshold(
         value_threshold=config.GenericValueThreshold(
             lower_bound={'value': 0.9}))
     eval_config = config.EvalConfig(
         model_specs=[
             config.ModelSpec(),
         ],
         slicing_specs=slicing_specs,
         metrics_specs=[
             config.MetricsSpec(metrics=[
                 config.MetricConfig(
                     class_name='AUC',
                     threshold=threshold if slicing_specs is None else None,
                     per_slice_thresholds=[
                         config.PerSliceMetricThreshold(
                             slicing_specs=slicing_specs,
                             threshold=threshold)
                     ]),
             ],
                                model_names=['']),
         ],
     )
     sliced_metrics = (slice_key, {
         metric_types.MetricKey(name='auc'):
         types.ValueWithTDistribution(sample_mean=0.91, unsampled_value=0.8)
     })
     result = metrics_validator.validate_metrics(sliced_metrics,
                                                 eval_config)
     self.assertFalse(result.validation_ok)
     expected = text_format.Parse(
         """
     metric_validations_per_slice {
       failures {
         metric_key {
           name: "auc"
         }
         metric_value {
           double_value {
             value: 0.8
           }
         }
       }
     }""", validation_result_pb2.ValidationResult())
     expected.metric_validations_per_slice[0].failures[
         0].metric_threshold.CopyFrom(threshold)
     expected.metric_validations_per_slice[0].slice_key.CopyFrom(
         slicer.serialize_slice_key(slice_key))
     for spec in slicing_specs or [None]:
         if (spec is None or slicer.SingleSliceSpec(
                 spec=spec).is_slice_applicable(slice_key)):
             slicing_details = expected.validation_details.slicing_details.add(
             )
             if spec is not None:
                 slicing_details.slicing_spec.CopyFrom(spec)
             else:
                 slicing_details.slicing_spec.CopyFrom(config.SlicingSpec())
             slicing_details.num_matching_slices = 1
     self.assertEqual(result, expected)
示例#6
0
 def testMetricKeysToSkipForConfidenceIntervals(self):
   metrics_specs = [
       config.MetricsSpec(
           metrics=[
               config.MetricConfig(
                   class_name='ExampleCount',
                   config=json.dumps({'name': 'example_count'}),
                   threshold=config.MetricThreshold(
                       value_threshold=config.GenericValueThreshold())),
               config.MetricConfig(
                   class_name='MeanLabel',
                   config=json.dumps({'name': 'mean_label'}),
                   threshold=config.MetricThreshold(
                       change_threshold=config.GenericChangeThreshold())),
               config.MetricConfig(
                   class_name='MeanSquaredError',
                   config=json.dumps({'name': 'mse'}),
                   threshold=config.MetricThreshold(
                       change_threshold=config.GenericChangeThreshold()))
           ],
           # Model names and output_names should be ignored because
           # ExampleCount is model independent.
           model_names=['model_name1', 'model_name2'],
           output_names=['output_name1', 'output_name2']),
   ]
   metrics_specs += metric_specs.specs_from_metrics(
       [tf.keras.metrics.MeanSquaredError('mse')])
   keys = metric_specs.metric_keys_to_skip_for_confidence_intervals(
       metrics_specs)
   self.assertLen(keys, 1)
   self.assertIn(metric_types.MetricKey(name='example_count'), keys)
    def testGetMissingSlices(self):
        slicing_specs = [
            config.SlicingSpec(),
            config.SlicingSpec(feature_values={'feature1': 'value1'}),
            config.SlicingSpec(feature_values={'feature2': 'value2'})
        ]
        threshold = config.MetricThreshold(
            value_threshold=config.GenericValueThreshold(
                upper_bound={'value': 1}))
        eval_config = config.EvalConfig(
            model_specs=[
                config.ModelSpec(),
            ],
            slicing_specs=slicing_specs,
            metrics_specs=[
                config.MetricsSpec(
                    metrics=[
                        config.MetricConfig(
                            class_name='WeightedExampleCount',
                            # 1.5 < 1, NOT OK.
                            per_slice_thresholds=[
                                config.PerSliceMetricThreshold(
                                    slicing_specs=slicing_specs,
                                    threshold=threshold)
                            ]),
                    ],
                    model_names=['']),
            ],
        )
        sliced_metrics = ((('feature1', 'value1'), ), {
            metric_types.MetricKey(name='weighted_example_count'):
            0,
        })
        result = metrics_validator.validate_metrics(sliced_metrics,
                                                    eval_config)

        expected_checks = text_format.Parse(
            """
        validation_ok: true
        validation_details {
          slicing_details {
            slicing_spec {
              feature_values {
                key: "feature1"
                value: "value1"
              }
            }
            num_matching_slices: 1
          }
        }""", validation_result_pb2.ValidationResult())

        self.assertProtoEquals(expected_checks, result)

        missing = metrics_validator.get_missing_slices(
            result.validation_details.slicing_details, eval_config)
        self.assertLen(missing, 2)
        self.assertProtoEquals(missing[0], slicing_specs[0])
        self.assertProtoEquals(missing[1], slicing_specs[2])
 def testValidateMetricsMetricValueAndThreshold(self):
     eval_config = config.EvalConfig(
         model_specs=[
             config.ModelSpec(),
         ],
         slicing_specs=[config.SlicingSpec()],
         metrics_specs=[
             config.MetricsSpec(
                 metrics=[
                     config.MetricConfig(
                         class_name='WeightedExampleCount',
                         # 1.5 < 1, NOT OK.
                         threshold=config.MetricThreshold(
                             value_threshold=config.GenericValueThreshold(
                                 upper_bound={'value': 1}))),
                 ],
                 model_names=['']),
         ],
     )
     sliced_metrics = ((()), {
         metric_types.MetricKey(name='weighted_example_count'):
         1.5,
     })
     result = metrics_validator.validate_metrics(sliced_metrics,
                                                 eval_config)
     self.assertFalse(result.validation_ok)
     expected = text_format.Parse(
         """
     metric_validations_per_slice {
       slice_key {
       }
       failures {
         metric_key {
           name: "weighted_example_count"
         }
         metric_threshold {
           value_threshold {
             upper_bound {
               value: 1.0
             }
           }
         }
         metric_value {
           double_value {
             value: 1.5
           }
         }
       }
     }""", validation_result_pb2.ValidationResult())
     self.assertEqual(result, expected)
 def testValidateMetricsMetricValueAndThreshold(self, slicing_specs,
                                                slice_key):
   threshold = config.MetricThreshold(
       value_threshold=config.GenericValueThreshold(upper_bound={'value': 1}))
   eval_config = config.EvalConfig(
       model_specs=[
           config.ModelSpec(),
       ],
       slicing_specs=slicing_specs,
       metrics_specs=[
           config.MetricsSpec(
               metrics=[
                   config.MetricConfig(
                       class_name='WeightedExampleCount',
                       # 1.5 < 1, NOT OK.
                       threshold=threshold if slicing_specs is None else None,
                       per_slice_thresholds=[
                           config.PerSliceMetricThreshold(
                               slicing_specs=slicing_specs,
                               threshold=threshold)
                       ]),
               ],
               model_names=['']),
       ],
   )
   sliced_metrics = (slice_key, {
       metric_types.MetricKey(name='weighted_example_count'): 1.5,
   })
   result = metrics_validator.validate_metrics(sliced_metrics, eval_config)
   self.assertFalse(result.validation_ok)
   expected = text_format.Parse(
       """
       metric_validations_per_slice {
         failures {
           metric_key {
             name: "weighted_example_count"
           }
           metric_value {
             double_value {
               value: 1.5
             }
           }
         }
       }""", validation_result_pb2.ValidationResult())
   expected.metric_validations_per_slice[0].failures[
       0].metric_threshold.CopyFrom(threshold)
   expected.metric_validations_per_slice[0].slice_key.CopyFrom(
       slicer.serialize_slice_key(slice_key))
   self.assertEqual(result, expected)
 def testValidateMetricsInvalidThreshold(self):
     eval_config = config.EvalConfig(
         model_specs=[
             config.ModelSpec(),
         ],
         slicing_specs=[config.SlicingSpec()],
         metrics_specs=[
             config.MetricsSpec(
                 thresholds={
                     'invalid_threshold':
                     config.MetricThreshold(
                         value_threshold=config.GenericValueThreshold(
                             lower_bound={'value': 0.2}))
                 })
         ],
     )
     sliced_metrics = ((()), {
         metric_types.MetricKey(name='weighted_example_count'):
         1.5,
     })
     result = metrics_validator.validate_metrics(sliced_metrics,
                                                 eval_config)
     self.assertFalse(result.validation_ok)
     expected = text_format.Parse(
         """
     metric_validations_per_slice {
       slice_key {
       }
       failures {
         metric_key {
           name: "invalid_threshold"
         }
         metric_threshold {
           value_threshold {
             lower_bound {
               value: 0.2
             }
           }
         }
         message: 'Metric not found.'
       }
     }""", validation_result_pb2.ValidationResult())
     self.assertProtoEquals(expected, result)
示例#11
0
 def testMetricThresholdsFromMetricsSpecs(self):
     metrics_specs = [
         config.MetricsSpec(
             thresholds={
                 'auc':
                 config.MetricThreshold(
                     value_threshold=config.GenericValueThreshold()),
                 'mean/label':
                 config.MetricThreshold(
                     value_threshold=config.GenericValueThreshold(),
                     change_threshold=config.GenericChangeThreshold()),
                 # The mse metric will be overridden by MetricConfig below.
                 'mse':
                 config.MetricThreshold(
                     change_threshold=config.GenericChangeThreshold())
             },
             # Model names and output_names should be ignored because
             # ExampleCount is model independent.
             model_names=['model_name'],
             output_names=['output_name']),
         config.MetricsSpec(
             metrics=[
                 config.MetricConfig(
                     class_name='ExampleCount',
                     config=json.dumps({'name': 'example_count'}),
                     threshold=config.MetricThreshold(
                         value_threshold=config.GenericValueThreshold()))
             ],
             # Model names and output_names should be ignored because
             # ExampleCount is model independent.
             model_names=['model_name1', 'model_name2'],
             output_names=['output_name1', 'output_name2']),
         config.MetricsSpec(metrics=[
             config.MetricConfig(
                 class_name='WeightedExampleCount',
                 config=json.dumps({'name': 'weighted_example_count'}),
                 threshold=config.MetricThreshold(
                     value_threshold=config.GenericValueThreshold()))
         ],
                            model_names=['model_name1', 'model_name2'],
                            output_names=['output_name1', 'output_name2']),
         config.MetricsSpec(
             metrics=[
                 config.MetricConfig(
                     class_name='MeanSquaredError',
                     config=json.dumps({'name': 'mse'}),
                     threshold=config.MetricThreshold(
                         change_threshold=config.GenericChangeThreshold())),
                 config.MetricConfig(
                     class_name='MeanLabel',
                     config=json.dumps({'name': 'mean_label'}),
                     threshold=config.MetricThreshold(
                         change_threshold=config.GenericChangeThreshold()))
             ],
             model_names=['model_name'],
             output_names=['output_name'],
             binarize=config.BinarizationOptions(
                 class_ids={'values': [0, 1]}),
             aggregate=config.AggregationOptions(macro_average=True))
     ]
     thresholds = metric_specs.metric_thresholds_from_metrics_specs(
         metrics_specs)
     self.assertLen(thresholds, 14)
     self.assertIn(
         metric_types.MetricKey(name='auc',
                                model_name='model_name',
                                output_name='output_name'), thresholds)
     self.assertIn(
         metric_types.MetricKey(name='mean/label',
                                model_name='model_name',
                                output_name='output_name',
                                is_diff=True), thresholds)
     self.assertIn(
         metric_types.MetricKey(name='mean/label',
                                model_name='model_name',
                                output_name='output_name',
                                is_diff=False), thresholds)
     self.assertIn(metric_types.MetricKey(name='example_count'), thresholds)
     self.assertIn(
         metric_types.MetricKey(name='weighted_example_count',
                                model_name='model_name1',
                                output_name='output_name1'), thresholds)
     self.assertIn(
         metric_types.MetricKey(name='weighted_example_count',
                                model_name='model_name1',
                                output_name='output_name2'), thresholds)
     self.assertIn(
         metric_types.MetricKey(name='weighted_example_count',
                                model_name='model_name2',
                                output_name='output_name1'), thresholds)
     self.assertIn(
         metric_types.MetricKey(name='weighted_example_count',
                                model_name='model_name2',
                                output_name='output_name2'), thresholds)
     self.assertIn(
         metric_types.MetricKey(name='mse',
                                model_name='model_name',
                                output_name='output_name',
                                sub_key=metric_types.SubKey(class_id=0),
                                is_diff=True), thresholds)
     self.assertIn(
         metric_types.MetricKey(name='mse',
                                model_name='model_name',
                                output_name='output_name',
                                sub_key=metric_types.SubKey(class_id=1),
                                is_diff=True), thresholds)
     self.assertIn(
         metric_types.MetricKey(name='mse',
                                model_name='model_name',
                                output_name='output_name',
                                is_diff=True), thresholds)
     self.assertIn(
         metric_types.MetricKey(name='mean_label',
                                model_name='model_name',
                                output_name='output_name',
                                sub_key=metric_types.SubKey(class_id=0),
                                is_diff=True), thresholds)
     self.assertIn(
         metric_types.MetricKey(name='mean_label',
                                model_name='model_name',
                                output_name='output_name',
                                sub_key=metric_types.SubKey(class_id=1),
                                is_diff=True), thresholds)
     self.assertIn(
         metric_types.MetricKey(name='mean_label',
                                model_name='model_name',
                                output_name='output_name',
                                is_diff=True), thresholds)
  def testWriteValidationResults(self):
    model_dir, baseline_dir = self._getExportDir(), self._getBaselineDir()
    eval_shared_model = self._build_keras_model(model_dir, mul=0)
    baseline_eval_shared_model = self._build_keras_model(baseline_dir, mul=1)
    validations_file = os.path.join(self._getTempDir(),
                                    constants.VALIDATIONS_KEY)
    examples = [
        self._makeExample(
            input=0.0,
            label=1.0,
            example_weight=1.0,
            extra_feature='non_model_feature'),
        self._makeExample(
            input=1.0,
            label=0.0,
            example_weight=0.5,
            extra_feature='non_model_feature'),
    ]

    eval_config = config.EvalConfig(
        model_specs=[
            config.ModelSpec(
                name='candidate',
                label_key='label',
                example_weight_key='example_weight'),
            config.ModelSpec(
                name='baseline',
                label_key='label',
                example_weight_key='example_weight',
                is_baseline=True)
        ],
        slicing_specs=[config.SlicingSpec()],
        metrics_specs=[
            config.MetricsSpec(
                metrics=[
                    config.MetricConfig(
                        class_name='WeightedExampleCount',
                        # 1.5 < 1, NOT OK.
                        threshold=config.MetricThreshold(
                            value_threshold=config.GenericValueThreshold(
                                upper_bound={'value': 1}))),
                    config.MetricConfig(
                        class_name='ExampleCount',
                        # 2 > 10, NOT OK.
                        threshold=config.MetricThreshold(
                            value_threshold=config.GenericValueThreshold(
                                lower_bound={'value': 10}))),
                    config.MetricConfig(
                        class_name='MeanLabel',
                        # 0 > 0 and 0 > 0%?: NOT OK.
                        threshold=config.MetricThreshold(
                            change_threshold=config.GenericChangeThreshold(
                                direction=config.MetricDirection
                                .HIGHER_IS_BETTER,
                                relative={'value': 0},
                                absolute={'value': 0}))),
                    config.MetricConfig(
                        # MeanPrediction = (0+0)/(1+0.5) = 0
                        class_name='MeanPrediction',
                        # -.01 < 0 < .01, OK.
                        # Diff% = -.333/.333 = -100% < -99%, OK.
                        # Diff = 0 - .333 = -.333 < 0, OK.
                        threshold=config.MetricThreshold(
                            value_threshold=config.GenericValueThreshold(
                                upper_bound={'value': .01},
                                lower_bound={'value': -.01}),
                            change_threshold=config.GenericChangeThreshold(
                                direction=config.MetricDirection
                                .LOWER_IS_BETTER,
                                relative={'value': -.99},
                                absolute={'value': 0})))
                ],
                model_names=['candidate', 'baseline']),
        ],
        options=config.Options(
            disabled_outputs={'values': ['eval_config.json']}),
    )
    slice_spec = [
        slicer.SingleSliceSpec(spec=s) for s in eval_config.slicing_specs
    ]
    eval_shared_models = {
        'candidate': eval_shared_model,
        'baseline': baseline_eval_shared_model
    }
    extractors = [
        input_extractor.InputExtractor(eval_config),
        predict_extractor_v2.PredictExtractor(
            eval_shared_model=eval_shared_models, eval_config=eval_config),
        slice_key_extractor.SliceKeyExtractor(slice_spec=slice_spec)
    ]
    evaluators = [
        metrics_and_plots_evaluator_v2.MetricsAndPlotsEvaluator(
            eval_config=eval_config, eval_shared_model=eval_shared_models)
    ]
    output_paths = {
        constants.VALIDATIONS_KEY: validations_file,
    }
    writers = [
        metrics_plots_and_validations_writer.MetricsPlotsAndValidationsWriter(
            output_paths, add_metrics_callbacks=[])
    ]

    with beam.Pipeline() as pipeline:

      # pylint: disable=no-value-for-parameter
      _ = (
          pipeline
          | 'Create' >> beam.Create([e.SerializeToString() for e in examples])
          | 'ExtractEvaluateAndWriteResults' >>
          model_eval_lib.ExtractEvaluateAndWriteResults(
              eval_config=eval_config,
              eval_shared_model=eval_shared_model,
              extractors=extractors,
              evaluators=evaluators,
              writers=writers))
      # pylint: enable=no-value-for-parameter

    validation_result = model_eval_lib.load_validation_result(
        os.path.dirname(validations_file))

    expected_validations = [
        text_format.Parse(
            """
            metric_key {
              name: "weighted_example_count"
              model_name: "candidate"
            }
            metric_threshold {
              value_threshold {
                upper_bound {
                  value: 1.0
                }
              }
            }
            metric_value {
              double_value {
                value: 1.5
              }
            }
            """, validation_result_pb2.ValidationFailure()),
        text_format.Parse(
            """
            metric_key {
              name: "example_count"
            }
            metric_threshold {
              value_threshold {
                lower_bound {
                  value: 10.0
                }
              }
            }
            metric_value {
              double_value {
                value: 2.0
              }
            }
            """, validation_result_pb2.ValidationFailure()),
        text_format.Parse(
            """
            metric_key {
              name: "mean_label"
              model_name: "candidate"
              is_diff: true
            }
            metric_threshold {
              change_threshold {
                absolute {
                  value: 0.0
                }
                relative {
                  value: 0.0
                }
                direction: HIGHER_IS_BETTER
              }
            }
            metric_value {
              double_value {
                value: 0.0
              }
            }
            """, validation_result_pb2.ValidationFailure()),
    ]
    self.assertFalse(validation_result.validation_ok)
    self.assertLen(validation_result.metric_validations_per_slice, 1)
    self.assertCountEqual(
        expected_validations,
        validation_result.metric_validations_per_slice[0].failures)
示例#13
0
  def testWriteValidationResults(self, output_file_format):
    model_dir, baseline_dir = self._getExportDir(), self._getBaselineDir()
    eval_shared_model = self._build_keras_model(model_dir, mul=0)
    baseline_eval_shared_model = self._build_keras_model(baseline_dir, mul=1)
    validations_file = os.path.join(self._getTempDir(),
                                    constants.VALIDATIONS_KEY)
    schema = text_format.Parse(
        """
        tensor_representation_group {
          key: ""
          value {
            tensor_representation {
              key: "input"
              value {
                dense_tensor {
                  column_name: "input"
                  shape { dim { size: 1 } }
                }
              }
            }
          }
        }
        feature {
          name: "input"
          type: FLOAT
        }
        feature {
          name: "label"
          type: FLOAT
        }
        feature {
          name: "example_weight"
          type: FLOAT
        }
        feature {
          name: "extra_feature"
          type: BYTES
        }
        """, schema_pb2.Schema())
    tfx_io = test_util.InMemoryTFExampleRecord(
        schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN)
    tensor_adapter_config = tensor_adapter.TensorAdapterConfig(
        arrow_schema=tfx_io.ArrowSchema(),
        tensor_representations=tfx_io.TensorRepresentations())
    examples = [
        self._makeExample(
            input=0.0,
            label=1.0,
            example_weight=1.0,
            extra_feature='non_model_feature'),
        self._makeExample(
            input=1.0,
            label=0.0,
            example_weight=0.5,
            extra_feature='non_model_feature'),
    ]

    eval_config = config.EvalConfig(
        model_specs=[
            config.ModelSpec(
                name='candidate',
                label_key='label',
                example_weight_key='example_weight'),
            config.ModelSpec(
                name='baseline',
                label_key='label',
                example_weight_key='example_weight',
                is_baseline=True)
        ],
        slicing_specs=[config.SlicingSpec()],
        metrics_specs=[
            config.MetricsSpec(
                metrics=[
                    config.MetricConfig(
                        class_name='WeightedExampleCount',
                        # 1.5 < 1, NOT OK.
                        threshold=config.MetricThreshold(
                            value_threshold=config.GenericValueThreshold(
                                upper_bound={'value': 1}))),
                    config.MetricConfig(
                        class_name='ExampleCount',
                        # 2 > 10, NOT OK.
                        threshold=config.MetricThreshold(
                            value_threshold=config.GenericValueThreshold(
                                lower_bound={'value': 10}))),
                    config.MetricConfig(
                        class_name='MeanLabel',
                        # 0 > 0 and 0 > 0%?: NOT OK.
                        threshold=config.MetricThreshold(
                            change_threshold=config.GenericChangeThreshold(
                                direction=config.MetricDirection
                                .HIGHER_IS_BETTER,
                                relative={'value': 0},
                                absolute={'value': 0}))),
                    config.MetricConfig(
                        # MeanPrediction = (0+0)/(1+0.5) = 0
                        class_name='MeanPrediction',
                        # -.01 < 0 < .01, OK.
                        # Diff% = -.333/.333 = -100% < -99%, OK.
                        # Diff = 0 - .333 = -.333 < 0, OK.
                        threshold=config.MetricThreshold(
                            value_threshold=config.GenericValueThreshold(
                                upper_bound={'value': .01},
                                lower_bound={'value': -.01}),
                            change_threshold=config.GenericChangeThreshold(
                                direction=config.MetricDirection
                                .LOWER_IS_BETTER,
                                relative={'value': -.99},
                                absolute={'value': 0})))
                ],
                model_names=['candidate', 'baseline']),
        ],
        options=config.Options(
            disabled_outputs={'values': ['eval_config.json']}),
    )
    slice_spec = [
        slicer.SingleSliceSpec(spec=s) for s in eval_config.slicing_specs
    ]
    eval_shared_models = {
        'candidate': eval_shared_model,
        'baseline': baseline_eval_shared_model
    }
    extractors = [
        batched_input_extractor.BatchedInputExtractor(eval_config),
        batched_predict_extractor_v2.BatchedPredictExtractor(
            eval_shared_model=eval_shared_models,
            eval_config=eval_config,
            tensor_adapter_config=tensor_adapter_config),
        unbatch_extractor.UnbatchExtractor(),
        slice_key_extractor.SliceKeyExtractor(slice_spec=slice_spec)
    ]
    evaluators = [
        metrics_and_plots_evaluator_v2.MetricsAndPlotsEvaluator(
            eval_config=eval_config, eval_shared_model=eval_shared_models)
    ]
    output_paths = {
        constants.VALIDATIONS_KEY: validations_file,
    }
    writers = [
        metrics_plots_and_validations_writer.MetricsPlotsAndValidationsWriter(
            output_paths,
            add_metrics_callbacks=[],
            output_file_format=output_file_format)
    ]

    with beam.Pipeline() as pipeline:
      # pylint: disable=no-value-for-parameter
      _ = (
          pipeline
          | 'Create' >> beam.Create([e.SerializeToString() for e in examples])
          | 'BatchExamples' >> tfx_io.BeamSource()
          | 'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts()
          | 'ExtractEvaluate' >> model_eval_lib.ExtractAndEvaluate(
              extractors=extractors, evaluators=evaluators)
          | 'WriteResults' >> model_eval_lib.WriteResults(writers=writers))
      # pylint: enable=no-value-for-parameter

    validation_result = (
        metrics_plots_and_validations_writer
        .load_and_deserialize_validation_result(
            os.path.dirname(validations_file)))

    expected_validations = [
        text_format.Parse(
            """
            metric_key {
              name: "weighted_example_count"
              model_name: "candidate"
            }
            metric_threshold {
              value_threshold {
                upper_bound {
                  value: 1.0
                }
              }
            }
            metric_value {
              double_value {
                value: 1.5
              }
            }
            """, validation_result_pb2.ValidationFailure()),
        text_format.Parse(
            """
            metric_key {
              name: "example_count"
              model_name: "candidate"
            }
            metric_threshold {
              value_threshold {
                lower_bound {
                  value: 10.0
                }
              }
            }
            metric_value {
              double_value {
                value: 2.0
              }
            }
            """, validation_result_pb2.ValidationFailure()),
        text_format.Parse(
            """
            metric_key {
              name: "mean_label"
              model_name: "candidate"
              is_diff: true
            }
            metric_threshold {
              change_threshold {
                absolute {
                  value: 0.0
                }
                relative {
                  value: 0.0
                }
                direction: HIGHER_IS_BETTER
              }
            }
            metric_value {
              double_value {
                value: 0.0
              }
            }
            """, validation_result_pb2.ValidationFailure()),
    ]
    self.assertFalse(validation_result.validation_ok)
    self.assertLen(validation_result.metric_validations_per_slice, 1)
    self.assertCountEqual(
        expected_validations,
        validation_result.metric_validations_per_slice[0].failures)
示例#14
0
    def testMetricThresholdsFromMetricsSpecs(self):
        slice_specs = [
            config.SlicingSpec(feature_keys=['feature1']),
            config.SlicingSpec(feature_values={'feature2': 'value1'})
        ]

        # For cross slice tests.
        baseline_slice_spec = config.SlicingSpec(feature_keys=['feature3'])

        metrics_specs = [
            config.MetricsSpec(
                thresholds={
                    'auc':
                    config.MetricThreshold(
                        value_threshold=config.GenericValueThreshold()),
                    'mean/label':
                    config.MetricThreshold(
                        value_threshold=config.GenericValueThreshold(),
                        change_threshold=config.GenericChangeThreshold()),
                    'mse':
                    config.MetricThreshold(
                        change_threshold=config.GenericChangeThreshold())
                },
                per_slice_thresholds={
                    'auc':
                    config.PerSliceMetricThresholds(thresholds=[
                        config.PerSliceMetricThreshold(
                            slicing_specs=slice_specs,
                            threshold=config.MetricThreshold(
                                value_threshold=config.GenericValueThreshold(
                                )))
                    ]),
                    'mean/label':
                    config.PerSliceMetricThresholds(thresholds=[
                        config.PerSliceMetricThreshold(
                            slicing_specs=slice_specs,
                            threshold=config.MetricThreshold(
                                value_threshold=config.GenericValueThreshold(),
                                change_threshold=config.GenericChangeThreshold(
                                )))
                    ])
                },
                cross_slice_thresholds={
                    'auc':
                    config.CrossSliceMetricThresholds(thresholds=[
                        config.CrossSliceMetricThreshold(
                            cross_slicing_specs=[
                                config.CrossSlicingSpec(
                                    baseline_spec=baseline_slice_spec,
                                    slicing_specs=slice_specs)
                            ],
                            threshold=config.MetricThreshold(
                                value_threshold=config.GenericValueThreshold(),
                                change_threshold=config.GenericChangeThreshold(
                                )))
                    ]),
                    'mse':
                    config.CrossSliceMetricThresholds(thresholds=[
                        config.CrossSliceMetricThreshold(
                            cross_slicing_specs=[
                                config.CrossSlicingSpec(
                                    baseline_spec=baseline_slice_spec,
                                    slicing_specs=slice_specs)
                            ],
                            threshold=config.MetricThreshold(
                                change_threshold=config.GenericChangeThreshold(
                                ))),
                        # Test for duplicate cross_slicing_spec.
                        config.CrossSliceMetricThreshold(
                            cross_slicing_specs=[
                                config.CrossSlicingSpec(
                                    baseline_spec=baseline_slice_spec,
                                    slicing_specs=slice_specs)
                            ],
                            threshold=config.MetricThreshold(
                                value_threshold=config.GenericValueThreshold())
                        )
                    ])
                },
                model_names=['model_name'],
                output_names=['output_name']),
            config.MetricsSpec(metrics=[
                config.MetricConfig(
                    class_name='ExampleCount',
                    config=json.dumps({'name': 'example_count'}),
                    threshold=config.MetricThreshold(
                        value_threshold=config.GenericValueThreshold()))
            ],
                               model_names=['model_name1', 'model_name2'],
                               output_names=['output_name1', 'output_name2']),
            config.MetricsSpec(metrics=[
                config.MetricConfig(
                    class_name='WeightedExampleCount',
                    config=json.dumps({'name': 'weighted_example_count'}),
                    threshold=config.MetricThreshold(
                        value_threshold=config.GenericValueThreshold()))
            ],
                               model_names=['model_name1', 'model_name2'],
                               output_names=['output_name1', 'output_name2']),
            config.MetricsSpec(metrics=[
                config.MetricConfig(
                    class_name='MeanSquaredError',
                    config=json.dumps({'name': 'mse'}),
                    threshold=config.MetricThreshold(
                        change_threshold=config.GenericChangeThreshold())),
                config.MetricConfig(
                    class_name='MeanLabel',
                    config=json.dumps({'name': 'mean_label'}),
                    threshold=config.MetricThreshold(
                        change_threshold=config.GenericChangeThreshold()),
                    per_slice_thresholds=[
                        config.PerSliceMetricThreshold(
                            slicing_specs=slice_specs,
                            threshold=config.MetricThreshold(
                                change_threshold=config.GenericChangeThreshold(
                                ))),
                    ],
                    cross_slice_thresholds=[
                        config.CrossSliceMetricThreshold(
                            cross_slicing_specs=[
                                config.CrossSlicingSpec(
                                    baseline_spec=baseline_slice_spec,
                                    slicing_specs=slice_specs)
                            ],
                            threshold=config.MetricThreshold(
                                change_threshold=config.GenericChangeThreshold(
                                )))
                    ]),
            ],
                               model_names=['model_name'],
                               output_names=['output_name'],
                               binarize=config.BinarizationOptions(
                                   class_ids={'values': [0, 1]}),
                               aggregate=config.AggregationOptions(
                                   macro_average=True,
                                   class_weights={
                                       0: 1.0,
                                       1: 1.0
                                   }))
        ]

        thresholds = metric_specs.metric_thresholds_from_metrics_specs(
            metrics_specs)

        expected_keys_and_threshold_counts = {
            metric_types.MetricKey(name='auc',
                                   model_name='model_name',
                                   output_name='output_name',
                                   is_diff=False):
            4,
            metric_types.MetricKey(name='auc',
                                   model_name='model_name',
                                   output_name='output_name',
                                   is_diff=True):
            1,
            metric_types.MetricKey(name='mean/label',
                                   model_name='model_name',
                                   output_name='output_name',
                                   is_diff=True):
            3,
            metric_types.MetricKey(name='mean/label',
                                   model_name='model_name',
                                   output_name='output_name',
                                   is_diff=False):
            3,
            metric_types.MetricKey(name='example_count',
                                   model_name='model_name1',
                                   output_name='output_name1'):
            1,
            metric_types.MetricKey(name='example_count',
                                   model_name='model_name1',
                                   output_name='output_name2'):
            1,
            metric_types.MetricKey(name='example_count',
                                   model_name='model_name2',
                                   output_name='output_name1'):
            1,
            metric_types.MetricKey(name='example_count',
                                   model_name='model_name2',
                                   output_name='output_name2'):
            1,
            metric_types.MetricKey(name='weighted_example_count',
                                   model_name='model_name1',
                                   output_name='output_name1'):
            1,
            metric_types.MetricKey(name='weighted_example_count',
                                   model_name='model_name1',
                                   output_name='output_name2'):
            1,
            metric_types.MetricKey(name='weighted_example_count',
                                   model_name='model_name2',
                                   output_name='output_name1'):
            1,
            metric_types.MetricKey(name='weighted_example_count',
                                   model_name='model_name2',
                                   output_name='output_name2'):
            1,
            metric_types.MetricKey(name='mse',
                                   model_name='model_name',
                                   output_name='output_name',
                                   sub_key=metric_types.SubKey(class_id=0),
                                   is_diff=True):
            1,
            metric_types.MetricKey(name='mse',
                                   model_name='model_name',
                                   output_name='output_name',
                                   sub_key=metric_types.SubKey(class_id=1),
                                   is_diff=True):
            1,
            metric_types.MetricKey(name='mse',
                                   model_name='model_name',
                                   output_name='output_name',
                                   is_diff=True):
            2,
            metric_types.MetricKey(name='mse',
                                   model_name='model_name',
                                   output_name='output_name',
                                   is_diff=False):
            1,
            metric_types.MetricKey(name='mse',
                                   model_name='model_name',
                                   output_name='output_name',
                                   aggregation_type=metric_types.AggregationType(macro_average=True),
                                   is_diff=True):
            1,
            metric_types.MetricKey(name='mean_label',
                                   model_name='model_name',
                                   output_name='output_name',
                                   sub_key=metric_types.SubKey(class_id=0),
                                   is_diff=True):
            4,
            metric_types.MetricKey(name='mean_label',
                                   model_name='model_name',
                                   output_name='output_name',
                                   sub_key=metric_types.SubKey(class_id=1),
                                   is_diff=True):
            4,
            metric_types.MetricKey(name='mean_label',
                                   model_name='model_name',
                                   output_name='output_name',
                                   aggregation_type=metric_types.AggregationType(macro_average=True),
                                   is_diff=True):
            4
        }
        self.assertLen(thresholds, len(expected_keys_and_threshold_counts))
        for key, count in expected_keys_and_threshold_counts.items():
            self.assertIn(key, thresholds)
            self.assertLen(thresholds[key], count,
                           'failed for key {}'.format(key))
示例#15
0
 def testValidateMetricsChangeThresholdEqualPass(self, slicing_specs,
                                                 slice_key):
   # Change thresholds.
   threshold1 = config.MetricThreshold(
       change_threshold=config.GenericChangeThreshold(
           direction=config.MetricDirection.HIGHER_IS_BETTER,
           absolute={'value': -.333},
           relative={'value': -.333}))
   threshold2 = config.MetricThreshold(
       change_threshold=config.GenericChangeThreshold(
           direction=config.MetricDirection.LOWER_IS_BETTER,
           absolute={'value': -.333},
           relative={'value': -.333}))
   # Value thresholds.
   threshold3 = config.MetricThreshold(
       value_threshold=config.GenericValueThreshold(lower_bound={'value': 1}))
   threshold4 = config.MetricThreshold(
       value_threshold=config.GenericValueThreshold(upper_bound={'value': 1}))
   eval_config = config.EvalConfig(
       model_specs=[
           config.ModelSpec(name='candidate'),
           config.ModelSpec(name='baseline', is_baseline=True)
       ],
       slicing_specs=slicing_specs,
       metrics_specs=[
           config.MetricsSpec(
               metrics=[
                   config.MetricConfig(
                       class_name='MeanPrediction',
                       # Diff = -.333 == -.333, OK.
                       threshold=threshold1 if slicing_specs is None else None,
                       per_slice_thresholds=[
                           config.PerSliceMetricThreshold(
                               slicing_specs=slicing_specs,
                               threshold=threshold1)
                       ]),
                   config.MetricConfig(
                       class_name='MeanLabel',
                       # Diff = -.333 == -.333, OK.
                       threshold=threshold2 if slicing_specs is None else None,
                       per_slice_thresholds=[
                           config.PerSliceMetricThreshold(
                               slicing_specs=slicing_specs,
                               threshold=threshold2)
                       ]),
                   config.MetricConfig(
                       class_name='ExampleCount',
                       # 1 == 1, OK.
                       threshold=threshold3 if slicing_specs is None else None,
                       per_slice_thresholds=[
                           config.PerSliceMetricThreshold(
                               slicing_specs=slicing_specs,
                               threshold=threshold3)
                       ]),
                   config.MetricConfig(
                       class_name='WeightedExampleCount',
                       # 1 == 1, OK.
                       threshold=threshold4 if slicing_specs is None else None,
                       per_slice_thresholds=[
                           config.PerSliceMetricThreshold(
                               slicing_specs=slicing_specs,
                               threshold=threshold4)
                       ]),
               ],
               model_names=['candidate']),
       ],
   )
   sliced_metrics = (slice_key, {
       metric_types.MetricKey(name='mean_prediction', model_name='candidate'):
           0.677,
       metric_types.MetricKey(name='mean_prediction', model_name='baseline'):
           1,
       metric_types.MetricKey(
           name='mean_prediction', is_diff=True, model_name='candidate'):
           -0.333,
       metric_types.MetricKey(name='mean_label', model_name='candidate'):
           0.677,
       metric_types.MetricKey(name='mean_label', model_name='baseline'):
           1,
       metric_types.MetricKey(
           name='mean_label', is_diff=True, model_name='candidate'):
           -0.333,
       metric_types.MetricKey(name='example_count', model_name='candidate'):
           1,
       metric_types.MetricKey(
           name='weighted_example_count', model_name='candidate'):
           1,
   })
   result = metrics_validator.validate_metrics(sliced_metrics, eval_config)
   self.assertTrue(result.validation_ok)