def testNegation(self): paths = [gc.Path("/foo", 4), gc.Path("/foo", 5), gc.Path("/foo", 6), gc.Path("/foo", 9)] mod = gc.negation(gc.mod_export_version(2)) self.assertEquals( mod(paths), [gc.Path("/foo", 5), gc.Path("/foo", 9)]) mod = gc.negation(gc.mod_export_version(3)) self.assertEquals( mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 5)])
def testNegation(self): paths = [ gc.Path("/foo", 4), gc.Path("/foo", 5), gc.Path("/foo", 6), gc.Path("/foo", 9) ] mod = gc.negation(gc.mod_export_version(2)) self.assertEquals(mod(paths), [gc.Path("/foo", 5), gc.Path("/foo", 9)]) mod = gc.negation(gc.mod_export_version(3)) self.assertEquals(mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 5)])
def export(self, export_dir_base, global_step_tensor, sess=None, exports_to_keep=None): """Exports the model. Args: export_dir_base: A string path to the base export dir. global_step_tensor: An Tensor or tensor name providing the global step counter to append to the export directory path and set in the manifest version. sess: A Session to use to save the parameters. exports_to_keep: a gc.Path filter function used to determine the set of exports to keep. If set to None, all versions will be kept. Raises: RuntimeError: if init is not called. RuntimeError: if the export would overwrite an existing directory. """ if not self._has_init: raise RuntimeError("init must be called first") global_step = training_util.global_step(sess, global_step_tensor) export_dir = os.path.join(export_dir_base, VERSION_FORMAT_SPECIFIER % global_step) # Prevent overwriting on existing exports which could lead to bad/corrupt # storage and loading of models. This is an important check that must be # done before any output files or directories are created. if gfile.Exists(export_dir): raise RuntimeError("Overwriting exports can cause corruption and are " "not allowed. Duplicate export dir: %s" % export_dir) # Output to a temporary directory which is atomically renamed to the final # directory when complete. tmp_export_dir = export_dir + "-tmp" gfile.MakeDirs(tmp_export_dir) self._saver.save(sess, os.path.join(tmp_export_dir, EXPORT_BASE_NAME), meta_graph_suffix=EXPORT_SUFFIX_NAME) # Run the asset callback. if self._assets_callback: assets_dir = os.path.join(tmp_export_dir, ASSETS_DIRECTORY) gfile.MakeDirs(assets_dir) self._assets_callback(assets_dir) # TODO(b/27794910): Delete *checkpoint* file before rename. gfile.Rename(tmp_export_dir, export_dir) if exports_to_keep: # create a simple parser that pulls the export_version from the directory. def parser(path): match = re.match("^" + export_dir_base + "/(\\d{8})$", path.path) if not match: return None return path._replace(export_version=int(match.group(1))) paths_to_delete = gc.negation(exports_to_keep) for p in paths_to_delete(gc.get_paths(export_dir_base, parser=parser)): gfile.DeleteRecursively(p.path)