def input_output_table():
        stream_env = StreamExecutionEnvironment.get_execution_environment()
        table_env = StreamTableEnvironment.create(stream_env)
        statement_set = table_env.create_statement_set()
        work_num = 2
        ps_num = 1
        python_file = os.getcwd() + "/../../src/test/python/input_output.py"
        prop = {}
        func = "map_func"
        env_path = None
        prop[
            MLCONSTANTS.
            ENCODING_CLASS] = "org.flinkextended.flink.ml.operator.coding.RowCSVCoding"
        prop[
            MLCONSTANTS.
            DECODING_CLASS] = "org.flinkextended.flink.ml.operator.coding.RowCSVCoding"
        inputSb = "INT_32" + "," + "INT_64" + "," + "FLOAT_32" + "," + "FLOAT_64" + "," + "STRING"
        prop["sys:csv_encode_types"] = inputSb
        prop["sys:csv_decode_types"] = inputSb
        prop[MLCONSTANTS.PYTHON_VERSION] = "3.7"
        source_file = os.getcwd() + "/../../src/test/resources/input.csv"
        sink_file = os.getcwd() + "/../../src/test/resources/output.csv"
        table_source = CsvTableSource(source_file, ["a", "b", "c", "d", "e"], [
            DataTypes.INT(),
            DataTypes.BIGINT(),
            DataTypes.FLOAT(),
            DataTypes.DOUBLE(),
            DataTypes.STRING()
        ])
        table_env.register_table_source("source", table_source)
        input_tb = table_env.from_path("source")
        output_schema = TableSchema(["a", "b", "c", "d", "e"], [
            DataTypes.INT(),
            DataTypes.BIGINT(),
            DataTypes.FLOAT(),
            DataTypes.DOUBLE(),
            DataTypes.STRING()
        ])
        sink = CsvTableSink(["a", "b", "c", "d", "e"], [
            DataTypes.INT(),
            DataTypes.BIGINT(),
            DataTypes.FLOAT(),
            DataTypes.DOUBLE(),
            DataTypes.STRING()
        ],
                            sink_file,
                            write_mode=WriteMode.OVERWRITE)
        table_env.register_table_sink("table_row_sink", sink)
        tf_config = TFConfig(work_num, ps_num, prop, python_file, func,
                             env_path)
        output_table = train(stream_env, table_env, statement_set, input_tb,
                             tf_config, output_schema)

        # output_table = inference(stream_env, table_env, statement_set, input_tb, tf_config, output_schema)

        statement_set.add_insert("table_row_sink", output_table)
        job_client = statement_set.execute().get_job_client()
        if job_client is not None:
            job_client.get_job_execution_result(
                user_class_loader=None).result()
Ejemplo n.º 2
0
 def prepare_csv_source(cls, path, data, data_types, fields):
     if os.path.isfile(path):
         os.remove(path)
     csv_data = ""
     for item in data:
         if isinstance(item, list) or isinstance(item, tuple):
             csv_data += ",".join([str(element) for element in item]) + "\n"
         else:
             csv_data += str(item) + "\n"
     with open(path, 'w') as f:
         f.write(csv_data)
         f.close()
     return CsvTableSource(path, fields, data_types)
 def inputOutputTable():
     stream_env = StreamExecutionEnvironment.get_execution_environment()
     table_env = StreamTableEnvironment.create(stream_env)
     work_num = 2
     ps_num = 1
     python_file = os.getcwd() + "/../../src/test/python/input_output.py"
     property = {}
     func = "map_func"
     env_path = None
     zk_conn = None
     zk_base_path = None
     property[
         MLCONSTANTS.
         ENCODING_CLASS] = "com.alibaba.flink.ml.operator.coding.RowCSVCoding"
     property[
         MLCONSTANTS.
         DECODING_CLASS] = "com.alibaba.flink.ml.operator.coding.RowCSVCoding"
     inputSb = "INT_32" + "," + "INT_64" + "," + "FLOAT_32" + "," + "FLOAT_64" + "," + "STRING"
     property["SYS:csv_encode_types"] = inputSb
     property["SYS:csv_decode_types"] = inputSb
     source_file = os.getcwd() + "/../../src/test/resources/input.csv"
     table_source = CsvTableSource(source_file, ["a", "b", "c", "d", "e"], [
         DataTypes.INT(),
         DataTypes.INT(),
         DataTypes.FLOAT(),
         DataTypes.DOUBLE(),
         DataTypes.STRING()
     ])
     table_env.register_table_source("source", table_source)
     input_tb = table_env.scan("source")
     output_schema = TableSchema(["a", "b", "c", "d", "e"], [
         DataTypes.INT(),
         DataTypes.INT(),
         DataTypes.FLOAT(),
         DataTypes.DOUBLE(),
         DataTypes.STRING()
     ])
     train(work_num, ps_num, python_file, func, property, env_path, zk_conn,
           zk_base_path, stream_env, table_env, input_tb, output_schema)