def test_create_row_coder_from_schema(self): schema = schema_pb2.Schema( id="person", fields=[ schema_pb2.Field( name="name", type=schema_pb2.FieldType(atomic_type=schema_pb2.STRING)), schema_pb2.Field( name="age", type=schema_pb2.FieldType(atomic_type=schema_pb2.INT32)), schema_pb2.Field(name="address", type=schema_pb2.FieldType( atomic_type=schema_pb2.STRING, nullable=True)), schema_pb2.Field( name="aliases", type=schema_pb2.FieldType(array_type=schema_pb2.ArrayType( element_type=schema_pb2.FieldType( atomic_type=schema_pb2.STRING)))), schema_pb2.Field( name="knows_javascript", type=schema_pb2.FieldType(atomic_type=schema_pb2.BOOLEAN)), schema_pb2.Field(name="payload", type=schema_pb2.FieldType( atomic_type=schema_pb2.BYTES, nullable=True)), ]) coder = RowCoder(schema) for test_case in self.PEOPLE: self.assertEqual(test_case, coder.decode(coder.encode(test_case)))
def test_create_row_coder_from_schema(self): schema = schema_pb2.Schema( id="person", fields=[ schema_pb2.Field( name="name", type=schema_pb2.FieldType(atomic_type=schema_pb2.STRING)), schema_pb2.Field( name="age", type=schema_pb2.FieldType(atomic_type=schema_pb2.INT32)), schema_pb2.Field( name="address", type=schema_pb2.FieldType( atomic_type=schema_pb2.STRING, nullable=True)), schema_pb2.Field( name="aliases", type=schema_pb2.FieldType( array_type=schema_pb2.ArrayType( element_type=schema_pb2.FieldType( atomic_type=schema_pb2.STRING)))), schema_pb2.Field( name="knows_javascript", type=schema_pb2.FieldType(atomic_type=schema_pb2.BOOLEAN)), schema_pb2.Field( name="payload", type=schema_pb2.FieldType( atomic_type=schema_pb2.BYTES, nullable=True)), schema_pb2.Field( name="custom_metadata", type=schema_pb2.FieldType( map_type=schema_pb2.MapType( key_type=schema_pb2.FieldType( atomic_type=schema_pb2.STRING), value_type=schema_pb2.FieldType( atomic_type=schema_pb2.INT64), ))), schema_pb2.Field( name="favorite_time", type=schema_pb2.FieldType( logical_type=schema_pb2.LogicalType( urn="beam:logical_type:micros_instant:v1", representation=schema_pb2.FieldType( row_type=schema_pb2.RowType( schema=schema_pb2.Schema( id="micros_instant", fields=[ schema_pb2.Field( name="seconds", type=schema_pb2.FieldType( atomic_type=schema_pb2.INT64)), schema_pb2.Field( name="micros", type=schema_pb2.FieldType( atomic_type=schema_pb2.INT64)), ])))))), ]) coder = RowCoder(schema) for test_case in self.PEOPLE: self.assertEqual(test_case, coder.decode(coder.encode(test_case)))
def typing_to_runner_api(type_): if match_is_named_tuple(type_): schema = None if hasattr(type_, _BEAM_SCHEMA_ID): schema = SCHEMA_REGISTRY.get_schema_by_id(getattr(type_, _BEAM_SCHEMA_ID)) if schema is None: fields = [ schema_pb2.Field( name=name, type=typing_to_runner_api(type_._field_types[name])) for name in type_._fields ] type_id = str(uuid4()) schema = schema_pb2.Schema(fields=fields, id=type_id) setattr(type_, _BEAM_SCHEMA_ID, type_id) SCHEMA_REGISTRY.add(type_, schema) return schema_pb2.FieldType(row_type=schema_pb2.RowType(schema=schema)) # All concrete types (other than NamedTuple sub-classes) should map to # a supported primitive type. elif type_ in PRIMITIVE_TO_ATOMIC_TYPE: return schema_pb2.FieldType(atomic_type=PRIMITIVE_TO_ATOMIC_TYPE[type_]) elif _match_is_exactly_mapping(type_): key_type, value_type = map(typing_to_runner_api, _get_args(type_)) return schema_pb2.FieldType( map_type=schema_pb2.MapType(key_type=key_type, value_type=value_type)) elif _match_is_optional(type_): # It's possible that a user passes us Optional[Optional[T]], but in python # typing this is indistinguishable from Optional[T] - both resolve to # Union[T, None] - so there's no need to check for that case here. result = typing_to_runner_api(extract_optional_type(type_)) result.nullable = True return result elif _safe_issubclass(type_, Sequence): element_type = typing_to_runner_api(_get_args(type_)[0]) return schema_pb2.FieldType( array_type=schema_pb2.ArrayType(element_type=element_type)) elif _safe_issubclass(type_, Mapping): key_type, value_type = map(typing_to_runner_api, _get_args(type_)) return schema_pb2.FieldType( map_type=schema_pb2.MapType(key_type=key_type, value_type=value_type)) try: logical_type = LogicalType.from_typing(type_) except ValueError: # Unknown type, just treat it like Any return schema_pb2.FieldType( logical_type=schema_pb2.LogicalType(urn=PYTHON_ANY_URN)) else: # TODO(bhulette): Add support for logical types that require arguments return schema_pb2.FieldType( logical_type=schema_pb2.LogicalType( urn=logical_type.urn(), representation=typing_to_runner_api( logical_type.representation_type())))
def typing_to_runner_api(type_): if _match_is_named_tuple(type_): schema = None if hasattr(type_, _BEAM_SCHEMA_ID): schema = SCHEMA_REGISTRY.get_schema_by_id( getattr(type_, _BEAM_SCHEMA_ID)) if schema is None: fields = [ schema_pb2.Field(name=name, type=typing_to_runner_api( type_._field_types[name])) for name in type_._fields ] type_id = str(uuid4()) schema = schema_pb2.Schema(fields=fields, id=type_id) setattr(type_, _BEAM_SCHEMA_ID, type_id) SCHEMA_REGISTRY.add(type_, schema) return schema_pb2.FieldType(row_type=schema_pb2.RowType(schema=schema)) # All concrete types (other than NamedTuple sub-classes) should map to # a supported primitive type. elif type_ in PRIMITIVE_TO_ATOMIC_TYPE: return schema_pb2.FieldType( atomic_type=PRIMITIVE_TO_ATOMIC_TYPE[type_]) elif sys.version_info.major == 2 and type_ == str: raise ValueError( "type 'str' is not supported in python 2. Please use 'unicode' or " "'typing.ByteString' instead to unambiguously indicate if this is a " "UTF-8 string or a byte array.") elif _match_is_exactly_mapping(type_): key_type, value_type = map(typing_to_runner_api, _get_args(type_)) return schema_pb2.FieldType(map_type=schema_pb2.MapType( key_type=key_type, value_type=value_type)) elif _match_is_optional(type_): # It's possible that a user passes us Optional[Optional[T]], but in python # typing this is indistinguishable from Optional[T] - both resolve to # Union[T, None] - so there's no need to check for that case here. result = typing_to_runner_api(extract_optional_type(type_)) result.nullable = True return result elif _safe_issubclass(type_, Sequence): element_type = typing_to_runner_api(_get_args(type_)[0]) return schema_pb2.FieldType(array_type=schema_pb2.ArrayType( element_type=element_type)) elif _safe_issubclass(type_, Mapping): key_type, value_type = map(typing_to_runner_api, _get_args(type_)) return schema_pb2.FieldType(map_type=schema_pb2.MapType( key_type=key_type, value_type=value_type)) raise ValueError("Unsupported type: %s" % type_)
def test_trivial_example(self): MyCuteClass = NamedTuple( 'MyCuteClass', [ ('name', unicode), ('age', Optional[int]), ('interests', List[unicode]), ('height', float), ('blob', ByteString), ]) expected = schema_pb2.FieldType( row_type=schema_pb2.RowType( schema=schema_pb2.Schema( fields=[ schema_pb2.Field( name='name', type=schema_pb2.FieldType( atomic_type=schema_pb2.STRING), ), schema_pb2.Field( name='age', type=schema_pb2.FieldType( nullable=True, atomic_type=schema_pb2.INT64)), schema_pb2.Field( name='interests', type=schema_pb2.FieldType( array_type=schema_pb2.ArrayType( element_type=schema_pb2.FieldType( atomic_type=schema_pb2.STRING)))), schema_pb2.Field( name='height', type=schema_pb2.FieldType( atomic_type=schema_pb2.DOUBLE)), schema_pb2.Field( name='blob', type=schema_pb2.FieldType( atomic_type=schema_pb2.BYTES)), ]))) # Only test that the fields are equal. If we attempt to test the entire type # or the entire schema, the generated id will break equality. self.assertEqual( expected.row_type.schema.fields, typing_to_runner_api(MyCuteClass).row_type.schema.fields)
def test_proto_survives_typing_roundtrip(self): all_nonoptional_primitives = [ schema_pb2.FieldType(atomic_type=typ) for typ in schema_pb2.AtomicType.values() if typ is not schema_pb2.UNSPECIFIED ] # The bytes type cannot survive a roundtrip to/from proto in Python 2. # In order to use BYTES a user type has to use typing.ByteString (because # bytes == str, and we map str to STRING). if not IS_PYTHON_3: all_nonoptional_primitives.remove( schema_pb2.FieldType(atomic_type=schema_pb2.BYTES)) all_optional_primitives = [ schema_pb2.FieldType(nullable=True, atomic_type=typ) for typ in schema_pb2.AtomicType.values() if typ is not schema_pb2.UNSPECIFIED ] all_primitives = all_nonoptional_primitives + all_optional_primitives basic_array_types = [ schema_pb2.FieldType(array_type=schema_pb2.ArrayType( element_type=typ)) for typ in all_primitives ] basic_map_types = [ schema_pb2.FieldType(map_type=schema_pb2.MapType( key_type=key_type, value_type=value_type)) for key_type, value_type in itertools.product( all_primitives, all_primitives) ] selected_schemas = [ schema_pb2.FieldType(row_type=schema_pb2.RowType( schema=schema_pb2.Schema( id='32497414-85e8-46b7-9c90-9a9cc62fe390', fields=[ schema_pb2.Field(name='field%d' % i, type=typ) for i, typ in enumerate(all_primitives) ]))), schema_pb2.FieldType(row_type=schema_pb2.RowType( schema=schema_pb2.Schema( id='dead1637-3204-4bcb-acf8-99675f338600', fields=[ schema_pb2.Field(name='id', type=schema_pb2.FieldType( atomic_type=schema_pb2.INT64)), schema_pb2.Field(name='name', type=schema_pb2.FieldType( atomic_type=schema_pb2.STRING)), schema_pb2.Field( name='optional_map', type=schema_pb2.FieldType( nullable=True, map_type=schema_pb2.MapType( key_type=schema_pb2.FieldType( atomic_type=schema_pb2.STRING), value_type=schema_pb2.FieldType( atomic_type=schema_pb2.DOUBLE)))), schema_pb2.Field( name='optional_array', type=schema_pb2.FieldType( nullable=True, array_type=schema_pb2.ArrayType( element_type=schema_pb2.FieldType( atomic_type=schema_pb2.FLOAT)))), schema_pb2.Field( name='array_optional', type=schema_pb2.FieldType( array_type=schema_pb2.ArrayType( element_type=schema_pb2.FieldType( nullable=True, atomic_type=schema_pb2.BYTES)))), ]))), ] test_cases = all_primitives + \ basic_array_types + \ basic_map_types + \ selected_schemas for test_case in test_cases: self.assertEqual( test_case, typing_to_runner_api(typing_from_runner_api(test_case)))
def test_proto_survives_typing_roundtrip(self): all_nonoptional_primitives = [ schema_pb2.FieldType(atomic_type=typ) for typ in schema_pb2.AtomicType.values() if typ is not schema_pb2.UNSPECIFIED ] all_optional_primitives = [ schema_pb2.FieldType(nullable=True, atomic_type=typ) for typ in schema_pb2.AtomicType.values() if typ is not schema_pb2.UNSPECIFIED ] all_primitives = all_nonoptional_primitives + all_optional_primitives basic_array_types = [ schema_pb2.FieldType(array_type=schema_pb2.ArrayType( element_type=typ)) for typ in all_primitives ] basic_map_types = [ schema_pb2.FieldType(map_type=schema_pb2.MapType( key_type=key_type, value_type=value_type)) for key_type, value_type in itertools.product( all_primitives, all_primitives) ] selected_schemas = [ schema_pb2.FieldType(row_type=schema_pb2.RowType( schema=schema_pb2.Schema( id='32497414-85e8-46b7-9c90-9a9cc62fe390', fields=[ schema_pb2.Field(name='field%d' % i, type=typ) for i, typ in enumerate(all_primitives) ]))), schema_pb2.FieldType(row_type=schema_pb2.RowType( schema=schema_pb2.Schema( id='dead1637-3204-4bcb-acf8-99675f338600', fields=[ schema_pb2.Field(name='id', type=schema_pb2.FieldType( atomic_type=schema_pb2.INT64)), schema_pb2.Field(name='name', type=schema_pb2.FieldType( atomic_type=schema_pb2.STRING)), schema_pb2.Field( name='optional_map', type=schema_pb2.FieldType( nullable=True, map_type=schema_pb2.MapType( key_type=schema_pb2.FieldType( atomic_type=schema_pb2.STRING), value_type=schema_pb2.FieldType( atomic_type=schema_pb2.DOUBLE)))), schema_pb2.Field( name='optional_array', type=schema_pb2.FieldType( nullable=True, array_type=schema_pb2.ArrayType( element_type=schema_pb2.FieldType( atomic_type=schema_pb2.FLOAT)))), schema_pb2.Field( name='array_optional', type=schema_pb2.FieldType( array_type=schema_pb2.ArrayType( element_type=schema_pb2.FieldType( nullable=True, atomic_type=schema_pb2.BYTES)))), ]))), ] test_cases = all_primitives + \ basic_array_types + \ basic_map_types + \ selected_schemas for test_case in test_cases: self.assertEqual( test_case, typing_to_runner_api(typing_from_runner_api( test_case, schema_registry=SchemaTypeRegistry()), schema_registry=SchemaTypeRegistry()))