def test_typing_survives_proto_roundtrip(self): all_nonoptional_primitives = [ np.int8, np.int16, np.int32, np.int64, np.float32, np.float64, bool, bytes, str, ] all_optional_primitives = [ Optional[typ] for typ in all_nonoptional_primitives ] all_primitives = all_nonoptional_primitives + all_optional_primitives basic_array_types = [Sequence[typ] for typ in all_primitives] basic_map_types = [ Mapping[key_type, value_type] for key_type, value_type in itertools.product( all_primitives, all_primitives) ] selected_schemas = [ NamedTuple('AllPrimitives', [('field%d' % i, typ) for i, typ in enumerate(all_primitives)]), NamedTuple('ComplexSchema', [ ('id', np.int64), ('name', str), ('optional_map', Optional[Mapping[str, Optional[np.float64]]]), ('optional_array', Optional[Sequence[np.float32]]), ('array_optional', Sequence[Optional[bool]]), ('timestamp', Timestamp), ]) ] test_cases = all_primitives + \ basic_array_types + \ basic_map_types for test_case in test_cases: self.assertEqual( test_case, typing_from_runner_api(typing_to_runner_api( test_case, schema_registry=SchemaTypeRegistry()), schema_registry=SchemaTypeRegistry())) # Break out NamedTuple types since they require special verification for test_case in selected_schemas: self.assert_namedtuple_equivalent( test_case, typing_from_runner_api(typing_to_runner_api( test_case, schema_registry=SchemaTypeRegistry()), schema_registry=SchemaTypeRegistry()))
def test_python_callable_maps_to_logical_type(self): from apache_beam.utils.python_callable import PythonCallableWithSource self.assertEqual( schema_pb2.FieldType(logical_type=schema_pb2.LogicalType( urn=common_urns.python_callable.urn, representation=typing_to_runner_api(str))), typing_to_runner_api(PythonCallableWithSource)) self.assertEqual( typing_from_runner_api( schema_pb2.FieldType(logical_type=schema_pb2.LogicalType( urn=common_urns.python_callable.urn, representation=typing_to_runner_api(str)))), PythonCallableWithSource)
def test_row_coder_nested_struct(self): Pair = typing.NamedTuple('Pair', [('left', Person), ('right', Person)]) value = Pair(self.PEOPLE[0], self.PEOPLE[1]) coder = RowCoder(typing_to_runner_api(Pair).row_type.schema) self.assertEqual(value, coder.decode(coder.encode(value)))
def __init__( self, table_name, driver_class_name, jdbc_url, username, password, statement=None, connection_properties=None, connection_init_sqls=None, expansion_service=None, classpath=None, ): """ Initializes a write operation to Jdbc. :param driver_class_name: name of the jdbc driver class :param jdbc_url: full jdbc url to the database. :param username: database username :param password: database password :param statement: sql statement to be executed :param connection_properties: properties of the jdbc connection passed as string with format [propertyName=property;]* :param connection_init_sqls: required only for MySql and MariaDB. passed as list of strings :param expansion_service: The address (host:port) of the ExpansionService. :param classpath: A list of JARs or Java packages to include in the classpath for the expansion service. This option is usually needed for `jdbc` to include extra JDBC driver packages. The packages can be in these three formats: (1) A local file, (2) A URL, (3) A gradle-style identifier of a Maven package (e.g. "org.postgresql:postgresql:42.3.1"). By default, this argument includes a Postgres SQL JDBC driver. """ classpath = classpath or DEFAULT_JDBC_CLASSPATH super().__init__( self.URN, NamedTupleBasedPayloadBuilder( JdbcConfigSchema( location=table_name, config=RowCoder( typing_to_runner_api(Config).row_type.schema).encode( Config( driver_class_name=driver_class_name, jdbc_url=jdbc_url, username=username, password=password, connection_properties=connection_properties, connection_init_sqls=connection_init_sqls, write_statement=statement, read_query=None, fetch_size=None, output_parallelization=None, ))), ), expansion_service or default_io_expansion_service(classpath), )
def test_typing_survives_proto_roundtrip(self): all_nonoptional_primitives = [ np.int8, np.int16, np.int32, np.int64, np.float32, np.float64, unicode, bool, ] # The bytes type cannot survive a roundtrip to/from proto in Python 2. # In order to use BYTES a user type has to use typing.ByteString (because # bytes == str, and we map str to STRING). # TODO(BEAM-7372) if IS_PYTHON_3: all_nonoptional_primitives.extend([bytes]) all_optional_primitives = [ Optional[typ] for typ in all_nonoptional_primitives ] all_primitives = all_nonoptional_primitives + all_optional_primitives basic_array_types = [Sequence[typ] for typ in all_primitives] basic_map_types = [ Mapping[key_type, value_type] for key_type, value_type in itertools.product(all_primitives, all_primitives) ] selected_schemas = [ NamedTuple( 'AllPrimitives', [('field%d' % i, typ) for i, typ in enumerate(all_primitives)]), NamedTuple( 'ComplexSchema', [ ('id', np.int64), ('name', unicode), ( 'optional_map', Optional[Mapping[unicode, Optional[np.float64]]]), ('optional_array', Optional[Sequence[np.float32]]), ('array_optional', Sequence[Optional[bool]]), ('timestamp', Timestamp), ]) ] test_cases = all_primitives + \ basic_array_types + \ basic_map_types + \ selected_schemas for test_case in test_cases: self.assertEqual( test_case, typing_from_runner_api(typing_to_runner_api(test_case)))
def test_create_row_coder_from_named_tuple(self): expected_coder = RowCoder(typing_to_runner_api(Person).row_type.schema) real_coder = coders_registry.get_coder(Person) for test_case in self.PEOPLE: self.assertEqual(expected_coder.encode(test_case), real_coder.encode(test_case)) self.assertEqual(test_case, real_coder.decode(real_coder.encode(test_case)))
def test_row_accepts_trailing_zeros_truncated(self): expected_coder = RowCoder( typing_to_runner_api(NullablePerson).row_type.schema) person = NullablePerson(None, np.int32(25), "Westeros", ["Mother of Dragons"], False, None, {"dragons": 3}, None, "NotNull") out = expected_coder.encode(person) # 9 fields, 1 null byte, field 0, 5, 7 are null new_payload = bytes([9, 1, 1 | 1 << 5 | 1 << 7]) + out[4:] new_value = expected_coder.decode(new_payload) self.assertEqual(person, new_value)
def test_typing_survives_proto_roundtrip(self): all_nonoptional_primitives = [ np.int8, np.int16, np.int32, np.int64, np.float32, np.float64, unicode, bool, bytes, str, ] all_optional_primitives = [ Optional[typ] for typ in all_nonoptional_primitives ] all_primitives = all_nonoptional_primitives + all_optional_primitives basic_array_types = [Sequence[typ] for typ in all_primitives] basic_map_types = [ Mapping[key_type, value_type] for key_type, value_type in itertools.product(all_primitives, all_primitives) ] selected_schemas = [ NamedTuple( 'AllPrimitives', [('field%d' % i, typ) for i, typ in enumerate(all_primitives)]), NamedTuple( 'ComplexSchema', [ ('id', np.int64), ('name', unicode), ( 'optional_map', Optional[Mapping[unicode, Optional[np.float64]]]), ('optional_array', Optional[Sequence[np.float32]]), ('array_optional', Sequence[Optional[bool]]), ('timestamp', Timestamp), ]) ] test_cases = all_primitives + \ basic_array_types + \ basic_map_types + \ selected_schemas for test_case in test_cases: self.assertEqual( test_case, typing_from_runner_api(typing_to_runner_api(test_case)))
def __init__( self, table_name, driver_class_name, jdbc_url, username, password, query=None, output_parallelization=None, fetch_size=None, connection_properties=None, connection_init_sqls=None, expansion_service=None, ): """ Initializes a read operation from Jdbc. :param driver_class_name: name of the jdbc driver class :param jdbc_url: full jdbc url to the database. :param username: database username :param password: database password :param query: sql query to be executed :param output_parallelization: is output parallelization on :param fetch_size: how many rows to fetch :param connection_properties: properties of the jdbc connection passed as string with format [propertyName=property;]* :param connection_init_sqls: required only for MySql and MariaDB. passed as list of strings :param expansion_service: The address (host:port) of the ExpansionService. """ super().__init__( self.URN, NamedTupleBasedPayloadBuilder( JdbcConfigSchema( location=table_name, config=RowCoder( typing_to_runner_api(Config).row_type.schema).encode( Config( driver_class_name=driver_class_name, jdbc_url=jdbc_url, username=username, password=password, connection_properties=connection_properties, connection_init_sqls=connection_init_sqls, write_statement=None, read_query=query, fetch_size=fetch_size, output_parallelization=output_parallelization, ))), ), expansion_service or default_io_expansion_service(), )
def test_trivial_example(self): MyCuteClass = NamedTuple( 'MyCuteClass', [ ('name', unicode), ('age', Optional[int]), ('interests', List[unicode]), ('height', float), ('blob', ByteString), ]) expected = schema_pb2.FieldType( row_type=schema_pb2.RowType( schema=schema_pb2.Schema( fields=[ schema_pb2.Field( name='name', type=schema_pb2.FieldType( atomic_type=schema_pb2.STRING), ), schema_pb2.Field( name='age', type=schema_pb2.FieldType( nullable=True, atomic_type=schema_pb2.INT64)), schema_pb2.Field( name='interests', type=schema_pb2.FieldType( array_type=schema_pb2.ArrayType( element_type=schema_pb2.FieldType( atomic_type=schema_pb2.STRING)))), schema_pb2.Field( name='height', type=schema_pb2.FieldType( atomic_type=schema_pb2.DOUBLE)), schema_pb2.Field( name='blob', type=schema_pb2.FieldType( atomic_type=schema_pb2.BYTES)), ]))) # Only test that the fields are equal. If we attempt to test the entire type # or the entire schema, the generated id will break equality. self.assertEqual( expected.row_type.schema.fields, typing_to_runner_api(MyCuteClass).row_type.schema.fields)
def test_float_maps_to_float64(self): self.assertEqual(schema_pb2.FieldType(atomic_type=schema_pb2.DOUBLE), typing_to_runner_api(float))
def test_int_maps_to_int64(self): self.assertEqual(schema_pb2.FieldType(atomic_type=schema_pb2.INT64), typing_to_runner_api(int))
def test_str_raises_error_py2(self): self.assertRaises(lambda: typing_to_runner_api(str)) self.assertRaises(lambda: typing_to_runner_api( NamedTuple('Test', [('int', int), ('str', str)])))
def test_unknown_primitive_raise_valueerror(self): self.assertRaises(ValueError, lambda: typing_to_runner_api(np.uint32))
def test_proto_survives_typing_roundtrip(self): all_nonoptional_primitives = [ schema_pb2.FieldType(atomic_type=typ) for typ in schema_pb2.AtomicType.values() if typ is not schema_pb2.UNSPECIFIED ] # The bytes type cannot survive a roundtrip to/from proto in Python 2. # In order to use BYTES a user type has to use typing.ByteString (because # bytes == str, and we map str to STRING). if not IS_PYTHON_3: all_nonoptional_primitives.remove( schema_pb2.FieldType(atomic_type=schema_pb2.BYTES)) all_optional_primitives = [ schema_pb2.FieldType(nullable=True, atomic_type=typ) for typ in schema_pb2.AtomicType.values() if typ is not schema_pb2.UNSPECIFIED ] all_primitives = all_nonoptional_primitives + all_optional_primitives basic_array_types = [ schema_pb2.FieldType(array_type=schema_pb2.ArrayType( element_type=typ)) for typ in all_primitives ] basic_map_types = [ schema_pb2.FieldType(map_type=schema_pb2.MapType( key_type=key_type, value_type=value_type)) for key_type, value_type in itertools.product( all_primitives, all_primitives) ] selected_schemas = [ schema_pb2.FieldType(row_type=schema_pb2.RowType( schema=schema_pb2.Schema( id='32497414-85e8-46b7-9c90-9a9cc62fe390', fields=[ schema_pb2.Field(name='field%d' % i, type=typ) for i, typ in enumerate(all_primitives) ]))), schema_pb2.FieldType(row_type=schema_pb2.RowType( schema=schema_pb2.Schema( id='dead1637-3204-4bcb-acf8-99675f338600', fields=[ schema_pb2.Field(name='id', type=schema_pb2.FieldType( atomic_type=schema_pb2.INT64)), schema_pb2.Field(name='name', type=schema_pb2.FieldType( atomic_type=schema_pb2.STRING)), schema_pb2.Field( name='optional_map', type=schema_pb2.FieldType( nullable=True, map_type=schema_pb2.MapType( key_type=schema_pb2.FieldType( atomic_type=schema_pb2.STRING), value_type=schema_pb2.FieldType( atomic_type=schema_pb2.DOUBLE)))), schema_pb2.Field( name='optional_array', type=schema_pb2.FieldType( nullable=True, array_type=schema_pb2.ArrayType( element_type=schema_pb2.FieldType( atomic_type=schema_pb2.FLOAT)))), schema_pb2.Field( name='array_optional', type=schema_pb2.FieldType( array_type=schema_pb2.ArrayType( element_type=schema_pb2.FieldType( nullable=True, atomic_type=schema_pb2.BYTES)))), ]))), ] test_cases = all_primitives + \ basic_array_types + \ basic_map_types + \ selected_schemas for test_case in test_cases: self.assertEqual( test_case, typing_to_runner_api(typing_from_runner_api(test_case)))
def test_unknown_primitive_maps_to_any(self): self.assertEqual( typing_to_runner_api(np.uint32), schema_pb2.FieldType(logical_type=schema_pb2.LogicalType( urn="beam:logical:pythonsdk_any:v1"), nullable=True))
def test_proto_survives_typing_roundtrip(self): all_nonoptional_primitives = [ schema_pb2.FieldType(atomic_type=typ) for typ in schema_pb2.AtomicType.values() if typ is not schema_pb2.UNSPECIFIED ] all_optional_primitives = [ schema_pb2.FieldType(nullable=True, atomic_type=typ) for typ in schema_pb2.AtomicType.values() if typ is not schema_pb2.UNSPECIFIED ] all_primitives = all_nonoptional_primitives + all_optional_primitives basic_array_types = [ schema_pb2.FieldType(array_type=schema_pb2.ArrayType( element_type=typ)) for typ in all_primitives ] basic_map_types = [ schema_pb2.FieldType(map_type=schema_pb2.MapType( key_type=key_type, value_type=value_type)) for key_type, value_type in itertools.product( all_primitives, all_primitives) ] selected_schemas = [ schema_pb2.FieldType(row_type=schema_pb2.RowType( schema=schema_pb2.Schema( id='32497414-85e8-46b7-9c90-9a9cc62fe390', fields=[ schema_pb2.Field(name='field%d' % i, type=typ) for i, typ in enumerate(all_primitives) ]))), schema_pb2.FieldType(row_type=schema_pb2.RowType( schema=schema_pb2.Schema( id='dead1637-3204-4bcb-acf8-99675f338600', fields=[ schema_pb2.Field(name='id', type=schema_pb2.FieldType( atomic_type=schema_pb2.INT64)), schema_pb2.Field(name='name', type=schema_pb2.FieldType( atomic_type=schema_pb2.STRING)), schema_pb2.Field( name='optional_map', type=schema_pb2.FieldType( nullable=True, map_type=schema_pb2.MapType( key_type=schema_pb2.FieldType( atomic_type=schema_pb2.STRING), value_type=schema_pb2.FieldType( atomic_type=schema_pb2.DOUBLE)))), schema_pb2.Field( name='optional_array', type=schema_pb2.FieldType( nullable=True, array_type=schema_pb2.ArrayType( element_type=schema_pb2.FieldType( atomic_type=schema_pb2.FLOAT)))), schema_pb2.Field( name='array_optional', type=schema_pb2.FieldType( array_type=schema_pb2.ArrayType( element_type=schema_pb2.FieldType( nullable=True, atomic_type=schema_pb2.BYTES)))), ]))), ] test_cases = all_primitives + \ basic_array_types + \ basic_map_types + \ selected_schemas for test_case in test_cases: self.assertEqual( test_case, typing_to_runner_api(typing_from_runner_api( test_case, schema_registry=SchemaTypeRegistry()), schema_registry=SchemaTypeRegistry()))