def extract_weights(graph_def, output_graph): """Takes a Python GraphDef object and extract the weights. Args: graph_def: tf.GraphDef tensorflow GraphDef proto object, which represents the model topology """ constants = [node for node in graph_def.node if node.op == 'Const'] # removed the conditional inputs for constants for const in constants: del const.input[:] print('Writing weight file ' + output_graph + '...') const_manifest = [] path = os.path.dirname(output_graph) graph = tf.Graph() with tf.Session(graph=graph) as sess: tf.import_graph_def(graph_def, name='') for const in constants: tensor = graph.get_tensor_by_name(const.name + ':0') value = tensor.eval(session=sess) if not isinstance(value, np.ndarray): value = np.array(value) const_manifest.append({'name': const.name, 'data': value}) # Remove the binary array from tensor and save it to the external file. const.attr["value"].tensor.ClearField('tensor_content') write_weights.write_weights([const_manifest], path) file_io.atomic_write_string_to_file(os.path.abspath(output_graph), graph_def.SerializeToString())
def _build_meta_graph(obj, export_dir, signatures, options, meta_graph_def=None): """Creates a MetaGraph containing the resources and functions of an object.""" if ops.inside_function(): raise AssertionError( "tf.saved_model.save is not supported inside a traced " "@tf.function. Move the call to the outer eagerly-executed " "context.") # pylint: enable=line-too-long if not isinstance(obj, base.Trackable): raise ValueError( "Expected a Trackable object for export, got {}.".format(obj)) meta_graph_def = meta_graph_def or meta_graph_pb2.MetaGraphDef() checkpoint_graph_view = _AugmentedGraphView(obj) if signatures is None: signatures = signature_serialization.find_function_to_export( checkpoint_graph_view) signatures, wrapped_functions = ( signature_serialization.canonicalize_signatures(signatures)) signature_serialization.validate_saveable_view(checkpoint_graph_view) signature_map = signature_serialization.create_signature_map(signatures) checkpoint_graph_view.add_object( parent_node=checkpoint_graph_view.root, name_in_parent=signature_serialization.SIGNATURE_ATTRIBUTE_NAME, subgraph_root=signature_map) # Use _SaveableView to provide a frozen listing of properties and functions. # Note we run this twice since, while constructing the view the first time # there can be side effects of creating variables. _ = _SaveableView(checkpoint_graph_view) saveable_view = _SaveableView(checkpoint_graph_view, wrapped_functions) object_saver = util.TrackableSaver(checkpoint_graph_view) asset_info, exported_graph = _fill_meta_graph_def( meta_graph_def, saveable_view, signatures, options.namespace_whitelist) if options.function_aliases: function_aliases = meta_graph_def.meta_info_def.function_aliases for alias, func in options.function_aliases.items(): for fdef in func._stateful_fn._function_cache.all_values(): # pylint: disable=protected-access function_aliases[fdef.name] = alias for fdef in func._stateless_fn._function_cache.all_values(): # pylint: disable=protected-access function_aliases[fdef.name] = alias object_graph_proto = _serialize_object_graph(saveable_view, asset_info.asset_index) meta_graph_def.object_graph_def.CopyFrom(object_graph_proto) # Save debug info, if requested. if options.save_debug_info: graph_debug_info = _export_debug_info(exported_graph) file_io.atomic_write_string_to_file( os.path.join(utils_impl.get_or_create_debug_dir(export_dir), constants.DEBUG_INFO_FILENAME_PB), graph_debug_info.SerializeToString(deterministic=True)) return meta_graph_def, exported_graph, object_saver, asset_info
def update_checkpoint_state_internal(save_dir, model_checkpoint_path, all_model_checkpoint_paths=None, latest_filename=None, save_relative_paths=False): """Updates the content of the 'checkpoint' file. This updates the checkpoint file containing a CheckpointState proto. Args: save_dir: Directory where the model was saved. model_checkpoint_path: The checkpoint file. all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted checkpoints, sorted from oldest to newest. If this is a non-empty list, the last element must be equal to model_checkpoint_path. These paths are also saved in the CheckpointState proto. latest_filename: Optional name of the checkpoint file. Default to 'checkpoint'. save_relative_paths: If `True`, will write relative paths to the checkpoint state file. Raises: RuntimeError: If any of the model checkpoint paths conflict with the file containing CheckpointSate. """ # Writes the "checkpoint" file for the coordinator for later restoration. coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename) if save_relative_paths: if os.path.isabs(model_checkpoint_path): rel_model_checkpoint_path = os.path.relpath( model_checkpoint_path, save_dir) else: rel_model_checkpoint_path = model_checkpoint_path rel_all_model_checkpoint_paths = [] for p in all_model_checkpoint_paths: if os.path.isabs(p): rel_all_model_checkpoint_paths.append(os.path.relpath(p, save_dir)) else: rel_all_model_checkpoint_paths.append(p) ckpt = generate_checkpoint_state_proto( save_dir, rel_model_checkpoint_path, all_model_checkpoint_paths=rel_all_model_checkpoint_paths) else: ckpt = generate_checkpoint_state_proto( save_dir, model_checkpoint_path, all_model_checkpoint_paths=all_model_checkpoint_paths) if coord_checkpoint_filename == ckpt.model_checkpoint_path: raise RuntimeError("Save path '%s' conflicts with path used for " "checkpoint state. Please use a different save path." % model_checkpoint_path) # Preventing potential read/write race condition by *atomically* writing to a # file. file_io.atomic_write_string_to_file(coord_checkpoint_filename, text_format.MessageToString(ckpt))
def upload_schema(self): # type: () -> None if not self.schema: raise ValueError( "Cannot upload a schema since no schema_path was provided. Either provide one, or " "use write_stats_and_schema so that a schema can be inferred first." ) file_io.atomic_write_string_to_file(self.schema_snapshot_path, self.schema.SerializeToString())
def set_data(self, key, data): key_path = self._generate_path(key) base_dir = os.path.dirname(key_path) if not gfile.Exists(base_dir): try: gfile.MakeDirs(base_dir) except tf.errors.OpError as e: # pylint: disable=broad-except fl_logging.warning("create directory %s failed," " reason: %s", base_dir, str(e)) return False file_io.atomic_write_string_to_file(key_path, data) return True
def write_metadata(metadata, path): """Write metadata to given path, in JSON format. Args: metadata: A `DatasetMetadata` to write. path: a path to a directory where metadata should be written. """ if not file_io.file_exists(path): file_io.recursive_create_dir(path) schema_file = os.path.join(path, 'schema.pbtxt') ascii_proto = text_format.MessageToString(metadata.schema) file_io.atomic_write_string_to_file(schema_file, ascii_proto, overwrite=True)
def export_meta_graph(obj, filename, signatures=None, options=None): """Exports the MetaGraph proto to a file.""" options = options or save_options.SaveOptions() export_dir = os.path.dirname(filename) meta_graph_def, exported_graph, _, _ = _build_meta_graph( obj, export_dir, signatures, options) file_io.atomic_write_string_to_file( filename, meta_graph_def.SerializeToString(deterministic=True)) # Clean reference cycles so repeated export()s don't make work for the garbage # collector. Before this point, we need to keep references to captured # constants in the saved graph. ops.dismantle_graph(exported_graph)
def write_graph(graph_or_graph_def, logdir, name, as_text=True): """Writes a graph proto to a file. The graph is written as a text proto unless `as_text` is `False`. ```python v = tf.Variable(0, name='my_variable') sess = tf.compat.v1.Session() tf.io.write_graph(sess.graph_def, '/tmp/my-model', 'train.pbtxt') ``` or ```python v = tf.Variable(0, name='my_variable') sess = tf.compat.v1.Session() tf.io.write_graph(sess.graph, '/tmp/my-model', 'train.pbtxt') ``` Args: graph_or_graph_def: A `Graph` or a `GraphDef` protocol buffer. logdir: Directory where to write the graph. This can refer to remote filesystems, such as Google Cloud Storage (GCS). name: Filename for the graph. as_text: If `True`, writes the graph as an ASCII proto. Returns: The path of the output proto file. """ if isinstance(graph_or_graph_def, ops.Graph): graph_def = graph_or_graph_def.as_graph_def() else: graph_def = graph_or_graph_def # gcs does not have the concept of directory at the moment. if not logdir.startswith('gs:'): file_io.recursive_create_dir(logdir) path = os.path.join(logdir, name) if as_text: file_io.atomic_write_string_to_file(path, text_format.MessageToString( graph_def, float_format='')) else: file_io.atomic_write_string_to_file( path, graph_def.SerializeToString(deterministic=True)) return path
def write_graph(graph_or_graph_def, logdir, name, as_text=True): """Writes a graph proto to a file. The graph is written as a text proto unless `as_text` is `False`. ```python v = tf.Variable(0, name='my_variable') sess = tf.compat.v1.Session() tf.io.write_graph(sess.graph_def, '/tmp/my-model', 'train.pbtxt') ``` or ```python v = tf.Variable(0, name='my_variable') sess = tf.compat.v1.Session() tf.io.write_graph(sess.graph, '/tmp/my-model', 'train.pbtxt') ``` Args: graph_or_graph_def: A `Graph` or a `GraphDef` protocol buffer. logdir: Directory where to write the graph. This can refer to remote filesystems, such as Google Cloud Storage (GCS). name: Filename for the graph. as_text: If `True`, writes the graph as an ASCII proto. Returns: The path of the output proto file. """ if isinstance(graph_or_graph_def, ops.Graph): graph_def = graph_or_graph_def.as_graph_def() else: graph_def = graph_or_graph_def # gcs does not have the concept of directory at the moment. if not file_io.file_exists(logdir) and not logdir.startswith('gs:'): file_io.recursive_create_dir(logdir) path = os.path.join(logdir, name) if as_text: file_io.atomic_write_string_to_file(path, text_format.MessageToString( graph_def, float_format='')) else: file_io.atomic_write_string_to_file(path, graph_def.SerializeToString()) return path
def extract_weights(graph_def, output_graph, quantization_dtype=None): """Takes a Python GraphDef object and extract the weights. Args: graph_def: tf.GraphDef TensorFlow GraphDef proto object, which represents the model topology. quantization_dtype: An optional numpy dtype to quantize weights to for compression. Only np.uint8 and np.uint16 are supported. """ constants = [node for node in graph_def.node if node.op == 'Const'] constInputs = {} # removed the conditional inputs for constants for const in constants: constInputs[const.name] = const.input[:] del const.input[:] print('Writing weight file ' + output_graph + '...') const_manifest = [] path = os.path.dirname(output_graph) graph = tf.Graph() with tf.Session(graph=graph) as sess: tf.import_graph_def(graph_def, name='') for const in constants: tensor = graph.get_tensor_by_name(const.name + ':0') value = tensor.eval(session=sess) if not isinstance(value, np.ndarray): value = np.array(value) # Restore the conditional inputs const_manifest.append({'name': const.name, 'data': value}) const.input[:] = constInputs[const.name] # Remove the binary array from tensor and save it to the external file. const.attr["value"].tensor.ClearField('tensor_content') write_weights.write_weights( [const_manifest], path, quantization_dtype=quantization_dtype) file_io.atomic_write_string_to_file( os.path.abspath(output_graph), graph_def.SerializeToString())
def extract_weights(graph_def, output_graph, quantization_dtype=None): """Takes a Python GraphDef object and extract the weights. Args: graph_def: tf.GraphDef TensorFlow GraphDef proto object, which represents the model topology. quantization_dtype: An optional numpy dtype to quantize weights to for compression. Only np.uint8 and np.uint16 are supported. """ constants = [node for node in graph_def.node if node.op == 'Const'] constInputs = {} # removed the conditional inputs for constants for const in constants: constInputs[const.name] = const.input[:] del const.input[:] print('Writing weight file ' + output_graph + '...') const_manifest = [] path = os.path.dirname(output_graph) graph = tf.Graph() with tf.Session(graph=graph) as sess: tf.import_graph_def(graph_def, name='') for const in constants: tensor = graph.get_tensor_by_name(const.name + ':0') value = tensor.eval(session=sess) if not isinstance(value, np.ndarray): value = np.array(value) # Restore the conditional inputs const_manifest.append({'name': const.name, 'data': value}) const.input[:] = constInputs[const.name] # Remove the binary array from tensor and save it to the external file. for field_name in CLEARED_TENSOR_FIELDS: const.attr["value"].tensor.ClearField(field_name) write_weights.write_weights([const_manifest], path, quantization_dtype=quantization_dtype) file_io.atomic_write_string_to_file(os.path.abspath(output_graph), graph_def.SerializeToString())
def testAtomicWriteStringToFileOverwriteFalse(self): file_path = os.path.join(self._base_dir, "temp_file") file_io.atomic_write_string_to_file(file_path, "old", overwrite=False) with self.assertRaises(errors.AlreadyExistsError): file_io.atomic_write_string_to_file(file_path, "new", overwrite=False) file_contents = file_io.read_file_to_string(file_path) self.assertEqual("old", file_contents) file_io.delete_file(file_path) file_io.atomic_write_string_to_file(file_path, "new", overwrite=False) file_contents = file_io.read_file_to_string(file_path) self.assertEqual("new", file_contents)
def upload_anomalies(self): # type: () -> None if self.anomalies.anomaly_info: file_io.atomic_write_string_to_file(self.anomalies_path, self.anomalies.SerializeToString())
def save(obj, export_dir, signatures=None, options=None): # pylint: disable=line-too-long """Exports the Trackable object `obj` to [SavedModel format](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md). Example usage: ```python class Adder(tf.Module): @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)]) def add(self, x): return x + x + 1. to_export = Adder() tf.saved_model.save(to_export, '/tmp/adder') ``` The resulting SavedModel is then servable with an input named "x", its value having any shape and dtype float32. The optional `signatures` argument controls which methods in `obj` will be available to programs which consume `SavedModel`s, for example, serving APIs. Python functions may be decorated with `@tf.function(input_signature=...)` and passed as signatures directly, or lazily with a call to `get_concrete_function` on the method decorated with `@tf.function`. If the `signatures` argument is omitted, `obj` will be searched for `@tf.function`-decorated methods. If exactly one `@tf.function` is found, that method will be used as the default signature for the SavedModel. This behavior is expected to change in the future, when a corresponding `tf.saved_model.load` symbol is added. At that point signatures will be completely optional, and any `@tf.function` attached to `obj` or its dependencies will be exported for use with `load`. When invoking a signature in an exported SavedModel, `Tensor` arguments are identified by name. These names will come from the Python function's argument names by default. They may be overridden by specifying a `name=...` argument in the corresponding `tf.TensorSpec` object. Explicit naming is required if multiple `Tensor`s are passed through a single argument to the Python function. The outputs of functions used as `signatures` must either be flat lists, in which case outputs will be numbered, or a dictionary mapping string keys to `Tensor`, in which case the keys will be used to name outputs. Signatures are available in objects returned by `tf.saved_model.load` as a `.signatures` attribute. This is a reserved attribute: `tf.saved_model.save` on an object with a custom `.signatures` attribute will raise an exception. Since `tf.keras.Model` objects are also Trackable, this function can be used to export Keras models. For example, exporting with a signature specified: ```python class Model(tf.keras.Model): @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)]) def serve(self, serialized): ... m = Model() tf.saved_model.save(m, '/tmp/saved_model/') ``` Exporting from a function without a fixed signature: ```python class Model(tf.keras.Model): @tf.function def call(self, x): ... m = Model() tf.saved_model.save( m, '/tmp/saved_model/', signatures=m.call.get_concrete_function( tf.TensorSpec(shape=[None, 3], dtype=tf.float32, name="inp"))) ``` `tf.keras.Model` instances constructed from inputs and outputs already have a signature and so do not require a `@tf.function` decorator or a `signatures` argument. If neither are specified, the model's forward pass is exported. ```python x = input_layer.Input((4,), name="x") y = core.Dense(5, name="out")(x) model = training.Model(x, y) tf.saved_model.save(model, '/tmp/saved_model/') # The exported SavedModel takes "x" with shape [None, 4] and returns "out" # with shape [None, 5] ``` Variables must be tracked by assigning them to an attribute of a tracked object or to an attribute of `obj` directly. TensorFlow objects (e.g. layers from `tf.keras.layers`, optimizers from `tf.train`) track their variables automatically. This is the same tracking scheme that `tf.train.Checkpoint` uses, and an exported `Checkpoint` object may be restored as a training checkpoint by pointing `tf.train.Checkpoint.restore` to the SavedModel's "variables/" subdirectory. Currently, variables are the only stateful objects supported by `tf.saved_model.save`, but others (e.g. tables) will be supported in the future. `tf.function` does not hard-code device annotations from outside the function body, instead of using the calling context's device. This means for example that exporting a model that runs on a GPU and serving it on a CPU will generally work, with some exceptions. `tf.device` annotations inside the body of the function will be hard-coded in the exported model; this type of annotation is discouraged. Device-specific operations, e.g. with "cuDNN" in the name or with device-specific layouts, may cause issues. Currently a `DistributionStrategy` is another exception: active distribution strategies will cause device placements to be hard-coded in a function. Exporting a single-device computation and importing under a `DistributionStrategy` is not currently supported, but may be in the future. SavedModels exported with `tf.saved_model.save` [strip default-valued attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes) automatically, which removes one source of incompatibilities when the consumer of a SavedModel is running an older TensorFlow version than the producer. There are however other sources of incompatibilities which are not handled automatically, such as when the exported model contains operations which the consumer does not have definitions for. A single tf.function can generate many ConcreteFunctions. If a downstream tool wants to refer to all concrete functions generated by a single tf.function you can use the `function_aliases` argument to store a map from the alias name to all concrete function names. E.g. ```python class MyModel: @tf.function def func(): ... @tf.function def serve(): ... func() model = MyModel() signatures = { 'serving_default': model.serve.get_concrete_function(), } options = tf.saved_model.SaveOptions(function_aliases={ 'my_func': func, }) tf.saved_model.save(model, export_dir, signatures, options) ``` Args: obj: A trackable object to export. export_dir: A directory in which to write the SavedModel. signatures: Optional, either a `tf.function` with an input signature specified or the result of `f.get_concrete_function` on a `@tf.function`-decorated function `f`, in which case `f` will be used to generate a signature for the SavedModel under the default serving signature key. `signatures` may also be a dictionary, in which case it maps from signature keys to either `tf.function` instances with input signatures or concrete functions. The keys of such a dictionary may be arbitrary strings, but will typically be from the `tf.saved_model.signature_constants` module. options: Optional, `tf.saved_model.SaveOptions` object that specifies options for saving. Raises: ValueError: If `obj` is not trackable. @compatibility(eager) Not well supported when graph building. From TensorFlow 1.x, `tf.compat.v1.enable_eager_execution()` should run first. Calling tf.saved_model.save in a loop when graph building from TensorFlow 1.x will add new save operations to the default graph each iteration. May not be called from within a function body. @end_compatibility """ options = options or save_options.SaveOptions() # TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x # compatible (no sessions) and share it with this export API rather than # making a SavedModel proto and writing it directly. saved_model = saved_model_pb2.SavedModel() meta_graph_def = saved_model.meta_graphs.add() _, exported_graph, object_saver, asset_info = _build_meta_graph( obj, export_dir, signatures, options, meta_graph_def) saved_model.saved_model_schema_version = constants.SAVED_MODEL_SCHEMA_VERSION # Write the checkpoint, copy assets into the assets directory, and write out # the SavedModel proto itself. utils_impl.get_or_create_variables_dir(export_dir) ckpt_options = checkpoint_options.CheckpointOptions( experimental_io_device=options.experimental_io_device) object_saver.save(utils_impl.get_variables_path(export_dir), options=ckpt_options) builder_impl.copy_assets_to_destination_dir(asset_info.asset_filename_map, export_dir) # Note that this needs to be the last file operation when saving the # SavedModel. Users rely on checking saved_model_dir/saved_model.pb as an # indication that the SavedModel is completely written. if context.executing_eagerly(): try: context.async_wait() # Ensure save operations have completed. except errors.NotFoundError as err: raise FileNotFoundError( str(err) + "\n If trying to save on a different device from the " "computational device, consider using setting the " "`experimental_io_device` option on tf.saved_model.SaveOptions " "to the io_device such as '/job:localhost'.") path = os.path.join(compat.as_str(export_dir), compat.as_str(constants.SAVED_MODEL_FILENAME_PB)) file_io.atomic_write_string_to_file( path, saved_model.SerializeToString(deterministic=True)) # Clean reference cycles so repeated export()s don't make work for the garbage # collector. Before this point, we need to keep references to captured # constants in the saved graph. ops.dismantle_graph(exported_graph)
def main(unused_argv): params = hyperparameters.get_hyperparameters(FLAGS.default_hparams_file, FLAGS.hparams_file, FLAGS, FLAGS.hparams) tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( FLAGS.tpu if (FLAGS.tpu or params['use_tpu']) else '', zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) if params['use_async_checkpointing']: save_checkpoints_steps = None else: save_checkpoints_steps = max(2500, params['iterations_per_loop']) config = tf.contrib.tpu.RunConfig( cluster=tpu_cluster_resolver, model_dir=get_model_dir(params), save_checkpoints_steps=save_checkpoints_steps, keep_checkpoint_max=None, # Keep all checkpoints. log_step_count_steps=FLAGS.log_step_count_steps, session_config=tf.ConfigProto( graph_options=tf.GraphOptions( rewrite_options=rewriter_config_pb2.RewriterConfig( disable_meta_optimizer=True))), tpu_config=tf.contrib.tpu.TPUConfig( iterations_per_loop=params['iterations_per_loop'], num_shards=params['num_cores'], # copybara:strip_begin tpu_job_name=FLAGS.tpu_job_name, # copybara:strip_end per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig .PER_HOST_V2)) # pylint: disable=line-too-long resnet_classifier = tf.contrib.tpu.TPUEstimator( use_tpu=params['use_tpu'], model_fn=resnet_model_fn, config=config, params=params, train_batch_size=params['train_batch_size'], eval_batch_size=params['eval_batch_size'], export_to_tpu=FLAGS.export_to_tpu) # copybara:strip_begin if FLAGS.xla_compile: resnet_classifier = tf.contrib.tpu.TPUEstimator( use_tpu=params['use_tpu'], model_fn=xla.estimator_model_fn(resnet_model_fn), config=config, params=params, train_batch_size=params['train_batch_size'], eval_batch_size=params['eval_batch_size'], export_to_tpu=FLAGS.export_to_tpu) # copybara:strip_end assert (params['precision'] == 'bfloat16' or params['precision'] == 'float32'), ('Invalid value for precision parameter; ' 'must be bfloat16 or float32.') tf.logging.info('Precision: %s', params['precision']) use_bfloat16 = params['precision'] == 'bfloat16' # Input pipelines are slightly different (with regards to shuffling and # preprocessing) between training and evaluation. if FLAGS.bigtable_instance: tf.logging.info('Using Bigtable dataset, table %s', FLAGS.bigtable_table) select_train, select_eval = _select_tables_from_flags() imagenet_train = imagenet_input.ImageNetBigtableInput( is_training=True, use_bfloat16=use_bfloat16, transpose_input=params['transpose_input'], selection=select_train) imagenet_eval = imagenet_input.ImageNetBigtableInput( is_training=False, use_bfloat16=use_bfloat16, transpose_input=params['transpose_input'], selection=select_eval) else: if FLAGS.data_dir == FAKE_DATA_DIR: tf.logging.info('Using fake dataset.') else: tf.logging.info('Using dataset: %s', FLAGS.data_dir) imagenet_train, imagenet_eval = [ imagenet_input.ImageNetInput( is_training=is_training, data_dir=FLAGS.data_dir, transpose_input=params['transpose_input'], cache=params['use_cache'] and is_training, image_size=params['image_size'], num_parallel_calls=params['num_parallel_calls'], use_bfloat16=use_bfloat16) for is_training in [True, False] ] steps_per_epoch = params['num_train_images'] // params['train_batch_size'] eval_steps = params['num_eval_images'] // params['eval_batch_size'] if FLAGS.mode == 'eval': # Run evaluation when there's a new checkpoint for ckpt in evaluation.checkpoints_iterator( get_model_dir(params), timeout=FLAGS.eval_timeout): tf.logging.info('Starting to evaluate.') try: start_timestamp = time.time( ) # This time will include compilation time eval_results = resnet_classifier.evaluate( input_fn=imagenet_eval.input_fn, steps=eval_steps, checkpoint_path=ckpt) elapsed_time = int(time.time() - start_timestamp) tf.logging.info('Eval results: %s. Elapsed seconds: %d', eval_results, elapsed_time) # Terminate eval job when final checkpoint is reached current_step = int(os.path.basename(ckpt).split('-')[1]) if current_step >= params['train_steps']: tf.logging.info( 'Evaluation finished after training step %d', current_step) break except tf.errors.NotFoundError: # Since the coordinator is on a different job than the TPU worker, # sometimes the TPU worker does not finish initializing until long after # the CPU job tells it to start evaluating. In this case, the checkpoint # file could have been deleted already. tf.logging.info( 'Checkpoint %s no longer exists, skipping checkpoint', ckpt) elif FLAGS.mode == 'eval_igt': # IGT evaluation mode. Evaluate metrics for the desired parameters # (true or shifted) on the desired dataset (train or eval). Note that # train is still with data augmentation. # Get checkpoint file names. index_files = tf.gfile.Glob( os.path.join(get_model_dir(params), 'model.ckpt-*.index')) checkpoints = [fn[:-len('.index')] for fn in index_files] # Need to sort them to get proper tensorboard plotting (increasing event # timestamps correspond to increasing steps). checkpoint_steps = [] for ckpt in checkpoints: tf.logging.info(ckpt) step_match = re.match(r'.*model.ckpt-([0-9]*)', ckpt) checkpoint_steps.append(int(step_match.group(1))) checkpoints = [ ckpt for _, ckpt in sorted(zip(checkpoint_steps, checkpoints)) ] tf.logging.info('There are {} checkpoints'.format(len(checkpoints))) tf.logging.info(', '.join(checkpoints)) # Keep track of the last processed checkpoint (fault tolerance). analysis_state_path = os.path.join( get_model_dir(params), 'analysis_state_' + FLAGS.igt_eval_set + '_' + FLAGS.igt_eval_mode) next_analysis_index = 0 if tf.gfile.Exists(analysis_state_path): with tf.gfile.Open(analysis_state_path) as fd: next_analysis_index = int(fd.read()) # Process each checkpoint. while next_analysis_index < len(checkpoints): tf.logging.info( 'Next analysis index: {}'.format(next_analysis_index)) ckpt_path = checkpoints[next_analysis_index] tf.logging.info('Starting to evaluate: {}.'.format(ckpt_path)) start_timestamp = time.time( ) # This time will include compilation time if FLAGS.igt_eval_set == 'train': the_input_fn = imagenet_train.input_fn the_steps = steps_per_epoch elif FLAGS.igt_eval_set == 'eval': the_input_fn = imagenet_eval.input_fn the_steps = eval_steps else: raise ValueError('Unsupported igt_eval_set') eval_results = resnet_classifier.evaluate( input_fn=the_input_fn, steps=the_steps, checkpoint_path=ckpt_path, name=FLAGS.igt_eval_set + '_' + FLAGS.igt_eval_mode) elapsed_time = int(time.time() - start_timestamp) tf.logging.info('Eval results: %s. Elapsed seconds: %d', eval_results, elapsed_time) next_analysis_index += 1 file_io.atomic_write_string_to_file(analysis_state_path, str(next_analysis_index)) else: # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval' current_step = estimator._load_global_step_from_checkpoint_dir( get_model_dir(params)) # pylint:disable=protected-access,g-line-too-long steps_per_epoch = params['num_train_images'] // params[ 'train_batch_size'] tf.logging.info( 'Training for %d steps (%.2f epochs in total). Current' ' step %d.', params['train_steps'], params['train_steps'] / steps_per_epoch, current_step) start_timestamp = time.time( ) # This time will include compilation time if FLAGS.mode == 'train': hooks = [] if params['use_async_checkpointing']: hooks.append( async_checkpoint.AsyncCheckpointSaverHook( checkpoint_dir=get_model_dir(params), save_steps=max(2500, params['iterations_per_loop']))) resnet_classifier.train(input_fn=imagenet_train.input_fn, max_steps=params['train_steps'], hooks=hooks) else: assert FLAGS.mode == 'train_and_eval' while current_step < params['train_steps']: # Train for up to steps_per_eval number of steps. # At the end of training, a checkpoint will be written to --model_dir. next_checkpoint = min(current_step + FLAGS.steps_per_eval, params['train_steps']) resnet_classifier.train(input_fn=imagenet_train.input_fn, max_steps=next_checkpoint) current_step = next_checkpoint tf.logging.info( 'Finished training up to step %d. Elapsed seconds %d.', next_checkpoint, int(time.time() - start_timestamp)) # Evaluate the model on the most recent model in --model_dir. # Since evaluation happens in batches of --eval_batch_size, some images # may be excluded modulo the batch size. As long as the batch size is # consistent, the evaluated images are also consistent. tf.logging.info('Starting to evaluate.') eval_results = resnet_classifier.evaluate( input_fn=imagenet_eval.input_fn, steps=params['num_eval_images'] // params['eval_batch_size']) tf.logging.info('Eval results at step %d: %s', next_checkpoint, eval_results) elapsed_time = int(time.time() - start_timestamp) tf.logging.info( 'Finished training up to step %d. Elapsed seconds %d.', params['train_steps'], elapsed_time) if FLAGS.export_dir is not None: # The guide to serve a exported TensorFlow model is at: # https://www.tensorflow.org/serving/serving_basic tf.logging.info('Starting to export model.') unused_export_path = resnet_classifier.export_saved_model( export_dir_base=FLAGS.export_dir, serving_input_receiver_fn=imagenet_input.image_serving_input_fn )
def testAtomicWriteStringToFile(self): file_path = os.path.join(self._base_dir, "temp_file") file_io.atomic_write_string_to_file(file_path, "testing") self.assertTrue(file_io.file_exists(file_path)) file_contents = file_io.read_file_to_string(file_path) self.assertEqual("testing", file_contents)
def safe_json_dump(filepath, obj, **kwargs): string = safe_json_dumps(obj, **kwargs) file_io.atomic_write_string_to_file(filepath, string)
def write(self, model: Model, path: str) -> None: graph: tf.Graph = Exporter.export_graph(model) graph_def = graph.as_graph_def(add_shapes=True) file_io.atomic_write_string_to_file(path, graph_def.SerializeToString())
def upload_schema(self): # type: () -> None file_io.atomic_write_string_to_file(self.schema_snapshot_path, self.schema.SerializeToString())
def save(obj, export_dir, signatures=None, options=None): # pylint: disable=line-too-long """Exports the Trackable object `obj` to [SavedModel format](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md). Example usage: ```python class Adder(tf.Module): @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)]) def add(self, x): return x + x + 1. to_export = Adder() tf.saved_model.save(to_export, '/tmp/adder') ``` The resulting SavedModel is then servable with an input named "x", its value having any shape and dtype float32. The optional `signatures` argument controls which methods in `obj` will be available to programs which consume `SavedModel`s, for example serving APIs. Python functions may be decorated with `@tf.function(input_signature=...)` and passed as signatures directly, or lazily with a call to `get_concrete_function` on the method decorated with `@tf.function`. If the `signatures` argument is omitted, `obj` will be searched for `@tf.function`-decorated methods. If exactly one `@tf.function` is found, that method will be used as the default signature for the SavedModel. This behavior is expected to change in the future, when a corresponding `tf.saved_model.load` symbol is added. At that point signatures will be completely optional, and any `@tf.function` attached to `obj` or its dependencies will be exported for use with `load`. When invoking a signature in an exported SavedModel, `Tensor` arguments are identified by name. These names will come from the Python function's argument names by default. They may be overridden by specifying a `name=...` argument in the corresponding `tf.TensorSpec` object. Explicit naming is required if multiple `Tensor`s are passed through a single argument to the Python function. The outputs of functions used as `signatures` must either be flat lists, in which case outputs will be numbered, or a dictionary mapping string keys to `Tensor`, in which case the keys will be used to name outputs. Signatures are available in objects returned by `tf.saved_model.load` as a `.signatures` attribute. This is a reserved attribute: `tf.saved_model.save` on an object with a custom `.signatures` attribute will raise an exception. Since `tf.keras.Model` objects are also Trackable, this function can be used to export Keras models. For example, exporting with a signature specified: ```python class Model(tf.keras.Model): @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)]) def serve(self, serialized): ... m = Model() tf.saved_model.save(m, '/tmp/saved_model/') ``` Exporting from a function without a fixed signature: ```python class Model(tf.keras.Model): @tf.function def call(self, x): ... m = Model() tf.saved_model.save( m, '/tmp/saved_model/', signatures=m.call.get_concrete_function( tf.TensorSpec(shape=[None, 3], dtype=tf.float32, name="inp"))) ``` `tf.keras.Model` instances constructed from inputs and outputs already have a signature and so do not require a `@tf.function` decorator or a `signatures` argument. If neither are specified, the model's forward pass is exported. ```python x = input_layer.Input((4,), name="x") y = core.Dense(5, name="out")(x) model = training.Model(x, y) tf.saved_model.save(model, '/tmp/saved_model/') # The exported SavedModel takes "x" with shape [None, 4] and returns "out" # with shape [None, 5] ``` Variables must be tracked by assigning them to an attribute of a tracked object or to an attribute of `obj` directly. TensorFlow objects (e.g. layers from `tf.keras.layers`, optimizers from `tf.train`) track their variables automatically. This is the same tracking scheme that `tf.train.Checkpoint` uses, and an exported `Checkpoint` object may be restored as a training checkpoint by pointing `tf.train.Checkpoint.restore` to the SavedModel's "variables/" subdirectory. Currently variables are the only stateful objects supported by `tf.saved_model.save`, but others (e.g. tables) will be supported in the future. `tf.function` does not hard-code device annotations from outside the function body, instead using the calling context's device. This means for example that exporting a model which runs on a GPU and serving it on a CPU will generally work, with some exceptions. `tf.device` annotations inside the body of the function will be hard-coded in the exported model; this type of annotation is discouraged. Device-specific operations, e.g. with "cuDNN" in the name or with device-specific layouts, may cause issues. Currently a `DistributionStrategy` is another exception: active distribution strategies will cause device placements to be hard-coded in a function. Exporting a single-device computation and importing under a `DistributionStrategy` is not currently supported, but may be in the future. SavedModels exported with `tf.saved_model.save` [strip default-valued attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes) automatically, which removes one source of incompatibilities when the consumer of a SavedModel is running an older TensorFlow version than the producer. There are however other sources of incompatibilities which are not handled automatically, such as when the exported model contains operations which the consumer does not have definitions for. Args: obj: A trackable object to export. export_dir: A directory in which to write the SavedModel. signatures: Optional, either a `tf.function` with an input signature specified or the result of `f.get_concrete_function` on a `@tf.function`-decorated function `f`, in which case `f` will be used to generate a signature for the SavedModel under the default serving signature key. `signatures` may also be a dictionary, in which case it maps from signature keys to either `tf.function` instances with input signatures or concrete functions. The keys of such a dictionary may be arbitrary strings, but will typically be from the `tf.saved_model.signature_constants` module. options: Optional, `tf.saved_model.SaveOptions` object that specifies options for saving. Raises: ValueError: If `obj` is not trackable. @compatibility(eager) Not well supported when graph building. From TensorFlow 1.x, `tf.compat.v1.enable_eager_execution()` should run first. Calling tf.saved_model.save in a loop when graph building from TensorFlow 1.x will add new save operations to the default graph each iteration. May not be called from within a function body. @end_compatibility """ if ops.inside_function(): raise AssertionError( "tf.saved_model.save is not supported inside a traced " "@tf.function. Move the call to the outer eagerly-executed " "context.") # pylint: enable=line-too-long if not isinstance(obj, base.Trackable): raise ValueError( "Expected a Trackable object for export, got {}.".format(obj)) options = options or save_options.SaveOptions() checkpoint_graph_view = _AugmentedGraphView(obj) if signatures is None: signatures = signature_serialization.find_function_to_export( checkpoint_graph_view) signatures = signature_serialization.canonicalize_signatures(signatures) signature_serialization.validate_saveable_view(checkpoint_graph_view) signature_map = signature_serialization.create_signature_map(signatures) checkpoint_graph_view.add_object( parent_node=checkpoint_graph_view.root, name_in_parent=signature_serialization.SIGNATURE_ATTRIBUTE_NAME, subgraph_root=signature_map) # Use _SaveableView to provide a frozen listing of properties and functions. # Note we run this twice since, while constructing the view the first time # there can be side effects of creating variables. _ = _SaveableView(checkpoint_graph_view) saveable_view = _SaveableView(checkpoint_graph_view) # TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x # compatible (no sessions) and share it with this export API rather than # making a SavedModel proto and writing it directly. saved_model = saved_model_pb2.SavedModel() meta_graph_def = saved_model.meta_graphs.add() object_saver = util.TrackableSaver(checkpoint_graph_view) asset_info, exported_graph = _fill_meta_graph_def( meta_graph_def, saveable_view, signatures, options.namespace_whitelist) saved_model.saved_model_schema_version = ( constants.SAVED_MODEL_SCHEMA_VERSION) # So far we've just been generating protocol buffers with no I/O. Now we write # the checkpoint, copy assets into the assets directory, and write out the # SavedModel proto itself. utils_impl.get_or_create_variables_dir(export_dir) object_saver.save(utils_impl.get_variables_path(export_dir)) builder_impl.copy_assets_to_destination_dir(asset_info.asset_filename_map, export_dir) path = os.path.join( compat.as_str(export_dir), compat.as_str(constants.SAVED_MODEL_FILENAME_PB)) object_graph_proto = _serialize_object_graph( saveable_view, asset_info.asset_index) meta_graph_def.object_graph_def.CopyFrom(object_graph_proto) # Save debug info, if requested. if options.save_debug_info: graph_debug_info = _export_debug_info(exported_graph) file_io.atomic_write_string_to_file( os.path.join( utils_impl.get_or_create_debug_dir(export_dir), constants.DEBUG_INFO_FILENAME_PB), graph_debug_info.SerializeToString(deterministic=True)) # Note that this needs to be the last file operation when saving the # SavedModel. Users rely on checking saved_model_dir/saved_model.pb as an # indication that the SavedModel is completely written. file_io.atomic_write_string_to_file( path, saved_model.SerializeToString(deterministic=True)) # Clean reference cycles so repeated export()s don't make work for the garbage # collector. Before this point we need to keep references to captured # constants in the saved graph. ops.dismantle_graph(exported_graph)
def update_checkpoint_state_internal(save_dir, model_checkpoint_path, all_model_checkpoint_paths=None, latest_filename=None, save_relative_paths=False, all_model_checkpoint_timestamps=None, last_preserved_timestamp=None): """Updates the content of the 'checkpoint' file. This updates the checkpoint file containing a CheckpointState proto. Args: save_dir: Directory where the model was saved. model_checkpoint_path: The checkpoint file. all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted checkpoints, sorted from oldest to newest. If this is a non-empty list, the last element must be equal to model_checkpoint_path. These paths are also saved in the CheckpointState proto. latest_filename: Optional name of the checkpoint file. Default to 'checkpoint'. save_relative_paths: If `True`, will write relative paths to the checkpoint state file. all_model_checkpoint_timestamps: Optional list of timestamps (floats, seconds since the Epoch) indicating when the checkpoints in `all_model_checkpoint_paths` were created. last_preserved_timestamp: A float, indicating the number of seconds since the Epoch when the last preserved checkpoint was written, e.g. due to a `keep_checkpoint_every_n_hours` parameter (see `tf.contrib.checkpoint.CheckpointManager` for an implementation). Raises: RuntimeError: If any of the model checkpoint paths conflict with the file containing CheckpointSate. """ # Writes the "checkpoint" file for the coordinator for later restoration. coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename) if save_relative_paths: if os.path.isabs(model_checkpoint_path): rel_model_checkpoint_path = os.path.relpath( model_checkpoint_path, save_dir) else: rel_model_checkpoint_path = model_checkpoint_path rel_all_model_checkpoint_paths = [] for p in all_model_checkpoint_paths: if os.path.isabs(p): rel_all_model_checkpoint_paths.append(os.path.relpath(p, save_dir)) else: rel_all_model_checkpoint_paths.append(p) ckpt = generate_checkpoint_state_proto( save_dir, rel_model_checkpoint_path, all_model_checkpoint_paths=rel_all_model_checkpoint_paths, all_model_checkpoint_timestamps=all_model_checkpoint_timestamps, last_preserved_timestamp=last_preserved_timestamp) else: ckpt = generate_checkpoint_state_proto( save_dir, model_checkpoint_path, all_model_checkpoint_paths=all_model_checkpoint_paths, all_model_checkpoint_timestamps=all_model_checkpoint_timestamps, last_preserved_timestamp=last_preserved_timestamp) if coord_checkpoint_filename == ckpt.model_checkpoint_path: raise RuntimeError("Save path '%s' conflicts with path used for " "checkpoint state. Please use a different save path." % model_checkpoint_path) # Preventing potential read/write race condition by *atomically* writing to a # file. file_io.atomic_write_string_to_file(coord_checkpoint_filename, text_format.MessageToString(ckpt))
def _SetState(self, state): file_io.atomic_write_string_to_file(self._state_file, text_format.MessageToString(state))
def update_checkpoint_state_internal(save_dir, model_checkpoint_path, all_model_checkpoint_paths=None, latest_filename=None, save_relative_paths=False): """Updates the content of the 'checkpoint' file. This updates the checkpoint file containing a CheckpointState proto. Args: save_dir: Directory where the model was saved. model_checkpoint_path: The checkpoint file. all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted checkpoints, sorted from oldest to newest. If this is a non-empty list, the last element must be equal to model_checkpoint_path. These paths are also saved in the CheckpointState proto. latest_filename: Optional name of the checkpoint file. Default to 'checkpoint'. save_relative_paths: If `True`, will write relative paths to the checkpoint state file. Raises: RuntimeError: If any of the model checkpoint paths conflict with the file containing CheckpointSate. """ # Writes the "checkpoint" file for the coordinator for later restoration. coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename) if save_relative_paths: if os.path.isabs(model_checkpoint_path): rel_model_checkpoint_path = os.path.relpath( model_checkpoint_path, save_dir) else: rel_model_checkpoint_path = model_checkpoint_path rel_all_model_checkpoint_paths = [] for p in all_model_checkpoint_paths: if os.path.isabs(p): rel_all_model_checkpoint_paths.append( os.path.relpath(p, save_dir)) else: rel_all_model_checkpoint_paths.append(p) ckpt = generate_checkpoint_state_proto( save_dir, rel_model_checkpoint_path, all_model_checkpoint_paths=rel_all_model_checkpoint_paths) else: ckpt = generate_checkpoint_state_proto( save_dir, model_checkpoint_path, all_model_checkpoint_paths=all_model_checkpoint_paths) if coord_checkpoint_filename == ckpt.model_checkpoint_path: raise RuntimeError( "Save path '%s' conflicts with path used for " "checkpoint state. Please use a different save path." % model_checkpoint_path) # Preventing potential read/write race condition by *atomically* writing to a # file. file_io.atomic_write_string_to_file(coord_checkpoint_filename, text_format.MessageToString(ckpt))