def run(iterative_process: tff.templates.IterativeProcess, client_datasets_fn: Callable[[int], List[tf.data.Dataset]], validation_fn: Callable[[Any], Dict[str, float]], total_rounds: int, experiment_name: str, train_eval_fn: Optional[Callable[[Any], Dict[str, float]]] = None, test_fn: Optional[Callable[[Any], Dict[str, float]]] = None, root_output_dir: Optional[str] = '/tmp/fed_opt', hparam_dict: Optional[Dict[str, Any]] = None, write_metrics_with_bz2: Optional[bool] = True, rounds_per_eval: Optional[int] = 1, rounds_per_checkpoint: Optional[int] = 50, rounds_per_train_eval: Optional[int] = 100, rounds_per_profile: Optional[int] = 0, clients_per_round: Optional[float] = 1.0): """Runs federated training for a given `tff.templates.IterativeProcess`. We assume that the iterative process has the following functional type signatures: * `initialize`: `( -> S@SERVER)` where `S` represents the server state. * `next`: `<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>` where `S` represents the server state, `{B*}` represents the client datasets, and `T` represents a python `Mapping` object. Moreover, the server state must have an attribute `model` that can be passed to `validation_fn`, `train_eval_fn`, and `test_fn` (if given). Args: iterative_process: A `tff.templates.IterativeProcess` instance to run. client_datasets_fn: Function accepting an integer argument (the round number) and returning a list of client datasets to use as federated data for that round, and a list of the corresponding client ids. validation_fn: A callable accepting the `model` attribute of the iterative process state and returning a dict of evaluation metrics. Used to compute validation metrics throughout the training process. total_rounds: The number of federated training rounds to perform. experiment_name: The name of the experiment being run. This will be appended to the `root_output_dir` for purposes of writing outputs. train_eval_fn: An optional callable accepting the `model` attribute of the iterative process state and returning a dict of evaluation metrics. Used to compute training metrics over the entire training dataset throughout the course of the iterative process. If set to `None`, no such evaluation is done. test_fn: An optional callable accepting the `model` attribute of the iterative process state and returning a dict of test metrics. Used to compute test metrics at the end of the training process. root_output_dir: The name of the root output directory for writing experiment outputs. hparam_dict: An optional dictionary specifying hyperparameters of the experiment. If provided, the hyperparameters will be written to CSV. write_metrics_with_bz2: Whether to use `bz2` compression when writing metrics to CSV. rounds_per_eval: How often to compute validation metrics. rounds_per_checkpoint: How often to checkpoint the iterative process state. If you expect the job to restart frequently, this should be small. If no interruptions are expected, this can be made larger. rounds_per_train_eval: How often to compute metrics over the entire training dataset. Note that this is only done if a `train_eval_fn` argument is supplied. rounds_per_profile: Experimental setting. If set to a value greater than 0, this dictates how often a TensorFlow profiler is run. Returns: The final `state` of the iterative process after training. """ if not isinstance(iterative_process, tff.templates.IterativeProcess): raise TypeError('iterative_process should be type ' '`tff.templates.IterativeProcess`.') if not callable(client_datasets_fn): raise TypeError('client_datasets_fn should be callable.') if not callable(validation_fn): raise TypeError('validation_fn should be callable.') if train_eval_fn is not None and not callable(train_eval_fn): raise TypeError('train_eval_fn should be callable.') if test_fn is not None and not callable(test_fn): raise TypeError('test_fn should be callable.') logging.info('Starting iterative_process training loop...') initial_state = iterative_process.initialize() logging.info( ' Initilized! keys: \n') for k in initial_state.__dict__.keys(): logging.info(f' {k}') if not hasattr(initial_state, 'model'): raise TypeError('The server state must have a model attribute.') checkpoint_mngr, metrics_mngr, summary_writer, profiler = _setup_outputs( root_output_dir, experiment_name, hparam_dict, write_metrics_with_bz2, rounds_per_profile) results_r_vec_dir = os.path.join(root_output_dir, 'results', experiment_name, 'r_vec') create_if_not_exists(results_r_vec_dir) logging.info('Asking checkpoint manager to load checkpoint.') state, round_num = checkpoint_mngr.load_latest_checkpoint(initial_state) if state is None: logging.info('Initializing experiment from scratch.') state = initial_state round_num = 0 metrics_mngr.clear_all_rounds() else: logging.info('Restarted from checkpoint round %d', round_num) round_num += 1 # Increment to avoid overwriting current checkpoint metrics_mngr.clear_rounds_after(last_valid_round_num=round_num - 1) loop_start_time = time.time() r_vec=None while round_num < total_rounds: data_prep_start_time = time.time() if round_num==0: federated_train_data, federated_weights,r_vec,idx_ids,avail = client_datasets_fn(round_num) else: if r_vec is None: r_vec_filename = os.path.join(results_r_vec_dir, f'r_vec{round_num-1}.npy') r_vec_numpy = np.load(r_vec_filename, allow_pickle=True) r_vec = tf.Variable(r_vec_numpy, dtype = tf.float32) federated_train_data, federated_weights,r_vec,idx_ids,avail = client_datasets_fn(round_num, r_vec) train_metrics = { 'prepare_datasets_secs': time.time() - data_prep_start_time } np.save( os.path.join(results_r_vec_dir, f'r_vec{round_num}.npy'), r_vec.numpy()) training_start_time = time.time() prev_model = state.model # TODO(b/145604851): This try/except is used to circumvent ambiguous TF # errors during training, and should be removed once the root cause is # determined (and possibly fixed). try: with profiler(round_num): state, round_metrics = iterative_process.next(state, federated_train_data, federated_weights.numpy().tolist()) except (tf.errors.FailedPreconditionError, tf.errors.NotFoundError, tf.errors.InternalError) as e: logging.warning('Caught %s exception while running round %d:\n\t%s', type(e), round_num, e) continue # restart the loop without incrementing the round number train_metrics['training_secs'] = time.time() - training_start_time train_metrics['model_delta_l2_norm'] = _compute_numpy_l2_difference( state.model, prev_model) if hasattr(state, 'num_participants'): train_metrics['num_participants'] = state.num_participants if hasattr(state, 'threshold'): train_metrics['threshold'] = state.threshold train_metrics.update(round_metrics) logging.info(f'Number of available clients: {sum(avail)} - participant 1: {idx_ids[0]}') logging.info('Round {:2d}, {:.2f}s per round in average.'.format( round_num, (time.time() - loop_start_time) / (round_num + 1))) if (round_num % rounds_per_checkpoint == 0 or round_num == total_rounds - 1): save_checkpoint_start_time = time.time() checkpoint_mngr.save_checkpoint(state, round_num) train_metrics['save_checkpoint_secs'] = ( time.time() - save_checkpoint_start_time) metrics = {'train': train_metrics} if round_num % rounds_per_eval == 0: # Compute validation metrics evaluate_start_time = time.time() validation_metrics = validation_fn(state.model) validation_metrics['evaluate_secs'] = time.time() - evaluate_start_time metrics['eval'] = validation_metrics if train_eval_fn and round_num % rounds_per_train_eval == 0: # Compute metrics over the entire training dataset train_eval_start = time.time() train_eval_metrics = train_eval_fn(state.model) train_eval_metrics['evaluate_secs'] = time.time() - train_eval_start metrics['train_eval'] = train_eval_metrics if test_fn and round_num % rounds_per_train_eval == 0: test_start_time = time.time() test_metrics = test_fn(state.model) test_metrics['evaluate_secs'] = time.time() - test_start_time metrics['test'] = test_metrics _write_metrics(metrics_mngr, summary_writer, metrics, round_num) round_num += 1 # Final metrics evaluation once the training has completed metrics = {} # Validation metrics evaluate_start_time = time.time() validation_metrics = validation_fn(state.model) validation_metrics['evaluate_secs'] = time.time() - evaluate_start_time metrics['eval'] = validation_metrics # Training set metrics if train_eval_fn: train_eval_start = time.time() train_eval_metrics = train_eval_fn(state.model) train_eval_metrics['evaluate_secs'] = time.time() - train_eval_start metrics['train_eval'] = train_eval_metrics # Test set metrics if test_fn: test_start_time = time.time() test_metrics = test_fn(state.model) test_metrics['evaluate_secs'] = time.time() - test_start_time metrics['test'] = test_metrics _write_metrics(metrics_mngr, summary_writer, metrics, total_rounds) return state
def run(iterative_process: tff.templates.IterativeProcess, client_datasets_fn: Callable[[int], List[tf.data.Dataset]], validation_fn: Callable[[Any, int], Dict[str, float]], total_rounds: int, experiment_name: str, test_fn: Optional[Callable[[Any], Dict[str, float]]] = None, root_output_dir: Optional[str] = '/tmp/fed_opt', rounds_per_eval: Optional[int] = 1, rounds_per_checkpoint: Optional[int] = 50, rounds_per_profile: Optional[int] = 0): """Runs federated training for a given `tff.templates.IterativeProcess`. We assume that the iterative process has the following functional type signatures: * `initialize`: `( -> S@SERVER)` where `S` represents the server state. * `next`: `<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>` where `S` represents the server state, `{B*}` represents the client datasets, and `T` represents a python `Mapping` object. The iterative process must also have a callable attribute `get_model_weights` that takes as input the state of the iterative process, and returns a `tff.learning.ModelWeights` object. Args: iterative_process: A `tff.templates.IterativeProcess` instance to run. client_datasets_fn: Function accepting an integer argument (the round number) and returning a list of client datasets to use as federated data for that round. validation_fn: A callable accepting a `tff.learning.ModelWeights` and the current round number, and returning a dict of evaluation metrics. Used to compute validation metrics throughout the training process. total_rounds: The number of federated training rounds to perform. experiment_name: The name of the experiment being run. This will be appended to the `root_output_dir` for purposes of writing outputs. test_fn: An optional callable accepting a `tff.learning.ModelWeights` and returning a dict of test set metrics. Used to compute test metrics at the end of the training process. root_output_dir: The name of the root output directory for writing experiment outputs. rounds_per_eval: How often to compute validation metrics. rounds_per_checkpoint: How often to checkpoint the iterative process state. If you expect the job to restart frequently, this should be small. If no interruptions are expected, this can be made larger. rounds_per_profile: Experimental setting. If set to a value greater than 0, this dictates how often a TensorFlow profiler is run. Returns: The final `state` of the iterative process after training. """ _check_iterative_process_compatibility(iterative_process) if not callable(client_datasets_fn): raise TypeError('client_datasets_fn should be callable.') if not callable(validation_fn): raise TypeError('validation_fn should be callable.') if test_fn is not None and not callable(test_fn): raise TypeError('test_fn should be callable.') logging.info('Starting iterative_process training loop...') initial_state = iterative_process.initialize() checkpoint_mngr, metrics_mngr, tb_mngr, profiler = _setup_outputs( root_output_dir, experiment_name, rounds_per_profile) logging.info('Asking checkpoint manager to load checkpoint.') state, round_num = checkpoint_mngr.load_latest_checkpoint(initial_state) if state is None: logging.info('Initializing experiment from scratch.') state = initial_state round_num = 0 else: logging.info('Restarted from checkpoint round %d', round_num) round_num += 1 # Increment to avoid overwriting current checkpoint metrics_mngr.clear_metrics(round_num) current_model = iterative_process.get_model_weights(state) loop_start_time = time.time() loop_start_round = round_num while round_num < total_rounds: data_prep_start_time = time.time() federated_train_data = client_datasets_fn(round_num) train_metrics = { 'prepare_datasets_secs': time.time() - data_prep_start_time } training_start_time = time.time() prev_model = current_model # TODO(b/145604851): This try/except is used to circumvent ambiguous TF # errors during training, and should be removed once the root cause is # determined (and possibly fixed). try: with profiler(round_num): state, round_metrics = iterative_process.next( state, federated_train_data) except (tf.errors.FailedPreconditionError, tf.errors.NotFoundError, tf.errors.InternalError) as e: logging.warning( 'Caught %s exception while running round %d:\n\t%s', type(e), round_num, e) continue # restart the loop without incrementing the round number current_model = iterative_process.get_model_weights(state) train_metrics['training_secs'] = time.time() - training_start_time train_metrics['model_delta_l2_norm'] = _compute_numpy_l2_difference( current_model, prev_model) train_metrics.update(round_metrics) loop_time = time.time() - loop_start_time loop_rounds = (round_num - loop_start_round + 1) logging.info('Round {:2d}, {:.2f}s per round in average.'.format( round_num, loop_time / loop_rounds)) if (round_num % rounds_per_checkpoint == 0 or round_num == total_rounds - 1): save_checkpoint_start_time = time.time() checkpoint_mngr.save_checkpoint(state, round_num) train_metrics['save_checkpoint_secs'] = ( time.time() - save_checkpoint_start_time) metrics = {'train': train_metrics} if round_num % rounds_per_eval == 0: # Compute validation metrics evaluate_start_time = time.time() validation_metrics = validation_fn(current_model, round_num) validation_metrics['evaluate_secs'] = time.time( ) - evaluate_start_time metrics['eval'] = validation_metrics _write_metrics(metrics_mngr, tb_mngr, metrics, round_num) round_num += 1 # Final metrics evaluation once the training has completed metrics = {} # Validation metrics evaluate_start_time = time.time() validation_metrics = validation_fn(current_model, round_num) validation_metrics['evaluate_secs'] = time.time() - evaluate_start_time metrics['eval'] = validation_metrics # Test set metrics if test_fn: test_start_time = time.time() test_metrics = test_fn(current_model) test_metrics['evaluate_secs'] = time.time() - test_start_time metrics['test'] = test_metrics _write_metrics(metrics_mngr, tb_mngr, metrics, total_rounds) return state
def run(iterative_process: tff.templates.IterativeProcess, client_datasets_fn: Callable[[int], List[tf.data.Dataset]], validation_fn: Callable[[Any], Dict[str, float]], train_eval_fn: Optional[Callable[[Any], Dict[str, float]]] = None, test_fn: Optional[Callable[[Any], Dict[str, float]]] = None): """Runs federated training for the given TFF `IterativeProcess` instance. Args: iterative_process: A `tff.templates.IterativeProcess` instance to run. client_datasets_fn: Function accepting an integer argument (the round number) and returning a list of client datasets to use as federated data for that round, and a list of the corresponding client ids. validation_fn: Callable accepting the `model` attribute of an `IterationResult.state`, and returning a dict of evaluation metrics. Used to compute validation metrics throughout the training process. train_eval_fn: An optional callable accepting the `model` attribute of an `IterationResult.state` and returning a dict of evaluation metrics. Used to compute training metrics over the entire training dataset throughout the course of the iterative process. test_fn: An optional callable accepting the `model` attribute of an `IterationResult.state`) and returning a dict of test metrics. Used to compute test metrics at the end of the training process. Returns: The `state` of the `IterationResult` representing the result of the training loop. """ if not isinstance(iterative_process, tff.templates.IterativeProcess): raise TypeError('iterative_process should be type ' '`tff.templates.IterativeProcess`.') if not callable(client_datasets_fn): raise TypeError('client_datasets_fn should be callable.') if not callable(validation_fn): raise TypeError('validation_fn should be callable.') if train_eval_fn is not None and not callable(train_eval_fn): raise TypeError('train_eval_fn should be callable.') if test_fn is not None and not callable(test_fn): raise TypeError('test_fn should be callable.') total_rounds = FLAGS.total_rounds logging.info('Starting iterative_process training loop...') initial_state = iterative_process.initialize() if not hasattr(initial_state, 'model'): raise TypeError('The server state must have a model attribute.') hparam_flags = utils_impl.get_hparam_flags() hparam_dict = collections.OrderedDict([(name, FLAGS[name].value) for name in hparam_flags]) checkpoint_mngr, metrics_mngr, summary_writer, profiler = _setup_outputs( FLAGS.root_output_dir, FLAGS.experiment_name, hparam_dict) logging.info('Asking checkpoint manager to load checkpoint.') state, round_num = checkpoint_mngr.load_latest_checkpoint(initial_state) if state is None: logging.info('Initializing experiment from scratch.') state = initial_state round_num = 0 metrics_mngr.clear_all_rounds() else: logging.info('Restarted from checkpoint round %d', round_num) round_num += 1 # Increment to avoid overwriting current checkpoint metrics_mngr.clear_rounds_after(last_valid_round_num=round_num - 1) loop_start_time = time.time() while round_num < total_rounds: data_prep_start_time = time.time() federated_train_data = client_datasets_fn(round_num) train_metrics = { 'prepare_datasets_secs': time.time() - data_prep_start_time } training_start_time = time.time() prev_model = state.model # TODO(b/145604851): This try/except is used to circumvent ambiguous TF # errors during training, and should be removed once the root cause is # determined (and possibly fixed). try: with profiler(round_num): state, round_metrics = iterative_process.next( state, federated_train_data) except (tf.errors.FailedPreconditionError, tf.errors.NotFoundError, tf.errors.InternalError) as e: logging.warning( 'Caught %s exception while running round %d:\n\t%s', type(e), round_num, e) continue # restart the loop without incrementing the round number train_metrics['training_secs'] = time.time() - training_start_time train_metrics['model_delta_l2_norm'] = _compute_numpy_l2_difference( state.model, prev_model) train_metrics.update(round_metrics) logging.info('Round {:2d}, {:.2f}s per round in average.'.format( round_num, (time.time() - loop_start_time) / (round_num + 1))) if (round_num % FLAGS.rounds_per_checkpoint == 0 or round_num == total_rounds - 1): save_checkpoint_start_time = time.time() checkpoint_mngr.save_checkpoint(state, round_num) train_metrics['save_checkpoint_secs'] = ( time.time() - save_checkpoint_start_time) metrics = {'train': train_metrics} if round_num % FLAGS.rounds_per_eval == 0: # Compute validation metrics evaluate_start_time = time.time() validation_metrics = validation_fn(state.model) validation_metrics['evaluate_secs'] = time.time( ) - evaluate_start_time metrics['eval'] = validation_metrics if train_eval_fn and round_num % FLAGS.rounds_per_train_eval == 0: # Compute metrics over the entire training dataset train_eval_start = time.time() train_eval_metrics = train_eval_fn(state.model) train_eval_metrics['evaluate_secs'] = time.time( ) - train_eval_start metrics['train_eval'] = train_eval_metrics _write_metrics(metrics_mngr, summary_writer, metrics, round_num) round_num += 1 # Final metrics evaluation once the training has completed metrics = {} # Validation metrics evaluate_start_time = time.time() validation_metrics = validation_fn(state.model) validation_metrics['evaluate_secs'] = time.time() - evaluate_start_time metrics['eval'] = validation_metrics # Training set metrics if train_eval_fn: train_eval_start = time.time() train_eval_metrics = train_eval_fn(state.model) train_eval_metrics['evaluate_secs'] = time.time() - train_eval_start metrics['train_eval'] = train_eval_metrics # Test set metrics if test_fn: test_start_time = time.time() test_metrics = test_fn(state.model) test_metrics['evaluate_secs'] = time.time() - test_start_time metrics['test'] = test_metrics _write_metrics(metrics_mngr, summary_writer, metrics, total_rounds) return state
def run(iterative_process: tff.templates.IterativeProcess, train_client_datasets_fn: Callable[[int], List[tf.data.Dataset]], evaluation_fn: Callable[[Any, int], Dict[str, float]], total_rounds: int, experiment_name: str, test_fn: Optional[Callable[[Any], Dict[str, float]]] = None, root_output_dir: Optional[str] = '/tmp/fedopt_guide', hparam_dict: Optional[Dict[str, Any]] = None, rounds_per_eval: Optional[int] = 10, rounds_per_checkpoint: Optional[int] = 50): """Runs federated training for a given `tff.templates.IterativeProcess`. We assume that the iterative process has the following functional type signatures: * `initialize`: `( -> S@SERVER)` where `S` represents the server state. * `next`: `<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>` where `S` represents the server state, `{B*}` represents the client datasets, and `T` represents a python `Mapping` object. The iterative process must also have a callable attribute `get_model_weights` that takes as input the state of the iterative process, and returns a nested structure of tensors that can be input to the `evaluation_fn` and `test_fn` (if provided). Args: iterative_process: A `tff.templates.IterativeProcess` instance to run. train_client_datasets_fn: Function accepting an integer argument (the round number) and returning a list of train client datasets to use as federated data for that training round. evaluation_fn: A callable accepting the output of the `get_model_weights` attribute of the iterative process and a `round_num`, and returning a dictionary of evaluation metrics. Used to compute evaluation metrics throughout the training process. total_rounds: The number of federated training rounds to perform. experiment_name: The name of the experiment being run. This will be appended to the `root_output_dir` for purposes of writing outputs. test_fn: A callable accepting the output of the `get_model_weights` attribute of the iterative process and returning a dictionary of test metrics. Used to compute test metrics at the end of the training process. root_output_dir: The name of the root output directory for writing experiment outputs. hparam_dict: An optional dictionary specifying hyperparameters of the experiment. If provided, the hyperparameters will be written to CSV. rounds_per_eval: How often to compute validation metrics. rounds_per_checkpoint: How often to checkpoint the iterative process state. If you expect the job to restart frequently, this should be small. If no interruptions are expected, this can be made larger. Returns: The final `state` of the iterative process after training. """ _check_iterative_process_compatibility(iterative_process) if not callable(train_client_datasets_fn): raise TypeError('train_client_datasets_fn should be callable.') if not callable(evaluation_fn): raise TypeError('evaluation_fn should be callable.') if test_fn is not None and not callable(test_fn): raise TypeError('test_fn should be callable.') logging.info('Starting iterative_process training loop...') initial_state = iterative_process.initialize() if not hasattr(initial_state, 'model'): raise TypeError('The server state must have a model attribute.') checkpoint_mngr, metrics_mngr, tensorboard_mngr = _setup_outputs( root_output_dir, experiment_name, hparam_dict) logging.info('Asking checkpoint manager to load checkpoint.') state, round_num = checkpoint_mngr.load_latest_checkpoint(initial_state) if state is None: logging.info('Initializing experiment from scratch.') state = initial_state round_num = 0 else: logging.info('Restarted from checkpoint round %d', round_num) round_num += 1 # Increment to avoid overwriting current checkpoint metrics_mngr.clear_metrics(round_num) current_model = iterative_process.get_model_weights(state) loop_start_time = time.time() while round_num < total_rounds: data_prep_start_time = time.time() federated_train_data = train_client_datasets_fn(round_num) train_metrics = { 'prepare_datasets_secs': time.time() - data_prep_start_time } training_start_time = time.time() prev_model = iterative_process.get_model_weights(state) state, round_metrics = iterative_process.next(state, federated_train_data) current_model = iterative_process.get_model_weights(state) train_metrics['training_secs'] = time.time() - training_start_time train_metrics['model_delta_l2_norm'] = _compute_numpy_l2_difference( current_model, prev_model) train_metrics.update(round_metrics) logging.info('Round {:2d}, {:.2f}s per round in average.'.format( round_num, (time.time() - loop_start_time) / (round_num + 1))) if (round_num % rounds_per_checkpoint == 0 or round_num == total_rounds - 1): save_checkpoint_start_time = time.time() checkpoint_mngr.save_checkpoint(state, round_num) train_metrics['save_checkpoint_secs'] = ( time.time() - save_checkpoint_start_time) metrics = {'train': train_metrics} if round_num % rounds_per_eval == 0: # Compute evaluation metrics evaluate_start_time = time.time() validation_metrics = evaluation_fn(current_model, round_num) validation_metrics['evaluate_secs'] = time.time( ) - evaluate_start_time metrics['eval'] = validation_metrics _write_metrics(metrics_mngr, tensorboard_mngr, metrics, round_num) round_num += 1 # Final metrics evaluation once the training has completed metrics = {} # Evaluation metrics evaluate_start_time = time.time() validation_metrics = evaluation_fn(current_model, round_num) validation_metrics['evaluate_secs'] = time.time() - evaluate_start_time metrics['eval'] = validation_metrics # Test set metrics if test_fn: test_start_time = time.time() test_metrics = test_fn(current_model) test_metrics['evaluate_secs'] = time.time() - test_start_time metrics['test'] = test_metrics _write_metrics(metrics_mngr, tensorboard_mngr, metrics, total_rounds) return state
def run(iterative_process: tff.templates.IterativeProcess, client_datasets_fn: Callable[[int], List[tf.data.Dataset]], validation_fn: Callable[[Any, int], Dict[str, float]], total_rounds: int, experiment_name: str, test_fn: Optional[Callable[[Any], Dict[str, float]]] = None, root_output_dir: Optional[str] = '/tmp/fed_opt', rounds_per_eval: Optional[int] = 1, rounds_per_checkpoint: Optional[int] = 50): """Runs federated training for a given `tff.templates.IterativeProcess`. We assume that the iterative process has the following functional type signatures: * `initialize`: `( -> S@SERVER)` where `S` represents the server state. * `next`: `<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>` where `S` represents the server state, `{B*}` represents the client datasets, and `T` represents a python `Mapping` object. Args: iterative_process: A `tff.templates.IterativeProcess` instance to run. client_datasets_fn: Function accepting an integer argument (the round number) and returning a list of client datasets to use as federated data for that round. validation_fn: A callable accepting the current state of `iterative_process` and the current round number, and returning a dict of evaluation metrics. Used to compute validation metrics throughout the training process. total_rounds: The number of federated training rounds to perform. experiment_name: The name of the experiment being run. This will be appended to the `root_output_dir` for purposes of writing outputs. test_fn: An optional callable accepting the current state of `iterative_process` and returning a dict of test set metrics. Used to compute test metrics at the end of the training process. root_output_dir: The name of the root output directory for writing experiment outputs. rounds_per_eval: How often to compute validation metrics. rounds_per_checkpoint: How often to checkpoint the iterative process state. If you expect the job to restart frequently, this should be small. If no interruptions are expected, this can be made larger. Returns: The final `state` of the iterative process after training. """ if not isinstance(iterative_process, tff.templates.IterativeProcess): raise TypeError( 'iterative_process must be a `tff.templates.IterativeProcess`.') if not callable(client_datasets_fn): raise TypeError('client_datasets_fn should be callable.') if not callable(validation_fn): raise TypeError('validation_fn should be callable.') if test_fn is not None and not callable(test_fn): raise TypeError('test_fn should be callable.') logging.info('Starting iterative_process training loop...') initial_state = iterative_process.initialize() checkpoint_mngr, metrics_mngr, tb_mngr = _setup_outputs( root_output_dir, experiment_name) logging.info('Asking checkpoint manager to load checkpoint.') state, round_num = checkpoint_mngr.load_latest_checkpoint(initial_state) if state is None: logging.info('Initializing experiment from scratch.') state = initial_state round_num = 0 else: logging.info('Restarted from checkpoint round %d', round_num) round_num += 1 # Increment to avoid overwriting current checkpoint metrics_mngr.clear_metrics(round_num) loop_start_time = time.time() loop_start_round = round_num while round_num < total_rounds: data_prep_start_time = time.time() federated_train_data = client_datasets_fn(round_num) train_metrics = { 'prepare_datasets_secs': time.time() - data_prep_start_time } training_start_time = time.time() state, round_metrics = iterative_process.next(state, federated_train_data) train_metrics['training_secs'] = time.time() - training_start_time train_metrics.update(round_metrics) loop_time = time.time() - loop_start_time loop_rounds = (round_num - loop_start_round + 1) logging.info('Round {:2d}, {:.2f}s per round in average.'.format( round_num, loop_time / loop_rounds)) if (round_num % rounds_per_checkpoint == 0 or round_num == total_rounds - 1): save_checkpoint_start_time = time.time() checkpoint_mngr.save_checkpoint(state, round_num) train_metrics['save_checkpoint_secs'] = ( time.time() - save_checkpoint_start_time) metrics = {'train': train_metrics} if round_num % rounds_per_eval == 0: # Compute validation metrics evaluate_start_time = time.time() validation_metrics = validation_fn(state, round_num) validation_metrics['evaluate_secs'] = time.time( ) - evaluate_start_time metrics['eval'] = validation_metrics _write_metrics(metrics_mngr, tb_mngr, metrics, round_num) round_num += 1 # Final metrics evaluation once the training has completed metrics = {} # Validation metrics evaluate_start_time = time.time() validation_metrics = validation_fn(state, round_num) validation_metrics['evaluate_secs'] = time.time() - evaluate_start_time metrics['eval'] = validation_metrics # Test set metrics if test_fn: test_start_time = time.time() test_metrics = test_fn(state) test_metrics['evaluate_secs'] = time.time() - test_start_time metrics['test'] = test_metrics _write_metrics(metrics_mngr, tb_mngr, metrics, total_rounds) return state
def run( iterative_process: tff.templates.IterativeProcess, client_datasets_fn: Callable[[int, int], Tuple[List, int]], # pylint: disable=g-bare-generic validation_fn: Callable[[Any], Dict[str, float]], total_epochs: int, total_rounds: int, experiment_name: str, train_eval_fn: Optional[Callable[[Any], Dict[str, float]]] = None, test_fn: Optional[Callable[[Any], Dict[str, float]]] = None, root_output_dir: Optional[str] = '/tmp/fed_opt', hparam_dict: Optional[Dict[str, Any]] = None, rounds_per_eval: Optional[int] = 1, rounds_per_checkpoint: Optional[int] = 50, rounds_per_train_eval: Optional[int] = 100): """Runs federated training for a given `tff.templates.IterativeProcess`. We assume that the iterative process has the following functional type signatures: * `initialize`: `( -> S@SERVER)` where `S` represents the server state. * `next`: `<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>` where `S` represents the server state, `{B*}` represents the client datasets, and `T` represents a python `Mapping` object. Args: iterative_process: A `tff.templates.IterativeProcess` instance to run. client_datasets_fn: Function accepts integer arguments (the round number and the epoch) and returns a tuple of a list of client datasets to use as data data for that round, and the updated epoch index. validation_fn: A callable accepting the `model` attribute of the iterative process state and returning a dict of evaluation metrics. Used to compute validation metrics throughout the training process. total_epochs: Nubmer of total epochs if using `ClientIDShuffler` to shuffle clients. Use 0 when sampling clients and control by `total_rounds`. total_rounds: The number of federated training rounds to perform. If `ClientIDShuffler` is used for `client_datasets_fn`, the total rounds will take the minimum of `total_rounds` and rounds_per_epoch*`total_epochs`. experiment_name: The name of the experiment being run. This will be appended to the `root_output_dir` for purposes of writing outputs. train_eval_fn: An optional callable accepting the `model` attribute of the iterative process state and returning a dict of evaluation metrics. Used to compute training metrics over the entire training dataset throughout the course of the iterative process. If set to `None`, no such evaluation is done. test_fn: An optional callable accepting the `model` attribute of the iterative process state and returning a dict of test metrics. Used to compute test metrics at the end of the training process. root_output_dir: The name of the root output directory for writing experiment outputs. hparam_dict: An optional dictionary specifying hyperparameters of the experiment. If provided, the hyperparameters will be written to CSV. rounds_per_eval: How often to compute validation metrics. rounds_per_checkpoint: How often to checkpoint the iterative process state. If you expect the job to restart frequently, this should be small. If no interruptions are expected, this can be made larger. rounds_per_train_eval: How often to compute metrics over the entire training dataset. Note that this is only done if a `train_eval_fn` argument is supplied. Returns: The final `state` of the iterative process after training. """ if not isinstance(iterative_process, tff.templates.IterativeProcess): raise TypeError('iterative_process should be type ' '`tff.templates.IterativeProcess`.') if not callable(client_datasets_fn): raise TypeError('client_datasets_fn should be callable.') if not callable(validation_fn): raise TypeError('validation_fn should be callable.') if train_eval_fn is not None and not callable(train_eval_fn): raise TypeError('train_eval_fn should be callable.') if test_fn is not None and not callable(test_fn): raise TypeError('test_fn should be callable.') logging.info('Starting iterative_process training loop...') initial_state = iterative_process.initialize() checkpoint_mngr, metrics_mngr, tensorboard_mngr = _setup_outputs( root_output_dir, experiment_name, hparam_dict) logging.info('Asking checkpoint manager to load checkpoint.') state, round_num = checkpoint_mngr.load_latest_checkpoint(initial_state) # TODO(b/172867399): we disable restarting from checkpoint when shuffling # client IDs by epochs. Non-trivial amount of change has to be made to make # sure disjoint clients are used cross rounds when restarts. A better design # of client dataset generator with random seed instead of `client_datasets_fn` # accepting `epoch` as argument, can help. epoch = 0 if total_epochs > 0 else -1 if state is None or total_epochs > 0: state = initial_state round_num = 0 logging.info('Initializing experiment from scratch at round %d.', round_num) else: logging.info('Restarted from checkpoint round %d', round_num) round_num += 1 # Increment to avoid overwriting current checkpoint metrics_mngr.clear_metrics(round_num) loop_start_time = time.time() while epoch < total_epochs and round_num < total_rounds: # TODO(b/172867399): add restarts functionality for FTRLM when total_epochs # is larger than one. data_prep_start_time = time.time() federated_train_data, epoch = client_datasets_fn(round_num, epoch) train_metrics = { 'prepare_datasets_secs': time.time() - data_prep_start_time } training_start_time = time.time() prev_model = _get_model_weights(state) state, loss = iterative_process.next(state, federated_train_data) train_metrics['training_secs'] = time.time() - training_start_time train_metrics['model_delta_l2_norm'] = _compute_numpy_l2_difference( _get_model_weights(state), prev_model) train_metrics['loss'] = loss logging.info('Round {:2d}, {:.2f}s per round in average.'.format( round_num, (time.time() - loop_start_time) / (round_num + 1))) if (round_num % rounds_per_checkpoint == 0 or round_num == total_rounds - 1): save_checkpoint_start_time = time.time() checkpoint_mngr.save_checkpoint(state, round_num) train_metrics['save_checkpoint_secs'] = ( time.time() - save_checkpoint_start_time) metrics = {'train': train_metrics} if train_eval_fn and round_num % rounds_per_train_eval == 0: # Compute metrics over the entire training dataset train_eval_start = time.time() train_eval_metrics = train_eval_fn(_get_model_weights(state)) train_eval_metrics['evaluate_secs'] = time.time( ) - train_eval_start metrics['train_eval'] = train_eval_metrics if round_num % rounds_per_eval == 0: # Compute validation metrics evaluate_start_time = time.time() validation_metrics = validation_fn(_get_model_weights(state)) validation_metrics['evaluate_secs'] = time.time( ) - evaluate_start_time metrics['eval'] = validation_metrics _write_metrics(metrics_mngr, tensorboard_mngr, metrics, round_num) round_num += 1 # Final metrics evaluation once the training has completed metrics = {} # Validation metrics evaluate_start_time = time.time() validation_metrics = validation_fn(_get_model_weights(state)) validation_metrics['evaluate_secs'] = time.time() - evaluate_start_time metrics['eval'] = validation_metrics # Training set metrics if train_eval_fn: train_eval_start = time.time() train_eval_metrics = train_eval_fn(_get_model_weights(state)) train_eval_metrics['evaluate_secs'] = time.time() - train_eval_start metrics['train_eval'] = train_eval_metrics # Test set metrics if test_fn: test_start_time = time.time() test_metrics = test_fn(_get_model_weights(state)) test_metrics['evaluate_secs'] = time.time() - test_start_time metrics['test'] = test_metrics _write_metrics(metrics_mngr, tensorboard_mngr, metrics, round_num) return state