def input_output_table(): stream_env = StreamExecutionEnvironment.get_execution_environment() table_env = StreamTableEnvironment.create(stream_env) statement_set = table_env.create_statement_set() work_num = 2 ps_num = 1 python_file = os.getcwd() + "/../../src/test/python/input_output.py" prop = {} func = "map_func" env_path = None prop[ MLCONSTANTS. ENCODING_CLASS] = "org.flinkextended.flink.ml.operator.coding.RowCSVCoding" prop[ MLCONSTANTS. DECODING_CLASS] = "org.flinkextended.flink.ml.operator.coding.RowCSVCoding" inputSb = "INT_32" + "," + "INT_64" + "," + "FLOAT_32" + "," + "FLOAT_64" + "," + "STRING" prop["sys:csv_encode_types"] = inputSb prop["sys:csv_decode_types"] = inputSb prop[MLCONSTANTS.PYTHON_VERSION] = "3.7" source_file = os.getcwd() + "/../../src/test/resources/input.csv" sink_file = os.getcwd() + "/../../src/test/resources/output.csv" table_source = CsvTableSource(source_file, ["a", "b", "c", "d", "e"], [ DataTypes.INT(), DataTypes.BIGINT(), DataTypes.FLOAT(), DataTypes.DOUBLE(), DataTypes.STRING() ]) table_env.register_table_source("source", table_source) input_tb = table_env.from_path("source") output_schema = TableSchema(["a", "b", "c", "d", "e"], [ DataTypes.INT(), DataTypes.BIGINT(), DataTypes.FLOAT(), DataTypes.DOUBLE(), DataTypes.STRING() ]) sink = CsvTableSink(["a", "b", "c", "d", "e"], [ DataTypes.INT(), DataTypes.BIGINT(), DataTypes.FLOAT(), DataTypes.DOUBLE(), DataTypes.STRING() ], sink_file, write_mode=WriteMode.OVERWRITE) table_env.register_table_sink("table_row_sink", sink) tf_config = TFConfig(work_num, ps_num, prop, python_file, func, env_path) output_table = train(stream_env, table_env, statement_set, input_tb, tf_config, output_schema) # output_table = inference(stream_env, table_env, statement_set, input_tb, tf_config, output_schema) statement_set.add_insert("table_row_sink", output_table) job_client = statement_set.execute().get_job_client() if job_client is not None: job_client.get_job_execution_result( user_class_loader=None).result()
def test_fields(self): fields = collections.OrderedDict([ ("int_field", DataTypes.INT()), ("long_field", DataTypes.BIGINT()), ("string_field", DataTypes.STRING()), ("timestamp_field", DataTypes.TIMESTAMP(3)), ("time_field", DataTypes.TIME()), ("date_field", DataTypes.DATE()), ("double_field", DataTypes.DOUBLE()), ("float_field", DataTypes.FLOAT()), ("byte_field", DataTypes.TINYINT()), ("short_field", DataTypes.SMALLINT()), ("boolean_field", DataTypes.BOOLEAN()) ]) schema = Schema().fields(fields) properties = schema.to_properties() expected = { 'schema.0.name': 'int_field', 'schema.0.data-type': 'INT', 'schema.1.name': 'long_field', 'schema.1.data-type': 'BIGINT', 'schema.2.name': 'string_field', 'schema.2.data-type': 'VARCHAR(2147483647)', 'schema.3.name': 'timestamp_field', 'schema.3.data-type': 'TIMESTAMP(3)', 'schema.4.name': 'time_field', 'schema.4.data-type': 'TIME(0)', 'schema.5.name': 'date_field', 'schema.5.data-type': 'DATE', 'schema.6.name': 'double_field', 'schema.6.data-type': 'DOUBLE', 'schema.7.name': 'float_field', 'schema.7.data-type': 'FLOAT', 'schema.8.name': 'byte_field', 'schema.8.data-type': 'TINYINT', 'schema.9.name': 'short_field', 'schema.9.data-type': 'SMALLINT', 'schema.10.name': 'boolean_field', 'schema.10.data-type': 'BOOLEAN' } self.assertEqual(expected, properties) if sys.version_info[:2] <= (3, 5): fields = { "int_field": DataTypes.INT(), "long_field": DataTypes.BIGINT(), "string_field": DataTypes.STRING(), "timestamp_field": DataTypes.TIMESTAMP(3), "time_field": DataTypes.TIME(), "date_field": DataTypes.DATE(), "double_field": DataTypes.DOUBLE(), "float_field": DataTypes.FLOAT(), "byte_field": DataTypes.TINYINT(), "short_field": DataTypes.SMALLINT(), "boolean_field": DataTypes.BOOLEAN() } self.assertRaises(TypeError, Schema().fields, fields)
def setUpClass(cls): super(PandasConversionTestBase, cls).setUpClass() cls.data = [(1, 1, 1, 1, True, 1.1, 1.2, 'hello', bytearray(b"aaa"), decimal.Decimal('1000000000000000000.01'), datetime.date(2014, 9, 13), datetime.time(hour=1, minute=0, second=1), datetime.datetime(1970, 1, 1, 0, 0, 0, 123000), ['hello', '中文'], Row(a=1, b='hello', c=datetime.datetime(1970, 1, 1, 0, 0, 0, 123000), d=[1, 2])), (1, 2, 2, 2, False, 2.1, 2.2, 'world', bytearray(b"bbb"), decimal.Decimal('1000000000000000000.02'), datetime.date(2014, 9, 13), datetime.time(hour=1, minute=0, second=1), datetime.datetime(1970, 1, 1, 0, 0, 0, 123000), ['hello', '中文'], Row(a=1, b='hello', c=datetime.datetime(1970, 1, 1, 0, 0, 0, 123000), d=[1, 2]))] cls.data_type = DataTypes.ROW( [DataTypes.FIELD("f1", DataTypes.TINYINT()), DataTypes.FIELD("f2", DataTypes.SMALLINT()), DataTypes.FIELD("f3", DataTypes.INT()), DataTypes.FIELD("f4", DataTypes.BIGINT()), DataTypes.FIELD("f5", DataTypes.BOOLEAN()), DataTypes.FIELD("f6", DataTypes.FLOAT()), DataTypes.FIELD("f7", DataTypes.DOUBLE()), DataTypes.FIELD("f8", DataTypes.STRING()), DataTypes.FIELD("f9", DataTypes.BYTES()), DataTypes.FIELD("f10", DataTypes.DECIMAL(38, 18)), DataTypes.FIELD("f11", DataTypes.DATE()), DataTypes.FIELD("f12", DataTypes.TIME()), DataTypes.FIELD("f13", DataTypes.TIMESTAMP(3)), DataTypes.FIELD("f14", DataTypes.ARRAY(DataTypes.STRING())), DataTypes.FIELD("f15", DataTypes.ROW( [DataTypes.FIELD("a", DataTypes.INT()), DataTypes.FIELD("b", DataTypes.STRING()), DataTypes.FIELD("c", DataTypes.TIMESTAMP(3)), DataTypes.FIELD("d", DataTypes.ARRAY(DataTypes.INT()))]))], False) cls.pdf = cls.create_pandas_data_frame()
def test_csv_primitive_column(self): schema = CsvSchema.builder() \ .add_number_column('tinyint', DataTypes.TINYINT()) \ .add_number_column('smallint', DataTypes.SMALLINT()) \ .add_number_column('int', DataTypes.INT()) \ .add_number_column('bigint', DataTypes.BIGINT()) \ .add_number_column('float', DataTypes.FLOAT()) \ .add_number_column('double', DataTypes.DOUBLE()) \ .add_number_column('decimal', DataTypes.DECIMAL(2, 0)) \ .add_boolean_column('boolean') \ .add_string_column('string') \ .build() with open(self.csv_file_name, 'w') as f: f.write('127,') f.write('-32767,') f.write('2147483647,') f.write('-9223372036854775808,') f.write('3e38,') f.write('2e-308,') f.write('1.5,') f.write('true,') f.write('string\n') self._build_csv_job(schema) self.env.execute('test_csv_primitive_column') row = self.test_sink.get_results(True, False)[0] self.assertEqual(row['tinyint'], 127) self.assertEqual(row['smallint'], -32767) self.assertEqual(row['int'], 2147483647) self.assertEqual(row['bigint'], -9223372036854775808) self.assertAlmostEqual(row['float'], 3e38, delta=1e31) self.assertAlmostEqual(row['double'], 2e-308, delta=2e-301) self.assertAlmostEqual(row['decimal'], 2) self.assertEqual(row['boolean'], True) self.assertEqual(row['string'], 'string')
def _create_orc_basic_row_and_data() -> Tuple[RowType, RowTypeInfo, List[Row]]: row_type = DataTypes.ROW([ DataTypes.FIELD('char', DataTypes.CHAR(10)), DataTypes.FIELD('varchar', DataTypes.VARCHAR(10)), DataTypes.FIELD('bytes', DataTypes.BYTES()), DataTypes.FIELD('boolean', DataTypes.BOOLEAN()), DataTypes.FIELD('decimal', DataTypes.DECIMAL(2, 0)), DataTypes.FIELD('int', DataTypes.INT()), DataTypes.FIELD('bigint', DataTypes.BIGINT()), DataTypes.FIELD('double', DataTypes.DOUBLE()), DataTypes.FIELD('date', DataTypes.DATE().bridged_to('java.sql.Date')), DataTypes.FIELD('timestamp', DataTypes.TIMESTAMP(3).bridged_to('java.sql.Timestamp')), ]) row_type_info = Types.ROW_NAMED( ['char', 'varchar', 'bytes', 'boolean', 'decimal', 'int', 'bigint', 'double', 'date', 'timestamp'], [Types.STRING(), Types.STRING(), Types.PRIMITIVE_ARRAY(Types.BYTE()), Types.BOOLEAN(), Types.BIG_DEC(), Types.INT(), Types.LONG(), Types.DOUBLE(), Types.SQL_DATE(), Types.SQL_TIMESTAMP()] ) data = [Row( char='char', varchar='varchar', bytes=b'varbinary', boolean=True, decimal=Decimal(1.5), int=2147483647, bigint=-9223372036854775808, double=2e-308, date=date(1970, 1, 1), timestamp=datetime(1970, 1, 2, 3, 4, 5, 600000), )] return row_type, row_type_info, data
def inputOutputTable(): stream_env = StreamExecutionEnvironment.get_execution_environment() table_env = StreamTableEnvironment.create(stream_env) work_num = 2 ps_num = 1 python_file = os.getcwd() + "/../../src/test/python/input_output.py" property = {} func = "map_func" env_path = None zk_conn = None zk_base_path = None property[ MLCONSTANTS. ENCODING_CLASS] = "com.alibaba.flink.ml.operator.coding.RowCSVCoding" property[ MLCONSTANTS. DECODING_CLASS] = "com.alibaba.flink.ml.operator.coding.RowCSVCoding" inputSb = "INT_32" + "," + "INT_64" + "," + "FLOAT_32" + "," + "FLOAT_64" + "," + "STRING" property["SYS:csv_encode_types"] = inputSb property["SYS:csv_decode_types"] = inputSb source_file = os.getcwd() + "/../../src/test/resources/input.csv" table_source = CsvTableSource(source_file, ["a", "b", "c", "d", "e"], [ DataTypes.INT(), DataTypes.INT(), DataTypes.FLOAT(), DataTypes.DOUBLE(), DataTypes.STRING() ]) table_env.register_table_source("source", table_source) input_tb = table_env.scan("source") output_schema = TableSchema(["a", "b", "c", "d", "e"], [ DataTypes.INT(), DataTypes.INT(), DataTypes.FLOAT(), DataTypes.DOUBLE(), DataTypes.STRING() ]) train(work_num, ps_num, python_file, func, property, env_path, zk_conn, zk_base_path, stream_env, table_env, input_tb, output_schema)
def _create_parquet_basic_row_and_data() -> Tuple[RowType, RowTypeInfo, List[Row]]: row_type = DataTypes.ROW([ DataTypes.FIELD('char', DataTypes.CHAR(10)), DataTypes.FIELD('varchar', DataTypes.VARCHAR(10)), DataTypes.FIELD('binary', DataTypes.BINARY(10)), DataTypes.FIELD('varbinary', DataTypes.VARBINARY(10)), DataTypes.FIELD('boolean', DataTypes.BOOLEAN()), DataTypes.FIELD('decimal', DataTypes.DECIMAL(2, 0)), DataTypes.FIELD('int', DataTypes.INT()), DataTypes.FIELD('bigint', DataTypes.BIGINT()), DataTypes.FIELD('double', DataTypes.DOUBLE()), DataTypes.FIELD('date', DataTypes.DATE().bridged_to('java.sql.Date')), DataTypes.FIELD('time', DataTypes.TIME().bridged_to('java.sql.Time')), DataTypes.FIELD('timestamp', DataTypes.TIMESTAMP(3).bridged_to('java.sql.Timestamp')), DataTypes.FIELD('timestamp_ltz', DataTypes.TIMESTAMP_LTZ(3)), ]) row_type_info = Types.ROW_NAMED( ['char', 'varchar', 'binary', 'varbinary', 'boolean', 'decimal', 'int', 'bigint', 'double', 'date', 'time', 'timestamp', 'timestamp_ltz'], [Types.STRING(), Types.STRING(), Types.PRIMITIVE_ARRAY(Types.BYTE()), Types.PRIMITIVE_ARRAY(Types.BYTE()), Types.BOOLEAN(), Types.BIG_DEC(), Types.INT(), Types.LONG(), Types.DOUBLE(), Types.SQL_DATE(), Types.SQL_TIME(), Types.SQL_TIMESTAMP(), Types.INSTANT()] ) datetime_ltz = datetime.datetime(1970, 2, 3, 4, 5, 6, 700000, tzinfo=pytz.timezone('UTC')) timestamp_ltz = Instant.of_epoch_milli( ( calendar.timegm(datetime_ltz.utctimetuple()) + calendar.timegm(time.localtime(0)) ) * 1000 + datetime_ltz.microsecond // 1000 ) data = [Row( char='char', varchar='varchar', binary=b'binary', varbinary=b'varbinary', boolean=True, decimal=Decimal(1.5), int=2147483647, bigint=-9223372036854775808, double=2e-308, date=datetime.date(1970, 1, 1), time=datetime.time(1, 1, 1), timestamp=datetime.datetime(1970, 1, 2, 3, 4, 5, 600000), timestamp_ltz=timestamp_ltz )] return row_type, row_type_info, data
def test_basic_type(self): test_types = [DataTypes.STRING(), DataTypes.BOOLEAN(), DataTypes.BYTES(), DataTypes.TINYINT(), DataTypes.SMALLINT(), DataTypes.INT(), DataTypes.BIGINT(), DataTypes.FLOAT(), DataTypes.DOUBLE(), DataTypes.DATE(), DataTypes.TIME(), DataTypes.TIMESTAMP(3)] java_types = [_to_java_type(item) for item in test_types] converted_python_types = [_from_java_type(item) for item in java_types] self.assertEqual(test_types, converted_python_types)
def test_field(self): schema = Schema() schema = schema\ .field("int_field", DataTypes.INT())\ .field("long_field", DataTypes.BIGINT())\ .field("string_field", DataTypes.STRING())\ .field("timestamp_field", DataTypes.TIMESTAMP())\ .field("time_field", DataTypes.TIME())\ .field("date_field", DataTypes.DATE())\ .field("double_field", DataTypes.DOUBLE())\ .field("float_field", DataTypes.FLOAT())\ .field("byte_field", DataTypes.TINYINT())\ .field("short_field", DataTypes.SMALLINT())\ .field("boolean_field", DataTypes.BOOLEAN()) properties = schema.to_properties() expected = { 'schema.0.name': 'int_field', 'schema.0.type': 'INT', 'schema.1.name': 'long_field', 'schema.1.type': 'BIGINT', 'schema.2.name': 'string_field', 'schema.2.type': 'VARCHAR', 'schema.3.name': 'timestamp_field', 'schema.3.type': 'TIMESTAMP', 'schema.4.name': 'time_field', 'schema.4.type': 'TIME', 'schema.5.name': 'date_field', 'schema.5.type': 'DATE', 'schema.6.name': 'double_field', 'schema.6.type': 'DOUBLE', 'schema.7.name': 'float_field', 'schema.7.type': 'FLOAT', 'schema.8.name': 'byte_field', 'schema.8.type': 'TINYINT', 'schema.9.name': 'short_field', 'schema.9.type': 'SMALLINT', 'schema.10.name': 'boolean_field', 'schema.10.type': 'BOOLEAN' } assert properties == expected
def test_field(self): schema = Schema()\ .field("int_field", DataTypes.INT())\ .field("long_field", DataTypes.BIGINT())\ .field("string_field", DataTypes.STRING())\ .field("timestamp_field", DataTypes.TIMESTAMP(3))\ .field("time_field", DataTypes.TIME())\ .field("date_field", DataTypes.DATE())\ .field("double_field", DataTypes.DOUBLE())\ .field("float_field", DataTypes.FLOAT())\ .field("byte_field", DataTypes.TINYINT())\ .field("short_field", DataTypes.SMALLINT())\ .field("boolean_field", DataTypes.BOOLEAN()) properties = schema.to_properties() expected = { 'schema.0.name': 'int_field', 'schema.0.data-type': 'INT', 'schema.1.name': 'long_field', 'schema.1.data-type': 'BIGINT', 'schema.2.name': 'string_field', 'schema.2.data-type': 'VARCHAR(2147483647)', 'schema.3.name': 'timestamp_field', 'schema.3.data-type': 'TIMESTAMP(3)', 'schema.4.name': 'time_field', 'schema.4.data-type': 'TIME(0)', 'schema.5.name': 'date_field', 'schema.5.data-type': 'DATE', 'schema.6.name': 'double_field', 'schema.6.data-type': 'DOUBLE', 'schema.7.name': 'float_field', 'schema.7.data-type': 'FLOAT', 'schema.8.name': 'byte_field', 'schema.8.data-type': 'TINYINT', 'schema.9.name': 'short_field', 'schema.9.data-type': 'SMALLINT', 'schema.10.name': 'boolean_field', 'schema.10.data-type': 'BOOLEAN' } self.assertEqual(expected, properties)
def test_parquet_columnar_basic(self): parquet_file_name = tempfile.mktemp(suffix='.parquet', dir=self.tempdir) schema, records = _create_basic_avro_schema_and_records() FileSourceParquetAvroFormatTests._create_parquet_avro_file( parquet_file_name, schema, records) row_type = DataTypes.ROW([ DataTypes.FIELD( 'null', DataTypes.STRING()), # DataTypes.NULL cannot be serialized DataTypes.FIELD('boolean', DataTypes.BOOLEAN()), DataTypes.FIELD('int', DataTypes.INT()), DataTypes.FIELD('long', DataTypes.BIGINT()), DataTypes.FIELD('float', DataTypes.FLOAT()), DataTypes.FIELD('double', DataTypes.DOUBLE()), DataTypes.FIELD('string', DataTypes.STRING()), DataTypes.FIELD('unknown', DataTypes.STRING()) ]) self._build_parquet_columnar_job(row_type, parquet_file_name) self.env.execute('test_parquet_columnar_basic') results = self.test_sink.get_results(True, False) _check_basic_avro_schema_results(self, results) self.assertIsNone(results[0]['unknown']) self.assertIsNone(results[1]['unknown'])
def sql_type(cls): return DataTypes.ARRAY(DataTypes.DOUBLE(False))
def test_verify_type_not_nullable(self): import array import datetime import decimal schema = DataTypes.ROW([ DataTypes.FIELD('s', DataTypes.STRING(nullable=False)), DataTypes.FIELD('i', DataTypes.INT(True)) ]) class MyObj: def __init__(self, **kwargs): for k, v in kwargs.items(): setattr(self, k, v) # obj, data_type success_spec = [ # String ("", DataTypes.STRING()), (u"", DataTypes.STRING()), # UDT (ExamplePoint(1.0, 2.0), ExamplePointUDT()), # Boolean (True, DataTypes.BOOLEAN()), # TinyInt (-(2**7), DataTypes.TINYINT()), (2**7 - 1, DataTypes.TINYINT()), # SmallInt (-(2**15), DataTypes.SMALLINT()), (2**15 - 1, DataTypes.SMALLINT()), # Int (-(2**31), DataTypes.INT()), (2**31 - 1, DataTypes.INT()), # BigInt (2**64, DataTypes.BIGINT()), # Float & Double (1.0, DataTypes.FLOAT()), (1.0, DataTypes.DOUBLE()), # Decimal (decimal.Decimal("1.0"), DataTypes.DECIMAL(10, 0)), # Binary (bytearray([1]), DataTypes.BINARY(1)), # Date/Time/Timestamp (datetime.date(2000, 1, 2), DataTypes.DATE()), (datetime.datetime(2000, 1, 2, 3, 4), DataTypes.DATE()), (datetime.time(1, 1, 2), DataTypes.TIME()), (datetime.datetime(2000, 1, 2, 3, 4), DataTypes.TIMESTAMP()), # Array ([], DataTypes.ARRAY(DataTypes.INT())), (["1", None], DataTypes.ARRAY(DataTypes.STRING(nullable=True))), ([1, 2], DataTypes.ARRAY(DataTypes.INT())), ((1, 2), DataTypes.ARRAY(DataTypes.INT())), (array.array('h', [1, 2]), DataTypes.ARRAY(DataTypes.INT())), # Map ({}, DataTypes.MAP(DataTypes.STRING(), DataTypes.INT())), ({ "a": 1 }, DataTypes.MAP(DataTypes.STRING(), DataTypes.INT())), ({ "a": None }, DataTypes.MAP(DataTypes.STRING(nullable=False), DataTypes.INT(True))), # Struct ({ "s": "a", "i": 1 }, schema), ({ "s": "a", "i": None }, schema), ({ "s": "a" }, schema), ({ "s": "a", "f": 1.0 }, schema), (Row(s="a", i=1), schema), (Row(s="a", i=None), schema), (Row(s="a", i=1, f=1.0), schema), (["a", 1], schema), (["a", None], schema), (("a", 1), schema), (MyObj(s="a", i=1), schema), (MyObj(s="a", i=None), schema), (MyObj(s="a"), schema), ] # obj, data_type, exception class failure_spec = [ # Char/VarChar (match anything but None) (None, DataTypes.VARCHAR(1), ValueError), (None, DataTypes.CHAR(1), ValueError), # VarChar (length exceeds maximum length) ("abc", DataTypes.VARCHAR(1), ValueError), # Char (length exceeds length) ("abc", DataTypes.CHAR(1), ValueError), # UDT (ExamplePoint(1.0, 2.0), PythonOnlyUDT(), ValueError), # Boolean (1, DataTypes.BOOLEAN(), TypeError), ("True", DataTypes.BOOLEAN(), TypeError), ([1], DataTypes.BOOLEAN(), TypeError), # TinyInt (-(2**7) - 1, DataTypes.TINYINT(), ValueError), (2**7, DataTypes.TINYINT(), ValueError), ("1", DataTypes.TINYINT(), TypeError), (1.0, DataTypes.TINYINT(), TypeError), # SmallInt (-(2**15) - 1, DataTypes.SMALLINT(), ValueError), (2**15, DataTypes.SMALLINT(), ValueError), # Int (-(2**31) - 1, DataTypes.INT(), ValueError), (2**31, DataTypes.INT(), ValueError), # Float & Double (1, DataTypes.FLOAT(), TypeError), (1, DataTypes.DOUBLE(), TypeError), # Decimal (1.0, DataTypes.DECIMAL(10, 0), TypeError), (1, DataTypes.DECIMAL(10, 0), TypeError), ("1.0", DataTypes.DECIMAL(10, 0), TypeError), # Binary (1, DataTypes.BINARY(1), TypeError), # VarBinary (length exceeds maximum length) (bytearray([1, 2]), DataTypes.VARBINARY(1), ValueError), # Char (length exceeds length) (bytearray([1, 2]), DataTypes.BINARY(1), ValueError), # Date/Time/Timestamp ("2000-01-02", DataTypes.DATE(), TypeError), ("10:01:02", DataTypes.TIME(), TypeError), (946811040, DataTypes.TIMESTAMP(), TypeError), # Array (["1", None], DataTypes.ARRAY(DataTypes.VARCHAR(1, nullable=False)), ValueError), ([1, "2"], DataTypes.ARRAY(DataTypes.INT()), TypeError), # Map ({ "a": 1 }, DataTypes.MAP(DataTypes.INT(), DataTypes.INT()), TypeError), ({ "a": "1" }, DataTypes.MAP(DataTypes.VARCHAR(1), DataTypes.INT()), TypeError), ({ "a": None }, DataTypes.MAP(DataTypes.VARCHAR(1), DataTypes.INT(False)), ValueError), # Struct ({ "s": "a", "i": "1" }, schema, TypeError), (Row(s="a"), schema, ValueError), # Row can't have missing field (Row(s="a", i="1"), schema, TypeError), (["a"], schema, ValueError), (["a", "1"], schema, TypeError), (MyObj(s="a", i="1"), schema, TypeError), (MyObj(s=None, i="1"), schema, ValueError), ] # Check success cases for obj, data_type in success_spec: try: _create_type_verifier(data_type.not_null())(obj) except (TypeError, ValueError): self.fail("verify_type(%s, %s, nullable=False)" % (obj, data_type)) # Check failure cases for obj, data_type, exp in failure_spec: msg = "verify_type(%s, %s, nullable=False) == %s" % ( obj, data_type, exp) with self.assertRaises(exp, msg=msg): _create_type_verifier(data_type.not_null())(obj)
def test_merge_type(self): self.assertEqual(_merge_type(DataTypes.BIGINT(), DataTypes.NULL()), DataTypes.BIGINT()) self.assertEqual(_merge_type(DataTypes.NULL(), DataTypes.BIGINT()), DataTypes.BIGINT()) self.assertEqual(_merge_type(DataTypes.BIGINT(), DataTypes.BIGINT()), DataTypes.BIGINT()) self.assertEqual( _merge_type(DataTypes.ARRAY(DataTypes.BIGINT()), DataTypes.ARRAY(DataTypes.BIGINT())), DataTypes.ARRAY(DataTypes.BIGINT())) with self.assertRaises(TypeError): _merge_type(DataTypes.ARRAY(DataTypes.BIGINT()), DataTypes.ARRAY(DataTypes.DOUBLE())) self.assertEqual( _merge_type(DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT()), DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT())), DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT())) with self.assertRaises(TypeError): _merge_type(DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT()), DataTypes.MAP(DataTypes.DOUBLE(), DataTypes.BIGINT())) with self.assertRaises(TypeError): _merge_type(DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT()), DataTypes.MAP(DataTypes.STRING(), DataTypes.DOUBLE())) self.assertEqual( _merge_type( DataTypes.ROW([ DataTypes.FIELD('f1', DataTypes.BIGINT()), DataTypes.FIELD('f2', DataTypes.STRING()) ]), DataTypes.ROW([ DataTypes.FIELD('f1', DataTypes.BIGINT()), DataTypes.FIELD('f2', DataTypes.STRING()) ])), DataTypes.ROW([ DataTypes.FIELD('f1', DataTypes.BIGINT()), DataTypes.FIELD('f2', DataTypes.STRING()) ])) with self.assertRaises(TypeError): _merge_type( DataTypes.ROW([ DataTypes.FIELD('f1', DataTypes.BIGINT()), DataTypes.FIELD('f2', DataTypes.STRING()) ]), DataTypes.ROW([ DataTypes.FIELD('f1', DataTypes.DOUBLE()), DataTypes.FIELD('f2', DataTypes.STRING()) ])) self.assertEqual( _merge_type( DataTypes.ROW([ DataTypes.FIELD( 'f1', DataTypes.ROW( [DataTypes.FIELD('f2', DataTypes.BIGINT())])) ]), DataTypes.ROW([ DataTypes.FIELD( 'f1', DataTypes.ROW( [DataTypes.FIELD('f2', DataTypes.BIGINT())])) ])), DataTypes.ROW([ DataTypes.FIELD( 'f1', DataTypes.ROW([DataTypes.FIELD('f2', DataTypes.BIGINT())])) ])) with self.assertRaises(TypeError): _merge_type( DataTypes.ROW([ DataTypes.FIELD( 'f1', DataTypes.ROW( [DataTypes.FIELD('f2', DataTypes.BIGINT())])) ]), DataTypes.ROW([ DataTypes.FIELD( 'f1', DataTypes.ROW( [DataTypes.FIELD('f2', DataTypes.STRING())])) ])) self.assertEqual( _merge_type( DataTypes.ROW([ DataTypes.FIELD('f1', DataTypes.ARRAY(DataTypes.BIGINT())), DataTypes.FIELD('f2', DataTypes.STRING()) ]), DataTypes.ROW([ DataTypes.FIELD('f1', DataTypes.ARRAY(DataTypes.BIGINT())), DataTypes.FIELD('f2', DataTypes.STRING()) ])), DataTypes.ROW([ DataTypes.FIELD('f1', DataTypes.ARRAY(DataTypes.BIGINT())), DataTypes.FIELD('f2', DataTypes.STRING()) ])) with self.assertRaises(TypeError): _merge_type( DataTypes.ROW([ DataTypes.FIELD('f1', DataTypes.ARRAY(DataTypes.BIGINT())), DataTypes.FIELD('f2', DataTypes.STRING()) ]), DataTypes.ROW([ DataTypes.FIELD('f1', DataTypes.ARRAY(DataTypes.DOUBLE())), DataTypes.FIELD('f2', DataTypes.STRING()) ])) self.assertEqual( _merge_type( DataTypes.ROW([ DataTypes.FIELD( 'f1', DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT())), DataTypes.FIELD('f2', DataTypes.STRING()) ]), DataTypes.ROW([ DataTypes.FIELD( 'f1', DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT())), DataTypes.FIELD('f2', DataTypes.STRING()) ])), DataTypes.ROW([ DataTypes.FIELD( 'f1', DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT())), DataTypes.FIELD('f2', DataTypes.STRING()) ])) with self.assertRaises(TypeError): _merge_type( DataTypes.ROW([ DataTypes.FIELD( 'f1', DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT())), DataTypes.FIELD('f2', DataTypes.STRING()) ]), DataTypes.ROW([ DataTypes.FIELD( 'f1', DataTypes.MAP(DataTypes.STRING(), DataTypes.DOUBLE())), DataTypes.FIELD('f2', DataTypes.STRING()) ])) self.assertEqual( _merge_type( DataTypes.ROW([ DataTypes.FIELD( 'f1', DataTypes.ARRAY( DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT()))) ]), DataTypes.ROW([ DataTypes.FIELD( 'f1', DataTypes.ARRAY( DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT()))) ])), DataTypes.ROW([ DataTypes.FIELD( 'f1', DataTypes.ARRAY( DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT()))) ])) with self.assertRaises(TypeError): _merge_type( DataTypes.ROW([ DataTypes.FIELD( 'f1', DataTypes.ARRAY( DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT()))) ]), DataTypes.ROW([ DataTypes.FIELD( 'f1', DataTypes.ARRAY( DataTypes.MAP(DataTypes.DOUBLE(), DataTypes.BIGINT()))) ]))