def test_create_row_coder_from_schema(self): schema = schema_pb2.Schema( id="person", fields=[ schema_pb2.Field( name="name", type=schema_pb2.FieldType(atomic_type=schema_pb2.STRING)), schema_pb2.Field( name="age", type=schema_pb2.FieldType(atomic_type=schema_pb2.INT32)), schema_pb2.Field( name="address", type=schema_pb2.FieldType( atomic_type=schema_pb2.STRING, nullable=True)), schema_pb2.Field( name="aliases", type=schema_pb2.FieldType( array_type=schema_pb2.ArrayType( element_type=schema_pb2.FieldType( atomic_type=schema_pb2.STRING)))), schema_pb2.Field( name="knows_javascript", type=schema_pb2.FieldType(atomic_type=schema_pb2.BOOLEAN)), schema_pb2.Field( name="payload", type=schema_pb2.FieldType( atomic_type=schema_pb2.BYTES, nullable=True)), schema_pb2.Field( name="custom_metadata", type=schema_pb2.FieldType( map_type=schema_pb2.MapType( key_type=schema_pb2.FieldType( atomic_type=schema_pb2.STRING), value_type=schema_pb2.FieldType( atomic_type=schema_pb2.INT64), ))), schema_pb2.Field( name="favorite_time", type=schema_pb2.FieldType( logical_type=schema_pb2.LogicalType( urn="beam:logical_type:micros_instant:v1", representation=schema_pb2.FieldType( row_type=schema_pb2.RowType( schema=schema_pb2.Schema( id="micros_instant", fields=[ schema_pb2.Field( name="seconds", type=schema_pb2.FieldType( atomic_type=schema_pb2.INT64)), schema_pb2.Field( name="micros", type=schema_pb2.FieldType( atomic_type=schema_pb2.INT64)), ])))))), ]) coder = RowCoder(schema) for test_case in self.PEOPLE: self.assertEqual(test_case, coder.decode(coder.encode(test_case)))
def test_encoding_position_reorder_fields(self): schema1 = schema_pb2.Schema( id="reorder_test_schema1", fields=[ schema_pb2.Field( name="f_int32", type=schema_pb2.FieldType(atomic_type=schema_pb2.INT32), ), schema_pb2.Field( name="f_str", type=schema_pb2.FieldType(atomic_type=schema_pb2.STRING), ), ]) schema2 = schema_pb2.Schema( id="reorder_test_schema2", encoding_positions_set=True, fields=[ schema_pb2.Field( name="f_str", type=schema_pb2.FieldType(atomic_type=schema_pb2.STRING), encoding_position=1, ), schema_pb2.Field( name="f_int32", type=schema_pb2.FieldType(atomic_type=schema_pb2.INT32), encoding_position=0, ), ]) RowSchema1 = named_tuple_from_schema(schema1) RowSchema2 = named_tuple_from_schema(schema2) roundtripped = RowCoder(schema2).decode( RowCoder(schema1).encode(RowSchema1(42, "Hello World!"))) self.assertEqual(RowSchema2(f_int32=42, f_str="Hello World!"), roundtripped)
def named_fields_to_schema(names_and_types): return schema_pb2.Schema( fields=[ schema_pb2.Field(name=name, type=typing_to_runner_api(type)) for (name, type) in names_and_types ], id=str(uuid4()))
def named_fields_to_schema(names_and_types): # type: (Sequence[Tuple[str, type]]) -> schema_pb2.Schema return schema_pb2.Schema(fields=[ schema_pb2.Field(name=name, type=typing_to_runner_api(type)) for (name, type) in names_and_types ], id=str(uuid4()))
def test_create_row_coder_from_schema(self): schema = schema_pb2.Schema( id="person", fields=[ schema_pb2.Field( name="name", type=schema_pb2.FieldType(atomic_type=schema_pb2.STRING)), schema_pb2.Field( name="age", type=schema_pb2.FieldType(atomic_type=schema_pb2.INT32)), schema_pb2.Field(name="address", type=schema_pb2.FieldType( atomic_type=schema_pb2.STRING, nullable=True)), schema_pb2.Field( name="aliases", type=schema_pb2.FieldType(array_type=schema_pb2.ArrayType( element_type=schema_pb2.FieldType( atomic_type=schema_pb2.STRING)))), schema_pb2.Field( name="knows_javascript", type=schema_pb2.FieldType(atomic_type=schema_pb2.BOOLEAN)), schema_pb2.Field(name="payload", type=schema_pb2.FieldType( atomic_type=schema_pb2.BYTES, nullable=True)), ]) coder = RowCoder(schema) for test_case in self.PEOPLE: self.assertEqual(test_case, coder.decode(coder.encode(test_case)))
def test_encoding_position_reorder_fields(self): fields = [("field1", str), ("field2", int), ("field3", int)] expected = typing.NamedTuple('expected', fields) reorder = schema_pb2.Schema( id="new_order", fields=[ schema_pb2.Field( name="field3", type=schema_pb2.FieldType(atomic_type=schema_pb2.STRING), encoding_position=2), schema_pb2.Field( name="field2", type=schema_pb2.FieldType(atomic_type=schema_pb2.INT32), encoding_position=1), schema_pb2.Field( name="field1", type=schema_pb2.FieldType(atomic_type=schema_pb2.INT32), encoding_position=0) ]) old_coder = RowCoder.from_type_hint(expected, None) new_coder = RowCoder(reorder) encode_expected = old_coder.encode(expected("foo", 7, 12)) encode_reorder = new_coder.encode(expected(12, 7, "foo")) self.assertEqual(encode_expected, encode_reorder)
def typing_to_runner_api(type_): if match_is_named_tuple(type_): schema = None if hasattr(type_, _BEAM_SCHEMA_ID): schema = SCHEMA_REGISTRY.get_schema_by_id(getattr(type_, _BEAM_SCHEMA_ID)) if schema is None: fields = [ schema_pb2.Field( name=name, type=typing_to_runner_api(type_._field_types[name])) for name in type_._fields ] type_id = str(uuid4()) schema = schema_pb2.Schema(fields=fields, id=type_id) setattr(type_, _BEAM_SCHEMA_ID, type_id) SCHEMA_REGISTRY.add(type_, schema) return schema_pb2.FieldType(row_type=schema_pb2.RowType(schema=schema)) # All concrete types (other than NamedTuple sub-classes) should map to # a supported primitive type. elif type_ in PRIMITIVE_TO_ATOMIC_TYPE: return schema_pb2.FieldType(atomic_type=PRIMITIVE_TO_ATOMIC_TYPE[type_]) elif _match_is_exactly_mapping(type_): key_type, value_type = map(typing_to_runner_api, _get_args(type_)) return schema_pb2.FieldType( map_type=schema_pb2.MapType(key_type=key_type, value_type=value_type)) elif _match_is_optional(type_): # It's possible that a user passes us Optional[Optional[T]], but in python # typing this is indistinguishable from Optional[T] - both resolve to # Union[T, None] - so there's no need to check for that case here. result = typing_to_runner_api(extract_optional_type(type_)) result.nullable = True return result elif _safe_issubclass(type_, Sequence): element_type = typing_to_runner_api(_get_args(type_)[0]) return schema_pb2.FieldType( array_type=schema_pb2.ArrayType(element_type=element_type)) elif _safe_issubclass(type_, Mapping): key_type, value_type = map(typing_to_runner_api, _get_args(type_)) return schema_pb2.FieldType( map_type=schema_pb2.MapType(key_type=key_type, value_type=value_type)) try: logical_type = LogicalType.from_typing(type_) except ValueError: # Unknown type, just treat it like Any return schema_pb2.FieldType( logical_type=schema_pb2.LogicalType(urn=PYTHON_ANY_URN)) else: # TODO(bhulette): Add support for logical types that require arguments return schema_pb2.FieldType( logical_type=schema_pb2.LogicalType( urn=logical_type.urn(), representation=typing_to_runner_api( logical_type.representation_type())))
def test_schema_with_bad_field_raises_helpful_error(self): schema_proto = schema_pb2.Schema(fields=[ schema_pb2.Field(name="type_with_no_typeinfo", type=schema_pb2.FieldType()) ]) # Should raise an exception referencing the problem field self.assertRaisesRegex(ValueError, "type_with_no_typeinfo", lambda: named_tuple_from_schema(schema_proto))
def named_fields_to_schema(names_and_types): # type: (Union[Dict[str, type], Sequence[Tuple[str, type]]]) -> schema_pb2.Schema if isinstance(names_and_types, dict): names_and_types = names_and_types.items() return schema_pb2.Schema(fields=[ schema_pb2.Field(name=name, type=typing_to_runner_api(type)) for (name, type) in names_and_types ], id=str(uuid4()))
def test_row_coder_fail_early_bad_schema(self): schema_proto = schema_pb2.Schema(fields=[ schema_pb2.Field(name="type_with_no_typeinfo", type=schema_pb2.FieldType()) ]) # Should raise an exception referencing the problem field self.assertRaisesRegex(ValueError, "type_with_no_typeinfo", lambda: RowCoder(schema_proto))
def test_row_coder_cloud_object_schema(self): schema_proto = schema_pb2.Schema() schema_proto_json = json_format.MessageToJson(schema_proto).encode('utf-8') coder = RowCoder(schema_proto) cloud_object = coder.as_cloud_object() self.assertEqual(schema_proto_json, cloud_object['schema'])
def typing_to_runner_api(type_): if _match_is_named_tuple(type_): schema = None if hasattr(type_, _BEAM_SCHEMA_ID): schema = SCHEMA_REGISTRY.get_schema_by_id( getattr(type_, _BEAM_SCHEMA_ID)) if schema is None: fields = [ schema_pb2.Field(name=name, type=typing_to_runner_api( type_._field_types[name])) for name in type_._fields ] type_id = str(uuid4()) schema = schema_pb2.Schema(fields=fields, id=type_id) setattr(type_, _BEAM_SCHEMA_ID, type_id) SCHEMA_REGISTRY.add(type_, schema) return schema_pb2.FieldType(row_type=schema_pb2.RowType(schema=schema)) # All concrete types (other than NamedTuple sub-classes) should map to # a supported primitive type. elif type_ in PRIMITIVE_TO_ATOMIC_TYPE: return schema_pb2.FieldType( atomic_type=PRIMITIVE_TO_ATOMIC_TYPE[type_]) elif sys.version_info.major == 2 and type_ == str: raise ValueError( "type 'str' is not supported in python 2. Please use 'unicode' or " "'typing.ByteString' instead to unambiguously indicate if this is a " "UTF-8 string or a byte array.") elif _match_is_exactly_mapping(type_): key_type, value_type = map(typing_to_runner_api, _get_args(type_)) return schema_pb2.FieldType(map_type=schema_pb2.MapType( key_type=key_type, value_type=value_type)) elif _match_is_optional(type_): # It's possible that a user passes us Optional[Optional[T]], but in python # typing this is indistinguishable from Optional[T] - both resolve to # Union[T, None] - so there's no need to check for that case here. result = typing_to_runner_api(extract_optional_type(type_)) result.nullable = True return result elif _safe_issubclass(type_, Sequence): element_type = typing_to_runner_api(_get_args(type_)[0]) return schema_pb2.FieldType(array_type=schema_pb2.ArrayType( element_type=element_type)) elif _safe_issubclass(type_, Mapping): key_type, value_type = map(typing_to_runner_api, _get_args(type_)) return schema_pb2.FieldType(map_type=schema_pb2.MapType( key_type=key_type, value_type=value_type)) raise ValueError("Unsupported type: %s" % type_)
def test_encoding_position_add_fields_and_reorder(self): old_schema = schema_pb2.Schema( id="add_test_old", fields=[ schema_pb2.Field( name="f_int32", type=schema_pb2.FieldType(atomic_type=schema_pb2.INT32), ), schema_pb2.Field( name="f_str", type=schema_pb2.FieldType(atomic_type=schema_pb2.STRING), ), ]) new_schema = schema_pb2.Schema( encoding_positions_set=True, id="add_test_new", fields=[ schema_pb2.Field( name="f_new_str", type=schema_pb2.FieldType( atomic_type=schema_pb2.STRING, nullable=True), encoding_position=2, ), schema_pb2.Field( name="f_int32", type=schema_pb2.FieldType(atomic_type=schema_pb2.INT32), encoding_position=0, ), schema_pb2.Field( name="f_str", type=schema_pb2.FieldType(atomic_type=schema_pb2.STRING), encoding_position=1, ), ]) Old = named_tuple_from_schema(old_schema) New = named_tuple_from_schema(new_schema) roundtripped = RowCoder(new_schema).decode( RowCoder(old_schema).encode(Old(42, "Hello World!"))) self.assertEqual( New(f_new_str=None, f_int32=42, f_str="Hello World!"), roundtripped)
def test_generated_class_pickle(self): schema = schema_pb2.Schema( id="some-uuid", fields=[ schema_pb2.Field( name='name', type=schema_pb2.FieldType(atomic_type=schema_pb2.STRING), ) ]) user_type = named_tuple_from_schema(schema) instance = user_type(name="test") self.assertEqual(instance, pickle.loads(pickle.dumps(instance)))
def test_schema_with_bad_field_raises_helpful_error(self): schema_proto = schema_pb2.Schema( fields=[ schema_pb2.Field(name="type_with_no_typeinfo", type=schema_pb2.FieldType()) ], id="helpful-error-uuid", ) # Should raise an exception referencing the problem field self.assertRaisesRegex( ValueError, "type_with_no_typeinfo", lambda: named_tuple_from_schema( schema_proto, # bypass schema cache schema_registry=SchemaTypeRegistry()))
def test_trivial_example(self): MyCuteClass = NamedTuple( 'MyCuteClass', [ ('name', unicode), ('age', Optional[int]), ('interests', List[unicode]), ('height', float), ('blob', ByteString), ]) expected = schema_pb2.FieldType( row_type=schema_pb2.RowType( schema=schema_pb2.Schema( fields=[ schema_pb2.Field( name='name', type=schema_pb2.FieldType( atomic_type=schema_pb2.STRING), ), schema_pb2.Field( name='age', type=schema_pb2.FieldType( nullable=True, atomic_type=schema_pb2.INT64)), schema_pb2.Field( name='interests', type=schema_pb2.FieldType( array_type=schema_pb2.ArrayType( element_type=schema_pb2.FieldType( atomic_type=schema_pb2.STRING)))), schema_pb2.Field( name='height', type=schema_pb2.FieldType( atomic_type=schema_pb2.DOUBLE)), schema_pb2.Field( name='blob', type=schema_pb2.FieldType( atomic_type=schema_pb2.BYTES)), ]))) # Only test that the fields are equal. If we attempt to test the entire type # or the entire schema, the generated id will break equality. self.assertEqual( expected.row_type.schema.fields, typing_to_runner_api(MyCuteClass).row_type.schema.fields)
def test_proto_survives_typing_roundtrip(self): all_nonoptional_primitives = [ schema_pb2.FieldType(atomic_type=typ) for typ in schema_pb2.AtomicType.values() if typ is not schema_pb2.UNSPECIFIED ] # The bytes type cannot survive a roundtrip to/from proto in Python 2. # In order to use BYTES a user type has to use typing.ByteString (because # bytes == str, and we map str to STRING). if not IS_PYTHON_3: all_nonoptional_primitives.remove( schema_pb2.FieldType(atomic_type=schema_pb2.BYTES)) all_optional_primitives = [ schema_pb2.FieldType(nullable=True, atomic_type=typ) for typ in schema_pb2.AtomicType.values() if typ is not schema_pb2.UNSPECIFIED ] all_primitives = all_nonoptional_primitives + all_optional_primitives basic_array_types = [ schema_pb2.FieldType(array_type=schema_pb2.ArrayType( element_type=typ)) for typ in all_primitives ] basic_map_types = [ schema_pb2.FieldType(map_type=schema_pb2.MapType( key_type=key_type, value_type=value_type)) for key_type, value_type in itertools.product( all_primitives, all_primitives) ] selected_schemas = [ schema_pb2.FieldType(row_type=schema_pb2.RowType( schema=schema_pb2.Schema( id='32497414-85e8-46b7-9c90-9a9cc62fe390', fields=[ schema_pb2.Field(name='field%d' % i, type=typ) for i, typ in enumerate(all_primitives) ]))), schema_pb2.FieldType(row_type=schema_pb2.RowType( schema=schema_pb2.Schema( id='dead1637-3204-4bcb-acf8-99675f338600', fields=[ schema_pb2.Field(name='id', type=schema_pb2.FieldType( atomic_type=schema_pb2.INT64)), schema_pb2.Field(name='name', type=schema_pb2.FieldType( atomic_type=schema_pb2.STRING)), schema_pb2.Field( name='optional_map', type=schema_pb2.FieldType( nullable=True, map_type=schema_pb2.MapType( key_type=schema_pb2.FieldType( atomic_type=schema_pb2.STRING), value_type=schema_pb2.FieldType( atomic_type=schema_pb2.DOUBLE)))), schema_pb2.Field( name='optional_array', type=schema_pb2.FieldType( nullable=True, array_type=schema_pb2.ArrayType( element_type=schema_pb2.FieldType( atomic_type=schema_pb2.FLOAT)))), schema_pb2.Field( name='array_optional', type=schema_pb2.FieldType( array_type=schema_pb2.ArrayType( element_type=schema_pb2.FieldType( nullable=True, atomic_type=schema_pb2.BYTES)))), ]))), ] test_cases = all_primitives + \ basic_array_types + \ basic_map_types + \ selected_schemas for test_case in test_cases: self.assertEqual( test_case, typing_to_runner_api(typing_from_runner_api(test_case)))
def test_proto_survives_typing_roundtrip(self): all_nonoptional_primitives = [ schema_pb2.FieldType(atomic_type=typ) for typ in schema_pb2.AtomicType.values() if typ is not schema_pb2.UNSPECIFIED ] all_optional_primitives = [ schema_pb2.FieldType(nullable=True, atomic_type=typ) for typ in schema_pb2.AtomicType.values() if typ is not schema_pb2.UNSPECIFIED ] all_primitives = all_nonoptional_primitives + all_optional_primitives basic_array_types = [ schema_pb2.FieldType(array_type=schema_pb2.ArrayType( element_type=typ)) for typ in all_primitives ] basic_map_types = [ schema_pb2.FieldType(map_type=schema_pb2.MapType( key_type=key_type, value_type=value_type)) for key_type, value_type in itertools.product( all_primitives, all_primitives) ] selected_schemas = [ schema_pb2.FieldType(row_type=schema_pb2.RowType( schema=schema_pb2.Schema( id='32497414-85e8-46b7-9c90-9a9cc62fe390', fields=[ schema_pb2.Field(name='field%d' % i, type=typ) for i, typ in enumerate(all_primitives) ]))), schema_pb2.FieldType(row_type=schema_pb2.RowType( schema=schema_pb2.Schema( id='dead1637-3204-4bcb-acf8-99675f338600', fields=[ schema_pb2.Field(name='id', type=schema_pb2.FieldType( atomic_type=schema_pb2.INT64)), schema_pb2.Field(name='name', type=schema_pb2.FieldType( atomic_type=schema_pb2.STRING)), schema_pb2.Field( name='optional_map', type=schema_pb2.FieldType( nullable=True, map_type=schema_pb2.MapType( key_type=schema_pb2.FieldType( atomic_type=schema_pb2.STRING), value_type=schema_pb2.FieldType( atomic_type=schema_pb2.DOUBLE)))), schema_pb2.Field( name='optional_array', type=schema_pb2.FieldType( nullable=True, array_type=schema_pb2.ArrayType( element_type=schema_pb2.FieldType( atomic_type=schema_pb2.FLOAT)))), schema_pb2.Field( name='array_optional', type=schema_pb2.FieldType( array_type=schema_pb2.ArrayType( element_type=schema_pb2.FieldType( nullable=True, atomic_type=schema_pb2.BYTES)))), ]))), ] test_cases = all_primitives + \ basic_array_types + \ basic_map_types + \ selected_schemas for test_case in test_cases: self.assertEqual( test_case, typing_to_runner_api(typing_from_runner_api( test_case, schema_registry=SchemaTypeRegistry()), schema_registry=SchemaTypeRegistry()))