Пример #1
0
    def test_infer_schema_nulltype(self):
        elements = [
            Row(c1=[], c2={}, c3=None),
            Row(c1=[Row(a=1, b='s')], c2={"key": Row(c=1.0, d="2")}, c3="")
        ]
        schema = _infer_schema_from_data(elements)
        self.assertTrue(isinstance(schema, RowType))
        self.assertEqual(3, len(schema.fields))

        # first column is array
        self.assertTrue(isinstance(schema.fields[0].data_type, ArrayType))

        # element type of first column is struct
        self.assertTrue(
            isinstance(schema.fields[0].data_type.element_type, RowType))

        self.assertTrue(
            isinstance(
                schema.fields[0].data_type.element_type.fields[0].data_type,
                BigIntType))
        self.assertTrue(
            isinstance(
                schema.fields[0].data_type.element_type.fields[1].data_type,
                VarCharType))

        # second column is map
        self.assertTrue(isinstance(schema.fields[1].data_type, MapType))
        self.assertTrue(
            isinstance(schema.fields[1].data_type.key_type, VarCharType))
        self.assertTrue(
            isinstance(schema.fields[1].data_type.value_type, RowType))

        # third column is varchar
        self.assertTrue(isinstance(schema.fields[2].data_type, VarCharType))
Пример #2
0
    def test_from_data_stream_with_schema(self):
        from pyflink.table import Schema

        ds = self.env.from_collection(
            [(1, 'Hi', 'Hello'), (2, 'Hello', 'Hi')],
            type_info=Types.ROW_NAMED(
                ["a", "b", "c"],
                [Types.INT(), Types.STRING(),
                 Types.STRING()]))

        table = self.t_env.from_data_stream(
            ds,
            Schema.new_builder().column("a", DataTypes.INT()).column(
                "b", DataTypes.STRING()).column("c",
                                                DataTypes.STRING()).build())
        result = table.execute()
        with result.collect() as result:
            collected_result = [str(item) for item in result]
            expected_result = [
                item for item in map(
                    str, [Row(1, 'Hi', 'Hello'),
                          Row(2, 'Hello', 'Hi')])
            ]
            expected_result.sort()
            collected_result.sort()
            self.assertEqual(expected_result, collected_result)
Пример #3
0
 def setUpClass(cls):
     super(PandasConversionTestBase, cls).setUpClass()
     cls.data = [(1, 1, 1, 1, True, 1.1, 1.2, 'hello', bytearray(b"aaa"),
                  decimal.Decimal('1000000000000000000.01'), datetime.date(2014, 9, 13),
                  datetime.time(hour=1, minute=0, second=1),
                  datetime.datetime(1970, 1, 1, 0, 0, 0, 123000), ['hello', '中文'],
                  Row(a=1, b='hello', c=datetime.datetime(1970, 1, 1, 0, 0, 0, 123000),
                      d=[1, 2])),
                 (1, 2, 2, 2, False, 2.1, 2.2, 'world', bytearray(b"bbb"),
                  decimal.Decimal('1000000000000000000.02'), datetime.date(2014, 9, 13),
                  datetime.time(hour=1, minute=0, second=1),
                  datetime.datetime(1970, 1, 1, 0, 0, 0, 123000), ['hello', '中文'],
                  Row(a=1, b='hello', c=datetime.datetime(1970, 1, 1, 0, 0, 0, 123000),
                      d=[1, 2]))]
     cls.data_type = DataTypes.ROW(
         [DataTypes.FIELD("f1", DataTypes.TINYINT()),
          DataTypes.FIELD("f2", DataTypes.SMALLINT()),
          DataTypes.FIELD("f3", DataTypes.INT()),
          DataTypes.FIELD("f4", DataTypes.BIGINT()),
          DataTypes.FIELD("f5", DataTypes.BOOLEAN()),
          DataTypes.FIELD("f6", DataTypes.FLOAT()),
          DataTypes.FIELD("f7", DataTypes.DOUBLE()),
          DataTypes.FIELD("f8", DataTypes.STRING()),
          DataTypes.FIELD("f9", DataTypes.BYTES()),
          DataTypes.FIELD("f10", DataTypes.DECIMAL(38, 18)),
          DataTypes.FIELD("f11", DataTypes.DATE()),
          DataTypes.FIELD("f12", DataTypes.TIME()),
          DataTypes.FIELD("f13", DataTypes.TIMESTAMP(3)),
          DataTypes.FIELD("f14", DataTypes.ARRAY(DataTypes.STRING())),
          DataTypes.FIELD("f15", DataTypes.ROW(
              [DataTypes.FIELD("a", DataTypes.INT()),
               DataTypes.FIELD("b", DataTypes.STRING()),
               DataTypes.FIELD("c", DataTypes.TIMESTAMP(3)),
               DataTypes.FIELD("d", DataTypes.ARRAY(DataTypes.INT()))]))])
     cls.pdf = cls.create_pandas_data_frame()
Пример #4
0
    def test_infer_schema(self):
        from decimal import Decimal

        class A(object):
            def __init__(self):
                self.a = 1

        from collections import namedtuple
        Point = namedtuple('Point', 'x y')

        data = [
            True,
            1,
            "a",
            u"a",
            datetime.date(1970, 1, 1),
            datetime.time(0, 0, 0),
            datetime.datetime(1970, 1, 1, 0, 0),
            1.0,
            array.array("d", [1]),
            [1],
            (1, ),
            Point(1.0, 5.0),
            {
                "a": 1
            },
            bytearray(1),
            Decimal(1),
            Row(a=1),
            Row("a")(1),
            A(),
        ]

        expected = [
            'BooleanType(true)',
            'BigIntType(true)',
            'VarCharType(2147483647, true)',
            'VarCharType(2147483647, true)',
            'DateType(true)',
            'TimeType(0, true)',
            'LocalZonedTimestampType(6, true)',
            'DoubleType(true)',
            "ArrayType(DoubleType(false), true)",
            "ArrayType(BigIntType(true), true)",
            'RowType(RowField(_1, BigIntType(true), ...))',
            'RowType(RowField(x, DoubleType(true), ...),RowField(y, DoubleType(true), ...))',
            'MapType(VarCharType(2147483647, false), BigIntType(true), true)',
            'VarBinaryType(2147483647, true)',
            'DecimalType(38, 18, true)',
            'RowType(RowField(a, BigIntType(true), ...))',
            'RowType(RowField(a, BigIntType(true), ...))',
            'RowType(RowField(a, BigIntType(true), ...))',
        ]

        schema = _infer_schema_from_data([data])
        self.assertEqual(expected, [repr(f.data_type) for f in schema.fields])
Пример #5
0
    def test_from_and_to_data_stream_event_time(self):
        from pyflink.table import Schema

        ds = self.env.from_collection([(1, 42, "a"), (2, 5, "a"), (3, 1000, "c"), (100, 1000, "c")],
                                      Types.ROW_NAMED(
                                          ["a", "b", "c"],
                                          [Types.LONG(), Types.INT(), Types.STRING()]))
        ds = ds.assign_timestamps_and_watermarks(
            WatermarkStrategy.for_monotonous_timestamps()
            .with_timestamp_assigner(MyTimestampAssigner()))

        table = self.t_env.from_data_stream(ds,
                                            Schema.new_builder()
                                                  .column_by_metadata("rowtime", "TIMESTAMP_LTZ(3)")
                                                  .watermark("rowtime", "SOURCE_WATERMARK()")
                                                  .build())
        self.assertEqual("""(
  `a` BIGINT,
  `b` INT,
  `c` STRING,
  `rowtime` TIMESTAMP_LTZ(3) *ROWTIME* METADATA,
  WATERMARK FOR `rowtime`: TIMESTAMP_LTZ(3) AS SOURCE_WATERMARK()
)""",
                         table._j_table.getResolvedSchema().toString())
        self.t_env.create_temporary_view("t",
                                         ds,
                                         Schema.new_builder()
                                         .column_by_metadata("rowtime", "TIMESTAMP_LTZ(3)")
                                         .watermark("rowtime", "SOURCE_WATERMARK()")
                                         .build())

        result = self.t_env.execute_sql("SELECT "
                                        "c, SUM(b) "
                                        "FROM t "
                                        "GROUP BY c, TUMBLE(rowtime, INTERVAL '0.005' SECOND)")
        with result.collect() as result:
            collected_result = [str(item) for item in result]
            expected_result = [item for item in
                               map(str, [Row('a', 47), Row('c', 1000), Row('c', 1000)])]
            expected_result.sort()
            collected_result.sort()
            self.assertEqual(expected_result, collected_result)

        ds = self.t_env.to_data_stream(table)
        ds.key_by(lambda k: k.c, key_type=Types.STRING()) \
            .window(MyTumblingEventTimeWindow()) \
            .apply(SumWindowFunction(), Types.TUPLE([Types.STRING(), Types.INT()])) \
            .add_sink(self.test_sink)
        self.env.execute()
        expected_results = ['(a,47)', '(c,1000)', '(c,1000)']
        actual_results = self.test_sink.get_results(False)
        expected_results.sort()
        actual_results.sort()
        self.assertEqual(expected_results, actual_results)
Пример #6
0
    def test_from_data_stream_atomic(self):
        data_stream = self.env.from_collection([(1,), (2,), (3,), (4,), (5,)])
        result = self.t_env.from_data_stream(data_stream).execute()
        self.assertEqual("""(
  `f0` RAW('[B', '...')
)""",
                         result._j_table_result.getResolvedSchema().toString())
        with result.collect() as result:
            collected_result = [str(item) for item in result]
            expected_result = [item for item in map(str, [Row(1), Row(2), Row(3), Row(4), Row(5)])]
            expected_result.sort()
            collected_result.sort()
            self.assertEqual(expected_result, collected_result)
Пример #7
0
    def test_infer_nested_schema(self):
        NestedRow = Row("f1", "f2")
        data1 = [
            NestedRow([1, 2], {"row1": 1.0}),
            NestedRow([2, 3], {"row2": 2.0})
        ]
        schema1 = _infer_schema_from_data(data1)
        expected1 = [
            'ArrayType(BigIntType(true), true)',
            'MapType(VarCharType(2147483647, false), DoubleType(true), true)'
        ]
        self.assertEqual(expected1,
                         [repr(f.data_type) for f in schema1.fields])

        data2 = [
            NestedRow([[1, 2], [2, 3]], [1, 2]),
            NestedRow([[2, 3], [3, 4]], [2, 3])
        ]
        schema2 = _infer_schema_from_data(data2)
        expected2 = [
            'ArrayType(ArrayType(BigIntType(true), true), true)',
            'ArrayType(BigIntType(true), true)'
        ]
        self.assertEqual(expected2,
                         [repr(f.data_type) for f in schema2.fields])
Пример #8
0
 def test_collect_null_value_result(self):
     element_data = [(1, None, 'a'),
                     (3, 4, 'b'),
                     (5, None, 'a'),
                     (7, 8, 'b')]
     source = self.t_env.from_elements(element_data,
                                       DataTypes.ROW([DataTypes.FIELD('a', DataTypes.INT()),
                                                      DataTypes.FIELD('b', DataTypes.INT()),
                                                      DataTypes.FIELD('c', DataTypes.STRING())]))
     table_result = source.execute()
     expected_result = [Row(1, None, 'a'), Row(3, 4, 'b'), Row(5, None, 'a'),
                        Row(7, 8, 'b')]
     with table_result.collect() as results:
         collected_result = []
         for result in results:
             collected_result.append(result)
         self.assertEqual(collected_result, expected_result)
Пример #9
0
def pickled_bytes_to_python_converter(data, field_type: DataType):
    if isinstance(field_type, RowType):
        row_kind = RowKind(
            int.from_bytes(data[0], byteorder='big', signed=False))
        data = zip(list(data[1:]), field_type.field_types())
        fields = []
        for d, d_type in data:
            fields.append(pickled_bytes_to_python_converter(d, d_type))
        result_row = Row(fields)
        result_row.set_row_kind(row_kind)
        return result_row
    else:
        data = pickle.loads(data)
        if isinstance(field_type, TimeType):
            seconds, microseconds = divmod(data, 10**6)
            minutes, seconds = divmod(seconds, 60)
            hours, minutes = divmod(minutes, 60)
            return datetime.time(hours, minutes, seconds, microseconds)
        elif isinstance(field_type, DateType):
            return field_type.from_sql_type(data)
        elif isinstance(field_type, TimestampType):
            return field_type.from_sql_type(int(data.timestamp() * 10**6))
        elif isinstance(field_type, MapType):
            key_type = field_type.key_type
            value_type = field_type.value_type
            zip_kv = zip(data[0], data[1])
            return dict((pickled_bytes_to_python_converter(k, key_type),
                         pickled_bytes_to_python_converter(v, value_type))
                        for k, v in zip_kv)
        elif isinstance(field_type, FloatType):
            return field_type.from_sql_type(ast.literal_eval(data))
        elif isinstance(field_type, ArrayType):
            element_type = field_type.element_type
            elements = []
            for element_bytes in data:
                elements.append(
                    pickled_bytes_to_python_converter(element_bytes,
                                                      element_type))
            return elements
        elif isinstance(field_type, RawType):
            return field_type.from_sql_type(data)
        else:
            return field_type.from_sql_type(data)
Пример #10
0
 def test_infer_bigint_type(self):
     longrow = [Row(f1='a', f2=100000000000000)]
     schema = _infer_schema_from_data(longrow)
     self.assertEqual(DataTypes.BIGINT(), schema.fields[1].data_type)
     self.assertEqual(DataTypes.BIGINT(), _infer_type(1))
     self.assertEqual(DataTypes.BIGINT(), _infer_type(2**10))
     self.assertEqual(DataTypes.BIGINT(), _infer_type(2**20))
     self.assertEqual(DataTypes.BIGINT(), _infer_type(2**31 - 1))
     self.assertEqual(DataTypes.BIGINT(), _infer_type(2**31))
     self.assertEqual(DataTypes.BIGINT(), _infer_type(2**61))
     self.assertEqual(DataTypes.BIGINT(), _infer_type(2**71))
Пример #11
0
    def test_collect_with_retract(self):
        expected_row_kinds = [
            RowKind.INSERT, RowKind.UPDATE_BEFORE, RowKind.UPDATE_AFTER,
            RowKind.INSERT, RowKind.UPDATE_BEFORE, RowKind.UPDATE_AFTER
        ]
        element_data = [(1, 2, 'a'), (3, 4, 'b'), (5, 6, 'a'), (7, 8, 'b')]
        field_names = ['a', 'b', 'c']
        source = self.t_env.from_elements(element_data, field_names)
        table_result = self.t_env.execute_sql(
            "SELECT SUM(a), c FROM %s group by c" % source)
        with table_result.collect() as result:
            collected_result = []
            for i in result:
                collected_result.append(i)

            collected_result = [
                str(result) + ',' + str(result.get_row_kind())
                for result in collected_result
            ]
            expected_result = [
                Row(1, 'a'),
                Row(1, 'a'),
                Row(6, 'a'),
                Row(3, 'b'),
                Row(3, 'b'),
                Row(10, 'b')
            ]
            for i in range(len(expected_result)):
                expected_result[i] = str(expected_result[i]) + ',' + str(
                    expected_row_kinds[i])
            expected_result.sort()
            collected_result.sort()
            self.assertEqual(expected_result, collected_result)
Пример #12
0
    def test_from_and_to_changelog_stream_event_time(self):
        from pyflink.table import Schema

        self.env.set_parallelism(1)
        ds = self.env.from_collection(
            [(1, 42, "a"), (2, 5, "a"), (3, 1000, "c"), (100, 1000, "c")],
            Types.ROW([Types.LONG(), Types.INT(),
                       Types.STRING()]))
        ds = ds.assign_timestamps_and_watermarks(
            WatermarkStrategy.for_monotonous_timestamps(
            ).with_timestamp_assigner(MyTimestampAssigner()))

        changelog_stream = ds.map(lambda t: Row(t.f1, t.f2),
                                  Types.ROW([Types.INT(),
                                             Types.STRING()]))

        # derive physical columns and add a rowtime
        table = self.t_env.from_changelog_stream(
            changelog_stream,
            Schema.new_builder().column_by_metadata(
                "rowtime", DataTypes.TIMESTAMP_LTZ(3)).column_by_expression(
                    "computed", str(col("f1").upper_case)).watermark(
                        "rowtime", str(source_watermark())).build())

        self.t_env.create_temporary_view("t", table)

        # access and reorder columns
        reordered = self.t_env.sql_query("SELECT computed, rowtime, f0 FROM t")

        # write out the rowtime column with fully declared schema
        result = self.t_env.to_changelog_stream(
            reordered,
            Schema.new_builder().column(
                "f1", DataTypes.STRING()).column_by_metadata(
                    "rowtime",
                    DataTypes.TIMESTAMP_LTZ(3)).column_by_expression(
                        "ignored", str(col("f1").upper_case)).column(
                            "f0", DataTypes.INT()).build())

        # test event time window and field access
        result.key_by(lambda k: k.f1) \
            .window(MyTumblingEventTimeWindow()) \
            .apply(SumWindowFunction(), Types.TUPLE([Types.STRING(), Types.INT()])) \
            .add_sink(self.test_sink)
        self.env.execute()
        expected_results = ['(A,47)', '(C,1000)', '(C,1000)']
        actual_results = self.test_sink.get_results(False)
        expected_results.sort()
        actual_results.sort()
        self.assertEqual(expected_results, actual_results)
Пример #13
0
 def test_convert_row_to_dict(self):
     row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
     self.assertEqual(1, row.as_dict()['l'][0].a)
     self.assertEqual(1.0, row.as_dict()['d']['key'].c)
Пример #14
0
 def test_collect_for_all_data_types(self):
     expected_result = [
         Row(1, None, 1, True, 32767, -2147483648, 1.23, 1.98932,
             bytearray(b'pyflink'), 'pyflink', datetime.date(2014, 9, 13),
             datetime.time(12, 0, 0, 123000),
             datetime.datetime(2018, 3, 11, 3, 0, 0, 123000),
             [Row(['[pyflink]']),
              Row(['[pyflink]']),
              Row(['[pyflink]'])], {
                  1: Row(['[flink]']),
                  2: Row(['[pyflink]'])
              }, decimal.Decimal('1000000000000000000.050000000000000000'),
             decimal.Decimal('1000000000000000000.059999999999999999'))
     ]
     source = self.t_env.from_elements(
         [(1, None, 1, True, 32767, -2147483648, 1.23, 1.98932,
           bytearray(b'pyflink'), 'pyflink', datetime.date(2014, 9, 13),
           datetime.time(hour=12, minute=0, second=0, microsecond=123000),
           datetime.datetime(2018, 3, 11, 3, 0, 0, 123000),
           [Row(['pyflink']),
            Row(['pyflink']),
            Row(['pyflink'])], {
                1: Row(['flink']),
                2: Row(['pyflink'])
            }, decimal.Decimal('1000000000000000000.05'),
           decimal.Decimal(
               '1000000000000000000.05999999999999999899999999999'))],
         DataTypes.ROW([
             DataTypes.FIELD("a", DataTypes.BIGINT()),
             DataTypes.FIELD("b", DataTypes.BIGINT()),
             DataTypes.FIELD("c", DataTypes.TINYINT()),
             DataTypes.FIELD("d", DataTypes.BOOLEAN()),
             DataTypes.FIELD("e", DataTypes.SMALLINT()),
             DataTypes.FIELD("f", DataTypes.INT()),
             DataTypes.FIELD("g", DataTypes.FLOAT()),
             DataTypes.FIELD("h", DataTypes.DOUBLE()),
             DataTypes.FIELD("i", DataTypes.BYTES()),
             DataTypes.FIELD("j", DataTypes.STRING()),
             DataTypes.FIELD("k", DataTypes.DATE()),
             DataTypes.FIELD("l", DataTypes.TIME()),
             DataTypes.FIELD("m", DataTypes.TIMESTAMP(3)),
             DataTypes.FIELD(
                 "n",
                 DataTypes.ARRAY(
                     DataTypes.ROW(
                         [DataTypes.FIELD('ss2', DataTypes.STRING())]))),
             DataTypes.FIELD(
                 "o",
                 DataTypes.MAP(
                     DataTypes.BIGINT(),
                     DataTypes.ROW(
                         [DataTypes.FIELD('ss', DataTypes.STRING())]))),
             DataTypes.FIELD("p", DataTypes.DECIMAL(38, 18)),
             DataTypes.FIELD("q", DataTypes.DECIMAL(38, 18))
         ]))
     table_result = source.execute()
     with table_result.collect() as result:
         collected_result = []
         for i in result:
             collected_result.append(i)
         self.assertEqual(expected_result, collected_result)
Пример #15
0
    def test_verify_type_not_nullable(self):
        import array
        import datetime
        import decimal

        schema = DataTypes.ROW([
            DataTypes.FIELD('s', DataTypes.STRING(nullable=False)),
            DataTypes.FIELD('i', DataTypes.INT(True))
        ])

        class MyObj:
            def __init__(self, **kwargs):
                for k, v in kwargs.items():
                    setattr(self, k, v)

        # obj, data_type
        success_spec = [
            # String
            ("", DataTypes.STRING()),
            (u"", DataTypes.STRING()),

            # UDT
            (ExamplePoint(1.0, 2.0), ExamplePointUDT()),

            # Boolean
            (True, DataTypes.BOOLEAN()),

            # TinyInt
            (-(2**7), DataTypes.TINYINT()),
            (2**7 - 1, DataTypes.TINYINT()),

            # SmallInt
            (-(2**15), DataTypes.SMALLINT()),
            (2**15 - 1, DataTypes.SMALLINT()),

            # Int
            (-(2**31), DataTypes.INT()),
            (2**31 - 1, DataTypes.INT()),

            # BigInt
            (2**64, DataTypes.BIGINT()),

            # Float & Double
            (1.0, DataTypes.FLOAT()),
            (1.0, DataTypes.DOUBLE()),

            # Decimal
            (decimal.Decimal("1.0"), DataTypes.DECIMAL(10, 0)),

            # Binary
            (bytearray([1]), DataTypes.BINARY(1)),

            # Date/Time/Timestamp
            (datetime.date(2000, 1, 2), DataTypes.DATE()),
            (datetime.datetime(2000, 1, 2, 3, 4), DataTypes.DATE()),
            (datetime.time(1, 1, 2), DataTypes.TIME()),
            (datetime.datetime(2000, 1, 2, 3, 4), DataTypes.TIMESTAMP()),

            # Array
            ([], DataTypes.ARRAY(DataTypes.INT())),
            (["1", None], DataTypes.ARRAY(DataTypes.STRING(nullable=True))),
            ([1, 2], DataTypes.ARRAY(DataTypes.INT())),
            ((1, 2), DataTypes.ARRAY(DataTypes.INT())),
            (array.array('h', [1, 2]), DataTypes.ARRAY(DataTypes.INT())),

            # Map
            ({}, DataTypes.MAP(DataTypes.STRING(), DataTypes.INT())),
            ({
                "a": 1
            }, DataTypes.MAP(DataTypes.STRING(), DataTypes.INT())),
            ({
                "a": None
            },
             DataTypes.MAP(DataTypes.STRING(nullable=False),
                           DataTypes.INT(True))),

            # Struct
            ({
                "s": "a",
                "i": 1
            }, schema),
            ({
                "s": "a",
                "i": None
            }, schema),
            ({
                "s": "a"
            }, schema),
            ({
                "s": "a",
                "f": 1.0
            }, schema),
            (Row(s="a", i=1), schema),
            (Row(s="a", i=None), schema),
            (Row(s="a", i=1, f=1.0), schema),
            (["a", 1], schema),
            (["a", None], schema),
            (("a", 1), schema),
            (MyObj(s="a", i=1), schema),
            (MyObj(s="a", i=None), schema),
            (MyObj(s="a"), schema),
        ]

        # obj, data_type, exception class
        failure_spec = [
            # Char/VarChar (match anything but None)
            (None, DataTypes.VARCHAR(1), ValueError),
            (None, DataTypes.CHAR(1), ValueError),

            # VarChar (length exceeds maximum length)
            ("abc", DataTypes.VARCHAR(1), ValueError),
            # Char (length exceeds length)
            ("abc", DataTypes.CHAR(1), ValueError),

            # UDT
            (ExamplePoint(1.0, 2.0), PythonOnlyUDT(), ValueError),

            # Boolean
            (1, DataTypes.BOOLEAN(), TypeError),
            ("True", DataTypes.BOOLEAN(), TypeError),
            ([1], DataTypes.BOOLEAN(), TypeError),

            # TinyInt
            (-(2**7) - 1, DataTypes.TINYINT(), ValueError),
            (2**7, DataTypes.TINYINT(), ValueError),
            ("1", DataTypes.TINYINT(), TypeError),
            (1.0, DataTypes.TINYINT(), TypeError),

            # SmallInt
            (-(2**15) - 1, DataTypes.SMALLINT(), ValueError),
            (2**15, DataTypes.SMALLINT(), ValueError),

            # Int
            (-(2**31) - 1, DataTypes.INT(), ValueError),
            (2**31, DataTypes.INT(), ValueError),

            # Float & Double
            (1, DataTypes.FLOAT(), TypeError),
            (1, DataTypes.DOUBLE(), TypeError),

            # Decimal
            (1.0, DataTypes.DECIMAL(10, 0), TypeError),
            (1, DataTypes.DECIMAL(10, 0), TypeError),
            ("1.0", DataTypes.DECIMAL(10, 0), TypeError),

            # Binary
            (1, DataTypes.BINARY(1), TypeError),
            # VarBinary (length exceeds maximum length)
            (bytearray([1, 2]), DataTypes.VARBINARY(1), ValueError),
            # Char (length exceeds length)
            (bytearray([1, 2]), DataTypes.BINARY(1), ValueError),

            # Date/Time/Timestamp
            ("2000-01-02", DataTypes.DATE(), TypeError),
            ("10:01:02", DataTypes.TIME(), TypeError),
            (946811040, DataTypes.TIMESTAMP(), TypeError),

            # Array
            (["1", None], DataTypes.ARRAY(DataTypes.VARCHAR(1,
                                                            nullable=False)),
             ValueError),
            ([1, "2"], DataTypes.ARRAY(DataTypes.INT()), TypeError),

            # Map
            ({
                "a": 1
            }, DataTypes.MAP(DataTypes.INT(), DataTypes.INT()), TypeError),
            ({
                "a": "1"
            }, DataTypes.MAP(DataTypes.VARCHAR(1),
                             DataTypes.INT()), TypeError),
            ({
                "a": None
            }, DataTypes.MAP(DataTypes.VARCHAR(1),
                             DataTypes.INT(False)), ValueError),

            # Struct
            ({
                "s": "a",
                "i": "1"
            }, schema, TypeError),
            (Row(s="a"), schema, ValueError),  # Row can't have missing field
            (Row(s="a", i="1"), schema, TypeError),
            (["a"], schema, ValueError),
            (["a", "1"], schema, TypeError),
            (MyObj(s="a", i="1"), schema, TypeError),
            (MyObj(s=None, i="1"), schema, ValueError),
        ]

        # Check success cases
        for obj, data_type in success_spec:
            try:
                _create_type_verifier(data_type.not_null())(obj)
            except (TypeError, ValueError):
                self.fail("verify_type(%s, %s, nullable=False)" %
                          (obj, data_type))

        # Check failure cases
        for obj, data_type, exp in failure_spec:
            msg = "verify_type(%s, %s, nullable=False) == %s" % (
                obj, data_type, exp)
            with self.assertRaises(exp, msg=msg):
                _create_type_verifier(data_type.not_null())(obj)
Пример #16
0
 def test_invalid_create_row(self):
     row_class = Row("c1", "c2")
     self.assertRaises(ValueError, lambda: row_class(1, 2, 3))
Пример #17
0
 def test_empty_row(self):
     row = Row()
     self.assertEqual(len(row), 0)
Пример #18
0
    def test_array_types(self):
        # This test need to make sure that the Scala type selected is at least
        # as large as the python's types. This is necessary because python's
        # array types depend on C implementation on the machine. Therefore there
        # is no machine independent correspondence between python's array types
        # and Scala types.
        # See: https://docs.python.org/2/library/array.html

        def assert_collect_success(typecode, value, element_type):
            self.assertEqual(
                element_type,
                str(_infer_type(array.array(typecode, [value])).element_type))

        # supported string types
        #
        # String types in python's array are "u" for Py_UNICODE and "c" for char.
        # "u" will be removed in python 4, and "c" is not supported in python 3.
        supported_string_types = []
        if sys.version_info[0] < 4:
            supported_string_types += ['u']
            # test unicode
            assert_collect_success('u', u'a', 'CHAR')

        # supported float and double
        #
        # Test max, min, and precision for float and double, assuming IEEE 754
        # floating-point format.
        supported_fractional_types = ['f', 'd']
        assert_collect_success('f', ctypes.c_float(1e+38).value, 'FLOAT')
        assert_collect_success('f', ctypes.c_float(1e-38).value, 'FLOAT')
        assert_collect_success('f', ctypes.c_float(1.123456).value, 'FLOAT')
        assert_collect_success('d', sys.float_info.max, 'DOUBLE')
        assert_collect_success('d', sys.float_info.min, 'DOUBLE')
        assert_collect_success('d', sys.float_info.epsilon, 'DOUBLE')

        def get_int_data_type(size):
            if size <= 8:
                return "TINYINT"
            if size <= 16:
                return "SMALLINT"
            if size <= 32:
                return "INT"
            if size <= 64:
                return "BIGINT"

        # supported signed int types
        #
        # The size of C types changes with implementation, we need to make sure
        # that there is no overflow error on the platform running this test.
        supported_signed_int_types = list(
            set(_array_signed_int_typecode_ctype_mappings.keys()).intersection(
                set(_array_type_mappings.keys())))
        for t in supported_signed_int_types:
            ctype = _array_signed_int_typecode_ctype_mappings[t]
            max_val = 2**(ctypes.sizeof(ctype) * 8 - 1)
            assert_collect_success(t, max_val - 1,
                                   get_int_data_type(ctypes.sizeof(ctype) * 8))
            assert_collect_success(t, -max_val,
                                   get_int_data_type(ctypes.sizeof(ctype) * 8))

        # supported unsigned int types
        #
        # JVM does not have unsigned types. We need to be very careful to make
        # sure that there is no overflow error.
        supported_unsigned_int_types = list(
            set(_array_unsigned_int_typecode_ctype_mappings.keys()).
            intersection(set(_array_type_mappings.keys())))
        for t in supported_unsigned_int_types:
            ctype = _array_unsigned_int_typecode_ctype_mappings[t]
            max_val = 2**(ctypes.sizeof(ctype) * 8 - 1)
            assert_collect_success(
                t, max_val, get_int_data_type(ctypes.sizeof(ctype) * 8 + 1))

        # all supported types
        #
        # Make sure the types tested above:
        # 1. are all supported types
        # 2. cover all supported types
        supported_types = (supported_string_types +
                           supported_fractional_types +
                           supported_signed_int_types +
                           supported_unsigned_int_types)
        self.assertEqual(set(supported_types),
                         set(_array_type_mappings.keys()))

        # all unsupported types
        #
        # Keys in _array_type_mappings is a complete list of all supported types,
        # and types not in _array_type_mappings are considered unsupported.
        all_types = set(array.typecodes)
        unsupported_types = all_types - set(supported_types)
        # test unsupported types
        for t in unsupported_types:
            with self.assertRaises(TypeError):
                _infer_schema_from_data([Row(myarray=array.array(t))])
Пример #19
0
 def test_convert_row_to_dict(self):
     row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
     self.assertEqual(1, row.as_dict()['l'][0].a)
     self.assertEqual(1.0, row.as_dict()['d']['key'].c)
Пример #20
0
 def decode_from_stream(self, in_stream, nested):
     return Row(*self._decode_one_row_from_stream(in_stream, nested))