def table_with_input(): prep_data() testing_server = start_zk_server() stream_env = StreamExecutionEnvironment.get_execution_environment() table_env = TableEnvironment.get_table_environment(stream_env) stream_env.set_parallelism(2) train_data_path = "file://" + get_root_path( ) + "/examples/target/data/train/" paths = [train_data_path + "0.tfrecords", train_data_path + "1.tfrecords"] out_row_type = RowType([StringType(), IntegerType()], ["image_raw", "label"]) table_src = TFRTableSource( paths=paths, epochs=1, out_row_type=out_row_type, converters=[ScalarConverter.FIRST, ScalarConverter.ONE_HOT]) input_table = table_src.register_table(table_env=table_env) tensorflow_on_flink_table.train(num_worker=1, num_ps=1, func=mnist_dist_with_input.map_fun, properties=build_props(), stream_env=stream_env, table_env=table_env, input_table=input_table) table_env.generate_stream_graph() stream_env.execute() testing_server.stop()
def table_no_input(): prep_data() testing_server = start_zk_server() tensorflow_on_flink_table.train(num_worker=1, num_ps=1, func=mnist_dist.map_fun, properties=build_props(), zk_conn="localhost:2181") testing_server.stop()
def generate_model(): if not os.path.exists(export_path): tensorflow_on_flink_table.train(num_worker=1, num_ps=1, func=mnist_dist.map_fun, properties=build_props('0'))