def run(iterative_process: tff.templates.IterativeProcess,
        client_datasets_fn: Callable[[int], List[tf.data.Dataset]],
        validation_fn: Callable[[Any], Dict[str, float]],
        total_rounds: int,
        experiment_name: str,
        train_eval_fn: Optional[Callable[[Any], Dict[str, float]]] = None,
        test_fn: Optional[Callable[[Any], Dict[str, float]]] = None,
        root_output_dir: Optional[str] = '/tmp/fed_opt',
        hparam_dict: Optional[Dict[str, Any]] = None,
        write_metrics_with_bz2: Optional[bool] = True,
        rounds_per_eval: Optional[int] = 1,
        rounds_per_checkpoint: Optional[int] = 50,
        rounds_per_train_eval: Optional[int] = 100,
        rounds_per_profile: Optional[int] = 0,
        clients_per_round: Optional[float] = 1.0):
  """Runs federated training for a given `tff.templates.IterativeProcess`.

  We assume that the iterative process has the following functional type
  signatures:

    *   `initialize`: `( -> S@SERVER)` where `S` represents the server state.
    *   `next`: `<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>` where `S`
        represents the server state, `{B*}` represents the client datasets,
        and `T` represents a python `Mapping` object.

  Moreover, the server state must have an attribute `model` that can be passed
  to `validation_fn`, `train_eval_fn`, and `test_fn` (if given).

  Args:
    iterative_process: A `tff.templates.IterativeProcess` instance to run.
    client_datasets_fn: Function accepting an integer argument (the round
      number) and returning a list of client datasets to use as federated data
      for that round, and a list of the corresponding client ids.
    validation_fn: A callable accepting the `model` attribute of the iterative
      process state and returning a dict of evaluation metrics. Used to compute
      validation metrics throughout the training process.
    total_rounds: The number of federated training rounds to perform.
    experiment_name: The name of the experiment being run. This will be appended
      to the `root_output_dir` for purposes of writing outputs.
    train_eval_fn: An optional callable accepting the `model` attribute of the
      iterative process state and returning a dict of evaluation metrics. Used
      to compute training metrics over the entire training dataset throughout
      the course of the iterative process. If set to `None`, no such evaluation
      is done.
    test_fn: An optional callable accepting the `model` attribute of the
      iterative process state and returning a dict of test metrics. Used to
      compute test metrics at the end of the training process.
    root_output_dir: The name of the root output directory for writing
      experiment outputs.
    hparam_dict: An optional dictionary specifying hyperparameters of the
      experiment. If provided, the hyperparameters will be written to CSV.
    write_metrics_with_bz2: Whether to use `bz2` compression when writing
      metrics to CSV.
    rounds_per_eval: How often to compute validation metrics.
    rounds_per_checkpoint: How often to checkpoint the iterative process state.
      If you expect the job to restart frequently, this should be small. If no
      interruptions are expected, this can be made larger.
    rounds_per_train_eval: How often to compute metrics over the entire training
      dataset. Note that this is only done if a `train_eval_fn` argument is
      supplied.
    rounds_per_profile: Experimental setting. If set to a value greater than 0,
      this dictates how often a TensorFlow profiler is run.

  Returns:
    The final `state` of the iterative process after training.
  """
  if not isinstance(iterative_process, tff.templates.IterativeProcess):
    raise TypeError('iterative_process should be type '
                    '`tff.templates.IterativeProcess`.')
  if not callable(client_datasets_fn):
    raise TypeError('client_datasets_fn should be callable.')
  if not callable(validation_fn):
    raise TypeError('validation_fn should be callable.')
  if train_eval_fn is not None and not callable(train_eval_fn):
    raise TypeError('train_eval_fn should be callable.')
  if test_fn is not None and not callable(test_fn):
    raise TypeError('test_fn should be callable.')

  logging.info('Starting iterative_process training loop...')
  initial_state = iterative_process.initialize()
  logging.info( ' Initilized!  keys: \n')
  for k in initial_state.__dict__.keys():
    logging.info(f'             {k}')

  if not hasattr(initial_state, 'model'):
    raise TypeError('The server state must have a model attribute.')

  checkpoint_mngr, metrics_mngr, summary_writer, profiler = _setup_outputs(
      root_output_dir, experiment_name, hparam_dict, write_metrics_with_bz2,
      rounds_per_profile)

  results_r_vec_dir = os.path.join(root_output_dir, 'results', experiment_name, 'r_vec')
  create_if_not_exists(results_r_vec_dir)

  logging.info('Asking checkpoint manager to load checkpoint.')
  state, round_num = checkpoint_mngr.load_latest_checkpoint(initial_state)

  if state is None:
    logging.info('Initializing experiment from scratch.')
    state = initial_state
    round_num = 0
    metrics_mngr.clear_all_rounds()
  else:
    logging.info('Restarted from checkpoint round %d', round_num)
    round_num += 1  # Increment to avoid overwriting current checkpoint
    metrics_mngr.clear_rounds_after(last_valid_round_num=round_num - 1)

  loop_start_time = time.time()
  r_vec=None
  while round_num < total_rounds:
    data_prep_start_time = time.time()
    if round_num==0:
      federated_train_data, federated_weights,r_vec,idx_ids,avail = client_datasets_fn(round_num)
    else:
      if r_vec is None:
        r_vec_filename = os.path.join(results_r_vec_dir, f'r_vec{round_num-1}.npy')
        r_vec_numpy = np.load(r_vec_filename, allow_pickle=True)
        r_vec = tf.Variable(r_vec_numpy, dtype = tf.float32)
      federated_train_data, federated_weights,r_vec,idx_ids,avail = client_datasets_fn(round_num, r_vec)
    train_metrics = {
        'prepare_datasets_secs': time.time() - data_prep_start_time
    }
    np.save( os.path.join(results_r_vec_dir, f'r_vec{round_num}.npy'), r_vec.numpy())
    training_start_time = time.time()
    prev_model = state.model
    # TODO(b/145604851): This try/except is used to circumvent ambiguous TF
    # errors during training, and should be removed once the root cause is
    # determined (and possibly fixed).
    try:
      with profiler(round_num):
        state, round_metrics = iterative_process.next(state,
                                                      federated_train_data, federated_weights.numpy().tolist())
    except (tf.errors.FailedPreconditionError, tf.errors.NotFoundError,
            tf.errors.InternalError) as e:
      logging.warning('Caught %s exception while running round %d:\n\t%s',
                      type(e), round_num, e)
      continue  # restart the loop without incrementing the round number

    train_metrics['training_secs'] = time.time() - training_start_time
    train_metrics['model_delta_l2_norm'] = _compute_numpy_l2_difference(
        state.model, prev_model)
    if hasattr(state, 'num_participants'):
      train_metrics['num_participants'] = state.num_participants
    if hasattr(state, 'threshold'):
      train_metrics['threshold'] = state.threshold

    train_metrics.update(round_metrics)

    logging.info(f'Number of available clients: {sum(avail)} -  participant 1:  {idx_ids[0]}')

    logging.info('Round {:2d}, {:.2f}s per round in average.'.format(
        round_num, (time.time() - loop_start_time) / (round_num + 1)))

    if (round_num % rounds_per_checkpoint == 0 or
        round_num == total_rounds - 1):
      save_checkpoint_start_time = time.time()
      checkpoint_mngr.save_checkpoint(state, round_num)
      train_metrics['save_checkpoint_secs'] = (
          time.time() - save_checkpoint_start_time)

    metrics = {'train': train_metrics}

    if round_num % rounds_per_eval == 0:
      # Compute validation metrics
      evaluate_start_time = time.time()
      validation_metrics = validation_fn(state.model)
      validation_metrics['evaluate_secs'] = time.time() - evaluate_start_time
      metrics['eval'] = validation_metrics

    if train_eval_fn and round_num % rounds_per_train_eval == 0:
      # Compute metrics over the entire training dataset
      train_eval_start = time.time()
      train_eval_metrics = train_eval_fn(state.model)
      train_eval_metrics['evaluate_secs'] = time.time() - train_eval_start
      metrics['train_eval'] = train_eval_metrics
    if test_fn and round_num % rounds_per_train_eval == 0:
      test_start_time = time.time()
      test_metrics = test_fn(state.model)
      test_metrics['evaluate_secs'] = time.time() - test_start_time
      metrics['test'] = test_metrics


    _write_metrics(metrics_mngr, summary_writer, metrics, round_num)
    round_num += 1

  # Final metrics evaluation once the training has completed
  metrics = {}

  # Validation metrics
  evaluate_start_time = time.time()
  validation_metrics = validation_fn(state.model)
  validation_metrics['evaluate_secs'] = time.time() - evaluate_start_time
  metrics['eval'] = validation_metrics

  # Training set metrics
  if train_eval_fn:
    train_eval_start = time.time()
    train_eval_metrics = train_eval_fn(state.model)
    train_eval_metrics['evaluate_secs'] = time.time() - train_eval_start
    metrics['train_eval'] = train_eval_metrics

  # Test set metrics
  if test_fn:
    test_start_time = time.time()
    test_metrics = test_fn(state.model)
    test_metrics['evaluate_secs'] = time.time() - test_start_time
    metrics['test'] = test_metrics
  _write_metrics(metrics_mngr, summary_writer, metrics, total_rounds)

  return state
예제 #2
0
def run(iterative_process: tff.templates.IterativeProcess,
        client_datasets_fn: Callable[[int], List[tf.data.Dataset]],
        validation_fn: Callable[[Any, int], Dict[str, float]],
        total_rounds: int,
        experiment_name: str,
        test_fn: Optional[Callable[[Any], Dict[str, float]]] = None,
        root_output_dir: Optional[str] = '/tmp/fed_opt',
        rounds_per_eval: Optional[int] = 1,
        rounds_per_checkpoint: Optional[int] = 50,
        rounds_per_profile: Optional[int] = 0):
    """Runs federated training for a given `tff.templates.IterativeProcess`.

  We assume that the iterative process has the following functional type
  signatures:

    *   `initialize`: `( -> S@SERVER)` where `S` represents the server state.
    *   `next`: `<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>` where `S`
        represents the server state, `{B*}` represents the client datasets,
        and `T` represents a python `Mapping` object.

  The iterative process must also have a callable attribute `get_model_weights`
  that takes as input the state of the iterative process, and returns a
  `tff.learning.ModelWeights` object.

  Args:
    iterative_process: A `tff.templates.IterativeProcess` instance to run.
    client_datasets_fn: Function accepting an integer argument (the round
      number) and returning a list of client datasets to use as federated data
      for that round.
    validation_fn: A callable accepting a `tff.learning.ModelWeights` and the
      current round number, and returning a dict of evaluation metrics. Used to
      compute validation metrics throughout the training process.
    total_rounds: The number of federated training rounds to perform.
    experiment_name: The name of the experiment being run. This will be appended
      to the `root_output_dir` for purposes of writing outputs.
    test_fn: An optional callable accepting a `tff.learning.ModelWeights` and
      returning a dict of test set metrics. Used to compute test metrics at the
      end of the training process.
    root_output_dir: The name of the root output directory for writing
      experiment outputs.
    rounds_per_eval: How often to compute validation metrics.
    rounds_per_checkpoint: How often to checkpoint the iterative process state.
      If you expect the job to restart frequently, this should be small. If no
      interruptions are expected, this can be made larger.
    rounds_per_profile: Experimental setting. If set to a value greater than 0,
      this dictates how often a TensorFlow profiler is run.

  Returns:
    The final `state` of the iterative process after training.
  """
    _check_iterative_process_compatibility(iterative_process)
    if not callable(client_datasets_fn):
        raise TypeError('client_datasets_fn should be callable.')
    if not callable(validation_fn):
        raise TypeError('validation_fn should be callable.')
    if test_fn is not None and not callable(test_fn):
        raise TypeError('test_fn should be callable.')

    logging.info('Starting iterative_process training loop...')
    initial_state = iterative_process.initialize()

    checkpoint_mngr, metrics_mngr, tb_mngr, profiler = _setup_outputs(
        root_output_dir, experiment_name, rounds_per_profile)

    logging.info('Asking checkpoint manager to load checkpoint.')
    state, round_num = checkpoint_mngr.load_latest_checkpoint(initial_state)

    if state is None:
        logging.info('Initializing experiment from scratch.')
        state = initial_state
        round_num = 0
    else:
        logging.info('Restarted from checkpoint round %d', round_num)
        round_num += 1  # Increment to avoid overwriting current checkpoint
    metrics_mngr.clear_metrics(round_num)

    current_model = iterative_process.get_model_weights(state)

    loop_start_time = time.time()
    loop_start_round = round_num
    while round_num < total_rounds:
        data_prep_start_time = time.time()
        federated_train_data = client_datasets_fn(round_num)
        train_metrics = {
            'prepare_datasets_secs': time.time() - data_prep_start_time
        }

        training_start_time = time.time()
        prev_model = current_model

        # TODO(b/145604851): This try/except is used to circumvent ambiguous TF
        # errors during training, and should be removed once the root cause is
        # determined (and possibly fixed).
        try:
            with profiler(round_num):
                state, round_metrics = iterative_process.next(
                    state, federated_train_data)
        except (tf.errors.FailedPreconditionError, tf.errors.NotFoundError,
                tf.errors.InternalError) as e:
            logging.warning(
                'Caught %s exception while running round %d:\n\t%s', type(e),
                round_num, e)
            continue  # restart the loop without incrementing the round number

        current_model = iterative_process.get_model_weights(state)
        train_metrics['training_secs'] = time.time() - training_start_time
        train_metrics['model_delta_l2_norm'] = _compute_numpy_l2_difference(
            current_model, prev_model)
        train_metrics.update(round_metrics)

        loop_time = time.time() - loop_start_time
        loop_rounds = (round_num - loop_start_round + 1)
        logging.info('Round {:2d}, {:.2f}s per round in average.'.format(
            round_num, loop_time / loop_rounds))

        if (round_num % rounds_per_checkpoint == 0
                or round_num == total_rounds - 1):
            save_checkpoint_start_time = time.time()
            checkpoint_mngr.save_checkpoint(state, round_num)
            train_metrics['save_checkpoint_secs'] = (
                time.time() - save_checkpoint_start_time)

        metrics = {'train': train_metrics}

        if round_num % rounds_per_eval == 0:
            # Compute validation metrics
            evaluate_start_time = time.time()
            validation_metrics = validation_fn(current_model, round_num)
            validation_metrics['evaluate_secs'] = time.time(
            ) - evaluate_start_time
            metrics['eval'] = validation_metrics

        _write_metrics(metrics_mngr, tb_mngr, metrics, round_num)
        round_num += 1

    # Final metrics evaluation once the training has completed
    metrics = {}

    # Validation metrics
    evaluate_start_time = time.time()
    validation_metrics = validation_fn(current_model, round_num)
    validation_metrics['evaluate_secs'] = time.time() - evaluate_start_time
    metrics['eval'] = validation_metrics

    # Test set metrics
    if test_fn:
        test_start_time = time.time()
        test_metrics = test_fn(current_model)
        test_metrics['evaluate_secs'] = time.time() - test_start_time
        metrics['test'] = test_metrics
    _write_metrics(metrics_mngr, tb_mngr, metrics, total_rounds)

    return state
예제 #3
0
def run(iterative_process: tff.templates.IterativeProcess,
        client_datasets_fn: Callable[[int], List[tf.data.Dataset]],
        validation_fn: Callable[[Any], Dict[str, float]],
        train_eval_fn: Optional[Callable[[Any], Dict[str, float]]] = None,
        test_fn: Optional[Callable[[Any], Dict[str, float]]] = None):
    """Runs federated training for the given TFF `IterativeProcess` instance.

  Args:
    iterative_process: A `tff.templates.IterativeProcess` instance to run.
    client_datasets_fn: Function accepting an integer argument (the round
      number) and returning a list of client datasets to use as federated data
      for that round, and a list of the corresponding client ids.
    validation_fn: Callable accepting the `model` attribute of an
      `IterationResult.state`, and returning a dict of evaluation metrics. Used
      to compute validation metrics throughout the training process.
    train_eval_fn: An optional callable accepting the `model` attribute of
      an `IterationResult.state` and returning a dict of evaluation metrics.
      Used to compute training metrics over the entire training dataset
      throughout the course of the iterative process.
    test_fn: An optional callable accepting the `model` attribute of an
      `IterationResult.state`) and returning a dict of test metrics. Used to
      compute test metrics at the end of the training process.

  Returns:
    The `state` of the `IterationResult` representing the result of the training
      loop.
  """
    if not isinstance(iterative_process, tff.templates.IterativeProcess):
        raise TypeError('iterative_process should be type '
                        '`tff.templates.IterativeProcess`.')
    if not callable(client_datasets_fn):
        raise TypeError('client_datasets_fn should be callable.')
    if not callable(validation_fn):
        raise TypeError('validation_fn should be callable.')
    if train_eval_fn is not None and not callable(train_eval_fn):
        raise TypeError('train_eval_fn should be callable.')
    if test_fn is not None and not callable(test_fn):
        raise TypeError('test_fn should be callable.')
    total_rounds = FLAGS.total_rounds

    logging.info('Starting iterative_process training loop...')
    initial_state = iterative_process.initialize()

    if not hasattr(initial_state, 'model'):
        raise TypeError('The server state must have a model attribute.')

    hparam_flags = utils_impl.get_hparam_flags()
    hparam_dict = collections.OrderedDict([(name, FLAGS[name].value)
                                           for name in hparam_flags])

    checkpoint_mngr, metrics_mngr, summary_writer, profiler = _setup_outputs(
        FLAGS.root_output_dir, FLAGS.experiment_name, hparam_dict)

    logging.info('Asking checkpoint manager to load checkpoint.')
    state, round_num = checkpoint_mngr.load_latest_checkpoint(initial_state)

    if state is None:
        logging.info('Initializing experiment from scratch.')
        state = initial_state
        round_num = 0
        metrics_mngr.clear_all_rounds()
    else:
        logging.info('Restarted from checkpoint round %d', round_num)
        round_num += 1  # Increment to avoid overwriting current checkpoint
        metrics_mngr.clear_rounds_after(last_valid_round_num=round_num - 1)

    loop_start_time = time.time()
    while round_num < total_rounds:
        data_prep_start_time = time.time()
        federated_train_data = client_datasets_fn(round_num)
        train_metrics = {
            'prepare_datasets_secs': time.time() - data_prep_start_time
        }

        training_start_time = time.time()
        prev_model = state.model
        # TODO(b/145604851): This try/except is used to circumvent ambiguous TF
        # errors during training, and should be removed once the root cause is
        # determined (and possibly fixed).
        try:
            with profiler(round_num):
                state, round_metrics = iterative_process.next(
                    state, federated_train_data)
        except (tf.errors.FailedPreconditionError, tf.errors.NotFoundError,
                tf.errors.InternalError) as e:
            logging.warning(
                'Caught %s exception while running round %d:\n\t%s', type(e),
                round_num, e)
            continue  # restart the loop without incrementing the round number

        train_metrics['training_secs'] = time.time() - training_start_time
        train_metrics['model_delta_l2_norm'] = _compute_numpy_l2_difference(
            state.model, prev_model)
        train_metrics.update(round_metrics)

        logging.info('Round {:2d}, {:.2f}s per round in average.'.format(
            round_num, (time.time() - loop_start_time) / (round_num + 1)))

        if (round_num % FLAGS.rounds_per_checkpoint == 0
                or round_num == total_rounds - 1):
            save_checkpoint_start_time = time.time()
            checkpoint_mngr.save_checkpoint(state, round_num)
            train_metrics['save_checkpoint_secs'] = (
                time.time() - save_checkpoint_start_time)

        metrics = {'train': train_metrics}

        if round_num % FLAGS.rounds_per_eval == 0:
            # Compute validation metrics
            evaluate_start_time = time.time()
            validation_metrics = validation_fn(state.model)
            validation_metrics['evaluate_secs'] = time.time(
            ) - evaluate_start_time
            metrics['eval'] = validation_metrics

        if train_eval_fn and round_num % FLAGS.rounds_per_train_eval == 0:
            # Compute metrics over the entire training dataset
            train_eval_start = time.time()
            train_eval_metrics = train_eval_fn(state.model)
            train_eval_metrics['evaluate_secs'] = time.time(
            ) - train_eval_start
            metrics['train_eval'] = train_eval_metrics

        _write_metrics(metrics_mngr, summary_writer, metrics, round_num)
        round_num += 1

    # Final metrics evaluation once the training has completed
    metrics = {}

    # Validation metrics
    evaluate_start_time = time.time()
    validation_metrics = validation_fn(state.model)
    validation_metrics['evaluate_secs'] = time.time() - evaluate_start_time
    metrics['eval'] = validation_metrics

    # Training set metrics
    if train_eval_fn:
        train_eval_start = time.time()
        train_eval_metrics = train_eval_fn(state.model)
        train_eval_metrics['evaluate_secs'] = time.time() - train_eval_start
        metrics['train_eval'] = train_eval_metrics

    # Test set metrics
    if test_fn:
        test_start_time = time.time()
        test_metrics = test_fn(state.model)
        test_metrics['evaluate_secs'] = time.time() - test_start_time
        metrics['test'] = test_metrics
    _write_metrics(metrics_mngr, summary_writer, metrics, total_rounds)

    return state
예제 #4
0
def run(iterative_process: tff.templates.IterativeProcess,
        train_client_datasets_fn: Callable[[int], List[tf.data.Dataset]],
        evaluation_fn: Callable[[Any, int], Dict[str, float]],
        total_rounds: int,
        experiment_name: str,
        test_fn: Optional[Callable[[Any], Dict[str, float]]] = None,
        root_output_dir: Optional[str] = '/tmp/fedopt_guide',
        hparam_dict: Optional[Dict[str, Any]] = None,
        rounds_per_eval: Optional[int] = 10,
        rounds_per_checkpoint: Optional[int] = 50):
    """Runs federated training for a given `tff.templates.IterativeProcess`.

  We assume that the iterative process has the following functional type
  signatures:

    *   `initialize`: `( -> S@SERVER)` where `S` represents the server state.
    *   `next`: `<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>` where `S`
        represents the server state, `{B*}` represents the client datasets,
        and `T` represents a python `Mapping` object.

  The iterative process must also have a callable attribute `get_model_weights`
  that takes as input the state of the iterative process, and returns a
  nested structure of tensors that can be input to the `evaluation_fn` and
  `test_fn` (if provided).

  Args:
    iterative_process: A `tff.templates.IterativeProcess` instance to run.
    train_client_datasets_fn: Function accepting an integer argument (the round
      number) and returning a list of train client datasets to use as federated
      data for that training round.
    evaluation_fn: A callable accepting the output of the `get_model_weights`
      attribute of the iterative process and a `round_num`, and returning a
      dictionary of evaluation metrics. Used to compute evaluation metrics
      throughout the training process.
    total_rounds: The number of federated training rounds to perform.
    experiment_name: The name of the experiment being run. This will be appended
      to the `root_output_dir` for purposes of writing outputs.
    test_fn: A callable accepting the output of the `get_model_weights`
      attribute of the iterative process and returning a dictionary of test
      metrics. Used to compute test metrics at the end of the training process.
    root_output_dir: The name of the root output directory for writing
      experiment outputs.
    hparam_dict: An optional dictionary specifying hyperparameters of the
      experiment. If provided, the hyperparameters will be written to CSV.
    rounds_per_eval: How often to compute validation metrics.
    rounds_per_checkpoint: How often to checkpoint the iterative process state.
      If you expect the job to restart frequently, this should be small. If no
      interruptions are expected, this can be made larger.

  Returns:
    The final `state` of the iterative process after training.
  """
    _check_iterative_process_compatibility(iterative_process)
    if not callable(train_client_datasets_fn):
        raise TypeError('train_client_datasets_fn should be callable.')
    if not callable(evaluation_fn):
        raise TypeError('evaluation_fn should be callable.')
    if test_fn is not None and not callable(test_fn):
        raise TypeError('test_fn should be callable.')

    logging.info('Starting iterative_process training loop...')
    initial_state = iterative_process.initialize()

    if not hasattr(initial_state, 'model'):
        raise TypeError('The server state must have a model attribute.')

    checkpoint_mngr, metrics_mngr, tensorboard_mngr = _setup_outputs(
        root_output_dir, experiment_name, hparam_dict)

    logging.info('Asking checkpoint manager to load checkpoint.')
    state, round_num = checkpoint_mngr.load_latest_checkpoint(initial_state)

    if state is None:
        logging.info('Initializing experiment from scratch.')
        state = initial_state
        round_num = 0
    else:
        logging.info('Restarted from checkpoint round %d', round_num)
        round_num += 1  # Increment to avoid overwriting current checkpoint
    metrics_mngr.clear_metrics(round_num)
    current_model = iterative_process.get_model_weights(state)

    loop_start_time = time.time()
    while round_num < total_rounds:
        data_prep_start_time = time.time()
        federated_train_data = train_client_datasets_fn(round_num)
        train_metrics = {
            'prepare_datasets_secs': time.time() - data_prep_start_time
        }

        training_start_time = time.time()
        prev_model = iterative_process.get_model_weights(state)
        state, round_metrics = iterative_process.next(state,
                                                      federated_train_data)
        current_model = iterative_process.get_model_weights(state)

        train_metrics['training_secs'] = time.time() - training_start_time
        train_metrics['model_delta_l2_norm'] = _compute_numpy_l2_difference(
            current_model, prev_model)
        train_metrics.update(round_metrics)

        logging.info('Round {:2d}, {:.2f}s per round in average.'.format(
            round_num, (time.time() - loop_start_time) / (round_num + 1)))

        if (round_num % rounds_per_checkpoint == 0
                or round_num == total_rounds - 1):
            save_checkpoint_start_time = time.time()
            checkpoint_mngr.save_checkpoint(state, round_num)
            train_metrics['save_checkpoint_secs'] = (
                time.time() - save_checkpoint_start_time)

        metrics = {'train': train_metrics}

        if round_num % rounds_per_eval == 0:
            # Compute evaluation metrics
            evaluate_start_time = time.time()
            validation_metrics = evaluation_fn(current_model, round_num)
            validation_metrics['evaluate_secs'] = time.time(
            ) - evaluate_start_time
            metrics['eval'] = validation_metrics

        _write_metrics(metrics_mngr, tensorboard_mngr, metrics, round_num)
        round_num += 1

    # Final metrics evaluation once the training has completed
    metrics = {}

    # Evaluation metrics
    evaluate_start_time = time.time()
    validation_metrics = evaluation_fn(current_model, round_num)
    validation_metrics['evaluate_secs'] = time.time() - evaluate_start_time
    metrics['eval'] = validation_metrics

    # Test set metrics
    if test_fn:
        test_start_time = time.time()
        test_metrics = test_fn(current_model)
        test_metrics['evaluate_secs'] = time.time() - test_start_time
        metrics['test'] = test_metrics
    _write_metrics(metrics_mngr, tensorboard_mngr, metrics, total_rounds)

    return state
예제 #5
0
def run(iterative_process: tff.templates.IterativeProcess,
        client_datasets_fn: Callable[[int], List[tf.data.Dataset]],
        validation_fn: Callable[[Any, int], Dict[str, float]],
        total_rounds: int,
        experiment_name: str,
        test_fn: Optional[Callable[[Any], Dict[str, float]]] = None,
        root_output_dir: Optional[str] = '/tmp/fed_opt',
        rounds_per_eval: Optional[int] = 1,
        rounds_per_checkpoint: Optional[int] = 50):
    """Runs federated training for a given `tff.templates.IterativeProcess`.

  We assume that the iterative process has the following functional type
  signatures:

    *   `initialize`: `( -> S@SERVER)` where `S` represents the server state.
    *   `next`: `<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>` where `S`
        represents the server state, `{B*}` represents the client datasets,
        and `T` represents a python `Mapping` object.

  Args:
    iterative_process: A `tff.templates.IterativeProcess` instance to run.
    client_datasets_fn: Function accepting an integer argument (the round
      number) and returning a list of client datasets to use as federated data
      for that round.
    validation_fn: A callable accepting the current state of `iterative_process`
      and the current round number, and returning a dict of evaluation metrics.
      Used to compute validation metrics throughout the training process.
    total_rounds: The number of federated training rounds to perform.
    experiment_name: The name of the experiment being run. This will be appended
      to the `root_output_dir` for purposes of writing outputs.
    test_fn: An optional callable accepting the current state of
      `iterative_process` and returning a dict of test set metrics. Used to
      compute test metrics at the end of the training process.
    root_output_dir: The name of the root output directory for writing
      experiment outputs.
    rounds_per_eval: How often to compute validation metrics.
    rounds_per_checkpoint: How often to checkpoint the iterative process state.
      If you expect the job to restart frequently, this should be small. If no
      interruptions are expected, this can be made larger.

  Returns:
    The final `state` of the iterative process after training.
  """
    if not isinstance(iterative_process, tff.templates.IterativeProcess):
        raise TypeError(
            'iterative_process must be a `tff.templates.IterativeProcess`.')
    if not callable(client_datasets_fn):
        raise TypeError('client_datasets_fn should be callable.')
    if not callable(validation_fn):
        raise TypeError('validation_fn should be callable.')
    if test_fn is not None and not callable(test_fn):
        raise TypeError('test_fn should be callable.')

    logging.info('Starting iterative_process training loop...')
    initial_state = iterative_process.initialize()

    checkpoint_mngr, metrics_mngr, tb_mngr = _setup_outputs(
        root_output_dir, experiment_name)

    logging.info('Asking checkpoint manager to load checkpoint.')
    state, round_num = checkpoint_mngr.load_latest_checkpoint(initial_state)

    if state is None:
        logging.info('Initializing experiment from scratch.')
        state = initial_state
        round_num = 0
    else:
        logging.info('Restarted from checkpoint round %d', round_num)
        round_num += 1  # Increment to avoid overwriting current checkpoint
    metrics_mngr.clear_metrics(round_num)

    loop_start_time = time.time()
    loop_start_round = round_num
    while round_num < total_rounds:
        data_prep_start_time = time.time()
        federated_train_data = client_datasets_fn(round_num)
        train_metrics = {
            'prepare_datasets_secs': time.time() - data_prep_start_time
        }

        training_start_time = time.time()
        state, round_metrics = iterative_process.next(state,
                                                      federated_train_data)
        train_metrics['training_secs'] = time.time() - training_start_time
        train_metrics.update(round_metrics)

        loop_time = time.time() - loop_start_time
        loop_rounds = (round_num - loop_start_round + 1)
        logging.info('Round {:2d}, {:.2f}s per round in average.'.format(
            round_num, loop_time / loop_rounds))

        if (round_num % rounds_per_checkpoint == 0
                or round_num == total_rounds - 1):
            save_checkpoint_start_time = time.time()
            checkpoint_mngr.save_checkpoint(state, round_num)
            train_metrics['save_checkpoint_secs'] = (
                time.time() - save_checkpoint_start_time)

        metrics = {'train': train_metrics}

        if round_num % rounds_per_eval == 0:
            # Compute validation metrics
            evaluate_start_time = time.time()
            validation_metrics = validation_fn(state, round_num)
            validation_metrics['evaluate_secs'] = time.time(
            ) - evaluate_start_time
            metrics['eval'] = validation_metrics

        _write_metrics(metrics_mngr, tb_mngr, metrics, round_num)
        round_num += 1

    # Final metrics evaluation once the training has completed
    metrics = {}

    # Validation metrics
    evaluate_start_time = time.time()
    validation_metrics = validation_fn(state, round_num)
    validation_metrics['evaluate_secs'] = time.time() - evaluate_start_time
    metrics['eval'] = validation_metrics

    # Test set metrics
    if test_fn:
        test_start_time = time.time()
        test_metrics = test_fn(state)
        test_metrics['evaluate_secs'] = time.time() - test_start_time
        metrics['test'] = test_metrics
    _write_metrics(metrics_mngr, tb_mngr, metrics, total_rounds)

    return state
예제 #6
0
def run(
        iterative_process: tff.templates.IterativeProcess,
        client_datasets_fn: Callable[[int, int], Tuple[List, int]],  # pylint: disable=g-bare-generic
        validation_fn: Callable[[Any], Dict[str, float]],
        total_epochs: int,
        total_rounds: int,
        experiment_name: str,
        train_eval_fn: Optional[Callable[[Any], Dict[str, float]]] = None,
        test_fn: Optional[Callable[[Any], Dict[str, float]]] = None,
        root_output_dir: Optional[str] = '/tmp/fed_opt',
        hparam_dict: Optional[Dict[str, Any]] = None,
        rounds_per_eval: Optional[int] = 1,
        rounds_per_checkpoint: Optional[int] = 50,
        rounds_per_train_eval: Optional[int] = 100):
    """Runs federated training for a given `tff.templates.IterativeProcess`.

  We assume that the iterative process has the following functional type
  signatures:

    *   `initialize`: `( -> S@SERVER)` where `S` represents the server state.
    *   `next`: `<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>` where `S`
        represents the server state, `{B*}` represents the client datasets,
        and `T` represents a python `Mapping` object.

  Args:
    iterative_process: A `tff.templates.IterativeProcess` instance to run.
    client_datasets_fn: Function accepts integer arguments (the round number and
      the epoch) and returns a tuple of a list of client datasets to use as data
      data for that round, and the updated epoch index.
    validation_fn: A callable accepting the `model` attribute of the iterative
      process state and returning a dict of evaluation metrics. Used to compute
      validation metrics throughout the training process.
    total_epochs: Nubmer of total epochs if using `ClientIDShuffler` to shuffle
      clients. Use 0 when sampling clients and control by `total_rounds`.
    total_rounds: The number of federated training rounds to perform. If
      `ClientIDShuffler` is used for `client_datasets_fn`, the total rounds will
      take the minimum of `total_rounds` and rounds_per_epoch*`total_epochs`.
    experiment_name: The name of the experiment being run. This will be appended
      to the `root_output_dir` for purposes of writing outputs.
    train_eval_fn: An optional callable accepting the `model` attribute of the
      iterative process state and returning a dict of evaluation metrics. Used
      to compute training metrics over the entire training dataset throughout
      the course of the iterative process. If set to `None`, no such evaluation
      is done.
    test_fn: An optional callable accepting the `model` attribute of the
      iterative process state and returning a dict of test metrics. Used to
      compute test metrics at the end of the training process.
    root_output_dir: The name of the root output directory for writing
      experiment outputs.
    hparam_dict: An optional dictionary specifying hyperparameters of the
      experiment. If provided, the hyperparameters will be written to CSV.
    rounds_per_eval: How often to compute validation metrics.
    rounds_per_checkpoint: How often to checkpoint the iterative process state.
      If you expect the job to restart frequently, this should be small. If no
      interruptions are expected, this can be made larger.
    rounds_per_train_eval: How often to compute metrics over the entire training
      dataset. Note that this is only done if a `train_eval_fn` argument is
      supplied.

  Returns:
    The final `state` of the iterative process after training.
  """
    if not isinstance(iterative_process, tff.templates.IterativeProcess):
        raise TypeError('iterative_process should be type '
                        '`tff.templates.IterativeProcess`.')
    if not callable(client_datasets_fn):
        raise TypeError('client_datasets_fn should be callable.')
    if not callable(validation_fn):
        raise TypeError('validation_fn should be callable.')
    if train_eval_fn is not None and not callable(train_eval_fn):
        raise TypeError('train_eval_fn should be callable.')
    if test_fn is not None and not callable(test_fn):
        raise TypeError('test_fn should be callable.')

    logging.info('Starting iterative_process training loop...')
    initial_state = iterative_process.initialize()

    checkpoint_mngr, metrics_mngr, tensorboard_mngr = _setup_outputs(
        root_output_dir, experiment_name, hparam_dict)

    logging.info('Asking checkpoint manager to load checkpoint.')
    state, round_num = checkpoint_mngr.load_latest_checkpoint(initial_state)

    # TODO(b/172867399): we disable restarting from checkpoint when shuffling
    # client IDs by epochs. Non-trivial amount of change has to be made to make
    # sure disjoint clients are used cross rounds when restarts. A better design
    # of client dataset generator with random seed instead of `client_datasets_fn`
    # accepting `epoch` as argument, can help.
    epoch = 0 if total_epochs > 0 else -1
    if state is None or total_epochs > 0:
        state = initial_state
        round_num = 0
        logging.info('Initializing experiment from scratch at round %d.',
                     round_num)
    else:
        logging.info('Restarted from checkpoint round %d', round_num)
        round_num += 1  # Increment to avoid overwriting current checkpoint
    metrics_mngr.clear_metrics(round_num)

    loop_start_time = time.time()
    while epoch < total_epochs and round_num < total_rounds:
        # TODO(b/172867399): add restarts functionality for FTRLM when total_epochs
        # is larger than one.
        data_prep_start_time = time.time()
        federated_train_data, epoch = client_datasets_fn(round_num, epoch)
        train_metrics = {
            'prepare_datasets_secs': time.time() - data_prep_start_time
        }

        training_start_time = time.time()
        prev_model = _get_model_weights(state)
        state, loss = iterative_process.next(state, federated_train_data)

        train_metrics['training_secs'] = time.time() - training_start_time
        train_metrics['model_delta_l2_norm'] = _compute_numpy_l2_difference(
            _get_model_weights(state), prev_model)
        train_metrics['loss'] = loss

        logging.info('Round {:2d}, {:.2f}s per round in average.'.format(
            round_num, (time.time() - loop_start_time) / (round_num + 1)))

        if (round_num % rounds_per_checkpoint == 0
                or round_num == total_rounds - 1):
            save_checkpoint_start_time = time.time()
            checkpoint_mngr.save_checkpoint(state, round_num)
            train_metrics['save_checkpoint_secs'] = (
                time.time() - save_checkpoint_start_time)

        metrics = {'train': train_metrics}

        if train_eval_fn and round_num % rounds_per_train_eval == 0:
            # Compute metrics over the entire training dataset
            train_eval_start = time.time()
            train_eval_metrics = train_eval_fn(_get_model_weights(state))
            train_eval_metrics['evaluate_secs'] = time.time(
            ) - train_eval_start
            metrics['train_eval'] = train_eval_metrics

        if round_num % rounds_per_eval == 0:
            # Compute validation metrics
            evaluate_start_time = time.time()
            validation_metrics = validation_fn(_get_model_weights(state))
            validation_metrics['evaluate_secs'] = time.time(
            ) - evaluate_start_time
            metrics['eval'] = validation_metrics
            _write_metrics(metrics_mngr, tensorboard_mngr, metrics, round_num)

        round_num += 1

    # Final metrics evaluation once the training has completed
    metrics = {}

    # Validation metrics
    evaluate_start_time = time.time()
    validation_metrics = validation_fn(_get_model_weights(state))
    validation_metrics['evaluate_secs'] = time.time() - evaluate_start_time
    metrics['eval'] = validation_metrics

    # Training set metrics
    if train_eval_fn:
        train_eval_start = time.time()
        train_eval_metrics = train_eval_fn(_get_model_weights(state))
        train_eval_metrics['evaluate_secs'] = time.time() - train_eval_start
        metrics['train_eval'] = train_eval_metrics

    # Test set metrics
    if test_fn:
        test_start_time = time.time()
        test_metrics = test_fn(_get_model_weights(state))
        test_metrics['evaluate_secs'] = time.time() - test_start_time
        metrics['test'] = test_metrics
    _write_metrics(metrics_mngr, tensorboard_mngr, metrics, round_num)

    return state