Beispiel #1
0
def export_inference_graph(input_type,
                           pipeline_config,
                           trained_checkpoint_prefix,
                           output_directory,
                           input_shape=None,
                           output_collection_name='inference_op',
                           additional_output_tensor_names=None,
                           write_inference_graph=False,
                           use_side_inputs=False,
                           side_input_shapes=None,
                           side_input_names=None,
                           side_input_types=None):
    """Exports inference graph for the model specified in the pipeline config.

  Args:
    input_type: Type of input for the graph. Can be one of ['image_tensor',
      'encoded_image_string_tensor', 'tf_example'].
    pipeline_config: pipeline_pb2.TrainAndEvalPipelineConfig proto.
    trained_checkpoint_prefix: Path to the trained checkpoint file.
    output_directory: Path to write outputs.
    input_shape: Sets a fixed shape for an `image_tensor` input. If not
      specified, will default to [None, None, None, 3].
    output_collection_name: Name of collection to add output tensors to.
      If None, does not add output tensors to a collection.
    additional_output_tensor_names: list of additional output
      tensors to include in the frozen graph.
    write_inference_graph: If true, writes inference graph to disk.
    use_side_inputs: If True, the model requires side_inputs.
    side_input_shapes: List of shapes of the side input tensors,
      required if use_side_inputs is True.
    side_input_names: List of names of the side input tensors,
      required if use_side_inputs is True.
    side_input_types: List of types of the side input tensors,
      required if use_side_inputs is True.
  """
    detection_model = model_builder.build(pipeline_config.model,
                                          is_training=False)
    graph_rewriter_fn = None
    if pipeline_config.HasField('graph_rewriter'):
        graph_rewriter_config = pipeline_config.graph_rewriter
        graph_rewriter_fn = graph_rewriter_builder.build(graph_rewriter_config,
                                                         is_training=False)
    _export_inference_graph(
        input_type,
        detection_model,
        pipeline_config.eval_config.use_moving_averages,
        trained_checkpoint_prefix,
        output_directory,
        additional_output_tensor_names,
        input_shape,
        output_collection_name,
        graph_hook_fn=graph_rewriter_fn,
        write_inference_graph=write_inference_graph,
        use_side_inputs=use_side_inputs,
        side_input_shapes=side_input_shapes,
        side_input_names=side_input_names,
        side_input_types=side_input_types)
    pipeline_config.eval_config.use_moving_averages = False
    config_util.save_pipeline_config(pipeline_config, output_directory)
    def test_save_pipeline_config(self):
        """Tests that the pipeline config is properly saved to disk."""
        pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
        pipeline_config.model.faster_rcnn.num_classes = 10
        pipeline_config.train_config.batch_size = 32
        pipeline_config.train_input_reader.label_map_path = "path/to/label_map"
        pipeline_config.eval_config.num_examples = 20
        pipeline_config.eval_input_reader.queue_capacity = 100

        config_util.save_pipeline_config(pipeline_config, self.get_temp_dir())
        configs = config_util.get_configs_from_pipeline_file(
            os.path.join(self.get_temp_dir(), "pipeline.config"))
        pipeline_config_reconstructed = (
            config_util.create_pipeline_proto_from_configs(configs))

        self.assertEqual(pipeline_config, pipeline_config_reconstructed)
Beispiel #3
0
def create_estimator_and_inputs(run_config,
                                hparams,
                                pipeline_config_path,
                                config_override=None,
                                train_steps=None,
                                sample_1_of_n_eval_examples=1,
                                sample_1_of_n_eval_on_train_examples=1,
                                model_fn_creator=create_model_fn,
                                use_tpu_estimator=False,
                                use_tpu=False,
                                num_shards=1,
                                params=None,
                                override_eval_num_epochs=True,
                                save_final_config=False,
                                **kwargs):
    """Creates `Estimator`, input functions, and steps.

  Args:
    run_config: A `RunConfig`.
    hparams: A `HParams`.
    pipeline_config_path: A path to a pipeline config file.
    config_override: A pipeline_pb2.TrainEvalPipelineConfig text proto to
      override the config from `pipeline_config_path`.
    train_steps: Number of training steps. If None, the number of training steps
      is set from the `TrainConfig` proto.
    sample_1_of_n_eval_examples: Integer representing how often an eval example
      should be sampled. If 1, will sample all examples.
    sample_1_of_n_eval_on_train_examples: Similar to
      `sample_1_of_n_eval_examples`, except controls the sampling of training
      data for evaluation.
    model_fn_creator: A function that creates a `model_fn` for `Estimator`.
      Follows the signature:

      * Args:
        * `detection_model_fn`: Function that returns `DetectionModel` instance.
        * `configs`: Dictionary of pipeline config objects.
        * `hparams`: `HParams` object.
      * Returns:
        `model_fn` for `Estimator`.

    use_tpu_estimator: Whether a `TPUEstimator` should be returned. If False,
      an `Estimator` will be returned.
    use_tpu: Boolean, whether training and evaluation should run on TPU. Only
      used if `use_tpu_estimator` is True.
    num_shards: Number of shards (TPU cores). Only used if `use_tpu_estimator`
      is True.
    params: Parameter dictionary passed from the estimator. Only used if
      `use_tpu_estimator` is True.
    override_eval_num_epochs: Whether to overwrite the number of epochs to
      1 for eval_input.
    save_final_config: Whether to save final config (obtained after applying
      overrides) to `estimator.model_dir`.
    **kwargs: Additional keyword arguments for configuration override.

  Returns:
    A dictionary with the following fields:
    'estimator': An `Estimator` or `TPUEstimator`.
    'train_input_fn': A training input function.
    'eval_input_fns': A list of all evaluation input functions.
    'eval_input_names': A list of names for each evaluation input.
    'eval_on_train_input_fn': An evaluation-on-train input function.
    'predict_input_fn': A prediction input function.
    'train_steps': Number of training steps. Either directly from input or from
      configuration.
  """
    get_configs_from_pipeline_file = MODEL_BUILD_UTIL_MAP[
        'get_configs_from_pipeline_file']
    merge_external_params_with_configs = MODEL_BUILD_UTIL_MAP[
        'merge_external_params_with_configs']
    create_pipeline_proto_from_configs = MODEL_BUILD_UTIL_MAP[
        'create_pipeline_proto_from_configs']
    create_train_input_fn = MODEL_BUILD_UTIL_MAP['create_train_input_fn']
    create_eval_input_fn = MODEL_BUILD_UTIL_MAP['create_eval_input_fn']
    create_predict_input_fn = MODEL_BUILD_UTIL_MAP['create_predict_input_fn']

    configs = get_configs_from_pipeline_file(pipeline_config_path,
                                             config_override=config_override)
    kwargs.update({
        'train_steps': train_steps,
        'sample_1_of_n_eval_examples': sample_1_of_n_eval_examples
    })
    if override_eval_num_epochs:
        kwargs.update({'eval_num_epochs': 1})
        tf.logging.warning(
            'Forced number of epochs for all eval validations to be 1.')
    configs = merge_external_params_with_configs(configs,
                                                 hparams,
                                                 kwargs_dict=kwargs)
    model_config = configs['model']
    train_config = configs['train_config']
    train_input_config = configs['train_input_config']
    eval_config = configs['eval_config']
    eval_input_configs = configs['eval_input_configs']
    eval_on_train_input_config = copy.deepcopy(train_input_config)
    eval_on_train_input_config.sample_1_of_n_examples = (
        sample_1_of_n_eval_on_train_examples)
    if override_eval_num_epochs and eval_on_train_input_config.num_epochs != 1:
        tf.logging.warning('Expected number of evaluation epochs is 1, but '
                           'instead encountered `eval_on_train_input_config'
                           '.num_epochs` = '
                           '{}. Overwriting `num_epochs` to 1.'.format(
                               eval_on_train_input_config.num_epochs))
        eval_on_train_input_config.num_epochs = 1

    # update train_steps from config but only when non-zero value is provided
    if train_steps is None and train_config.num_steps != 0:
        train_steps = train_config.num_steps

    detection_model_fn = functools.partial(model_builder.build,
                                           model_config=model_config)

    # Create the input functions for TRAIN/EVAL/PREDICT.
    train_input_fn = create_train_input_fn(
        train_config=train_config,
        train_input_config=train_input_config,
        model_config=model_config)
    eval_input_fns = [
        create_eval_input_fn(eval_config=eval_config,
                             eval_input_config=eval_input_config,
                             model_config=model_config)
        for eval_input_config in eval_input_configs
    ]
    eval_input_names = [
        eval_input_config.name for eval_input_config in eval_input_configs
    ]
    eval_on_train_input_fn = create_eval_input_fn(
        eval_config=eval_config,
        eval_input_config=eval_on_train_input_config,
        model_config=model_config)
    predict_input_fn = create_predict_input_fn(
        model_config=model_config, predict_input_config=eval_input_configs[0])

    export_to_tpu = hparams.get('export_to_tpu', False)
    tf.logging.info(
        'create_estimator_and_inputs: use_tpu %s, export_to_tpu %s', use_tpu,
        export_to_tpu)
    model_fn = model_fn_creator(detection_model_fn, configs, hparams, use_tpu)
    if use_tpu_estimator:
        estimator = tf.contrib.tpu.TPUEstimator(
            model_fn=model_fn,
            train_batch_size=train_config.batch_size,
            # For each core, only batch size 1 is supported for eval.
            eval_batch_size=num_shards * 1 if use_tpu else 1,
            use_tpu=use_tpu,
            config=run_config,
            # TODO(lzc): Remove conditional after CMLE moves to TF 1.9
            params=params if params else {})
    else:
        estimator = tf.estimator.Estimator(model_fn=model_fn,
                                           config=run_config)

    # Write the as-run pipeline config to disk.
    if run_config.is_chief and save_final_config:
        pipeline_config_final = create_pipeline_proto_from_configs(configs)
        config_util.save_pipeline_config(pipeline_config_final,
                                         estimator.model_dir)

    return dict(estimator=estimator,
                train_input_fn=train_input_fn,
                eval_input_fns=eval_input_fns,
                eval_input_names=eval_input_names,
                eval_on_train_input_fn=eval_on_train_input_fn,
                predict_input_fn=predict_input_fn,
                train_steps=train_steps)
def export_inference_graph(input_type,
                           pipeline_config,
                           trained_checkpoint_dir,
                           output_directory,
                           use_side_inputs=False,
                           side_input_shapes='',
                           side_input_types='',
                           side_input_names=''):
    """Exports inference graph for the model specified in the pipeline config.

  This function creates `output_directory` if it does not already exist,
  which will hold a copy of the pipeline config with filename `pipeline.config`,
  and two subdirectories named `checkpoint` and `saved_model`
  (containing the exported checkpoint and SavedModel respectively).

  Args:
    input_type: Type of input for the graph. Can be one of ['image_tensor',
      'encoded_image_string_tensor', 'tf_example'].
    pipeline_config: pipeline_pb2.TrainAndEvalPipelineConfig proto.
    trained_checkpoint_dir: Path to the trained checkpoint file.
    output_directory: Path to write outputs.
    use_side_inputs: boolean that determines whether side inputs should be
      included in the input signature.
    side_input_shapes: forward-slash-separated list of comma-separated lists
        describing input shapes.
    side_input_types: comma-separated list of the types of the inputs.
    side_input_names: comma-separated list of the names of the inputs.
  Raises:
    ValueError: if input_type is invalid.
  """
    output_checkpoint_directory = os.path.join(output_directory, 'checkpoint')
    output_saved_model_directory = os.path.join(output_directory,
                                                'saved_model')

    detection_model = model_builder.build(pipeline_config.model,
                                          is_training=False)

    ckpt = tf.train.Checkpoint(model=detection_model)
    manager = tf.train.CheckpointManager(ckpt,
                                         trained_checkpoint_dir,
                                         max_to_keep=1)
    status = ckpt.restore(manager.latest_checkpoint).expect_partial()

    if input_type not in DETECTION_MODULE_MAP:
        raise ValueError('Unrecognized `input_type`')
    if use_side_inputs and input_type != 'image_tensor':
        raise ValueError(
            'Side inputs supported for image_tensor input type only.')

    zipped_side_inputs = []
    if use_side_inputs:
        zipped_side_inputs = _combine_side_inputs(side_input_shapes,
                                                  side_input_types,
                                                  side_input_names)

    detection_module = DETECTION_MODULE_MAP[input_type](
        detection_model, use_side_inputs, list(zipped_side_inputs))
    # Getting the concrete function traces the graph and forces variables to
    # be constructed --- only after this can we save the checkpoint and
    # saved model.
    concrete_function = detection_module.__call__.get_concrete_function()
    status.assert_existing_objects_matched()

    exported_checkpoint_manager = tf.train.CheckpointManager(
        ckpt, output_checkpoint_directory, max_to_keep=1)
    exported_checkpoint_manager.save(checkpoint_number=0)

    tf.saved_model.save(detection_module,
                        output_saved_model_directory,
                        signatures=concrete_function)

    config_util.save_pipeline_config(pipeline_config, output_directory)