예제 #1
0
def _get_checkpoint_filename(ckpt_dir_or_file):
  """Returns checkpoint filename given directory or specific checkpoint file."""
  if isinstance(ckpt_dir_or_file, os.PathLike):
    ckpt_dir_or_file = os.fspath(ckpt_dir_or_file)
  if gfile.IsDirectory(ckpt_dir_or_file):
    return checkpoint_management.latest_checkpoint(ckpt_dir_or_file)
  return ckpt_dir_or_file
예제 #2
0
def wait_for_new_checkpoint(checkpoint_dir,
                            last_checkpoint=None,
                            seconds_to_sleep=1,
                            timeout=None):
  """Waits until a new checkpoint file is found.

  Args:
    checkpoint_dir: The directory in which checkpoints are saved.
    last_checkpoint: The last checkpoint path used or `None` if we're expecting
      a checkpoint for the first time.
    seconds_to_sleep: The number of seconds to sleep for before looking for a
      new checkpoint.
    timeout: The maximum number of seconds to wait. If left as `None`, then the
      process will wait indefinitely.

  Returns:
    a new checkpoint path, or None if the timeout was reached.
  """
  logging.info("Waiting for new checkpoint at %s", checkpoint_dir)
  stop_time = time.time() + timeout if timeout is not None else None
  while True:
    checkpoint_path = checkpoint_management.latest_checkpoint(checkpoint_dir)
    if checkpoint_path is None or checkpoint_path == last_checkpoint:
      if stop_time is not None and time.time() + seconds_to_sleep > stop_time:
        return None
      time.sleep(seconds_to_sleep)
    else:
      logging.info("Found new checkpoint at %s", checkpoint_path)
      return checkpoint_path
  def test_spmd_model_checkpointing(self):

    class LinearModel(module.Module):

      def __init__(self, w):
        super(LinearModel, self).__init__()
        self.w = variables.Variable(w)

      def __call__(self, x):
        return math_ops.matmul(x, self.w)

      def change_weights_op(self, w_new):
        return self.w.assign(w_new)

    batch_size = 32
    num_feature_in = 16
    num_feature_out = 8
    w1 = random_ops.random_uniform((num_feature_in, num_feature_out),
                                   dtype=dtypes.float32)
    w2 = random_ops.random_uniform((num_feature_in, num_feature_out),
                                   dtype=dtypes.float32)
    x = random_ops.random_uniform((batch_size, num_feature_in),
                                  dtype=dtypes.float32)

    strategy, num_replicas = get_tpu_strategy(enable_spmd=True)
    with strategy.scope():
      model = LinearModel(w1)

    checkpoint_dir = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
    checkpoint = util.Checkpoint(model=model)

    @def_function.function
    def step_fn(x):
      x = strategy.experimental_split_to_logical_devices(x, [1, 2])
      return model(x)

    with self.cached_session() as sess:
      self.evaluate(variables.global_variables_initializer())
      checkpoint.save(file_prefix=checkpoint_prefix)

      self.evaluate(model.change_weights_op(w2))
      result = strategy.run(step_fn, args=(x,))
      self.assertAllClose(
          math_ops.matmul(x, w2) * num_replicas,
          self.evaluate(strategy.reduce("SUM", result, axis=None)),
          rtol=5e-3,
          atol=5e-3)

      status = checkpoint.restore(
          checkpoint_management.latest_checkpoint(checkpoint_dir))
      status.run_restore_ops(sess)  # must run restore op in non-eager mode.
      status.assert_consumed()
      status.assert_existing_objects_matched()
      result = strategy.run(step_fn, args=(x,))
      self.assertAllClose(
          math_ops.matmul(x, w1) * num_replicas,
          self.evaluate(strategy.reduce("SUM", result, axis=None)),
          rtol=5e-3,
          atol=5e-3)
예제 #4
0
    def _restore_or_save_initial_ckpt(self, session):
        # Ideally this should be run in after_create_session but is not for the
        # following reason:
        # Currently there is no way of enforcing an order of running the
        # `SessionRunHooks`. Hence it is possible that the `_DatasetInitializerHook`
        # is run *after* this hook. That is troublesome because
        # 1. If a checkpoint exists and this hook restores it, the initializer hook
        #    will override it.
        # 2. If no checkpoint exists, this hook will try to save an uninitialized
        #    iterator which will result in an exception.
        #
        # As a temporary fix we enter the following implicit contract between this
        # hook and the _DatasetInitializerHook.
        # 1. The _DatasetInitializerHook initializes the iterator in the call to
        #    after_create_session.
        # 2. This hook saves the iterator on the first call to `before_run()`, which
        #    is guaranteed to happen after `after_create_session()` of all hooks
        #    have been run.

        # Check if there is an existing checkpoint. If so, restore from it.
        # pylint: disable=protected-access
        latest_checkpoint_path = checkpoint_management.latest_checkpoint(
            self._checkpoint_saver_hook._checkpoint_dir,
            latest_filename=self._latest_filename)
        if latest_checkpoint_path:
            self._checkpoint_saver_hook._get_saver().restore(
                session, latest_checkpoint_path)
        else:
            # The checkpoint saved here is the state at step "global_step".
            # Note: We do not save the GraphDef or MetaGraphDef here.
            global_step = session.run(
                self._checkpoint_saver_hook._global_step_tensor)
            self._checkpoint_saver_hook._save(session, global_step)
            self._checkpoint_saver_hook._timer.update_last_triggered_step(
                global_step)
    def testLatestCheckpointFSpathDirectory(self):
        directory = pathlib.Path(self.get_temp_dir())
        checkpoint = util.Checkpoint()
        manager = checkpoint_management.CheckpointManager(
            checkpoint, directory, max_to_keep=2, checkpoint_name="ckpt_name")
        manager.save()

        cp_dir = checkpoint_management.latest_checkpoint(directory)
        self.assertEqual(str(directory / "ckpt_name-1"), cp_dir)
예제 #6
0
 def _read_vars(self, model_dir):
   """Returns (global_step, latest_feature)."""
   with ops.Graph().as_default() as g:
     ckpt_path = checkpoint_management.latest_checkpoint(model_dir)
     meta_filename = ckpt_path + '.meta'
     saver_lib.import_meta_graph(meta_filename)
     saver = saver_lib.Saver()
     with self.session(graph=g) as sess:
       saver.restore(sess, ckpt_path)
       return sess.run(ops.get_collection('my_vars'))
예제 #7
0
    def test_paritioned_model_checkpointing(self):
        class PartitionedModel(module.Module):
            def __init__(self, v, w):
                super(PartitionedModel, self).__init__()

                assert distribution_strategy_context.has_strategy()
                strategy = distribution_strategy_context.get_strategy()

                with strategy.extended.experimental_logical_device(0):
                    self.v = variables.Variable(v)
                with strategy.extended.experimental_logical_device(1):
                    self.w = variables.Variable(w)

            def __call__(self, x):
                replica_ctx = distribution_strategy_context.get_replica_context(
                )
                with replica_ctx.experimental_logical_device(0):
                    y = self.v * x
                with replica_ctx.experimental_logical_device(1):
                    z = self.w * y
                return z

            def change_weights_op(self, v_new, w_new):
                return control_flow_ops.group(
                    [self.v.assign(v_new),
                     self.w.assign(w_new)])

        strategy, num_replicas = get_tpu_strategy()
        with strategy.scope():
            model = PartitionedModel(2., 3.)

        checkpoint_dir = self.get_temp_dir()
        checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
        checkpoint = util.Checkpoint(model=model)

        with self.cached_session() as sess:
            self.evaluate(variables.global_variables_initializer())
            checkpoint.save(file_prefix=checkpoint_prefix)

            self.evaluate(model.change_weights_op(1., 4.))
            result = strategy.run(def_function.function(model), args=(5.0, ))
            self.assertEqual(
                20. * num_replicas,
                self.evaluate(strategy.reduce("SUM", result, axis=None)))

            status = checkpoint.restore(
                checkpoint_management.latest_checkpoint(checkpoint_dir))
            status.run_restore_ops(
                sess)  # must run restore op in non-eager mode.
            status.assert_consumed()
            status.assert_existing_objects_matched()
            result = strategy.run(def_function.function(model), args=(5.0, ))
            self.assertEqual(
                30. * num_replicas,
                self.evaluate(strategy.reduce("SUM", result, axis=None)))
예제 #8
0
 def testRestoreInReconstructedIteratorInitializable(self):
     checkpoint_directory = self.get_temp_dir()
     checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
     dataset = dataset_ops.Dataset.range(10)
     iterator = iter(dataset)
     get_next = iterator.get_next
     checkpoint = trackable_utils.Checkpoint(iterator=iterator)
     for i in range(5):
         checkpoint.restore(
             checkpoint_management.latest_checkpoint(
                 checkpoint_directory)).initialize_or_restore()
         for j in range(2):
             self.assertEqual(i * 2 + j, self.evaluate(get_next()))
         checkpoint.save(file_prefix=checkpoint_prefix)
    def testNameCollision(self):
        # Make sure we have a clean directory to work in.
        with self.tempDir() as tempdir:
            # Jump to that directory until this test is done.
            with self.tempWorkingDir(tempdir):
                # Save training snapshots to a relative path.
                traindir = "train"
                os.mkdir(traindir)
                # Collides with the default name of the checkpoint state file.
                filepath = os.path.join(traindir, "checkpoint")

                with self.cached_session() as sess:
                    unused_a = variables.Variable(
                        0.0)  # So that Saver saves something.
                    self.evaluate(variables.global_variables_initializer())

                    # Should fail.
                    saver = saver_module.Saver(sharded=False)
                    with self.assertRaisesRegex(ValueError, "collides with"):
                        saver.save(sess, filepath)

                    # Succeeds: the file will be named "checkpoint-<step>".
                    saver.save(sess, filepath, global_step=1)
                    self.assertIsNotNone(
                        checkpoint_management.latest_checkpoint(traindir))

                    # Succeeds: the file will be named "checkpoint-<i>-of-<n>".
                    saver = saver_module.Saver(sharded=True)
                    saver.save(sess, filepath)
                    self.assertIsNotNone(
                        checkpoint_management.latest_checkpoint(traindir))

                    # Succeeds: the file will be named "checkpoint-<step>-<i>-of-<n>".
                    saver = saver_module.Saver(sharded=True)
                    saver.save(sess, filepath, global_step=1)
                    self.assertIsNotNone(
                        checkpoint_management.latest_checkpoint(traindir))
    def testRelativePath(self):
        # Make sure we have a clean directory to work in.
        with self.tempDir() as tempdir:

            # Jump to that directory until this test is done.
            with self.tempWorkingDir(tempdir):

                # Save training snapshots to a relative path.
                traindir = "train"
                os.mkdir(traindir)

                filename = "snapshot"
                filepath = os.path.join(traindir, filename)

                with self.cached_session() as sess:
                    # Build a simple graph.
                    v0 = variables.Variable(0.0)
                    inc = v0.assign_add(1.0)

                    save = saver_module.Saver({"v0": v0})

                    # Record a short training history.
                    self.evaluate(variables.global_variables_initializer())
                    save.save(sess, filepath, global_step=0)
                    self.evaluate(inc)
                    save.save(sess, filepath, global_step=1)
                    self.evaluate(inc)
                    save.save(sess, filepath, global_step=2)

                with self.cached_session() as sess:
                    # Build a new graph with different initialization.
                    v0 = variables.Variable(-1.0)

                    # Create a new saver.
                    save = saver_module.Saver({"v0": v0})
                    self.evaluate(variables.global_variables_initializer())

                    # Get the most recent checkpoint name from the training history file.
                    name = checkpoint_management.latest_checkpoint(traindir)
                    self.assertIsNotNone(name)

                    # Restore "v0" from that checkpoint.
                    save.restore(sess, name)
                    self.assertEqual(v0.eval(), 2.0)
    def testCheckpointExists(self):
        for sharded in (False, True):
            for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
                with self.session(graph=ops_lib.Graph()) as sess:
                    unused_v = variables.Variable(1.0, name="v")
                    self.evaluate(variables.global_variables_initializer())
                    saver = saver_module.Saver(sharded=sharded,
                                               write_version=version)

                    path = os.path.join(self._base_dir,
                                        "%s-%s" % (sharded, version))
                    self.assertFalse(
                        checkpoint_management.checkpoint_exists(
                            path))  # Not saved yet.

                    ckpt_prefix = saver.save(sess, path)
                    self.assertTrue(
                        checkpoint_management.checkpoint_exists(ckpt_prefix))

                    ckpt_prefix = checkpoint_management.latest_checkpoint(
                        self._base_dir)
                    self.assertTrue(
                        checkpoint_management.checkpoint_exists(ckpt_prefix))
 def testCustomNumbering(self):
     directory = self.get_temp_dir()
     step = variables.Variable(0, dtype=dtypes.int64)
     checkpoint = util.Checkpoint(step=step)
     manager = checkpoint_management.CheckpointManager(checkpoint,
                                                       directory,
                                                       max_to_keep=2)
     self.evaluate(step.initializer)
     for i in range(5):
         path = manager.save(checkpoint_number=step)
         expected_suffix = "-%d" % (2 * i, )
         if not path.endswith(expected_suffix):
             self.fail("%s should have suffix %s" % (path, expected_suffix))
         self.evaluate(step.assign_add(2))
     self.assertEqual(5, self.evaluate(checkpoint.save_counter))
     # Test regular integers
     last_path = manager.save(checkpoint_number=32)
     self.assertIn("-32", last_path)
     self.assertEqual(last_path, manager.latest_checkpoint)
     self.assertEqual(last_path,
                      checkpoint_management.latest_checkpoint(directory))
     state = checkpoint_management.get_checkpoint_state(directory)
     # Only the most recent two checkpoints are saved
     self.assertEqual([path, last_path], state.all_model_checkpoint_paths)
예제 #13
0
 def _latest_ckpt(self):
     return checkpoint_management.latest_checkpoint(self.get_temp_dir())