def _assert_test(self, num_buckets, baseline_examples, comparison_examples, lift_metric_value, ignore_out_of_bound_examples=False): eval_config = config.EvalConfig( cross_slicing_specs=[config.CrossSlicingSpec()]) computations = lift.Lift( num_buckets=num_buckets, ignore_out_of_bound_examples=ignore_out_of_bound_examples ).computations(eval_config=eval_config) histogram = computations[0] lift_metrics = computations[1] with beam.Pipeline() as pipeline: # pylint: disable=no-value-for-parameter baseline_result = ( pipeline | 'CreateB' >> beam.Create(baseline_examples) | 'ProcessB' >> beam.Map(metric_util.to_standard_metric_inputs) | 'AddSliceB' >> beam.Map(lambda x: ((), x)) | 'ComputeHistogramB' >> beam.CombinePerKey(histogram.combiner) ) # pyformat: ignore comparison_result = ( pipeline | 'CreateC' >> beam.Create(comparison_examples) | 'ProcessC' >> beam.Map(metric_util.to_standard_metric_inputs) | 'AddSliceC' >> beam.Map(lambda x: (('slice'), x)) | 'ComputeHistogramC' >> beam.CombinePerKey(histogram.combiner) ) # pyformat: ignore # pylint: enable=no-value-for-parameter merged_result = ((baseline_result, comparison_result) | 'MergePCollections' >> beam.Flatten()) def check_result(got): try: self.assertLen(got, 2) slice_1, metric_1 = got[0] slice_2, metric_2 = got[1] lift_value = None if not slice_1: lift_value = lift_metrics.cross_slice_comparison( metric_1, metric_2) else: lift_value = lift_metrics.cross_slice_comparison( metric_2, metric_1) self.assertDictElementsAlmostEqual( lift_value, { metric_types.MetricKey(name=f'lift@{num_buckets}'): lift_metric_value, }) except AssertionError as err: raise util.BeamAssertException(err) util.assert_that(merged_result, check_result, label='result')
config.SlicingSpec(feature_values={'feature1': 'value1'}), config.SlicingSpec(feature_values={'feature2': 'value2'}) ], (('feature1', 'value1'), )) _UNMATCHED_SINGLE_SLICE_TEST = ('single_slice', [config.SlicingSpec(feature_keys='feature1')], (('unmatched_feature', 'unmatched_value'), )) _UNMATCHED_MULTIPLE_SLICES_TEST = ('multiple_slices', [ config.SlicingSpec(feature_values={'feature1': 'value1'}), config.SlicingSpec(feature_values={'feature2': 'value2'}) ], (('unmatched_feature', 'unmatched_value'), )) # Cross slice tests: (<test_name>, <cross_slice_config>, <cross_slice_key>) _CROSS_SLICE_GLOBAL_TEST = ('global_slice', [ config.CrossSlicingSpec( baseline_spec=config.SlicingSpec(), slicing_specs=[ config.SlicingSpec(feature_values={'feature2': 'value2'}) ]) ], ((()), (('feature2', 'value2'), ))) _SINGLE_CROSS_SLICE_TEST = ('single_slice', [ config.CrossSlicingSpec( baseline_spec=config.SlicingSpec(feature_keys=['feature1']), slicing_specs=[ config.SlicingSpec(feature_values={'feature2': 'value2'}) ]) ], ((('feature1', 'value1'), ), (('feature2', 'value2'), ))) _MULTIPLE_CROSS_SLICE_TEST = ('multiple_slice', [ config.CrossSlicingSpec( baseline_spec=config.SlicingSpec(feature_keys=['feature1']), slicing_specs=[ config.SlicingSpec(feature_values={'feature2': 'value2'})
def testIsCrossSliceApplicable(self): test_cases = [ (True, 'overall pass', ((), (('b', 2), )), config.CrossSlicingSpec( baseline_spec=config.SlicingSpec(), slicing_specs=[config.SlicingSpec(feature_values={'b': '2'}) ])), (True, 'value pass', ((('a', 1), ), (('b', 2), )), config.CrossSlicingSpec( baseline_spec=config.SlicingSpec(feature_values={'a': '1'}), slicing_specs=[config.SlicingSpec(feature_values={'b': '2'}) ])), (True, 'baseline key pass', ((('a', 1), ), (('b', 2), )), config.CrossSlicingSpec( baseline_spec=config.SlicingSpec(feature_keys=['a']), slicing_specs=[config.SlicingSpec(feature_values={'b': '2'}) ])), (True, 'comparison key pass', ((('a', 1), ), (('b', 2), )), config.CrossSlicingSpec( baseline_spec=config.SlicingSpec(feature_values={'a': '1'}), slicing_specs=[config.SlicingSpec(feature_keys=['b'])])), (True, 'comparison multiple key pass', ((('a', 1), ), (('c', 3), )), config.CrossSlicingSpec( baseline_spec=config.SlicingSpec(feature_values={'a': '1'}), slicing_specs=[ config.SlicingSpec(feature_keys=['b']), config.SlicingSpec(feature_keys=['c']) ])), (False, 'overall fail', ((('a', 1), ), (('b', 2), )), config.CrossSlicingSpec( baseline_spec=config.SlicingSpec(), slicing_specs=[config.SlicingSpec(feature_values={'b': '2'}) ])), (False, 'value fail', ((('a', 1), ), (('b', 3), )), config.CrossSlicingSpec( baseline_spec=config.SlicingSpec(feature_values={'a': '1'}), slicing_specs=[config.SlicingSpec(feature_values={'b': '2'}) ])), (False, 'baseline key fail', ((('c', 1), ), (('b', 2), )), config.CrossSlicingSpec( baseline_spec=config.SlicingSpec(feature_keys=['a']), slicing_specs=[config.SlicingSpec(feature_values={'b': '2'}) ])), (False, 'comparison key fail', ((('a', 1), ), (('c', 3), )), config.CrossSlicingSpec( baseline_spec=config.SlicingSpec(feature_values={'a': '1'}), slicing_specs=[config.SlicingSpec(feature_keys=['b'])])), (False, 'comparison multiple key fail', ((('a', 1), ), (('d', 3), )), config.CrossSlicingSpec( baseline_spec=config.SlicingSpec(feature_values={'a': '1'}), slicing_specs=[ config.SlicingSpec(feature_keys=['b']), config.SlicingSpec(feature_keys=['c']) ])), ] # pyformat: disable for (expected_result, name, sliced_key, slicing_spec) in test_cases: self.assertEqual(expected_result, slicer.is_cross_slice_applicable( cross_slice_key=sliced_key, cross_slicing_spec=slicing_spec), msg=name)
def testMetricThresholdsFromMetricsSpecs(self): slice_specs = [ config.SlicingSpec(feature_keys=['feature1']), config.SlicingSpec(feature_values={'feature2': 'value1'}) ] # For cross slice tests. baseline_slice_spec = config.SlicingSpec(feature_keys=['feature3']) metrics_specs = [ config.MetricsSpec( thresholds={ 'auc': config.MetricThreshold( value_threshold=config.GenericValueThreshold()), 'mean/label': config.MetricThreshold( value_threshold=config.GenericValueThreshold(), change_threshold=config.GenericChangeThreshold()), 'mse': config.MetricThreshold( change_threshold=config.GenericChangeThreshold()) }, per_slice_thresholds={ 'auc': config.PerSliceMetricThresholds(thresholds=[ config.PerSliceMetricThreshold( slicing_specs=slice_specs, threshold=config.MetricThreshold( value_threshold=config.GenericValueThreshold( ))) ]), 'mean/label': config.PerSliceMetricThresholds(thresholds=[ config.PerSliceMetricThreshold( slicing_specs=slice_specs, threshold=config.MetricThreshold( value_threshold=config.GenericValueThreshold(), change_threshold=config.GenericChangeThreshold( ))) ]) }, cross_slice_thresholds={ 'auc': config.CrossSliceMetricThresholds(thresholds=[ config.CrossSliceMetricThreshold( cross_slicing_specs=[ config.CrossSlicingSpec( baseline_spec=baseline_slice_spec, slicing_specs=slice_specs) ], threshold=config.MetricThreshold( value_threshold=config.GenericValueThreshold(), change_threshold=config.GenericChangeThreshold( ))) ]), 'mse': config.CrossSliceMetricThresholds(thresholds=[ config.CrossSliceMetricThreshold( cross_slicing_specs=[ config.CrossSlicingSpec( baseline_spec=baseline_slice_spec, slicing_specs=slice_specs) ], threshold=config.MetricThreshold( change_threshold=config.GenericChangeThreshold( ))), # Test for duplicate cross_slicing_spec. config.CrossSliceMetricThreshold( cross_slicing_specs=[ config.CrossSlicingSpec( baseline_spec=baseline_slice_spec, slicing_specs=slice_specs) ], threshold=config.MetricThreshold( value_threshold=config.GenericValueThreshold()) ) ]) }, model_names=['model_name'], output_names=['output_name']), config.MetricsSpec(metrics=[ config.MetricConfig( class_name='ExampleCount', config=json.dumps({'name': 'example_count'}), threshold=config.MetricThreshold( value_threshold=config.GenericValueThreshold())) ], model_names=['model_name1', 'model_name2'], output_names=['output_name1', 'output_name2']), config.MetricsSpec(metrics=[ config.MetricConfig( class_name='WeightedExampleCount', config=json.dumps({'name': 'weighted_example_count'}), threshold=config.MetricThreshold( value_threshold=config.GenericValueThreshold())) ], model_names=['model_name1', 'model_name2'], output_names=['output_name1', 'output_name2']), config.MetricsSpec(metrics=[ config.MetricConfig( class_name='MeanSquaredError', config=json.dumps({'name': 'mse'}), threshold=config.MetricThreshold( change_threshold=config.GenericChangeThreshold())), config.MetricConfig( class_name='MeanLabel', config=json.dumps({'name': 'mean_label'}), threshold=config.MetricThreshold( change_threshold=config.GenericChangeThreshold()), per_slice_thresholds=[ config.PerSliceMetricThreshold( slicing_specs=slice_specs, threshold=config.MetricThreshold( change_threshold=config.GenericChangeThreshold( ))), ], cross_slice_thresholds=[ config.CrossSliceMetricThreshold( cross_slicing_specs=[ config.CrossSlicingSpec( baseline_spec=baseline_slice_spec, slicing_specs=slice_specs) ], threshold=config.MetricThreshold( change_threshold=config.GenericChangeThreshold( ))) ]), ], model_names=['model_name'], output_names=['output_name'], binarize=config.BinarizationOptions( class_ids={'values': [0, 1]}), aggregate=config.AggregationOptions( macro_average=True, class_weights={ 0: 1.0, 1: 1.0 })) ] thresholds = metric_specs.metric_thresholds_from_metrics_specs( metrics_specs) expected_keys_and_threshold_counts = { metric_types.MetricKey(name='auc', model_name='model_name', output_name='output_name', is_diff=False): 4, metric_types.MetricKey(name='auc', model_name='model_name', output_name='output_name', is_diff=True): 1, metric_types.MetricKey(name='mean/label', model_name='model_name', output_name='output_name', is_diff=True): 3, metric_types.MetricKey(name='mean/label', model_name='model_name', output_name='output_name', is_diff=False): 3, metric_types.MetricKey(name='example_count', model_name='model_name1', output_name='output_name1'): 1, metric_types.MetricKey(name='example_count', model_name='model_name1', output_name='output_name2'): 1, metric_types.MetricKey(name='example_count', model_name='model_name2', output_name='output_name1'): 1, metric_types.MetricKey(name='example_count', model_name='model_name2', output_name='output_name2'): 1, metric_types.MetricKey(name='weighted_example_count', model_name='model_name1', output_name='output_name1'): 1, metric_types.MetricKey(name='weighted_example_count', model_name='model_name1', output_name='output_name2'): 1, metric_types.MetricKey(name='weighted_example_count', model_name='model_name2', output_name='output_name1'): 1, metric_types.MetricKey(name='weighted_example_count', model_name='model_name2', output_name='output_name2'): 1, metric_types.MetricKey(name='mse', model_name='model_name', output_name='output_name', sub_key=metric_types.SubKey(class_id=0), is_diff=True): 1, metric_types.MetricKey(name='mse', model_name='model_name', output_name='output_name', sub_key=metric_types.SubKey(class_id=1), is_diff=True): 1, metric_types.MetricKey(name='mse', model_name='model_name', output_name='output_name', is_diff=True): 2, metric_types.MetricKey(name='mse', model_name='model_name', output_name='output_name', is_diff=False): 1, metric_types.MetricKey(name='mse', model_name='model_name', output_name='output_name', aggregation_type=metric_types.AggregationType(macro_average=True), is_diff=True): 1, metric_types.MetricKey(name='mean_label', model_name='model_name', output_name='output_name', sub_key=metric_types.SubKey(class_id=0), is_diff=True): 4, metric_types.MetricKey(name='mean_label', model_name='model_name', output_name='output_name', sub_key=metric_types.SubKey(class_id=1), is_diff=True): 4, metric_types.MetricKey(name='mean_label', model_name='model_name', output_name='output_name', aggregation_type=metric_types.AggregationType(macro_average=True), is_diff=True): 4 } self.assertLen(thresholds, len(expected_keys_and_threshold_counts)) for key, count in expected_keys_and_threshold_counts.items(): self.assertIn(key, thresholds) self.assertLen(thresholds[key], count, 'failed for key {}'.format(key))