示例#1
0
def _encode_and_add_to_node_collection(collection_prefix, key, node):
    tf.add_to_collection(
        encoding.with_suffix(collection_prefix, encoding.KEY_SUFFIX),
        encoding.encode_key(key))
    tf.add_to_collection(
        encoding.with_suffix(collection_prefix, encoding.NODE_SUFFIX),
        encoding.encode_tensor_node(node))
示例#2
0
    def testEncodeDecodeTensorNode(self):
        g = tf.Graph()
        with g.as_default():
            example = tf.placeholder(tf.string, name='example')
            features = tf.parse_example(
                example, {
                    'age': tf.FixedLenFeature(
                        [], dtype=tf.int64, default_value=-1),
                    'gender': tf.FixedLenFeature([], dtype=tf.string),
                    'varstr': tf.VarLenFeature(tf.string),
                    'varint': tf.VarLenFeature(tf.int64),
                    'varfloat': tf.VarLenFeature(tf.float32),
                    u'unicode\u1234': tf.FixedLenFeature([], dtype=tf.string),
                })
            constant = tf.constant(1.0)
            sparse = tf.SparseTensor(indices=tf.placeholder(tf.int64),
                                     values=tf.placeholder(tf.int64),
                                     dense_shape=tf.placeholder(tf.int64))

        test_cases = [
            example, features['age'], features['gender'], features['varstr'],
            features['varint'], features['varfloat'],
            features[u'unicode\u1234'], constant, sparse
        ]
        for tensor in test_cases:
            got_tensor = encoding.decode_tensor_node(
                g, encoding.encode_tensor_node(tensor))
            if isinstance(tensor, tf.SparseTensor):
                self.assertEqual(tensor.indices, got_tensor.indices)
                self.assertEqual(tensor.values, got_tensor.values)
                self.assertEqual(tensor.dense_shape, got_tensor.dense_shape)
            else:
                self.assertEqual(tensor, got_tensor)
示例#3
0
def _encode_and_add_to_node_collection(collection_prefix: Text,
                                       key: types.FPLKeyType,
                                       node: types.TensorType) -> None:
  tf.compat.v1.add_to_collection(
      encoding.with_suffix(collection_prefix, encoding.KEY_SUFFIX),
      encoding.encode_key(key))
  tf.compat.v1.add_to_collection(
      encoding.with_suffix(collection_prefix, encoding.NODE_SUFFIX),
      encoding.encode_tensor_node(node))
示例#4
0
  def testGetNodeInGraph(self):
    g = tf.Graph()
    with g.as_default():
      apple = tf.constant(1.0)

    meta_graph_def = meta_graph_pb2.MetaGraphDef()
    meta_graph_def.collection_def['fruit_node'].any_list.value.extend(
        [encoding.encode_tensor_node(apple)])

    self.assertEqual(
        apple, graph_ref.get_node_in_graph(meta_graph_def, 'fruit_node', g))
示例#5
0
def _add_tfma_collections(features: types.TensorTypeMaybeDict,
                          labels: Optional[types.TensorTypeMaybeDict],
                          input_refs: types.TensorType):
    """Add extra collections for features, labels, input_refs, version.

  This should be called within the Graph that will be saved. Typical usage
  would be when features and labels have been parsed, i.e. in the
  input_receiver_fn.

  Args:
    features: dict of strings to tensors representing features
    labels: dict of strings to tensors or a single tensor
    input_refs: See EvalInputReceiver().
  """
    # Clear existing collections first, in case the EvalInputReceiver was called
    # multiple times.
    del tf.compat.v1.get_collection_ref(
        encoding.with_suffix(encoding.FEATURES_COLLECTION,
                             encoding.KEY_SUFFIX))[:]
    del tf.compat.v1.get_collection_ref(
        encoding.with_suffix(encoding.FEATURES_COLLECTION,
                             encoding.NODE_SUFFIX))[:]
    del tf.compat.v1.get_collection_ref(
        encoding.with_suffix(encoding.LABELS_COLLECTION,
                             encoding.KEY_SUFFIX))[:]
    del tf.compat.v1.get_collection_ref(
        encoding.with_suffix(encoding.LABELS_COLLECTION,
                             encoding.NODE_SUFFIX))[:]
    del tf.compat.v1.get_collection_ref(encoding.EXAMPLE_REF_COLLECTION)[:]
    del tf.compat.v1.get_collection_ref(encoding.TFMA_VERSION_COLLECTION)[:]

    for feature_name, feature_node in features.items():
        _encode_and_add_to_node_collection(encoding.FEATURES_COLLECTION,
                                           feature_name, feature_node)

    if labels is not None:
        # Labels can either be a Tensor, or a dict of Tensors.
        if not isinstance(labels, dict):
            labels = {util.default_dict_key(constants.LABELS_NAME): labels}

        for label_key, label_node in labels.items():
            _encode_and_add_to_node_collection(encoding.LABELS_COLLECTION,
                                               label_key, label_node)
    # Previously input_refs was called example_ref. This code is being deprecated
    # so it was not renamed.
    example_ref_collection = tf.compat.v1.get_collection_ref(
        encoding.EXAMPLE_REF_COLLECTION)
    example_ref_collection.append(encoding.encode_tensor_node(input_refs))

    tf.compat.v1.add_to_collection(encoding.TFMA_VERSION_COLLECTION,
                                   version.VERSION)
示例#6
0
def _add_tfma_collections(features,
                          labels,
                          example_ref):
  """Add extra collections for features, labels, example_ref, version.

  This should be called within the Graph that will be saved. Typical usage
  would be when features and labels have been parsed, i.e. in the
  input_receiver_fn.

  Args:
    features: dict of strings to tensors representing features
    labels: dict of strings to tensors or a single tensor
    example_ref: See EvalInputReceiver().
  """
  # Clear existing collections first, in case the EvalInputReceiver was called
  # multiple times.
  del tf.get_collection_ref(
      encoding.with_suffix(encoding.FEATURES_COLLECTION,
                           encoding.KEY_SUFFIX))[:]
  del tf.get_collection_ref(
      encoding.with_suffix(encoding.FEATURES_COLLECTION,
                           encoding.NODE_SUFFIX))[:]
  del tf.get_collection_ref(
      encoding.with_suffix(encoding.LABELS_COLLECTION, encoding.KEY_SUFFIX))[:]
  del tf.get_collection_ref(
      encoding.with_suffix(encoding.LABELS_COLLECTION, encoding.NODE_SUFFIX))[:]
  del tf.get_collection_ref(encoding.EXAMPLE_REF_COLLECTION)[:]
  del tf.get_collection_ref(encoding.TFMA_VERSION_COLLECTION)[:]

  for feature_name, feature_node in features.items():
    _encode_and_add_to_node_collection(encoding.FEATURES_COLLECTION,
                                       feature_name, feature_node)

  if labels is not None:
    # Labels can either be a Tensor, or a dict of Tensors.
    if not isinstance(labels, dict):
      labels = {encoding.DEFAULT_LABELS_DICT_KEY: labels}

    for label_key, label_node in labels.items():
      _encode_and_add_to_node_collection(encoding.LABELS_COLLECTION, label_key,
                                         label_node)
  example_ref_collection = tf.get_collection_ref(
      encoding.EXAMPLE_REF_COLLECTION)
  example_ref_collection.append(encoding.encode_tensor_node(example_ref))

  tf.add_to_collection(encoding.TFMA_VERSION_COLLECTION, version.VERSION_STRING)
示例#7
0
文件: export.py 项目: hakanhp/cahnel
def export_eval_savedmodel(
    estimator,
    export_dir_base,
    eval_input_receiver_fn,
    checkpoint_path = None):
  """Export a EvalSavedModel for the given estimator.

  Args:
    estimator: Estimator to export the graph for.
    export_dir_base: Base path for export. Graph will be exported into a
      subdirectory of this base path.
    eval_input_receiver_fn: Eval input receiver function.
    checkpoint_path: Path to a specific checkpoint to export. If set to None,
      exports the latest checkpoint.

  Returns:
    Path to the directory where the eval graph was exported.

  Raises:
    ValueError: Could not find a checkpoint to export.
  """
  with tf.Graph().as_default() as g:
    eval_input_receiver = eval_input_receiver_fn()
    tf.train.create_global_step(g)
    tf.set_random_seed(estimator.config.tf_random_seed)

    # Workaround for TensorFlow issue #17568. Note that we pass the
    # identity-wrapped features and labels to model_fn, but we have to feed
    # the non-identity wrapped Tensors during evaluation.
    #
    # Also note that we can't wrap predictions, so metrics that have control
    # dependencies on predictions will cause the predictions to be recomputed
    # during their evaluation.
    wrapped_features = util.wrap_tensor_or_dict_of_tensors_in_identity(
        eval_input_receiver.features)
    wrapped_labels = util.wrap_tensor_or_dict_of_tensors_in_identity(
        eval_input_receiver.labels)

    if isinstance(estimator, tf.estimator.Estimator):
      # This is a core estimator
      estimator_spec = estimator.model_fn(
          features=wrapped_features,
          labels=wrapped_labels,
          mode=tf.estimator.ModeKeys.EVAL,
          config=estimator.config)
    else:
      # This is a contrib estimator
      model_fn_ops = estimator._call_model_fn(  # pylint: disable=protected-access
          features=wrapped_features,
          labels=wrapped_labels,
          mode=tf.estimator.ModeKeys.EVAL)
      estimator_spec = lambda x: None
      estimator_spec.predictions = model_fn_ops.predictions
      estimator_spec.eval_metric_ops = model_fn_ops.eval_metric_ops
      estimator_spec.scaffold = model_fn_ops.scaffold

    # Save metric using eval_metric_ops.
    for user_metric_key, (value_op, update_op) in (
        estimator_spec.eval_metric_ops.items()):
      tf.add_to_collection('%s/%s' % (encoding.METRICS_COLLECTION,
                                      encoding.KEY_SUFFIX),
                           encoding.encode_key(user_metric_key))
      tf.add_to_collection('%s/%s' % (encoding.METRICS_COLLECTION,
                                      encoding.VALUE_OP_SUFFIX),
                           encoding.encode_tensor_node(value_op))
      tf.add_to_collection('%s/%s' % (encoding.METRICS_COLLECTION,
                                      encoding.UPDATE_OP_SUFFIX),
                           encoding.encode_tensor_node(update_op))

    # Save all prediction nodes.
    # Predictions can either be a Tensor, or a dict of Tensors.
    predictions = estimator_spec.predictions
    if not isinstance(predictions, dict):
      predictions = {encoding.DEFAULT_PREDICTIONS_DICT_KEY: predictions}

    for prediction_key, prediction_node in predictions.items():
      _encode_and_add_to_node_collection(encoding.PREDICTIONS_COLLECTION,
                                         prediction_key, prediction_node)

    ############################################################
    ## Features, label (and weight) graph

    # Placeholder for input example to label graph.
    tf.add_to_collection(encoding.INPUT_EXAMPLE_COLLECTION,
                         encoding.encode_tensor_node(
                             eval_input_receiver.receiver_tensors['examples']))

    # Save all label nodes.
    # Labels can either be a Tensor, or a dict of Tensors.
    labels = eval_input_receiver.labels
    if not isinstance(labels, dict):
      labels = {encoding.DEFAULT_LABELS_DICT_KEY: labels}

    for label_key, label_node in labels.items():
      _encode_and_add_to_node_collection(encoding.LABELS_COLLECTION, label_key,
                                         label_node)

    # Save features.
    for feature_name, feature_node in eval_input_receiver.features.items():
      _encode_and_add_to_node_collection(encoding.FEATURES_COLLECTION,
                                         feature_name, feature_node)

    ############################################################
    ## Export as normal

    if not checkpoint_path:
      checkpoint_path = tf.train.latest_checkpoint(estimator.model_dir)
      if not checkpoint_path:
        raise ValueError(
            'Could not find trained model at %s.' % estimator.model_dir)

    export_dir = _get_timestamped_export_dir(export_dir_base)
    temp_export_dir = _get_temp_export_dir(export_dir)

    if estimator.config.session_config is None:
      session_config = config_pb2.ConfigProto(allow_soft_placement=True)
    else:
      session_config = estimator.config.session_config

    with tf.Session(config=session_config) as session:
      if estimator_spec.scaffold and estimator_spec.scaffold.saver:
        saver_for_restore = estimator_spec.scaffold.saver
      else:
        saver_for_restore = tf.train.Saver(sharded=True)
      saver_for_restore.restore(session, checkpoint_path)

      if estimator_spec.scaffold and estimator_spec.scaffold.local_init_op:
        local_init_op = estimator_spec.scaffold.local_init_op
      else:
        local_init_op = tf.train.Scaffold._default_local_init_op()
      # pylint: enable=protected-access

      # Perform the export
      builder = tf.saved_model.builder.SavedModelBuilder(temp_export_dir)
      builder.add_meta_graph_and_variables(
          session,
          [tf.saved_model.tag_constants.SERVING],
          # Don't export any signatures, since this graph is not actually
          # meant for serving.
          signature_def_map=None,
          assets_collection=tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS),
          legacy_init_op=local_init_op)
      builder.save(False)

      gfile.Rename(temp_export_dir, export_dir)
      return export_dir