def datastream_inference_input(): prep_data() testing_server = start_zk_server() generate_model() stream_env = StreamExecutionEnvironment.get_execution_environment() test_data_path = "file://" + get_root_path( ) + "/examples/target/data/test/" paths = [test_data_path + "0.tfrecords", test_data_path + "1.tfrecords"] src_row_type = RowType([StringType(), IntegerType()], ["image_raw", "label"]) input_ds = stream_env.add_source(source_func=TFRSourceFunc( paths=paths, epochs=1, out_row_type=src_row_type, converters=[ScalarConverter.FIRST, ScalarConverter.ONE_HOT])) input_ds.set_parallelism(len(paths)) out_row_type = RowType([IntegerType(), IntegerType()], ['label_org', 'predict_label']) output_ds = tensorflow_on_flink_datastream.inference( num_worker=1, func=mnist_table_inference.map_fun, properties=build_props('0'), stream_env=stream_env, input_ds=input_ds, output_row_type=out_row_type) output_ds.add_sink(LogSink()) stream_env.execute() testing_server.stop()
def table_with_input(): prep_data() testing_server = start_zk_server() stream_env = StreamExecutionEnvironment.get_execution_environment() table_env = TableEnvironment.get_table_environment(stream_env) stream_env.set_parallelism(2) train_data_path = "file://" + get_root_path( ) + "/examples/target/data/train/" paths = [train_data_path + "0.tfrecords", train_data_path + "1.tfrecords"] out_row_type = RowType([StringType(), IntegerType()], ["image_raw", "label"]) table_src = TFRTableSource( paths=paths, epochs=1, out_row_type=out_row_type, converters=[ScalarConverter.FIRST, ScalarConverter.ONE_HOT]) input_table = table_src.register_table(table_env=table_env) tensorflow_on_flink_table.train(num_worker=1, num_ps=1, func=mnist_dist_with_input.map_fun, properties=build_props(), stream_env=stream_env, table_env=table_env, input_table=input_table) table_env.generate_stream_graph() stream_env.execute() testing_server.stop()
def table_inference(): prep_data() testing_server = start_zk_server() generate_model() stream_env = StreamExecutionEnvironment.get_execution_environment() table_env = TableEnvironment.get_table_environment(stream_env) stream_env.set_parallelism(2) test_data_path = "file://" + get_root_path( ) + "/examples/target/data/test/" paths = [test_data_path + "0.tfrecords", test_data_path + "1.tfrecords"] src_row_type = RowType([StringType(), IntegerType()], ["image_raw", "label"]) table_src = TFRTableSource( paths=paths, epochs=1, out_row_type=src_row_type, converters=[ScalarConverter.FIRST, ScalarConverter.ONE_HOT]) input_table = table_src.register_table(table_env=table_env) builder = TableSchema.Builder() builder.column(name='label_org', data_type=IntegerType()).column(name='predict_label', data_type=IntegerType()) output_schema = builder.build() output_table = tensorflow_on_flink_table.inference( num_worker=2, func=mnist_table_inference.map_fun, properties=build_props('0'), stream_env=stream_env, table_env=table_env, input_table=input_table, output_schema=output_schema) output_table.write_to_sink(LogTableStreamSink()) table_env.generate_stream_graph() stream_env.execute() testing_server.stop()
def table_java_inference(): prep_data() testing_server = start_zk_server() generate_model() stream_env = StreamExecutionEnvironment.get_execution_environment() table_env = TableEnvironment.get_table_environment(stream_env) stream_env.set_parallelism(2) train_data_path = "file://" + get_root_path( ) + "/examples/target/data/test/" paths = [train_data_path + "0.tfrecords", train_data_path + "1.tfrecords"] src_row_type = RowType([StringType(), IntegerType()], ["image_raw", "label"]) table_src = TFRTableSource( paths=paths, epochs=1, out_row_type=src_row_type, converters=[ScalarConverter.FIRST, ScalarConverter.ONE_HOT]) tfr_tbl_name = "tfr_input_table" table_env.register_table_source(tfr_tbl_name, table_src) ext_func_name = "tfr_extract" table_env.register_function(ext_func_name, java_inference_extract_func()) out_cols = 'image,org_label' in_cols = ','.join(src_row_type.fields_names) extracted = table_env.sql_query( 'select {} from {}, LATERAL TABLE({}({})) as T({})'.format( out_cols, tfr_tbl_name, ext_func_name, in_cols, out_cols)) builder = TableSchema.Builder() builder.column(name='real_label', data_type=LongType()).column(name='predicted_label', data_type=LongType()) output_schema = builder.build() props = build_props('0') props[TF_INFERENCE_EXPORT_PATH] = export_path props[TF_INFERENCE_INPUT_TENSOR_NAMES] = 'image' props[TF_INFERENCE_OUTPUT_TENSOR_NAMES] = 'prediction' props[TF_INFERENCE_OUTPUT_ROW_FIELDS] = ','.join( ['org_label', 'prediction']) output_table = tensorflow_on_flink_table.inference( num_worker=2, properties=props, stream_env=stream_env, table_env=table_env, input_table=extracted, output_schema=output_schema) output_table.write_to_sink(LogInferAccSink()) table_env.generate_stream_graph() stream_env.execute() testing_server.stop()
def datastream_with_input(): prep_data() testing_server = start_zk_server() stream_env = StreamExecutionEnvironment.get_execution_environment() train_data_path = "file://" + get_root_path( ) + "/examples/target/data/train/" paths = [train_data_path + "0.tfrecords", train_data_path + "1.tfrecords"] src_row_type = RowType([StringType(), IntegerType()], ["image_raw", "label"]) input_ds = stream_env.add_source( TFRSourceFunc( paths=paths, epochs=1, out_row_type=src_row_type, converters=[ScalarConverter.FIRST, ScalarConverter.ONE_HOT])) input_ds.set_parallelism(len(paths)) tensorflow_on_flink_datastream.train(num_worker=1, num_ps=1, func=mnist_dist_with_input.map_fun, properties=build_props(), input_ds=input_ds, stream_env=stream_env) stream_env.execute() testing_server.stop()