Пример #1
0
def table_inference():
    prep_data()
    testing_server = start_zk_server()
    generate_model()
    stream_env = StreamExecutionEnvironment.get_execution_environment()
    table_env = TableEnvironment.get_table_environment(stream_env)
    stream_env.set_parallelism(2)
    test_data_path = "file://" + get_root_path(
    ) + "/examples/target/data/test/"
    paths = [test_data_path + "0.tfrecords", test_data_path + "1.tfrecords"]
    src_row_type = RowType([StringType(), IntegerType()],
                           ["image_raw", "label"])
    table_src = TFRTableSource(
        paths=paths,
        epochs=1,
        out_row_type=src_row_type,
        converters=[ScalarConverter.FIRST, ScalarConverter.ONE_HOT])
    input_table = table_src.register_table(table_env=table_env)
    builder = TableSchema.Builder()
    builder.column(name='label_org',
                   data_type=IntegerType()).column(name='predict_label',
                                                   data_type=IntegerType())
    output_schema = builder.build()
    output_table = tensorflow_on_flink_table.inference(
        num_worker=2,
        func=mnist_table_inference.map_fun,
        properties=build_props('0'),
        stream_env=stream_env,
        table_env=table_env,
        input_table=input_table,
        output_schema=output_schema)
    output_table.write_to_sink(LogTableStreamSink())
    table_env.generate_stream_graph()
    stream_env.execute()
    testing_server.stop()
Пример #2
0
def datastream_inference_input():
    prep_data()
    testing_server = start_zk_server()
    generate_model()
    stream_env = StreamExecutionEnvironment.get_execution_environment()
    test_data_path = "file://" + get_root_path(
    ) + "/examples/target/data/test/"
    paths = [test_data_path + "0.tfrecords", test_data_path + "1.tfrecords"]
    src_row_type = RowType([StringType(), IntegerType()],
                           ["image_raw", "label"])
    input_ds = stream_env.add_source(source_func=TFRSourceFunc(
        paths=paths,
        epochs=1,
        out_row_type=src_row_type,
        converters=[ScalarConverter.FIRST, ScalarConverter.ONE_HOT]))
    input_ds.set_parallelism(len(paths))
    out_row_type = RowType([IntegerType(), IntegerType()],
                           ['label_org', 'predict_label'])
    output_ds = tensorflow_on_flink_datastream.inference(
        num_worker=1,
        func=mnist_table_inference.map_fun,
        properties=build_props('0'),
        stream_env=stream_env,
        input_ds=input_ds,
        output_row_type=out_row_type)
    output_ds.add_sink(LogSink())
    stream_env.execute()
    testing_server.stop()
Пример #3
0
def table_with_input():
    prep_data()
    testing_server = start_zk_server()
    stream_env = StreamExecutionEnvironment.get_execution_environment()
    table_env = TableEnvironment.get_table_environment(stream_env)
    stream_env.set_parallelism(2)
    train_data_path = "file://" + get_root_path(
    ) + "/examples/target/data/train/"
    paths = [train_data_path + "0.tfrecords", train_data_path + "1.tfrecords"]
    out_row_type = RowType([StringType(), IntegerType()],
                           ["image_raw", "label"])
    table_src = TFRTableSource(
        paths=paths,
        epochs=1,
        out_row_type=out_row_type,
        converters=[ScalarConverter.FIRST, ScalarConverter.ONE_HOT])
    input_table = table_src.register_table(table_env=table_env)
    tensorflow_on_flink_table.train(num_worker=1,
                                    num_ps=1,
                                    func=mnist_dist_with_input.map_fun,
                                    properties=build_props(),
                                    stream_env=stream_env,
                                    table_env=table_env,
                                    input_table=input_table)
    table_env.generate_stream_graph()
    stream_env.execute()
    testing_server.stop()
Пример #4
0
def table_java_inference():
    prep_data()
    testing_server = start_zk_server()
    generate_model()
    stream_env = StreamExecutionEnvironment.get_execution_environment()
    table_env = TableEnvironment.get_table_environment(stream_env)
    stream_env.set_parallelism(2)
    train_data_path = "file://" + get_root_path(
    ) + "/examples/target/data/test/"
    paths = [train_data_path + "0.tfrecords", train_data_path + "1.tfrecords"]
    src_row_type = RowType([StringType(), IntegerType()],
                           ["image_raw", "label"])
    table_src = TFRTableSource(
        paths=paths,
        epochs=1,
        out_row_type=src_row_type,
        converters=[ScalarConverter.FIRST, ScalarConverter.ONE_HOT])
    tfr_tbl_name = "tfr_input_table"
    table_env.register_table_source(tfr_tbl_name, table_src)
    ext_func_name = "tfr_extract"
    table_env.register_function(ext_func_name, java_inference_extract_func())
    out_cols = 'image,org_label'
    in_cols = ','.join(src_row_type.fields_names)
    extracted = table_env.sql_query(
        'select {} from {}, LATERAL TABLE({}({})) as T({})'.format(
            out_cols, tfr_tbl_name, ext_func_name, in_cols, out_cols))
    builder = TableSchema.Builder()
    builder.column(name='real_label',
                   data_type=LongType()).column(name='predicted_label',
                                                data_type=LongType())
    output_schema = builder.build()
    props = build_props('0')
    props[TF_INFERENCE_EXPORT_PATH] = export_path
    props[TF_INFERENCE_INPUT_TENSOR_NAMES] = 'image'
    props[TF_INFERENCE_OUTPUT_TENSOR_NAMES] = 'prediction'
    props[TF_INFERENCE_OUTPUT_ROW_FIELDS] = ','.join(
        ['org_label', 'prediction'])
    output_table = tensorflow_on_flink_table.inference(
        num_worker=2,
        properties=props,
        stream_env=stream_env,
        table_env=table_env,
        input_table=extracted,
        output_schema=output_schema)
    output_table.write_to_sink(LogInferAccSink())
    table_env.generate_stream_graph()
    stream_env.execute()
    testing_server.stop()
def train(num_worker,
          num_ps,
          func,
          properties=None,
          env_path=None,
          zk_conn=None,
          zk_base_path=None,
          stream_env=None,
          table_env=None,
          input_table=None,
          output_schema=None):
    """
    Tensorflow training for Table
    :param zk_conn: The Zookeeper connection string
    :param zk_base_path: The Zookeeper base path
    :param num_worker: Number of workers
    :param num_ps: Number of PS
    :param func: The user-defined function that runs TF training
    :param properties: User-defined properties
    :param env_path: Path to the virtual env
    :param stream_env: The StreamExecutionEnvironment. If it's None, this method will create one and execute the job
                       at the end. Otherwise, caller is responsible to trigger the job execution
    :param table_env: The TableEnvironment
    :param input_table: The input Table
    :param output_schema: The TableSchema of the output Table. If it's None, a dummy sink will be added to the output
                          Table. Otherwise, caller is responsible to add sink before executing the job.
    :return: The output Table
    """
    tf_config = TFConfig(num_worker, num_ps, func, properties, env_path,
                         zk_conn, zk_base_path)
    execute = stream_env is None
    if stream_env is None:
        stream_env = StreamExecutionEnvironment.get_execution_environment()
    if table_env is None:
        table_env = TableEnvironment.get_table_environment(stream_env)
    if input_table is not None:
        input_table = input_table._java_table
    if output_schema is not None:
        output_schema = output_schema._j_table_schema
    output_table = get_gateway(
    ).jvm.com.alibaba.flink.tensorflow.client.TFUtils.train(
        stream_env._j_env, table_env._j_tenv, input_table,
        tf_config.java_config(), output_schema)
    if execute:
        table_env.execute()
    return Table(java_table=output_table)
def train(num_worker,
          num_ps,
          func,
          properties=None,
          env_path=None,
          zk_conn=None,
          zk_base_path=None,
          stream_env=None,
          input_ds=None,
          output_row_type=None):
    """
    Tensorflow training for DataStream
    :param zk_conn: The Zookeeper connection string
    :param zk_base_path: The Zookeeper base path
    :param num_worker: Number of workers
    :param num_ps: Number of PS
    :param func: The user-defined function that runs TF training
    :param properties: User-defined properties
    :param env_path: Path to the virtual env
    :param stream_env: The StreamExecutionEnvironment. If it's None, this method will create one and execute the job
                       at the end. Otherwise, caller is responsible to trigger the job execution
    :param input_ds: The input DataStream
    :param output_row_type: The RowType for the output DataStream. If it's None, a dummy sink will be added to the
                      output DataStream. Otherwise, caller is responsible to add sink before executing the job.
    :return: The output DataStream. Currently it's always of type Row.
    """
    tf_config = TFConfig(num_worker, num_ps, func, properties, env_path,
                         zk_conn, zk_base_path)
    execute = stream_env is None
    if stream_env is None:
        stream_env = StreamExecutionEnvironment.get_execution_environment()
    if input_ds is not None:
        if isinstance(input_ds, DataStreamSource):
            input_ds = input_ds._j_datastream_source
        else:
            input_ds = input_ds._j_datastream
    output_ds = get_gateway(
    ).jvm.com.alibaba.flink.tensorflow.client.TFUtils.train(
        stream_env._j_env, input_ds, tf_config.java_config(),
        to_row_type_info(output_row_type))
    if execute:
        stream_env.execute()
    return DataStream(output_ds)
Пример #7
0
def datastream_with_input():
    prep_data()
    testing_server = start_zk_server()
    stream_env = StreamExecutionEnvironment.get_execution_environment()
    train_data_path = "file://" + get_root_path(
    ) + "/examples/target/data/train/"
    paths = [train_data_path + "0.tfrecords", train_data_path + "1.tfrecords"]
    src_row_type = RowType([StringType(), IntegerType()],
                           ["image_raw", "label"])
    input_ds = stream_env.add_source(
        TFRSourceFunc(
            paths=paths,
            epochs=1,
            out_row_type=src_row_type,
            converters=[ScalarConverter.FIRST, ScalarConverter.ONE_HOT]))
    input_ds.set_parallelism(len(paths))
    tensorflow_on_flink_datastream.train(num_worker=1,
                                         num_ps=1,
                                         func=mnist_dist_with_input.map_fun,
                                         properties=build_props(),
                                         input_ds=input_ds,
                                         stream_env=stream_env)
    stream_env.execute()
    testing_server.stop()