示例#1
0
def get_timestamped_export_dir(export_dir_base):
  """Builds a path to a new subdirectory within the base directory.

  Each export is written into a new subdirectory named using the
  current time.  This guarantees monotonically increasing version
  numbers even across multiple runs of the pipeline.
  The timestamp used is the number of seconds since epoch UTC.

  Args:
    export_dir_base: A string containing a directory to write the exported
        graph and checkpoints.
  Returns:
    The full path of the new subdirectory (which is not actually created yet).

  Raises:
    RuntimeError: if repeated attempts fail to obtain a unique timestamped
      directory name.
  """
  return util.get_timestamped_dir(export_dir_base)
示例#2
0
def get_timestamped_export_dir(export_dir_base):
  """Builds a path to a new subdirectory within the base directory.

  Each export is written into a new subdirectory named using the
  current time.  This guarantees monotonically increasing version
  numbers even across multiple runs of the pipeline.
  The timestamp used is the number of seconds since epoch UTC.

  Args:
    export_dir_base: A string containing a directory to write the exported
        graph and checkpoints.
  Returns:
    The full path of the new subdirectory (which is not actually created yet).

  Raises:
    RuntimeError: if repeated attempts fail to obtain a unique timestamped
      directory name.
  """
  return util.get_timestamped_dir(export_dir_base)
示例#3
0
def export_eval_savedmodel(
    estimator,
    export_dir_base,
    eval_input_receiver_fn,
    checkpoint_path = None):
  """Export a EvalSavedModel for the given estimator.

  Args:
    estimator: Estimator to export the graph for.
    export_dir_base: Base path for export. Graph will be exported into a
      subdirectory of this base path.
    eval_input_receiver_fn: Eval input receiver function.
    checkpoint_path: Path to a specific checkpoint to export. If set to None,
      exports the latest checkpoint.

  Returns:
    Path to the directory where the eval graph was exported.

  Raises:
    ValueError: Could not find a checkpoint to export.
  """
  with tf.Graph().as_default() as g:
    eval_input_receiver = eval_input_receiver_fn()
    tf.train.create_global_step(g)
    tf.set_random_seed(estimator.config.tf_random_seed)

    # Workaround for TensorFlow issue #17568. Note that we pass the
    # identity-wrapped features and labels to model_fn, but we have to feed
    # the non-identity wrapped Tensors during evaluation.
    #
    # Also note that we can't wrap predictions, so metrics that have control
    # dependencies on predictions will cause the predictions to be recomputed
    # during their evaluation.
    wrapped_features = util.wrap_tensor_or_dict_of_tensors_in_identity(
        eval_input_receiver.features)
    wrapped_labels = util.wrap_tensor_or_dict_of_tensors_in_identity(
        eval_input_receiver.labels)

    if isinstance(estimator, tf.estimator.Estimator):
      # This is a core Estimator.
      estimator_spec = estimator.model_fn(
          features=wrapped_features,
          labels=wrapped_labels,
          mode=tf.estimator.ModeKeys.EVAL,
          config=estimator.config)
    else:
      # This is a contrib Estimator. Note that contrib Estimators are
      # deprecated.
      model_fn_ops = estimator.model_fn(
          features=wrapped_features,
          labels=wrapped_labels,
          mode=tf.estimator.ModeKeys.EVAL,
          config=estimator.config)
      # "Convert" model_fn_ops into EstimatorSpec,
      # populating only the fields we need.
      estimator_spec = tf.estimator.EstimatorSpec(
          loss=tf.constant(0.0),
          mode=tf.estimator.ModeKeys.EVAL,
          predictions=model_fn_ops.predictions,
          eval_metric_ops=model_fn_ops.eval_metric_ops,
          scaffold=model_fn_ops.scaffold)

    # Write out exporter version.
    tf.add_to_collection(encoding.TFMA_VERSION_COLLECTION,
                         version.VERSION_STRING)

    # Save metric using eval_metric_ops.
    for user_metric_key, (value_op, update_op) in (
        estimator_spec.eval_metric_ops.items()):
      tf.add_to_collection('%s/%s' % (encoding.METRICS_COLLECTION,
                                      encoding.KEY_SUFFIX),
                           encoding.encode_key(user_metric_key))
      tf.add_to_collection('%s/%s' % (encoding.METRICS_COLLECTION,
                                      encoding.VALUE_OP_SUFFIX),
                           encoding.encode_tensor_node(value_op))
      tf.add_to_collection('%s/%s' % (encoding.METRICS_COLLECTION,
                                      encoding.UPDATE_OP_SUFFIX),
                           encoding.encode_tensor_node(update_op))

    # Save all prediction nodes.
    # Predictions can either be a Tensor, or a dict of Tensors.
    predictions = estimator_spec.predictions
    if not isinstance(predictions, dict):
      predictions = {encoding.DEFAULT_PREDICTIONS_DICT_KEY: predictions}

    for prediction_key, prediction_node in predictions.items():
      _encode_and_add_to_node_collection(encoding.PREDICTIONS_COLLECTION,
                                         prediction_key, prediction_node)

    ############################################################
    ## Features, label (and weight) graph

    # Placeholder for input example to label graph.
    tf.add_to_collection(encoding.INPUT_EXAMPLE_COLLECTION,
                         encoding.encode_tensor_node(
                             eval_input_receiver.receiver_tensors['examples']))

    # Save all label nodes.
    # Labels can either be a Tensor, or a dict of Tensors.
    labels = eval_input_receiver.labels
    if not isinstance(labels, dict):
      labels = {encoding.DEFAULT_LABELS_DICT_KEY: labels}

    for label_key, label_node in labels.items():
      _encode_and_add_to_node_collection(encoding.LABELS_COLLECTION, label_key,
                                         label_node)

    # Save features.
    for feature_name, feature_node in eval_input_receiver.features.items():
      _encode_and_add_to_node_collection(encoding.FEATURES_COLLECTION,
                                         feature_name, feature_node)

    ############################################################
    ## Export as normal

    if not checkpoint_path:
      checkpoint_path = tf.train.latest_checkpoint(estimator.model_dir)
      if not checkpoint_path:
        raise ValueError(
            'Could not find trained model at %s.' % estimator.model_dir)

    export_dir = estimator_util.get_timestamped_dir(export_dir_base)
    temp_export_dir = _get_temp_export_dir(export_dir)

    if estimator.config.session_config is None:
      session_config = config_pb2.ConfigProto(allow_soft_placement=True)
    else:
      session_config = estimator.config.session_config

    with tf.Session(config=session_config) as session:
      if estimator_spec.scaffold and estimator_spec.scaffold.saver:
        saver_for_restore = estimator_spec.scaffold.saver
      else:
        saver_for_restore = tf.train.Saver(sharded=True)
      saver_for_restore.restore(session, checkpoint_path)

      if estimator_spec.scaffold and estimator_spec.scaffold.local_init_op:
        local_init_op = estimator_spec.scaffold.local_init_op
      else:
        if hasattr(tf.train.Scaffold, 'default_local_init_op'):
          local_init_op = tf.train.Scaffold.default_local_init_op()
        else:
          local_init_op = tf.train.Scaffold._default_local_init_op()  # pylint: disable=protected-access

      # Perform the export
      builder = tf.saved_model.builder.SavedModelBuilder(temp_export_dir)
      builder.add_meta_graph_and_variables(
          session,
          [constants.EVAL_SAVED_MODEL_TAG],
          # Don't export any signatures, since this graph is not actually
          # meant for serving.
          signature_def_map=None,
          assets_collection=tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS),
          legacy_init_op=local_init_op)
      builder.save(False)

      gfile.Rename(temp_export_dir, export_dir)
      return export_dir