def test_raises_no_repeat_and_no_take(self): with self.assertRaisesRegex( ValueError, 'Argument client_epochs_per_round is set to -1'): stackoverflow_lr_dataset.get_stackoverflow_datasets( vocab_tokens_size=1000, vocab_tags_size=500, max_training_elements_per_user=128, client_batch_size=10, client_epochs_per_round=-1, max_batches_per_user=-1, num_validation_examples=500)
def test_take_with_repeat(self): so_train, _, _ = stackoverflow_lr_dataset.get_stackoverflow_datasets( vocab_tokens_size=1000, vocab_tags_size=500, max_training_elements_per_user=128, client_batch_size=10, client_epochs_per_round=-1, max_batches_per_user=8, num_validation_examples=500) for i in range(10): client_ds = so_train.create_tf_dataset_for_client( so_train.client_ids[i]) self.assertEqual(_compute_length_of_dataset(client_ds), 8)
def test_stackoverflow_dataset_structure(self): stackoverflow_train, stackoverflow_validation, stackoverflow_test = stackoverflow_lr_dataset.get_stackoverflow_datasets( vocab_tokens_size=100, vocab_tags_size=5, max_training_elements_per_user=10, client_batch_size=10, client_epochs_per_round=1, num_validation_examples=10000) self.assertEqual(len(stackoverflow_train.client_ids), 342477) sample_train_ds = stackoverflow_train.create_tf_dataset_for_client( stackoverflow_train.client_ids[0]) train_batch = next(iter(sample_train_ds)) valid_batch = next(iter(stackoverflow_validation)) test_batch = next(iter(stackoverflow_test)) self.assertEqual(train_batch[0].shape.as_list(), [10, 100]) self.assertEqual(train_batch[1].shape.as_list(), [10, 5]) self.assertEqual(valid_batch[0].shape.as_list(), [TEST_BATCH_SIZE, 100]) self.assertEqual(valid_batch[1].shape.as_list(), [TEST_BATCH_SIZE, 5]) self.assertEqual(test_batch[0].shape.as_list(), [TEST_BATCH_SIZE, 100]) self.assertEqual(test_batch[1].shape.as_list(), [TEST_BATCH_SIZE, 5])
def run_federated( iterative_process_builder: Callable[..., tff.templates.IterativeProcess], client_epochs_per_round: int, client_batch_size: int, clients_per_round: int, max_batches_per_client: Optional[int] = -1, client_datasets_random_seed: Optional[int] = None, vocab_tokens_size: Optional[int] = 10000, vocab_tags_size: Optional[int] = 500, max_elements_per_user: Optional[int] = 1000, num_validation_examples: Optional[int] = 10000, total_rounds: Optional[int] = 1500, experiment_name: Optional[str] = 'federated_so_lr', root_output_dir: Optional[str] = '/tmp/fed_opt', max_eval_batches: Optional[int] = None, **kwargs): """Runs an iterative process on the Stack Overflow logistic regression task. This method will load and pre-process dataset and construct a model used for the task. It then uses `iterative_process_builder` to create an iterative process that it applies to the task, using `federated_research.utils.training_loop`. We assume that the iterative process has the following functional type signatures: * `initialize`: `( -> S@SERVER)` where `S` represents the server state. * `next`: `<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>` where `S` represents the server state, `{B*}` represents the client datasets, and `T` represents a python `Mapping` object. Moreover, the server state must have an attribute `model` of type `tff.learning.ModelWeights`. Args: iterative_process_builder: A function that accepts a no-arg `model_fn`, and returns a `tff.templates.IterativeProcess`. The `model_fn` must return a `tff.learning.Model`. client_epochs_per_round: An integer representing the number of epochs of training performed per client in each training round. client_batch_size: An integer representing the batch size used on clients. clients_per_round: An integer representing the number of clients participating in each round. max_batches_per_client: An optional int specifying the number of batches taken by each client at each round. If `-1`, the entire client dataset is used. client_datasets_random_seed: An optional int used to seed which clients are sampled at each round. If `None`, no seed is used. vocab_tokens_size: Integer dictating the number of most frequent words to use in the vocabulary. vocab_tags_size: Integer dictating the number of most frequent tags to use in the label creation. max_elements_per_user: The maximum number of elements processed for each client's dataset. num_validation_examples: The number of test examples to use for validation. total_rounds: The number of federated training rounds. experiment_name: The name of the experiment being run. This will be appended to the `root_output_dir` for purposes of writing outputs. root_output_dir: The name of the root output directory for writing experiment outputs. max_eval_batches: If set to a positive integer, evaluation datasets are capped to at most that many batches. If set to None or a nonpositive integer, the full evaluation datasets are used. **kwargs: Additional arguments configuring the training loop. For details on supported arguments, see `federated_research/utils/training_utils.py`. """ stackoverflow_train, _, _ = stackoverflow_lr_dataset.get_stackoverflow_datasets( vocab_tokens_size=vocab_tokens_size, vocab_tags_size=vocab_tags_size, client_batch_size=client_batch_size, client_epochs_per_round=client_epochs_per_round, max_training_elements_per_user=max_elements_per_user, max_batches_per_user=max_batches_per_client, num_validation_examples=num_validation_examples) _, stackoverflow_validation, stackoverflow_test = stackoverflow_lr_dataset.get_centralized_datasets( train_batch_size=client_batch_size, vocab_tokens_size=vocab_tokens_size, vocab_tags_size=vocab_tags_size, num_validation_examples=num_validation_examples, max_validation_batches=max_eval_batches, max_test_batches=max_eval_batches) input_spec = stackoverflow_train.create_tf_dataset_for_client( stackoverflow_train.client_ids[0]).element_spec model_builder = functools.partial( stackoverflow_lr_models.create_logistic_model, vocab_tokens_size=vocab_tokens_size, vocab_tags_size=vocab_tags_size) loss_builder = functools.partial(tf.keras.losses.BinaryCrossentropy, from_logits=False, reduction=tf.keras.losses.Reduction.SUM) def tff_model_fn() -> tff.learning.Model: return tff.learning.from_keras_model(keras_model=model_builder(), input_spec=input_spec, loss=loss_builder(), metrics=metrics_builder()) training_process = iterative_process_builder(tff_model_fn) client_datasets_fn = training_utils.build_client_datasets_fn( train_dataset=stackoverflow_train, train_clients_per_round=clients_per_round, random_seed=client_datasets_random_seed) evaluate_fn = training_utils.build_evaluate_fn( model_builder=model_builder, eval_dataset=stackoverflow_validation, loss_builder=loss_builder, metrics_builder=metrics_builder) test_fn = training_utils.build_evaluate_fn( model_builder=model_builder, # Use both val and test for symmetry with other experiments, which # evaluate on the entire test set. eval_dataset=stackoverflow_validation.concatenate(stackoverflow_test), loss_builder=loss_builder, metrics_builder=metrics_builder) logging.info('Training model:') logging.info(model_builder().summary()) training_loop.run(iterative_process=training_process, client_datasets_fn=client_datasets_fn, validation_fn=evaluate_fn, test_fn=test_fn, total_rounds=total_rounds, experiment_name=experiment_name, root_output_dir=root_output_dir, **kwargs)