def test_latest_checkpoint(self):
    prefix = 'chkpnt_'
    latest_checkpoint = checkpoint_utils.latest_checkpoint(
        self.get_temp_dir(), prefix)
    self.assertIsNone(latest_checkpoint)

    # Create checkpoints and ensure that the latest checkpoint found is
    # always the most recently created path.
    state = build_fake_state()
    for round_num in range(5):
      export_dir = os.path.join(self.get_temp_dir(),
                                '{}{:03d}'.format(prefix, round_num))
      checkpoint_utils.save(state, export_dir)
      latest_checkpoint_path = checkpoint_utils.latest_checkpoint(
          self.get_temp_dir(), prefix)
      self.assertEndsWith(
          latest_checkpoint_path,
          '{:03d}'.format(round_num),
          msg=latest_checkpoint_path)

    # Delete the checkpoints in reverse order and ensure the latest checkpoint
    # decreases.
    for round_num in reversed(range(2, 5)):
      export_dir = os.path.join(self.get_temp_dir(),
                                '{}{:03d}'.format(prefix, round_num))
      tf.io.gfile.rmtree(export_dir)
      latest_checkpoint_path = checkpoint_utils.latest_checkpoint(
          self.get_temp_dir(), prefix)
      self.assertEndsWith(
          latest_checkpoint_path,
          '{:03d}'.format(round_num - 1),
          msg=latest_checkpoint_path)
  def test_save_and_load_roundtrip(self):
    state = build_fake_state()
    export_dir = os.path.join(self.get_temp_dir(), 'ckpt_1')
    checkpoint_utils.save(state, export_dir)

    loaded_state = checkpoint_utils.load(export_dir, state)
    self.assertEqual(state, loaded_state)
Exemplo n.º 3
0
    def test_load_latest_state(self):
        server_optimizer_fn = functools.partial(tf.keras.optimizers.SGD,
                                                learning_rate=0.1,
                                                momentum=0.9)

        iterative_process = tff.learning.build_federated_averaging_process(
            models.model_fn, server_optimizer_fn=server_optimizer_fn)
        server_state = iterative_process.initialize()
        # TODO(b/130724878): These conversions should not be needed.
        obj_1 = Obj.from_anon_tuple(server_state, 1)
        export_dir = os.path.join(self.get_temp_dir(), 'ckpt_1')
        checkpoint_utils.save(obj_1, export_dir)

        # TODO(b/130724878): These conversions should not be needed.
        obj_2 = Obj.from_anon_tuple(server_state, 2)
        export_dir = os.path.join(self.get_temp_dir(), 'ckpt_2')
        checkpoint_utils.save(obj_2, export_dir)

        export_dir = checkpoint_utils.latest_checkpoint(self.get_temp_dir())

        loaded_obj = checkpoint_utils.load(export_dir, obj_1)

        self.assertEqual(os.path.join(self.get_temp_dir(), 'ckpt_2'),
                         export_dir)
        self.assertAllClose(tf.nest.flatten(obj_2),
                            tf.nest.flatten(loaded_obj))
Exemplo n.º 4
0
def write_checkpoint(root_checkpoint_dir, server_state, round_num):
  """Write the current experiment state to disk."""
  if root_checkpoint_dir is None:
    return

  if tf.io.gfile.exists(root_checkpoint_dir):
    # Clean-up old checkpoints if more than 5 exist, not including the
    # original (which captures random model initialization).
    checkpoints = sorted(
        tf.io.gfile.glob(
            os.path.join(root_checkpoint_dir, CHECKPOINT_PREFIX + '*')))
    to_remove = checkpoints[1:-1]
    logging.info('Cleaning up %s', to_remove)
    for checkpoint in to_remove:
      tf.io.gfile.rmtree(checkpoint)

  state = ExperimentState(round_num, server_state)
  checkpoint_dir = os.path.join(root_checkpoint_dir,
                                '{}{:04d}'.format(CHECKPOINT_PREFIX, round_num))
  checkpoint_utils.save(state, checkpoint_dir)
Exemplo n.º 5
0
def write_checkpoint(checkpoint_dir, server_state, metrics_dataframe,
                     round_num):
    """Write the current experiment state to disk."""
    # Clean-up old checkpoints if more than 5 exist; but not the initialization
    # checkpoint.
    checkpoints = sorted(
        tf.io.gfile.glob(os.path.join(checkpoint_dir, 'ckpt_*')))
    for checkpoint in checkpoints[1:-3]:
        tf.io.gfile.rmtree(checkpoint)

    # We must flatten the pd.Dataframe to a single string, otherwise we don't
    # know the nested structure (how many rounds have passed) to unpack
    # in `checkpoint_utils.load()` during `read_checkpoint`.
    csv_string = io.StringIO()
    metrics_dataframe.to_csv(csv_string, header=True)

    experiment_state = ExperimentState(
        round_num=round_num,
        metrics_csv_string=csv_string.getvalue(),
        server_state=server_state)
    checkpoint_utils.save(
        experiment_state,
        os.path.join(checkpoint_dir, 'ckpt_{:03d}'.format(round_num)))