def test_set_domain_invalid_global_domain(self): schema = schema_pb2.Schema() schema.feature.add(name='feature') schema.string_domain.add(name='domain1', value=['a', 'b']) with self.assertRaisesRegexp(ValueError, 'Invalid global string domain'): schema_util.set_domain(schema, 'feature', 'domain2')
def test_set_domain(self, input_schema_proto_text, feature_name, domain, output_schema_proto_text): actual_schema = schema_pb2.Schema() text_format.Merge(input_schema_proto_text, actual_schema) schema_util.set_domain(actual_schema, feature_name, domain) expected_schema = schema_pb2.Schema() text_format.Merge(output_schema_proto_text, expected_schema) self.assertEqual(actual_schema, expected_schema)
def test_set_domain_invalid_domain(self): with self.assertRaisesRegexp(TypeError, 'domain is of type'): schema_util.set_domain(schema_pb2.Schema(), 'feature', {})
def test_set_domain_invalid_schema(self): with self.assertRaisesRegexp(TypeError, 'should be a Schema proto'): schema_util.set_domain({}, 'feature', schema_pb2.IntDomain())