def _encode_and_add_to_node_collection(collection_prefix, key, node): tf.add_to_collection( encoding.with_suffix(collection_prefix, encoding.KEY_SUFFIX), encoding.encode_key(key)) tf.add_to_collection( encoding.with_suffix(collection_prefix, encoding.NODE_SUFFIX), encoding.encode_tensor_node(node))
def _encode_and_add_to_node_collection(collection_prefix: Text, key: types.FPLKeyType, node: types.TensorType) -> None: tf.compat.v1.add_to_collection( encoding.with_suffix(collection_prefix, encoding.KEY_SUFFIX), encoding.encode_key(key)) tf.compat.v1.add_to_collection( encoding.with_suffix(collection_prefix, encoding.NODE_SUFFIX), encoding.encode_tensor_node(node))
def testEncodeDecodeKey(self): test_cases = [ 'a', 'simple', 'dollar$', '$dollar', '$do$ll$ar$', ('a'), ('a', 'simple'), ('dollar$', 'simple'), ('do$llar', 'sim$ple', 'str$'), ('many', 'many', 'elements', 'in', 'the', 'tuple'), u'unicode\u1234', u'uni\u1234code\u2345', ('mixed', u'uni\u1234', u'\u2345\u1234'), (u'\u1234\u2345', u'\u3456\u2345') ] for key in test_cases: self.assertEqual(key, encoding.decode_key(encoding.encode_key(key)))
def export_eval_savedmodel( estimator, export_dir_base, eval_input_receiver_fn, checkpoint_path = None): """Export a EvalSavedModel for the given estimator. Args: estimator: Estimator to export the graph for. export_dir_base: Base path for export. Graph will be exported into a subdirectory of this base path. eval_input_receiver_fn: Eval input receiver function. checkpoint_path: Path to a specific checkpoint to export. If set to None, exports the latest checkpoint. Returns: Path to the directory where the eval graph was exported. Raises: ValueError: Could not find a checkpoint to export. """ with tf.Graph().as_default() as g: eval_input_receiver = eval_input_receiver_fn() tf.train.create_global_step(g) tf.set_random_seed(estimator.config.tf_random_seed) # Workaround for TensorFlow issue #17568. Note that we pass the # identity-wrapped features and labels to model_fn, but we have to feed # the non-identity wrapped Tensors during evaluation. # # Also note that we can't wrap predictions, so metrics that have control # dependencies on predictions will cause the predictions to be recomputed # during their evaluation. wrapped_features = util.wrap_tensor_or_dict_of_tensors_in_identity( eval_input_receiver.features) wrapped_labels = util.wrap_tensor_or_dict_of_tensors_in_identity( eval_input_receiver.labels) if isinstance(estimator, tf.estimator.Estimator): # This is a core estimator estimator_spec = estimator.model_fn( features=wrapped_features, labels=wrapped_labels, mode=tf.estimator.ModeKeys.EVAL, config=estimator.config) else: # This is a contrib estimator model_fn_ops = estimator._call_model_fn( # pylint: disable=protected-access features=wrapped_features, labels=wrapped_labels, mode=tf.estimator.ModeKeys.EVAL) estimator_spec = lambda x: None estimator_spec.predictions = model_fn_ops.predictions estimator_spec.eval_metric_ops = model_fn_ops.eval_metric_ops estimator_spec.scaffold = model_fn_ops.scaffold # Save metric using eval_metric_ops. for user_metric_key, (value_op, update_op) in ( estimator_spec.eval_metric_ops.items()): tf.add_to_collection('%s/%s' % (encoding.METRICS_COLLECTION, encoding.KEY_SUFFIX), encoding.encode_key(user_metric_key)) tf.add_to_collection('%s/%s' % (encoding.METRICS_COLLECTION, encoding.VALUE_OP_SUFFIX), encoding.encode_tensor_node(value_op)) tf.add_to_collection('%s/%s' % (encoding.METRICS_COLLECTION, encoding.UPDATE_OP_SUFFIX), encoding.encode_tensor_node(update_op)) # Save all prediction nodes. # Predictions can either be a Tensor, or a dict of Tensors. predictions = estimator_spec.predictions if not isinstance(predictions, dict): predictions = {encoding.DEFAULT_PREDICTIONS_DICT_KEY: predictions} for prediction_key, prediction_node in predictions.items(): _encode_and_add_to_node_collection(encoding.PREDICTIONS_COLLECTION, prediction_key, prediction_node) ############################################################ ## Features, label (and weight) graph # Placeholder for input example to label graph. tf.add_to_collection(encoding.INPUT_EXAMPLE_COLLECTION, encoding.encode_tensor_node( eval_input_receiver.receiver_tensors['examples'])) # Save all label nodes. # Labels can either be a Tensor, or a dict of Tensors. labels = eval_input_receiver.labels if not isinstance(labels, dict): labels = {encoding.DEFAULT_LABELS_DICT_KEY: labels} for label_key, label_node in labels.items(): _encode_and_add_to_node_collection(encoding.LABELS_COLLECTION, label_key, label_node) # Save features. for feature_name, feature_node in eval_input_receiver.features.items(): _encode_and_add_to_node_collection(encoding.FEATURES_COLLECTION, feature_name, feature_node) ############################################################ ## Export as normal if not checkpoint_path: checkpoint_path = tf.train.latest_checkpoint(estimator.model_dir) if not checkpoint_path: raise ValueError( 'Could not find trained model at %s.' % estimator.model_dir) export_dir = _get_timestamped_export_dir(export_dir_base) temp_export_dir = _get_temp_export_dir(export_dir) if estimator.config.session_config is None: session_config = config_pb2.ConfigProto(allow_soft_placement=True) else: session_config = estimator.config.session_config with tf.Session(config=session_config) as session: if estimator_spec.scaffold and estimator_spec.scaffold.saver: saver_for_restore = estimator_spec.scaffold.saver else: saver_for_restore = tf.train.Saver(sharded=True) saver_for_restore.restore(session, checkpoint_path) if estimator_spec.scaffold and estimator_spec.scaffold.local_init_op: local_init_op = estimator_spec.scaffold.local_init_op else: local_init_op = tf.train.Scaffold._default_local_init_op() # pylint: enable=protected-access # Perform the export builder = tf.saved_model.builder.SavedModelBuilder(temp_export_dir) builder.add_meta_graph_and_variables( session, [tf.saved_model.tag_constants.SERVING], # Don't export any signatures, since this graph is not actually # meant for serving. signature_def_map=None, assets_collection=tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS), legacy_init_op=local_init_op) builder.save(False) gfile.Rename(temp_export_dir, export_dir) return export_dir