def test_get_bool_domain_feature_level_domain(self): schema = text_format.Parse( """ feature { name: "feature1" bool_domain { name: "domain1" } } """, schema_pb2.Schema()) domain1 = schema_util.get_domain(schema, 'feature1') self.assertIsInstance(domain1, schema_pb2.BoolDomain) self.assertEqual(domain1.name, 'domain1') # Check to verify that we are operating on the same domain object. self.assertIs(domain1, schema_util.get_domain(schema, 'feature1'))
def test_get_string_domain_schema_level_domain(self): schema = text_format.Parse( """ string_domain { name: "domain1" } string_domain { name: "domain2" } feature { name: "feature1" domain: "domain2" } """, schema_pb2.Schema()) domain2 = schema_util.get_domain(schema, 'feature1') self.assertIsInstance(domain2, schema_pb2.StringDomain) self.assertEqual(domain2.name, 'domain2') # Check to verify that we are operating on the same domain object. self.assertIs(domain2, schema_util.get_domain(schema, 'feature1'))
def test_get_domain_not_present(self): schema = text_format.Parse( """ string_domain { name: "domain1" } feature { name: "feature1" } """, schema_pb2.Schema()) with self.assertRaisesRegexp(ValueError, 'has no domain associated'): _ = schema_util.get_domain(schema, 'feature1')
def test_raise_on_get_struct_domain(self): schema = text_format.Parse( """ feature { name: "feature1" type: STRUCT struct_domain { feature { name: "sub_feature1" } } } """, schema_pb2.Schema()) with self.assertRaisesRegexp(ValueError, 'has an unsupported domain'): _ = schema_util.get_domain(schema, types.FeaturePath(['feature1']))
def test_get_domain_using_path(self): schema = text_format.Parse( """ feature { name: "feature1" type: STRUCT struct_domain { feature { name: "sub_feature1" bool_domain { name: "domain1" } } } } """, schema_pb2.Schema()) domain1 = schema_util.get_domain( schema, types.FeaturePath(['feature1', 'sub_feature1'])) self.assertIs(domain1, schema.feature[0].struct_domain.feature[0].bool_domain)
def test_get_domain_invalid_schema_input(self): with self.assertRaisesRegexp(TypeError, '.*should be a Schema proto.*'): _ = schema_util.get_domain({}, 'feature')
def test_update_schema(self): schema = text_format.Parse( """ string_domain { name: "MyAloneEnum" value: "A" value: "B" value: "C" } feature { name: "annotated_enum" value_count { min:1 max:1 } presence { min_count: 1 } type: BYTES domain: "MyAloneEnum" } feature { name: "ignore_this" lifecycle_stage: DEPRECATED value_count { min:1 } presence { min_count: 1 } type: BYTES } """, schema_pb2.Schema()) statistics = text_format.Parse( """ datasets{ num_examples: 10 features { name: 'annotated_enum' type: STRING string_stats { common_stats { num_missing: 3 num_non_missing: 7 min_num_values: 1 max_num_values: 1 } unique: 3 rank_histogram { buckets { label: "D" sample_count: 1 } } } } } """, statistics_pb2.DatasetFeatureStatisticsList()) expected_anomalies = { 'annotated_enum': text_format.Parse( """ description: "Examples contain values missing from the schema: D (?). " severity: ERROR short_description: "Unexpected string values" reason { type: ENUM_TYPE_UNEXPECTED_STRING_VALUES short_description: "Unexpected string values" description: "Examples contain values missing from the schema: D (?). " } """, anomalies_pb2.AnomalyInfo()) } # Validate the stats. anomalies = validation_api.validate_statistics(statistics, schema) self._assert_equal_anomalies(anomalies, expected_anomalies) # Verify the updated schema. actual_updated_schema = validation_api.update_schema(schema, statistics) expected_updated_schema = schema schema_util.get_domain(expected_updated_schema, 'annotated_enum').value.append('D') self.assertEqual(actual_updated_schema, expected_updated_schema) # Verify that there are no anomalies with the updated schema. actual_updated_anomalies = validation_api.validate_statistics( statistics, actual_updated_schema) self._assert_equal_anomalies(actual_updated_anomalies, {})