コード例 #1
0
    def test_typing_survives_proto_roundtrip(self):
        all_nonoptional_primitives = [
            np.int8,
            np.int16,
            np.int32,
            np.int64,
            np.float32,
            np.float64,
            bool,
            bytes,
            str,
        ]

        all_optional_primitives = [
            Optional[typ] for typ in all_nonoptional_primitives
        ]

        all_primitives = all_nonoptional_primitives + all_optional_primitives

        basic_array_types = [Sequence[typ] for typ in all_primitives]

        basic_map_types = [
            Mapping[key_type, value_type]
            for key_type, value_type in itertools.product(
                all_primitives, all_primitives)
        ]

        selected_schemas = [
            NamedTuple('AllPrimitives',
                       [('field%d' % i, typ)
                        for i, typ in enumerate(all_primitives)]),
            NamedTuple('ComplexSchema', [
                ('id', np.int64),
                ('name', str),
                ('optional_map', Optional[Mapping[str, Optional[np.float64]]]),
                ('optional_array', Optional[Sequence[np.float32]]),
                ('array_optional', Sequence[Optional[bool]]),
                ('timestamp', Timestamp),
            ])
        ]

        test_cases = all_primitives + \
                     basic_array_types + \
                     basic_map_types

        for test_case in test_cases:
            self.assertEqual(
                test_case,
                typing_from_runner_api(typing_to_runner_api(
                    test_case, schema_registry=SchemaTypeRegistry()),
                                       schema_registry=SchemaTypeRegistry()))

        # Break out NamedTuple types since they require special verification
        for test_case in selected_schemas:
            self.assert_namedtuple_equivalent(
                test_case,
                typing_from_runner_api(typing_to_runner_api(
                    test_case, schema_registry=SchemaTypeRegistry()),
                                       schema_registry=SchemaTypeRegistry()))
コード例 #2
0
 def test_python_callable_maps_to_logical_type(self):
     from apache_beam.utils.python_callable import PythonCallableWithSource
     self.assertEqual(
         schema_pb2.FieldType(logical_type=schema_pb2.LogicalType(
             urn=common_urns.python_callable.urn,
             representation=typing_to_runner_api(str))),
         typing_to_runner_api(PythonCallableWithSource))
     self.assertEqual(
         typing_from_runner_api(
             schema_pb2.FieldType(logical_type=schema_pb2.LogicalType(
                 urn=common_urns.python_callable.urn,
                 representation=typing_to_runner_api(str)))),
         PythonCallableWithSource)
コード例 #3
0
ファイル: row_coder_test.py プロジェクト: tapanu/beam
    def test_row_coder_nested_struct(self):
        Pair = typing.NamedTuple('Pair', [('left', Person), ('right', Person)])

        value = Pair(self.PEOPLE[0], self.PEOPLE[1])
        coder = RowCoder(typing_to_runner_api(Pair).row_type.schema)

        self.assertEqual(value, coder.decode(coder.encode(value)))
コード例 #4
0
ファイル: jdbc.py プロジェクト: scwhittle/beam
    def __init__(
        self,
        table_name,
        driver_class_name,
        jdbc_url,
        username,
        password,
        statement=None,
        connection_properties=None,
        connection_init_sqls=None,
        expansion_service=None,
        classpath=None,
    ):
        """
    Initializes a write operation to Jdbc.

    :param driver_class_name: name of the jdbc driver class
    :param jdbc_url: full jdbc url to the database.
    :param username: database username
    :param password: database password
    :param statement: sql statement to be executed
    :param connection_properties: properties of the jdbc connection
                                  passed as string with format
                                  [propertyName=property;]*
    :param connection_init_sqls: required only for MySql and MariaDB.
                                 passed as list of strings
    :param expansion_service: The address (host:port) of the ExpansionService.
    :param classpath: A list of JARs or Java packages to include in the
                      classpath for the expansion service. This option is
                      usually needed for `jdbc` to include extra JDBC driver
                      packages.
                      The packages can be in these three formats: (1) A local
                      file, (2) A URL, (3) A gradle-style identifier of a Maven
                      package (e.g. "org.postgresql:postgresql:42.3.1").
                      By default, this argument includes a Postgres SQL JDBC
                      driver.
    """
        classpath = classpath or DEFAULT_JDBC_CLASSPATH
        super().__init__(
            self.URN,
            NamedTupleBasedPayloadBuilder(
                JdbcConfigSchema(
                    location=table_name,
                    config=RowCoder(
                        typing_to_runner_api(Config).row_type.schema).encode(
                            Config(
                                driver_class_name=driver_class_name,
                                jdbc_url=jdbc_url,
                                username=username,
                                password=password,
                                connection_properties=connection_properties,
                                connection_init_sqls=connection_init_sqls,
                                write_statement=statement,
                                read_query=None,
                                fetch_size=None,
                                output_parallelization=None,
                            ))), ),
            expansion_service or default_io_expansion_service(classpath),
        )
コード例 #5
0
  def test_typing_survives_proto_roundtrip(self):
    all_nonoptional_primitives = [
        np.int8,
        np.int16,
        np.int32,
        np.int64,
        np.float32,
        np.float64,
        unicode,
        bool,
    ]

    # The bytes type cannot survive a roundtrip to/from proto in Python 2.
    # In order to use BYTES a user type has to use typing.ByteString (because
    # bytes == str, and we map str to STRING).
    # TODO(BEAM-7372)
    if IS_PYTHON_3:
      all_nonoptional_primitives.extend([bytes])

    all_optional_primitives = [
        Optional[typ] for typ in all_nonoptional_primitives
    ]

    all_primitives = all_nonoptional_primitives + all_optional_primitives

    basic_array_types = [Sequence[typ] for typ in all_primitives]

    basic_map_types = [
        Mapping[key_type, value_type] for key_type,
        value_type in itertools.product(all_primitives, all_primitives)
    ]

    selected_schemas = [
        NamedTuple(
            'AllPrimitives',
            [('field%d' % i, typ) for i, typ in enumerate(all_primitives)]),
        NamedTuple(
            'ComplexSchema',
            [
                ('id', np.int64),
                ('name', unicode),
                (
                    'optional_map',
                    Optional[Mapping[unicode, Optional[np.float64]]]),
                ('optional_array', Optional[Sequence[np.float32]]),
                ('array_optional', Sequence[Optional[bool]]),
                ('timestamp', Timestamp),
            ])
    ]

    test_cases = all_primitives + \
                 basic_array_types + \
                 basic_map_types + \
                 selected_schemas

    for test_case in test_cases:
      self.assertEqual(
          test_case, typing_from_runner_api(typing_to_runner_api(test_case)))
コード例 #6
0
ファイル: row_coder_test.py プロジェクト: tapanu/beam
    def test_create_row_coder_from_named_tuple(self):
        expected_coder = RowCoder(typing_to_runner_api(Person).row_type.schema)
        real_coder = coders_registry.get_coder(Person)

        for test_case in self.PEOPLE:
            self.assertEqual(expected_coder.encode(test_case),
                             real_coder.encode(test_case))

            self.assertEqual(test_case,
                             real_coder.decode(real_coder.encode(test_case)))
コード例 #7
0
ファイル: row_coder_test.py プロジェクト: melap/beam
 def test_row_accepts_trailing_zeros_truncated(self):
     expected_coder = RowCoder(
         typing_to_runner_api(NullablePerson).row_type.schema)
     person = NullablePerson(None, np.int32(25), "Westeros",
                             ["Mother of Dragons"], False, None,
                             {"dragons": 3}, None, "NotNull")
     out = expected_coder.encode(person)
     # 9 fields, 1 null byte, field 0, 5, 7 are null
     new_payload = bytes([9, 1, 1 | 1 << 5 | 1 << 7]) + out[4:]
     new_value = expected_coder.decode(new_payload)
     self.assertEqual(person, new_value)
コード例 #8
0
  def test_typing_survives_proto_roundtrip(self):
    all_nonoptional_primitives = [
        np.int8,
        np.int16,
        np.int32,
        np.int64,
        np.float32,
        np.float64,
        unicode,
        bool,
        bytes,
        str,
    ]

    all_optional_primitives = [
        Optional[typ] for typ in all_nonoptional_primitives
    ]

    all_primitives = all_nonoptional_primitives + all_optional_primitives

    basic_array_types = [Sequence[typ] for typ in all_primitives]

    basic_map_types = [
        Mapping[key_type, value_type] for key_type,
        value_type in itertools.product(all_primitives, all_primitives)
    ]

    selected_schemas = [
        NamedTuple(
            'AllPrimitives',
            [('field%d' % i, typ) for i, typ in enumerate(all_primitives)]),
        NamedTuple(
            'ComplexSchema',
            [
                ('id', np.int64),
                ('name', unicode),
                (
                    'optional_map',
                    Optional[Mapping[unicode, Optional[np.float64]]]),
                ('optional_array', Optional[Sequence[np.float32]]),
                ('array_optional', Sequence[Optional[bool]]),
                ('timestamp', Timestamp),
            ])
    ]

    test_cases = all_primitives + \
                 basic_array_types + \
                 basic_map_types + \
                 selected_schemas

    for test_case in test_cases:
      self.assertEqual(
          test_case, typing_from_runner_api(typing_to_runner_api(test_case)))
コード例 #9
0
    def __init__(
        self,
        table_name,
        driver_class_name,
        jdbc_url,
        username,
        password,
        query=None,
        output_parallelization=None,
        fetch_size=None,
        connection_properties=None,
        connection_init_sqls=None,
        expansion_service=None,
    ):
        """
    Initializes a read operation from Jdbc.

    :param driver_class_name: name of the jdbc driver class
    :param jdbc_url: full jdbc url to the database.
    :param username: database username
    :param password: database password
    :param query: sql query to be executed
    :param output_parallelization: is output parallelization on
    :param fetch_size: how many rows to fetch
    :param connection_properties: properties of the jdbc connection
                                  passed as string with format
                                  [propertyName=property;]*
    :param connection_init_sqls: required only for MySql and MariaDB.
                                 passed as list of strings
    :param expansion_service: The address (host:port) of the ExpansionService.
    """
        super().__init__(
            self.URN,
            NamedTupleBasedPayloadBuilder(
                JdbcConfigSchema(
                    location=table_name,
                    config=RowCoder(
                        typing_to_runner_api(Config).row_type.schema).encode(
                            Config(
                                driver_class_name=driver_class_name,
                                jdbc_url=jdbc_url,
                                username=username,
                                password=password,
                                connection_properties=connection_properties,
                                connection_init_sqls=connection_init_sqls,
                                write_statement=None,
                                read_query=query,
                                fetch_size=fetch_size,
                                output_parallelization=output_parallelization,
                            ))), ),
            expansion_service or default_io_expansion_service(),
        )
コード例 #10
0
  def test_trivial_example(self):
    MyCuteClass = NamedTuple(
        'MyCuteClass',
        [
            ('name', unicode),
            ('age', Optional[int]),
            ('interests', List[unicode]),
            ('height', float),
            ('blob', ByteString),
        ])

    expected = schema_pb2.FieldType(
        row_type=schema_pb2.RowType(
            schema=schema_pb2.Schema(
                fields=[
                    schema_pb2.Field(
                        name='name',
                        type=schema_pb2.FieldType(
                            atomic_type=schema_pb2.STRING),
                    ),
                    schema_pb2.Field(
                        name='age',
                        type=schema_pb2.FieldType(
                            nullable=True, atomic_type=schema_pb2.INT64)),
                    schema_pb2.Field(
                        name='interests',
                        type=schema_pb2.FieldType(
                            array_type=schema_pb2.ArrayType(
                                element_type=schema_pb2.FieldType(
                                    atomic_type=schema_pb2.STRING)))),
                    schema_pb2.Field(
                        name='height',
                        type=schema_pb2.FieldType(
                            atomic_type=schema_pb2.DOUBLE)),
                    schema_pb2.Field(
                        name='blob',
                        type=schema_pb2.FieldType(
                            atomic_type=schema_pb2.BYTES)),
                ])))

    # Only test that the fields are equal. If we attempt to test the entire type
    # or the entire schema, the generated id will break equality.
    self.assertEqual(
        expected.row_type.schema.fields,
        typing_to_runner_api(MyCuteClass).row_type.schema.fields)
コード例 #11
0
 def test_float_maps_to_float64(self):
     self.assertEqual(schema_pb2.FieldType(atomic_type=schema_pb2.DOUBLE),
                      typing_to_runner_api(float))
コード例 #12
0
 def test_int_maps_to_int64(self):
     self.assertEqual(schema_pb2.FieldType(atomic_type=schema_pb2.INT64),
                      typing_to_runner_api(int))
コード例 #13
0
 def test_str_raises_error_py2(self):
     self.assertRaises(lambda: typing_to_runner_api(str))
     self.assertRaises(lambda: typing_to_runner_api(
         NamedTuple('Test', [('int', int), ('str', str)])))
コード例 #14
0
 def test_unknown_primitive_raise_valueerror(self):
     self.assertRaises(ValueError, lambda: typing_to_runner_api(np.uint32))
コード例 #15
0
    def test_proto_survives_typing_roundtrip(self):
        all_nonoptional_primitives = [
            schema_pb2.FieldType(atomic_type=typ)
            for typ in schema_pb2.AtomicType.values()
            if typ is not schema_pb2.UNSPECIFIED
        ]

        # The bytes type cannot survive a roundtrip to/from proto in Python 2.
        # In order to use BYTES a user type has to use typing.ByteString (because
        # bytes == str, and we map str to STRING).
        if not IS_PYTHON_3:
            all_nonoptional_primitives.remove(
                schema_pb2.FieldType(atomic_type=schema_pb2.BYTES))

        all_optional_primitives = [
            schema_pb2.FieldType(nullable=True, atomic_type=typ)
            for typ in schema_pb2.AtomicType.values()
            if typ is not schema_pb2.UNSPECIFIED
        ]

        all_primitives = all_nonoptional_primitives + all_optional_primitives

        basic_array_types = [
            schema_pb2.FieldType(array_type=schema_pb2.ArrayType(
                element_type=typ)) for typ in all_primitives
        ]

        basic_map_types = [
            schema_pb2.FieldType(map_type=schema_pb2.MapType(
                key_type=key_type, value_type=value_type))
            for key_type, value_type in itertools.product(
                all_primitives, all_primitives)
        ]

        selected_schemas = [
            schema_pb2.FieldType(row_type=schema_pb2.RowType(
                schema=schema_pb2.Schema(
                    id='32497414-85e8-46b7-9c90-9a9cc62fe390',
                    fields=[
                        schema_pb2.Field(name='field%d' % i, type=typ)
                        for i, typ in enumerate(all_primitives)
                    ]))),
            schema_pb2.FieldType(row_type=schema_pb2.RowType(
                schema=schema_pb2.Schema(
                    id='dead1637-3204-4bcb-acf8-99675f338600',
                    fields=[
                        schema_pb2.Field(name='id',
                                         type=schema_pb2.FieldType(
                                             atomic_type=schema_pb2.INT64)),
                        schema_pb2.Field(name='name',
                                         type=schema_pb2.FieldType(
                                             atomic_type=schema_pb2.STRING)),
                        schema_pb2.Field(
                            name='optional_map',
                            type=schema_pb2.FieldType(
                                nullable=True,
                                map_type=schema_pb2.MapType(
                                    key_type=schema_pb2.FieldType(
                                        atomic_type=schema_pb2.STRING),
                                    value_type=schema_pb2.FieldType(
                                        atomic_type=schema_pb2.DOUBLE)))),
                        schema_pb2.Field(
                            name='optional_array',
                            type=schema_pb2.FieldType(
                                nullable=True,
                                array_type=schema_pb2.ArrayType(
                                    element_type=schema_pb2.FieldType(
                                        atomic_type=schema_pb2.FLOAT)))),
                        schema_pb2.Field(
                            name='array_optional',
                            type=schema_pb2.FieldType(
                                array_type=schema_pb2.ArrayType(
                                    element_type=schema_pb2.FieldType(
                                        nullable=True,
                                        atomic_type=schema_pb2.BYTES)))),
                    ]))),
        ]

        test_cases = all_primitives + \
                     basic_array_types + \
                     basic_map_types + \
                     selected_schemas

        for test_case in test_cases:
            self.assertEqual(
                test_case,
                typing_to_runner_api(typing_from_runner_api(test_case)))
コード例 #16
0
 def test_unknown_primitive_maps_to_any(self):
     self.assertEqual(
         typing_to_runner_api(np.uint32),
         schema_pb2.FieldType(logical_type=schema_pb2.LogicalType(
             urn="beam:logical:pythonsdk_any:v1"),
                              nullable=True))
コード例 #17
0
    def test_proto_survives_typing_roundtrip(self):
        all_nonoptional_primitives = [
            schema_pb2.FieldType(atomic_type=typ)
            for typ in schema_pb2.AtomicType.values()
            if typ is not schema_pb2.UNSPECIFIED
        ]

        all_optional_primitives = [
            schema_pb2.FieldType(nullable=True, atomic_type=typ)
            for typ in schema_pb2.AtomicType.values()
            if typ is not schema_pb2.UNSPECIFIED
        ]

        all_primitives = all_nonoptional_primitives + all_optional_primitives

        basic_array_types = [
            schema_pb2.FieldType(array_type=schema_pb2.ArrayType(
                element_type=typ)) for typ in all_primitives
        ]

        basic_map_types = [
            schema_pb2.FieldType(map_type=schema_pb2.MapType(
                key_type=key_type, value_type=value_type))
            for key_type, value_type in itertools.product(
                all_primitives, all_primitives)
        ]

        selected_schemas = [
            schema_pb2.FieldType(row_type=schema_pb2.RowType(
                schema=schema_pb2.Schema(
                    id='32497414-85e8-46b7-9c90-9a9cc62fe390',
                    fields=[
                        schema_pb2.Field(name='field%d' % i, type=typ)
                        for i, typ in enumerate(all_primitives)
                    ]))),
            schema_pb2.FieldType(row_type=schema_pb2.RowType(
                schema=schema_pb2.Schema(
                    id='dead1637-3204-4bcb-acf8-99675f338600',
                    fields=[
                        schema_pb2.Field(name='id',
                                         type=schema_pb2.FieldType(
                                             atomic_type=schema_pb2.INT64)),
                        schema_pb2.Field(name='name',
                                         type=schema_pb2.FieldType(
                                             atomic_type=schema_pb2.STRING)),
                        schema_pb2.Field(
                            name='optional_map',
                            type=schema_pb2.FieldType(
                                nullable=True,
                                map_type=schema_pb2.MapType(
                                    key_type=schema_pb2.FieldType(
                                        atomic_type=schema_pb2.STRING),
                                    value_type=schema_pb2.FieldType(
                                        atomic_type=schema_pb2.DOUBLE)))),
                        schema_pb2.Field(
                            name='optional_array',
                            type=schema_pb2.FieldType(
                                nullable=True,
                                array_type=schema_pb2.ArrayType(
                                    element_type=schema_pb2.FieldType(
                                        atomic_type=schema_pb2.FLOAT)))),
                        schema_pb2.Field(
                            name='array_optional',
                            type=schema_pb2.FieldType(
                                array_type=schema_pb2.ArrayType(
                                    element_type=schema_pb2.FieldType(
                                        nullable=True,
                                        atomic_type=schema_pb2.BYTES)))),
                    ]))),
        ]

        test_cases = all_primitives + \
                     basic_array_types + \
                     basic_map_types + \
                     selected_schemas

        for test_case in test_cases:
            self.assertEqual(
                test_case,
                typing_to_runner_api(typing_from_runner_api(
                    test_case, schema_registry=SchemaTypeRegistry()),
                                     schema_registry=SchemaTypeRegistry()))