def input_output_table():
        stream_env = StreamExecutionEnvironment.get_execution_environment()
        table_env = StreamTableEnvironment.create(stream_env)
        statement_set = table_env.create_statement_set()
        work_num = 2
        ps_num = 1
        python_file = os.getcwd() + "/../../src/test/python/input_output.py"
        prop = {}
        func = "map_func"
        env_path = None
        prop[
            MLCONSTANTS.
            ENCODING_CLASS] = "org.flinkextended.flink.ml.operator.coding.RowCSVCoding"
        prop[
            MLCONSTANTS.
            DECODING_CLASS] = "org.flinkextended.flink.ml.operator.coding.RowCSVCoding"
        inputSb = "INT_32" + "," + "INT_64" + "," + "FLOAT_32" + "," + "FLOAT_64" + "," + "STRING"
        prop["sys:csv_encode_types"] = inputSb
        prop["sys:csv_decode_types"] = inputSb
        prop[MLCONSTANTS.PYTHON_VERSION] = "3.7"
        source_file = os.getcwd() + "/../../src/test/resources/input.csv"
        sink_file = os.getcwd() + "/../../src/test/resources/output.csv"
        table_source = CsvTableSource(source_file, ["a", "b", "c", "d", "e"], [
            DataTypes.INT(),
            DataTypes.BIGINT(),
            DataTypes.FLOAT(),
            DataTypes.DOUBLE(),
            DataTypes.STRING()
        ])
        table_env.register_table_source("source", table_source)
        input_tb = table_env.from_path("source")
        output_schema = TableSchema(["a", "b", "c", "d", "e"], [
            DataTypes.INT(),
            DataTypes.BIGINT(),
            DataTypes.FLOAT(),
            DataTypes.DOUBLE(),
            DataTypes.STRING()
        ])
        sink = CsvTableSink(["a", "b", "c", "d", "e"], [
            DataTypes.INT(),
            DataTypes.BIGINT(),
            DataTypes.FLOAT(),
            DataTypes.DOUBLE(),
            DataTypes.STRING()
        ],
                            sink_file,
                            write_mode=WriteMode.OVERWRITE)
        table_env.register_table_sink("table_row_sink", sink)
        tf_config = TFConfig(work_num, ps_num, prop, python_file, func,
                             env_path)
        output_table = train(stream_env, table_env, statement_set, input_tb,
                             tf_config, output_schema)

        # output_table = inference(stream_env, table_env, statement_set, input_tb, tf_config, output_schema)

        statement_set.add_insert("table_row_sink", output_table)
        job_client = statement_set.execute().get_job_client()
        if job_client is not None:
            job_client.get_job_execution_result(
                user_class_loader=None).result()
Пример #2
0
    def test_fields(self):
        fields = collections.OrderedDict([
            ("int_field", DataTypes.INT()), ("long_field", DataTypes.BIGINT()),
            ("string_field", DataTypes.STRING()),
            ("timestamp_field", DataTypes.TIMESTAMP(3)),
            ("time_field", DataTypes.TIME()), ("date_field", DataTypes.DATE()),
            ("double_field", DataTypes.DOUBLE()),
            ("float_field", DataTypes.FLOAT()),
            ("byte_field", DataTypes.TINYINT()),
            ("short_field", DataTypes.SMALLINT()),
            ("boolean_field", DataTypes.BOOLEAN())
        ])

        schema = Schema().fields(fields)

        properties = schema.to_properties()
        expected = {
            'schema.0.name': 'int_field',
            'schema.0.data-type': 'INT',
            'schema.1.name': 'long_field',
            'schema.1.data-type': 'BIGINT',
            'schema.2.name': 'string_field',
            'schema.2.data-type': 'VARCHAR(2147483647)',
            'schema.3.name': 'timestamp_field',
            'schema.3.data-type': 'TIMESTAMP(3)',
            'schema.4.name': 'time_field',
            'schema.4.data-type': 'TIME(0)',
            'schema.5.name': 'date_field',
            'schema.5.data-type': 'DATE',
            'schema.6.name': 'double_field',
            'schema.6.data-type': 'DOUBLE',
            'schema.7.name': 'float_field',
            'schema.7.data-type': 'FLOAT',
            'schema.8.name': 'byte_field',
            'schema.8.data-type': 'TINYINT',
            'schema.9.name': 'short_field',
            'schema.9.data-type': 'SMALLINT',
            'schema.10.name': 'boolean_field',
            'schema.10.data-type': 'BOOLEAN'
        }
        self.assertEqual(expected, properties)

        if sys.version_info[:2] <= (3, 5):
            fields = {
                "int_field": DataTypes.INT(),
                "long_field": DataTypes.BIGINT(),
                "string_field": DataTypes.STRING(),
                "timestamp_field": DataTypes.TIMESTAMP(3),
                "time_field": DataTypes.TIME(),
                "date_field": DataTypes.DATE(),
                "double_field": DataTypes.DOUBLE(),
                "float_field": DataTypes.FLOAT(),
                "byte_field": DataTypes.TINYINT(),
                "short_field": DataTypes.SMALLINT(),
                "boolean_field": DataTypes.BOOLEAN()
            }
            self.assertRaises(TypeError, Schema().fields, fields)
Пример #3
0
 def setUpClass(cls):
     super(PandasConversionTestBase, cls).setUpClass()
     cls.data = [(1, 1, 1, 1, True, 1.1, 1.2, 'hello', bytearray(b"aaa"),
                  decimal.Decimal('1000000000000000000.01'), datetime.date(2014, 9, 13),
                  datetime.time(hour=1, minute=0, second=1),
                  datetime.datetime(1970, 1, 1, 0, 0, 0, 123000), ['hello', '中文'],
                  Row(a=1, b='hello', c=datetime.datetime(1970, 1, 1, 0, 0, 0, 123000),
                      d=[1, 2])),
                 (1, 2, 2, 2, False, 2.1, 2.2, 'world', bytearray(b"bbb"),
                  decimal.Decimal('1000000000000000000.02'), datetime.date(2014, 9, 13),
                  datetime.time(hour=1, minute=0, second=1),
                  datetime.datetime(1970, 1, 1, 0, 0, 0, 123000), ['hello', '中文'],
                  Row(a=1, b='hello', c=datetime.datetime(1970, 1, 1, 0, 0, 0, 123000),
                      d=[1, 2]))]
     cls.data_type = DataTypes.ROW(
         [DataTypes.FIELD("f1", DataTypes.TINYINT()),
          DataTypes.FIELD("f2", DataTypes.SMALLINT()),
          DataTypes.FIELD("f3", DataTypes.INT()),
          DataTypes.FIELD("f4", DataTypes.BIGINT()),
          DataTypes.FIELD("f5", DataTypes.BOOLEAN()),
          DataTypes.FIELD("f6", DataTypes.FLOAT()),
          DataTypes.FIELD("f7", DataTypes.DOUBLE()),
          DataTypes.FIELD("f8", DataTypes.STRING()),
          DataTypes.FIELD("f9", DataTypes.BYTES()),
          DataTypes.FIELD("f10", DataTypes.DECIMAL(38, 18)),
          DataTypes.FIELD("f11", DataTypes.DATE()),
          DataTypes.FIELD("f12", DataTypes.TIME()),
          DataTypes.FIELD("f13", DataTypes.TIMESTAMP(3)),
          DataTypes.FIELD("f14", DataTypes.ARRAY(DataTypes.STRING())),
          DataTypes.FIELD("f15", DataTypes.ROW(
              [DataTypes.FIELD("a", DataTypes.INT()),
               DataTypes.FIELD("b", DataTypes.STRING()),
               DataTypes.FIELD("c", DataTypes.TIMESTAMP(3)),
               DataTypes.FIELD("d", DataTypes.ARRAY(DataTypes.INT()))]))], False)
     cls.pdf = cls.create_pandas_data_frame()
Пример #4
0
 def test_csv_primitive_column(self):
     schema = CsvSchema.builder() \
         .add_number_column('tinyint', DataTypes.TINYINT()) \
         .add_number_column('smallint', DataTypes.SMALLINT()) \
         .add_number_column('int', DataTypes.INT()) \
         .add_number_column('bigint', DataTypes.BIGINT()) \
         .add_number_column('float', DataTypes.FLOAT()) \
         .add_number_column('double', DataTypes.DOUBLE()) \
         .add_number_column('decimal', DataTypes.DECIMAL(2, 0)) \
         .add_boolean_column('boolean') \
         .add_string_column('string') \
         .build()
     with open(self.csv_file_name, 'w') as f:
         f.write('127,')
         f.write('-32767,')
         f.write('2147483647,')
         f.write('-9223372036854775808,')
         f.write('3e38,')
         f.write('2e-308,')
         f.write('1.5,')
         f.write('true,')
         f.write('string\n')
     self._build_csv_job(schema)
     self.env.execute('test_csv_primitive_column')
     row = self.test_sink.get_results(True, False)[0]
     self.assertEqual(row['tinyint'], 127)
     self.assertEqual(row['smallint'], -32767)
     self.assertEqual(row['int'], 2147483647)
     self.assertEqual(row['bigint'], -9223372036854775808)
     self.assertAlmostEqual(row['float'], 3e38, delta=1e31)
     self.assertAlmostEqual(row['double'], 2e-308, delta=2e-301)
     self.assertAlmostEqual(row['decimal'], 2)
     self.assertEqual(row['boolean'], True)
     self.assertEqual(row['string'], 'string')
Пример #5
0
def _create_orc_basic_row_and_data() -> Tuple[RowType, RowTypeInfo, List[Row]]:
    row_type = DataTypes.ROW([
        DataTypes.FIELD('char', DataTypes.CHAR(10)),
        DataTypes.FIELD('varchar', DataTypes.VARCHAR(10)),
        DataTypes.FIELD('bytes', DataTypes.BYTES()),
        DataTypes.FIELD('boolean', DataTypes.BOOLEAN()),
        DataTypes.FIELD('decimal', DataTypes.DECIMAL(2, 0)),
        DataTypes.FIELD('int', DataTypes.INT()),
        DataTypes.FIELD('bigint', DataTypes.BIGINT()),
        DataTypes.FIELD('double', DataTypes.DOUBLE()),
        DataTypes.FIELD('date', DataTypes.DATE().bridged_to('java.sql.Date')),
        DataTypes.FIELD('timestamp', DataTypes.TIMESTAMP(3).bridged_to('java.sql.Timestamp')),
    ])
    row_type_info = Types.ROW_NAMED(
        ['char', 'varchar', 'bytes', 'boolean', 'decimal', 'int', 'bigint', 'double',
         'date', 'timestamp'],
        [Types.STRING(), Types.STRING(), Types.PRIMITIVE_ARRAY(Types.BYTE()), Types.BOOLEAN(),
         Types.BIG_DEC(), Types.INT(), Types.LONG(), Types.DOUBLE(), Types.SQL_DATE(),
         Types.SQL_TIMESTAMP()]
    )
    data = [Row(
        char='char',
        varchar='varchar',
        bytes=b'varbinary',
        boolean=True,
        decimal=Decimal(1.5),
        int=2147483647,
        bigint=-9223372036854775808,
        double=2e-308,
        date=date(1970, 1, 1),
        timestamp=datetime(1970, 1, 2, 3, 4, 5, 600000),
    )]
    return row_type, row_type_info, data
 def inputOutputTable():
     stream_env = StreamExecutionEnvironment.get_execution_environment()
     table_env = StreamTableEnvironment.create(stream_env)
     work_num = 2
     ps_num = 1
     python_file = os.getcwd() + "/../../src/test/python/input_output.py"
     property = {}
     func = "map_func"
     env_path = None
     zk_conn = None
     zk_base_path = None
     property[
         MLCONSTANTS.
         ENCODING_CLASS] = "com.alibaba.flink.ml.operator.coding.RowCSVCoding"
     property[
         MLCONSTANTS.
         DECODING_CLASS] = "com.alibaba.flink.ml.operator.coding.RowCSVCoding"
     inputSb = "INT_32" + "," + "INT_64" + "," + "FLOAT_32" + "," + "FLOAT_64" + "," + "STRING"
     property["SYS:csv_encode_types"] = inputSb
     property["SYS:csv_decode_types"] = inputSb
     source_file = os.getcwd() + "/../../src/test/resources/input.csv"
     table_source = CsvTableSource(source_file, ["a", "b", "c", "d", "e"], [
         DataTypes.INT(),
         DataTypes.INT(),
         DataTypes.FLOAT(),
         DataTypes.DOUBLE(),
         DataTypes.STRING()
     ])
     table_env.register_table_source("source", table_source)
     input_tb = table_env.scan("source")
     output_schema = TableSchema(["a", "b", "c", "d", "e"], [
         DataTypes.INT(),
         DataTypes.INT(),
         DataTypes.FLOAT(),
         DataTypes.DOUBLE(),
         DataTypes.STRING()
     ])
     train(work_num, ps_num, python_file, func, property, env_path, zk_conn,
           zk_base_path, stream_env, table_env, input_tb, output_schema)
Пример #7
0
def _create_parquet_basic_row_and_data() -> Tuple[RowType, RowTypeInfo, List[Row]]:
    row_type = DataTypes.ROW([
        DataTypes.FIELD('char', DataTypes.CHAR(10)),
        DataTypes.FIELD('varchar', DataTypes.VARCHAR(10)),
        DataTypes.FIELD('binary', DataTypes.BINARY(10)),
        DataTypes.FIELD('varbinary', DataTypes.VARBINARY(10)),
        DataTypes.FIELD('boolean', DataTypes.BOOLEAN()),
        DataTypes.FIELD('decimal', DataTypes.DECIMAL(2, 0)),
        DataTypes.FIELD('int', DataTypes.INT()),
        DataTypes.FIELD('bigint', DataTypes.BIGINT()),
        DataTypes.FIELD('double', DataTypes.DOUBLE()),
        DataTypes.FIELD('date', DataTypes.DATE().bridged_to('java.sql.Date')),
        DataTypes.FIELD('time', DataTypes.TIME().bridged_to('java.sql.Time')),
        DataTypes.FIELD('timestamp', DataTypes.TIMESTAMP(3).bridged_to('java.sql.Timestamp')),
        DataTypes.FIELD('timestamp_ltz', DataTypes.TIMESTAMP_LTZ(3)),
    ])
    row_type_info = Types.ROW_NAMED(
        ['char', 'varchar', 'binary', 'varbinary', 'boolean', 'decimal', 'int', 'bigint', 'double',
         'date', 'time', 'timestamp', 'timestamp_ltz'],
        [Types.STRING(), Types.STRING(), Types.PRIMITIVE_ARRAY(Types.BYTE()),
         Types.PRIMITIVE_ARRAY(Types.BYTE()), Types.BOOLEAN(), Types.BIG_DEC(), Types.INT(),
         Types.LONG(), Types.DOUBLE(), Types.SQL_DATE(), Types.SQL_TIME(), Types.SQL_TIMESTAMP(),
         Types.INSTANT()]
    )
    datetime_ltz = datetime.datetime(1970, 2, 3, 4, 5, 6, 700000, tzinfo=pytz.timezone('UTC'))
    timestamp_ltz = Instant.of_epoch_milli(
        (
            calendar.timegm(datetime_ltz.utctimetuple()) +
            calendar.timegm(time.localtime(0))
        ) * 1000 + datetime_ltz.microsecond // 1000
    )
    data = [Row(
        char='char',
        varchar='varchar',
        binary=b'binary',
        varbinary=b'varbinary',
        boolean=True,
        decimal=Decimal(1.5),
        int=2147483647,
        bigint=-9223372036854775808,
        double=2e-308,
        date=datetime.date(1970, 1, 1),
        time=datetime.time(1, 1, 1),
        timestamp=datetime.datetime(1970, 1, 2, 3, 4, 5, 600000),
        timestamp_ltz=timestamp_ltz
    )]
    return row_type, row_type_info, data
Пример #8
0
    def test_basic_type(self):
        test_types = [DataTypes.STRING(),
                      DataTypes.BOOLEAN(),
                      DataTypes.BYTES(),
                      DataTypes.TINYINT(),
                      DataTypes.SMALLINT(),
                      DataTypes.INT(),
                      DataTypes.BIGINT(),
                      DataTypes.FLOAT(),
                      DataTypes.DOUBLE(),
                      DataTypes.DATE(),
                      DataTypes.TIME(),
                      DataTypes.TIMESTAMP(3)]

        java_types = [_to_java_type(item) for item in test_types]

        converted_python_types = [_from_java_type(item) for item in java_types]

        self.assertEqual(test_types, converted_python_types)
Пример #9
0
    def test_field(self):
        schema = Schema()

        schema = schema\
            .field("int_field", DataTypes.INT())\
            .field("long_field", DataTypes.BIGINT())\
            .field("string_field", DataTypes.STRING())\
            .field("timestamp_field", DataTypes.TIMESTAMP())\
            .field("time_field", DataTypes.TIME())\
            .field("date_field", DataTypes.DATE())\
            .field("double_field", DataTypes.DOUBLE())\
            .field("float_field", DataTypes.FLOAT())\
            .field("byte_field", DataTypes.TINYINT())\
            .field("short_field", DataTypes.SMALLINT())\
            .field("boolean_field", DataTypes.BOOLEAN())

        properties = schema.to_properties()
        expected = {
            'schema.0.name': 'int_field',
            'schema.0.type': 'INT',
            'schema.1.name': 'long_field',
            'schema.1.type': 'BIGINT',
            'schema.2.name': 'string_field',
            'schema.2.type': 'VARCHAR',
            'schema.3.name': 'timestamp_field',
            'schema.3.type': 'TIMESTAMP',
            'schema.4.name': 'time_field',
            'schema.4.type': 'TIME',
            'schema.5.name': 'date_field',
            'schema.5.type': 'DATE',
            'schema.6.name': 'double_field',
            'schema.6.type': 'DOUBLE',
            'schema.7.name': 'float_field',
            'schema.7.type': 'FLOAT',
            'schema.8.name': 'byte_field',
            'schema.8.type': 'TINYINT',
            'schema.9.name': 'short_field',
            'schema.9.type': 'SMALLINT',
            'schema.10.name': 'boolean_field',
            'schema.10.type': 'BOOLEAN'
        }
        assert properties == expected
Пример #10
0
    def test_field(self):
        schema = Schema()\
            .field("int_field", DataTypes.INT())\
            .field("long_field", DataTypes.BIGINT())\
            .field("string_field", DataTypes.STRING())\
            .field("timestamp_field", DataTypes.TIMESTAMP(3))\
            .field("time_field", DataTypes.TIME())\
            .field("date_field", DataTypes.DATE())\
            .field("double_field", DataTypes.DOUBLE())\
            .field("float_field", DataTypes.FLOAT())\
            .field("byte_field", DataTypes.TINYINT())\
            .field("short_field", DataTypes.SMALLINT())\
            .field("boolean_field", DataTypes.BOOLEAN())

        properties = schema.to_properties()
        expected = {
            'schema.0.name': 'int_field',
            'schema.0.data-type': 'INT',
            'schema.1.name': 'long_field',
            'schema.1.data-type': 'BIGINT',
            'schema.2.name': 'string_field',
            'schema.2.data-type': 'VARCHAR(2147483647)',
            'schema.3.name': 'timestamp_field',
            'schema.3.data-type': 'TIMESTAMP(3)',
            'schema.4.name': 'time_field',
            'schema.4.data-type': 'TIME(0)',
            'schema.5.name': 'date_field',
            'schema.5.data-type': 'DATE',
            'schema.6.name': 'double_field',
            'schema.6.data-type': 'DOUBLE',
            'schema.7.name': 'float_field',
            'schema.7.data-type': 'FLOAT',
            'schema.8.name': 'byte_field',
            'schema.8.data-type': 'TINYINT',
            'schema.9.name': 'short_field',
            'schema.9.data-type': 'SMALLINT',
            'schema.10.name': 'boolean_field',
            'schema.10.data-type': 'BOOLEAN'
        }
        self.assertEqual(expected, properties)
Пример #11
0
 def test_parquet_columnar_basic(self):
     parquet_file_name = tempfile.mktemp(suffix='.parquet',
                                         dir=self.tempdir)
     schema, records = _create_basic_avro_schema_and_records()
     FileSourceParquetAvroFormatTests._create_parquet_avro_file(
         parquet_file_name, schema, records)
     row_type = DataTypes.ROW([
         DataTypes.FIELD(
             'null',
             DataTypes.STRING()),  # DataTypes.NULL cannot be serialized
         DataTypes.FIELD('boolean', DataTypes.BOOLEAN()),
         DataTypes.FIELD('int', DataTypes.INT()),
         DataTypes.FIELD('long', DataTypes.BIGINT()),
         DataTypes.FIELD('float', DataTypes.FLOAT()),
         DataTypes.FIELD('double', DataTypes.DOUBLE()),
         DataTypes.FIELD('string', DataTypes.STRING()),
         DataTypes.FIELD('unknown', DataTypes.STRING())
     ])
     self._build_parquet_columnar_job(row_type, parquet_file_name)
     self.env.execute('test_parquet_columnar_basic')
     results = self.test_sink.get_results(True, False)
     _check_basic_avro_schema_results(self, results)
     self.assertIsNone(results[0]['unknown'])
     self.assertIsNone(results[1]['unknown'])
Пример #12
0
 def sql_type(cls):
     return DataTypes.ARRAY(DataTypes.DOUBLE(False))
Пример #13
0
    def test_verify_type_not_nullable(self):
        import array
        import datetime
        import decimal

        schema = DataTypes.ROW([
            DataTypes.FIELD('s', DataTypes.STRING(nullable=False)),
            DataTypes.FIELD('i', DataTypes.INT(True))
        ])

        class MyObj:
            def __init__(self, **kwargs):
                for k, v in kwargs.items():
                    setattr(self, k, v)

        # obj, data_type
        success_spec = [
            # String
            ("", DataTypes.STRING()),
            (u"", DataTypes.STRING()),

            # UDT
            (ExamplePoint(1.0, 2.0), ExamplePointUDT()),

            # Boolean
            (True, DataTypes.BOOLEAN()),

            # TinyInt
            (-(2**7), DataTypes.TINYINT()),
            (2**7 - 1, DataTypes.TINYINT()),

            # SmallInt
            (-(2**15), DataTypes.SMALLINT()),
            (2**15 - 1, DataTypes.SMALLINT()),

            # Int
            (-(2**31), DataTypes.INT()),
            (2**31 - 1, DataTypes.INT()),

            # BigInt
            (2**64, DataTypes.BIGINT()),

            # Float & Double
            (1.0, DataTypes.FLOAT()),
            (1.0, DataTypes.DOUBLE()),

            # Decimal
            (decimal.Decimal("1.0"), DataTypes.DECIMAL(10, 0)),

            # Binary
            (bytearray([1]), DataTypes.BINARY(1)),

            # Date/Time/Timestamp
            (datetime.date(2000, 1, 2), DataTypes.DATE()),
            (datetime.datetime(2000, 1, 2, 3, 4), DataTypes.DATE()),
            (datetime.time(1, 1, 2), DataTypes.TIME()),
            (datetime.datetime(2000, 1, 2, 3, 4), DataTypes.TIMESTAMP()),

            # Array
            ([], DataTypes.ARRAY(DataTypes.INT())),
            (["1", None], DataTypes.ARRAY(DataTypes.STRING(nullable=True))),
            ([1, 2], DataTypes.ARRAY(DataTypes.INT())),
            ((1, 2), DataTypes.ARRAY(DataTypes.INT())),
            (array.array('h', [1, 2]), DataTypes.ARRAY(DataTypes.INT())),

            # Map
            ({}, DataTypes.MAP(DataTypes.STRING(), DataTypes.INT())),
            ({
                "a": 1
            }, DataTypes.MAP(DataTypes.STRING(), DataTypes.INT())),
            ({
                "a": None
            },
             DataTypes.MAP(DataTypes.STRING(nullable=False),
                           DataTypes.INT(True))),

            # Struct
            ({
                "s": "a",
                "i": 1
            }, schema),
            ({
                "s": "a",
                "i": None
            }, schema),
            ({
                "s": "a"
            }, schema),
            ({
                "s": "a",
                "f": 1.0
            }, schema),
            (Row(s="a", i=1), schema),
            (Row(s="a", i=None), schema),
            (Row(s="a", i=1, f=1.0), schema),
            (["a", 1], schema),
            (["a", None], schema),
            (("a", 1), schema),
            (MyObj(s="a", i=1), schema),
            (MyObj(s="a", i=None), schema),
            (MyObj(s="a"), schema),
        ]

        # obj, data_type, exception class
        failure_spec = [
            # Char/VarChar (match anything but None)
            (None, DataTypes.VARCHAR(1), ValueError),
            (None, DataTypes.CHAR(1), ValueError),

            # VarChar (length exceeds maximum length)
            ("abc", DataTypes.VARCHAR(1), ValueError),
            # Char (length exceeds length)
            ("abc", DataTypes.CHAR(1), ValueError),

            # UDT
            (ExamplePoint(1.0, 2.0), PythonOnlyUDT(), ValueError),

            # Boolean
            (1, DataTypes.BOOLEAN(), TypeError),
            ("True", DataTypes.BOOLEAN(), TypeError),
            ([1], DataTypes.BOOLEAN(), TypeError),

            # TinyInt
            (-(2**7) - 1, DataTypes.TINYINT(), ValueError),
            (2**7, DataTypes.TINYINT(), ValueError),
            ("1", DataTypes.TINYINT(), TypeError),
            (1.0, DataTypes.TINYINT(), TypeError),

            # SmallInt
            (-(2**15) - 1, DataTypes.SMALLINT(), ValueError),
            (2**15, DataTypes.SMALLINT(), ValueError),

            # Int
            (-(2**31) - 1, DataTypes.INT(), ValueError),
            (2**31, DataTypes.INT(), ValueError),

            # Float & Double
            (1, DataTypes.FLOAT(), TypeError),
            (1, DataTypes.DOUBLE(), TypeError),

            # Decimal
            (1.0, DataTypes.DECIMAL(10, 0), TypeError),
            (1, DataTypes.DECIMAL(10, 0), TypeError),
            ("1.0", DataTypes.DECIMAL(10, 0), TypeError),

            # Binary
            (1, DataTypes.BINARY(1), TypeError),
            # VarBinary (length exceeds maximum length)
            (bytearray([1, 2]), DataTypes.VARBINARY(1), ValueError),
            # Char (length exceeds length)
            (bytearray([1, 2]), DataTypes.BINARY(1), ValueError),

            # Date/Time/Timestamp
            ("2000-01-02", DataTypes.DATE(), TypeError),
            ("10:01:02", DataTypes.TIME(), TypeError),
            (946811040, DataTypes.TIMESTAMP(), TypeError),

            # Array
            (["1", None], DataTypes.ARRAY(DataTypes.VARCHAR(1,
                                                            nullable=False)),
             ValueError),
            ([1, "2"], DataTypes.ARRAY(DataTypes.INT()), TypeError),

            # Map
            ({
                "a": 1
            }, DataTypes.MAP(DataTypes.INT(), DataTypes.INT()), TypeError),
            ({
                "a": "1"
            }, DataTypes.MAP(DataTypes.VARCHAR(1),
                             DataTypes.INT()), TypeError),
            ({
                "a": None
            }, DataTypes.MAP(DataTypes.VARCHAR(1),
                             DataTypes.INT(False)), ValueError),

            # Struct
            ({
                "s": "a",
                "i": "1"
            }, schema, TypeError),
            (Row(s="a"), schema, ValueError),  # Row can't have missing field
            (Row(s="a", i="1"), schema, TypeError),
            (["a"], schema, ValueError),
            (["a", "1"], schema, TypeError),
            (MyObj(s="a", i="1"), schema, TypeError),
            (MyObj(s=None, i="1"), schema, ValueError),
        ]

        # Check success cases
        for obj, data_type in success_spec:
            try:
                _create_type_verifier(data_type.not_null())(obj)
            except (TypeError, ValueError):
                self.fail("verify_type(%s, %s, nullable=False)" %
                          (obj, data_type))

        # Check failure cases
        for obj, data_type, exp in failure_spec:
            msg = "verify_type(%s, %s, nullable=False) == %s" % (
                obj, data_type, exp)
            with self.assertRaises(exp, msg=msg):
                _create_type_verifier(data_type.not_null())(obj)
Пример #14
0
    def test_merge_type(self):
        self.assertEqual(_merge_type(DataTypes.BIGINT(), DataTypes.NULL()),
                         DataTypes.BIGINT())
        self.assertEqual(_merge_type(DataTypes.NULL(), DataTypes.BIGINT()),
                         DataTypes.BIGINT())

        self.assertEqual(_merge_type(DataTypes.BIGINT(), DataTypes.BIGINT()),
                         DataTypes.BIGINT())

        self.assertEqual(
            _merge_type(DataTypes.ARRAY(DataTypes.BIGINT()),
                        DataTypes.ARRAY(DataTypes.BIGINT())),
            DataTypes.ARRAY(DataTypes.BIGINT()))
        with self.assertRaises(TypeError):
            _merge_type(DataTypes.ARRAY(DataTypes.BIGINT()),
                        DataTypes.ARRAY(DataTypes.DOUBLE()))

        self.assertEqual(
            _merge_type(DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT()),
                        DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT())),
            DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT()))
        with self.assertRaises(TypeError):
            _merge_type(DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT()),
                        DataTypes.MAP(DataTypes.DOUBLE(), DataTypes.BIGINT()))
        with self.assertRaises(TypeError):
            _merge_type(DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT()),
                        DataTypes.MAP(DataTypes.STRING(), DataTypes.DOUBLE()))

        self.assertEqual(
            _merge_type(
                DataTypes.ROW([
                    DataTypes.FIELD('f1', DataTypes.BIGINT()),
                    DataTypes.FIELD('f2', DataTypes.STRING())
                ]),
                DataTypes.ROW([
                    DataTypes.FIELD('f1', DataTypes.BIGINT()),
                    DataTypes.FIELD('f2', DataTypes.STRING())
                ])),
            DataTypes.ROW([
                DataTypes.FIELD('f1', DataTypes.BIGINT()),
                DataTypes.FIELD('f2', DataTypes.STRING())
            ]))
        with self.assertRaises(TypeError):
            _merge_type(
                DataTypes.ROW([
                    DataTypes.FIELD('f1', DataTypes.BIGINT()),
                    DataTypes.FIELD('f2', DataTypes.STRING())
                ]),
                DataTypes.ROW([
                    DataTypes.FIELD('f1', DataTypes.DOUBLE()),
                    DataTypes.FIELD('f2', DataTypes.STRING())
                ]))

        self.assertEqual(
            _merge_type(
                DataTypes.ROW([
                    DataTypes.FIELD(
                        'f1',
                        DataTypes.ROW(
                            [DataTypes.FIELD('f2', DataTypes.BIGINT())]))
                ]),
                DataTypes.ROW([
                    DataTypes.FIELD(
                        'f1',
                        DataTypes.ROW(
                            [DataTypes.FIELD('f2', DataTypes.BIGINT())]))
                ])),
            DataTypes.ROW([
                DataTypes.FIELD(
                    'f1',
                    DataTypes.ROW([DataTypes.FIELD('f2', DataTypes.BIGINT())]))
            ]))
        with self.assertRaises(TypeError):
            _merge_type(
                DataTypes.ROW([
                    DataTypes.FIELD(
                        'f1',
                        DataTypes.ROW(
                            [DataTypes.FIELD('f2', DataTypes.BIGINT())]))
                ]),
                DataTypes.ROW([
                    DataTypes.FIELD(
                        'f1',
                        DataTypes.ROW(
                            [DataTypes.FIELD('f2', DataTypes.STRING())]))
                ]))

        self.assertEqual(
            _merge_type(
                DataTypes.ROW([
                    DataTypes.FIELD('f1', DataTypes.ARRAY(DataTypes.BIGINT())),
                    DataTypes.FIELD('f2', DataTypes.STRING())
                ]),
                DataTypes.ROW([
                    DataTypes.FIELD('f1', DataTypes.ARRAY(DataTypes.BIGINT())),
                    DataTypes.FIELD('f2', DataTypes.STRING())
                ])),
            DataTypes.ROW([
                DataTypes.FIELD('f1', DataTypes.ARRAY(DataTypes.BIGINT())),
                DataTypes.FIELD('f2', DataTypes.STRING())
            ]))
        with self.assertRaises(TypeError):
            _merge_type(
                DataTypes.ROW([
                    DataTypes.FIELD('f1', DataTypes.ARRAY(DataTypes.BIGINT())),
                    DataTypes.FIELD('f2', DataTypes.STRING())
                ]),
                DataTypes.ROW([
                    DataTypes.FIELD('f1', DataTypes.ARRAY(DataTypes.DOUBLE())),
                    DataTypes.FIELD('f2', DataTypes.STRING())
                ]))

        self.assertEqual(
            _merge_type(
                DataTypes.ROW([
                    DataTypes.FIELD(
                        'f1',
                        DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT())),
                    DataTypes.FIELD('f2', DataTypes.STRING())
                ]),
                DataTypes.ROW([
                    DataTypes.FIELD(
                        'f1',
                        DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT())),
                    DataTypes.FIELD('f2', DataTypes.STRING())
                ])),
            DataTypes.ROW([
                DataTypes.FIELD(
                    'f1', DataTypes.MAP(DataTypes.STRING(),
                                        DataTypes.BIGINT())),
                DataTypes.FIELD('f2', DataTypes.STRING())
            ]))
        with self.assertRaises(TypeError):
            _merge_type(
                DataTypes.ROW([
                    DataTypes.FIELD(
                        'f1',
                        DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT())),
                    DataTypes.FIELD('f2', DataTypes.STRING())
                ]),
                DataTypes.ROW([
                    DataTypes.FIELD(
                        'f1',
                        DataTypes.MAP(DataTypes.STRING(), DataTypes.DOUBLE())),
                    DataTypes.FIELD('f2', DataTypes.STRING())
                ]))

        self.assertEqual(
            _merge_type(
                DataTypes.ROW([
                    DataTypes.FIELD(
                        'f1',
                        DataTypes.ARRAY(
                            DataTypes.MAP(DataTypes.STRING(),
                                          DataTypes.BIGINT())))
                ]),
                DataTypes.ROW([
                    DataTypes.FIELD(
                        'f1',
                        DataTypes.ARRAY(
                            DataTypes.MAP(DataTypes.STRING(),
                                          DataTypes.BIGINT())))
                ])),
            DataTypes.ROW([
                DataTypes.FIELD(
                    'f1',
                    DataTypes.ARRAY(
                        DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT())))
            ]))
        with self.assertRaises(TypeError):
            _merge_type(
                DataTypes.ROW([
                    DataTypes.FIELD(
                        'f1',
                        DataTypes.ARRAY(
                            DataTypes.MAP(DataTypes.STRING(),
                                          DataTypes.BIGINT())))
                ]),
                DataTypes.ROW([
                    DataTypes.FIELD(
                        'f1',
                        DataTypes.ARRAY(
                            DataTypes.MAP(DataTypes.DOUBLE(),
                                          DataTypes.BIGINT())))
                ]))