Exemplo n.º 1
0
 def setUp(self):
     super(PyFlinkBatchTableTestCase, self).setUp()
     self.t_env = TableEnvironment.create(EnvironmentSettings.in_batch_mode())
     self.t_env.get_config().get_configuration().set_string("parallelism.default", "2")
     self.t_env.get_config().get_configuration().set_string(
         "python.fn-execution.bundle.size", "1")
     self.t_env._remote_mode = True
Exemplo n.º 2
0
def table_inference():
    prep_data()
    testing_server = start_zk_server()
    generate_model()
    stream_env = StreamExecutionEnvironment.get_execution_environment()
    table_env = TableEnvironment.get_table_environment(stream_env)
    stream_env.set_parallelism(2)
    test_data_path = "file://" + get_root_path(
    ) + "/examples/target/data/test/"
    paths = [test_data_path + "0.tfrecords", test_data_path + "1.tfrecords"]
    src_row_type = RowType([StringType(), IntegerType()],
                           ["image_raw", "label"])
    table_src = TFRTableSource(
        paths=paths,
        epochs=1,
        out_row_type=src_row_type,
        converters=[ScalarConverter.FIRST, ScalarConverter.ONE_HOT])
    input_table = table_src.register_table(table_env=table_env)
    builder = TableSchema.Builder()
    builder.column(name='label_org',
                   data_type=IntegerType()).column(name='predict_label',
                                                   data_type=IntegerType())
    output_schema = builder.build()
    output_table = tensorflow_on_flink_table.inference(
        num_worker=2,
        func=mnist_table_inference.map_fun,
        properties=build_props('0'),
        stream_env=stream_env,
        table_env=table_env,
        input_table=input_table,
        output_schema=output_schema)
    output_table.write_to_sink(LogTableStreamSink())
    table_env.generate_stream_graph()
    stream_env.execute()
    testing_server.stop()
Exemplo n.º 3
0
 def setUp(self):
     super(PyFlinkBlinkBatchTableTestCase, self).setUp()
     self.t_env = TableEnvironment.create(
         EnvironmentSettings.new_instance().in_batch_mode().use_blink_planner().build())
     self.t_env.get_config().get_configuration().set_string("parallelism.default", "2")
     self.t_env.get_config().get_configuration().set_string(
         "python.fn-execution.bundle.size", "1")
Exemplo n.º 4
0
def table_with_input():
    prep_data()
    testing_server = start_zk_server()
    stream_env = StreamExecutionEnvironment.get_execution_environment()
    table_env = TableEnvironment.get_table_environment(stream_env)
    stream_env.set_parallelism(2)
    train_data_path = "file://" + get_root_path(
    ) + "/examples/target/data/train/"
    paths = [train_data_path + "0.tfrecords", train_data_path + "1.tfrecords"]
    out_row_type = RowType([StringType(), IntegerType()],
                           ["image_raw", "label"])
    table_src = TFRTableSource(
        paths=paths,
        epochs=1,
        out_row_type=out_row_type,
        converters=[ScalarConverter.FIRST, ScalarConverter.ONE_HOT])
    input_table = table_src.register_table(table_env=table_env)
    tensorflow_on_flink_table.train(num_worker=1,
                                    num_ps=1,
                                    func=mnist_dist_with_input.map_fun,
                                    properties=build_props(),
                                    stream_env=stream_env,
                                    table_env=table_env,
                                    input_table=input_table)
    table_env.generate_stream_graph()
    stream_env.execute()
    testing_server.stop()
Exemplo n.º 5
0
 def setUp(self):
     super(PyFlinkStreamTableTestCase, self).setUp()
     self.t_env = TableEnvironment.create(
         EnvironmentSettings.in_streaming_mode())
     self.t_env.get_config().get_configuration().set_string(
         "parallelism.default", "2")
     self.t_env.get_config().get_configuration().set_string(
         "python.fn-execution.bundle.size", "1")
     self.t_env._execution_mode = "process"
Exemplo n.º 6
0
def table_java_inference():
    prep_data()
    testing_server = start_zk_server()
    generate_model()
    stream_env = StreamExecutionEnvironment.get_execution_environment()
    table_env = TableEnvironment.get_table_environment(stream_env)
    stream_env.set_parallelism(2)
    train_data_path = "file://" + get_root_path(
    ) + "/examples/target/data/test/"
    paths = [train_data_path + "0.tfrecords", train_data_path + "1.tfrecords"]
    src_row_type = RowType([StringType(), IntegerType()],
                           ["image_raw", "label"])
    table_src = TFRTableSource(
        paths=paths,
        epochs=1,
        out_row_type=src_row_type,
        converters=[ScalarConverter.FIRST, ScalarConverter.ONE_HOT])
    tfr_tbl_name = "tfr_input_table"
    table_env.register_table_source(tfr_tbl_name, table_src)
    ext_func_name = "tfr_extract"
    table_env.register_function(ext_func_name, java_inference_extract_func())
    out_cols = 'image,org_label'
    in_cols = ','.join(src_row_type.fields_names)
    extracted = table_env.sql_query(
        'select {} from {}, LATERAL TABLE({}({})) as T({})'.format(
            out_cols, tfr_tbl_name, ext_func_name, in_cols, out_cols))
    builder = TableSchema.Builder()
    builder.column(name='real_label',
                   data_type=LongType()).column(name='predicted_label',
                                                data_type=LongType())
    output_schema = builder.build()
    props = build_props('0')
    props[TF_INFERENCE_EXPORT_PATH] = export_path
    props[TF_INFERENCE_INPUT_TENSOR_NAMES] = 'image'
    props[TF_INFERENCE_OUTPUT_TENSOR_NAMES] = 'prediction'
    props[TF_INFERENCE_OUTPUT_ROW_FIELDS] = ','.join(
        ['org_label', 'prediction'])
    output_table = tensorflow_on_flink_table.inference(
        num_worker=2,
        properties=props,
        stream_env=stream_env,
        table_env=table_env,
        input_table=extracted,
        output_schema=output_schema)
    output_table.write_to_sink(LogInferAccSink())
    table_env.generate_stream_graph()
    stream_env.execute()
    testing_server.stop()
Exemplo n.º 7
0
    def test_create_table_environment(self):
        table_config = TableConfig.Builder()\
            .set_parallelism(2)\
            .set_max_generated_code_length(32000)\
            .set_null_check(False)\
            .set_timezone("Asia/Shanghai")\
            .as_batch_execution()\
            .build()

        t_env = TableEnvironment.create(table_config)

        readed_table_config = t_env.get_config()
        assert readed_table_config.parallelism() == 2
        assert readed_table_config.null_check() is False
        assert readed_table_config.max_generated_code_length() == 32000
        assert readed_table_config.timezone() == "Asia/Shanghai"
        assert readed_table_config.is_stream() is False
Exemplo n.º 8
0
    def test_create_table_environment(self):
        table_config = TableConfig.Builder()\
            .set_parallelism(2)\
            .set_max_generated_code_length(32000)\
            .set_null_check(False)\
            .set_timezone("Asia/Shanghai")\
            .as_batch_execution()\
            .build()

        t_env = TableEnvironment.create(table_config)

        readed_table_config = t_env.get_config()
        assert readed_table_config.parallelism() == 2
        assert readed_table_config.null_check() is False
        assert readed_table_config.max_generated_code_length() == 32000
        assert readed_table_config.timezone() == "Asia/Shanghai"
        assert readed_table_config.is_stream() is False
def train(num_worker,
          num_ps,
          func,
          properties=None,
          env_path=None,
          zk_conn=None,
          zk_base_path=None,
          stream_env=None,
          table_env=None,
          input_table=None,
          output_schema=None):
    """
    Tensorflow training for Table
    :param zk_conn: The Zookeeper connection string
    :param zk_base_path: The Zookeeper base path
    :param num_worker: Number of workers
    :param num_ps: Number of PS
    :param func: The user-defined function that runs TF training
    :param properties: User-defined properties
    :param env_path: Path to the virtual env
    :param stream_env: The StreamExecutionEnvironment. If it's None, this method will create one and execute the job
                       at the end. Otherwise, caller is responsible to trigger the job execution
    :param table_env: The TableEnvironment
    :param input_table: The input Table
    :param output_schema: The TableSchema of the output Table. If it's None, a dummy sink will be added to the output
                          Table. Otherwise, caller is responsible to add sink before executing the job.
    :return: The output Table
    """
    tf_config = TFConfig(num_worker, num_ps, func, properties, env_path,
                         zk_conn, zk_base_path)
    execute = stream_env is None
    if stream_env is None:
        stream_env = StreamExecutionEnvironment.get_execution_environment()
    if table_env is None:
        table_env = TableEnvironment.get_table_environment(stream_env)
    if input_table is not None:
        input_table = input_table._java_table
    if output_schema is not None:
        output_schema = output_schema._j_table_schema
    output_table = get_gateway(
    ).jvm.com.alibaba.flink.tensorflow.client.TFUtils.train(
        stream_env._j_env, table_env._j_tenv, input_table,
        tf_config.java_config(), output_schema)
    if execute:
        table_env.execute()
    return Table(java_table=output_table)
Exemplo n.º 10
0
    def test_create_table_environment(self):
        table_config = TableConfig.Builder()\
            .set_parallelism(2)\
            .set_max_generated_code_length(32000)\
            .set_null_check(False)\
            .set_timezone("Asia/Shanghai")\
            .as_streaming_execution()\
            .build()

        t_env = TableEnvironment.create(table_config)

        readed_table_config = t_env.get_config()
        self.assertEqual(2, readed_table_config.parallelism())
        self.assertFalse(readed_table_config.null_check())
        self.assertEqual(32000, readed_table_config.max_generated_code_length())
        self.assertEqual("Asia/Shanghai", readed_table_config.timezone())
        self.assertTrue(readed_table_config.is_stream())