Exemple #1
0
def table_with_input():
    prep_data()
    testing_server = start_zk_server()
    stream_env = StreamExecutionEnvironment.get_execution_environment()
    table_env = TableEnvironment.get_table_environment(stream_env)
    stream_env.set_parallelism(2)
    train_data_path = "file://" + get_root_path(
    ) + "/examples/target/data/train/"
    paths = [train_data_path + "0.tfrecords", train_data_path + "1.tfrecords"]
    out_row_type = RowType([StringType(), IntegerType()],
                           ["image_raw", "label"])
    table_src = TFRTableSource(
        paths=paths,
        epochs=1,
        out_row_type=out_row_type,
        converters=[ScalarConverter.FIRST, ScalarConverter.ONE_HOT])
    input_table = table_src.register_table(table_env=table_env)
    tensorflow_on_flink_table.train(num_worker=1,
                                    num_ps=1,
                                    func=mnist_dist_with_input.map_fun,
                                    properties=build_props(),
                                    stream_env=stream_env,
                                    table_env=table_env,
                                    input_table=input_table)
    table_env.generate_stream_graph()
    stream_env.execute()
    testing_server.stop()
Exemple #2
0
def table_no_input():
    prep_data()
    testing_server = start_zk_server()
    tensorflow_on_flink_table.train(num_worker=1,
                                    num_ps=1,
                                    func=mnist_dist.map_fun,
                                    properties=build_props(),
                                    zk_conn="localhost:2181")
    testing_server.stop()
Exemple #3
0
def generate_model():
    if not os.path.exists(export_path):
        tensorflow_on_flink_table.train(num_worker=1,
                                        num_ps=1,
                                        func=mnist_dist.map_fun,
                                        properties=build_props('0'))