Example #1
0
def serving_input_fn(batch_size,
                     desired_image_size,
                     stride,
                     input_type,
                     input_name='input'):
    """Input function for SavedModels and TF serving.

  Args:
    batch_size: The batch size.
    desired_image_size: The tuple/list of two integers, specifying the desired
      image size.
    stride: an integer, the stride of the backbone network. The processed image
      will be (internally) padded such that each side is the multiple of this
      number.
    input_type: a string of 'image_tensor', 'image_bytes' or 'tf_example',
      specifying which type of input will be used in serving.
    input_name: a string to specify the name of the input signature.

  Returns:
    a `tf.estimator.export.ServingInputReceiver` for a SavedModel.
  """
    placeholder, features = inputs.build_serving_input(input_type, batch_size,
                                                       desired_image_size,
                                                       stride)
    return tf.estimator.export.ServingInputReceiver(features=features,
                                                    receiver_tensors={
                                                        input_name:
                                                        placeholder,
                                                    })
Example #2
0
def main(argv):
    del argv  # Unused.

    params = factory.config_generator(FLAGS.model)
    if FLAGS.config_file:
        params = params_dict.override_params_dict(params,
                                                  FLAGS.config_file,
                                                  is_strict=True)
    # Use `is_strict=False` to load params_override with run_time variables like
    # `train.num_shards`.
    params = params_dict.override_params_dict(params,
                                              FLAGS.params_override,
                                              is_strict=False)
    params.validate()
    params.lock()

    image_size = [int(x) for x in FLAGS.input_image_size.split(',')]

    g = tf.Graph()
    with g.as_default():
        # Build the input.
        _, features = inputs.build_serving_input(
            input_type=FLAGS.input_type,
            batch_size=FLAGS.batch_size,
            desired_image_size=image_size,
            stride=(2**params.anchor.max_level))

        # Build the model.
        print(' - Building the graph...')
        if FLAGS.model in ['retinanet', 'mask_rcnn', 'shapemask']:
            graph_fn = detection.serving_model_graph_builder(
                FLAGS.output_image_info, FLAGS.output_normalized_coordinates,
                FLAGS.cast_num_detections_to_float)
        else:
            raise ValueError('The model type `{}` is not supported.'.format(
                FLAGS.model))

        predictions = graph_fn(features, params)

        # Add a saver for checkpoint loading.
        tf.train.Saver()

        inference_graph_def = g.as_graph_def()
        optimized_graph_def = inference_graph_def

        if FLAGS.optimize_graph:
            print(' - Optimizing the graph...')
            # Trim the unused nodes in the graph.
            output_nodes = [
                output_node.op.name for output_node in predictions.values()
            ]
            # TODO(pengchong): Consider to use `strip_unused_lib.strip_unused` and/or
            # `optimize_for_inference_lib.optimize_for_inference` to trim the graph.
            # Use `optimize_for_inference` if we decide to export the frozen graph
            # (graph + checkpoint) and want explictily fold in batchnorm variables.
            optimized_graph_def = graph_util.remove_training_nodes(
                optimized_graph_def, output_nodes)

    print(' - Saving the graph...')
    tf.train.write_graph(optimized_graph_def, FLAGS.export_dir,
                         'inference_graph.pbtxt')
    print(' - Done!')