Esempio n. 1
0
def _setup_outputs(root_output_dir, experiment_name, hparam_dict):
  """Set up directories for experiment loops, write hyperparameters to disk."""

  if not experiment_name:
    raise ValueError('experiment_name must be specified.')

  create_if_not_exists(root_output_dir)

  checkpoint_dir = os.path.join(root_output_dir, 'checkpoints', experiment_name)
  create_if_not_exists(checkpoint_dir)
  checkpoint_mngr = checkpoint_manager.FileCheckpointManager(checkpoint_dir)

  results_dir = os.path.join(root_output_dir, 'results', experiment_name)
  create_if_not_exists(results_dir)
  csv_file = os.path.join(results_dir, 'experiment.metrics.csv')
  metrics_mngr = csv_manager.CSVMetricsManager(csv_file)

  summary_logdir = os.path.join(root_output_dir, 'logdir', experiment_name)
  create_if_not_exists(summary_logdir)
  summary_writer = tf.summary.create_file_writer(summary_logdir)

  if hparam_dict:
    hparam_dict['metrics_file'] = metrics_mngr.metrics_filename
    hparams_file = os.path.join(results_dir, 'hparams.csv')
    utils_impl.atomic_write_to_csv(pd.Series(hparam_dict), hparams_file)
    with summary_writer.as_default():
      hp.hparams({k: v for k, v in hparam_dict.items() if v is not None})

  logging.info('Writing...')
  logging.info('    checkpoints to: %s', checkpoint_dir)
  logging.info('    metrics csv to: %s', metrics_mngr.metrics_filename)
  logging.info('    summaries to: %s', summary_logdir)

  return checkpoint_mngr, metrics_mngr, summary_writer
    def test_raises_file_not_found_error_with_no_checkpoint(self):
        temp_dir = self.get_temp_dir()
        checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir)
        structure = _create_dummy_state()

        with self.assertRaises(FileNotFoundError):
            _ = checkpoint_mngr.load_checkpoint(structure, 0)
    def test_saves_one_checkpoint(self):
        temp_dir = self.get_temp_dir()
        checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir)

        dummy_state_1 = _create_dummy_state(1)
        checkpoint_mngr.save_checkpoint(dummy_state_1, 1)

        self.assertCountEqual(os.listdir(temp_dir), ['ckpt_1'])
    def test_raises_value_error_with_bad_structure(self):
        temp_dir = self.get_temp_dir()
        checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir)
        dummy_state_1 = _create_dummy_state(1)
        checkpoint_mngr.save_checkpoint(dummy_state_1, 1)
        structure = None

        with self.assertRaises(ValueError):
            checkpoint_mngr.load_latest_checkpoint(structure)
    def test_returns_none_and_zero_with_no_checkpoints(self):
        temp_dir = self.get_temp_dir()
        checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir)
        structure = _create_dummy_state()

        state, round_num = checkpoint_mngr.load_latest_checkpoint(structure)

        self.assertIsNone(state)
        self.assertEqual(round_num, 0)
    def test_raises_already_exists_error_with_existing_round_number(self):
        temp_dir = self.get_temp_dir()
        checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir)

        dummy_state_1 = _create_dummy_state(1)
        checkpoint_mngr.save_checkpoint(dummy_state_1, 1)

        with self.assertRaises(tf.errors.AlreadyExistsError):
            checkpoint_mngr.save_checkpoint(dummy_state_1, 1)
    def test_returns_state_with_one_checkpoint(self):
        temp_dir = self.get_temp_dir()
        checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir)
        dummy_state_1 = _create_dummy_state(1)
        checkpoint_mngr.save_checkpoint(dummy_state_1, 1)
        structure = _create_dummy_state()

        state = checkpoint_mngr.load_checkpoint(structure, 1)

        self.assertEqual(state, dummy_state_1)
Esempio n. 8
0
    def test_saves_and_returns_structure_and_zero_with_no_checkpoints(self):
        temp_dir = self.get_temp_dir()
        checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir)
        structure = _create_dummy_state()

        state, round_num = checkpoint_mngr.load_latest_checkpoint_or_default(
            structure)

        self.assertEqual(state, structure)
        self.assertEqual(round_num, 0)
        self.assertCountEqual(os.listdir(temp_dir), ['ckpt_0'])
    def test_saves_three_checkpoints(self):
        temp_dir = self.get_temp_dir()
        checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir)

        dummy_state_1 = _create_dummy_state(1)
        checkpoint_mngr.save_checkpoint(dummy_state_1, 1)
        dummy_state_2 = _create_dummy_state(2)
        checkpoint_mngr.save_checkpoint(dummy_state_2, 2)
        dummy_state_3 = _create_dummy_state(3)
        checkpoint_mngr.save_checkpoint(dummy_state_3, 3)

        self.assertCountEqual(os.listdir(temp_dir),
                              ['ckpt_1', 'ckpt_2', 'ckpt_3'])
    def test_returns_state_with_three_checkpoint_for_first_round(self):
        temp_dir = self.get_temp_dir()
        checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir)
        dummy_state_1 = _create_dummy_state(1)
        checkpoint_mngr.save_checkpoint(dummy_state_1, 1)
        dummy_state_2 = _create_dummy_state(2)
        checkpoint_mngr.save_checkpoint(dummy_state_2, 2)
        dummy_state_3 = _create_dummy_state(3)
        checkpoint_mngr.save_checkpoint(dummy_state_3, 3)
        structure = _create_dummy_state()

        state = checkpoint_mngr.load_checkpoint(structure, 1)

        self.assertEqual(state, dummy_state_1)
    def test_removes_oldest_with_keep_first_false(self):
        temp_dir = self.get_temp_dir()
        checkpoint_mngr = checkpoint_manager.FileCheckpointManager(
            temp_dir, keep_total=3, keep_first=False)

        dummy_state_1 = _create_dummy_state(1)
        checkpoint_mngr.save_checkpoint(dummy_state_1, 1)
        dummy_state_2 = _create_dummy_state(2)
        checkpoint_mngr.save_checkpoint(dummy_state_2, 2)
        dummy_state_3 = _create_dummy_state(3)
        checkpoint_mngr.save_checkpoint(dummy_state_3, 3)
        dummy_state_4 = _create_dummy_state(4)
        checkpoint_mngr.save_checkpoint(dummy_state_4, 4)

        self.assertCountEqual(os.listdir(temp_dir),
                              ['ckpt_2', 'ckpt_3', 'ckpt_4'])
Esempio n. 12
0
def _setup_outputs(root_output_dir,
                   experiment_name,
                   hparam_dict,
                   rounds_per_profile=0):
    """Set up directories for experiment loops, write hyperparameters to disk."""

    if not experiment_name:
        raise ValueError('experiment_name must be specified.')

    create_if_not_exists(root_output_dir)

    checkpoint_dir = os.path.join(root_output_dir, 'checkpoints',
                                  experiment_name)
    create_if_not_exists(checkpoint_dir)
    checkpoint_mngr = checkpoint_manager.FileCheckpointManager(checkpoint_dir)

    results_dir = os.path.join(root_output_dir, 'results', experiment_name)
    create_if_not_exists(results_dir)
    csv_file = os.path.join(results_dir, 'experiment.metrics.csv')
    metrics_mngr = csv_manager.CSVMetricsManager(csv_file)

    summary_logdir = os.path.join(root_output_dir, 'logdir', experiment_name)
    tb_mngr = tensorboard_manager.TensorBoardManager(
        summary_dir=summary_logdir)

    if hparam_dict:
        hparam_dict['metrics_file'] = metrics_mngr.metrics_filename
        hparams_file = os.path.join(results_dir, 'hparams.csv')
        utils_impl.atomic_write_to_csv(pd.Series(hparam_dict), hparams_file)
        tb_mngr.update_hparams(
            {k: v
             for k, v in hparam_dict.items() if v is not None})

    logging.info('Writing...')
    logging.info('    checkpoints to: %s', checkpoint_dir)
    logging.info('    metrics csv to: %s', metrics_mngr.metrics_filename)
    logging.info('    summaries to: %s', summary_logdir)

    @contextlib.contextmanager
    def profiler(round_num):
        if (rounds_per_profile > 0 and round_num % rounds_per_profile == 0):
            with tf.profiler.experimental.Profile(summary_logdir):
                yield
        else:
            yield

    return checkpoint_mngr, metrics_mngr, tb_mngr, profiler
Esempio n. 13
0
def _setup_outputs(root_output_dir,
                   experiment_name,
                   hparam_dict,
                   write_metrics_with_bz2=True,
                   rounds_per_profile=0):
    """Set up directories for experiment loops, write hyperparameters to disk."""

    if not experiment_name:
        raise ValueError('experiment_name must be specified.')

    create_if_not_exists(root_output_dir)

    checkpoint_dir = os.path.join(root_output_dir, 'checkpoints',
                                  experiment_name)
    create_if_not_exists(checkpoint_dir)
    checkpoint_mngr = checkpoint_manager.FileCheckpointManager(checkpoint_dir)

    results_dir = os.path.join(root_output_dir, 'results', experiment_name)
    create_if_not_exists(results_dir)
    metrics_mngr = metrics_manager.ScalarMetricsManager(
        results_dir, use_bz2=write_metrics_with_bz2)

    summary_logdir = os.path.join(root_output_dir, 'logdir', experiment_name)
    create_if_not_exists(summary_logdir)
    summary_writer = tf.summary.create_file_writer(summary_logdir)

    if hparam_dict:
        hparam_dict['metrics_file'] = metrics_mngr.metrics_filename
        hparams_file = os.path.join(results_dir, 'hparams.csv')
        utils_impl.atomic_write_to_csv(pd.Series(hparam_dict), hparams_file)

    logging.info('Writing...')
    logging.info('    checkpoints to: %s', checkpoint_dir)
    logging.info('    metrics csv to: %s', metrics_mngr.metrics_filename)
    logging.info('    summaries to: %s', summary_logdir)

    @contextlib.contextmanager
    def profiler(round_num):
        if (rounds_per_profile > 0 and round_num % rounds_per_profile == 0):
            with tf.profiler.experimental.Profile(summary_logdir):
                yield
        else:
            yield

    return checkpoint_mngr, metrics_mngr, summary_writer, profiler
    def test_checkpoint_manager_saves_state(self):
        experiment_name = 'checkpoint_manager_saves_state'
        iterative_process = _build_federated_averaging_process()
        federated_data = [[_batch_fn()]]

        def client_datasets_fn(round_num):
            del round_num
            return federated_data

        def validation_fn(model):
            del model
            return {}

        root_output_dir = self.get_temp_dir()
        final_state = training_loop.run(iterative_process=iterative_process,
                                        client_datasets_fn=client_datasets_fn,
                                        validation_fn=validation_fn,
                                        total_rounds=1,
                                        experiment_name=experiment_name,
                                        root_output_dir=root_output_dir)
        final_model = iterative_process.get_model_weights(final_state)

        ckpt_manager = checkpoint_manager.FileCheckpointManager(
            os.path.join(root_output_dir, 'checkpoints', experiment_name))
        restored_state, restored_round = ckpt_manager.load_latest_checkpoint(
            final_state)

        self.assertEqual(restored_round, 0)

        keras_model = tff.simulation.models.mnist.create_keras_model(
            compile_model=True)
        restored_model = iterative_process.get_model_weights(restored_state)

        restored_model.assign_weights_to(keras_model)
        restored_loss = keras_model.test_on_batch(federated_data[0][0].x,
                                                  federated_data[0][0].y)

        final_model.assign_weights_to(keras_model)
        final_loss = keras_model.test_on_batch(federated_data[0][0].x,
                                               federated_data[0][0].y)
        self.assertEqual(final_loss, restored_loss)
Esempio n. 15
0
    def test_checkpoint_manager_saves_state(self):
        experiment_name = 'checkpoint_manager_saves_state'
        iterative_process = _build_federated_averaging_process()

        def client_datasets_fn(round_num):
            del round_num
            return _federated_data()

        def evaluation_fn(model, round_num):
            del model, round_num
            return {}

        root_output_dir = self.get_temp_dir()
        final_state = training_loop.run(
            iterative_process=iterative_process,
            train_client_datasets_fn=client_datasets_fn,
            evaluation_fn=evaluation_fn,
            total_rounds=1,
            experiment_name=experiment_name,
            root_output_dir=root_output_dir)
        final_model = iterative_process.get_model_weights(final_state)

        ckpt_manager = checkpoint_manager.FileCheckpointManager(
            os.path.join(root_output_dir, 'checkpoints', experiment_name))
        restored_state, restored_round = ckpt_manager.load_latest_checkpoint(
            final_state)

        self.assertEqual(restored_round, 0)

        keras_model = _compiled_keras_model_builder()
        restored_model = iterative_process.get_model_weights(restored_state)

        restored_model.assign_weights_to(keras_model)
        batch = next(iter(_create_tf_dataset_for_client(5)))
        restored_loss = keras_model.test_on_batch(batch['x'], batch['y'])

        final_model.assign_weights_to(keras_model)
        final_loss = keras_model.test_on_batch(batch['x'], batch['y'])
        self.assertEqual(final_loss, restored_loss)
Esempio n. 16
0
def _federated_averaging_training_loop(model_fn,
                                       client_optimizer_fn,
                                       server_optimizer_fn,
                                       client_datasets_fn,
                                       evaluate_fn,
                                       total_rounds=500,
                                       rounds_per_eval=1,
                                       metrics_hook=None):
  """A simple example of training loop for the Federated Averaging algorithm.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    client_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer`.
    server_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer`.
    client_datasets_fn: A function that takes the round number, and returns a
      list of `tf.data.Datset`, one per client.
    evaluate_fn: A function that takes state, performs evaluation, and returns
      evaluations metrics.
    total_rounds: Number of rounds to train.
    rounds_per_eval: How often to call the  `metrics_hook` function.
    metrics_hook: A function taking arguments (training metrics,
      evaluation metrics, and round number). Optional.

  Returns:
    Final `ServerState`.
  """
  logging.info('Starting federated training loop')

  checkpoint_dir = os.path.join(FLAGS.root_output_dir, FLAGS.exp_name)
  checkpoint_manager_obj = checkpoint_manager.FileCheckpointManager(
      checkpoint_dir)

  if FLAGS.server_optimizer != 'flars':
    logging.error('Unsupported server_optimzier: %s', FLAGS.server_optimizer)
  else:
    iterative_process = flars_fedavg.build_federated_averaging_process(
        model_fn,
        client_optimizer_fn=client_optimizer_fn,
        server_optimizer_fn=server_optimizer_fn)

  # construct an initial state here to act as a checkpoint template
  inital_state = iterative_process.initialize()

  logging.info('Looking for checkpoints in \'%s\'', checkpoint_dir)
  state, round_num = checkpoint_manager_obj.load_latest_checkpoint(inital_state)

  if state is None:
    logging.info('No previous checkpoints, initializing experiment')
    state = inital_state
    round_num = 0
    if metrics_hook is not None:
      eval_metrics = evaluate_fn(state)
      metrics_hook({}, eval_metrics, round_num)
    checkpoint_manager_obj.save_checkpoint(state, 0)
  else:
    logging.info('Restarted from checkpoint round %d', round_num)

  while round_num < total_rounds:
    round_num += 1
    train_metrics = {}

    # Reset the executor to clear the cache, and clear the default graph to
    # garbage collect tf.Functions that will no longer be used.
    tff.backends.native.set_local_execution_context(max_fanout=25)
    tf.compat.v1.reset_default_graph()

    round_start_time = time.time()
    data_prep_start_time = time.time()
    train_data = client_datasets_fn(round_num)
    train_metrics['prepare_datasets_secs'] = time.time() - data_prep_start_time

    training_start_time = time.time()
    state, tff_train_metrics = iterative_process.next(state, train_data)
    tff_train_metrics = tff_train_metrics._asdict(recursive=True)
    train_metrics.update(tff_train_metrics)
    train_metrics['training_secs'] = time.time() - training_start_time

    logging.info('Round {:2d} elapsed time: {:.2f}s .'.format(
        round_num, (time.time() - round_start_time)))
    train_metrics['total_round_secs'] = time.time() - round_start_time

    if (round_num % FLAGS.rounds_per_checkpoint == 0 or
        round_num == total_rounds):
      save_checkpoint_start_time = time.time()
      checkpoint_manager_obj.save_checkpoint(state, round_num)
      train_metrics['save_checkpoint_secs'] = (
          time.time() - save_checkpoint_start_time)

    if round_num % rounds_per_eval == 0 or round_num == total_rounds:
      if metrics_hook is not None:
        eval_metrics = evaluate_fn(state)
        metrics_hook(train_metrics, eval_metrics, round_num)