def testCopy(self): gfile.MkDir(self.tmp + "dir1") gfile.MkDir(self.tmp + "dir2") with gfile.GFile(self.tmp + "dir1/file1", "w"): pass # Create file with gfile.GFile(self.tmp + "dir2/file2", "w"): pass # Create file # Dest file already exists, overwrite=False (default). self.assertRaises( OSError, lambda: gfile.Copy(self.tmp + "dir1/file1", self.tmp + "dir2/file2")) # Overwrite succeeds gfile.Copy(self.tmp + "dir1/file1", self.tmp + "dir2/file2", overwrite=True) self.assertTrue(gfile.Exists(self.tmp + "dir2/file2")) # Normal copy. gfile.Rename(self.tmp + "dir1/file1", self.tmp + "dir2/file1") self.assertTrue(gfile.Exists(self.tmp + "dir2/file1")) # Normal copy to non-existent dir self.assertRaises(OSError, lambda: gfile.Rename(self.tmp + "dir1/file1", self.tmp + "newdir/file1"))
def export_fn(estimator, export_dir_base, checkpoint_path=None): """Exports the given Estimator as a SavedModel and invokes post_export_fn. Args: estimator: the Estimator to export. export_dir_base: A string containing a directory to write the exported graphs and checkpoint. checkpoint_path: The checkpoint path to export. If None (the default), the most recent checkpoint found within the model directory is chosen. Returns: The string path to the SavedModel indicated by post_export_fn. Raises: ValueError: If `estimator` is a ${tf.estimator.Estimator} instance and `default_output_alternative_key` was specified or if post_export_fn does not return a valid directory. """ tmp_base_export_dir = tempfile.mkdtemp() tmp_base_export = base_export_strategy.export(estimator, tmp_base_export_dir, checkpoint_path) tmp_post_export_dir = tempfile.mkdtemp() tmp_post_export = post_export_fn(tmp_base_export, tmp_post_export_dir) if not tmp_post_export.startswith(tmp_post_export_dir): raise ValueError( 'post_export_fn must return a sub-directory of {}'.format( tmp_post_export_dir)) export_relpath = os.path.relpath(tmp_post_export, tmp_post_export_dir) gfile.Rename(os.path.join(tmp_post_export_dir, export_relpath), os.path.join(export_dir_base, export_relpath)) return os.path.join(export_dir_base, export_relpath)
def shuffle_records(fname): """Shuffle records in a single file.""" print("Shuffling records in file %s" % fname) # Rename file prior to shuffling tmp_fname = fname + ".unshuffled" gfile.Rename(fname, tmp_fname) reader = python_io.tf_record_iterator(tmp_fname) records = [] for record in reader: records.append(record) if len(records) % 100000 == 0: print("\tRead: %d", len(records)) random.shuffle(records) # Write shuffled records to original file name with python_io.TFRecordWriter(fname) as w: for count, record in enumerate(records): w.write(record) if count > 0 and count % 100000 == 0: print("\tWriting record: %d" % count) gfile.Remove(tmp_fname)
def test_file_operations(self): """Test file operations""" f = get_oss_path("test_file_operations") self.assertFalse(gfile.Exists(f)) fh = gfile.Open(f, mode="w") content = "file content" fh.write(content) fh.close() self.assertTrue(gfile.Exists(f)) fh = gfile.Open(f) self.assertEqual(fh.read(), content) self.assertEqual(gfile.Stat(f).length, len(content)) f2 = get_oss_path("test_file_2") gfile.Rename(f, f2) self.assertFalse(gfile.Exists(f)) self.assertTrue(gfile.Exists(f2)) f3 = get_oss_path("test_file_3") gfile.Copy(f2, f3, overwrite=True) self.assertTrue(gfile.Exists(f3))
def finish_example_id_dumper(self): self._tf_record_writer.close() self._tf_record_writer = None if self.dumped_example_number() > 0: fpath = self._get_dumped_fpath() gfile.Rename(self._tmp_fpath, fpath) return ExampleIdMeta(self._start_index, self._end_index, fpath) assert self._start_index == self._end_index gfile.Remove(self._tmp_fpath) return None
def finish_data_block(self): assert self._example_num == len(self._data_block_meta.example_ids) self._tf_record_writer.close() self._tf_record_writer = None if len(self._data_block_meta.example_ids) > 0: data_block_id = self._generate_data_block_id() data_block_path = os.path.join(self._get_data_block_dir(), data_block_id + DataBlockSuffix) gfile.Rename(self._tmp_fpath, data_block_path) self._data_block_meta.start_time = self._start_time self._data_block_meta.end_time = self._end_time self._data_block_meta.block_id = data_block_id meta_tmp_fpath = self._get_tmp_fpath() with tf.io.TFRecordWriter(meta_tmp_fpath) as meta_writer: meta_writer.write(self._data_block_meta.SerializeToString()) meta_path = os.path.join(self._get_data_block_dir(), data_block_id + DataBlockMetaSuffix) gfile.Rename(meta_tmp_fpath, meta_path) else: gfile.Remove(self._tmp_fpath)
def testRename(self): gfile.MkDir(self.tmp + "dir1") gfile.MkDir(self.tmp + "dir2") with gfile.GFile(self.tmp + "file1", "w"): pass # Create file with gfile.GFile(self.tmp + "file2", "w"): pass # Create file # Dest file already exists, overwrite=False (default). self.assertRaises( OSError, lambda: gfile.Rename(self.tmp + "file1", self.tmp + "file2")) gfile.Rename(self.tmp + "file1", self.tmp + "file2", overwrite=True) self.assertFalse(gfile.Exists(self.tmp + "file1")) gfile.Rename(self.tmp + "file2", self.tmp + "newfile") self.assertTrue(gfile.Exists(self.tmp + "newfile")) gfile.Rename(self.tmp + "dir1", self.tmp + "dir2") self.assertFalse(gfile.Exists(self.tmp + "dir1")) gfile.Rename(self.tmp + "dir2", self.tmp + "newdir") self.assertTrue(gfile.Exists(self.tmp + "newdir"))
def test_dir_operations(self): """ Test directory operations""" d = get_oss_path("d1/d2/d3/d4") gfile.MakeDirs(d) self.assertTrue(gfile.Stat(d).is_directory) # Test listing bucket directory with and without trailing '/' content = gfile.ListDirectory( "oss://%s\x01id=%s\x02key=%s\x02host=%s" % (bucket, access_id, access_key, host)) content_s = gfile.ListDirectory( "oss://%s\x01id=%s\x02key=%s\x02host=%s/" % (bucket, access_id, access_key, host)) self.assertEqual(content, content_s) self.assertIn("oss_fs_test", content) self.assertIn("oss_fs_test/d1", content) self.assertIn("oss_fs_test/d1/d2", content) # Test listing test directory with and without trailing '/' content = gfile.ListDirectory( "oss://%s\x01id=%s\x02key=%s\x02host=%s" % (bucket, access_id, access_key, host) + "/oss_fs_test") content_s = gfile.ListDirectory( "oss://%s\x01id=%s\x02key=%s\x02host=%s" % (bucket, access_id, access_key, host) + "/oss_fs_test/") self.assertEqual(content, content_s) self.assertIn("d1", content) self.assertIn("d1/d2", content) # Test listing sub directories. content = gfile.ListDirectory(get_oss_path("d1")) content_s = gfile.ListDirectory(get_oss_path("d1/")) self.assertEqual(content, content_s) self.assertIn("d2", content) content = gfile.ListDirectory(get_oss_path("d1/d2/d3/d4")) content_s = gfile.ListDirectory(get_oss_path("d1/d2/d3/d4")) self.assertEqual(content, content_s) self.assertEqual([], content) # Test Rename directories self.assertTrue(gfile.Exists(get_oss_path("d1"))) gfile.Rename(get_oss_path("d1"), get_oss_path("rename_d1"), overwrite=True) self.assertTrue(gfile.Exists(get_oss_path("rename_d1"))) self.assertFalse(gfile.Exists(get_oss_path("d1"))) content = gfile.ListDirectory(get_oss_path("rename_d1")) content_s = gfile.ListDirectory(get_oss_path("rename_d1/")) self.assertEqual(content, content_s) self.assertIn("d2", content)
def encode_and_save_files(subtokenizer, data_dir, raw_files, tag, total_shards): """Save data from files as encoded Examples in TFrecord format. Args: subtokenizer: Subtokenizer object that will be used to encode the strings. data_dir: The directory in which to write the examples raw_files: A tuple of (input, target) data files. Each line in the input and the corresponding line in target file will be saved in a tf.Example. tag: String that will be added onto the file names. total_shards: Number of files to divide the data into. Returns: List of all files produced. """ # Create a file for each shard. filepaths = [ shard_filename(data_dir, tag, n + 1, total_shards) for n in range(total_shards) ] if all_exist(filepaths): print("Files with tag %s already exist." % tag) return filepaths print("Saving files with tag %s." % tag) input_file = raw_files[0] target_file = raw_files[1] # Write examples to each shard in round robin order. tmp_filepaths = [fname + ".incomplete" for fname in filepaths] writers = [python_io.TFRecordWriter(fname) for fname in tmp_filepaths] counter, shard = 0, 0 for counter, (input_line, target_line) in enumerate( zip(txt_line_iterator(input_file), txt_line_iterator(target_file))): if counter > 0 and counter % 100000 == 0: print("\tSaving case %d." % counter) example = dict_to_example({ "inputs": subtokenizer.encode(input_line, add_eos=True), "targets": subtokenizer.encode(target_line, add_eos=True) }) writers[shard].write(example.SerializeToString()) shard = (shard + 1) % total_shards for writer in writers: writer.close() for tmp_name, final_name in zip(tmp_filepaths, filepaths): gfile.Rename(tmp_name, final_name) print("Saved %d Examples", counter) return filepaths
def export_fn(estimator, export_dir_base, checkpoint_path=None): """Exports the given Estimator as a SavedModel and invokes post_export_fn. Args: estimator: the Estimator to export. export_dir_base: A string containing a directory to write the exported graphs and checkpoint. checkpoint_path: The checkpoint path to export. If None (the default), the most recent checkpoint found within the model directory is chosen. Returns: The string path to the SavedModel indicated by post_export_fn. Raises: ValueError: If `estimator` is a ${tf.estimator.Estimator} instance and `default_output_alternative_key` was specified or if post_export_fn does not return a valid directory. RuntimeError: If unable to create temporary or final export directory. """ tmp_base_export_folder = 'temp-base-export-' + str(int(time.time())) tmp_base_export_dir = os.path.join(export_dir_base, tmp_base_export_folder) if gfile.Exists(tmp_base_export_dir): raise RuntimeError('Failed to obtain base export directory') gfile.MakeDirs(tmp_base_export_dir) tmp_base_export = base_export_strategy.export(estimator, tmp_base_export_dir, checkpoint_path) tmp_post_export_folder = 'temp-post-export-' + str(int(time.time())) tmp_post_export_dir = os.path.join(export_dir_base, tmp_post_export_folder) if gfile.Exists(tmp_post_export_dir): raise RuntimeError('Failed to obtain temp export directory') gfile.MakeDirs(tmp_post_export_dir) tmp_post_export = post_export_fn(tmp_base_export, tmp_post_export_dir) if not tmp_post_export.startswith(tmp_post_export_dir): raise ValueError( 'post_export_fn must return a sub-directory of {}'.format( tmp_post_export_dir)) post_export_relpath = os.path.relpath(tmp_post_export, tmp_post_export_dir) post_export = os.path.join(export_dir_base, post_export_relpath) if gfile.Exists(post_export): raise RuntimeError('Failed to obtain final export directory') gfile.Rename(tmp_post_export, post_export) gfile.DeleteRecursively(tmp_base_export_dir) gfile.DeleteRecursively(tmp_post_export_dir) return post_export
def test_rename_dir(self): """Test rename dir. """ # Setup and check preconditions. src_dir_name = "igfs:///test_rename_dir/1" dst_dir_name = "igfs:///test_rename_dir/2" gfile.MkDir(src_dir_name) # Rename directory. gfile.Rename(src_dir_name, dst_dir_name) # Check that only new name of directory is available. self.assertFalse(gfile.Exists(src_dir_name)) self.assertTrue(gfile.Exists(dst_dir_name)) self.assertTrue(gfile.IsDirectory(dst_dir_name))
def test_rename_file(self): """Test rename file. """ # Setup and check preconditions. src_file_name = "igfs:///test_rename_file/1" dst_file_name = "igfs:///test_rename_file/2" with gfile.Open(src_file_name, mode="w") as w: w.write("42") self.assertTrue(gfile.Exists(src_file_name)) # Rename file. gfile.Rename(src_file_name, dst_file_name) # Check that only new name of file is available. self.assertFalse(gfile.Exists(src_file_name)) self.assertTrue(gfile.Exists(dst_file_name)) with gfile.Open(dst_file_name, mode="r") as r: data = r.read() self.assertEqual("42", data)
def export_saved_model(estimator, export_dir_base, checkpoint_path, serving_input_receiver_fn, as_text=False): with context.graph_mode(): export_dir = export_helpers.get_timestamped_export_dir(export_dir_base) temp_export_dir = export_helpers.get_temp_export_dir(export_dir) builder = saved_model_builder.SavedModelBuilder(temp_export_dir) save_variables = True _add_meta_graph_for_mode(estimator, builder, serving_input_receiver_fn, checkpoint_path, save_variables) save_variables = False builder.save(as_text) if save_variables: raise ValueError('No valid modes for exporting found.') gfile.Rename(temp_export_dir, export_dir) return export_dir
def download_from_url(path, url): """Download content from a url. Args: path: string directory where file will be downloaded url: string url Returns: Full path to downloaded file """ filename = url.split("/")[-1] found_file = find_file(path, filename, max_depth=0) if found_file is None: filename = os.path.join(path, filename) print("Downloading from %s to %s." % (url, filename)) inprogress_filepath = filename + ".incomplete" inprogress_filepath, _ = urlretrieve(url, inprogress_filepath, reporthook=download_report_hook) # Print newline to clear the carriage return from the download progress. print() gfile.Rename(inprogress_filepath, filename) return filename else: print("Already downloaded: %s (at %s)." % (url, found_file)) return found_file
def export(self, export_dir_base, global_step_tensor, sess=None, exports_to_keep=None): """Exports the model. Args: export_dir_base: A string path to the base export dir. global_step_tensor: An Tensor or tensor name providing the global step counter to append to the export directory path and set in the manifest version. sess: A Session to use to save the parameters. exports_to_keep: a gc.Path filter function used to determine the set of exports to keep. If set to None, all versions will be kept. Raises: RuntimeError: if init is not called. RuntimeError: if the export would overwrite an existing directory. """ if not self._has_init: raise RuntimeError("init must be called first") global_step = training_util.global_step(sess, global_step_tensor) export_dir = os.path.join(export_dir_base, VERSION_FORMAT_SPECIFIER % global_step) # Prevent overwriting on existing exports which could lead to bad/corrupt # storage and loading of models. This is an important check that must be # done before any output files or directories are created. if gfile.Exists(export_dir): raise RuntimeError("Overwriting exports can cause corruption and are " "not allowed. Duplicate export dir: %s" % export_dir) # Output to a temporary directory which is atomically renamed to the final # directory when complete. tmp_export_dir = export_dir + "-tmp" gfile.MakeDirs(tmp_export_dir) self._saver.save(sess, os.path.join(tmp_export_dir, EXPORT_BASE_NAME), meta_graph_suffix=EXPORT_SUFFIX_NAME) # Run the asset callback. if self._assets_callback: assets_dir = os.path.join(tmp_export_dir, ASSETS_DIRECTORY) gfile.MakeDirs(assets_dir) self._assets_callback(assets_dir) # TODO(b/27794910): Delete *checkpoint* file before rename. gfile.Rename(tmp_export_dir, export_dir) if exports_to_keep: # create a simple parser that pulls the export_version from the directory. def parser(path): match = re.match("^" + export_dir_base + "/(\\d{8})$", path.path) if not match: return None return path._replace(export_version=int(match.group(1))) paths_to_delete = gc.negation(exports_to_keep) for p in paths_to_delete(gc.get_paths(export_dir_base, parser=parser)): gfile.DeleteRecursively(p.path)
def export_savedmodel( self, export_dir_base, serving_input_receiver_fn, assets_extra=None, as_text=False, checkpoint_path=None): """Exports inference graph as a SavedModel into given dir. This method builds a new graph by first calling the serving_input_receiver_fn to obtain feature `Tensor`s, and then calling this `Estimator`'s model_fn to generate the model graph based on those features. It restores the given checkpoint (or, lacking that, the most recent checkpoint) into this graph in a fresh session. Finally it creates a timestamped export directory below the given export_dir_base, and writes a `SavedModel` into it containing a single `MetaGraphDef` saved from this session. The exported `MetaGraphDef` will provide one `SignatureDef` for each element of the export_outputs dict returned from the model_fn, named using the same keys. One of these keys is always signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, indicating which signature will be served when a serving request does not specify one. For each signature, the outputs are provided by the corresponding `ExportOutput`s, and the inputs are always the input receivers provided by the serving_input_receiver_fn. Extra assets may be written into the SavedModel via the extra_assets argument. This should be a dict, where each key gives a destination path (including the filename) relative to the assets.extra directory. The corresponding value gives the full path of the source file to be copied. For example, the simple case of copying a single file without renaming it is specified as `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`. Args: export_dir_base: A string containing a directory in which to create timestamped subdirectories containing exported SavedModels. serving_input_receiver_fn: A function that takes no argument and returns a `ServingInputReceiver`. assets_extra: A dict specifying how to populate the assets.extra directory within the exported SavedModel, or `None` if no extra assets are needed. as_text: whether to write the SavedModel proto in text format. checkpoint_path: The checkpoint path to export. If `None` (the default), the most recent checkpoint found within the model directory is chosen. Returns: The string path to the exported directory. Raises: ValueError: if no serving_input_receiver_fn is provided, no export_outputs are provided, or no checkpoint can be found. """ if serving_input_receiver_fn is None: raise ValueError('serving_input_receiver_fn must be defined.') with ops.Graph().as_default() as g: self._create_and_assert_global_step(g) random_seed.set_random_seed(self._config.tf_random_seed) serving_input_receiver = serving_input_receiver_fn() # Call the model_fn and collect the export_outputs. estimator_spec = self._call_model_fn( features=serving_input_receiver.features, labels=None, mode=model_fn_lib.ModeKeys.PREDICT, config=self.config) # Build the SignatureDefs from receivers and all outputs signature_def_map = build_all_signature_defs( serving_input_receiver.receiver_tensors, estimator_spec.export_outputs, serving_input_receiver.receiver_tensors_alternatives) if not checkpoint_path: # Locate the latest checkpoint checkpoint_path = saver.latest_checkpoint(self._model_dir) if not checkpoint_path: raise ValueError("Couldn't find trained model at %s." % self._model_dir) export_dir = get_timestamped_export_dir(export_dir_base) temp_export_dir = get_temp_export_dir(export_dir) # TODO(soergel): Consider whether MonitoredSession makes sense here with tf_session.Session() as session: saver_for_restore = estimator_spec.scaffold.saver or saver.Saver( sharded=True) saver_for_restore.restore(session, checkpoint_path) # TODO(b/36111876): replace legacy_init_op with main_op mechanism # pylint: disable=protected-access local_init_op = ( estimator_spec.scaffold.local_init_op or monitored_session.Scaffold._default_local_init_op()) # pylint: enable=protected-access # Perform the export builder = saved_model_builder.SavedModelBuilder(temp_export_dir) builder.add_meta_graph_and_variables( session, [tag_constants.SERVING], signature_def_map=signature_def_map, assets_collection=ops.get_collection( ops.GraphKeys.ASSET_FILEPATHS), legacy_init_op=local_init_op) builder.save(as_text) # Add the extra assets if assets_extra: assets_extra_path = os.path.join(compat.as_bytes(temp_export_dir), compat.as_bytes('assets.extra')) for dest_relative, source in assets_extra.items(): dest_absolute = os.path.join(compat.as_bytes(assets_extra_path), compat.as_bytes(dest_relative)) dest_path = os.path.dirname(dest_absolute) gfile.MakeDirs(dest_path) gfile.Copy(source, dest_absolute) gfile.Rename(temp_export_dir, export_dir) return export_dir
def export_eval_savedmodel( estimator, export_dir_base, eval_input_receiver_fn, checkpoint_path = None): """Export a EvalSavedModel for the given estimator. Args: estimator: Estimator to export the graph for. export_dir_base: Base path for export. Graph will be exported into a subdirectory of this base path. eval_input_receiver_fn: Eval input receiver function. checkpoint_path: Path to a specific checkpoint to export. If set to None, exports the latest checkpoint. Returns: Path to the directory where the eval graph was exported. Raises: ValueError: Could not find a checkpoint to export. """ with tf.Graph().as_default() as g: eval_input_receiver = eval_input_receiver_fn() tf.train.create_global_step(g) tf.set_random_seed(estimator.config.tf_random_seed) # Workaround for TensorFlow issue #17568. Note that we pass the # identity-wrapped features and labels to model_fn, but we have to feed # the non-identity wrapped Tensors during evaluation. # # Also note that we can't wrap predictions, so metrics that have control # dependencies on predictions will cause the predictions to be recomputed # during their evaluation. wrapped_features = util.wrap_tensor_or_dict_of_tensors_in_identity( eval_input_receiver.features) wrapped_labels = util.wrap_tensor_or_dict_of_tensors_in_identity( eval_input_receiver.labels) if isinstance(estimator, tf.estimator.Estimator): # This is a core estimator estimator_spec = estimator.model_fn( features=wrapped_features, labels=wrapped_labels, mode=tf.estimator.ModeKeys.EVAL, config=estimator.config) else: # This is a contrib estimator model_fn_ops = estimator._call_model_fn( # pylint: disable=protected-access features=wrapped_features, labels=wrapped_labels, mode=tf.estimator.ModeKeys.EVAL) estimator_spec = lambda x: None estimator_spec.predictions = model_fn_ops.predictions estimator_spec.eval_metric_ops = model_fn_ops.eval_metric_ops estimator_spec.scaffold = model_fn_ops.scaffold # Save metric using eval_metric_ops. for user_metric_key, (value_op, update_op) in ( estimator_spec.eval_metric_ops.items()): tf.add_to_collection('%s/%s' % (encoding.METRICS_COLLECTION, encoding.KEY_SUFFIX), encoding.encode_key(user_metric_key)) tf.add_to_collection('%s/%s' % (encoding.METRICS_COLLECTION, encoding.VALUE_OP_SUFFIX), encoding.encode_tensor_node(value_op)) tf.add_to_collection('%s/%s' % (encoding.METRICS_COLLECTION, encoding.UPDATE_OP_SUFFIX), encoding.encode_tensor_node(update_op)) # Save all prediction nodes. # Predictions can either be a Tensor, or a dict of Tensors. predictions = estimator_spec.predictions if not isinstance(predictions, dict): predictions = {encoding.DEFAULT_PREDICTIONS_DICT_KEY: predictions} for prediction_key, prediction_node in predictions.items(): _encode_and_add_to_node_collection(encoding.PREDICTIONS_COLLECTION, prediction_key, prediction_node) ############################################################ ## Features, label (and weight) graph # Placeholder for input example to label graph. tf.add_to_collection(encoding.INPUT_EXAMPLE_COLLECTION, encoding.encode_tensor_node( eval_input_receiver.receiver_tensors['examples'])) # Save all label nodes. # Labels can either be a Tensor, or a dict of Tensors. labels = eval_input_receiver.labels if not isinstance(labels, dict): labels = {encoding.DEFAULT_LABELS_DICT_KEY: labels} for label_key, label_node in labels.items(): _encode_and_add_to_node_collection(encoding.LABELS_COLLECTION, label_key, label_node) # Save features. for feature_name, feature_node in eval_input_receiver.features.items(): _encode_and_add_to_node_collection(encoding.FEATURES_COLLECTION, feature_name, feature_node) ############################################################ ## Export as normal if not checkpoint_path: checkpoint_path = tf.train.latest_checkpoint(estimator.model_dir) if not checkpoint_path: raise ValueError( 'Could not find trained model at %s.' % estimator.model_dir) export_dir = _get_timestamped_export_dir(export_dir_base) temp_export_dir = _get_temp_export_dir(export_dir) if estimator.config.session_config is None: session_config = config_pb2.ConfigProto(allow_soft_placement=True) else: session_config = estimator.config.session_config with tf.Session(config=session_config) as session: if estimator_spec.scaffold and estimator_spec.scaffold.saver: saver_for_restore = estimator_spec.scaffold.saver else: saver_for_restore = tf.train.Saver(sharded=True) saver_for_restore.restore(session, checkpoint_path) if estimator_spec.scaffold and estimator_spec.scaffold.local_init_op: local_init_op = estimator_spec.scaffold.local_init_op else: local_init_op = tf.train.Scaffold._default_local_init_op() # pylint: enable=protected-access # Perform the export builder = tf.saved_model.builder.SavedModelBuilder(temp_export_dir) builder.add_meta_graph_and_variables( session, [tf.saved_model.tag_constants.SERVING], # Don't export any signatures, since this graph is not actually # meant for serving. signature_def_map=None, assets_collection=tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS), legacy_init_op=local_init_op) builder.save(False) gfile.Rename(temp_export_dir, export_dir) return export_dir
def save_keras_model(model, saved_model_path, custom_objects=None, as_text=None): """Save a `tf.keras.Model` into Tensorflow SavedModel format. `save_model` generates new files/folders under the `saved_model_path` folder: 1) an asset folder containing the json string of the model's configuration (topology). 2) a checkpoint containing the model weights. 3) a saved_model.pb file containing the model's MetaGraphs. The prediction graph is always exported. The evaluaton and training graphs are exported if the following conditions are met: - Evaluation: model loss is defined. - Training: model is compiled with an optimizer defined under `tf.train`. This is because `tf.keras.optimizers.Optimizer` instances cannot be saved to checkpoints. Model Requirements: - Model must be a sequential model or functional model. Subclassed models can not be saved via this function, unless you provide an implementation for get_config() and from_config(). - All variables must be saveable by the model. In general, this condition is met through the use of layers defined in the keras library. However, there is currently a bug with variables created in Lambda layer functions not being saved correctly (see https://github.com/keras-team/keras/issues/9740). Note that each mode is exported in separate graphs, so different modes do not share variables. To use the train graph with evaluation or prediction graphs, create a new checkpoint if variable values have been updated. Args: model: A `tf.keras.Model` to be saved. saved_model_path: a string specifying the path to the SavedModel directory. The SavedModel will be saved to a timestamped folder created within this directory. custom_objects: Optional dictionary mapping string names to custom classes or functions (e.g. custom loss functions). as_text: whether to write the `SavedModel` proto in text format. Returns: String path to the SavedModel folder, a subdirectory of `saved_model_path`. Raises: NotImplementedError: If the model is a subclassed model. ValueError: If a Sequential model does not have input shapes defined by the user, and is not built. """ if not model._is_graph_network: if isinstance(model, sequential.Sequential): # If input shape is not directly set in the model, the exported model # will assume that the inputs have the same shape as the shape the model # was built model with. if not model.built: raise ValueError( 'Sequential model must be built before it can be exported.' ) else: raise NotImplementedError( 'Exporting subclassed models is not yet supported.') export_dir = export_helpers.get_timestamped_export_dir(saved_model_path) temp_export_dir = export_helpers.get_temp_export_dir(export_dir) builder = saved_model_builder.SavedModelBuilder(temp_export_dir) # Manually save variables to export them in an object-based checkpoint. This # skips the `builder.add_meta_graph_and_variables()` step, which saves a # named-based checkpoint. # TODO(b/113134168): Add fn to Builder to save with object-based saver. # TODO(b/113178242): This should only export the model json structure. Only # one save is needed once the weights can be copied from the model to clone. checkpoint_path = _export_model_json_and_variables(model, temp_export_dir) # Export each mode. Use ModeKeys enums defined for `Estimator` to ensure that # Keras models and `Estimator`s are exported with the same format. # Every time a mode is exported, the code checks to see if new variables have # been created (e.g. optimizer slot variables). If that is the case, the # checkpoint is re-saved to include the new variables. export_args = { 'builder': builder, 'model': model, 'custom_objects': custom_objects, 'checkpoint_path': checkpoint_path } has_saved_vars = False if model.optimizer: if isinstance(model.optimizer, optimizers.TFOptimizer): _export_mode(model_fn_lib.ModeKeys.TRAIN, has_saved_vars, **export_args) has_saved_vars = True _export_mode(model_fn_lib.ModeKeys.EVAL, has_saved_vars, **export_args) else: logging.warning( 'Model was compiled with an optimizer, but the optimizer is not from ' '`tf.train` (e.g. `tf.train.AdagradOptimizer`). Only the serving ' 'graph was exported. The train and evaluate graphs were not added to ' 'the SavedModel.') _export_mode(model_fn_lib.ModeKeys.PREDICT, has_saved_vars, **export_args) builder.save(as_text) gfile.Rename(temp_export_dir, export_dir) return export_dir
def _write_with_backup(filename, content): if gfile.Exists(filename): gfile.Rename(filename, filename + '.old', overwrite=True) with gfile.Open(filename, 'w') as f: f.write(content)
def test_dir_operations(self): """Test directory operations""" d = get_oss_path("d1/d2/d3/d4") gfile.MakeDirs(d) self.assertTrue(gfile.Stat(d).is_directory) # Test listing bucket directory with and without trailing '/' content = gfile.ListDirectory( "oss://%s\x01id=%s\x02key=%s\x02host=%s" % (bucket, access_id, access_key, host) ) content_s = gfile.ListDirectory( "oss://%s\x01id=%s\x02key=%s\x02host=%s/" % (bucket, access_id, access_key, host) ) self.assertEqual(content, content_s) self.assertIn("oss_fs_test", content) self.assertIn("oss_fs_test/d1", content) self.assertIn("oss_fs_test/d1/d2", content) # Test listing test directory with and without trailing '/' content = gfile.ListDirectory( "oss://%s\x01id=%s\x02key=%s\x02host=%s" % (bucket, access_id, access_key, host) + "/oss_fs_test" ) content_s = gfile.ListDirectory( "oss://%s\x01id=%s\x02key=%s\x02host=%s" % (bucket, access_id, access_key, host) + "/oss_fs_test/" ) self.assertEqual(content, content_s) self.assertIn("d1", content) self.assertIn("d1/d2", content) # Test listing sub directories. content = gfile.ListDirectory(get_oss_path("d1")) content_s = gfile.ListDirectory(get_oss_path("d1/")) self.assertEqual(content, content_s) self.assertIn("d2", content) content = gfile.ListDirectory(get_oss_path("d1/d2/d3/d4")) content_s = gfile.ListDirectory(get_oss_path("d1/d2/d3/d4")) self.assertEqual(content, content_s) self.assertEqual([], content) # Test Rename directories self.assertTrue(gfile.Exists(get_oss_path("d1"))) gfile.Rename(get_oss_path("d1"), get_oss_path("rename_d1"), overwrite=True) self.assertTrue(gfile.Exists(get_oss_path("rename_d1"))) self.assertFalse(gfile.Exists(get_oss_path("d1"))) content = gfile.ListDirectory(get_oss_path("rename_d1")) content_s = gfile.ListDirectory(get_oss_path("rename_d1/")) self.assertEqual(content, content_s) self.assertIn("d2", content) # Test Rename non-empty directories not_empty_dir = get_oss_path("not_empty_dir/") rename_not_empty_dir = get_oss_path("rename_not_empty_dir/") gfile.MakeDirs(not_empty_dir) not_empty_file = get_oss_path("not_empty_dir/not_empty_file") rename_not_empty_file = get_oss_path("rename_not_empty_dir/not_empty_file") with gfile.Open(not_empty_file, mode="w") as fh: content = "file content" fh.write(content) self.assertTrue(gfile.Exists(not_empty_dir)) self.assertTrue(gfile.Exists(not_empty_file)) gfile.Rename(not_empty_dir, rename_not_empty_dir, overwrite=True) self.assertFalse(gfile.Exists(not_empty_dir)) self.assertFalse(gfile.Exists(not_empty_file)) self.assertTrue(gfile.Exists(rename_not_empty_dir)) self.assertTrue(gfile.Exists(rename_not_empty_file))