コード例 #1
0
def table_inference():
    prep_data()
    testing_server = start_zk_server()
    generate_model()
    stream_env = StreamExecutionEnvironment.get_execution_environment()
    table_env = TableEnvironment.get_table_environment(stream_env)
    stream_env.set_parallelism(2)
    test_data_path = "file://" + get_root_path(
    ) + "/examples/target/data/test/"
    paths = [test_data_path + "0.tfrecords", test_data_path + "1.tfrecords"]
    src_row_type = RowType([StringType(), IntegerType()],
                           ["image_raw", "label"])
    table_src = TFRTableSource(
        paths=paths,
        epochs=1,
        out_row_type=src_row_type,
        converters=[ScalarConverter.FIRST, ScalarConverter.ONE_HOT])
    input_table = table_src.register_table(table_env=table_env)
    builder = TableSchema.Builder()
    builder.column(name='label_org',
                   data_type=IntegerType()).column(name='predict_label',
                                                   data_type=IntegerType())
    output_schema = builder.build()
    output_table = tensorflow_on_flink_table.inference(
        num_worker=2,
        func=mnist_table_inference.map_fun,
        properties=build_props('0'),
        stream_env=stream_env,
        table_env=table_env,
        input_table=input_table,
        output_schema=output_schema)
    output_table.write_to_sink(LogTableStreamSink())
    table_env.generate_stream_graph()
    stream_env.execute()
    testing_server.stop()
コード例 #2
0
def table_java_inference():
    prep_data()
    testing_server = start_zk_server()
    generate_model()
    stream_env = StreamExecutionEnvironment.get_execution_environment()
    table_env = TableEnvironment.get_table_environment(stream_env)
    stream_env.set_parallelism(2)
    train_data_path = "file://" + get_root_path(
    ) + "/examples/target/data/test/"
    paths = [train_data_path + "0.tfrecords", train_data_path + "1.tfrecords"]
    src_row_type = RowType([StringType(), IntegerType()],
                           ["image_raw", "label"])
    table_src = TFRTableSource(
        paths=paths,
        epochs=1,
        out_row_type=src_row_type,
        converters=[ScalarConverter.FIRST, ScalarConverter.ONE_HOT])
    tfr_tbl_name = "tfr_input_table"
    table_env.register_table_source(tfr_tbl_name, table_src)
    ext_func_name = "tfr_extract"
    table_env.register_function(ext_func_name, java_inference_extract_func())
    out_cols = 'image,org_label'
    in_cols = ','.join(src_row_type.fields_names)
    extracted = table_env.sql_query(
        'select {} from {}, LATERAL TABLE({}({})) as T({})'.format(
            out_cols, tfr_tbl_name, ext_func_name, in_cols, out_cols))
    builder = TableSchema.Builder()
    builder.column(name='real_label',
                   data_type=LongType()).column(name='predicted_label',
                                                data_type=LongType())
    output_schema = builder.build()
    props = build_props('0')
    props[TF_INFERENCE_EXPORT_PATH] = export_path
    props[TF_INFERENCE_INPUT_TENSOR_NAMES] = 'image'
    props[TF_INFERENCE_OUTPUT_TENSOR_NAMES] = 'prediction'
    props[TF_INFERENCE_OUTPUT_ROW_FIELDS] = ','.join(
        ['org_label', 'prediction'])
    output_table = tensorflow_on_flink_table.inference(
        num_worker=2,
        properties=props,
        stream_env=stream_env,
        table_env=table_env,
        input_table=extracted,
        output_schema=output_schema)
    output_table.write_to_sink(LogInferAccSink())
    table_env.generate_stream_graph()
    stream_env.execute()
    testing_server.stop()