def test_client_datasets_fn_with_random_seed(self): tff_dataset = tff.simulation.client_data.ConcreteClientData( [0, 1, 2, 3, 4], create_tf_dataset_for_client) client_datasets_fn_1 = training_utils.build_client_datasets_fn( tff_dataset, clients_per_round=1, random_seed=363) client_datasets_1 = client_datasets_fn_1(round_num=5) sample_batch_1 = next(iter(client_datasets_1[0])) client_datasets_fn_2 = training_utils.build_client_datasets_fn( tff_dataset, clients_per_round=1, random_seed=363) client_datasets_2 = client_datasets_fn_2(round_num=5) sample_batch_2 = next(iter(client_datasets_2[0])) self.assertAllClose(sample_batch_1, sample_batch_2)
def test_build_client_datasets_fn(self): tff_dataset = tff.simulation.client_data.ConcreteClientData( [2], create_tf_dataset_for_client) client_datasets_fn = training_utils.build_client_datasets_fn( tff_dataset, clients_per_round=1) client_datasets = client_datasets_fn(round_num=7) sample_batch = next(iter(client_datasets[0])) reference_batch = next(iter(create_tf_dataset_for_client(2))) self.assertAllClose(sample_batch, reference_batch)
def test_different_random_seed_give_different_clients(self): tff_dataset = tff.simulation.client_data.ConcreteClientData( list(range(100)), create_tf_dataset_for_client) client_datasets_fn_1 = training_utils.build_client_datasets_fn( tff_dataset, clients_per_round=50, random_seed=1) client_datasets_1 = client_datasets_fn_1(round_num=1001) sample_batches_1 = [ next(iter(client_dataset)) for client_dataset in client_datasets_1 ] client_datasets_fn_2 = training_utils.build_client_datasets_fn( tff_dataset, clients_per_round=50, random_seed=2) client_datasets_2 = client_datasets_fn_2(round_num=1001) sample_batches_2 = [ next(iter(client_dataset)) for client_dataset in client_datasets_2 ] self.assertNotAllClose(sample_batches_1, sample_batches_2)
def test_client_datasets_fn_without_random_seed(self): tff_dataset = tff.simulation.client_data.ConcreteClientData( list(range(100)), create_tf_dataset_for_client) client_datasets_fn = training_utils.build_client_datasets_fn( tff_dataset, train_clients_per_round=50) client_datasets_1 = client_datasets_fn(round_num=0) sample_batches_1 = [ next(iter(client_dataset)) for client_dataset in client_datasets_1 ] client_datasets_2 = client_datasets_fn(round_num=0) sample_batches_2 = [ next(iter(client_dataset)) for client_dataset in client_datasets_2 ] self.assertNotAllClose(sample_batches_1, sample_batches_2)
def run_experiment(): """Data preprocessing and experiment execution.""" emnist_train, emnist_test = emnist_dataset.get_emnist_datasets( FLAGS.client_batch_size, FLAGS.client_epochs_per_round, only_digits=FLAGS.only_digits) example_dataset = emnist_train.create_tf_dataset_for_client( emnist_train.client_ids[0]) input_spec = example_dataset.element_spec client_datasets_fn = training_utils.build_client_datasets_fn( emnist_train, FLAGS.clients_per_round) evaluate_fn = training_utils.build_centralized_evaluate_fn( eval_dataset=emnist_test, model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder) client_optimizer_fn = functools.partial( utils_impl.create_optimizer_from_flags, 'client') server_optimizer_fn = functools.partial( utils_impl.create_optimizer_from_flags, 'server') def tff_model_fn(): keras_model = model_builder() return tff.learning.from_keras_model(keras_model, input_spec=input_spec, loss=loss_builder(), metrics=metrics_builder()) if FLAGS.use_compression: # We create a `MeasuredProcess` for broadcast process and a # `MeasuredProcess` for aggregate process by providing the # `_broadcast_encoder_fn` and `_mean_encoder_fn` to corresponding utilities. # The fns are called once for each of the model weights created by # tff_model_fn, and return instances of appropriate encoders. encoded_broadcast_process = ( tff.learning.framework.build_encoded_broadcast_process_from_model( tff_model_fn, _broadcast_encoder_fn)) encoded_mean_process = ( tff.learning.framework.build_encoded_mean_process_from_model( tff_model_fn, _mean_encoder_fn)) else: encoded_broadcast_process = None encoded_mean_process = None iterative_process = tff.learning.build_federated_averaging_process( model_fn=tff_model_fn, client_optimizer_fn=client_optimizer_fn, server_optimizer_fn=server_optimizer_fn, aggregation_process=encoded_mean_process, broadcast_process=encoded_broadcast_process) hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags()) training_loop_dict = utils_impl.lookup_flag_values(training_loop_flags) training_loop.run(iterative_process=iterative_process, client_datasets_fn=client_datasets_fn, validation_fn=evaluate_fn, hparam_dict=hparam_dict, **training_loop_dict)
def run_federated( iterative_process_builder: Callable[..., tff.templates.IterativeProcess], evaluation_computation_builder: Callable[..., tff.Computation], client_batch_size: int, clients_per_round: int, global_variables_only: bool, vocab_size: int = 10000, num_oov_buckets: int = 1, sequence_length: int = 20, max_elements_per_user: int = 1000, embedding_size: int = 96, latent_size: int = 670, num_layers: int = 1, total_rounds: int = 1500, experiment_name: str = 'federated_so_nwp', root_output_dir: str = '/tmp/fed_recon', split_dataset_strategy: str = federated_trainer_utils .SPLIT_STRATEGY_AGGREGATED, split_dataset_proportion: int = 2, compose_dataset_computation: bool = False, **kwargs): """Runs an iterative process on the Stack Overflow next word prediction task. This method will load and pre-process dataset and construct a model used for the task. It then uses `iterative_process_builder` to create an iterative process that it applies to the task, using `federated_research.utils.training_loop`. This model only sends updates for its embeddings corresponding to the most common words. Embeddings for out of vocabulary buckets are reconstructed on device at the beginning of each round, and destroyed at the end of these rounds. We assume that the iterative process has the following functional type signatures: * `initialize`: `( -> S@SERVER)` where `S` represents the server state. * `next`: `<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>` where `S` represents the server state, `{B*}` represents the client datasets, and `T` represents a python `Mapping` object. The iterative process must also have a callable attribute `get_model_weights` that takes as input the state of the iterative process, and returns a `tff.learning.ModelWeights` object. Args: iterative_process_builder: A function that accepts a no-arg `model_fn`, a `loss_fn`, a `metrics_fn`, and a `client_weight_fn`, and returns a `tff.templates.IterativeProcess`. The `model_fn` must return a `reconstruction_model.ReconstructionModel`. See `federated_trainer.py` for an example. evaluation_computation_builder: A function that accepts a no-arg `model_fn`, a loss_fn`, and a `metrics_fn`, and returns a `tff.Computation` for federated reconstruction evaluation. The `model_fn` must return a `reconstruction_model.ReconstructionModel`. See `federated_trainer.py` for an example. client_batch_size: An integer representing the batch size used on clients. clients_per_round: An integer representing the number of clients participating in each round. global_variables_only: If True, the `ReconstructionModel` contains all model variables as global variables. This can be useful for baselines involving aggregating all variables. vocab_size: Integer dictating the number of most frequent words to use in the vocabulary. num_oov_buckets: The number of out-of-vocabulary buckets to use. sequence_length: The maximum number of words to take for each sequence. max_elements_per_user: The maximum number of elements processed for each client's dataset. embedding_size: The dimension of the word embedding layer. latent_size: The dimension of the latent units in the recurrent layers. num_layers: The number of stacked recurrent layers to use. total_rounds: The number of federated training rounds. experiment_name: The name of the experiment being run. This will be appended to the `root_output_dir` for purposes of writing outputs. root_output_dir: The name of the root output directory for writing experiment outputs. split_dataset_strategy: The method to use to split the data. Must be one of `skip`, in which case every `split_dataset_proportion` example is used for reconstruction, or `aggregated`, when the first 1/`split_dataset_proportion` proportion of the examples is used for reconstruction. split_dataset_proportion: Parameter controlling how much of the data is used for reconstruction. If `split_dataset_proportion` is n, then 1 / n of the data is used for reconstruction. compose_dataset_computation: Whether to compose dataset computation with training and evaluation computations. If True, may speed up experiments by parallelizing dataset computations in multimachine setups. Not currently supported in OSS. **kwargs: Additional arguments configuring the training loop. For details on supported arguments, see `training_loop.py`. """ loss_fn = functools.partial( tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True) special_tokens = stackoverflow_word_prediction.get_special_tokens( vocab_size, num_oov_buckets) pad_token = special_tokens.pad oov_tokens = special_tokens.oov eos_token = special_tokens.eos def metrics_fn(): return [ keras_metrics.MaskedCategoricalAccuracy( name='accuracy_with_oov', masked_tokens=[pad_token]), keras_metrics.MaskedCategoricalAccuracy( name='accuracy_no_oov', masked_tokens=[pad_token] + oov_tokens), # Notice BOS never appears in ground truth. keras_metrics.MaskedCategoricalAccuracy( name='accuracy_no_oov_or_eos', masked_tokens=[pad_token, eos_token] + oov_tokens), keras_metrics.NumBatchesCounter(), keras_metrics.NumTokensCounter(masked_tokens=[pad_token]) ] train_clientdata, validation_clientdata, test_clientdata = ( tff.simulation.datasets.stackoverflow.load_data()) vocab = stackoverflow_word_prediction.create_vocab(vocab_size) dataset_preprocess_comp = stackoverflow_dataset.create_preprocess_fn( vocab=vocab, num_oov_buckets=num_oov_buckets, client_batch_size=client_batch_size, max_sequence_length=sequence_length, max_elements_per_client=max_elements_per_user, feature_dtypes=train_clientdata.element_type_structure, sort_by_date=True) input_spec = dataset_preprocess_comp.type_signature.result.element model_fn = functools.partial( models.create_recurrent_reconstruction_model, vocab_size=vocab_size, num_oov_buckets=num_oov_buckets, embedding_size=embedding_size, latent_size=latent_size, num_layers=num_layers, input_spec=input_spec, global_variables_only=global_variables_only) def client_weight_fn(local_outputs): # Num_tokens is a tensor with type int64[1], to use as a weight need # a float32 scalar. return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32) iterative_process = iterative_process_builder( model_fn, loss_fn=loss_fn, metrics_fn=metrics_fn, client_weight_fn=client_weight_fn, dataset_split_fn_builder=functools.partial( federated_trainer_utils.build_dataset_split_fn, split_dataset_strategy=split_dataset_strategy, split_dataset_proportion=split_dataset_proportion)) base_eval_computation = evaluation_computation_builder( model_fn, loss_fn=loss_fn, metrics_fn=metrics_fn, dataset_split_fn_builder=functools.partial( federated_trainer_utils.build_dataset_split_fn, split_dataset_strategy=split_dataset_strategy, split_dataset_proportion=split_dataset_proportion)) if compose_dataset_computation: # Compose dataset computations with client training and evaluation to avoid # linear cost of computing centrally. This changes the expected input of # the `IterativeProcess` and `tff.Computation` to be a list of client IDs # instead of datasets. training_process = ( tff.simulation.compose_dataset_computation_with_iterative_process( dataset_preprocess_comp, iterative_process)) training_process = ( tff.simulation.compose_dataset_computation_with_iterative_process( train_clientdata.dataset_computation, training_process)) training_process.get_model_weights = iterative_process.get_model_weights base_eval_computation = ( tff.simulation.compose_dataset_computation_with_computation( dataset_preprocess_comp, base_eval_computation)) val_computation = ( tff.simulation.compose_dataset_computation_with_computation( validation_clientdata.dataset_computation, base_eval_computation)) test_computation = ( tff.simulation.compose_dataset_computation_with_computation( test_clientdata.dataset_computation, base_eval_computation)) # Create client sampling functions for each of train/val/test. # We need to sample client IDs, not datasets, and we do not need to apply # `dataset_preprocess_comp` since this is applied as part of the training # process and evaluation computation. train_client_datasets_fn = federated_trainer_utils.build_list_sample_fn( train_clientdata.client_ids, size=clients_per_round, replace=False) val_client_datasets_fn = federated_trainer_utils.build_list_sample_fn( validation_clientdata.client_ids, size=clients_per_round, replace=False) test_client_datasets_fn = federated_trainer_utils.build_list_sample_fn( test_clientdata.client_ids, size=clients_per_round, replace=False) else: training_process = iterative_process val_computation = base_eval_computation test_computation = base_eval_computation # Apply dataset computations. train_clientdata = train_clientdata.preprocess(dataset_preprocess_comp) validation_clientdata = validation_clientdata.preprocess( dataset_preprocess_comp) test_clientdata = test_clientdata.preprocess(dataset_preprocess_comp) # Create client sampling functions for each of train/val/test. train_client_datasets_fn = training_utils.build_client_datasets_fn( train_clientdata, clients_per_round=clients_per_round) val_client_datasets_fn = training_utils.build_client_datasets_fn( validation_clientdata, clients_per_round=clients_per_round) test_client_datasets_fn = training_utils.build_client_datasets_fn( test_clientdata, clients_per_round=clients_per_round) # Create final evaluation functions to pass to `training_loop`. val_fn = federated_trainer_utils.build_eval_fn( evaluation_computation=val_computation, client_datasets_fn=val_client_datasets_fn, get_model=training_process.get_model_weights) test_fn = federated_trainer_utils.build_eval_fn( evaluation_computation=test_computation, client_datasets_fn=test_client_datasets_fn, get_model=training_process.get_model_weights) test_fn = functools.partial(test_fn, round_num=0) training_loop.run( iterative_process=training_process, client_datasets_fn=train_client_datasets_fn, validation_fn=val_fn, test_fn=test_fn, total_rounds=total_rounds, experiment_name=experiment_name, root_output_dir=root_output_dir, **kwargs)
def configure_training( task_spec: training_specs.TaskSpec, vocab_tokens_size: int = 10000, vocab_tags_size: int = 500, max_elements_per_user: int = 1000, num_validation_examples: int = 10000) -> training_specs.RunnerSpec: """Configures training for the Stack Overflow tag prediction task. This tag prediction is performed via multi-class one-versus-rest logistic regression. This method will load and pre-process datasets and construct a model used for the task. It then uses `iterative_process_builder` to create an iterative process compatible with `federated_research.utils.training_loop`. Args: task_spec: A `TaskSpec` class for creating federated training tasks. vocab_tokens_size: Integer dictating the number of most frequent words to use in the vocabulary. vocab_tags_size: Integer dictating the number of most frequent tags to use in the label creation. max_elements_per_user: The maximum number of elements processed for each client's dataset. num_validation_examples: The number of test examples to use for validation. Returns: A `RunnerSpec` containing attributes used for running the newly created federated task. """ stackoverflow_train, _, _ = tff.simulation.datasets.stackoverflow.load_data( ) _, stackoverflow_validation, stackoverflow_test = stackoverflow_tag_prediction.get_centralized_datasets( train_batch_size=task_spec.client_batch_size, word_vocab_size=vocab_tokens_size, tag_vocab_size=vocab_tags_size, num_validation_examples=num_validation_examples) word_vocab = stackoverflow_tag_prediction.create_word_vocab( vocab_tokens_size) tag_vocab = stackoverflow_tag_prediction.create_tag_vocab(vocab_tags_size) train_preprocess_fn = stackoverflow_tag_prediction.create_preprocess_fn( word_vocab=word_vocab, tag_vocab=tag_vocab, client_batch_size=task_spec.client_batch_size, client_epochs_per_round=task_spec.client_epochs_per_round, max_elements_per_client=max_elements_per_user) input_spec = train_preprocess_fn.type_signature.result.element model_builder = functools.partial( stackoverflow_lr_models.create_logistic_model, vocab_tokens_size=vocab_tokens_size, vocab_tags_size=vocab_tags_size) loss_builder = functools.partial(tf.keras.losses.BinaryCrossentropy, from_logits=False, reduction=tf.keras.losses.Reduction.SUM) def tff_model_fn() -> tff.learning.Model: return tff.learning.from_keras_model(keras_model=model_builder(), input_spec=input_spec, loss=loss_builder(), metrics=metrics_builder()) iterative_process = task_spec.iterative_process_builder(tff_model_fn) if hasattr(stackoverflow_train, 'dataset_computation'): @tff.tf_computation(tf.string) def build_train_dataset_from_client_id(client_id): client_dataset = stackoverflow_train.dataset_computation(client_id) return train_preprocess_fn(client_dataset) training_process = tff.simulation.compose_dataset_computation_with_iterative_process( build_train_dataset_from_client_id, iterative_process) client_ids_fn = training_utils.build_sample_fn( stackoverflow_train.client_ids, size=task_spec.clients_per_round, replace=False, random_seed=task_spec.client_datasets_random_seed) # We convert the output to a list (instead of an np.ndarray) so that it can # be used as input to the iterative process. client_sampling_fn = lambda x: list(client_ids_fn(x)) else: training_process = tff.simulation.compose_dataset_computation_with_iterative_process( train_preprocess_fn, iterative_process) client_sampling_fn = training_utils.build_client_datasets_fn( dataset=stackoverflow_train, clients_per_round=task_spec.clients_per_round, random_seed=task_spec.client_datasets_random_seed) training_process.get_model_weights = iterative_process.get_model_weights evaluate_fn = training_utils.build_centralized_evaluate_fn( model_builder=model_builder, eval_dataset=stackoverflow_validation, loss_builder=loss_builder, metrics_builder=metrics_builder) validation_fn = lambda model_weights, round_num: evaluate_fn(model_weights) test_fn = training_utils.build_centralized_evaluate_fn( model_builder=model_builder, # Use both val and test for symmetry with other experiments, which # evaluate on the entire test set. eval_dataset=stackoverflow_validation.concatenate(stackoverflow_test), loss_builder=loss_builder, metrics_builder=metrics_builder) return training_specs.RunnerSpec(iterative_process=training_process, client_datasets_fn=client_sampling_fn, validation_fn=validation_fn, test_fn=test_fn)
def configure_training(task_spec: training_specs.TaskSpec, eval_spec: Optional[training_specs.EvalSpec] = None, model: str = 'cnn') -> training_specs.RunnerSpec: """Configures training for the EMNIST character recognition task. This method will load and pre-process datasets and construct a model used for the task. It then uses `iterative_process_builder` to create an iterative process compatible with `federated_research.utils.training_loop`. Args: task_spec: A `TaskSpec` class for creating federated training tasks. eval_spec: An `EvalSpec` class for configuring federated evaluation. If set to None, centralized evaluation is used for validation and testing instead. model: A string specifying the model used for character recognition. Can be one of `cnn` and `2nn`, corresponding to a CNN model and a densely connected 2-layer model (respectively). Returns: A `RunnerSpec` containing attributes used for running the newly created federated task. """ emnist_task = 'digit_recognition' emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data( only_digits=False) train_preprocess_fn = emnist_dataset.create_preprocess_fn( num_epochs=task_spec.client_epochs_per_round, batch_size=task_spec.client_batch_size, emnist_task=emnist_task) input_spec = train_preprocess_fn.type_signature.result.element if model == 'cnn': model_builder = functools.partial( emnist_models.create_conv_dropout_model, only_digits=False) elif model == '2nn': model_builder = functools.partial( emnist_models.create_two_hidden_layer_model, only_digits=False) else: raise ValueError( 'Cannot handle model flag [{!s}], must be one of {!s}.'.format( model, EMNIST_MODELS)) loss_builder = tf.keras.losses.SparseCategoricalCrossentropy metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()] def tff_model_fn() -> tff.learning.Model: return tff.learning.from_keras_model( keras_model=model_builder(), input_spec=input_spec, loss=loss_builder(), metrics=metrics_builder()) iterative_process = task_spec.iterative_process_builder(tff_model_fn) clients_per_train_round = min(task_spec.clients_per_round, TOTAL_NUM_TRAIN_CLIENTS) if hasattr(emnist_train, 'dataset_computation'): @tff.tf_computation(tf.string) def build_train_dataset_from_client_id(client_id): client_dataset = emnist_train.dataset_computation(client_id) return train_preprocess_fn(client_dataset) training_process = tff.simulation.compose_dataset_computation_with_iterative_process( build_train_dataset_from_client_id, iterative_process) client_ids_fn = training_utils.build_sample_fn( emnist_train.client_ids, size=clients_per_train_round, replace=False, random_seed=task_spec.sampling_random_seed) # We convert the output to a list (instead of an np.ndarray) so that it can # be used as input to the iterative process. client_sampling_fn = lambda x: list(client_ids_fn(x)) else: training_process = tff.simulation.compose_dataset_computation_with_iterative_process( train_preprocess_fn, iterative_process) client_sampling_fn = training_utils.build_client_datasets_fn( dataset=emnist_train, clients_per_round=clients_per_train_round, random_seed=task_spec.sampling_random_seed) training_process.get_model_weights = iterative_process.get_model_weights if eval_spec: if eval_spec.clients_per_validation_round is None: clients_per_validation_round = TOTAL_NUM_TEST_CLIENTS else: clients_per_validation_round = min(eval_spec.clients_per_validation_round, TOTAL_NUM_TEST_CLIENTS) if eval_spec.clients_per_test_round is None: clients_per_test_round = TOTAL_NUM_TEST_CLIENTS else: clients_per_test_round = min(eval_spec.clients_per_test_round, TOTAL_NUM_TEST_CLIENTS) test_preprocess_fn = emnist_dataset.create_preprocess_fn( num_epochs=1, batch_size=eval_spec.client_batch_size, shuffle_buffer_size=1, emnist_task=emnist_task) emnist_test = emnist_test.preprocess(test_preprocess_fn) def eval_metrics_builder(): return [ tf.keras.metrics.SparseCategoricalCrossentropy(), tf.keras.metrics.SparseCategoricalAccuracy() ] federated_eval_fn = training_utils.build_federated_evaluate_fn( model_builder=model_builder, metrics_builder=eval_metrics_builder) validation_client_sampling_fn = training_utils.build_client_datasets_fn( emnist_test, clients_per_validation_round, random_seed=eval_spec.sampling_random_seed) test_client_sampling_fn = training_utils.build_client_datasets_fn( emnist_test, clients_per_test_round, random_seed=eval_spec.sampling_random_seed) def validation_fn(model_weights, round_num): validation_clients = validation_client_sampling_fn(round_num) return federated_eval_fn(model_weights, validation_clients) def test_fn(model_weights): # We fix the round number to get deterministic behavior test_round_num = 0 test_clients = test_client_sampling_fn(test_round_num) return federated_eval_fn(model_weights, test_clients) else: _, central_emnist_test = emnist_dataset.get_centralized_datasets( only_digits=False, emnist_task=emnist_task) test_fn = training_utils.build_centralized_evaluate_fn( eval_dataset=central_emnist_test, model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder) validation_fn = lambda model_weights, round_num: test_fn(model_weights) return training_specs.RunnerSpec( iterative_process=training_process, client_datasets_fn=client_sampling_fn, validation_fn=validation_fn, test_fn=test_fn)
def run_federated( iterative_process_builder: Callable[..., tff.templates.IterativeProcess], client_epochs_per_round: int, client_batch_size: int, clients_per_round: int, client_datasets_random_seed: Optional[int] = None, model: Optional[str] = 'cnn', total_rounds: Optional[int] = 1500, experiment_name: Optional[str] = 'federated_emnist_cr', root_output_dir: Optional[str] = '/tmp/fed_opt', **kwargs): """Runs an iterative process on the EMNIST character recognition task. This method will load and pre-process dataset and construct a model used for the task. It then uses `iterative_process_builder` to create an iterative process that it applies to the task, using `federated_research.utils.training_loop`. We assume that the iterative process has the following functional type signatures: * `initialize`: `( -> S@SERVER)` where `S` represents the server state. * `next`: `<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>` where `S` represents the server state, `{B*}` represents the client datasets, and `T` represents a python `Mapping` object. The iterative process must also have a callable attribute `get_model_weights` that takes as input the state of the iterative process, and returns a `tff.learning.ModelWeights` object. Args: iterative_process_builder: A function that accepts a no-arg `model_fn`, and returns a `tff.templates.IterativeProcess`. The `model_fn` must return a `tff.learning.Model`. client_epochs_per_round: An integer representing the number of epochs of training performed per client in each training round. client_batch_size: An integer representing the batch size used on clients. clients_per_round: An integer representing the number of clients participating in each round. client_datasets_random_seed: An optional int used to seed which clients are sampled at each round. If `None`, no seed is used. model: A string specifying the model used for character recognition. Can be one of `cnn` and `2nn`, corresponding to a CNN model and a densely connected 2-layer model (respectively). total_rounds: The number of federated training rounds. experiment_name: The name of the experiment being run. This will be appended to the `root_output_dir` for purposes of writing outputs. root_output_dir: The name of the root output directory for writing experiment outputs. **kwargs: Additional arguments configuring the training loop. For details on supported arguments, see `federated_research/utils/training_utils.py`. """ emnist_train, _ = emnist_dataset.get_federated_datasets( train_client_batch_size=client_batch_size, train_client_epochs_per_round=client_epochs_per_round, only_digits=False) _, emnist_test = emnist_dataset.get_centralized_datasets(only_digits=False) input_spec = emnist_train.create_tf_dataset_for_client( emnist_train.client_ids[0]).element_spec if model == 'cnn': model_builder = functools.partial( emnist_models.create_conv_dropout_model, only_digits=False) elif model == '2nn': model_builder = functools.partial( emnist_models.create_two_hidden_layer_model, only_digits=False) else: raise ValueError( 'Cannot handle model flag [{!s}], must be one of {!s}.'.format( model, EMNIST_MODELS)) loss_builder = tf.keras.losses.SparseCategoricalCrossentropy metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()] def tff_model_fn() -> tff.learning.Model: return tff.learning.from_keras_model( keras_model=model_builder(), input_spec=input_spec, loss=loss_builder(), metrics=metrics_builder()) training_process = iterative_process_builder(tff_model_fn) client_datasets_fn = training_utils.build_client_datasets_fn( dataset=emnist_train, clients_per_round=clients_per_round, random_seed=client_datasets_random_seed) test_fn = training_utils.build_centralized_evaluate_fn( eval_dataset=emnist_test, model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder) validation_fn = lambda model_weights, round_num: test_fn(model_weights) logging.info('Training model:') logging.info(model_builder().summary()) training_loop.run( iterative_process=training_process, client_datasets_fn=client_datasets_fn, validation_fn=validation_fn, test_fn=test_fn, total_rounds=total_rounds, experiment_name=experiment_name, root_output_dir=root_output_dir, **kwargs)
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) tff.backends.native.set_local_execution_context(max_fanout=10) model_builder = functools.partial( stackoverflow_models.create_recurrent_model, vocab_size=FLAGS.vocab_size, embedding_size=FLAGS.embedding_size, latent_size=FLAGS.latent_size, num_layers=FLAGS.num_layers, shared_embedding=FLAGS.shared_embedding) loss_builder = functools.partial( tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True) special_tokens = stackoverflow_dataset.get_special_tokens(FLAGS.vocab_size) pad_token = special_tokens.pad oov_tokens = special_tokens.oov eos_token = special_tokens.eos def metrics_builder(): return [ keras_metrics.MaskedCategoricalAccuracy(name='accuracy_with_oov', masked_tokens=[pad_token]), keras_metrics.MaskedCategoricalAccuracy(name='accuracy_no_oov', masked_tokens=[pad_token] + oov_tokens), # Notice BOS never appears in ground truth. keras_metrics.MaskedCategoricalAccuracy( name='accuracy_no_oov_or_eos', masked_tokens=[pad_token, eos_token] + oov_tokens), keras_metrics.NumBatchesCounter(), keras_metrics.NumTokensCounter(masked_tokens=[pad_token]), ] datasets = stackoverflow_dataset.construct_word_level_datasets( FLAGS.vocab_size, FLAGS.client_batch_size, FLAGS.client_epochs_per_round, FLAGS.sequence_length, FLAGS.max_elements_per_user, FLAGS.num_validation_examples) train_dataset, validation_dataset, test_dataset = datasets if FLAGS.uniform_weighting: def client_weight_fn(local_outputs): del local_outputs return 1.0 else: def client_weight_fn(local_outputs): return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32) def model_fn(): return tff.learning.from_keras_model( model_builder(), loss_builder(), input_spec=validation_dataset.element_spec, metrics=metrics_builder()) if FLAGS.noise_multiplier is not None: if not FLAGS.uniform_weighting: raise ValueError( 'Differential privacy is only implemented for uniform weighting.' ) dp_query = tff.utils.build_dp_query( clip=FLAGS.clip, noise_multiplier=FLAGS.noise_multiplier, expected_total_weight=FLAGS.clients_per_round, adaptive_clip_learning_rate=FLAGS.adaptive_clip_learning_rate, target_unclipped_quantile=FLAGS.target_unclipped_quantile, clipped_count_budget_allocation=FLAGS. clipped_count_budget_allocation, expected_clients_per_round=FLAGS.clients_per_round) weights_type = tff.learning.framework.weights_type_from_model(model_fn) aggregation_process = tff.utils.build_dp_aggregate_process( weights_type.trainable, dp_query) else: aggregation_process = None server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'server') client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'client') iterative_process = tff.learning.build_federated_averaging_process( model_fn=model_fn, server_optimizer_fn=server_optimizer_fn, client_weight_fn=client_weight_fn, client_optimizer_fn=client_optimizer_fn, aggregation_process=aggregation_process) client_datasets_fn = training_utils.build_client_datasets_fn( train_dataset, FLAGS.clients_per_round) evaluate_fn = training_utils.build_centralized_evaluate_fn( model_builder=model_builder, eval_dataset=validation_dataset, loss_builder=loss_builder, metrics_builder=metrics_builder) test_fn = training_utils.build_centralized_evaluate_fn( model_builder=model_builder, # Use both val and test for symmetry with other experiments, which # evaluate on the entire test set. eval_dataset=validation_dataset.concatenate(test_dataset), loss_builder=loss_builder, metrics_builder=metrics_builder) logging.info('Training model:') logging.info(model_builder().summary()) hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags()) training_loop_dict = utils_impl.lookup_flag_values(training_loop_flags) training_loop.run(iterative_process=iterative_process, client_datasets_fn=client_datasets_fn, validation_fn=evaluate_fn, test_fn=test_fn, hparam_dict=hparam_dict, **training_loop_dict)
def run_experiment(): """Data preprocessing and experiment execution.""" emnist_train, _ = emnist_dataset.get_federated_datasets( train_client_batch_size=FLAGS.client_batch_size, train_client_epochs_per_round=FLAGS.client_epochs_per_round, only_digits=False) _, emnist_test = emnist_dataset.get_centralized_datasets() example_dataset = emnist_train.create_tf_dataset_for_client( emnist_train.client_ids[0]) input_spec = example_dataset.element_spec client_datasets_fn = training_utils.build_client_datasets_fn( emnist_train, FLAGS.clients_per_round) evaluate_fn = training_utils.build_centralized_evaluate_fn( eval_dataset=emnist_test, model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder) validation_fn = lambda model_weights, round_num: evaluate_fn(model_weights) client_optimizer_fn = functools.partial( utils_impl.create_optimizer_from_flags, 'client') server_optimizer_fn = functools.partial( utils_impl.create_optimizer_from_flags, 'server') def tff_model_fn(): keras_model = model_builder() return tff.learning.from_keras_model(keras_model, input_spec=input_spec, loss=loss_builder(), metrics=metrics_builder()) if FLAGS.use_compression: # We create a `MeasuredProcess` for broadcast process and a # `MeasuredProcess` for aggregate process by providing the # `_broadcast_encoder_fn` and `_mean_encoder_fn` to corresponding utilities. # The fns are called once for each of the model weights created by # tff_model_fn, and return instances of appropriate encoders. encoded_broadcast_process = ( tff.learning.framework.build_encoded_broadcast_process_from_model( tff_model_fn, _broadcast_encoder_fn)) encoded_mean_process = ( tff.learning.framework.build_encoded_mean_process_from_model( tff_model_fn, _mean_encoder_fn)) else: encoded_broadcast_process = None encoded_mean_process = None iterative_process = tff.learning.build_federated_averaging_process( model_fn=tff_model_fn, client_optimizer_fn=client_optimizer_fn, server_optimizer_fn=server_optimizer_fn, aggregation_process=encoded_mean_process, broadcast_process=encoded_broadcast_process) # Log hyperparameters to CSV hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags()) results_dir = os.path.join(FLAGS.root_output_dir, 'results', FLAGS.experiment_name) utils_impl.create_directory_if_not_exists(results_dir) hparam_file = os.path.join(results_dir, 'hparams.csv') utils_impl.atomic_write_series_to_csv(hparam_dict, hparam_file) training_loop.run(iterative_process=iterative_process, client_datasets_fn=client_datasets_fn, validation_fn=validation_fn, total_rounds=FLAGS.total_rounds, experiment_name=FLAGS.experiment_name, root_output_dir=FLAGS.root_output_dir, rounds_per_eval=FLAGS.rounds_per_eval, rounds_per_checkpoint=FLAGS.rounds_per_checkpoint, rounds_per_profile=FLAGS.rounds_per_profile)
def run_federated( iterative_process_builder: Callable[..., tff.templates.IterativeProcess], client_epochs_per_round: int, client_batch_size: int, clients_per_round: int, client_datasets_random_seed: Optional[int] = None, vocab_size: Optional[int] = 10000, num_oov_buckets: Optional[int] = 1, sequence_length: Optional[int] = 20, max_elements_per_user: Optional[int] = 1000, num_validation_examples: Optional[int] = 10000, embedding_size: Optional[int] = 96, latent_size: Optional[int] = 670, num_layers: Optional[int] = 1, shared_embedding: Optional[bool] = False, total_rounds: Optional[int] = 1500, experiment_name: Optional[str] = 'federated_so_nwp', root_output_dir: Optional[str] = '/tmp/fed_opt', **kwargs): """Runs an iterative process on the Stack Overflow next word prediction task. This method will load and pre-process dataset and construct a model used for the task. It then uses `iterative_process_builder` to create an iterative process that it applies to the task, using `federated_research.utils.training_loop`. We assume that the iterative process has the following functional type signatures: * `initialize`: `( -> S@SERVER)` where `S` represents the server state. * `next`: `<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>` where `S` represents the server state, `{B*}` represents the client datasets, and `T` represents a python `Mapping` object. The iterative process must also have a callable attribute `get_model_weights` that takes as input the state of the iterative process, and returns a `tff.learning.ModelWeights` object. Args: iterative_process_builder: A function that accepts a no-arg `model_fn`, a `client_weight_fn` and returns a `tff.templates.IterativeProcess`. The `model_fn` must return a `tff.learning.Model`. client_epochs_per_round: An integer representing the number of epochs of training performed per client in each training round. client_batch_size: An integer representing the batch size used on clients. clients_per_round: An integer representing the number of clients participating in each round. client_datasets_random_seed: An optional int used to seed which clients are sampled at each round. If `None`, no seed is used. vocab_size: Integer dictating the number of most frequent words to use in the vocabulary. num_oov_buckets: The number of out-of-vocabulary buckets to use. sequence_length: The maximum number of words to take for each sequence. max_elements_per_user: The maximum number of elements processed for each client's dataset. num_validation_examples: The number of test examples to use for validation. embedding_size: The dimension of the word embedding layer. latent_size: The dimension of the latent units in the recurrent layers. num_layers: The number of stacked recurrent layers to use. shared_embedding: Boolean indicating whether to tie input and output embeddings. total_rounds: The number of federated training rounds. experiment_name: The name of the experiment being run. This will be appended to the `root_output_dir` for purposes of writing outputs. root_output_dir: The name of the root output directory for writing experiment outputs. **kwargs: Additional arguments configuring the training loop. For details on supported arguments, see `federated_research/utils/training_utils.py`. """ model_builder = functools.partial( stackoverflow_models.create_recurrent_model, vocab_size=vocab_size, num_oov_buckets=num_oov_buckets, embedding_size=embedding_size, latent_size=latent_size, num_layers=num_layers, shared_embedding=shared_embedding) loss_builder = functools.partial( tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True) special_tokens = stackoverflow_word_prediction.get_special_tokens( vocab_size, num_oov_buckets) pad_token = special_tokens.pad oov_tokens = special_tokens.oov eos_token = special_tokens.eos def metrics_builder(): return [ keras_metrics.MaskedCategoricalAccuracy(name='accuracy_with_oov', masked_tokens=[pad_token]), keras_metrics.MaskedCategoricalAccuracy(name='accuracy_no_oov', masked_tokens=[pad_token] + oov_tokens), # Notice BOS never appears in ground truth. keras_metrics.MaskedCategoricalAccuracy( name='accuracy_no_oov_or_eos', masked_tokens=[pad_token, eos_token] + oov_tokens), keras_metrics.NumBatchesCounter(), keras_metrics.NumTokensCounter(masked_tokens=[pad_token]) ] train_clientdata, _, _ = tff.simulation.datasets.stackoverflow.load_data() # TODO(b/161914546): consider moving evaluation to use # `tff.learning.build_federated_evaluation` to get metrics over client # distributions, as well as the example weight means from this centralized # evaluation. _, validation_dataset, test_dataset = stackoverflow_word_prediction.get_centralized_datasets( vocab_size=vocab_size, max_sequence_length=sequence_length, num_validation_examples=num_validation_examples, num_oov_buckets=num_oov_buckets) train_dataset_preprocess_comp = stackoverflow_word_prediction.create_preprocess_fn( vocab=stackoverflow_word_prediction.create_vocab(vocab_size), num_oov_buckets=num_oov_buckets, client_batch_size=client_batch_size, client_epochs_per_round=client_epochs_per_round, max_sequence_length=sequence_length, max_elements_per_client=max_elements_per_user) input_spec = train_dataset_preprocess_comp.type_signature.result.element def tff_model_fn() -> tff.learning.Model: return tff.learning.from_keras_model(keras_model=model_builder(), input_spec=input_spec, loss=loss_builder(), metrics=metrics_builder()) def client_weight_fn(local_outputs): # Num_tokens is a tensor with type int64[1], to use as a weight need # a float32 scalar. return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32) iterative_process = iterative_process_builder( tff_model_fn, client_weight_fn=client_weight_fn) training_process = tff.simulation.compose_dataset_computation_with_iterative_process( train_dataset_preprocess_comp, iterative_process) training_process.get_model_weights = iterative_process.get_model_weights client_datasets_fn = training_utils.build_client_datasets_fn( dataset=train_clientdata, clients_per_round=clients_per_round, random_seed=client_datasets_random_seed) evaluate_fn = training_utils.build_centralized_evaluate_fn( model_builder=model_builder, eval_dataset=validation_dataset, loss_builder=loss_builder, metrics_builder=metrics_builder) validation_fn = lambda model_weights, round_num: evaluate_fn(model_weights) test_fn = training_utils.build_centralized_evaluate_fn( model_builder=model_builder, # Use both val and test for symmetry with other experiments, which # evaluate on the entire test set. eval_dataset=validation_dataset.concatenate(test_dataset), loss_builder=loss_builder, metrics_builder=metrics_builder) logging.info('Training model:') logging.info(model_builder().summary()) training_loop.run(iterative_process=training_process, client_datasets_fn=client_datasets_fn, validation_fn=validation_fn, test_fn=test_fn, total_rounds=total_rounds, experiment_name=experiment_name, root_output_dir=root_output_dir, **kwargs)
def configure_training(task_spec: training_specs.TaskSpec, sequence_length: int = 80) -> training_specs.RunnerSpec: """Configures training for the Shakespeare next-character prediction task. This method will load and pre-process datasets and construct a model used for the task. It then uses `iterative_process_builder` to create an iterative process compatible with `federated_research.utils.training_loop`. Args: task_spec: A `TaskSpec` class for creating federated training tasks. sequence_length: An int specifying the length of the character sequences used for prediction. Returns: A `RunnerSpec` containing attributes used for running the newly created federated task. """ shakespeare_train, _ = tff.simulation.datasets.shakespeare.load_data() _, shakespeare_test = shakespeare_dataset.get_centralized_datasets( sequence_length=sequence_length) train_preprocess_fn = shakespeare_dataset.create_preprocess_fn( num_epochs=task_spec.client_epochs_per_round, batch_size=task_spec.client_batch_size, sequence_length=sequence_length) input_spec = train_preprocess_fn.type_signature.result.element model_builder = functools.partial(create_shakespeare_model, sequence_length=sequence_length) loss_builder = functools.partial( tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True) def tff_model_fn() -> tff.learning.Model: return tff.learning.from_keras_model(keras_model=model_builder(), input_spec=input_spec, loss=loss_builder(), metrics=metrics_builder()) iterative_process = task_spec.iterative_process_builder(tff_model_fn) if hasattr(shakespeare_train, 'dataset_computation'): @tff.tf_computation(tf.string) def build_train_dataset_from_client_id(client_id): client_dataset = shakespeare_train.dataset_computation(client_id) return train_preprocess_fn(client_dataset) training_process = tff.simulation.compose_dataset_computation_with_iterative_process( build_train_dataset_from_client_id, iterative_process) client_ids_fn = training_utils.build_sample_fn( shakespeare_train.client_ids, size=task_spec.clients_per_round, replace=False, random_seed=task_spec.client_datasets_random_seed) # We convert the output to a list (instead of an np.ndarray) so that it can # be used as input to the iterative process. client_sampling_fn = lambda x: list(client_ids_fn(x)) else: training_process = tff.simulation.compose_dataset_computation_with_iterative_process( train_preprocess_fn, iterative_process) client_sampling_fn = training_utils.build_client_datasets_fn( dataset=shakespeare_train, clients_per_round=task_spec.clients_per_round, random_seed=task_spec.client_datasets_random_seed) training_process.get_model_weights = iterative_process.get_model_weights test_fn = training_utils.build_centralized_evaluate_fn( eval_dataset=shakespeare_test, model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder) validation_fn = lambda model_weights, round_num: test_fn(model_weights) return training_specs.RunnerSpec(iterative_process=training_process, client_datasets_fn=client_sampling_fn, validation_fn=validation_fn, test_fn=test_fn)
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) emnist_train, emnist_test = emnist_dataset.get_emnist_datasets( FLAGS.client_batch_size, FLAGS.client_epochs_per_round, only_digits=False) if FLAGS.model == 'cnn': model_builder = functools.partial( emnist_models.create_conv_dropout_model, only_digits=False) elif FLAGS.model == '2nn': model_builder = functools.partial( emnist_models.create_two_hidden_layer_model, only_digits=False) else: raise ValueError('Cannot handle model flag [{!s}].'.format( FLAGS.model)) loss_builder = tf.keras.losses.SparseCategoricalCrossentropy metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()] if FLAGS.uniform_weighting: def client_weight_fn(local_outputs): del local_outputs return 1.0 else: client_weight_fn = None # Defaults to the number of examples per client. def model_fn(): return tff.learning.from_keras_model( model_builder(), loss_builder(), input_spec=emnist_test.element_spec, metrics=metrics_builder()) if FLAGS.noise_multiplier is not None: if not FLAGS.uniform_weighting: raise ValueError( 'Differential privacy is only implemented for uniform weighting.' ) dp_query = tff.utils.build_dp_query( clip=FLAGS.clip, noise_multiplier=FLAGS.noise_multiplier, expected_total_weight=FLAGS.clients_per_round, adaptive_clip_learning_rate=FLAGS.adaptive_clip_learning_rate, target_unclipped_quantile=FLAGS.target_unclipped_quantile, clipped_count_budget_allocation=FLAGS. clipped_count_budget_allocation, expected_clients_per_round=FLAGS.clients_per_round, per_vector_clipping=FLAGS.per_vector_clipping, model=model_fn()) weights_type = tff.learning.framework.weights_type_from_model(model_fn) aggregation_process = tff.utils.build_dp_aggregate_process( weights_type.trainable, dp_query) else: aggregation_process = None server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'server') client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'client') iterative_process = tff.learning.build_federated_averaging_process( model_fn=model_fn, server_optimizer_fn=server_optimizer_fn, client_weight_fn=client_weight_fn, client_optimizer_fn=client_optimizer_fn, aggregation_process=aggregation_process) client_datasets_fn = training_utils.build_client_datasets_fn( emnist_train, FLAGS.clients_per_round) evaluate_fn = training_utils.build_evaluate_fn( eval_dataset=emnist_test, model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder) logging.info('Training model:') logging.info(model_builder().summary()) hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags()) training_loop_dict = utils_impl.lookup_flag_values(training_loop_flags) training_loop.run(iterative_process=iterative_process, client_datasets_fn=client_datasets_fn, validation_fn=evaluate_fn, hparam_dict=hparam_dict, **training_loop_dict)
def run_federated( iterative_process_builder: Callable[..., tff.templates.IterativeProcess], client_epochs_per_round: int, client_batch_size: int, clients_per_round: int, schedule: Optional[str] = 'none', beta: Optional[float] = 0., max_batches_per_client: Optional[int] = -1, client_datasets_random_seed: Optional[int] = None, model: Optional[str] = 'cnn', total_rounds: Optional[int] = 1500, experiment_name: Optional[str] = 'federated_synthetic', root_output_dir: Optional[str] = '/tmp/fed_opt', max_eval_batches: Optional[int] = None, alpha: Optional[float] = 0., beta_data: Optional[float] = 0., iid: Optional[int] = 0, num_users: Optional[int] = 100, **kwargs): """Runs an iterative process on the EMNIST character recognition task. This method will load and pre-process dataset and construct a model used for the task. It then uses `iterative_process_builder` to create an iterative process that it applies to the task, using `federated_research.utils.training_loop`. We assume that the iterative process has the following functional type signatures: * `initialize`: `( -> S@SERVER)` where `S` represents the server state. * `next`: `<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>` where `S` represents the server state, `{B*}` represents the client datasets, and `T` represents a python `Mapping` object. Moreover, the server state must have an attribute `model` of type `tff.learning.ModelWeights`. Args: iterative_process_builder: A function that accepts a no-arg `model_fn`, and returns a `tff.templates.IterativeProcess`. The `model_fn` must return a `tff.learning.Model`. client_epochs_per_round: An integer representing the number of epochs of training performed per client in each training round. client_batch_size: An integer representing the batch size used on clients. clients_per_round: An integer representing the number of clients participating in each round. max_batches_per_client: An optional int specifying the number of batches taken by each client at each round. If `-1`, the entire client dataset is used. client_datasets_random_seed: An optional int used to seed which clients are sampled at each round. If `None`, no seed is used. model: A string specifying the model used for character recognition. Can be one of `cnn` and `2nn`, corresponding to a CNN model and a densely connected 2-layer model (respectively). total_rounds: The number of federated training rounds. experiment_name: The name of the experiment being run. This will be appended to the `root_output_dir` for purposes of writing outputs. root_output_dir: The name of the root output directory for writing experiment outputs. max_eval_batches: If set to a positive integer, evaluation datasets are capped to at most that many batches. If set to None or a nonpositive integer, the full evaluation datasets are used. **kwargs: Additional arguments configuring the training loop. For details on supported arguments, see `federated_research/utils/training_utils.py`. """ logging.info(f' DATA PARAMS: ') logging.info(f' Num Users: {num_users}') logging.info(f' alpha: {alpha}') logging.info(f' beta: {beta_data}') logging.info(f' iid: {iid}') train_data, test_data, federated_test = synthetic_dataset.generate_federated_softmax_data( batch_size=client_batch_size, client_epochs_per_round=client_epochs_per_round, test_batch_size=100, alpha=alpha, beta=beta_data, iid=iid, num_users=num_users) input_spec = train_data.create_tf_dataset_for_client( train_data.client_ids[0]).element_spec model_builder = functools.partial(create_logistic_regression_model) loss_builder = tf.keras.losses.SparseCategoricalCrossentropy metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()] def tff_model_fn() -> tff.learning.Model: return tff.learning.from_keras_model(keras_model=model_builder(), input_spec=input_spec, loss=loss_builder(), metrics=metrics_builder()) training_process = iterative_process_builder(tff_model_fn) evaluate_fn = training_utils.build_evaluate_fn( eval_dataset=test_data, model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder) test_fn = training_utils.build_unweighted_test_fn( federated_eval_dataset=federated_test, model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder) logging.info('Training model:') logging.info(model_builder().summary()) try: var = kwargs['hparam_dict']['var_q_clients'] print(f'Variance: {var}') q_client = np.load( f'/home/monica/AVAIL_VECTORS/q_client_{var}_synthetic.npy') #if q_client is None: #logging.info('Could not load q_client - initializing random availabilities') #q_client=None except: logging.info( 'Could not load q_client - initializing random availabilities') q_client = None if schedule == 'none': client_datasets_fn = training_utils.build_client_datasets_fn( train_dataset=train_data, train_clients_per_round=clients_per_round, random_seed=client_datasets_random_seed, min_clients=kwargs['hparam_dict']['min_clients'], var_q_clients=kwargs['hparam_dict']['var_q_clients'], f_mult=kwargs['hparam_dict']['f_mult'], f_intercept=kwargs['hparam_dict']['f_intercept'], sine_wave=kwargs['hparam_dict']['sine_wave'], use_p=True, q_client=q_client, ) training_loop.run(iterative_process=training_process, client_datasets_fn=client_datasets_fn, validation_fn=evaluate_fn, test_fn=test_fn, total_rounds=total_rounds, experiment_name=experiment_name, root_output_dir=root_output_dir, **kwargs) elif schedule == 'loss': if 'loss_pool_size' in kwargs['hparam_dict'] and kwargs['hparam_dict'][ 'loss_pool_size'] is not None: loss_pool_size = kwargs['hparam_dict']['loss_pool_size'] logging.info(f'Loss pool size: {loss_pool_size}') client_datasets_fn = training_utils.build_client_datasets_fn( train_dataset=train_data, train_clients_per_round=loss_pool_size, random_seed=client_datasets_random_seed, min_clients=kwargs['hparam_dict']['min_clients'], var_q_clients=kwargs['hparam_dict']['var_q_clients'], f_mult=kwargs['hparam_dict']['f_mult'], f_intercept=kwargs['hparam_dict']['f_intercept'], sine_wave=kwargs['hparam_dict']['sine_wave'], use_p=True, q_client=q_client) training_loop_loss.run(iterative_process=training_process, client_datasets_fn=client_datasets_fn, validation_fn=evaluate_fn, test_fn=test_fn, total_rounds=total_rounds, total_clients=loss_pool_size, experiment_name=experiment_name, root_output_dir=root_output_dir, **kwargs) else: raise ValueError('Loss pool size not specified') else: init_p = kwargs['hparam_dict']['initialize_p'] logging.info(f'Initializing as p = {init_p}') client_datasets_fn = training_utils.build_availability_client_datasets_fn( train_dataset=train_data, train_clients_per_round=clients_per_round, beta=beta, min_clients=kwargs['hparam_dict']['min_clients'], var_q_clients=kwargs['hparam_dict']['var_q_clients'], f_mult=kwargs['hparam_dict']['f_mult'], f_intercept=kwargs['hparam_dict']['f_intercept'], sine_wave=kwargs['hparam_dict']['sine_wave'], q_client=q_client, initialize_p=init_p, ) training_loop_importance.run(iterative_process=training_process, client_datasets_fn=client_datasets_fn, validation_fn=evaluate_fn, test_fn=test_fn, total_rounds=total_rounds, experiment_name=experiment_name, root_output_dir=root_output_dir, **kwargs)
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) tff.backends.native.set_local_execution_context(max_fanout=10) model_builder = functools.partial( stackoverflow_models.create_recurrent_model, vocab_size=FLAGS.vocab_size, embedding_size=FLAGS.embedding_size, latent_size=FLAGS.latent_size, num_layers=FLAGS.num_layers, shared_embedding=FLAGS.shared_embedding) loss_builder = functools.partial( tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True) special_tokens = stackoverflow_word_prediction.get_special_tokens( FLAGS.vocab_size) pad_token = special_tokens.pad oov_tokens = special_tokens.oov eos_token = special_tokens.eos def metrics_builder(): return [ keras_metrics.MaskedCategoricalAccuracy(name='accuracy_with_oov', masked_tokens=[pad_token]), keras_metrics.MaskedCategoricalAccuracy(name='accuracy_no_oov', masked_tokens=[pad_token] + oov_tokens), # Notice BOS never appears in ground truth. keras_metrics.MaskedCategoricalAccuracy( name='accuracy_no_oov_or_eos', masked_tokens=[pad_token, eos_token] + oov_tokens), keras_metrics.NumBatchesCounter(), keras_metrics.NumTokensCounter(masked_tokens=[pad_token]), ] train_dataset, _ = stackoverflow_word_prediction.get_federated_datasets( vocab_size=FLAGS.vocab_size, train_client_batch_size=FLAGS.client_batch_size, train_client_epochs_per_round=FLAGS.client_epochs_per_round, max_sequence_length=FLAGS.sequence_length, max_elements_per_train_client=FLAGS.max_elements_per_user) _, validation_dataset, test_dataset = stackoverflow_word_prediction.get_centralized_datasets( vocab_size=FLAGS.vocab_size, max_sequence_length=FLAGS.sequence_length, num_validation_examples=FLAGS.num_validation_examples) if FLAGS.uniform_weighting: client_weighting = tff.learning.ClientWeighting.UNIFORM else: client_weighting = tff.learning.ClientWeighting.NUM_EXAMPLES def model_fn(): return tff.learning.from_keras_model( model_builder(), loss_builder(), input_spec=validation_dataset.element_spec, metrics=metrics_builder()) if FLAGS.noise_multiplier is not None: if not FLAGS.uniform_weighting: raise ValueError( 'Differential privacy is only implemented for uniform weighting.' ) if FLAGS.noise_multiplier <= 0: raise ValueError( 'noise_multiplier must be positive if DP is enabled.') if FLAGS.clip is None or FLAGS.clip <= 0: raise ValueError('clip must be positive if DP is enabled.') if not FLAGS.adaptive_clip_learning_rate: aggregation_factory = tff.aggregators.DifferentiallyPrivateFactory.gaussian_fixed( noise_multiplier=FLAGS.noise_multiplier, clients_per_round=FLAGS.clients_per_round, clip=FLAGS.clip) else: if FLAGS.adaptive_clip_learning_rate <= 0: raise ValueError( 'adaptive_clip_learning_rate must be positive if ' 'adaptive clipping is enabled.') aggregation_factory = tff.aggregators.DifferentiallyPrivateFactory.gaussian_adaptive( noise_multiplier=FLAGS.noise_multiplier, clients_per_round=FLAGS.clients_per_round, initial_l2_norm_clip=FLAGS.clip, target_unclipped_quantile=FLAGS.target_unclipped_quantile, learning_rate=FLAGS.adaptive_clip_learning_rate) else: if FLAGS.uniform_weighting: aggregation_factory = tff.aggregators.UnweightedMeanFactory() else: aggregation_factory = tff.aggregators.MeanFactory() server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'server') client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'client') iterative_process = tff.learning.build_federated_averaging_process( model_fn=model_fn, server_optimizer_fn=server_optimizer_fn, client_weighting=client_weighting, client_optimizer_fn=client_optimizer_fn, model_update_aggregation_factory=aggregation_factory) client_datasets_fn = training_utils.build_client_datasets_fn( train_dataset, FLAGS.clients_per_round) evaluate_fn = training_utils.build_centralized_evaluate_fn( model_builder=model_builder, eval_dataset=validation_dataset, loss_builder=loss_builder, metrics_builder=metrics_builder) validation_fn = lambda state, round_num: evaluate_fn(state.model) evaluate_test_fn = training_utils.build_centralized_evaluate_fn( model_builder=model_builder, # Use both val and test for symmetry with other experiments, which # evaluate on the entire test set. eval_dataset=validation_dataset.concatenate(test_dataset), loss_builder=loss_builder, metrics_builder=metrics_builder) test_fn = lambda state: evaluate_test_fn(state.model) logging.info('Training model:') logging.info(model_builder().summary()) # Log hyperparameters to CSV hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags()) results_dir = os.path.join(FLAGS.root_output_dir, 'results', FLAGS.experiment_name) utils_impl.create_directory_if_not_exists(results_dir) hparam_file = os.path.join(results_dir, 'hparams.csv') utils_impl.atomic_write_series_to_csv(hparam_dict, hparam_file) training_loop.run(iterative_process=iterative_process, client_datasets_fn=client_datasets_fn, validation_fn=validation_fn, test_fn=test_fn, total_rounds=FLAGS.total_rounds, experiment_name=FLAGS.experiment_name, root_output_dir=FLAGS.root_output_dir, rounds_per_eval=FLAGS.rounds_per_eval, rounds_per_checkpoint=FLAGS.rounds_per_checkpoint)
def configure_training( task_spec: training_specs.TaskSpec, crop_size: int = 24, distort_train_images: bool = True) -> training_specs.RunnerSpec: """Configures training for the CIFAR-100 classification task. This method will load and pre-process datasets and construct a model used for the task. It then uses `iterative_process_builder` to create an iterative process compatible with `federated_research.utils.training_loop`. Args: task_spec: A `TaskSpec` class for creating federated training tasks. crop_size: An optional integer representing the resulting size of input images after preprocessing. distort_train_images: A boolean indicating whether to distort training images during preprocessing via random crops, as opposed to simply resizing the image. Returns: A `RunnerSpec` containing attributes used for running the newly created federated task. """ crop_shape = (crop_size, crop_size, 3) cifar_train, _ = tff.simulation.datasets.cifar100.load_data() _, cifar_test = cifar100_dataset.get_centralized_datasets( train_batch_size=task_spec.client_batch_size, crop_shape=crop_shape) train_preprocess_fn = cifar100_dataset.create_preprocess_fn( num_epochs=task_spec.client_epochs_per_round, batch_size=task_spec.client_batch_size, crop_shape=crop_shape, distort_image=distort_train_images) input_spec = train_preprocess_fn.type_signature.result.element model_builder = functools.partial(resnet_models.create_resnet18, input_shape=crop_shape, num_classes=NUM_CLASSES) loss_builder = tf.keras.losses.SparseCategoricalCrossentropy metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()] def tff_model_fn() -> tff.learning.Model: return tff.learning.from_keras_model(keras_model=model_builder(), input_spec=input_spec, loss=loss_builder(), metrics=metrics_builder()) iterative_process = task_spec.iterative_process_builder(tff_model_fn) if hasattr(cifar_train, 'dataset_computation'): @tff.tf_computation(tf.string) def build_train_dataset_from_client_id(client_id): client_dataset = cifar_train.dataset_computation(client_id) return train_preprocess_fn(client_dataset) training_process = tff.simulation.compose_dataset_computation_with_iterative_process( build_train_dataset_from_client_id, iterative_process) client_ids_fn = training_utils.build_sample_fn( cifar_train.client_ids, size=task_spec.clients_per_round, replace=False, random_seed=task_spec.client_datasets_random_seed) # We convert the output to a list (instead of an np.ndarray) so that it can # be used as input to the iterative process. client_sampling_fn = lambda x: list(client_ids_fn(x)) else: training_process = tff.simulation.compose_dataset_computation_with_iterative_process( train_preprocess_fn, iterative_process) client_sampling_fn = training_utils.build_client_datasets_fn( dataset=cifar_train, clients_per_round=task_spec.clients_per_round, random_seed=task_spec.client_datasets_random_seed) training_process.get_model_weights = iterative_process.get_model_weights test_fn = training_utils.build_centralized_evaluate_fn( eval_dataset=cifar_test, model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder) validation_fn = lambda model_weights, round_num: test_fn(model_weights) return training_specs.RunnerSpec(iterative_process=training_process, client_datasets_fn=client_sampling_fn, validation_fn=validation_fn, test_fn=test_fn)
def run_federated( iterative_process_builder: Callable[..., tff.templates.IterativeProcess], client_epochs_per_round: int, client_batch_size: int, clients_per_round: int, client_datasets_random_seed: Optional[int] = None, crop_size: Optional[int] = 24, total_rounds: Optional[int] = 1500, experiment_name: Optional[str] = 'federated_cifar10', root_output_dir: Optional[str] = '/tmp/fed_opt', **kwargs): """Runs an iterative process on the CIFAR-10 classification task. This method will load and pre-process dataset and construct a model used for the task. It then uses `iterative_process_builder` to create an iterative process that it applies to the task, using `federated_research.utils.training_loop`. We assume that the iterative process has the following functional type signatures: * `initialize`: `( -> S@SERVER)` where `S` represents the server state. * `next`: `<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>` where `S` represents the server state, `{B*}` represents the client datasets, and `T` represents a python `Mapping` object. The iterative process must also have a callable attribute `get_model_weights` that takes as input the state of the iterative process, and returns a `tff.learning.ModelWeights` object. Args: iterative_process_builder: A function that accepts a no-arg `model_fn`, and returns a `tff.templates.IterativeProcess`. The `model_fn` must return a `tff.learning.Model`. client_epochs_per_round: An integer representing the number of epochs of training performed per client in each training round. client_batch_size: An integer representing the batch size used on clients. clients_per_round: An integer representing the number of clients participating in each round. client_datasets_random_seed: An optional int used to seed which clients are sampled at each round. If `None`, no seed is used. crop_size: An optional integer representing the resulting size of input images after preprocessing. total_rounds: The number of federated training rounds. experiment_name: The name of the experiment being run. This will be appended to the `root_output_dir` for purposes of writing outputs. root_output_dir: The name of the root output directory for writing experiment outputs. **kwargs: Additional arguments configuring the training loop. For details on supported arguments, see `federated_research/utils/training_utils.py`. """ crop_shape = (crop_size, crop_size, 3) cifar_train, _ = cifar10_dataset.get_federated_datasets( train_client_epochs_per_round=client_epochs_per_round, train_client_batch_size=client_batch_size, crop_shape=crop_shape) _, cifar_test = cifar10_dataset.get_centralized_datasets( crop_shape=crop_shape) input_spec = cifar_train.create_tf_dataset_for_client( cifar_train.client_ids[0]).element_spec model_builder = functools.partial(resnet_models.create_resnet18, input_shape=crop_shape, num_classes=NUM_CLASSES) loss_builder = tf.keras.losses.SparseCategoricalCrossentropy metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()] def tff_model_fn() -> tff.learning.Model: return tff.learning.from_keras_model(keras_model=model_builder(), input_spec=input_spec, loss=loss_builder(), metrics=metrics_builder()) training_process = iterative_process_builder(tff_model_fn) client_datasets_fn = training_utils.build_client_datasets_fn( dataset=cifar_train, clients_per_round=clients_per_round, random_seed=client_datasets_random_seed) test_fn = training_utils.build_centralized_evaluate_fn( eval_dataset=cifar_test, model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder) validation_fn = lambda model_weights, round_num: test_fn(model_weights) logging.info('Training model:') logging.info(model_builder().summary()) training_loop.run(iterative_process=training_process, train_client_datasets_fn=client_datasets_fn, evaluation_fn=validation_fn, test_fn=test_fn, total_rounds=total_rounds, experiment_name=experiment_name, root_output_dir=root_output_dir, **kwargs)
def run_federated( iterative_process_builder: Callable[..., tff.templates.IterativeProcess], client_epochs_per_round: int, client_batch_size: int, clients_per_round: int, max_batches_per_client: Optional[int] = -1, client_datasets_random_seed: Optional[int] = None, sequence_length: Optional[int] = 80, total_rounds: Optional[int] = 1500, experiment_name: Optional[str] = 'federated_shakespeare', root_output_dir: Optional[str] = '/tmp/fed_opt', max_eval_batches: Optional[int] = None, **kwargs): """Runs an iterative process on a Shakespeare next character prediction task. This method will load and pre-process dataset and construct a model used for the task. It then uses `iterative_process_builder` to create an iterative process that it applies to the task, using `federated_research.utils.training_loop`. We assume that the iterative process has the following functional type signatures: * `initialize`: `( -> S@SERVER)` where `S` represents the server state. * `next`: `<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>` where `S` represents the server state, `{B*}` represents the client datasets, and `T` represents a python `Mapping` object. Moreover, the server state must have an attribute `model` of type `tff.learning.ModelWeights`. Args: iterative_process_builder: A function that accepts a no-arg `model_fn`, and a `client_weight_fn`, and returns a `tff.templates.IterativeProcess`. The `model_fn` must return a `tff.learning.Model`. client_epochs_per_round: An integer representing the number of epochs of training performed per client in each training round. client_batch_size: An integer representing the batch size used on clients. clients_per_round: An integer representing the number of clients participating in each round. max_batches_per_client: An optional int specifying the number of batches taken by each client at each round. If `-1`, the entire client dataset is used. client_datasets_random_seed: An optional int used to seed which clients are sampled at each round. If `None`, no seed is used. sequence_length: An int specifying the length of the character sequences used for prediction. total_rounds: The number of federated training rounds. experiment_name: The name of the experiment being run. This will be appended to the `root_output_dir` for purposes of writing outputs. root_output_dir: The name of the root output directory for writing experiment outputs. max_eval_batches: If set to a positive integer, evaluation datasets are capped to at most that many batches. If set to None or a nonpositive integer, the full evaluation datasets are used. **kwargs: Additional arguments configuring the training loop. For details on supported arguments, see `federated_research/utils/training_utils.py`. """ train_clientdata = shakespeare_dataset.construct_character_level_datasets( client_batch_size=client_batch_size, client_epochs_per_round=client_epochs_per_round, sequence_length=sequence_length, max_batches_per_client=max_batches_per_client) _, test_dataset = shakespeare_dataset.get_centralized_datasets( train_batch_size=client_batch_size, max_test_batches=max_eval_batches, sequence_length=sequence_length) model_builder = functools.partial(create_shakespeare_model, sequence_length=sequence_length) loss_builder = functools.partial( tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True) input_spec = train_clientdata.create_tf_dataset_for_client( train_clientdata.client_ids[0]).element_spec def client_weight_fn(local_outputs): # Num_tokens is a tensor with type int64[1], to use as a weight need # a float32 scalar. return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32) def tff_model_fn() -> tff.learning.Model: return tff.learning.from_keras_model(keras_model=model_builder(), input_spec=input_spec, loss=loss_builder(), metrics=metrics_builder()) training_process = iterative_process_builder( tff_model_fn, client_weight_fn=client_weight_fn) client_datasets_fn = training_utils.build_client_datasets_fn( train_dataset=train_clientdata, train_clients_per_round=clients_per_round, random_seed=client_datasets_random_seed) evaluate_fn = training_utils.build_evaluate_fn( eval_dataset=test_dataset, model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder) logging.info('Training model:') logging.info(model_builder().summary()) training_loop.run(iterative_process=training_process, client_datasets_fn=client_datasets_fn, validation_fn=evaluate_fn, test_fn=evaluate_fn, total_rounds=total_rounds, total_clients=clients_per_round, experiment_name=experiment_name, root_output_dir=root_output_dir, **kwargs)
def run_federated( iterative_process_builder: Callable[..., tff.templates.IterativeProcess], client_epochs_per_round: int, client_batch_size: int, clients_per_round: int, client_datasets_random_seed: Optional[int] = None, vocab_tokens_size: Optional[int] = 10000, vocab_tags_size: Optional[int] = 500, max_elements_per_user: Optional[int] = 1000, num_validation_examples: Optional[int] = 10000, total_rounds: Optional[int] = 1500, experiment_name: Optional[str] = 'federated_so_lr', root_output_dir: Optional[str] = '/tmp/fed_opt', max_eval_batches: Optional[int] = None, **kwargs): """Runs an iterative process on the Stack Overflow logistic regression task. This method will load and pre-process dataset and construct a model used for the task. It then uses `iterative_process_builder` to create an iterative process that it applies to the task, using `federated_research.utils.training_loop`. We assume that the iterative process has the following functional type signatures: * `initialize`: `( -> S@SERVER)` where `S` represents the server state. * `next`: `<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>` where `S` represents the server state, `{B*}` represents the client datasets, and `T` represents a python `Mapping` object. The iterative process must also have a callable attribute `get_model_weights` that takes as input the state of the iterative process, and returns a `tff.learning.ModelWeights` object. Args: iterative_process_builder: A function that accepts a no-arg `model_fn`, and returns a `tff.templates.IterativeProcess`. The `model_fn` must return a `tff.learning.Model`. client_epochs_per_round: An integer representing the number of epochs of training performed per client in each training round. client_batch_size: An integer representing the batch size used on clients. clients_per_round: An integer representing the number of clients participating in each round. client_datasets_random_seed: An optional int used to seed which clients are sampled at each round. If `None`, no seed is used. vocab_tokens_size: Integer dictating the number of most frequent words to use in the vocabulary. vocab_tags_size: Integer dictating the number of most frequent tags to use in the label creation. max_elements_per_user: The maximum number of elements processed for each client's dataset. num_validation_examples: The number of test examples to use for validation. total_rounds: The number of federated training rounds. experiment_name: The name of the experiment being run. This will be appended to the `root_output_dir` for purposes of writing outputs. root_output_dir: The name of the root output directory for writing experiment outputs. max_eval_batches: If set to a positive integer, evaluation datasets are capped to at most that many batches. If set to None or a nonpositive integer, the full evaluation datasets are used. **kwargs: Additional arguments configuring the training loop. For details on supported arguments, see `federated_research/utils/training_utils.py`. """ stackoverflow_train, _ = stackoverflow_tag_prediction.get_federated_datasets( word_vocab_size=vocab_tokens_size, tag_vocab_size=vocab_tags_size, train_client_batch_size=client_batch_size, train_client_epochs_per_round=client_epochs_per_round, max_elements_per_train_client=max_elements_per_user) _, stackoverflow_validation, stackoverflow_test = stackoverflow_tag_prediction.get_centralized_datasets( train_batch_size=client_batch_size, word_vocab_size=vocab_tokens_size, tag_vocab_size=vocab_tags_size, num_validation_examples=num_validation_examples) if max_eval_batches and max_eval_batches >= 1: stackoverflow_validation = stackoverflow_validation.take(max_eval_batches) stackoverflow_test = stackoverflow_test.take(max_eval_batches) input_spec = stackoverflow_train.create_tf_dataset_for_client( stackoverflow_train.client_ids[0]).element_spec model_builder = functools.partial( stackoverflow_lr_models.create_logistic_model, vocab_tokens_size=vocab_tokens_size, vocab_tags_size=vocab_tags_size) loss_builder = functools.partial( tf.keras.losses.BinaryCrossentropy, from_logits=False, reduction=tf.keras.losses.Reduction.SUM) def tff_model_fn() -> tff.learning.Model: return tff.learning.from_keras_model( keras_model=model_builder(), input_spec=input_spec, loss=loss_builder(), metrics=metrics_builder()) training_process = iterative_process_builder(tff_model_fn) client_datasets_fn = training_utils.build_client_datasets_fn( dataset=stackoverflow_train, clients_per_round=clients_per_round, random_seed=client_datasets_random_seed) evaluate_fn = training_utils.build_centralized_evaluate_fn( model_builder=model_builder, eval_dataset=stackoverflow_validation, loss_builder=loss_builder, metrics_builder=metrics_builder) test_fn = training_utils.build_centralized_evaluate_fn( model_builder=model_builder, # Use both val and test for symmetry with other experiments, which # evaluate on the entire test set. eval_dataset=stackoverflow_validation.concatenate(stackoverflow_test), loss_builder=loss_builder, metrics_builder=metrics_builder) logging.info('Training model:') logging.info(model_builder().summary()) training_loop.run( iterative_process=training_process, client_datasets_fn=client_datasets_fn, validation_fn=evaluate_fn, test_fn=test_fn, total_rounds=total_rounds, experiment_name=experiment_name, root_output_dir=root_output_dir, **kwargs)
def configure_training( task_spec: training_specs.TaskSpec) -> training_specs.RunnerSpec: """Configures training for the EMNIST autoencoder task. This method will load and pre-process datasets and construct a model used for the task. It then uses `iterative_process_builder` to create an iterative process compatible with `federated_research.utils.training_loop`. Args: task_spec: A `TaskSpec` class for creating federated training tasks. Returns: A `RunnerSpec` containing attributes used for running the newly created federated task. """ emnist_task = 'autoencoder' emnist_train, _ = tff.simulation.datasets.emnist.load_data(only_digits=False) _, emnist_test = emnist_dataset.get_centralized_datasets( only_digits=False, emnist_task=emnist_task) train_preprocess_fn = emnist_dataset.create_preprocess_fn( num_epochs=task_spec.client_epochs_per_round, batch_size=task_spec.client_batch_size, emnist_task=emnist_task) input_spec = train_preprocess_fn.type_signature.result.element model_builder = emnist_ae_models.create_autoencoder_model loss_builder = functools.partial( tf.keras.losses.MeanSquaredError, reduction=tf.keras.losses.Reduction.SUM) metrics_builder = lambda: [tf.keras.metrics.MeanSquaredError()] def tff_model_fn() -> tff.learning.Model: return tff.learning.from_keras_model( keras_model=model_builder(), input_spec=input_spec, loss=loss_builder(), metrics=metrics_builder()) iterative_process = task_spec.iterative_process_builder(tff_model_fn) if hasattr(emnist_train, 'dataset_computation'): @tff.tf_computation(tf.string) def build_train_dataset_from_client_id(client_id): client_dataset = emnist_train.dataset_computation(client_id) return train_preprocess_fn(client_dataset) training_process = tff.simulation.compose_dataset_computation_with_iterative_process( build_train_dataset_from_client_id, iterative_process) client_ids_fn = training_utils.build_sample_fn( emnist_train.client_ids, size=task_spec.clients_per_round, replace=False, random_seed=task_spec.sampling_random_seed) # We convert the output to a list (instead of an np.ndarray) so that it can # be used as input to the iterative process. client_sampling_fn = lambda x: list(client_ids_fn(x)) else: training_process = tff.simulation.compose_dataset_computation_with_iterative_process( train_preprocess_fn, iterative_process) client_sampling_fn = training_utils.build_client_datasets_fn( dataset=emnist_train, clients_per_round=task_spec.clients_per_round, random_seed=task_spec.sampling_random_seed) training_process.get_model_weights = iterative_process.get_model_weights test_fn = training_utils.build_centralized_evaluate_fn( eval_dataset=emnist_test, model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder) validation_fn = lambda model_weights, round_num: test_fn(model_weights) return training_specs.RunnerSpec( iterative_process=training_process, client_datasets_fn=client_sampling_fn, validation_fn=validation_fn, test_fn=test_fn)
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) emnist_train, _ = emnist_dataset.get_federated_datasets( train_client_batch_size=FLAGS.client_batch_size, train_client_epochs_per_round=FLAGS.client_epochs_per_round, only_digits=False) _, emnist_test = emnist_dataset.get_centralized_datasets() if FLAGS.model == 'cnn': model_builder = functools.partial( emnist_models.create_conv_dropout_model, only_digits=False) elif FLAGS.model == '2nn': model_builder = functools.partial( emnist_models.create_two_hidden_layer_model, only_digits=False) else: raise ValueError('Cannot handle model flag [{!s}].'.format(FLAGS.model)) loss_builder = tf.keras.losses.SparseCategoricalCrossentropy metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()] if FLAGS.uniform_weighting: client_weighting = tff.learning.ClientWeighting.UNIFORM else: client_weighting = tff.learning.ClientWeighting.NUM_EXAMPLES def model_fn(): return tff.learning.from_keras_model( model_builder(), loss_builder(), input_spec=emnist_test.element_spec, metrics=metrics_builder()) if FLAGS.noise_multiplier is not None: if not FLAGS.uniform_weighting: raise ValueError( 'Differential privacy is only implemented for uniform weighting.') if FLAGS.noise_multiplier <= 0: raise ValueError('noise_multiplier must be positive if DP is enabled.') if FLAGS.clip is None or FLAGS.clip <= 0: raise ValueError('clip must be positive if DP is enabled.') if not FLAGS.adaptive_clip_learning_rate: aggregation_factory = tff.aggregators.DifferentiallyPrivateFactory.gaussian_fixed( noise_multiplier=FLAGS.noise_multiplier, clients_per_round=FLAGS.clients_per_round, clip=FLAGS.clip) else: if FLAGS.adaptive_clip_learning_rate <= 0: raise ValueError('adaptive_clip_learning_rate must be positive if ' 'adaptive clipping is enabled.') aggregation_factory = tff.aggregators.DifferentiallyPrivateFactory.gaussian_adaptive( noise_multiplier=FLAGS.noise_multiplier, clients_per_round=FLAGS.clients_per_round, initial_l2_norm_clip=FLAGS.clip, target_unclipped_quantile=FLAGS.target_unclipped_quantile, learning_rate=FLAGS.adaptive_clip_learning_rate) else: if FLAGS.uniform_weighting: aggregation_factory = tff.aggregators.UnweightedMeanFactory() else: aggregation_factory = tff.aggregators.MeanFactory() server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('server') client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('client') iterative_process = tff.learning.build_federated_averaging_process( model_fn=model_fn, server_optimizer_fn=server_optimizer_fn, client_weighting=client_weighting, client_optimizer_fn=client_optimizer_fn, model_update_aggregation_factory=aggregation_factory) client_datasets_fn = training_utils.build_client_datasets_fn( emnist_train, FLAGS.clients_per_round) evaluate_fn = training_utils.build_centralized_evaluate_fn( eval_dataset=emnist_test, model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder) validation_fn = lambda model_weights, round_num: evaluate_fn(model_weights) logging.info('Training model:') logging.info(model_builder().summary()) # Log hyperparameters to CSV hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags()) results_dir = os.path.join(FLAGS.root_output_dir, 'results', FLAGS.experiment_name) utils_impl.create_directory_if_not_exists(results_dir) hparam_file = os.path.join(results_dir, 'hparams.csv') utils_impl.atomic_write_series_to_csv(hparam_dict, hparam_file) training_loop.run( iterative_process=iterative_process, client_datasets_fn=client_datasets_fn, validation_fn=validation_fn, total_rounds=FLAGS.total_rounds, experiment_name=FLAGS.experiment_name, root_output_dir=FLAGS.root_output_dir, rounds_per_eval=FLAGS.rounds_per_eval, rounds_per_checkpoint=FLAGS.rounds_per_checkpoint, rounds_per_profile=FLAGS.rounds_per_profile)
def run_federated( iterative_process_builder: Callable[..., tff.templates.IterativeProcess], client_epochs_per_round: int, client_batch_size: int, clients_per_round: int, max_batches_per_client: Optional[int] = -1, client_datasets_random_seed: Optional[int] = None, total_rounds: Optional[int] = 1500, experiment_name: Optional[str] = 'federated_emnist_ae', root_output_dir: Optional[str] = '/tmp/fed_opt', max_eval_batches: Optional[int] = None, **kwargs): """Runs an iterative process on the EMNIST autoencoder task. This method will load and pre-process dataset and construct a model used for the task. It then uses `iterative_process_builder` to create an iterative process that it applies to the task, using `federated_research.utils.training_loop`. We assume that the iterative process has the following functional type signatures: * `initialize`: `( -> S@SERVER)` where `S` represents the server state. * `next`: `<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>` where `S` represents the server state, `{B*}` represents the client datasets, and `T` represents a python `Mapping` object. Moreover, the server state must have an attribute `model` of type `tff.learning.ModelWeights`. Args: iterative_process_builder: A function that accepts a no-arg `model_fn`, and returns a `tff.templates.IterativeProcess`. The `model_fn` must return a `tff.learning.Model`. client_epochs_per_round: An integer representing the number of epochs of training performed per client in each training round. client_batch_size: An integer representing the batch size used on clients. clients_per_round: An integer representing the number of clients participating in each round. max_batches_per_client: An optional int specifying the number of batches taken by each client at each round. If `-1`, the entire client dataset is used. client_datasets_random_seed: An optional int used to seed which clients are sampled at each round. If `None`, no seed is used. total_rounds: The number of federated training rounds. experiment_name: The name of the experiment being run. This will be appended to the `root_output_dir` for purposes of writing outputs. root_output_dir: The name of the root output directory for writing experiment outputs. max_eval_batches: If set to a positive integer, evaluation datasets are capped to at most that many batches. If set to None or a nonpositive integer, the full evaluation datasets are used. **kwargs: Additional arguments configuring the training loop. For details on supported arguments, see `federated_research/utils/training_utils.py`. """ emnist_train, _ = emnist_ae_dataset.get_emnist_datasets( client_batch_size=client_batch_size, client_epochs_per_round=client_epochs_per_round, max_batches_per_client=max_batches_per_client, only_digits=False) _, emnist_test = emnist_ae_dataset.get_centralized_datasets( train_batch_size=client_batch_size, max_test_batches=max_eval_batches, only_digits=False) input_spec = emnist_train.create_tf_dataset_for_client( emnist_train.client_ids[0]).element_spec model_builder = emnist_ae_models.create_autoencoder_model loss_builder = functools.partial(tf.keras.losses.MeanSquaredError, reduction=tf.keras.losses.Reduction.SUM) metrics_builder = lambda: [tf.keras.metrics.MeanSquaredError()] def tff_model_fn() -> tff.learning.Model: return tff.learning.from_keras_model(keras_model=model_builder(), input_spec=input_spec, loss=loss_builder(), metrics=metrics_builder()) training_process = iterative_process_builder(tff_model_fn) client_datasets_fn = training_utils.build_client_datasets_fn( train_dataset=emnist_train, train_clients_per_round=clients_per_round, random_seed=client_datasets_random_seed) evaluate_fn = training_utils.build_evaluate_fn( eval_dataset=emnist_test, model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder) logging.info('Training model:') logging.info(model_builder().summary()) training_loop.run(iterative_process=training_process, client_datasets_fn=client_datasets_fn, validation_fn=evaluate_fn, test_fn=evaluate_fn, total_rounds=total_rounds, experiment_name=experiment_name, root_output_dir=root_output_dir, **kwargs)
def configure_training( task_spec: training_specs.TaskSpec, vocab_size: int = 10000, num_oov_buckets: int = 1, sequence_length: int = 20, max_elements_per_user: int = 1000, num_validation_examples: int = 10000, embedding_size: int = 96, latent_size: int = 670, num_layers: int = 1, shared_embedding: bool = False) -> training_specs.RunnerSpec: """Configures training for Stack Overflow next-word prediction. This method will load and pre-process datasets and construct a model used for the task. It then uses `iterative_process_builder` to create an iterative process compatible with `federated_research.utils.training_loop`. Args: task_spec: A `TaskSpec` class for creating federated training tasks. vocab_size: Integer dictating the number of most frequent words to use in the vocabulary. num_oov_buckets: The number of out-of-vocabulary buckets to use. sequence_length: The maximum number of words to take for each sequence. max_elements_per_user: The maximum number of elements processed for each client's dataset. num_validation_examples: The number of test examples to use for validation. embedding_size: The dimension of the word embedding layer. latent_size: The dimension of the latent units in the recurrent layers. num_layers: The number of stacked recurrent layers to use. shared_embedding: Boolean indicating whether to tie input and output embeddings. Returns: A `RunnerSpec` containing attributes used for running the newly created federated task. """ model_builder = functools.partial( stackoverflow_models.create_recurrent_model, vocab_size=vocab_size, num_oov_buckets=num_oov_buckets, embedding_size=embedding_size, latent_size=latent_size, num_layers=num_layers, shared_embedding=shared_embedding) loss_builder = functools.partial( tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True) special_tokens = stackoverflow_word_prediction.get_special_tokens( vocab_size, num_oov_buckets) pad_token = special_tokens.pad oov_tokens = special_tokens.oov eos_token = special_tokens.eos def metrics_builder(): return [ keras_metrics.MaskedCategoricalAccuracy(name='accuracy_with_oov', masked_tokens=[pad_token]), keras_metrics.MaskedCategoricalAccuracy(name='accuracy_no_oov', masked_tokens=[pad_token] + oov_tokens), # Notice BOS never appears in ground truth. keras_metrics.MaskedCategoricalAccuracy( name='accuracy_no_oov_or_eos', masked_tokens=[pad_token, eos_token] + oov_tokens), keras_metrics.NumBatchesCounter(), keras_metrics.NumTokensCounter(masked_tokens=[pad_token]) ] train_clientdata, _, _ = tff.simulation.datasets.stackoverflow.load_data() # TODO(b/161914546): consider moving evaluation to use # `tff.learning.build_federated_evaluation` to get metrics over client # distributions, as well as the example weight means from this centralized # evaluation. _, validation_dataset, test_dataset = stackoverflow_word_prediction.get_centralized_datasets( vocab_size=vocab_size, max_sequence_length=sequence_length, num_validation_examples=num_validation_examples, num_oov_buckets=num_oov_buckets) train_dataset_preprocess_comp = stackoverflow_word_prediction.create_preprocess_fn( vocab=stackoverflow_word_prediction.create_vocab(vocab_size), num_oov_buckets=num_oov_buckets, client_batch_size=task_spec.client_batch_size, client_epochs_per_round=task_spec.client_epochs_per_round, max_sequence_length=sequence_length, max_elements_per_client=max_elements_per_user) input_spec = train_dataset_preprocess_comp.type_signature.result.element def tff_model_fn() -> tff.learning.Model: return tff.learning.from_keras_model(keras_model=model_builder(), input_spec=input_spec, loss=loss_builder(), metrics=metrics_builder()) iterative_process = task_spec.iterative_process_builder(tff_model_fn) if hasattr(train_clientdata, 'dataset_computation'): @tff.tf_computation(tf.string) def train_dataset_computation(client_id): client_train_data = train_clientdata.dataset_computation(client_id) return train_dataset_preprocess_comp(client_train_data) training_process = tff.simulation.compose_dataset_computation_with_iterative_process( train_dataset_computation, iterative_process) client_ids_fn = training_utils.build_sample_fn( train_clientdata.client_ids, size=task_spec.clients_per_round, replace=False, random_seed=task_spec.client_datasets_random_seed) # We convert the output to a list (instead of an np.ndarray) so that it can # be used as input to the iterative process. client_sampling_fn = lambda x: list(client_ids_fn(x)) else: training_process = tff.simulation.compose_dataset_computation_with_iterative_process( train_dataset_preprocess_comp, iterative_process) client_sampling_fn = training_utils.build_client_datasets_fn( dataset=train_clientdata, clients_per_round=task_spec.clients_per_round, random_seed=task_spec.client_datasets_random_seed) training_process.get_model_weights = iterative_process.get_model_weights evaluate_fn = training_utils.build_centralized_evaluate_fn( model_builder=model_builder, eval_dataset=validation_dataset, loss_builder=loss_builder, metrics_builder=metrics_builder) validation_fn = lambda model_weights, round_num: evaluate_fn(model_weights) test_fn = training_utils.build_centralized_evaluate_fn( model_builder=model_builder, # Use both val and test for symmetry with other experiments, which # evaluate on the entire test set. eval_dataset=validation_dataset.concatenate(test_dataset), loss_builder=loss_builder, metrics_builder=metrics_builder) return training_specs.RunnerSpec(iterative_process=training_process, client_datasets_fn=client_sampling_fn, validation_fn=validation_fn, test_fn=test_fn)
def run_federated( iterative_process_builder: Callable[..., tff.templates.IterativeProcess], client_epochs_per_round: int, client_batch_size: int, clients_per_round: int, schedule: Optional[str] = 'none', beta: Optional[float] = 0., max_batches_per_client: Optional[int] = -1, client_datasets_random_seed: Optional[int] = None, crop_size: Optional[int] = 24, total_rounds: Optional[int] = 1500, experiment_name: Optional[str] = 'federated_cifar100', root_output_dir: Optional[str] = '/tmp/fed_opt', max_eval_batches: Optional[int] = None, **kwargs): """Runs an iterative process on the CIFAR-100 classification task. This method will load and pre-process dataset and construct a model used for the task. It then uses `iterative_process_builder` to create an iterative process that it applies to the task, using `federated_research.utils.training_loop`. We assume that the iterative process has the following functional type signatures: * `initialize`: `( -> S@SERVER)` where `S` represents the server state. * `next`: `<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>` where `S` represents the server state, `{B*}` represents the client datasets, and `T` represents a python `Mapping` object. Moreover, the server state must have an attribute `model` of type `tff.learning.ModelWeights`. Args: iterative_process_builder: A function that accepts a no-arg `model_fn`, and returns a `tff.templates.IterativeProcess`. The `model_fn` must return a `tff.learning.Model`. client_epochs_per_round: An integer representing the number of epochs of training performed per client in each training round. client_batch_size: An integer representing the batch size used on clients. clients_per_round: An integer representing the number of clients participating in each round. max_batches_per_client: An optional int specifying the number of batches taken by each client at each round. If `-1`, the entire client dataset is used. client_datasets_random_seed: An optional int used to seed which clients are sampled at each round. If `None`, no seed is used. crop_size: An optional integer representing the resulting size of input images after preprocessing. total_rounds: The number of federated training rounds. experiment_name: The name of the experiment being run. This will be appended to the `root_output_dir` for purposes of writing outputs. root_output_dir: The name of the root output directory for writing experiment outputs. max_eval_batches: If set to a positive integer, evaluation datasets are capped to at most that many batches. If set to None or a nonpositive integer, the full evaluation datasets are used. **kwargs: Additional arguments configuring the training loop. For details on supported arguments, see `federated_research/utils/training_utils.py`. """ crop_shape = (crop_size, crop_size, 3) cifar_train, _, fed_test_data = cifar100_dataset.get_federated_cifar100( client_epochs_per_round=client_epochs_per_round, train_batch_size=client_batch_size, crop_shape=crop_shape, max_batches_per_client=max_batches_per_client) _, cifar_test = cifar100_dataset.get_centralized_datasets( train_batch_size=client_batch_size, max_test_batches=max_eval_batches, crop_shape=crop_shape) input_spec = cifar_train.create_tf_dataset_for_client( cifar_train.client_ids[0]).element_spec model_builder = functools.partial(resnet_models.create_resnet18, input_shape=crop_shape, num_classes=NUM_CLASSES) loss_builder = tf.keras.losses.SparseCategoricalCrossentropy metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()] def tff_model_fn() -> tff.learning.Model: return tff.learning.from_keras_model(keras_model=model_builder(), input_spec=input_spec, loss=loss_builder(), metrics=metrics_builder()) training_process = iterative_process_builder(tff_model_fn) evaluate_fn = training_utils.build_evaluate_fn( eval_dataset=cifar_test, model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder) test_fn = training_utils.build_unweighted_test_fn( federated_eval_dataset=fed_test_data, model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder) logging.info('Training model:') logging.info(model_builder().summary()) try: var = kwargs['hparam_dict']['var_q_clients'] q_client = np.load( f'/home/monica/AVAIL_VECTORS/q_client_{var}_cifar.npy') except: logging.info( 'Could not load q_client - initializing random availabilities') q_client = None if schedule == 'none': client_datasets_fn = training_utils.build_client_datasets_fn( train_dataset=cifar_train, train_clients_per_round=clients_per_round, random_seed=client_datasets_random_seed, var_q_clients=kwargs['hparam_dict']['var_q_clients'], f_mult=kwargs['hparam_dict']['f_mult'], f_intercept=kwargs['hparam_dict']['f_intercept'], sine_wave=kwargs['hparam_dict']['sine_wave'], use_p=True, q_client=q_client, ) training_loop.run(iterative_process=training_process, client_datasets_fn=client_datasets_fn, validation_fn=evaluate_fn, test_fn=test_fn, total_rounds=total_rounds, experiment_name=experiment_name, root_output_dir=root_output_dir, **kwargs) elif schedule == 'loss': if 'loss_pool_size' in kwargs['hparam_dict'] and kwargs['hparam_dict'][ 'loss_pool_size'] is not None: loss_pool_size = kwargs['hparam_dict']['loss_pool_size'] logging.info(f'Loss pool size: {loss_pool_size}') client_datasets_fn = training_utils.build_client_datasets_fn( train_dataset=cifar_train, train_clients_per_round=loss_pool_size, random_seed=client_datasets_random_seed, var_q_clients=kwargs['hparam_dict']['var_q_clients'], f_mult=kwargs['hparam_dict']['f_mult'], f_intercept=kwargs['hparam_dict']['f_intercept'], sine_wave=kwargs['hparam_dict']['sine_wave'], use_p=True, q_client=q_client, ) training_loop_loss.run(iterative_process=training_process, client_datasets_fn=client_datasets_fn, validation_fn=evaluate_fn, test_fn=test_fn, total_rounds=total_rounds, total_clients=loss_pool_size, experiment_name=experiment_name, root_output_dir=root_output_dir, **kwargs) else: raise ValueError('Loss pool size not specified') else: client_datasets_fn = training_utils.build_availability_client_datasets_fn( train_dataset=cifar_train, train_clients_per_round=clients_per_round, random_seed=client_datasets_random_seed, beta=beta, var_q_clients=kwargs['hparam_dict']['var_q_clients'], f_mult=kwargs['hparam_dict']['f_mult'], f_intercept=kwargs['hparam_dict']['f_intercept'], sine_wave=kwargs['hparam_dict']['sine_wave'], q_client=q_client, ) training_loop_importance.run(iterative_process=training_process, client_datasets_fn=client_datasets_fn, validation_fn=evaluate_fn, test_fn=test_fn, total_rounds=total_rounds, experiment_name=experiment_name, root_output_dir=root_output_dir, **kwargs)
def run_federated( iterative_process_builder: Callable[..., tff.templates.IterativeProcess], client_epochs_per_round: int, client_batch_size: int, clients_per_round: int, max_elements_per_user: int, total_rounds: int = 3000, vocab_size: int = 10000, num_oov_buckets: int = 1, sequence_length: int = 20, num_validation_examples: int = 10000, dim_embed: int = 96, dim_model: int = 512, dim_hidden: int = 2048, num_heads: int = 8, num_layers: int = 1, max_position_encoding: int = 1000, dropout: float = 0.1, client_datasets_random_seed: Optional[int] = None, experiment_name: str = 'federated_stackoverflow', root_output_dir: str = '/tmp/fedopt_guide', max_val_test_batches: Optional[int] = None, **kwargs) -> None: """Configures training for Stack Overflow next-word prediction. This method will load and pre-process dataset and construct a model used for the task. It then uses `iterative_process_builder` to create an iterative process that it applies to the task, using `federated_research/fedopt_guide/training_loop`. We assume that the iterative process has the following functional type signatures: * `initialize`: `( -> S@SERVER)` where `S` represents the server state. * `next`: `<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>` where `S` represents the server state, `{B*}` represents the client datasets, and `T` represents a python `Mapping` object. The iterative process must also have a callable attribute `get_model_weights` that takes as input the state of the iterative process, and returns a `tff.learning.ModelWeights` object. Args: iterative_process_builder: A function that accepts a no-arg `model_fn`, a `client_weight_fn` and returns a `tff.templates.IterativeProcess`. The `model_fn` must return a `tff.learning.Model`. client_epochs_per_round: An integer representing the number of epochs of training performed per client in each training round. client_batch_size: An integer representing the batch size used on clients. clients_per_round: An integer representing the number of clients participating in each round. max_elements_per_user: The maximum number of elements processed for each client's dataset. This has be to a positive value or -1 (which means that all elements are taken for training). total_rounds: The number of federated training rounds. vocab_size: Integer dictating the number of most frequent words to use in the vocabulary. num_oov_buckets: The number of out-of-vocabulary buckets to use. sequence_length: The maximum number of words to take for each sequence. num_validation_examples: The number of test examples to use for validation. dim_embed: An integer for the dimension of the token embeddings. dim_model: An integer for the dimension of features of MultiHeadAttention layers. dim_hidden: An integer for the dimension of hidden layers of the FFN. num_heads: An integer for the number of attention heads. num_layers: An integer for the number of Transformer blocks. max_position_encoding: Maximum number of positions for position embeddings. dropout: Dropout rate. client_datasets_random_seed: An optional int used to seed which clients are sampled at each round. If `None`, no seed is used. experiment_name: The name of the experiment being run. This will be appended to the `root_output_dir` for purposes of writing outputs. root_output_dir: The name of the root output directory for writing experiment outputs. max_val_test_batches: If set to a positive integer, val and test datasets are capped to at most that many batches. If set to None or a nonpositive integer, the full datasets are used. **kwargs: Additional arguments configuring the training loop. For details on supported arguments, see `federated_research/fedopt_guide/training_utils.py`. Returns: A `RunnerSpec` containing attributes used for running the newly created federated task. """ train_clientdata, _, _ = tff.simulation.datasets.stackoverflow.load_data() _, validation_dataset, test_dataset = stackoverflow_word_prediction.get_centralized_datasets( vocab_size=vocab_size, max_sequence_length=sequence_length, num_validation_examples=num_validation_examples, num_oov_buckets=num_oov_buckets) if max_val_test_batches and max_val_test_batches >= 1: validation_dataset = validation_dataset.take(max_val_test_batches) test_dataset = test_dataset.take(max_val_test_batches) model_builder = functools.partial( transformer_models.create_transformer_lm, vocab_size=vocab_size, num_oov_buckets=num_oov_buckets, d_embed=dim_embed, d_model=dim_model, d_hidden=dim_hidden, num_heads=num_heads, num_layers=num_layers, max_position_encoding=max_position_encoding, dropout=dropout, name='stackoverflow-transformer') loss_builder = functools.partial( tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True) special_tokens = stackoverflow_word_prediction.get_special_tokens( vocab_size, num_oov_buckets) pad_token = special_tokens.pad oov_tokens = special_tokens.oov eos_token = special_tokens.eos def metrics_builder(): return [ keras_metrics.MaskedCategoricalAccuracy( name='accuracy_with_oov', masked_tokens=[pad_token]), keras_metrics.MaskedCategoricalAccuracy( name='accuracy_no_oov', masked_tokens=[pad_token] + oov_tokens), # Notice BOS never appears in ground truth. keras_metrics.MaskedCategoricalAccuracy( name='accuracy_no_oov_or_eos', masked_tokens=[pad_token, eos_token] + oov_tokens), keras_metrics.NumBatchesCounter(), keras_metrics.NumTokensCounter(masked_tokens=[pad_token]) ] train_dataset_preprocess_comp = stackoverflow_word_prediction.create_preprocess_fn( vocab=stackoverflow_word_prediction.create_vocab(vocab_size), num_oov_buckets=num_oov_buckets, client_batch_size=client_batch_size, client_epochs_per_round=client_epochs_per_round, max_sequence_length=sequence_length, max_elements_per_client=max_elements_per_user) input_spec = train_dataset_preprocess_comp.type_signature.result.element def tff_model_fn() -> tff.learning.Model: return tff.learning.from_keras_model( keras_model=model_builder(), input_spec=input_spec, loss=loss_builder(), metrics=metrics_builder()) def client_weight_fn(local_outputs): # Num_tokens is a tensor with type int64[1], to use as a weight need # a float32 scalar. return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32) iterative_process = iterative_process_builder( tff_model_fn, client_weight_fn=client_weight_fn) if hasattr(train_clientdata, 'dataset_computation'): @tff.tf_computation(tf.string) def train_dataset_computation(client_id): client_train_data = train_clientdata.dataset_computation(client_id) return train_dataset_preprocess_comp(client_train_data) training_process = tff.simulation.compose_dataset_computation_with_iterative_process( train_dataset_computation, iterative_process) client_ids_fn = training_utils.build_sample_fn( train_clientdata.client_ids, size=clients_per_round, replace=False, random_seed=client_datasets_random_seed) # We convert the output to a list (instead of an np.ndarray) so that it can # be used as input to the iterative process. client_sampling_fn = lambda x: list(client_ids_fn(x)) else: training_process = tff.simulation.compose_dataset_computation_with_iterative_process( train_dataset_preprocess_comp, iterative_process) client_sampling_fn = training_utils.build_client_datasets_fn( dataset=train_clientdata, clients_per_round=clients_per_round, random_seed=client_datasets_random_seed) training_process.get_model_weights = iterative_process.get_model_weights evaluate_fn = training_utils.build_centralized_evaluate_fn( model_builder=model_builder, eval_dataset=validation_dataset, loss_builder=loss_builder, metrics_builder=metrics_builder) validation_fn = lambda model_weights, round_num: evaluate_fn(model_weights) test_fn = training_utils.build_centralized_evaluate_fn( model_builder=model_builder, # Use both val and test for symmetry with other experiments, which # evaluate on the entire test set. eval_dataset=validation_dataset.concatenate(test_dataset), loss_builder=loss_builder, metrics_builder=metrics_builder) logging.info('Training model:') logging.info(model_builder().summary()) training_loop.run( iterative_process=training_process, train_client_datasets_fn=client_sampling_fn, evaluation_fn=validation_fn, test_fn=test_fn, total_rounds=total_rounds, experiment_name=experiment_name, root_output_dir=root_output_dir, **kwargs)