Esempio n. 1
0
 def run_loop(self):
     self._sv.saver.save(self._sess, self._sv.save_path, global_step=self._sv.global_step)
     if self._sv.summary_writer and self._sv.global_step is not None:
         current_step = training_util.global_step(self._sess, self._sv.global_step)
         self._sv.summary_writer.add_session_log(
             SessionLog(status=SessionLog.CHECKPOINT, checkpoint_path=self._sv.save_path), current_step
         )
Esempio n. 2
0
 def end(self, session):
   if self._summary_op is not None:
     global_step = training_util.global_step(session, self._global_step)
     summary_str = session.run(self._summary_op, self._feed_dict)
     if self._summary_writer:
       self._summary_writer.add_summary(summary_str, global_step)
   if self._summary_writer:
     self._summary_writer.flush()
  def save(self, session=None, checkpoint_number=None):
    """Creates a new checkpoint and manages it.

    Args:
      session: The session to evaluate variables in. Ignored when executing
        eagerly. If not provided when graph building, the default session is
        used.
      checkpoint_number: An optional integer, or an integer-dtype `Variable` or
        `Tensor`, used to number the checkpoint. If `None` (default),
        checkpoints are numbered using `checkpoint.save_counter`. Even if
        `checkpoint_number` is provided, `save_counter` is still incremented. A
        user-provided `checkpoint_number` is not incremented even if it is a
        `Variable`.

    Returns:
      The path to the new checkpoint. It is also recorded in the `checkpoints`
      and `latest_checkpoint` properies.
    """
    # Save counter logic duplicated from tf.train.Checkpoint, soon to diverge
    # slightly with a custom numbering option.
    if context.executing_eagerly():
      save_counter = self._checkpoint.save_counter
      save_counter.assign_add(1)
    else:
      if session is None:
        session = ops.get_default_session()

      def _initializing_creator(next_creator, **kwargs):
        """Initialize the save counter if it has been newly created."""
        v = next_creator(**kwargs)
        session.run(v.initializer)
        return v

      with variable_scope.variable_creator_scope(_initializing_creator):
        save_counter = self._checkpoint.save_counter
      if self._save_counter_assign is None:
        self._save_counter_assign = save_counter.assign_add(1, read_value=False)
      session.run(self._save_counter_assign)
    if checkpoint_number is None:
      checkpoint_number = save_counter
    if not isinstance(checkpoint_number, compat.integral_types):
      checkpoint_number = training_util.global_step(
          sess=session, global_step_tensor=checkpoint_number)
    prefix = "%s-%d" % (self._prefix, checkpoint_number)
    save_path = self._checkpoint.write(prefix)
    timestamp = time.time()
    # If this is an overwritten checkpoint we were previously tracking, delete
    # and reinsert it to make sure it goes to the end of the queue.
    if save_path in self._maybe_delete:
      del self._maybe_delete[save_path]
    self._maybe_delete[save_path] = timestamp
    self._latest_checkpoint = save_path
    self._sweep()
    self._record_state()
    return save_path
  def save(self, sess, save_path, global_step=None, latest_filename=None):
    """Saves variables.

    This method runs the ops added by the constructor for saving variables.
    It requires a session in which the graph was launched.  The variables to
    save must also have been initialized.

    The method returns the path of the newly created checkpoint file.  This
    path can be passed directly to a call to `restore()`.

    Args:
      sess: A Session to use to save the variables.
      save_path: String.  Path to the checkpoint filename.  If the saver is
        `sharded`, this is the prefix of the sharded checkpoint filename.
      global_step: If provided the global step number is appended to
        `save_path` to create the checkpoint filename. The optional argument
        can be a `Tensor`, a `Tensor` name or an integer.
      latest_filename: Optional name for the protocol buffer file that will
        contains the list of most recent checkpoint filenames.  That file,
        kept in the same directory as the checkpoint files, is automatically
        managed by the saver to keep track of recent checkpoints.  Defaults to
        'checkpoint'.

    Returns:
      A string: path at which the variables were saved.  If the saver is
        sharded, this string ends with: '-?????-of-nnnnn' where 'nnnnn'
        is the number of shards created.

    Raises:
      TypeError: If `sess` is not a `Session`.
      ValueError: If `latest_filename` contains path components.
    """
    if latest_filename is None:
      latest_filename = "checkpoint"

    if os.path.split(latest_filename)[0]:
      raise ValueError("'latest_filename' must not contain path components")

    if global_step is not None:
      if not isinstance(global_step, compat.integral_types):
        global_step = training_util.global_step(sess, global_step)
      checkpoint_file = "%s-%d" % (save_path, global_step)
    else:
      checkpoint_file = save_path
    save_path = os.path.dirname(save_path)
    if not isinstance(sess, session.SessionInterface):
      raise TypeError("'sess' must be a Session; %s" % sess)

    model_checkpoint_path = sess.run(
        self._save_tensor_name, {self._filename_tensor_name: checkpoint_file})
    model_checkpoint_path = compat.as_str(model_checkpoint_path)
    self._MaybeDeleteOldCheckpoints(model_checkpoint_path)
    update_checkpoint_state(save_path, model_checkpoint_path,
                            self.last_checkpoints, latest_filename)
    return model_checkpoint_path
Esempio n. 5
0
  def start_standard_services(self, sess):
    """Start the standard services for 'sess'.

    This starts services in the background.  The services started depend
    on the parameters to the constructor and may include:

      - A Summary thread computing summaries every save_summaries_secs.
      - A Checkpoint thread saving the model every save_model_secs.
      - A StepCounter thread measure step time.

    Args:
      sess: A Session.

    Returns:
      A list of threads that are running the standard services.  You can use
      the Supervisor's Coordinator to join these threads with:
        sv.coord.Join(<list of threads>)

    Raises:
      RuntimeError: If called with a non-chief Supervisor.
      ValueError: If not `logdir` was passed to the constructor as the
        services need a log directory.
    """
    if not self._is_chief:
      raise RuntimeError("Only chief supervisor can start standard services. "
                         "Because only chief supervisors can write events.")

    if not self._logdir:
      logging.warning("Standard services need a 'logdir' "
                      "passed to the SessionManager")
      return

    if self._global_step is not None and self._summary_writer:
      # Only add the session log if we keep track of global step.
      # TensorBoard cannot use START message for purging expired events
      # if there is no step value.
      current_step = training_util.global_step(sess, self._global_step)
      self._summary_writer.add_session_log(
          SessionLog(status=SessionLog.START),
          current_step)

    threads = []
    if self._save_summaries_secs and self._summary_writer:
      if self._summary_op is not None:
        threads.append(SVSummaryThread(self, sess))
      if self._global_step is not None:
        threads.append(SVStepCounterThread(self, sess))
    if self.saver and self._save_model_secs:
      threads.append(SVTimerCheckpointThread(self, sess))
    for t in threads:
      t.start()
    self._started_threads.extend(threads)

    return threads
Esempio n. 6
0
def _wait_for_step(sess, global_step, step):
    """Wait till the global step has reached at least 'step'.

    Args:
      sess: A session.
      global_step: A Tensor.
      step: Int.  The global step to reach.
    """
    while True:
        if training_util.global_step(sess, global_step) >= step:
            break
        time.sleep(1.0)
Esempio n. 7
0
 def run_loop(self):
     # Count the steps.
     current_step = training_util.global_step(self._sess, self._sv.global_step)
     added_steps = current_step - self._last_step
     self._last_step = current_step
     # Measure the elapsed time.
     current_time = time.time()
     elapsed_time = current_time - self._last_time
     self._last_time = current_time
     # Reports the number of steps done per second
     steps_per_sec = added_steps / elapsed_time
     summary = Summary(value=[Summary.Value(tag=self._summary_tag, simple_value=steps_per_sec)])
     if self._sv.summary_writer:
         self._sv.summary_writer.add_summary(summary, current_step)
     logging.log_first_n(logging.INFO, "%s: %g", 10, self._summary_tag, steps_per_sec)
Esempio n. 8
0
  def summary_computed(self, sess, summary, global_step=None):
    """Indicate that a summary was computed.

    Args:
      sess: A `Session` object.
      summary: A Summary proto, or a string holding a serialized summary proto.
      global_step: Int. global step this summary is associated with. If `None`,
        it will try to fetch the current step.

    Raises:
      TypeError: if 'summary' is not a Summary proto or a string.
      RuntimeError: if the Supervisor was created without a `logdir`.
    """
    if not self._summary_writer:
      raise RuntimeError("Writing a summary requires a summary writer.")
    if global_step is None and self.global_step is not None:
      global_step = training_util.global_step(sess, self.global_step)
    self._summary_writer.add_summary(summary, global_step)
Esempio n. 9
0
  def export(self,
             export_dir_base,
             global_step_tensor,
             sess=None,
             exports_to_keep=None):
    """Exports the model.

    Args:
      export_dir_base: A string path to the base export dir.
      global_step_tensor: An Tensor or tensor name providing the
        global step counter to append to the export directory path and set
        in the manifest version.
      sess: A Session to use to save the parameters.
      exports_to_keep: a gc.Path filter function used to determine the set of
        exports to keep. If set to None, all versions will be kept.

    Raises:
      RuntimeError: if init is not called.
      RuntimeError: if the export would overwrite an existing directory.
    """
    if not self._has_init:
      raise RuntimeError("init must be called first")

    global_step = training_util.global_step(sess, global_step_tensor)
    export_dir = os.path.join(export_dir_base,
                              VERSION_FORMAT_SPECIFIER % global_step)

    # Prevent overwriting on existing exports which could lead to bad/corrupt
    # storage and loading of models. This is an important check that must be
    # done before any output files or directories are created.
    if gfile.Exists(export_dir):
      raise RuntimeError("Overwriting exports can cause corruption and are "
                         "not allowed. Duplicate export dir: %s" % export_dir)

    # Output to a temporary directory which is atomically renamed to the final
    # directory when complete.
    tmp_export_dir = export_dir + "-tmp"
    gfile.MakeDirs(tmp_export_dir)

    self._saver.save(sess,
                     os.path.join(tmp_export_dir, EXPORT_BASE_NAME),
                     meta_graph_suffix=EXPORT_SUFFIX_NAME)

    # Run the asset callback.
    if self._assets_callback:
      assets_dir = os.path.join(tmp_export_dir, ASSETS_DIRECTORY)
      gfile.MakeDirs(assets_dir)
      self._assets_callback(assets_dir)

    # TODO(b/27794910): Delete *checkpoint* file before rename.
    gfile.Rename(tmp_export_dir, export_dir)

    if exports_to_keep:
      # create a simple parser that pulls the export_version from the directory.
      def parser(path):
        match = re.match("^" + export_dir_base + "/(\\d{8})$", path.path)
        if not match:
          return None
        return path._replace(export_version=int(match.group(1)))

      paths_to_delete = gc.negation(exports_to_keep)
      for p in paths_to_delete(gc.get_paths(export_dir_base, parser=parser)):
        gfile.DeleteRecursively(p.path)
Esempio n. 10
0
def evaluation(sess,
               num_evals=1,
               initial_op=None,
               initial_op_feed_dict=None,
               eval_op=None,
               eval_op_feed_dict=None,
               final_op=None,
               final_op_feed_dict=None,
               summary_op=None,
               summary_op_feed_dict=None,
               summary_writer=None,
               global_step=None):
  """Performs a single evaluation run.

  A single evaluation consists of several steps run in the following order:
  (1) an initialization op, (2) an evaluation op which is executed `num_evals`
  times (3) a finalization op and (4) the execution of a summary op which is
  written out using a summary writer.

  Args:
    sess: The current TensorFlow `Session`.
    num_evals: The number of times to execute `eval_op`.
    initial_op: An operation run at the beginning of evaluation.
    initial_op_feed_dict: A feed dictionary to use when executing `initial_op`.
    eval_op: A operation run `num_evals` times.
    eval_op_feed_dict: The feed dictionary to use when executing the `eval_op`.
    final_op: An operation to execute after all of the `eval_op` executions. The
      value of `final_op` is returned.
    final_op_feed_dict: A feed dictionary to use when executing `final_op`.
    summary_op: A summary op executed after `eval_op` and `finalize_op`.
    summary_op_feed_dict: An optional feed dictionary to use when executing the
      `summary_op`.
    summary_writer: The summery writer used if `summary_op` is provided.
    global_step: the global step variable. If left as `None`, then
      slim.variables.global_step() is used.

  Returns:
    The value of `final_op` or `None` if `final_op` is `None`.

  Raises:
    ValueError: if `summary_op` is provided but `global_step` is `None`.
  """
  if initial_op is not None:
    logging.info('Executing initial eval op')
    sess.run(initial_op, initial_op_feed_dict)

  if eval_op is not None:
    logging.info('Executing eval ops')
    for i in range(int(num_evals)):
      logging.info('Executing eval_op %d/%d', i + 1, num_evals)
      sess.run(eval_op, eval_op_feed_dict)

  if final_op is not None:
    logging.info('Executing final op')
    final_op_value = sess.run(final_op, final_op_feed_dict)
  else:
    final_op_value = None

  if summary_op is not None:
    logging.info('Executing summary op')
    if global_step is None:
      global_step = variables.get_or_create_global_step()

    global_step = training_util.global_step(sess, global_step)
    summary = sess.run(summary_op, summary_op_feed_dict)
    summary_writer.add_summary(summary, global_step)
    summary_writer.flush()

  return final_op_value
Esempio n. 11
0
 def start_loop(self):
   self._last_time = time.time()
   self._last_step = training_util.global_step(
       self._sess, self._sv.global_step)
Esempio n. 12
0
  def export(self,
             export_dir_base,
             global_step_tensor,
             sess=None,
             exports_to_keep=None):
    """Exports the model.

    Args:
      export_dir_base: A string path to the base export dir.
      global_step_tensor: An Tensor or tensor name providing the
        global step counter to append to the export directory path and set
        in the manifest version.
      sess: A Session to use to save the parameters.
      exports_to_keep: a gc.Path filter function used to determine the set of
        exports to keep. If set to None, all versions will be kept.

    Returns:
      The string path to the exported directory.

    Raises:
      RuntimeError: if init is not called.
      RuntimeError: if the export would overwrite an existing directory.
    """
    if not self._has_init:
      raise RuntimeError("init must be called first")

    # Export dir must not end with / or it will break exports to keep. Strip /.
    if export_dir_base.endswith("/"):
      export_dir_base = export_dir_base[:-1]

    global_step = training_util.global_step(sess, global_step_tensor)
    export_dir = os.path.join(
        compat.as_bytes(export_dir_base),
        compat.as_bytes(constants.VERSION_FORMAT_SPECIFIER % global_step))

    # Prevent overwriting on existing exports which could lead to bad/corrupt
    # storage and loading of models. This is an important check that must be
    # done before any output files or directories are created.
    if gfile.Exists(export_dir):
      raise RuntimeError("Overwriting exports can cause corruption and are "
                         "not allowed. Duplicate export dir: %s" % export_dir)

    # Output to a temporary directory which is atomically renamed to the final
    # directory when complete.
    tmp_export_dir = compat.as_text(export_dir) + "-tmp"
    gfile.MakeDirs(tmp_export_dir)

    self._saver.save(sess,
                     os.path.join(
                         compat.as_text(tmp_export_dir),
                         compat.as_text(constants.EXPORT_BASE_NAME)),
                     meta_graph_suffix=constants.EXPORT_SUFFIX_NAME)

    # Run the asset callback.
    if self._assets_callback and self._assets_to_copy:
      assets_dir = os.path.join(
          compat.as_bytes(tmp_export_dir),
          compat.as_bytes(constants.ASSETS_DIRECTORY))
      gfile.MakeDirs(assets_dir)
      self._assets_callback(self._assets_to_copy, assets_dir)

    # TODO(b/27794910): Delete *checkpoint* file before rename.
    gfile.Rename(tmp_export_dir, export_dir)

    if exports_to_keep:
      # create a simple parser that pulls the export_version from the directory.
      def parser(path):
        match = re.match("^" + export_dir_base + "/(\\d{8})$", path.path)
        if not match:
          return None
        return path._replace(export_version=int(match.group(1)))

      paths_to_delete = gc.negation(exports_to_keep)
      for p in paths_to_delete(gc.get_paths(export_dir_base, parser=parser)):
        gfile.DeleteRecursively(p.path)

    return export_dir
Esempio n. 13
0
def evaluation(sess,
               num_evals=1,
               initial_op=None,
               initial_op_feed_dict=None,
               eval_op=None,
               eval_op_feed_dict=None,
               final_op=None,
               final_op_feed_dict=None,
               summary_op=None,
               summary_op_feed_dict=None,
               summary_writer=None,
               global_step=None,
               cm=None):
    """Performs a single evaluation run.
  A single evaluation consists of several steps run in the following order:
  (1) an initialization op, (2) an evaluation op which is executed `num_evals`
  times (3) a finalization op and (4) the execution of a summary op which is
  written out using a summary writer.
  Args:
    sess: The current TensorFlow `Session`.
    num_evals: The number of times to execute `eval_op`.
    initial_op: An operation run at the beginning of evaluation.
    initial_op_feed_dict: A feed dictionary to use when executing `initial_op`.
    eval_op: A operation run `num_evals` times.
    eval_op_feed_dict: The feed dictionary to use when executing the `eval_op`.
    final_op: An operation to execute after all of the `eval_op` executions. The
      value of `final_op` is returned.
    final_op_feed_dict: A feed dictionary to use when executing `final_op`.
    summary_op: A summary op executed after `eval_op` and `finalize_op`.
    summary_op_feed_dict: An optional feed dictionary to use when executing the
      `summary_op`.
    summary_writer: The summery writer used if `summary_op` is provided.
    global_step: the global step variable. If left as `None`, then
      slim.variables.global_step() is used.
  Returns:
    The value of `final_op` or `None` if `final_op` is `None`.
  Raises:
    ValueError: if `summary_op` is provided but `global_step` is `None`.
  """

    #big_confusion_matrix = np.zeros([10,10])
    accurancy = 0

    if initial_op is not None:
        logging.info('Executing initial eval op')
        sess.run(initial_op, initial_op_feed_dict)

    if eval_op is not None:
        logging.info('Executing eval ops')
        for i in range(int(num_evals)):
            logging.info('Executing eval_op %d/%d', i + 1, num_evals)
            accurancy += float(sess.run(eval_op, eval_op_feed_dict)[1])
            #confusion_matrix = cm[0]
            #big_confusion_matrix = big_confusion_matrix + np.array(sess.run(confusion_matrix))


#      big_labels.append(np.array(sess.run(cm[1])))
#      big_predictions.append(np.array(sess.run(cm[2])))
#      big_logits.append(np.array(sess.run(cm[3])))
#logging.info(sess.run(cm))

# with open("confusion_matrix.txt", "a") as myfile:
#  myfile.write(str(big_confusion_matrix.flatten()))
#   myfile.write('\n')
#  myfile.close()

#  with open("labels.txt", "a") as myfile:
#    myfile.write(str(big_labels))
#    myfile.write('\n')
#  myfile.close()

#  with open("predictions.txt", "a") as myfile:
#    myfile.write(str(big_predictions))
#    myfile.write('\n')
#  myfile.close()

#  with open("logits.txt", "a") as myfile:
#    myfile.write(str(big_logits))
#    myfile.write('\n')
#  myfile.close()

    if final_op is not None:
        logging.info('Executing final op')
        final_op_value = sess.run(final_op, final_op_feed_dict)
    else:
        final_op_value = None

    if summary_op is not None:
        logging.info('Executing summary op')
        if global_step is None:
            global_step = variables.get_or_create_global_step()

        global_step = training_util.global_step(sess, global_step)
        summary_str = sess.run(summary_op, summary_op_feed_dict)
        accurancy = accurancy / float(int(num_evals))
        accfile = open('accurancies.txt', 'a')
        accfile.write(str(accurancy))
        accfile.write(',')
        accfile.close()
        #np.save('confusion_matrix',big_confusion_matrix)
        summary_writer.add_summary(summary_str, global_step)
        summary_writer.flush()

    return final_op_value
Esempio n. 14
0
 def end(self, session):
     global_step = training_util.global_step(session, self._global_step)
     summary_str = session.run(self._summary_op, self._feed_dict)
     self._summary_writer.add_summary(summary_str, global_step)
     self._summary_writer.flush()
Esempio n. 15
0
def evaluation(sess,
               num_evals=1,
               init_op=None,
               init_op_feed_dict=None,
               eval_op=None,
               eval_op_feed_dict=None,
               final_op=None,
               final_op_feed_dict=None,
               summary_op=None,
               summary_op_feed_dict=None,
               summary_writer=None,
               global_step=None):
  """Performs a single evaluation run.

  A single evaluation consists of several steps run in the following order:
  (1) an initialization op, (2) an evaluation op which is executed `num_evals`
  times (3) a finalization op and (4) the execution of a summary op which is
  written out using a summary writer.

  Args:
    sess: The current TensorFlow `Session`.
    num_evals: The number of times to execute `eval_op`.
    init_op: An operation run at the beginning of evaluation.
    init_op_feed_dict: A feed dictionary to use when executing `init_op`.
    eval_op: A operation run `num_evals` times.
    eval_op_feed_dict: The feed dictionary to use when executing the `eval_op`.
    final_op: An operation to execute after all of the `eval_op` executions. The
      value of `final_op` is returned.
    final_op_feed_dict: A feed dictionary to use when executing `final_op`.
    summary_op: A summary op executed after `eval_op` and `finalize_op`.
    summary_op_feed_dict: An optional feed dictionary to use when executing the
      `summary_op`.
    summary_writer: The summery writer used if `summary_op` is provided.
    global_step: the global step variable. If left as `None`, then
      slim.variables.global_step() is used.

  Returns:
    The value of `final_op` or `None` if `final_op` is `None`.

  Raises:
    ValueError: if `summary_op` is provided but `global_step` is `None`.
  """
  if init_op is not None:
    logging.info('Executing init op')
    sess.run(init_op, init_op_feed_dict)

  if eval_op is not None:
    logging.info('Executing eval ops')
    for i in range(int(num_evals)):
      logging.info('Executing eval_op %d/%d', i + 1, num_evals)
      sess.run(eval_op, eval_op_feed_dict)

  if final_op is not None:
    logging.info('Executing final op')
    final_op_value = sess.run(final_op, final_op_feed_dict)
  else:
    final_op_value = None

  if summary_op is not None:
    logging.info('Executing summary op')
    if global_step is None:
      global_step = variables.get_or_create_global_step()

    global_step = training_util.global_step(sess, global_step)
    summary = sess.run(summary_op, summary_op_feed_dict)
    summary_writer.add_summary(summary, global_step)
    summary_writer.flush()

  return final_op_value
Esempio n. 16
0
 def start_loop(self):
     self._last_time = time.time()
     self._last_step = training_util.global_step(self._sess,
                                                 self._step_counter)
    def save(self, checkpoint_number=None, check_interval=True, options=None):
        """Creates a new checkpoint and manages it.

    Args:
      checkpoint_number: An optional integer, or an integer-dtype `Variable` or
        `Tensor`, used to number the checkpoint. If `None` (default),
        checkpoints are numbered using `checkpoint.save_counter`. Even if
        `checkpoint_number` is provided, `save_counter` is still incremented. A
        user-provided `checkpoint_number` is not incremented even if it is a
        `Variable`.
      check_interval: An optional boolean. The argument is only effective when
        `checkpoint_interval` is passed into the manager. If `True`, the manager
        will only save the checkpoint if the interval between checkpoints is
        larger than `checkpoint_interval`. Otherwise it will always save the
        checkpoint unless a checkpoint has already been saved for the current
        step.
      options: Optional `tf.train.CheckpointOptions` object. This argument only
        works with TF2 checkpoint objects. For example, options =
        tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')

    Returns:
      The path to the new checkpoint. It is also recorded in the `checkpoints`
      and `latest_checkpoint` properties. `None` if no checkpoint is saved.
    """
        if self._checkpoint_interval is not None:
            current_step = _evaluate(self._step_counter)
            if self._last_checkpoint_step is not None:
                if current_step == self._last_checkpoint_step:
                    return None
                if check_interval and current_step < (
                        self._last_checkpoint_step +
                        self._checkpoint_interval):
                    return None
            self._last_checkpoint_step = current_step

        # Save counter logic duplicated from tf.train.Checkpoint, soon to diverge
        # slightly with a custom numbering option.
        if context.executing_eagerly():
            save_counter = self._checkpoint.save_counter
            save_counter.assign_add(1)
            session = None
        else:
            session = ops.get_default_session()

            def _initializing_creator(next_creator, **kwargs):
                """Initialize the save counter if it has been newly created."""
                v = next_creator(**kwargs)
                session.run(v.initializer)
                return v

            with variable_scope.variable_creator_scope(_initializing_creator):
                save_counter = self._checkpoint.save_counter
            if self._save_counter_assign is None:
                self._save_counter_assign = save_counter.assign_add(
                    1, read_value=False)
            session.run(self._save_counter_assign)
        if checkpoint_number is None:
            checkpoint_number = save_counter
        if not isinstance(checkpoint_number, compat.integral_types):
            checkpoint_number = training_util.global_step(
                sess=session, global_step_tensor=checkpoint_number)
        prefix = "%s-%d" % (self._prefix, checkpoint_number)

        def _record_and_sweep_state(save_path):
            timestamp = time.time()
            # If this is an overwritten checkpoint we were previously tracking, delete
            # and reinsert it to make sure it goes to the end of the queue.
            if save_path in self._maybe_delete:
                del self._maybe_delete[save_path]
            self._maybe_delete[save_path] = timestamp
            self._latest_checkpoint = save_path
            # Before deleting anything we update the Checkpoint proto with the new
            # checkpoint. We'll go back and correct it after cleaning up old files,
            # but a preemption while deleting will be more likely to see the new
            # checkpoint this way.
            self._record_state()
            self._sweep()
            # Write out the Checkpoint proto a second time, now without the deleted
            # checkpoints.
            self._record_state()

        if options is None:
            save_path = self._checkpoint._write(  # pylint: disable=protected-access
                prefix,
                write_done_callback=_record_and_sweep_state)
        else:
            save_path = self._checkpoint._write(  # pylint: disable=protected-access
                prefix,
                options=options,
                write_done_callback=_record_and_sweep_state)

        return save_path