def test_infer_schema_with_infer_shape(self): statistics = text_format.Parse( """ datasets { num_examples: 7 features: { name: 'feature1' type: STRING string_stats: { common_stats: { num_missing: 0 num_non_missing: 7 min_num_values: 1 max_num_values: 1 } unique: 3 } } features: { name: 'feature2' type: STRING string_stats: { common_stats: { num_missing: 0 num_non_missing: 7 min_num_values: 3 max_num_values: 3 } unique: 5 } } } """, statistics_pb2.DatasetFeatureStatisticsList()) expected_schema = text_format.Parse( """ feature { name: "feature1" shape { dim { size: 1 } } presence: { min_fraction: 1.0 min_count: 1 } type: BYTES } feature { name: "feature2" value_count: { min: 1 } presence: { min_fraction: 1.0 min_count: 1 } type: BYTES } """, schema_pb2.Schema()) # Infer the schema from the stats. actual_schema = validation_api.infer_schema(statistics, infer_feature_shape=True) self.assertEqual(actual_schema, expected_schema)
def test_infer_schema_without_string_domain(self): statistics = text_format.Parse( """ datasets { num_examples: 7 features: { name: 'feature1' type: STRING string_stats: { common_stats: { num_missing: 3 num_non_missing: 4 min_num_values: 1 max_num_values: 1 } unique: 3 rank_histogram { buckets { low_rank: 0 high_rank: 0 label: "a" sample_count: 2.0 } buckets { low_rank: 1 high_rank: 1 label: "b" sample_count: 1.0 } buckets { low_rank: 2 high_rank: 2 label: "c" sample_count: 1.0 } } } } } """, statistics_pb2.DatasetFeatureStatisticsList()) expected_schema = text_format.Parse( """ feature { name: "feature1" value_count: { min: 1 max: 1 } presence: { min_count: 1 } type: BYTES } """, schema_pb2.Schema()) # Infer the schema from the stats. actual_schema = validation_api.infer_schema(statistics, max_string_domain_size=2) self.assertEqual(actual_schema, expected_schema)
def test_infer_schema_invalid_multiple_datasets(self): statistics = statistics_pb2.DatasetFeatureStatisticsList() statistics.datasets.extend([ statistics_pb2.DatasetFeatureStatistics(), statistics_pb2.DatasetFeatureStatistics() ]) with self.assertRaisesRegexp(ValueError, '.*statistics proto with one dataset.*'): _ = validation_api.infer_schema(statistics)
def test_infer_schema(self): statistics = text_format.Parse( """ datasets { num_examples: 7 features: { name: 'feature1' type: STRING string_stats: { common_stats: { num_non_missing: 7 min_num_values: 1 max_num_values: 1 } unique: 3 } } } """, statistics_pb2.DatasetFeatureStatisticsList()) expected_schema = text_format.Parse( """ feature { name: "feature1" value_count: { min: 1 max: 1 } presence: { min_fraction: 1.0 min_count: 1 } type: BYTES } """, schema_pb2.Schema()) validation_api._may_be_set_legacy_flag(expected_schema) # Infer the schema from the stats. actual_schema = validation_api.infer_schema(statistics, infer_feature_shape=False) self.assertEqual(actual_schema, expected_schema)
def test_infer_schema_invalid_statistics_input(self): with self.assertRaisesRegexp( TypeError, '.*should be a DatasetFeatureStatisticsList proto.*'): _ = validation_api.infer_schema({})