Ejemplo n.º 1
0
def _encode_and_add_to_node_collection(collection_prefix, key, node):
    tf.add_to_collection(
        encoding.with_suffix(collection_prefix, encoding.KEY_SUFFIX),
        encoding.encode_key(key))
    tf.add_to_collection(
        encoding.with_suffix(collection_prefix, encoding.NODE_SUFFIX),
        encoding.encode_tensor_node(node))
Ejemplo n.º 2
0
def _encode_and_add_to_node_collection(collection_prefix: Text,
                                       key: types.FPLKeyType,
                                       node: types.TensorType) -> None:
  tf.compat.v1.add_to_collection(
      encoding.with_suffix(collection_prefix, encoding.KEY_SUFFIX),
      encoding.encode_key(key))
  tf.compat.v1.add_to_collection(
      encoding.with_suffix(collection_prefix, encoding.NODE_SUFFIX),
      encoding.encode_tensor_node(node))
Ejemplo n.º 3
0
 def testEncodeDecodeKey(self):
   test_cases = [
       'a', 'simple', 'dollar$', '$dollar', '$do$ll$ar$', ('a'),
       ('a', 'simple'), ('dollar$', 'simple'), ('do$llar', 'sim$ple', 'str$'),
       ('many', 'many', 'elements', 'in', 'the', 'tuple'), u'unicode\u1234',
       u'uni\u1234code\u2345', ('mixed', u'uni\u1234',
                                u'\u2345\u1234'), (u'\u1234\u2345',
                                                   u'\u3456\u2345')
   ]
   for key in test_cases:
     self.assertEqual(key, encoding.decode_key(encoding.encode_key(key)))
Ejemplo n.º 4
0
def export_eval_savedmodel(
    estimator,
    export_dir_base,
    eval_input_receiver_fn,
    checkpoint_path = None):
  """Export a EvalSavedModel for the given estimator.

  Args:
    estimator: Estimator to export the graph for.
    export_dir_base: Base path for export. Graph will be exported into a
      subdirectory of this base path.
    eval_input_receiver_fn: Eval input receiver function.
    checkpoint_path: Path to a specific checkpoint to export. If set to None,
      exports the latest checkpoint.

  Returns:
    Path to the directory where the eval graph was exported.

  Raises:
    ValueError: Could not find a checkpoint to export.
  """
  with tf.Graph().as_default() as g:
    eval_input_receiver = eval_input_receiver_fn()
    tf.train.create_global_step(g)
    tf.set_random_seed(estimator.config.tf_random_seed)

    # Workaround for TensorFlow issue #17568. Note that we pass the
    # identity-wrapped features and labels to model_fn, but we have to feed
    # the non-identity wrapped Tensors during evaluation.
    #
    # Also note that we can't wrap predictions, so metrics that have control
    # dependencies on predictions will cause the predictions to be recomputed
    # during their evaluation.
    wrapped_features = util.wrap_tensor_or_dict_of_tensors_in_identity(
        eval_input_receiver.features)
    wrapped_labels = util.wrap_tensor_or_dict_of_tensors_in_identity(
        eval_input_receiver.labels)

    if isinstance(estimator, tf.estimator.Estimator):
      # This is a core estimator
      estimator_spec = estimator.model_fn(
          features=wrapped_features,
          labels=wrapped_labels,
          mode=tf.estimator.ModeKeys.EVAL,
          config=estimator.config)
    else:
      # This is a contrib estimator
      model_fn_ops = estimator._call_model_fn(  # pylint: disable=protected-access
          features=wrapped_features,
          labels=wrapped_labels,
          mode=tf.estimator.ModeKeys.EVAL)
      estimator_spec = lambda x: None
      estimator_spec.predictions = model_fn_ops.predictions
      estimator_spec.eval_metric_ops = model_fn_ops.eval_metric_ops
      estimator_spec.scaffold = model_fn_ops.scaffold

    # Save metric using eval_metric_ops.
    for user_metric_key, (value_op, update_op) in (
        estimator_spec.eval_metric_ops.items()):
      tf.add_to_collection('%s/%s' % (encoding.METRICS_COLLECTION,
                                      encoding.KEY_SUFFIX),
                           encoding.encode_key(user_metric_key))
      tf.add_to_collection('%s/%s' % (encoding.METRICS_COLLECTION,
                                      encoding.VALUE_OP_SUFFIX),
                           encoding.encode_tensor_node(value_op))
      tf.add_to_collection('%s/%s' % (encoding.METRICS_COLLECTION,
                                      encoding.UPDATE_OP_SUFFIX),
                           encoding.encode_tensor_node(update_op))

    # Save all prediction nodes.
    # Predictions can either be a Tensor, or a dict of Tensors.
    predictions = estimator_spec.predictions
    if not isinstance(predictions, dict):
      predictions = {encoding.DEFAULT_PREDICTIONS_DICT_KEY: predictions}

    for prediction_key, prediction_node in predictions.items():
      _encode_and_add_to_node_collection(encoding.PREDICTIONS_COLLECTION,
                                         prediction_key, prediction_node)

    ############################################################
    ## Features, label (and weight) graph

    # Placeholder for input example to label graph.
    tf.add_to_collection(encoding.INPUT_EXAMPLE_COLLECTION,
                         encoding.encode_tensor_node(
                             eval_input_receiver.receiver_tensors['examples']))

    # Save all label nodes.
    # Labels can either be a Tensor, or a dict of Tensors.
    labels = eval_input_receiver.labels
    if not isinstance(labels, dict):
      labels = {encoding.DEFAULT_LABELS_DICT_KEY: labels}

    for label_key, label_node in labels.items():
      _encode_and_add_to_node_collection(encoding.LABELS_COLLECTION, label_key,
                                         label_node)

    # Save features.
    for feature_name, feature_node in eval_input_receiver.features.items():
      _encode_and_add_to_node_collection(encoding.FEATURES_COLLECTION,
                                         feature_name, feature_node)

    ############################################################
    ## Export as normal

    if not checkpoint_path:
      checkpoint_path = tf.train.latest_checkpoint(estimator.model_dir)
      if not checkpoint_path:
        raise ValueError(
            'Could not find trained model at %s.' % estimator.model_dir)

    export_dir = _get_timestamped_export_dir(export_dir_base)
    temp_export_dir = _get_temp_export_dir(export_dir)

    if estimator.config.session_config is None:
      session_config = config_pb2.ConfigProto(allow_soft_placement=True)
    else:
      session_config = estimator.config.session_config

    with tf.Session(config=session_config) as session:
      if estimator_spec.scaffold and estimator_spec.scaffold.saver:
        saver_for_restore = estimator_spec.scaffold.saver
      else:
        saver_for_restore = tf.train.Saver(sharded=True)
      saver_for_restore.restore(session, checkpoint_path)

      if estimator_spec.scaffold and estimator_spec.scaffold.local_init_op:
        local_init_op = estimator_spec.scaffold.local_init_op
      else:
        local_init_op = tf.train.Scaffold._default_local_init_op()
      # pylint: enable=protected-access

      # Perform the export
      builder = tf.saved_model.builder.SavedModelBuilder(temp_export_dir)
      builder.add_meta_graph_and_variables(
          session,
          [tf.saved_model.tag_constants.SERVING],
          # Don't export any signatures, since this graph is not actually
          # meant for serving.
          signature_def_map=None,
          assets_collection=tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS),
          legacy_init_op=local_init_op)
      builder.save(False)

      gfile.Rename(temp_export_dir, export_dir)
      return export_dir