def serving_input_fn(batch_size, desired_image_size, stride, input_type, input_name='input'): """Input function for SavedModels and TF serving. Args: batch_size: The batch size. desired_image_size: The tuple/list of two integers, specifying the desired image size. stride: an integer, the stride of the backbone network. The processed image will be (internally) padded such that each side is the multiple of this number. input_type: a string of 'image_tensor', 'image_bytes' or 'tf_example', specifying which type of input will be used in serving. input_name: a string to specify the name of the input signature. Returns: a `tf.estimator.export.ServingInputReceiver` for a SavedModel. """ placeholder, features = inputs.build_serving_input(input_type, batch_size, desired_image_size, stride) return tf.estimator.export.ServingInputReceiver(features=features, receiver_tensors={ input_name: placeholder, })
def main(argv): del argv # Unused. params = factory.config_generator(FLAGS.model) if FLAGS.config_file: params = params_dict.override_params_dict(params, FLAGS.config_file, is_strict=True) # Use `is_strict=False` to load params_override with run_time variables like # `train.num_shards`. params = params_dict.override_params_dict(params, FLAGS.params_override, is_strict=False) params.validate() params.lock() image_size = [int(x) for x in FLAGS.input_image_size.split(',')] g = tf.Graph() with g.as_default(): # Build the input. _, features = inputs.build_serving_input( input_type=FLAGS.input_type, batch_size=FLAGS.batch_size, desired_image_size=image_size, stride=(2**params.anchor.max_level)) # Build the model. print(' - Building the graph...') if FLAGS.model in ['retinanet', 'mask_rcnn', 'shapemask']: graph_fn = detection.serving_model_graph_builder( FLAGS.output_image_info, FLAGS.output_normalized_coordinates, FLAGS.cast_num_detections_to_float) else: raise ValueError('The model type `{}` is not supported.'.format( FLAGS.model)) predictions = graph_fn(features, params) # Add a saver for checkpoint loading. tf.train.Saver() inference_graph_def = g.as_graph_def() optimized_graph_def = inference_graph_def if FLAGS.optimize_graph: print(' - Optimizing the graph...') # Trim the unused nodes in the graph. output_nodes = [ output_node.op.name for output_node in predictions.values() ] # TODO(pengchong): Consider to use `strip_unused_lib.strip_unused` and/or # `optimize_for_inference_lib.optimize_for_inference` to trim the graph. # Use `optimize_for_inference` if we decide to export the frozen graph # (graph + checkpoint) and want explictily fold in batchnorm variables. optimized_graph_def = graph_util.remove_training_nodes( optimized_graph_def, output_nodes) print(' - Saving the graph...') tf.train.write_graph(optimized_graph_def, FLAGS.export_dir, 'inference_graph.pbtxt') print(' - Done!')