def test_latest_checkpoint(self):
    prefix = 'chkpnt_'
    latest_checkpoint = checkpoint_utils.latest_checkpoint(
        self.get_temp_dir(), prefix)
    self.assertIsNone(latest_checkpoint)

    # Create checkpoints and ensure that the latest checkpoint found is
    # always the most recently created path.
    state = build_fake_state()
    for round_num in range(5):
      export_dir = os.path.join(self.get_temp_dir(),
                                '{}{:03d}'.format(prefix, round_num))
      checkpoint_utils.save(state, export_dir)
      latest_checkpoint_path = checkpoint_utils.latest_checkpoint(
          self.get_temp_dir(), prefix)
      self.assertEndsWith(
          latest_checkpoint_path,
          '{:03d}'.format(round_num),
          msg=latest_checkpoint_path)

    # Delete the checkpoints in reverse order and ensure the latest checkpoint
    # decreases.
    for round_num in reversed(range(2, 5)):
      export_dir = os.path.join(self.get_temp_dir(),
                                '{}{:03d}'.format(prefix, round_num))
      tf.io.gfile.rmtree(export_dir)
      latest_checkpoint_path = checkpoint_utils.latest_checkpoint(
          self.get_temp_dir(), prefix)
      self.assertEndsWith(
          latest_checkpoint_path,
          '{:03d}'.format(round_num - 1),
          msg=latest_checkpoint_path)
Exemplo n.º 2
0
    def test_load_latest_state(self):
        server_optimizer_fn = functools.partial(tf.keras.optimizers.SGD,
                                                learning_rate=0.1,
                                                momentum=0.9)

        iterative_process = tff.learning.build_federated_averaging_process(
            models.model_fn, server_optimizer_fn=server_optimizer_fn)
        server_state = iterative_process.initialize()
        # TODO(b/130724878): These conversions should not be needed.
        obj_1 = Obj.from_anon_tuple(server_state, 1)
        export_dir = os.path.join(self.get_temp_dir(), 'ckpt_1')
        checkpoint_utils.save(obj_1, export_dir)

        # TODO(b/130724878): These conversions should not be needed.
        obj_2 = Obj.from_anon_tuple(server_state, 2)
        export_dir = os.path.join(self.get_temp_dir(), 'ckpt_2')
        checkpoint_utils.save(obj_2, export_dir)

        export_dir = checkpoint_utils.latest_checkpoint(self.get_temp_dir())

        loaded_obj = checkpoint_utils.load(export_dir, obj_1)

        self.assertEqual(os.path.join(self.get_temp_dir(), 'ckpt_2'),
                         export_dir)
        self.assertAllClose(tf.nest.flatten(obj_2),
                            tf.nest.flatten(loaded_obj))
Exemplo n.º 3
0
def maybe_read_latest_checkpoint(root_checkpoint_dir, server_state):
  """Returns server_state, round_num, possibly from a recent checkpoint."""
  if root_checkpoint_dir is None:
    latest_checkpoint_dir = None
  else:
    latest_checkpoint_dir = checkpoint_utils.latest_checkpoint(
        root_checkpoint_dir, CHECKPOINT_PREFIX)
    logging.info('Looking for checkpoints in [%s/%s].', root_checkpoint_dir,
                 CHECKPOINT_PREFIX)
  if latest_checkpoint_dir is None:
    write_checkpoint(root_checkpoint_dir, server_state, 0)
    logging.info('No previous checkpoints, initializing experiment.')
    return server_state, 0
  else:
    server_state, round_num = read_checkpoint(latest_checkpoint_dir,
                                              server_state)
    round_num = int(round_num.numpy())
    logging.info('Restarting from checkpoint round %d.', round_num)
    return server_state, round_num
Exemplo n.º 4
0
def federated_averaging_training_loop(model_fn,
                                      server_optimizer_fn,
                                      client_datasets_fn,
                                      total_rounds=500,
                                      rounds_per_eval=1,
                                      metrics_hook=lambda *args: None):
    """A simple example of training loop for the Federated Averaging algorithm.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    server_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer`.
    client_datasets_fn: A function that takes the round number, and returns a
      list of `tf.data.Datset`, one per client.
    total_rounds: Number of rounds to train.
    rounds_per_eval: How often to call the  `metrics_hook` function.
    metrics_hook: A function taking arguments (server_state, train_metrics,
      round_num) and performs evaluation. Optional.

  Returns:
    Final `ServerState`.
  """
    logging.info('Starting federated_training_loop')
    checkpoint_dir = os.path.join(FLAGS.root_output_dir, FLAGS.exp_name)

    if FLAGS.server_optimizer != 'flars':
        iterative_process = simple_fedavg.build_federated_averaging_process(
            model_fn, server_optimizer_fn=server_optimizer_fn)
        ServerState = simple_fedavg.ServerState  # pylint: disable=invalid-name
    else:
        iterative_process = flars_fedavg.build_federated_averaging_process(
            model_fn, server_optimizer_fn=server_optimizer_fn)
        ServerState = flars_fedavg.ServerState  # pylint: disable=invalid-name

    server_state = ServerState.from_anon_tuple(iterative_process.initialize())
    round_num = None
    train_metrics = {}

    latest_checkpoint_dir = checkpoint_utils.latest_checkpoint(
        checkpoint_dir, CHECKPOINT_PREFIX)
    logging.info('Looking for checkpoints in [%s/%s]', checkpoint_dir,
                 CHECKPOINT_PREFIX)
    while latest_checkpoint_dir is not None:
        # Restart from a previous round.
        logging.info('Loading a previous checkpoint')
        try:
            server_state, metrics_hook.results, round_num = read_checkpoint(
                latest_checkpoint_dir, server_state)
            break
        except OSError as e:
            # Likely corrupted checkpoint, possibly job died while writing. Delete the
            # checkpoint directory and try again.
            logging.error('Exception [%s]', e)
            logging.warning('Deleteing likely corrupted checkpoint at [%s]',
                            latest_checkpoint_dir)
            tf.io.gfile.rmtree(latest_checkpoint_dir)
            latest_checkpoint_dir = checkpoint_utils.latest_checkpoint(
                checkpoint_dir, CHECKPOINT_PREFIX)

    if round_num is not None:
        logging.info('Restarted from checkpoint round %d', round_num)
    else:
        # Write the initial checkpoint
        logging.info('No previous checkpoints, initializing experiment')
        round_num = 0
        metrics_hook(server_state, train_metrics, round_num)
        write_checkpoint(checkpoint_dir, server_state, metrics_hook.results,
                         round_num)

    while round_num < total_rounds:
        round_num += 1
        # Reset the executor to clear the cache, and clear the default graph to
        # garbage collect tf.Functions that will no longer be used.
        tff.framework.set_default_executor(
            tff.framework.create_local_executor(max_fanout=25))
        tf.compat.v1.reset_default_graph()

        round_start_time = time.time()
        data_prep_start_time = time.time()
        federated_train_data = client_datasets_fn(round_num)
        train_metrics['prepare_datasets_secs'] = time.time(
        ) - data_prep_start_time

        training_start_time = time.time()
        anon_tuple_server_state, tff_train_metrics = iterative_process.next(
            server_state, federated_train_data)
        server_state = ServerState.from_anon_tuple(anon_tuple_server_state)
        train_metrics.update(tff_train_metrics._asdict(recursive=True))
        train_metrics['training_secs'] = time.time() - training_start_time

        logging.info('Round {:2d} elapsed time: {:.2f}s .'.format(
            round_num, (time.time() - round_start_time)))
        train_metrics['total_round_secs'] = time.time() - round_start_time

        if round_num % FLAGS.rounds_per_checkpoint == 0:
            write_checkpoint_start_time = time.time()
            write_checkpoint(checkpoint_dir, server_state,
                             metrics_hook.results, round_num)
            train_metrics['write_checkpoint_secs'] = (
                time.time() - write_checkpoint_start_time)

        if round_num % rounds_per_eval == 0:
            metrics_hook(server_state, train_metrics, round_num)

    metrics_hook(server_state, train_metrics, total_rounds)
    write_checkpoint(checkpoint_dir, server_state, metrics_hook.results,
                     round_num)

    return server_state