def export_estimator(estimator, export_dir, input_fn=_default_input_fn, signature_fn=_generic_signature_fn, default_batch_size=1, exports_to_keep=None): """Exports inference graph into given dir. Args: estimator: Estimator to export export_dir: A string containing a directory to write the exported graph and checkpoints. input_fn: Function that given `Tensor` of `Example` strings, parses it into features that are then passed to the model. signature_fn: Function that given `Tensor` of `Example` strings, `dict` of `Tensor`s for features and `dict` of `Tensor`s for predictions and returns default and named exporting signautres. default_batch_size: Default batch size of the `Example` placeholder. exports_to_keep: Number of exports to keep. """ checkpoint_path = tf_saver.latest_checkpoint(estimator._model_dir) with ops.Graph().as_default() as g: contrib_variables.create_global_step(g) examples = array_ops.placeholder(dtype=dtypes.string, shape=[default_batch_size], name='input_example_tensor') features = input_fn(estimator, examples) predictions = estimator._get_predict_ops(features) default_signature, named_graph_signatures = signature_fn( examples, features, predictions) if exports_to_keep is not None: exports_to_keep = gc.largest_export_versions(exports_to_keep) _export_graph(g, _get_saver(), checkpoint_path, export_dir, default_graph_signature=default_signature, named_graph_signatures=named_graph_signatures, exports_to_keep=exports_to_keep)
def _export_estimator(estimator, export_dir, signature_fn, input_fn, default_batch_size, exports_to_keep): input_fn = input_fn or _default_input_fn checkpoint_path = tf_saver.latest_checkpoint(estimator._model_dir) with ops.Graph().as_default() as g: contrib_variables.create_global_step(g) examples = array_ops.placeholder(dtype=dtypes.string, shape=[default_batch_size], name='input_example_tensor') features = input_fn(estimator, examples) predictions = estimator._get_predict_ops(features) # Explicit signature_fn takes priority if signature_fn: default_signature, named_graph_signatures = signature_fn(examples, features, predictions) else: try: # Some estimators provide a target_column of known type target_column = estimator._get_target_column() problem_type = target_column.problem_type if problem_type == layers.ProblemType.CLASSIFICATION: signature_fn = classification_signature_fn elif problem_type == layers.ProblemType.LINEAR_REGRESSION: signature_fn = regression_signature_fn elif problem_type == layers.ProblemType.LOGISTIC_REGRESSION: signature_fn = logistic_regression_signature_fn else: raise ValueError( 'signature_fn must be provided because the TargetColumn is a %s, ' 'which does not have a standard problem type and so cannot use a ' 'standard export signature.' % type(target_column).__name__) default_signature, named_graph_signatures = ( signature_fn(examples, features, predictions)) except AttributeError: logging.warn( 'Change warning: `signature_fn` will be required after' '2016-08-01.\n' 'Using generic signatures for now. To maintain this behavior, ' 'pass:\n' ' signature_fn=export.generic_signature_fn\n' 'Also consider passing a regression or classification signature; ' 'see cl/126430915 for an example.') default_signature, named_graph_signatures = generic_signature_fn( examples, features, predictions) if exports_to_keep is not None: exports_to_keep = gc.largest_export_versions(exports_to_keep) _export_graph(g, _get_saver(), checkpoint_path, export_dir, default_graph_signature=default_signature, named_graph_signatures=named_graph_signatures, exports_to_keep=exports_to_keep)
def testUnion(self): paths = [] for i in xrange(10): paths.append(gc.Path("/foo", i)) f = gc.union(gc.largest_export_versions(3), gc.mod_export_version(3)) self.assertEquals( f(paths), [gc.Path("/foo", 0), gc.Path("/foo", 3), gc.Path("/foo", 6), gc.Path("/foo", 7), gc.Path("/foo", 8), gc.Path("/foo", 9)])
def _export_estimator(estimator, export_dir, signature_fn, input_fn, default_batch_size, exports_to_keep): input_fn = input_fn or _default_input_fn checkpoint_path = tf_saver.latest_checkpoint(estimator._model_dir) with ops.Graph().as_default() as g: contrib_variables.create_global_step(g) examples = array_ops.placeholder(dtype=dtypes.string, shape=[default_batch_size], name='input_example_tensor') features = input_fn(estimator, examples) predictions = estimator._get_predict_ops(features) # Explicit signature_fn takes priority if signature_fn: default_signature, named_graph_signatures = signature_fn( examples, features, predictions) else: try: # Some estimators provide a target_column of known type target_column = estimator._get_target_column() problem_type = target_column.problem_type if problem_type == layers.ProblemType.CLASSIFICATION: signature_fn = classification_signature_fn elif problem_type == layers.ProblemType.LINEAR_REGRESSION: signature_fn = regression_signature_fn elif problem_type == layers.ProblemType.LOGISTIC_REGRESSION: signature_fn = logistic_regression_signature_fn else: raise ValueError( 'signature_fn must be provided because the TargetColumn is a %s, ' 'which does not have a standard problem type and so cannot use a ' 'standard export signature.' % type(target_column).__name__) default_signature, named_graph_signatures = (signature_fn( examples, features, predictions)) except AttributeError: logging.warn( 'Change warning: `signature_fn` will be required after' '2016-08-01.\n' 'Using generic signatures for now. To maintain this behavior, ' 'pass:\n' ' signature_fn=export.generic_signature_fn\n' 'Also consider passing a regression or classification signature; ' 'see cl/126430915 for an example.') default_signature, named_graph_signatures = generic_signature_fn( examples, features, predictions) if exports_to_keep is not None: exports_to_keep = gc.largest_export_versions(exports_to_keep) _export_graph(g, _get_saver(), checkpoint_path, export_dir, default_graph_signature=default_signature, named_graph_signatures=named_graph_signatures, exports_to_keep=exports_to_keep)
def testUnion(self): paths = [] for i in xrange(10): paths.append(gc.Path("/foo", i)) f = gc.union(gc.largest_export_versions(3), gc.mod_export_version(3)) self.assertEquals( f(paths), [gc.Path("/foo", 0), gc.Path("/foo", 3), gc.Path("/foo", 6), gc.Path("/foo", 7), gc.Path("/foo", 8), gc.Path("/foo", 9)])
def export_estimator(estimator, export_dir, signature_fn=None, input_fn=_default_input_fn, default_batch_size=1, exports_to_keep=None): """Exports inference graph into given dir. Args: estimator: Estimator to export export_dir: A string containing a directory to write the exported graph and checkpoints. signature_fn: Function that given `Tensor` of `Example` strings, `dict` of `Tensor`s for features and `dict` of `Tensor`s for predictions input_fn: Function that given `Tensor` of `Example` strings, parses it into features that are then passed to the model. and returns default and named exporting signatures. default_batch_size: Default batch size of the `Example` placeholder. exports_to_keep: Number of exports to keep. """ checkpoint_path = tf_saver.latest_checkpoint(estimator._model_dir) with ops.Graph().as_default() as g: contrib_variables.create_global_step(g) examples = array_ops.placeholder(dtype=dtypes.string, shape=[default_batch_size], name='input_example_tensor') features = input_fn(estimator, examples) predictions = estimator._get_predict_ops(features) if signature_fn: default_signature, named_graph_signatures = signature_fn( examples, features, predictions) else: logging.warn( 'Change warning: `signature_fn` will be required after 2016-08-01.\n' 'Using generic signatures for now. To maintain this behavior, ' 'pass:\n' ' signature_fn=export.generic_signature_fn\n' 'Also consider passing a regression or classification signature; see ' 'cl/126430915 for an example.') default_signature, named_graph_signatures = generic_signature_fn( examples, features, predictions) if exports_to_keep is not None: exports_to_keep = gc.largest_export_versions(exports_to_keep) _export_graph(g, _get_saver(), checkpoint_path, export_dir, default_graph_signature=default_signature, named_graph_signatures=named_graph_signatures, exports_to_keep=exports_to_keep)
def export_estimator(estimator, export_dir, signature_fn=None, input_fn=_default_input_fn, default_batch_size=1, exports_to_keep=None): """Exports inference graph into given dir. Args: estimator: Estimator to export export_dir: A string containing a directory to write the exported graph and checkpoints. signature_fn: Function that given `Tensor` of `Example` strings, `dict` of `Tensor`s for features and `dict` of `Tensor`s for predictions input_fn: Function that given `Tensor` of `Example` strings, parses it into features that are then passed to the model. and returns default and named exporting signatures. default_batch_size: Default batch size of the `Example` placeholder. exports_to_keep: Number of exports to keep. """ checkpoint_path = tf_saver.latest_checkpoint(estimator._model_dir) with ops.Graph().as_default() as g: contrib_variables.create_global_step(g) examples = array_ops.placeholder(dtype=dtypes.string, shape=[default_batch_size], name='input_example_tensor') features = input_fn(estimator, examples) predictions = estimator._get_predict_ops(features) if signature_fn: default_signature, named_graph_signatures = signature_fn(examples, features, predictions) else: logging.warn( 'Change warning: `signature_fn` will be required after 2016-08-01.\n' 'Using generic signatures for now. To maintain this behavior, ' 'pass:\n' ' signature_fn=export.generic_signature_fn\n' 'Also consider passing a regression or classification signature; see ' 'cl/126430915 for an example.') default_signature, named_graph_signatures = generic_signature_fn( examples, features, predictions) if exports_to_keep is not None: exports_to_keep = gc.largest_export_versions(exports_to_keep) _export_graph(g, _get_saver(), checkpoint_path, export_dir, default_graph_signature=default_signature, named_graph_signatures=named_graph_signatures, exports_to_keep=exports_to_keep)
def testLargestExportVersions(self): paths = [gc.Path("/foo", 8), gc.Path("/foo", 9), gc.Path("/foo", 10)] newest = gc.largest_export_versions(2) n = newest(paths) self.assertEquals(n, [gc.Path("/foo", 9), gc.Path("/foo", 10)])
def testLargestExportVersionsDoesNotDeleteZeroFolder(self): paths = [gc.Path("/foo", 0), gc.Path("/foo", 3)] newest = gc.largest_export_versions(2) n = newest(paths) self.assertEquals(n, [gc.Path("/foo", 0), gc.Path("/foo", 3)])
def _export_estimator(estimator, export_dir, signature_fn, input_fn, default_batch_size, exports_to_keep, input_feature_key=None, use_deprecated_input_fn=True, prediction_key=None, checkpoint_path=None): if use_deprecated_input_fn: input_fn = input_fn or _default_input_fn elif input_fn is None: raise ValueError('input_fn must be defined.') # If checkpoint_path is specified, use the specified checkpoint path. checkpoint_path = (checkpoint_path or tf_saver.latest_checkpoint(estimator._model_dir)) with ops.Graph().as_default() as g: contrib_variables.create_global_step(g) if use_deprecated_input_fn: examples = array_ops.placeholder(dtype=dtypes.string, shape=[default_batch_size], name='input_example_tensor') features = input_fn(estimator, examples) else: features, _ = input_fn() examples = None if input_feature_key is not None: examples = features.pop(input_feature_key) if (not features) and (examples is None): raise ValueError('Either features or examples must be defined.') predictions = estimator._get_predict_ops(features).predictions if prediction_key is not None: predictions = predictions[prediction_key] # Explicit signature_fn takes priority if signature_fn: default_signature, named_graph_signatures = signature_fn(examples, features, predictions) else: try: # Some estimators provide a signature function. # TODO(zakaria): check if the estimator has this function, # raise helpful error if not signature_fn = estimator._create_signature_fn() default_signature, named_graph_signatures = ( signature_fn(examples, features, predictions)) except AttributeError: logging.warn( 'Change warning: `signature_fn` will be required after' '2016-08-01.\n' 'Using generic signatures for now. To maintain this behavior, ' 'pass:\n' ' signature_fn=export.generic_signature_fn\n' 'Also consider passing a regression or classification signature; ' 'see cl/126430915 for an example.') default_signature, named_graph_signatures = generic_signature_fn( examples, features, predictions) if exports_to_keep is not None: exports_to_keep = gc.largest_export_versions(exports_to_keep) return _export_graph( g, _get_saver(), checkpoint_path, export_dir, default_graph_signature=default_signature, named_graph_signatures=named_graph_signatures, exports_to_keep=exports_to_keep)
def _export_estimator(estimator, export_dir, signature_fn, input_fn, default_batch_size, exports_to_keep, input_feature_key=None, use_deprecated_input_fn=True, prediction_key=None, checkpoint_path=None): if use_deprecated_input_fn: input_fn = input_fn or _default_input_fn elif input_fn is None: raise ValueError('input_fn must be defined.') # If checkpoint_path is specified, use the specified checkpoint path. checkpoint_path = (checkpoint_path or tf_saver.latest_checkpoint(estimator._model_dir)) with ops.Graph().as_default() as g: training_util.create_global_step(g) if use_deprecated_input_fn: examples = array_ops.placeholder(dtype=dtypes.string, shape=[default_batch_size], name='input_example_tensor') features = input_fn(estimator, examples) else: features, _ = input_fn() examples = None if input_feature_key is not None: examples = features.pop(input_feature_key) if (not features) and (examples is None): raise ValueError('Either features or examples must be defined.') predictions = estimator._get_predict_ops(features).predictions if prediction_key is not None: predictions = predictions[prediction_key] # Explicit signature_fn takes priority if signature_fn: default_signature, named_graph_signatures = signature_fn( examples, features, predictions) else: try: # Some estimators provide a signature function. # TODO (zakaria): check if the estimator has this function, id:1300 gh:1301 # raise helpful error if not signature_fn = estimator._create_signature_fn() default_signature, named_graph_signatures = (signature_fn( examples, features, predictions)) except AttributeError: logging.warn( 'Change warning: `signature_fn` will be required after' '2016-08-01.\n' 'Using generic signatures for now. To maintain this behavior, ' 'pass:\n' ' signature_fn=export.generic_signature_fn\n' 'Also consider passing a regression or classification signature; ' 'see cl/126430915 for an example.') default_signature, named_graph_signatures = generic_signature_fn( examples, features, predictions) if exports_to_keep is not None: exports_to_keep = gc.largest_export_versions(exports_to_keep) return _export_graph(g, _get_saver(), checkpoint_path, export_dir, default_graph_signature=default_signature, named_graph_signatures=named_graph_signatures, exports_to_keep=exports_to_keep)
def testLargestExportVersionsDoesNotDeleteZeroFolder(self): paths = [gc.Path("/foo", 0), gc.Path("/foo", 3)] newest = gc.largest_export_versions(2) n = newest(paths) self.assertEquals(n, [gc.Path("/foo", 0), gc.Path("/foo", 3)])
def _export_estimator(estimator, export_dir, signature_fn, input_fn, default_batch_size, exports_to_keep, input_feature_key=None, use_deprecated_input_fn=True, prediction_key=None): if use_deprecated_input_fn: input_fn = input_fn or _default_input_fn elif input_fn is None: raise ValueError('input_fn must be defined.') checkpoint_path = tf_saver.latest_checkpoint(estimator._model_dir) with ops.Graph().as_default() as g: contrib_variables.create_global_step(g) if use_deprecated_input_fn: examples = array_ops.placeholder(dtype=dtypes.string, shape=[default_batch_size], name='input_example_tensor') features = input_fn(estimator, examples) else: features, _ = input_fn() examples = None if input_feature_key is not None: examples = features.pop(input_feature_key) if (not features) and (examples is None): raise ValueError('Either features or examples must be defined.') # The default return type of _get_predict_ops is ModelFnOps. But there are # some subclasses of tf.contrib.learn.Estimator which override this # method and use the legacy signature, namely _get_predict_ops returns a # `predictions` Tensor or dict or Tensors. The following else-statement # code covers these cases, but will soon be deleted after the subclasses # are updated. # TODO(b/32664904): Update subclasses and delete the else-statement. infer_ops = estimator._get_predict_ops(features) if isinstance(infer_ops, model_fn.ModelFnOps): # Default signature predictions = infer_ops.predictions else: # Legacy signature predictions = infer_ops if prediction_key is not None: predictions = predictions[prediction_key] # Explicit signature_fn takes priority if signature_fn: default_signature, named_graph_signatures = signature_fn(examples, features, predictions) else: try: # Some estimators provide a signature function. # TODO(zakaria): check if the estimator has this function, # raise helpful error if not signature_fn = estimator._create_signature_fn() default_signature, named_graph_signatures = ( signature_fn(examples, features, predictions)) except AttributeError: logging.warn( 'Change warning: `signature_fn` will be required after' '2016-08-01.\n' 'Using generic signatures for now. To maintain this behavior, ' 'pass:\n' ' signature_fn=export.generic_signature_fn\n' 'Also consider passing a regression or classification signature; ' 'see cl/126430915 for an example.') default_signature, named_graph_signatures = generic_signature_fn( examples, features, predictions) if exports_to_keep is not None: exports_to_keep = gc.largest_export_versions(exports_to_keep) return _export_graph( g, _get_saver(), checkpoint_path, export_dir, default_graph_signature=default_signature, named_graph_signatures=named_graph_signatures, exports_to_keep=exports_to_keep)
def export_estimator(estimator, export_dir, signature_fn=None, input_fn=_default_input_fn, default_batch_size=1, exports_to_keep=None): """Exports inference graph into given dir. Args: estimator: Estimator to export export_dir: A string containing a directory to write the exported graph and checkpoints. signature_fn: Function that returns a default signature and a named signature map, given `Tensor` of `Example` strings, `dict` of `Tensor`s for features and `Tensor` or `dict` of `Tensor`s for predictions. input_fn: Function that given `Tensor` of `Example` strings, parses it into features that are then passed to the model. default_batch_size: Default batch size of the `Example` placeholder. exports_to_keep: Number of exports to keep. """ checkpoint_path = tf_saver.latest_checkpoint(estimator._model_dir) with ops.Graph().as_default() as g: contrib_variables.create_global_step(g) examples = array_ops.placeholder(dtype=dtypes.string, shape=[default_batch_size], name='input_example_tensor') features = input_fn(estimator, examples) predictions = estimator._get_predict_ops(features) # Explicit signature_fn takes priority if signature_fn: default_signature, named_graph_signatures = signature_fn( examples, features, predictions) else: try: # Some estimators provide a target_column of known type target_column = estimator._get_target_column() problem_type = target_column.problem_type if problem_type == layers.ProblemType.CLASSIFICATION: signature_fn = classification_signature_fn elif problem_type == layers.ProblemType.LINEAR_REGRESSION: signature_fn = regression_signature_fn elif problem_type == layers.ProblemType.LOGISTIC_REGRESSION: signature_fn = logistic_regression_signature_fn else: raise ValueError( 'signature_fn must be provided because the TargetColumn is a %s, ' 'which does not have a standard problem type and so cannot use a ' 'standard export signature.' % type(target_column).__name__) default_signature, named_graph_signatures = (signature_fn( examples, features, predictions)) except AttributeError: logging.warn( 'Change warning: `signature_fn` will be required after' '2016-08-01.\n' 'Using generic signatures for now. To maintain this behavior, ' 'pass:\n' ' signature_fn=export.generic_signature_fn\n' 'Also consider passing a regression or classification signature; ' 'see cl/126430915 for an example.') default_signature, named_graph_signatures = generic_signature_fn( examples, features, predictions) if exports_to_keep is not None: exports_to_keep = gc.largest_export_versions(exports_to_keep) _export_graph(g, _get_saver(), checkpoint_path, export_dir, default_graph_signature=default_signature, named_graph_signatures=named_graph_signatures, exports_to_keep=exports_to_keep)
def _export_estimator(estimator, export_dir, signature_fn, input_fn, default_batch_size, exports_to_keep, input_feature_key=None, use_deprecated_input_fn=True, prediction_key=None): if use_deprecated_input_fn: input_fn = input_fn or _default_input_fn elif input_fn is None: raise ValueError('input_fn must be defined.') checkpoint_path = tf_saver.latest_checkpoint(estimator._model_dir) with ops.Graph().as_default() as g: contrib_variables.create_global_step(g) if use_deprecated_input_fn: examples = array_ops.placeholder(dtype=dtypes.string, shape=[default_batch_size], name='input_example_tensor') features = input_fn(estimator, examples) else: features, _ = input_fn() examples = None if input_feature_key is not None: examples = features.pop(input_feature_key) if (not features) and (examples is None): raise ValueError('Either features or examples must be defined.') # The default return type of _get_predict_ops is ModelFnOps. But there are # some subclasses of tf.contrib.learn.Estimator which override this # method and use the legacy signature, namely _get_predict_ops returns a # `predictions` Tensor or dict or Tensors. The following else-statement # code covers these cases, but will soon be deleted after the subclasses # are updated. # TODO(b/32664904): Update subclasses and delete the else-statement. infer_ops = estimator._get_predict_ops(features) if isinstance(infer_ops, model_fn.ModelFnOps): # Default signature predictions = infer_ops.predictions else: # Legacy signature predictions = infer_ops if prediction_key is not None: predictions = predictions[prediction_key] # Explicit signature_fn takes priority if signature_fn: default_signature, named_graph_signatures = signature_fn( examples, features, predictions) else: try: # Some estimators provide a signature function. # TODO(zakaria): check if the estimator has this function, # raise helpful error if not signature_fn = estimator._create_signature_fn() default_signature, named_graph_signatures = (signature_fn( examples, features, predictions)) except AttributeError: logging.warn( 'Change warning: `signature_fn` will be required after' '2016-08-01.\n' 'Using generic signatures for now. To maintain this behavior, ' 'pass:\n' ' signature_fn=export.generic_signature_fn\n' 'Also consider passing a regression or classification signature; ' 'see cl/126430915 for an example.') default_signature, named_graph_signatures = generic_signature_fn( examples, features, predictions) if exports_to_keep is not None: exports_to_keep = gc.largest_export_versions(exports_to_keep) return _export_graph(g, _get_saver(), checkpoint_path, export_dir, default_graph_signature=default_signature, named_graph_signatures=named_graph_signatures, exports_to_keep=exports_to_keep)
def doBasicsOneExportPath(self, export_path, clear_devices=False, global_step=GLOBAL_STEP, sharded=True): # Build a graph with 2 parameter nodes on different devices. tf.reset_default_graph() with tf.Session(target="", config=config_pb2.ConfigProto( device_count={"CPU": 2})) as sess: # v2 is an unsaved variable derived from v0 and v1. It is used to # exercise the ability to run an init op when restoring a graph. with sess.graph.device("/cpu:0"): v0 = tf.Variable(10, name="v0") with sess.graph.device("/cpu:1"): v1 = tf.Variable(20, name="v1") v2 = tf.Variable(1, name="v2", trainable=False, collections=[]) assign_v2 = tf.assign(v2, tf.add(v0, v1)) init_op = tf.group(assign_v2, name="init_op") tf.add_to_collection("v", v0) tf.add_to_collection("v", v1) tf.add_to_collection("v", v2) global_step_tensor = tf.Variable(global_step, name="global_step") named_tensor_bindings = { "logical_input_A": v0, "logical_input_B": v1 } signatures = { "foo": exporter.regression_signature(input_tensor=v0, output_tensor=v1), "generic": exporter.generic_signature(named_tensor_bindings) } asset_filepath_orig = os.path.join(tf.test.get_temp_dir(), "hello42.txt") asset_file = tf.constant(asset_filepath_orig, name="filename42") tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, asset_file) with gfile.FastGFile(asset_filepath_orig, "w") as f: f.write("your data here") assets_collection = tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS) ignored_asset = os.path.join(tf.test.get_temp_dir(), "ignored.txt") with gfile.FastGFile(ignored_asset, "w") as f: f.write("additional data here") tf.initialize_all_variables().run() # Run an export. save = tf.train.Saver({ "v0": v0, "v1": v1 }, restore_sequentially=True, sharded=sharded) export = exporter.Exporter(save) export.init( sess.graph.as_graph_def(), init_op=init_op, clear_devices=clear_devices, default_graph_signature=exporter.classification_signature( input_tensor=v0), named_graph_signatures=signatures, assets_collection=assets_collection) export.export(export_path, global_step_tensor, sess, exports_to_keep=gc.largest_export_versions(2)) # Restore graph. compare_def = tf.get_default_graph().as_graph_def() tf.reset_default_graph() with tf.Session(target="", config=config_pb2.ConfigProto( device_count={"CPU": 2})) as sess: save = tf.train.import_meta_graph( os.path.join(export_path, constants.VERSION_FORMAT_SPECIFIER % global_step, constants.META_GRAPH_DEF_FILENAME)) self.assertIsNotNone(save) meta_graph_def = save.export_meta_graph() collection_def = meta_graph_def.collection_def # Validate custom graph_def. graph_def_any = collection_def[constants.GRAPH_KEY].any_list.value self.assertEquals(len(graph_def_any), 1) graph_def = tf.GraphDef() graph_def_any[0].Unpack(graph_def) if clear_devices: for node in compare_def.node: node.device = "" self.assertProtoEquals(compare_def, graph_def) # Validate init_op. init_ops = collection_def[constants.INIT_OP_KEY].node_list.value self.assertEquals(len(init_ops), 1) self.assertEquals(init_ops[0], "init_op") # Validate signatures. signatures_any = collection_def[ constants.SIGNATURES_KEY].any_list.value self.assertEquals(len(signatures_any), 1) signatures = manifest_pb2.Signatures() signatures_any[0].Unpack(signatures) default_signature = signatures.default_signature self.assertEqual( default_signature.classification_signature.input.tensor_name, "v0:0") bindings = signatures.named_signatures[ "generic"].generic_signature.map self.assertEquals(bindings["logical_input_A"].tensor_name, "v0:0") self.assertEquals(bindings["logical_input_B"].tensor_name, "v1:0") read_foo_signature = ( signatures.named_signatures["foo"].regression_signature) self.assertEquals(read_foo_signature.input.tensor_name, "v0:0") self.assertEquals(read_foo_signature.output.tensor_name, "v1:0") # Validate the assets. assets_any = collection_def[constants.ASSETS_KEY].any_list.value self.assertEquals(len(assets_any), 1) asset = manifest_pb2.AssetFile() assets_any[0].Unpack(asset) assets_path = os.path.join( export_path, constants.VERSION_FORMAT_SPECIFIER % global_step, constants.ASSETS_DIRECTORY, "hello42.txt") asset_contents = gfile.GFile(assets_path).read() self.assertEqual(asset_contents, "your data here") self.assertEquals("hello42.txt", asset.filename) self.assertEquals("filename42:0", asset.tensor_binding.tensor_name) ignored_asset_path = os.path.join( export_path, constants.VERSION_FORMAT_SPECIFIER % global_step, constants.ASSETS_DIRECTORY, "ignored.txt") self.assertFalse(gfile.Exists(ignored_asset_path)) # Validate graph restoration. if sharded: save.restore( sess, os.path.join( export_path, constants.VERSION_FORMAT_SPECIFIER % global_step, constants.VARIABLES_FILENAME_PATTERN)) else: save.restore( sess, os.path.join( export_path, constants.VERSION_FORMAT_SPECIFIER % global_step, constants.VARIABLES_FILENAME)) self.assertEqual(10, tf.get_collection("v")[0].eval()) self.assertEqual(20, tf.get_collection("v")[1].eval()) tf.get_collection(constants.INIT_OP_KEY)[0].run() self.assertEqual(30, tf.get_collection("v")[2].eval())
def doBasicsOneExportPath(self, export_path, clear_devices=False, global_step=GLOBAL_STEP, sharded=True): # Build a graph with 2 parameter nodes on different devices. tf.reset_default_graph() with tf.Session( target="", config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: # v2 is an unsaved variable derived from v0 and v1. It is used to # exercise the ability to run an init op when restoring a graph. with sess.graph.device("/cpu:0"): v0 = tf.Variable(10, name="v0") with sess.graph.device("/cpu:1"): v1 = tf.Variable(20, name="v1") v2 = tf.Variable(1, name="v2", trainable=False, collections=[]) assign_v2 = tf.assign(v2, tf.add(v0, v1)) init_op = tf.group(assign_v2, name="init_op") tf.add_to_collection("v", v0) tf.add_to_collection("v", v1) tf.add_to_collection("v", v2) global_step_tensor = tf.Variable(global_step, name="global_step") named_tensor_bindings = {"logical_input_A": v0, "logical_input_B": v1} signatures = { "foo": exporter.regression_signature(input_tensor=v0, output_tensor=v1), "generic": exporter.generic_signature(named_tensor_bindings) } asset_filepath_orig = os.path.join(tf.test.get_temp_dir(), "hello42.txt") asset_file = tf.constant(asset_filepath_orig, name="filename42") tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, asset_file) with gfile.FastGFile(asset_filepath_orig, "w") as f: f.write("your data here") assets_collection = tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS) ignored_asset = os.path.join(tf.test.get_temp_dir(), "ignored.txt") with gfile.FastGFile(ignored_asset, "w") as f: f.write("additional data here") tf.initialize_all_variables().run() # Run an export. save = tf.train.Saver({"v0": v0, "v1": v1}, restore_sequentially=True, sharded=sharded) export = exporter.Exporter(save) export.init(sess.graph.as_graph_def(), init_op=init_op, clear_devices=clear_devices, default_graph_signature=exporter.classification_signature( input_tensor=v0), named_graph_signatures=signatures, assets_collection=assets_collection) export.export(export_path, global_step_tensor, sess, exports_to_keep=gc.largest_export_versions(2)) # Restore graph. compare_def = tf.get_default_graph().as_graph_def() tf.reset_default_graph() with tf.Session( target="", config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: save = tf.train.import_meta_graph( os.path.join(export_path, constants.VERSION_FORMAT_SPECIFIER % global_step, constants.META_GRAPH_DEF_FILENAME)) self.assertIsNotNone(save) meta_graph_def = save.export_meta_graph() collection_def = meta_graph_def.collection_def # Validate custom graph_def. graph_def_any = collection_def[constants.GRAPH_KEY].any_list.value self.assertEquals(len(graph_def_any), 1) graph_def = tf.GraphDef() graph_def_any[0].Unpack(graph_def) if clear_devices: for node in compare_def.node: node.device = "" self.assertProtoEquals(compare_def, graph_def) # Validate init_op. init_ops = collection_def[constants.INIT_OP_KEY].node_list.value self.assertEquals(len(init_ops), 1) self.assertEquals(init_ops[0], "init_op") # Validate signatures. signatures_any = collection_def[constants.SIGNATURES_KEY].any_list.value self.assertEquals(len(signatures_any), 1) signatures = manifest_pb2.Signatures() signatures_any[0].Unpack(signatures) default_signature = signatures.default_signature self.assertEqual( default_signature.classification_signature.input.tensor_name, "v0:0") bindings = signatures.named_signatures["generic"].generic_signature.map self.assertEquals(bindings["logical_input_A"].tensor_name, "v0:0") self.assertEquals(bindings["logical_input_B"].tensor_name, "v1:0") read_foo_signature = ( signatures.named_signatures["foo"].regression_signature) self.assertEquals(read_foo_signature.input.tensor_name, "v0:0") self.assertEquals(read_foo_signature.output.tensor_name, "v1:0") # Validate the assets. assets_any = collection_def[constants.ASSETS_KEY].any_list.value self.assertEquals(len(assets_any), 1) asset = manifest_pb2.AssetFile() assets_any[0].Unpack(asset) assets_path = os.path.join(export_path, constants.VERSION_FORMAT_SPECIFIER % global_step, constants.ASSETS_DIRECTORY, "hello42.txt") asset_contents = gfile.GFile(assets_path).read() self.assertEqual(asset_contents, "your data here") self.assertEquals("hello42.txt", asset.filename) self.assertEquals("filename42:0", asset.tensor_binding.tensor_name) ignored_asset_path = os.path.join(export_path, constants.VERSION_FORMAT_SPECIFIER % global_step, constants.ASSETS_DIRECTORY, "ignored.txt") self.assertFalse(gfile.Exists(ignored_asset_path)) # Validate graph restoration. if sharded: save.restore(sess, os.path.join( export_path, constants.VERSION_FORMAT_SPECIFIER % global_step, constants.VARIABLES_FILENAME_PATTERN)) else: save.restore(sess, os.path.join( export_path, constants.VERSION_FORMAT_SPECIFIER % global_step, constants.VARIABLES_FILENAME)) self.assertEqual(10, tf.get_collection("v")[0].eval()) self.assertEqual(20, tf.get_collection("v")[1].eval()) tf.get_collection(constants.INIT_OP_KEY)[0].run() self.assertEqual(30, tf.get_collection("v")[2].eval())
def testLargestExportVersions(self): paths = [gc.Path("/foo", 8), gc.Path("/foo", 9), gc.Path("/foo", 10)] newest = gc.largest_export_versions(2) n = newest(paths) self.assertEquals(n, [gc.Path("/foo", 9), gc.Path("/foo", 10)])