def test_partition_slices_with_metric_sub_key(self): metrics = self._get_metrics() # Set sub_key. for metric in metrics: for kv in metric.metric_keys_and_values: kv.key.sub_key.MergeFrom( metric_types.SubKey(class_id=0).to_proto()) result = auto_slicing_util.partition_slices( metrics, metric_key=metric_types.MetricKey( name='accuracy', sub_key=metric_types.SubKey(class_id=0)), comparison_type='LOWER') self.assertCountEqual([s.slice_key for s in result[0]], [(('age', '[1.0, 6.0)'), )]) self.assertCountEqual([s.slice_key for s in result[1]], [(('age', '[6.0, 12.0)'), ), (('age', '[12.0, 18.0)'), ), (('country', 'USA'), ), (('country', 'USA'), ('age', '[12.0, 18.0)'))]) result = auto_slicing_util.partition_slices( metrics, metric_key=metric_types.MetricKey( name='accuracy', sub_key=metric_types.SubKey(class_id=0)), comparison_type='HIGHER') self.assertCountEqual([s.slice_key for s in result[0]], [(('age', '[12.0, 18.0)'), ), (('country', 'USA'), ), (('country', 'USA'), ('age', '[12.0, 18.0)'))]) self.assertCountEqual([s.slice_key for s in result[1]], [(('age', '[1.0, 6.0)'), ), (('age', '[6.0, 12.0)'), )])
def test_partition_slices_without_metric_sub_key(self): metrics = self._get_metrics() result = auto_slicing_util.partition_slices( metrics, metric_key=metric_types.MetricKey(name='accuracy'), comparison_type='LOWER') self.assertCountEqual([s.slice_key for s in result[0]], [(('age', '[1.0, 6.0)'), )]) self.assertCountEqual([s.slice_key for s in result[1]], [(('age', '[6.0, 12.0)'), ), (('age', '[12.0, 18.0)'), ), (('country', 'USA'), ), (('country', 'USA'), ('age', '[12.0, 18.0)'))]) result = auto_slicing_util.partition_slices( metrics, metric_key=metric_types.MetricKey(name='accuracy'), comparison_type='HIGHER') self.assertCountEqual([s.slice_key for s in result[0]], [(('age', '[12.0, 18.0)'), ), (('country', 'USA'), ), (('country', 'USA'), ('age', '[12.0, 18.0)'))]) self.assertCountEqual([s.slice_key for s in result[1]], [(('age', '[1.0, 6.0)'), ), (('age', '[6.0, 12.0)'), )])
def test_find_significant_slices(self): metrics = [ text_format.Parse( """ slice_key { } metric_keys_and_values { key { name: "accuracy" } value { bounded_value { value { value: 0.8 } lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } t_distribution_value { sample_mean { value: 0.8 } sample_standard_deviation { value: 0.1 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 0.8 } } } } } metric_keys_and_values { key { name: "example_count" } value { bounded_value { value { value: 1500 } lower_bound { value: 1500 } upper_bound { value: 1500 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 1500 } upper_bound { value: 1500 } t_distribution_value { sample_mean { value: 1500 } sample_standard_deviation { value: 0 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 1500 } } } } } """, metrics_for_slice_pb2.MetricsForSlice()), text_format.Parse( """ slice_key { single_slice_keys { column: 'age' bytes_value: '[1.0, 6.0)' } } metric_keys_and_values { key { name: "accuracy" } value { bounded_value { value { value: 0.4 } lower_bound { value: 0.3737843 } upper_bound { value: 0.6262157 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 0.3737843 } upper_bound { value: 0.6262157 } t_distribution_value { sample_mean { value: 0.4 } sample_standard_deviation { value: 0.1 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 0.4 } } } } } metric_keys_and_values { key { name: "example_count" } value { bounded_value { value { value: 500 } lower_bound { value: 500 } upper_bound { value: 500 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 500 } upper_bound { value: 500 } t_distribution_value { sample_mean { value: 500 } sample_standard_deviation { value: 0 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 500 } } } } } """, metrics_for_slice_pb2.MetricsForSlice()), text_format.Parse( """ slice_key { single_slice_keys { column: 'age' bytes_value: '[6.0, 12.0)' } } metric_keys_and_values { key { name: "accuracy" } value { bounded_value { value { value: 0.79 } lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } t_distribution_value { sample_mean { value: 0.79 } sample_standard_deviation { value: 0.1 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 0.79 } } } } } metric_keys_and_values { key { name: "example_count" } value { bounded_value { value { value: 500 } lower_bound { value: 500 } upper_bound { value: 500 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 500 } upper_bound { value: 500 } t_distribution_value { sample_mean { value: 500 } sample_standard_deviation { value: 0 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 500} } } } } """, metrics_for_slice_pb2.MetricsForSlice()), text_format.Parse( """ slice_key { single_slice_keys { column: 'age' bytes_value: '[12.0, 18.0)' } } metric_keys_and_values { key { name: "accuracy" } value { bounded_value { value { value: 0.9 } lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } t_distribution_value { sample_mean { value: 0.9 } sample_standard_deviation { value: 0.1 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 0.9 } } } } } metric_keys_and_values { key { name: "example_count" } value { bounded_value { value { value: 500 } lower_bound { value: 500 } upper_bound { value: 500 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 500 } upper_bound { value: 500 } t_distribution_value { sample_mean { value: 500 } sample_standard_deviation { value: 0 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 500} } } } } """, metrics_for_slice_pb2.MetricsForSlice()), text_format.Parse( """ slice_key { single_slice_keys { column: 'country' bytes_value: 'USA' } } metric_keys_and_values { key { name: "accuracy" } value { bounded_value { value { value: 0.9 } lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } t_distribution_value { sample_mean { value: 0.9 } sample_standard_deviation { value: 0.1 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 0.9 } } } } } metric_keys_and_values { key { name: "example_count" } value { bounded_value { value { value: 500 } lower_bound { value: 500 } upper_bound { value: 500 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 500 } upper_bound { value: 500 } t_distribution_value { sample_mean { value: 500 } sample_standard_deviation { value: 0 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 500} } } } } """, metrics_for_slice_pb2.MetricsForSlice()), text_format.Parse( """ slice_key { single_slice_keys { column: 'country' bytes_value: 'USA' } single_slice_keys { column: 'age' bytes_value: '[12.0, 18.0)' } } metric_keys_and_values { key { name: "accuracy" } value { bounded_value { value { value: 0.9 } lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 0.5737843 } upper_bound { value: 1.0262157 } t_distribution_value { sample_mean { value: 0.9 } sample_standard_deviation { value: 0.1 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 0.9 } } } } } metric_keys_and_values { key { name: "example_count" } value { bounded_value { value { value: 500 } lower_bound { value: 500 } upper_bound { value: 500 } methodology: POISSON_BOOTSTRAP } confidence_interval { lower_bound { value: 500 } upper_bound { value: 500 } t_distribution_value { sample_mean { value: 500 } sample_standard_deviation { value: 0 } sample_degrees_of_freedom { value: 9 } unsampled_value { value: 500} } } } } """, metrics_for_slice_pb2.MetricsForSlice()) ] result = auto_slicing_util.partition_slices(metrics, metric_key='accuracy', comparison_type='LOWER') self.assertCountEqual([s.slice_key for s in result[0]], [(('age', '[1.0, 6.0)'), )]) self.assertCountEqual([s.slice_key for s in result[1]], [(('age', '[6.0, 12.0)'), ), (('age', '[12.0, 18.0)'), ), (('country', 'USA'), ), (('country', 'USA'), ('age', '[12.0, 18.0)'))]) result = auto_slicing_util.partition_slices(metrics, metric_key='accuracy', comparison_type='HIGHER') self.assertCountEqual([s.slice_key for s in result[0]], [(('age', '[12.0, 18.0)'), ), (('country', 'USA'), ), (('country', 'USA'), ('age', '[12.0, 18.0)'))]) self.assertCountEqual([s.slice_key for s in result[1]], [(('age', '[1.0, 6.0)'), ), (('age', '[6.0, 12.0)'), )])