Пример #1
0
    def _assert_test(self,
                     num_buckets,
                     baseline_examples,
                     comparison_examples,
                     lift_metric_value,
                     ignore_out_of_bound_examples=False):
        eval_config = config.EvalConfig(
            cross_slicing_specs=[config.CrossSlicingSpec()])
        computations = lift.Lift(
            num_buckets=num_buckets,
            ignore_out_of_bound_examples=ignore_out_of_bound_examples
        ).computations(eval_config=eval_config)
        histogram = computations[0]
        lift_metrics = computations[1]

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            baseline_result = (
                pipeline
                | 'CreateB' >> beam.Create(baseline_examples)
                | 'ProcessB' >> beam.Map(metric_util.to_standard_metric_inputs)
                | 'AddSliceB' >> beam.Map(lambda x: ((), x))
                | 'ComputeHistogramB' >> beam.CombinePerKey(histogram.combiner)
            )  # pyformat: ignore

            comparison_result = (
                pipeline
                | 'CreateC' >> beam.Create(comparison_examples)
                | 'ProcessC' >> beam.Map(metric_util.to_standard_metric_inputs)
                | 'AddSliceC' >> beam.Map(lambda x: (('slice'), x))
                | 'ComputeHistogramC' >> beam.CombinePerKey(histogram.combiner)
            )  # pyformat: ignore

            # pylint: enable=no-value-for-parameter

            merged_result = ((baseline_result, comparison_result)
                             | 'MergePCollections' >> beam.Flatten())

            def check_result(got):
                try:
                    self.assertLen(got, 2)
                    slice_1, metric_1 = got[0]
                    slice_2, metric_2 = got[1]
                    lift_value = None
                    if not slice_1:
                        lift_value = lift_metrics.cross_slice_comparison(
                            metric_1, metric_2)
                    else:
                        lift_value = lift_metrics.cross_slice_comparison(
                            metric_2, metric_1)

                    self.assertDictElementsAlmostEqual(
                        lift_value, {
                            metric_types.MetricKey(name=f'lift@{num_buckets}'):
                            lift_metric_value,
                        })
                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(merged_result, check_result, label='result')
    config.SlicingSpec(feature_values={'feature1': 'value1'}),
    config.SlicingSpec(feature_values={'feature2': 'value2'})
], (('feature1', 'value1'), ))
_UNMATCHED_SINGLE_SLICE_TEST = ('single_slice',
                                [config.SlicingSpec(feature_keys='feature1')],
                                (('unmatched_feature', 'unmatched_value'), ))
_UNMATCHED_MULTIPLE_SLICES_TEST = ('multiple_slices', [
    config.SlicingSpec(feature_values={'feature1': 'value1'}),
    config.SlicingSpec(feature_values={'feature2': 'value2'})
], (('unmatched_feature', 'unmatched_value'), ))

# Cross slice tests: (<test_name>, <cross_slice_config>, <cross_slice_key>)
_CROSS_SLICE_GLOBAL_TEST = ('global_slice', [
    config.CrossSlicingSpec(
        baseline_spec=config.SlicingSpec(),
        slicing_specs=[
            config.SlicingSpec(feature_values={'feature2': 'value2'})
        ])
], ((()), (('feature2', 'value2'), )))
_SINGLE_CROSS_SLICE_TEST = ('single_slice', [
    config.CrossSlicingSpec(
        baseline_spec=config.SlicingSpec(feature_keys=['feature1']),
        slicing_specs=[
            config.SlicingSpec(feature_values={'feature2': 'value2'})
        ])
], ((('feature1', 'value1'), ), (('feature2', 'value2'), )))
_MULTIPLE_CROSS_SLICE_TEST = ('multiple_slice', [
    config.CrossSlicingSpec(
        baseline_spec=config.SlicingSpec(feature_keys=['feature1']),
        slicing_specs=[
            config.SlicingSpec(feature_values={'feature2': 'value2'})
Пример #3
0
 def testIsCrossSliceApplicable(self):
     test_cases = [
         (True, 'overall pass', ((), (('b', 2), )),
          config.CrossSlicingSpec(
              baseline_spec=config.SlicingSpec(),
              slicing_specs=[config.SlicingSpec(feature_values={'b': '2'})
                             ])),
         (True, 'value pass', ((('a', 1), ), (('b', 2), )),
          config.CrossSlicingSpec(
              baseline_spec=config.SlicingSpec(feature_values={'a': '1'}),
              slicing_specs=[config.SlicingSpec(feature_values={'b': '2'})
                             ])),
         (True, 'baseline key pass', ((('a', 1), ), (('b', 2), )),
          config.CrossSlicingSpec(
              baseline_spec=config.SlicingSpec(feature_keys=['a']),
              slicing_specs=[config.SlicingSpec(feature_values={'b': '2'})
                             ])),
         (True, 'comparison key pass', ((('a', 1), ), (('b', 2), )),
          config.CrossSlicingSpec(
              baseline_spec=config.SlicingSpec(feature_values={'a': '1'}),
              slicing_specs=[config.SlicingSpec(feature_keys=['b'])])),
         (True, 'comparison multiple key pass', ((('a', 1), ), (('c',
                                                                 3), )),
          config.CrossSlicingSpec(
              baseline_spec=config.SlicingSpec(feature_values={'a': '1'}),
              slicing_specs=[
                  config.SlicingSpec(feature_keys=['b']),
                  config.SlicingSpec(feature_keys=['c'])
              ])),
         (False, 'overall fail', ((('a', 1), ), (('b', 2), )),
          config.CrossSlicingSpec(
              baseline_spec=config.SlicingSpec(),
              slicing_specs=[config.SlicingSpec(feature_values={'b': '2'})
                             ])),
         (False, 'value fail', ((('a', 1), ), (('b', 3), )),
          config.CrossSlicingSpec(
              baseline_spec=config.SlicingSpec(feature_values={'a': '1'}),
              slicing_specs=[config.SlicingSpec(feature_values={'b': '2'})
                             ])),
         (False, 'baseline key fail', ((('c', 1), ), (('b', 2), )),
          config.CrossSlicingSpec(
              baseline_spec=config.SlicingSpec(feature_keys=['a']),
              slicing_specs=[config.SlicingSpec(feature_values={'b': '2'})
                             ])),
         (False, 'comparison key fail', ((('a', 1), ), (('c', 3), )),
          config.CrossSlicingSpec(
              baseline_spec=config.SlicingSpec(feature_values={'a': '1'}),
              slicing_specs=[config.SlicingSpec(feature_keys=['b'])])),
         (False, 'comparison multiple key fail', ((('a', 1), ), (('d',
                                                                  3), )),
          config.CrossSlicingSpec(
              baseline_spec=config.SlicingSpec(feature_values={'a': '1'}),
              slicing_specs=[
                  config.SlicingSpec(feature_keys=['b']),
                  config.SlicingSpec(feature_keys=['c'])
              ])),
     ]  # pyformat: disable
     for (expected_result, name, sliced_key, slicing_spec) in test_cases:
         self.assertEqual(expected_result,
                          slicer.is_cross_slice_applicable(
                              cross_slice_key=sliced_key,
                              cross_slicing_spec=slicing_spec),
                          msg=name)
Пример #4
0
    def testMetricThresholdsFromMetricsSpecs(self):
        slice_specs = [
            config.SlicingSpec(feature_keys=['feature1']),
            config.SlicingSpec(feature_values={'feature2': 'value1'})
        ]

        # For cross slice tests.
        baseline_slice_spec = config.SlicingSpec(feature_keys=['feature3'])

        metrics_specs = [
            config.MetricsSpec(
                thresholds={
                    'auc':
                    config.MetricThreshold(
                        value_threshold=config.GenericValueThreshold()),
                    'mean/label':
                    config.MetricThreshold(
                        value_threshold=config.GenericValueThreshold(),
                        change_threshold=config.GenericChangeThreshold()),
                    'mse':
                    config.MetricThreshold(
                        change_threshold=config.GenericChangeThreshold())
                },
                per_slice_thresholds={
                    'auc':
                    config.PerSliceMetricThresholds(thresholds=[
                        config.PerSliceMetricThreshold(
                            slicing_specs=slice_specs,
                            threshold=config.MetricThreshold(
                                value_threshold=config.GenericValueThreshold(
                                )))
                    ]),
                    'mean/label':
                    config.PerSliceMetricThresholds(thresholds=[
                        config.PerSliceMetricThreshold(
                            slicing_specs=slice_specs,
                            threshold=config.MetricThreshold(
                                value_threshold=config.GenericValueThreshold(),
                                change_threshold=config.GenericChangeThreshold(
                                )))
                    ])
                },
                cross_slice_thresholds={
                    'auc':
                    config.CrossSliceMetricThresholds(thresholds=[
                        config.CrossSliceMetricThreshold(
                            cross_slicing_specs=[
                                config.CrossSlicingSpec(
                                    baseline_spec=baseline_slice_spec,
                                    slicing_specs=slice_specs)
                            ],
                            threshold=config.MetricThreshold(
                                value_threshold=config.GenericValueThreshold(),
                                change_threshold=config.GenericChangeThreshold(
                                )))
                    ]),
                    'mse':
                    config.CrossSliceMetricThresholds(thresholds=[
                        config.CrossSliceMetricThreshold(
                            cross_slicing_specs=[
                                config.CrossSlicingSpec(
                                    baseline_spec=baseline_slice_spec,
                                    slicing_specs=slice_specs)
                            ],
                            threshold=config.MetricThreshold(
                                change_threshold=config.GenericChangeThreshold(
                                ))),
                        # Test for duplicate cross_slicing_spec.
                        config.CrossSliceMetricThreshold(
                            cross_slicing_specs=[
                                config.CrossSlicingSpec(
                                    baseline_spec=baseline_slice_spec,
                                    slicing_specs=slice_specs)
                            ],
                            threshold=config.MetricThreshold(
                                value_threshold=config.GenericValueThreshold())
                        )
                    ])
                },
                model_names=['model_name'],
                output_names=['output_name']),
            config.MetricsSpec(metrics=[
                config.MetricConfig(
                    class_name='ExampleCount',
                    config=json.dumps({'name': 'example_count'}),
                    threshold=config.MetricThreshold(
                        value_threshold=config.GenericValueThreshold()))
            ],
                               model_names=['model_name1', 'model_name2'],
                               output_names=['output_name1', 'output_name2']),
            config.MetricsSpec(metrics=[
                config.MetricConfig(
                    class_name='WeightedExampleCount',
                    config=json.dumps({'name': 'weighted_example_count'}),
                    threshold=config.MetricThreshold(
                        value_threshold=config.GenericValueThreshold()))
            ],
                               model_names=['model_name1', 'model_name2'],
                               output_names=['output_name1', 'output_name2']),
            config.MetricsSpec(metrics=[
                config.MetricConfig(
                    class_name='MeanSquaredError',
                    config=json.dumps({'name': 'mse'}),
                    threshold=config.MetricThreshold(
                        change_threshold=config.GenericChangeThreshold())),
                config.MetricConfig(
                    class_name='MeanLabel',
                    config=json.dumps({'name': 'mean_label'}),
                    threshold=config.MetricThreshold(
                        change_threshold=config.GenericChangeThreshold()),
                    per_slice_thresholds=[
                        config.PerSliceMetricThreshold(
                            slicing_specs=slice_specs,
                            threshold=config.MetricThreshold(
                                change_threshold=config.GenericChangeThreshold(
                                ))),
                    ],
                    cross_slice_thresholds=[
                        config.CrossSliceMetricThreshold(
                            cross_slicing_specs=[
                                config.CrossSlicingSpec(
                                    baseline_spec=baseline_slice_spec,
                                    slicing_specs=slice_specs)
                            ],
                            threshold=config.MetricThreshold(
                                change_threshold=config.GenericChangeThreshold(
                                )))
                    ]),
            ],
                               model_names=['model_name'],
                               output_names=['output_name'],
                               binarize=config.BinarizationOptions(
                                   class_ids={'values': [0, 1]}),
                               aggregate=config.AggregationOptions(
                                   macro_average=True,
                                   class_weights={
                                       0: 1.0,
                                       1: 1.0
                                   }))
        ]

        thresholds = metric_specs.metric_thresholds_from_metrics_specs(
            metrics_specs)

        expected_keys_and_threshold_counts = {
            metric_types.MetricKey(name='auc',
                                   model_name='model_name',
                                   output_name='output_name',
                                   is_diff=False):
            4,
            metric_types.MetricKey(name='auc',
                                   model_name='model_name',
                                   output_name='output_name',
                                   is_diff=True):
            1,
            metric_types.MetricKey(name='mean/label',
                                   model_name='model_name',
                                   output_name='output_name',
                                   is_diff=True):
            3,
            metric_types.MetricKey(name='mean/label',
                                   model_name='model_name',
                                   output_name='output_name',
                                   is_diff=False):
            3,
            metric_types.MetricKey(name='example_count',
                                   model_name='model_name1',
                                   output_name='output_name1'):
            1,
            metric_types.MetricKey(name='example_count',
                                   model_name='model_name1',
                                   output_name='output_name2'):
            1,
            metric_types.MetricKey(name='example_count',
                                   model_name='model_name2',
                                   output_name='output_name1'):
            1,
            metric_types.MetricKey(name='example_count',
                                   model_name='model_name2',
                                   output_name='output_name2'):
            1,
            metric_types.MetricKey(name='weighted_example_count',
                                   model_name='model_name1',
                                   output_name='output_name1'):
            1,
            metric_types.MetricKey(name='weighted_example_count',
                                   model_name='model_name1',
                                   output_name='output_name2'):
            1,
            metric_types.MetricKey(name='weighted_example_count',
                                   model_name='model_name2',
                                   output_name='output_name1'):
            1,
            metric_types.MetricKey(name='weighted_example_count',
                                   model_name='model_name2',
                                   output_name='output_name2'):
            1,
            metric_types.MetricKey(name='mse',
                                   model_name='model_name',
                                   output_name='output_name',
                                   sub_key=metric_types.SubKey(class_id=0),
                                   is_diff=True):
            1,
            metric_types.MetricKey(name='mse',
                                   model_name='model_name',
                                   output_name='output_name',
                                   sub_key=metric_types.SubKey(class_id=1),
                                   is_diff=True):
            1,
            metric_types.MetricKey(name='mse',
                                   model_name='model_name',
                                   output_name='output_name',
                                   is_diff=True):
            2,
            metric_types.MetricKey(name='mse',
                                   model_name='model_name',
                                   output_name='output_name',
                                   is_diff=False):
            1,
            metric_types.MetricKey(name='mse',
                                   model_name='model_name',
                                   output_name='output_name',
                                   aggregation_type=metric_types.AggregationType(macro_average=True),
                                   is_diff=True):
            1,
            metric_types.MetricKey(name='mean_label',
                                   model_name='model_name',
                                   output_name='output_name',
                                   sub_key=metric_types.SubKey(class_id=0),
                                   is_diff=True):
            4,
            metric_types.MetricKey(name='mean_label',
                                   model_name='model_name',
                                   output_name='output_name',
                                   sub_key=metric_types.SubKey(class_id=1),
                                   is_diff=True):
            4,
            metric_types.MetricKey(name='mean_label',
                                   model_name='model_name',
                                   output_name='output_name',
                                   aggregation_type=metric_types.AggregationType(macro_average=True),
                                   is_diff=True):
            4
        }
        self.assertLen(thresholds, len(expected_keys_and_threshold_counts))
        for key, count in expected_keys_and_threshold_counts.items():
            self.assertIn(key, thresholds)
            self.assertLen(thresholds[key], count,
                           'failed for key {}'.format(key))