Beispiel #1
0
  def test_log_no_root_dir(self):
    logger = logging.Logger()

    logger.log(
        writer_name='train', metric_name='loss', metric_value=4., round_num=0)

    self.assertEmpty(logger._summary_writers)
Beispiel #2
0
  def test_log_root_dir(self):
    root_dir = self.create_tempdir()
    logger = logging.Logger(root_dir)

    logger.log(
        writer_name='train', metric_name='loss', metric_value=4.1, round_num=0)
    logger.log(
        writer_name='eval', metric_name='loss', metric_value=5.3, round_num=0)

    self.assertCountEqual(['train', 'eval'], os.listdir(root_dir))
Beispiel #3
0
def run_federated_experiment(
    algorithm: federated_algorithm.FederatedAlgorithm,
    init_state: federated_algorithm.ServerState,
    client_sampler: client_samplers.ClientSampler,
    config: FederatedExperimentConfig,
    periodic_eval_fn_map: Optional[Mapping[str, Any]] = None,
    final_eval_fn_map: Optional[Mapping[str, EvaluationFn]] = None
) -> federated_algorithm.ServerState:
    """Runs the training loop of a federated algorithm experiment.

  Args:
    algorithm: Federated algorithm to use.
    init_state: Initial server state.
    client_sampler: Sampler for training clients.
    config: FederatedExperimentConfig configurations.
    periodic_eval_fn_map: Mapping of name to evaluation functions that are run
      repeatedly over multiple federated training rounds. The frequency is
      defined in `_FederatedExperimentConfig.eval_frequency`.
    final_eval_fn_map: Mapping of name to evaluation functions that are run at
      the very end of federated training. Typically, full test evaluation
      functions will be set here.

  Returns:
    Final state of the input federated algortihm after training.
  """
    if config.root_dir:
        tf.io.gfile.makedirs(config.root_dir)

    if periodic_eval_fn_map is None:
        periodic_eval_fn_map = {}
    if final_eval_fn_map is None:
        final_eval_fn_map = {}

    logger = fedjax_logging.Logger(config.root_dir)

    latest = checkpoint.load_latest_checkpoint(config.root_dir)
    if latest:
        state, last_round_num = latest
        start_round_num = last_round_num + 1
    else:
        state = init_state
        start_round_num = 1
    client_sampler.set_round_num(start_round_num)

    start = time.time()
    for round_num in range(start_round_num, config.num_rounds + 1):
        # Get a random state and randomly sample clients.
        clients = client_sampler.sample()
        client_ids = [i[0] for i in clients]
        logging.info('round_num %d: client_ids = %s', round_num, client_ids)

        # Run one round of the algorithm, where bulk of the work happens.
        state, _ = algorithm.apply(state, clients)

        # Save checkpoint.
        should_save_checkpoint = config.checkpoint_frequency and (
            round_num == start_round_num
            or round_num % config.checkpoint_frequency == 0)
        if should_save_checkpoint:
            checkpoint.save_checkpoint(config.root_dir, state, round_num,
                                       config.num_checkpoints_to_keep)

        # Run evaluation.
        should_run_eval = config.eval_frequency and (
            round_num == start_round_num
            or round_num % config.eval_frequency == 0)
        if should_run_eval:
            start_periodic_eval = time.time()
            for eval_name, eval_fn in periodic_eval_fn_map.items():
                if isinstance(eval_fn, EvaluationFn):
                    metrics = eval_fn(state, round_num)
                elif isinstance(eval_fn, TrainClientsEvaluationFn):
                    metrics = eval_fn(state, round_num, clients)
                else:
                    raise ValueError(f'Invalid eval_fn type {type(eval_fn)}')
                if metrics:
                    for metric_name, metric_value in metrics.items():
                        logger.log(eval_name, metric_name, metric_value,
                                   round_num)
            logger.log('.', 'periodic_eval_duration_sec',
                       time.time() - start_periodic_eval, round_num)

        # Log the time it takes per round. Rough approximation since we're not
        # using DeviceArray.block_until_ready()
        logger.log('.', 'mean_round_duration_sec',
                   (time.time() - start) / (round_num + 1 - start_round_num),
                   round_num)

    # Block until previous work has finished.
    jnp.zeros([]).block_until_ready()

    # Logging overall time it took.
    num_rounds = config.num_rounds - start_round_num + 1
    mean_round_duration = ((time.time() - start) /
                           num_rounds if num_rounds > 0 else 0)

    # Final evaluation.
    final_eval_start = time.time()
    for eval_name, eval_fn in final_eval_fn_map.items():
        metrics = eval_fn(state, round_num)
        if metrics:
            metrics_path = os.path.join(config.root_dir, f'{eval_name}.tsv')
            with tf.io.gfile.GFile(metrics_path, 'w') as f:
                f.write('\t'.join(metrics.keys()) + '\n')
                f.write('\t'.join([str(v) for v in metrics.values()]))
    # DeviceArray.block_until_ready() isn't needed here since we write to file.
    final_eval_duration = time.time() - final_eval_start
    logging.info('mean_round_duration = %f sec.', mean_round_duration)
    logging.info('final_eval_duration = %f sec.', final_eval_duration)
    return state
Beispiel #4
0
def run_federated_experiment(
    config: FederatedExperimentConfig,
    federated_algorithm: core.FederatedAlgorithm[T],
    periodic_eval_fn_map: Optional[Mapping[str, Callable[
        [T, int], core.MetricResults]]] = None,
    final_eval_fn_map: Optional[Mapping[str,
                                        Callable[[T, int],
                                                 core.MetricResults]]] = None
) -> T:
    """Runs federated algorithm experiment and auxiliary processes.

  Args:
    config: FederatedExperimentConfig configurations.
    federated_algorithm: FederatedAlgorithm to be run over multiple rounds.
    periodic_eval_fn_map: Mapping of name to evaluation functions that are run
      repeatedly over multiple federated training rounds. The frequency is
      defined in `_FederatedExperimentConfig.eval_frequency`.
    final_eval_fn_map: Mapping of name to evaluation functions that are run at
      the very end of federated training. Typically, full test evaluation
      functions will be set here.

  Returns:
    Final state of the input federated algortihm after training.
  """
    if config.root_dir:
        tf.io.gfile.makedirs(config.root_dir)

    if periodic_eval_fn_map is None:
        periodic_eval_fn_map = {}
    if final_eval_fn_map is None:
        final_eval_fn_map = {}

    logger = fedjax_logging.Logger(config.root_dir)

    latest = checkpoint.load_latest_checkpoint(config.root_dir)
    if latest:
        state, start_round_num = latest
    else:
        state = federated_algorithm.init_state()
        start_round_num = 1

    start = time.time()
    for round_num in range(start_round_num, config.num_rounds + 1):
        # Get a random state and randomly sample clients.
        random_state = get_pseudo_random_state(
            round_num, config.sample_client_random_seed)
        client_ids = list(
            random_state.choice(federated_algorithm.federated_data.client_ids,
                                size=config.num_clients_per_round,
                                replace=False))
        logging.info('round_num %d: client_ids = %s', round_num, client_ids)

        # Run one round of the algorithm, where bulk of the work happens.
        state = federated_algorithm.run_round(state, client_ids)

        # Save checkpoint.
        should_save_checkpoint = config.checkpoint_frequency and (
            round_num == start_round_num
            or round_num % config.checkpoint_frequency == 0)
        if should_save_checkpoint:
            checkpoint.save_checkpoint(config.root_dir, state, round_num,
                                       config.num_checkpoints_to_keep)

        # Run evaluation.
        should_run_eval = config.eval_frequency and (
            round_num == start_round_num
            or round_num % config.eval_frequency == 0)
        if should_run_eval:
            start_periodic_eval = time.time()
            for eval_name, eval_fn in periodic_eval_fn_map.items():
                metrics = eval_fn(state, round_num)
                if metrics:
                    for metric_name, metric_value in metrics.items():
                        logger.log(eval_name, metric_name, metric_value,
                                   round_num)
            logger.log('.', 'periodic_eval_duration_sec',
                       time.time() - start_periodic_eval, round_num)

        # Log the time it takes per round. Rough approximation since we're not
        # using DeviceArray.block_until_ready()
        logger.log('.', 'mean_round_duration_sec',
                   (time.time() - start) / (round_num + 1 - start_round_num),
                   round_num)

    # Logging overall time it took.
    # DeviceArray.block_until_ready() is needed for accurate timing due to
    # https://jax.readthedocs.io/en/latest/async_dispatch.html.
    _block_until_ready_state(state)
    num_rounds = config.num_rounds - start_round_num
    mean_round_duration = ((time.time() - start) /
                           num_rounds if num_rounds > 0 else 0)

    # Final evaluation.
    final_eval_start = time.time()
    for eval_name, eval_fn in final_eval_fn_map.items():
        metrics = eval_fn(state, round_num)
        if metrics:
            metrics_path = os.path.join(config.root_dir, f'{eval_name}.tsv')
            with tf.io.gfile.GFile(metrics_path, 'w') as f:
                f.write('\t'.join(metrics.keys()) + '\n')
                f.write('\t'.join([str(v) for v in metrics.values()]))
    # DeviceArray.block_until_ready() isn't needed here since we write to file.
    final_eval_duration = time.time() - final_eval_start
    logging.info('mean_round_duration = %f sec.', mean_round_duration)
    logging.info('final_eval_duration = %f sec.', final_eval_duration)
    return state