def test_build_evaluate_fn(self): loss_builder = tf.keras.losses.MeanSquaredError metrics_builder = lambda: [tf.keras.metrics.MeanSquaredError()] def tff_model_fn(): return tff.learning.from_keras_model( keras_model=model_builder(), dummy_batch=get_sample_batch(), loss=loss_builder(), metrics=metrics_builder()) iterative_process = fed_avg_schedule.build_fed_avg_process( tff_model_fn, client_optimizer_fn=tf.keras.optimizers.SGD) state = iterative_process.initialize() test_dataset = create_tf_dataset_for_client(1) evaluate_fn = training_utils.build_evaluate_fn(test_dataset, model_builder, loss_builder, metrics_builder) test_metrics = evaluate_fn(state) self.assertIn('loss', test_metrics)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') tf.enable_v2_behavior() # TODO(b/139129100): Remove this once the local executor is the default. tff.framework.set_default_executor( tff.framework.local_executor_factory(max_fanout=25)) stackoverflow_train, stackoverflow_validation, stackoverflow_test = dataset.get_stackoverflow_datasets( vocab_tokens_size=FLAGS.vocab_tokens_size, vocab_tags_size=FLAGS.vocab_tags_size, client_batch_size=FLAGS.client_batch_size, client_epochs_per_round=FLAGS.client_epochs_per_round, max_training_elements_per_user=FLAGS.max_elements_per_user, num_validation_examples=FLAGS.num_validation_examples) sample_client_dataset = stackoverflow_train.create_tf_dataset_for_client( stackoverflow_train.client_ids[0]) # TODO(b/144382142): Sample batches cannot be eager tensors, since they are # passed (implicitly) to tff.learning.build_federated_averaging_process. sample_batch = tf.nest.map_structure(lambda x: x.numpy(), next(iter(sample_client_dataset))) model_builder = functools.partial( models.create_logistic_model, vocab_tokens_size=FLAGS.vocab_tokens_size, vocab_tags_size=FLAGS.vocab_tags_size) loss_builder = functools.partial(tf.keras.losses.BinaryCrossentropy, from_logits=False, reduction=tf.keras.losses.Reduction.SUM) training_process = iterative_process_builder.from_flags( dummy_batch=sample_batch, model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder) client_datasets_fn = training_utils.build_client_datasets_fn( stackoverflow_train, FLAGS.clients_per_round) evaluate_fn = training_utils.build_evaluate_fn( model_builder=model_builder, eval_dataset=stackoverflow_validation, loss_builder=loss_builder, metrics_builder=metrics_builder, # Use both val and test for symmetry with other experiments, which # evaluate on the entire test set. test_dataset=stackoverflow_validation.concatenate(stackoverflow_test)) logging.info('Training model:') logging.info(model_builder().summary()) training_loop.run( iterative_process=training_process, client_datasets_fn=client_datasets_fn, evaluate_fn=evaluate_fn, )
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) tf.compat.v1.enable_v2_behavior() # TODO(b/139129100): Remove this once the local executor is the default. tff.framework.set_default_executor( tff.framework.local_executor_factory(max_fanout=25)) emnist_train, emnist_test = dataset.get_emnist_datasets( FLAGS.client_batch_size, FLAGS.client_epochs_per_round, only_digits=False) sample_client_dataset = emnist_train.create_tf_dataset_for_client( emnist_train.client_ids[0]) # TODO(b/144382142): Sample batches cannot be eager tensors, since they are # passed (implicitly) to tff.learning.build_federated_averaging_process. sample_batch = tf.nest.map_structure(lambda x: x.numpy(), next(iter(sample_client_dataset))) if FLAGS.model == 'cnn': model_builder = functools.partial(models.create_conv_dropout_model, only_digits=False) elif FLAGS.model == '2nn': model_builder = functools.partial(models.create_two_hidden_layer_model, only_digits=False) else: raise ValueError('Cannot handle model flag [{!s}].'.format( FLAGS.model)) loss_builder = tf.keras.losses.SparseCategoricalCrossentropy metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()] training_process = iterative_process_builder.from_flags( dummy_batch=sample_batch, model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder) client_datasets_fn = training_utils.build_client_datasets_fn( emnist_train, FLAGS.clients_per_round) evaluate_fn = training_utils.build_evaluate_fn( eval_dataset=emnist_test, model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder) logging.info('Training model:') logging.info(model_builder().summary()) training_loop.run( iterative_process=training_process, client_datasets_fn=client_datasets_fn, evaluate_fn=evaluate_fn, )
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') tf.compat.v1.enable_v2_behavior() tff.framework.set_default_executor( tff.framework.local_executor_factory(max_fanout=25)) train_clientdata, test_dataset = dataset.construct_character_level_datasets( FLAGS.client_batch_size, FLAGS.client_epochs_per_round, FLAGS.sequence_length) test_dataset = test_dataset.cache() loss_fn_builder = functools.partial( tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True) # Need to iterate until we find a client with data. for client_id in train_clientdata.client_ids: try: sample_batch = next( iter(train_clientdata.create_tf_dataset_for_client(client_id))) break except StopIteration: pass # Client had no batches. sample_batch = tf.nest.map_structure(lambda t: t.numpy(), sample_batch) def client_weight_fn(local_outputs): # Num_tokens is a tensor with type int64[1], to use as a weight need # a float32 scalar. return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32) training_process = iterative_process_builder.from_flags( dummy_batch=sample_batch, model_builder=model_builder, loss_builder=loss_fn_builder, metrics_builder=metrics_builder, client_weight_fn=client_weight_fn) logging.info('Training model:') logging.info(model_builder().summary()) training_loop.run( iterative_process=training_process, client_datasets_fn=training_utils.build_client_datasets_fn( train_clientdata, FLAGS.clients_per_round), evaluate_fn=training_utils.build_evaluate_fn( eval_dataset=test_dataset, model_builder=model_builder, loss_builder=loss_fn_builder, metrics_builder=metrics_builder), )
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) tf.compat.v1.enable_v2_behavior() # TODO(b/139129100): Remove this once the local executor is the default. tff.framework.set_default_executor( tff.framework.local_executor_factory(max_fanout=25)) cifar_train, cifar_test = dataset.get_federated_cifar100( client_epochs_per_round=FLAGS.client_epochs_per_round, train_batch_size=FLAGS.client_batch_size, crop_shape=CROP_SHAPE) sample_client_dataset = cifar_train.create_tf_dataset_for_client( cifar_train.client_ids[0]) sample_batch = tf.nest.map_structure(lambda x: x.numpy(), next(iter(sample_client_dataset))) model_builder = functools.partial( resnet_models.create_resnet18, input_shape=CROP_SHAPE, num_classes=NUM_CLASSES) logging.info('Training model:') logging.info(model_builder().summary()) loss_builder = tf.keras.losses.SparseCategoricalCrossentropy metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()] training_process = iterative_process_builder.from_flags( dummy_batch=sample_batch, model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder) training_loop.run( iterative_process=training_process, client_datasets_fn=training_utils.build_client_datasets_fn( cifar_train, FLAGS.clients_per_round), evaluate_fn=training_utils.build_evaluate_fn( eval_dataset=cifar_test, model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder), )
def main(_): tf.enable_v2_behavior() # TODO(b/139129100): Remove this once the local executor is the default. tff.framework.set_default_executor( tff.framework.local_executor_factory(max_fanout=25)) emnist_train, emnist_test = dataset.get_emnist_datasets( FLAGS.client_batch_size, FLAGS.client_epochs_per_round, only_digits=False) sample_client_dataset = emnist_train.create_tf_dataset_for_client( emnist_train.client_ids[0]) # TODO(b/144382142): Sample batches cannot be eager tensors, since they are # passed (implicitly) to tff.learning.build_federated_averaging_process. sample_batch = tf.nest.map_structure(lambda x: x.numpy(), next(iter(sample_client_dataset))) model_builder = models.create_autoencoder_model loss_builder = tf.keras.losses.MeanSquaredError metrics_builder = lambda: [tf.keras.metrics.MeanSquaredError()] training_process = iterative_process_builder.from_flags( dummy_batch=sample_batch, model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder) client_datasets_fn = training_utils.build_client_datasets_fn( emnist_train, FLAGS.clients_per_round) evaluate_fn = training_utils.build_evaluate_fn( eval_dataset=emnist_test, model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder) logging.info('Training model:') logging.info(model_builder().summary()) training_loop.run( iterative_process=training_process, client_datasets_fn=client_datasets_fn, evaluate_fn=evaluate_fn, )
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) tf.compat.v1.enable_v2_behavior() tff.framework.set_default_executor( tff.framework.local_executor_factory(max_fanout=10)) if FLAGS.lstm: def _layer_fn(x): return tf.keras.layers.LSTM(x, return_sequences=True) else: def _layer_fn(x): return tf.keras.layers.GRU(x, return_sequences=True) model_builder = functools.partial(models.create_recurrent_model, vocab_size=FLAGS.vocab_size, recurrent_layer_fn=_layer_fn, shared_embedding=FLAGS.shared_embedding) loss_builder = functools.partial( tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True) pad_token, oov_token, _, eos_token = dataset.get_special_tokens( FLAGS.vocab_size) def metrics_builder(): return [ keras_metrics.FlattenedCategoricalAccuracy( # Plus 4 for PAD, OOV, BOS and EOS. vocab_size=FLAGS.vocab_size + 4, name='accuracy_with_oov', masked_tokens=pad_token), keras_metrics.FlattenedCategoricalAccuracy( vocab_size=FLAGS.vocab_size + 4, name='accuracy_no_oov', masked_tokens=[pad_token, oov_token]), # Notice BOS never appears in ground truth. keras_metrics.FlattenedCategoricalAccuracy( vocab_size=FLAGS.vocab_size + 4, name='accuracy_no_oov_or_eos', masked_tokens=[pad_token, oov_token, eos_token]), keras_metrics.NumBatchesCounter(), keras_metrics.FlattenedNumExamplesCounter(name='num_tokens', mask_zero=True), ] (stackoverflow_train, stackoverflow_validation, stackoverflow_test) = dataset.construct_word_level_datasets( FLAGS.vocab_size, FLAGS.client_batch_size, FLAGS.client_epochs_per_round, FLAGS.sequence_length, FLAGS.max_elements_per_user, FLAGS.num_validation_examples) sample_batch = tf.nest.map_structure(lambda x: x.numpy(), next(iter(stackoverflow_validation))) def client_weight_fn(local_outputs): # Num_tokens is a tensor with type int64[1], to use as a weight need # a float32 scalar. return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32) training_process = iterative_process_builder.from_flags( dummy_batch=sample_batch, model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder, client_weight_fn=client_weight_fn) client_datasets_fn = training_utils.build_client_datasets_fn( stackoverflow_train, FLAGS.clients_per_round) eval_fn = training_utils.build_evaluate_fn( model_builder=model_builder, eval_dataset=stackoverflow_validation, loss_builder=loss_builder, metrics_builder=metrics_builder, # Use both val and test for symmetry with other experiments, which # evaluate on the entire test set. test_dataset=stackoverflow_validation.concatenate(stackoverflow_test)) logging.info('Training model:') logging.info(model_builder().summary()) training_loop.run(training_process, client_datasets_fn, eval_fn)