Exemple #1
0
def datastream_inference_input():
    prep_data()
    testing_server = start_zk_server()
    generate_model()
    stream_env = StreamExecutionEnvironment.get_execution_environment()
    test_data_path = "file://" + get_root_path(
    ) + "/examples/target/data/test/"
    paths = [test_data_path + "0.tfrecords", test_data_path + "1.tfrecords"]
    src_row_type = RowType([StringType(), IntegerType()],
                           ["image_raw", "label"])
    input_ds = stream_env.add_source(source_func=TFRSourceFunc(
        paths=paths,
        epochs=1,
        out_row_type=src_row_type,
        converters=[ScalarConverter.FIRST, ScalarConverter.ONE_HOT]))
    input_ds.set_parallelism(len(paths))
    out_row_type = RowType([IntegerType(), IntegerType()],
                           ['label_org', 'predict_label'])
    output_ds = tensorflow_on_flink_datastream.inference(
        num_worker=1,
        func=mnist_table_inference.map_fun,
        properties=build_props('0'),
        stream_env=stream_env,
        input_ds=input_ds,
        output_row_type=out_row_type)
    output_ds.add_sink(LogSink())
    stream_env.execute()
    testing_server.stop()
Exemple #2
0
def table_with_input():
    prep_data()
    testing_server = start_zk_server()
    stream_env = StreamExecutionEnvironment.get_execution_environment()
    table_env = TableEnvironment.get_table_environment(stream_env)
    stream_env.set_parallelism(2)
    train_data_path = "file://" + get_root_path(
    ) + "/examples/target/data/train/"
    paths = [train_data_path + "0.tfrecords", train_data_path + "1.tfrecords"]
    out_row_type = RowType([StringType(), IntegerType()],
                           ["image_raw", "label"])
    table_src = TFRTableSource(
        paths=paths,
        epochs=1,
        out_row_type=out_row_type,
        converters=[ScalarConverter.FIRST, ScalarConverter.ONE_HOT])
    input_table = table_src.register_table(table_env=table_env)
    tensorflow_on_flink_table.train(num_worker=1,
                                    num_ps=1,
                                    func=mnist_dist_with_input.map_fun,
                                    properties=build_props(),
                                    stream_env=stream_env,
                                    table_env=table_env,
                                    input_table=input_table)
    table_env.generate_stream_graph()
    stream_env.execute()
    testing_server.stop()
Exemple #3
0
def table_inference():
    prep_data()
    testing_server = start_zk_server()
    generate_model()
    stream_env = StreamExecutionEnvironment.get_execution_environment()
    table_env = TableEnvironment.get_table_environment(stream_env)
    stream_env.set_parallelism(2)
    test_data_path = "file://" + get_root_path(
    ) + "/examples/target/data/test/"
    paths = [test_data_path + "0.tfrecords", test_data_path + "1.tfrecords"]
    src_row_type = RowType([StringType(), IntegerType()],
                           ["image_raw", "label"])
    table_src = TFRTableSource(
        paths=paths,
        epochs=1,
        out_row_type=src_row_type,
        converters=[ScalarConverter.FIRST, ScalarConverter.ONE_HOT])
    input_table = table_src.register_table(table_env=table_env)
    builder = TableSchema.Builder()
    builder.column(name='label_org',
                   data_type=IntegerType()).column(name='predict_label',
                                                   data_type=IntegerType())
    output_schema = builder.build()
    output_table = tensorflow_on_flink_table.inference(
        num_worker=2,
        func=mnist_table_inference.map_fun,
        properties=build_props('0'),
        stream_env=stream_env,
        table_env=table_env,
        input_table=input_table,
        output_schema=output_schema)
    output_table.write_to_sink(LogTableStreamSink())
    table_env.generate_stream_graph()
    stream_env.execute()
    testing_server.stop()
Exemple #4
0
def table_java_inference():
    prep_data()
    testing_server = start_zk_server()
    generate_model()
    stream_env = StreamExecutionEnvironment.get_execution_environment()
    table_env = TableEnvironment.get_table_environment(stream_env)
    stream_env.set_parallelism(2)
    train_data_path = "file://" + get_root_path(
    ) + "/examples/target/data/test/"
    paths = [train_data_path + "0.tfrecords", train_data_path + "1.tfrecords"]
    src_row_type = RowType([StringType(), IntegerType()],
                           ["image_raw", "label"])
    table_src = TFRTableSource(
        paths=paths,
        epochs=1,
        out_row_type=src_row_type,
        converters=[ScalarConverter.FIRST, ScalarConverter.ONE_HOT])
    tfr_tbl_name = "tfr_input_table"
    table_env.register_table_source(tfr_tbl_name, table_src)
    ext_func_name = "tfr_extract"
    table_env.register_function(ext_func_name, java_inference_extract_func())
    out_cols = 'image,org_label'
    in_cols = ','.join(src_row_type.fields_names)
    extracted = table_env.sql_query(
        'select {} from {}, LATERAL TABLE({}({})) as T({})'.format(
            out_cols, tfr_tbl_name, ext_func_name, in_cols, out_cols))
    builder = TableSchema.Builder()
    builder.column(name='real_label',
                   data_type=LongType()).column(name='predicted_label',
                                                data_type=LongType())
    output_schema = builder.build()
    props = build_props('0')
    props[TF_INFERENCE_EXPORT_PATH] = export_path
    props[TF_INFERENCE_INPUT_TENSOR_NAMES] = 'image'
    props[TF_INFERENCE_OUTPUT_TENSOR_NAMES] = 'prediction'
    props[TF_INFERENCE_OUTPUT_ROW_FIELDS] = ','.join(
        ['org_label', 'prediction'])
    output_table = tensorflow_on_flink_table.inference(
        num_worker=2,
        properties=props,
        stream_env=stream_env,
        table_env=table_env,
        input_table=extracted,
        output_schema=output_schema)
    output_table.write_to_sink(LogInferAccSink())
    table_env.generate_stream_graph()
    stream_env.execute()
    testing_server.stop()
Exemple #5
0
def datastream_with_input():
    prep_data()
    testing_server = start_zk_server()
    stream_env = StreamExecutionEnvironment.get_execution_environment()
    train_data_path = "file://" + get_root_path(
    ) + "/examples/target/data/train/"
    paths = [train_data_path + "0.tfrecords", train_data_path + "1.tfrecords"]
    src_row_type = RowType([StringType(), IntegerType()],
                           ["image_raw", "label"])
    input_ds = stream_env.add_source(
        TFRSourceFunc(
            paths=paths,
            epochs=1,
            out_row_type=src_row_type,
            converters=[ScalarConverter.FIRST, ScalarConverter.ONE_HOT]))
    input_ds.set_parallelism(len(paths))
    tensorflow_on_flink_datastream.train(num_worker=1,
                                         num_ps=1,
                                         func=mnist_dist_with_input.map_fun,
                                         properties=build_props(),
                                         input_ds=input_ds,
                                         stream_env=stream_env)
    stream_env.execute()
    testing_server.stop()