コード例 #1
0
  def test_raises_file_not_found_error_with_no_checkpoint(self):
    temp_dir = self.get_temp_dir()
    checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir)
    structure = _create_dummy_state()

    with self.assertRaises(FileNotFoundError):
      _ = checkpoint_mngr.load_checkpoint(structure, 0)
コード例 #2
0
def _setup_outputs(root_output_dir, experiment_name, hparam_dict):
  """Set up directories for experiment loops, write hyperparameters to disk."""

  if not experiment_name:
    raise ValueError('experiment_name must be specified.')

  create_if_not_exists(root_output_dir)

  checkpoint_dir = os.path.join(root_output_dir, 'checkpoints', experiment_name)
  create_if_not_exists(checkpoint_dir)
  checkpoint_mngr = checkpoint_manager.FileCheckpointManager(checkpoint_dir)

  results_dir = os.path.join(root_output_dir, 'results', experiment_name)
  create_if_not_exists(results_dir)
  metrics_mngr = metrics_manager.ScalarMetricsManager(results_dir)

  summary_logdir = os.path.join(root_output_dir, 'logdir', experiment_name)
  create_if_not_exists(summary_logdir)
  summary_writer = tf.compat.v2.summary.create_file_writer(summary_logdir)

  hparam_dict['metrics_file'] = metrics_mngr.metrics_filename
  hparams_file = os.path.join(results_dir, 'hparams.csv')
  utils_impl.atomic_write_to_csv(pd.Series(hparam_dict), hparams_file)

  logging.info('Writing...')
  logging.info('    checkpoints to: %s', checkpoint_dir)
  logging.info('    metrics csv to: %s', metrics_mngr.metrics_filename)
  logging.info('    summaries to: %s', summary_logdir)

  return checkpoint_mngr, metrics_mngr, summary_writer
コード例 #3
0
  def test_saves_one_checkpoint(self):
    temp_dir = self.get_temp_dir()
    checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir)

    dummy_state_1 = _create_dummy_state(1)
    checkpoint_mngr.save_checkpoint(dummy_state_1, 1)

    self.assertCountEqual(os.listdir(temp_dir), ['ckpt_1'])
コード例 #4
0
  def test_saves_one_checkpoint(self):
    temp_dir = self.get_temp_dir()
    checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir)
    state = _create_dummy_state()

    checkpoint_mngr.save_checkpoint(state, 1)

    self.assertEqual(set(os.listdir(temp_dir)), set(['ckpt_1']))
コード例 #5
0
  def test_returns_none_and_zero_with_no_checkpoints(self):
    temp_dir = self.get_temp_dir()
    checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir)
    structure = _create_dummy_structure()

    state, round_num = checkpoint_mngr.load_latest_checkpoint(structure)

    self.assertIsNone(state)
    self.assertEqual(round_num, 0)
コード例 #6
0
  def test_raises_already_exists_error_with_existing_round_number(self):
    temp_dir = self.get_temp_dir()
    checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir)
    state = _create_dummy_state()

    checkpoint_mngr.save_checkpoint(state, 1)

    with self.assertRaises(tf.errors.AlreadyExistsError):
      checkpoint_mngr.save_checkpoint(state, 1)
コード例 #7
0
  def test_raises_value_error_with_bad_structure(self):
    temp_dir = self.get_temp_dir()
    checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir)
    state = _create_dummy_state()
    checkpoint_mngr.save_checkpoint(state, 1)
    structure = None

    with self.assertRaises(ValueError):
      _, _ = checkpoint_mngr.load_latest_checkpoint(structure)
コード例 #8
0
  def test_returns_state_with_one_checkpoint(self):
    temp_dir = self.get_temp_dir()
    checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir)
    dummy_state_1 = _create_dummy_state(1)
    checkpoint_mngr.save_checkpoint(dummy_state_1, 1)
    structure = _create_dummy_state()

    state = checkpoint_mngr.load_checkpoint(structure, 1)

    self.assertEqual(state, dummy_state_1)
コード例 #9
0
  def test_saves_and_returns_structure_and_zero_with_no_checkpoints(self):
    temp_dir = self.get_temp_dir()
    checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir)
    structure = _create_dummy_state()

    state, round_num = checkpoint_mngr.load_latest_checkpoint_or_default(
        structure)

    self.assertEqual(state, structure)
    self.assertEqual(round_num, 0)
    self.assertCountEqual(os.listdir(temp_dir), ['ckpt_0'])
コード例 #10
0
  def test_returns_state_and_round_num_with_one_checkpoint(self):
    temp_dir = self.get_temp_dir()
    checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir)
    state = _create_dummy_state()
    checkpoint_mngr.save_checkpoint(state, 1)
    structure = _create_dummy_structure()

    state, round_num = checkpoint_mngr.load_latest_checkpoint(structure)

    expected_state = _create_dummy_state()
    self.assertEqual(state, expected_state)
    self.assertEqual(round_num, 1)
コード例 #11
0
  def test_returns_default_and_zero_with_no_checkpoints_and_also_saves(self):
    temp_dir = self.get_temp_dir()
    checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir)
    default = _create_dummy_structure()

    state, round_num = checkpoint_mngr.load_latest_checkpoint_or_default(
        default)

    self.assertEqual(state, default)
    self.assertEqual(round_num, 0)

    self.assertEqual(set(os.listdir(temp_dir)), set(['ckpt_0']))
コード例 #12
0
  def test_saves_three_checkpoints(self):
    temp_dir = self.get_temp_dir()
    checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir)

    dummy_state_1 = _create_dummy_state(1)
    checkpoint_mngr.save_checkpoint(dummy_state_1, 1)
    dummy_state_2 = _create_dummy_state(2)
    checkpoint_mngr.save_checkpoint(dummy_state_2, 2)
    dummy_state_3 = _create_dummy_state(3)
    checkpoint_mngr.save_checkpoint(dummy_state_3, 3)

    self.assertCountEqual(os.listdir(temp_dir), ['ckpt_1', 'ckpt_2', 'ckpt_3'])
コード例 #13
0
  def test_removes_oldest_with_keep_first_false(self):
    temp_dir = self.get_temp_dir()
    checkpoint_mngr = checkpoint_manager.FileCheckpointManager(
        temp_dir, keep_total=3, keep_first=False)
    state = _create_dummy_state()

    checkpoint_mngr.save_checkpoint(state, 1)
    checkpoint_mngr.save_checkpoint(state, 2)
    checkpoint_mngr.save_checkpoint(state, 3)
    checkpoint_mngr.save_checkpoint(state, 4)

    self.assertEqual(
        set(os.listdir(temp_dir)), set(['ckpt_2', 'ckpt_3', 'ckpt_4']))
コード例 #14
0
  def test_returns_state_with_three_checkpoint_for_third_round(self):
    temp_dir = self.get_temp_dir()
    checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir)
    dummy_state_1 = _create_dummy_state(1)
    checkpoint_mngr.save_checkpoint(dummy_state_1, 1)
    dummy_state_2 = _create_dummy_state(2)
    checkpoint_mngr.save_checkpoint(dummy_state_2, 2)
    dummy_state_3 = _create_dummy_state(3)
    checkpoint_mngr.save_checkpoint(dummy_state_3, 3)
    structure = _create_dummy_state()

    state = checkpoint_mngr.load_checkpoint(structure, 3)

    self.assertEqual(state, dummy_state_3)
コード例 #15
0
  def test_returns_state_and_round_num_with_special_characters(self):
    temp_dir = self.get_temp_dir()
    special_char_dir = 'subdir(subdir_number=1*)/sub.dir(subdir_number=2?)'
    path = os.path.join(temp_dir, special_char_dir)
    checkpoint_mngr = checkpoint_manager.FileCheckpointManager(path)
    state = _create_dummy_state()
    checkpoint_mngr.save_checkpoint(state, 1)
    structure = _create_dummy_structure()

    state, round_num = checkpoint_mngr.load_latest_checkpoint(structure)

    expected_state = _create_dummy_state()
    self.assertEqual(state, expected_state)
    self.assertEqual(round_num, 1)
コード例 #16
0
  def test_removes_oldest_with_keep_first_false(self):
    temp_dir = self.get_temp_dir()
    checkpoint_mngr = checkpoint_manager.FileCheckpointManager(
        temp_dir, keep_total=3, keep_first=False)

    dummy_state_1 = _create_dummy_state(1)
    checkpoint_mngr.save_checkpoint(dummy_state_1, 1)
    dummy_state_2 = _create_dummy_state(2)
    checkpoint_mngr.save_checkpoint(dummy_state_2, 2)
    dummy_state_3 = _create_dummy_state(3)
    checkpoint_mngr.save_checkpoint(dummy_state_3, 3)
    dummy_state_4 = _create_dummy_state(4)
    checkpoint_mngr.save_checkpoint(dummy_state_4, 4)

    self.assertCountEqual(os.listdir(temp_dir), ['ckpt_2', 'ckpt_3', 'ckpt_4'])
コード例 #17
0
ファイル: training_loop.py プロジェクト: oodunsi1/federated-1
def _setup_outputs(root_output_dir,
                   experiment_name,
                   hparam_dict,
                   write_metrics_with_bz2=True,
                   rounds_per_profile=0):
    """Set up directories for experiment loops, write hyperparameters to disk."""

    if not experiment_name:
        raise ValueError('experiment_name must be specified.')

    create_if_not_exists(root_output_dir)

    checkpoint_dir = os.path.join(root_output_dir, 'checkpoints',
                                  experiment_name)
    create_if_not_exists(checkpoint_dir)
    checkpoint_mngr = checkpoint_manager.FileCheckpointManager(checkpoint_dir)

    results_dir = os.path.join(root_output_dir, 'results', experiment_name)
    create_if_not_exists(results_dir)
    metrics_mngr = metrics_manager.ScalarMetricsManager(
        results_dir, use_bz2=write_metrics_with_bz2)

    summary_logdir = os.path.join(root_output_dir, 'logdir', experiment_name)
    create_if_not_exists(summary_logdir)
    summary_writer = tf.summary.create_file_writer(summary_logdir)

    if hparam_dict:
        hparam_dict['metrics_file'] = metrics_mngr.metrics_filename
        hparams_file = os.path.join(results_dir, 'hparams.csv')
        utils_impl.atomic_write_to_csv(pd.Series(hparam_dict), hparams_file)

    logging.info('Writing...')
    logging.info('    checkpoints to: %s', checkpoint_dir)
    logging.info('    metrics csv to: %s', metrics_mngr.metrics_filename)
    logging.info('    summaries to: %s', summary_logdir)

    @contextlib.contextmanager
    def profiler(round_num):
        if (rounds_per_profile > 0 and round_num % rounds_per_profile == 0):
            with tf.profiler.experimental.Profile(summary_logdir):
                yield
        else:
            yield

    return checkpoint_mngr, metrics_mngr, summary_writer, profiler
コード例 #18
0
    def test_checkpoint_manager_saves_state(self):
        FLAGS.total_rounds = 1
        FLAGS.experiment_name = 'checkpoint_manager_saves_state'
        iterative_process = _build_federated_averaging_process()
        federated_data = [[_batch_fn()]]

        def client_datasets_fn(round_num):
            del round_num
            return federated_data

        def evaluate_fn(model):
            del model
            return {}

        temp_filepath = self.get_temp_dir()
        FLAGS.root_output_dir = temp_filepath
        final_state = training_loop.run(iterative_process, client_datasets_fn,
                                        evaluate_fn)

        ckpt_manager = checkpoint_manager.FileCheckpointManager(
            os.path.join(
                temp_filepath,
                'checkpoints',
                FLAGS.experiment_name,
            ))
        restored_state, restored_round = ckpt_manager.load_latest_checkpoint(
            final_state)

        self.assertEqual(restored_round, 0)

        keras_model = tff.simulation.models.mnist.create_keras_model(
            compile_model=True)
        restored_state.model.assign_weights_to(keras_model)
        restored_loss = keras_model.test_on_batch(federated_data[0][0].x,
                                                  federated_data[0][0].y)
        final_state.model.assign_weights_to(keras_model)
        final_loss = keras_model.test_on_batch(federated_data[0][0].x,
                                               federated_data[0][0].y)
        self.assertEqual(final_loss, restored_loss)
コード例 #19
0
def _federated_averaging_training_loop(model_fn,
                                       client_optimizer_fn,
                                       server_optimizer_fn,
                                       client_datasets_fn,
                                       evaluate_fn,
                                       total_rounds=500,
                                       rounds_per_eval=1,
                                       metrics_hook=None):
    """A simple example of training loop for the Federated Averaging algorithm.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    client_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer`.
    server_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer`.
    client_datasets_fn: A function that takes the round number, and returns a
      list of `tf.data.Datset`, one per client.
    evaluate_fn: A function that takes state, performs evaluation, and returns
      evaluations metrics.
    total_rounds: Number of rounds to train.
    rounds_per_eval: How often to call the  `metrics_hook` function.
    metrics_hook: A function taking arguments (training metrics,
      evaluation metrics, and round number). Optional.

  Returns:
    Final `ServerState`.
  """
    logging.info('Starting federated training loop')

    checkpoint_dir = os.path.join(FLAGS.root_output_dir, FLAGS.exp_name)
    checkpoint_manager_obj = checkpoint_manager.FileCheckpointManager(
        checkpoint_dir)

    if FLAGS.server_optimizer != 'flars':
        logging.error('Unsupported server_optimzier: %s',
                      FLAGS.server_optimizer)
    else:
        iterative_process = flars_fedavg.build_federated_averaging_process(
            model_fn,
            client_optimizer_fn=client_optimizer_fn,
            server_optimizer_fn=server_optimizer_fn)
        ServerState = flars_fedavg.ServerState  # pylint: disable=invalid-name

    # construct an initial state here to act as a checkpoint template
    inital_state = iterative_process.initialize()
    inital_state = ServerState.from_tff_result(inital_state)

    logging.info('Looking for checkpoints in \'%s\'', checkpoint_dir)
    state, round_num = checkpoint_manager_obj.load_latest_checkpoint(
        inital_state)

    if state is None:
        logging.info('No previous checkpoints, initializing experiment')
        state = inital_state
        round_num = 0
        if metrics_hook is not None:
            eval_metrics = evaluate_fn(state)
            metrics_hook({}, eval_metrics, round_num)
        checkpoint_manager_obj.save_checkpoint(state, 0)
    else:
        logging.info('Restarted from checkpoint round %d', round_num)

    while round_num < total_rounds:
        round_num += 1
        train_metrics = {}

        # Reset the executor to clear the cache, and clear the default graph to
        # garbage collect tf.Functions that will no longer be used.
        tff.framework.set_default_executor(
            tff.framework.local_executor_factory(max_fanout=25))
        tf.compat.v1.reset_default_graph()

        round_start_time = time.time()
        data_prep_start_time = time.time()
        train_data = client_datasets_fn(round_num)
        train_metrics['prepare_datasets_secs'] = time.time(
        ) - data_prep_start_time

        training_start_time = time.time()
        state, tff_train_metrics = iterative_process.next(state, train_data)
        state = ServerState.from_tff_result(state)
        tff_train_metrics = tff_train_metrics._asdict(recursive=True)
        train_metrics.update(tff_train_metrics)
        train_metrics['training_secs'] = time.time() - training_start_time

        logging.info('Round {:2d} elapsed time: {:.2f}s .'.format(
            round_num, (time.time() - round_start_time)))
        train_metrics['total_round_secs'] = time.time() - round_start_time

        if (round_num % FLAGS.rounds_per_checkpoint == 0
                or round_num == total_rounds):
            save_checkpoint_start_time = time.time()
            checkpoint_manager_obj.save_checkpoint(state, round_num)
            train_metrics['save_checkpoint_secs'] = (
                time.time() - save_checkpoint_start_time)

        if round_num % rounds_per_eval == 0 or round_num == total_rounds:
            if metrics_hook is not None:
                eval_metrics = evaluate_fn(state)
                metrics_hook(train_metrics, eval_metrics, round_num)
コード例 #20
0
 def setUp(self):
   super(FileCheckpointManagerLoadCheckpointTest, self).setUp()
   self._checkpoint_mngr = checkpoint_manager.FileCheckpointManager(
       self.get_temp_dir())