def table_with_input(): prep_data() testing_server = start_zk_server() stream_env = StreamExecutionEnvironment.get_execution_environment() table_env = TableEnvironment.get_table_environment(stream_env) stream_env.set_parallelism(2) train_data_path = "file://" + get_root_path( ) + "/examples/target/data/train/" paths = [train_data_path + "0.tfrecords", train_data_path + "1.tfrecords"] out_row_type = RowType([StringType(), IntegerType()], ["image_raw", "label"]) table_src = TFRTableSource( paths=paths, epochs=1, out_row_type=out_row_type, converters=[ScalarConverter.FIRST, ScalarConverter.ONE_HOT]) input_table = table_src.register_table(table_env=table_env) tensorflow_on_flink_table.train(num_worker=1, num_ps=1, func=mnist_dist_with_input.map_fun, properties=build_props(), stream_env=stream_env, table_env=table_env, input_table=input_table) table_env.generate_stream_graph() stream_env.execute() testing_server.stop()
def table_inference(): prep_data() testing_server = start_zk_server() generate_model() stream_env = StreamExecutionEnvironment.get_execution_environment() table_env = TableEnvironment.get_table_environment(stream_env) stream_env.set_parallelism(2) test_data_path = "file://" + get_root_path( ) + "/examples/target/data/test/" paths = [test_data_path + "0.tfrecords", test_data_path + "1.tfrecords"] src_row_type = RowType([StringType(), IntegerType()], ["image_raw", "label"]) table_src = TFRTableSource( paths=paths, epochs=1, out_row_type=src_row_type, converters=[ScalarConverter.FIRST, ScalarConverter.ONE_HOT]) input_table = table_src.register_table(table_env=table_env) builder = TableSchema.Builder() builder.column(name='label_org', data_type=IntegerType()).column(name='predict_label', data_type=IntegerType()) output_schema = builder.build() output_table = tensorflow_on_flink_table.inference( num_worker=2, func=mnist_table_inference.map_fun, properties=build_props('0'), stream_env=stream_env, table_env=table_env, input_table=input_table, output_schema=output_schema) output_table.write_to_sink(LogTableStreamSink()) table_env.generate_stream_graph() stream_env.execute() testing_server.stop()
def table_java_inference(): prep_data() testing_server = start_zk_server() generate_model() stream_env = StreamExecutionEnvironment.get_execution_environment() table_env = TableEnvironment.get_table_environment(stream_env) stream_env.set_parallelism(2) train_data_path = "file://" + get_root_path( ) + "/examples/target/data/test/" paths = [train_data_path + "0.tfrecords", train_data_path + "1.tfrecords"] src_row_type = RowType([StringType(), IntegerType()], ["image_raw", "label"]) table_src = TFRTableSource( paths=paths, epochs=1, out_row_type=src_row_type, converters=[ScalarConverter.FIRST, ScalarConverter.ONE_HOT]) tfr_tbl_name = "tfr_input_table" table_env.register_table_source(tfr_tbl_name, table_src) ext_func_name = "tfr_extract" table_env.register_function(ext_func_name, java_inference_extract_func()) out_cols = 'image,org_label' in_cols = ','.join(src_row_type.fields_names) extracted = table_env.sql_query( 'select {} from {}, LATERAL TABLE({}({})) as T({})'.format( out_cols, tfr_tbl_name, ext_func_name, in_cols, out_cols)) builder = TableSchema.Builder() builder.column(name='real_label', data_type=LongType()).column(name='predicted_label', data_type=LongType()) output_schema = builder.build() props = build_props('0') props[TF_INFERENCE_EXPORT_PATH] = export_path props[TF_INFERENCE_INPUT_TENSOR_NAMES] = 'image' props[TF_INFERENCE_OUTPUT_TENSOR_NAMES] = 'prediction' props[TF_INFERENCE_OUTPUT_ROW_FIELDS] = ','.join( ['org_label', 'prediction']) output_table = tensorflow_on_flink_table.inference( num_worker=2, properties=props, stream_env=stream_env, table_env=table_env, input_table=extracted, output_schema=output_schema) output_table.write_to_sink(LogInferAccSink()) table_env.generate_stream_graph() stream_env.execute() testing_server.stop()
def train(num_worker, num_ps, func, properties=None, env_path=None, zk_conn=None, zk_base_path=None, stream_env=None, table_env=None, input_table=None, output_schema=None): """ Tensorflow training for Table :param zk_conn: The Zookeeper connection string :param zk_base_path: The Zookeeper base path :param num_worker: Number of workers :param num_ps: Number of PS :param func: The user-defined function that runs TF training :param properties: User-defined properties :param env_path: Path to the virtual env :param stream_env: The StreamExecutionEnvironment. If it's None, this method will create one and execute the job at the end. Otherwise, caller is responsible to trigger the job execution :param table_env: The TableEnvironment :param input_table: The input Table :param output_schema: The TableSchema of the output Table. If it's None, a dummy sink will be added to the output Table. Otherwise, caller is responsible to add sink before executing the job. :return: The output Table """ tf_config = TFConfig(num_worker, num_ps, func, properties, env_path, zk_conn, zk_base_path) execute = stream_env is None if stream_env is None: stream_env = StreamExecutionEnvironment.get_execution_environment() if table_env is None: table_env = TableEnvironment.get_table_environment(stream_env) if input_table is not None: input_table = input_table._java_table if output_schema is not None: output_schema = output_schema._j_table_schema output_table = get_gateway( ).jvm.com.alibaba.flink.tensorflow.client.TFUtils.train( stream_env._j_env, table_env._j_tenv, input_table, tf_config.java_config(), output_schema) if execute: table_env.execute() return Table(java_table=output_table)