def testPathsWithParse(self): base_dir = os.path.join(tf.test.get_temp_dir(), "paths_parse") self.assertFalse(gfile.Exists(base_dir)) for p in xrange(3): gfile.MakeDirs(os.path.join(base_dir, "%d" % p)) # add a base_directory to ignore gfile.MakeDirs(os.path.join(base_dir, "ignore")) # create a simple parser that pulls the export_version from the directory. def parser(path): match = re.match("^" + base_dir + "/(\\d+)$", path.path) if not match: return None return path._replace(export_version=int(match.group(1))) self.assertEquals(gc.get_paths(base_dir, parser=parser), [ gc.Path(os.path.join(base_dir, "0"), 0), gc.Path(os.path.join(base_dir, "1"), 1), gc.Path(os.path.join(base_dir, "2"), 2) ])
def testPathsWithParse(self): base_dir = os.path.join(tf.test.get_temp_dir(), "paths_parse") self.assertFalse(gfile.Exists(base_dir)) for p in xrange(3): gfile.MakeDirs(os.path.join(base_dir, "%d" % p)) # add a base_directory to ignore gfile.MakeDirs(os.path.join(base_dir, "ignore")) # create a simple parser that pulls the export_version from the directory. def parser(path): match = re.match("^" + base_dir + "/(\\d+)$", path.path) if not match: return None return path._replace(export_version=int(match.group(1))) self.assertEquals( gc.get_paths(base_dir, parser=parser), [gc.Path(os.path.join(base_dir, "0"), 0), gc.Path(os.path.join(base_dir, "1"), 1), gc.Path(os.path.join(base_dir, "2"), 2)])
def export(self, export_dir_base, global_step_tensor, sess=None, exports_to_keep=None): """Exports the model. Args: export_dir_base: A string path to the base export dir. global_step_tensor: An Tensor or tensor name providing the global step counter to append to the export directory path and set in the manifest version. sess: A Session to use to save the parameters. exports_to_keep: a gc.Path filter function used to determine the set of exports to keep. If set to None, all versions will be kept. Raises: RuntimeError: if init is not called. RuntimeError: if the export would overwrite an existing directory. """ if not self._has_init: raise RuntimeError("init must be called first") global_step = training_util.global_step(sess, global_step_tensor) export_dir = os.path.join(export_dir_base, VERSION_FORMAT_SPECIFIER % global_step) # Prevent overwriting on existing exports which could lead to bad/corrupt # storage and loading of models. This is an important check that must be # done before any output files or directories are created. if gfile.Exists(export_dir): raise RuntimeError("Overwriting exports can cause corruption and are " "not allowed. Duplicate export dir: %s" % export_dir) # Output to a temporary directory which is atomically renamed to the final # directory when complete. tmp_export_dir = export_dir + "-tmp" gfile.MakeDirs(tmp_export_dir) self._saver.save(sess, os.path.join(tmp_export_dir, EXPORT_BASE_NAME), meta_graph_suffix=EXPORT_SUFFIX_NAME) # Run the asset callback. if self._assets_callback: assets_dir = os.path.join(tmp_export_dir, ASSETS_DIRECTORY) gfile.MakeDirs(assets_dir) self._assets_callback(assets_dir) # TODO(b/27794910): Delete *checkpoint* file before rename. gfile.Rename(tmp_export_dir, export_dir) if exports_to_keep: # create a simple parser that pulls the export_version from the directory. def parser(path): match = re.match("^" + export_dir_base + "/(\\d{8})$", path.path) if not match: return None return path._replace(export_version=int(match.group(1))) paths_to_delete = gc.negation(exports_to_keep) for p in paths_to_delete(gc.get_paths(export_dir_base, parser=parser)): gfile.DeleteRecursively(p.path)