Beispiel #1
0
    def testPathsWithParse(self):
        base_dir = os.path.join(tf.test.get_temp_dir(), "paths_parse")
        self.assertFalse(gfile.Exists(base_dir))
        for p in xrange(3):
            gfile.MakeDirs(os.path.join(base_dir, "%d" % p))
        # add a base_directory to ignore
        gfile.MakeDirs(os.path.join(base_dir, "ignore"))

        # create a simple parser that pulls the export_version from the directory.
        def parser(path):
            match = re.match("^" + base_dir + "/(\\d+)$", path.path)
            if not match:
                return None
            return path._replace(export_version=int(match.group(1)))

        self.assertEquals(gc.get_paths(base_dir, parser=parser), [
            gc.Path(os.path.join(base_dir, "0"), 0),
            gc.Path(os.path.join(base_dir, "1"), 1),
            gc.Path(os.path.join(base_dir, "2"), 2)
        ])
Beispiel #2
0
  def testPathsWithParse(self):
    base_dir = os.path.join(tf.test.get_temp_dir(), "paths_parse")
    self.assertFalse(gfile.Exists(base_dir))
    for p in xrange(3):
      gfile.MakeDirs(os.path.join(base_dir, "%d" % p))
    # add a base_directory to ignore
    gfile.MakeDirs(os.path.join(base_dir, "ignore"))

    # create a simple parser that pulls the export_version from the directory.
    def parser(path):
      match = re.match("^" + base_dir + "/(\\d+)$", path.path)
      if not match:
        return None
      return path._replace(export_version=int(match.group(1)))

    self.assertEquals(
        gc.get_paths(base_dir, parser=parser),
        [gc.Path(os.path.join(base_dir, "0"), 0),
         gc.Path(os.path.join(base_dir, "1"), 1),
         gc.Path(os.path.join(base_dir, "2"), 2)])
Beispiel #3
0
  def export(self,
             export_dir_base,
             global_step_tensor,
             sess=None,
             exports_to_keep=None):
    """Exports the model.

    Args:
      export_dir_base: A string path to the base export dir.
      global_step_tensor: An Tensor or tensor name providing the
        global step counter to append to the export directory path and set
        in the manifest version.
      sess: A Session to use to save the parameters.
      exports_to_keep: a gc.Path filter function used to determine the set of
        exports to keep. If set to None, all versions will be kept.

    Raises:
      RuntimeError: if init is not called.
      RuntimeError: if the export would overwrite an existing directory.
    """
    if not self._has_init:
      raise RuntimeError("init must be called first")

    global_step = training_util.global_step(sess, global_step_tensor)
    export_dir = os.path.join(export_dir_base,
                              VERSION_FORMAT_SPECIFIER % global_step)

    # Prevent overwriting on existing exports which could lead to bad/corrupt
    # storage and loading of models. This is an important check that must be
    # done before any output files or directories are created.
    if gfile.Exists(export_dir):
      raise RuntimeError("Overwriting exports can cause corruption and are "
                         "not allowed. Duplicate export dir: %s" % export_dir)

    # Output to a temporary directory which is atomically renamed to the final
    # directory when complete.
    tmp_export_dir = export_dir + "-tmp"
    gfile.MakeDirs(tmp_export_dir)

    self._saver.save(sess,
                     os.path.join(tmp_export_dir, EXPORT_BASE_NAME),
                     meta_graph_suffix=EXPORT_SUFFIX_NAME)

    # Run the asset callback.
    if self._assets_callback:
      assets_dir = os.path.join(tmp_export_dir, ASSETS_DIRECTORY)
      gfile.MakeDirs(assets_dir)
      self._assets_callback(assets_dir)

    # TODO(b/27794910): Delete *checkpoint* file before rename.
    gfile.Rename(tmp_export_dir, export_dir)

    if exports_to_keep:
      # create a simple parser that pulls the export_version from the directory.
      def parser(path):
        match = re.match("^" + export_dir_base + "/(\\d{8})$", path.path)
        if not match:
          return None
        return path._replace(export_version=int(match.group(1)))

      paths_to_delete = gc.negation(exports_to_keep)
      for p in paths_to_delete(gc.get_paths(export_dir_base, parser=parser)):
        gfile.DeleteRecursively(p.path)