Beispiel #1
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    stackoverflow_train, stackoverflow_validation, stackoverflow_test = stackoverflow_lr_dataset.get_stackoverflow_datasets(
        vocab_tokens_size=FLAGS.vocab_tokens_size,
        vocab_tags_size=FLAGS.vocab_tags_size,
        client_batch_size=FLAGS.client_batch_size,
        client_epochs_per_round=FLAGS.client_epochs_per_round,
        max_training_elements_per_user=FLAGS.max_elements_per_user,
        num_validation_examples=FLAGS.num_validation_examples)

    input_spec = stackoverflow_train.create_tf_dataset_for_client(
        stackoverflow_train.client_ids[0]).element_spec

    model_builder = functools.partial(
        stackoverflow_lr_models.create_logistic_model,
        vocab_tokens_size=FLAGS.vocab_tokens_size,
        vocab_tags_size=FLAGS.vocab_tags_size)

    loss_builder = functools.partial(tf.keras.losses.BinaryCrossentropy,
                                     from_logits=False,
                                     reduction=tf.keras.losses.Reduction.SUM)

    training_process = iterative_process_builder.from_flags(
        input_spec=input_spec,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    client_datasets_fn = training_utils.build_client_datasets_fn(
        train_dataset=stackoverflow_train,
        train_clients_per_round=FLAGS.clients_per_round,
        random_seed=FLAGS.client_datasets_random_seed)

    assign_weights_fn = fed_avg_schedule.ServerState.assign_weights_to_keras_model

    evaluate_fn = training_utils.build_evaluate_fn(
        model_builder=model_builder,
        eval_dataset=stackoverflow_validation,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder,
        assign_weights_to_keras_model=assign_weights_fn)

    test_fn = training_utils.build_evaluate_fn(
        model_builder=model_builder,
        # Use both val and test for symmetry with other experiments, which
        # evaluate on the entire test set.
        eval_dataset=stackoverflow_validation.concatenate(stackoverflow_test),
        loss_builder=loss_builder,
        metrics_builder=metrics_builder,
        assign_weights_to_keras_model=assign_weights_fn)

    logging.info('Training model:')
    logging.info(model_builder().summary())

    training_loop.run(training_process,
                      client_datasets_fn,
                      evaluate_fn,
                      test_fn=test_fn)
Beispiel #2
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tf.compat.v1.enable_v2_behavior()

    stackoverflow_train, stackoverflow_validation, stackoverflow_test = dataset.get_stackoverflow_datasets(
        vocab_tokens_size=FLAGS.vocab_tokens_size,
        vocab_tags_size=FLAGS.vocab_tags_size,
        client_batch_size=FLAGS.client_batch_size,
        client_epochs_per_round=FLAGS.client_epochs_per_round,
        max_training_elements_per_user=FLAGS.max_elements_per_user,
        num_validation_examples=FLAGS.num_validation_examples)

    sample_client_dataset = stackoverflow_train.create_tf_dataset_for_client(
        stackoverflow_train.client_ids[0])
    # TODO(b/144382142): Sample batches cannot be eager tensors, since they are
    # passed (implicitly) to tff.learning.build_federated_averaging_process.
    sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                         next(iter(sample_client_dataset)))

    model_builder = functools.partial(
        models.create_logistic_model,
        vocab_tokens_size=FLAGS.vocab_tokens_size,
        vocab_tags_size=FLAGS.vocab_tags_size)

    loss_builder = functools.partial(tf.keras.losses.BinaryCrossentropy,
                                     from_logits=False,
                                     reduction=tf.keras.losses.Reduction.SUM)

    training_process = iterative_process_builder.from_flags(
        dummy_batch=sample_batch,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    client_datasets_fn = training_utils.build_client_datasets_fn(
        stackoverflow_train, FLAGS.clients_per_round)

    evaluate_fn = training_utils.build_evaluate_fn(
        model_builder=model_builder,
        eval_dataset=stackoverflow_validation,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    test_fn = training_utils.build_evaluate_fn(
        model_builder=model_builder,
        # Use both val and test for symmetry with other experiments, which
        # evaluate on the entire test set.
        eval_dataset=stackoverflow_validation.concatenate(stackoverflow_test),
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    logging.info('Training model:')
    logging.info(model_builder().summary())

    training_loop.run(training_process,
                      client_datasets_fn,
                      evaluate_fn,
                      test_fn=test_fn)
Beispiel #3
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    tf.compat.v1.enable_v2_behavior()
    # TODO(b/139129100): Remove this once the local executor is the default.
    tff.framework.set_default_executor(
        tff.framework.local_executor_factory(max_fanout=25))

    emnist_train, emnist_test = dataset.get_emnist_datasets(
        FLAGS.client_batch_size,
        FLAGS.client_epochs_per_round,
        only_digits=False)

    sample_client_dataset = emnist_train.create_tf_dataset_for_client(
        emnist_train.client_ids[0])
    # TODO(b/144382142): Sample batches cannot be eager tensors, since they are
    # passed (implicitly) to tff.learning.build_federated_averaging_process.
    sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                         next(iter(sample_client_dataset)))

    if FLAGS.model == 'cnn':
        model_builder = functools.partial(models.create_conv_dropout_model,
                                          only_digits=False)
    elif FLAGS.model == '2nn':
        model_builder = functools.partial(models.create_two_hidden_layer_model,
                                          only_digits=False)
    else:
        raise ValueError('Cannot handle model flag [{!s}].'.format(
            FLAGS.model))

    loss_builder = tf.keras.losses.SparseCategoricalCrossentropy
    metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()]

    training_process = iterative_process_builder.from_flags(
        dummy_batch=sample_batch,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    client_datasets_fn = training_utils.build_client_datasets_fn(
        emnist_train, FLAGS.clients_per_round)

    evaluate_fn = training_utils.build_evaluate_fn(
        eval_dataset=emnist_test,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    logging.info('Training model:')
    logging.info(model_builder().summary())

    training_loop.run(
        iterative_process=training_process,
        client_datasets_fn=client_datasets_fn,
        evaluate_fn=evaluate_fn,
    )
Beispiel #4
0
 def test_iterative_process_no_schedule_decreases_loss(self):
     FLAGS.client_lr_schedule = 'constant'
     FLAGS.server_lr_schedule = 'constant'
     federated_data = [[_batch_fn()]]
     input_spec = _get_input_spec()
     iterproc = iterative_process_builder.from_flags(
         input_spec, model_builder, loss_builder, metrics_builder)
     _, train_outputs = self._run_rounds(iterproc, federated_data, 4)
     self.assertLess(train_outputs[-1]['loss'], train_outputs[0]['loss'])
Beispiel #5
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    tf.compat.v1.enable_v2_behavior()

    emnist_train, emnist_test = dataset.get_emnist_datasets(
        FLAGS.client_batch_size,
        FLAGS.client_epochs_per_round,
        only_digits=False)

    input_spec = emnist_train.create_tf_dataset_for_client(
        emnist_train.client_ids[0]).element_spec

    if FLAGS.model == 'cnn':
        model_builder = functools.partial(models.create_conv_dropout_model,
                                          only_digits=False)
    elif FLAGS.model == '2nn':
        model_builder = functools.partial(models.create_two_hidden_layer_model,
                                          only_digits=False)
    else:
        raise ValueError('Cannot handle model flag [{!s}].'.format(
            FLAGS.model))

    loss_builder = tf.keras.losses.SparseCategoricalCrossentropy
    metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()]

    training_process = iterative_process_builder.from_flags(
        input_spec=input_spec,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    client_datasets_fn = training_utils.build_client_datasets_fn(
        train_dataset=emnist_train,
        train_clients_per_round=FLAGS.clients_per_round,
        random_seed=FLAGS.client_datasets_random_seed)

    assign_weights_fn = fed_avg_schedule.ServerState.assign_weights_to_keras_model

    evaluate_fn = training_utils.build_evaluate_fn(
        eval_dataset=emnist_test,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder,
        assign_weights_to_keras_model=assign_weights_fn)

    logging.info('Training model:')
    logging.info(model_builder().summary())

    training_loop.run(
        iterative_process=training_process,
        client_datasets_fn=client_datasets_fn,
        evaluate_fn=evaluate_fn,
    )
Beispiel #6
0
 def test_iterative_process_with_inv_time_server_schedule(self, sched_type):
     FLAGS.client_lr_schedule = 'constant'
     FLAGS.server_lr_decay_steps = 1
     FLAGS.server_lr_decay_rate = 1.0
     FLAGS.server_lr_staircase = False
     FLAGS.server_lr_schedule = sched_type
     federated_data = [[_batch_fn()]]
     input_spec = _get_input_spec()
     iterproc = iterative_process_builder.from_flags(
         input_spec, model_builder, loss_builder, metrics_builder)
     _, train_outputs = self._run_rounds(iterproc, federated_data, 4)
     self.assertLess(train_outputs[-1]['loss'], train_outputs[0]['loss'])
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tf.compat.v1.enable_v2_behavior()
    tff.framework.set_default_executor(
        tff.framework.local_executor_factory(max_fanout=25))

    train_clientdata, test_dataset = dataset.construct_character_level_datasets(
        FLAGS.client_batch_size, FLAGS.client_epochs_per_round,
        FLAGS.sequence_length)
    test_dataset = test_dataset.cache()

    loss_fn_builder = functools.partial(
        tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True)

    # Need to iterate until we find a client with data.
    for client_id in train_clientdata.client_ids:
        try:
            sample_batch = next(
                iter(train_clientdata.create_tf_dataset_for_client(client_id)))
            break
        except StopIteration:
            pass  # Client had no batches.
    sample_batch = tf.nest.map_structure(lambda t: t.numpy(), sample_batch)

    def client_weight_fn(local_outputs):
        # Num_tokens is a tensor with type int64[1], to use as a weight need
        # a float32 scalar.
        return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32)

    training_process = iterative_process_builder.from_flags(
        dummy_batch=sample_batch,
        model_builder=model_builder,
        loss_builder=loss_fn_builder,
        metrics_builder=metrics_builder,
        client_weight_fn=client_weight_fn)

    logging.info('Training model:')
    logging.info(model_builder().summary())

    training_loop.run(
        iterative_process=training_process,
        client_datasets_fn=training_utils.build_client_datasets_fn(
            train_clientdata, FLAGS.clients_per_round),
        evaluate_fn=training_utils.build_evaluate_fn(
            eval_dataset=test_dataset,
            model_builder=model_builder,
            loss_builder=loss_fn_builder,
            metrics_builder=metrics_builder),
    )
Beispiel #8
0
    def test_iterative_process_with_exp_decay_server_schedule(self):
        FLAGS.client_lr_schedule = 'constant'
        FLAGS.server_lr_schedule = 'exp_decay'
        FLAGS.server_lr_decay_steps = 1
        FLAGS.server_lr_decay_rate = 0.5
        FLAGS.server_lr_staircase = False

        federated_data = [[_batch_fn()]]
        dummy_batch = _batch_fn()
        iterative_process = iterative_process_builder.from_flags(
            dummy_batch, model_builder, loss_builder, metrics_builder)
        _, train_outputs = self._run_rounds(iterative_process, federated_data,
                                            4)
        self.assertLess(train_outputs[-1]['loss'], train_outputs[0]['loss'])
Beispiel #9
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Expected no command-line arguments, '
                         'got: {}'.format(argv))

  tf.compat.v1.enable_v2_behavior()

  cifar_train, cifar_test = dataset.get_federated_cifar100(
      client_epochs_per_round=FLAGS.client_epochs_per_round,
      train_batch_size=FLAGS.client_batch_size,
      crop_shape=CROP_SHAPE)

  input_spec = cifar_train.create_tf_dataset_for_client(
      cifar_train.client_ids[0]).element_spec

  model_builder = functools.partial(
      resnet_models.create_resnet18,
      input_shape=CROP_SHAPE,
      num_classes=NUM_CLASSES)

  logging.info('Training model:')
  logging.info(model_builder().summary())

  loss_builder = tf.keras.losses.SparseCategoricalCrossentropy
  metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()]

  training_process = iterative_process_builder.from_flags(
      input_spec=input_spec,
      model_builder=model_builder,
      loss_builder=loss_builder,
      metrics_builder=metrics_builder)

  client_datasets_fn = training_utils.build_client_datasets_fn(
      train_dataset=cifar_train,
      train_clients_per_round=FLAGS.clients_per_round,
      random_seed=FLAGS.client_datasets_random_seed)

  assign_weights_fn = fed_avg_schedule.ServerState.assign_weights_to_keras_model

  evaluate_fn = training_utils.build_evaluate_fn(
      eval_dataset=cifar_test,
      model_builder=model_builder,
      loss_builder=loss_builder,
      metrics_builder=metrics_builder,
      assign_weights_to_keras_model=assign_weights_fn)

  training_loop.run(
      iterative_process=training_process,
      client_datasets_fn=client_datasets_fn,
      evaluate_fn=evaluate_fn)
Beispiel #10
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tf.compat.v1.enable_v2_behavior()

    train_clientdata, test_dataset = dataset.construct_character_level_datasets(
        FLAGS.client_batch_size, FLAGS.client_epochs_per_round,
        FLAGS.sequence_length)
    test_dataset = test_dataset.cache()

    loss_fn_builder = functools.partial(
        tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True)

    input_spec = train_clientdata.create_tf_dataset_for_client(
        train_clientdata.client_ids[0]).element_spec

    def client_weight_fn(local_outputs):
        # Num_tokens is a tensor with type int64[1], to use as a weight need
        # a float32 scalar.
        return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32)

    training_process = iterative_process_builder.from_flags(
        input_spec=input_spec,
        model_builder=model_builder,
        loss_builder=loss_fn_builder,
        metrics_builder=metrics_builder,
        client_weight_fn=client_weight_fn)

    client_datasets_fn = training_utils.build_client_datasets_fn(
        train_dataset=train_clientdata,
        train_clients_per_round=FLAGS.clients_per_round,
        random_seed=FLAGS.client_datasets_random_seed)

    assign_weights_fn = fed_avg_schedule.ServerState.assign_weights_to_keras_model

    evaluate_fn = training_utils.build_evaluate_fn(
        eval_dataset=test_dataset,
        model_builder=model_builder,
        loss_builder=loss_fn_builder,
        metrics_builder=metrics_builder,
        assign_weights_to_keras_model=assign_weights_fn)

    logging.info('Training model:')
    logging.info(model_builder().summary())

    training_loop.run(iterative_process=training_process,
                      client_datasets_fn=client_datasets_fn,
                      evaluate_fn=evaluate_fn)
Beispiel #11
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Expected no command-line arguments, '
                         'got: {}'.format(argv))

  tf.compat.v1.enable_v2_behavior()
  # TODO(b/139129100): Remove this once the local executor is the default.
  tff.framework.set_default_executor(
      tff.framework.local_executor_factory(max_fanout=25))

  cifar_train, cifar_test = dataset.get_federated_cifar100(
      client_epochs_per_round=FLAGS.client_epochs_per_round,
      train_batch_size=FLAGS.client_batch_size,
      crop_shape=CROP_SHAPE)

  sample_client_dataset = cifar_train.create_tf_dataset_for_client(
      cifar_train.client_ids[0])

  sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                       next(iter(sample_client_dataset)))

  model_builder = functools.partial(
      resnet_models.create_resnet18,
      input_shape=CROP_SHAPE,
      num_classes=NUM_CLASSES)

  logging.info('Training model:')
  logging.info(model_builder().summary())

  loss_builder = tf.keras.losses.SparseCategoricalCrossentropy
  metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()]

  training_process = iterative_process_builder.from_flags(
      dummy_batch=sample_batch,
      model_builder=model_builder,
      loss_builder=loss_builder,
      metrics_builder=metrics_builder)

  training_loop.run(
      iterative_process=training_process,
      client_datasets_fn=training_utils.build_client_datasets_fn(
          cifar_train, FLAGS.clients_per_round),
      evaluate_fn=training_utils.build_evaluate_fn(
          eval_dataset=cifar_test,
          model_builder=model_builder,
          loss_builder=loss_builder,
          metrics_builder=metrics_builder),
  )
Beispiel #12
0
    def test_decay_factor_0_does_not_decrease_loss(self):
        FLAGS.client_lr_schedule = 'exp_decay'
        FLAGS.client_lr_decay_steps = 2
        FLAGS.client_lr_decay_rate = 0.0
        FLAGS.client_lr_staircase = True
        FLAGS.server_lr_schedule = 'constant'

        federated_data = [[_batch_fn()]]
        input_spec = _get_input_spec()
        iterproc = iterative_process_builder.from_flags(
            input_spec, model_builder, loss_builder, metrics_builder)
        _, train_outputs = self._run_rounds(iterproc, federated_data, 4)
        self.assertLess(train_outputs[1]['loss'], train_outputs[0]['loss'])
        self.assertNear(train_outputs[2]['loss'],
                        train_outputs[3]['loss'],
                        err=1e-5)
Beispiel #13
0
def main(_):

    tf.enable_v2_behavior()
    # TODO(b/139129100): Remove this once the local executor is the default.
    tff.framework.set_default_executor(
        tff.framework.local_executor_factory(max_fanout=25))

    emnist_train, emnist_test = dataset.get_emnist_datasets(
        FLAGS.client_batch_size,
        FLAGS.client_epochs_per_round,
        only_digits=False)

    sample_client_dataset = emnist_train.create_tf_dataset_for_client(
        emnist_train.client_ids[0])
    # TODO(b/144382142): Sample batches cannot be eager tensors, since they are
    # passed (implicitly) to tff.learning.build_federated_averaging_process.
    sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                         next(iter(sample_client_dataset)))

    model_builder = models.create_autoencoder_model

    loss_builder = tf.keras.losses.MeanSquaredError
    metrics_builder = lambda: [tf.keras.metrics.MeanSquaredError()]

    training_process = iterative_process_builder.from_flags(
        dummy_batch=sample_batch,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    client_datasets_fn = training_utils.build_client_datasets_fn(
        emnist_train, FLAGS.clients_per_round)

    evaluate_fn = training_utils.build_evaluate_fn(
        eval_dataset=emnist_test,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    logging.info('Training model:')
    logging.info(model_builder().summary())

    training_loop.run(
        iterative_process=training_process,
        client_datasets_fn=client_datasets_fn,
        evaluate_fn=evaluate_fn,
    )
Beispiel #14
0
def main(_):

    emnist_train, emnist_test = emnist_ae_dataset.get_emnist_datasets(
        FLAGS.client_batch_size,
        FLAGS.client_epochs_per_round,
        only_digits=False)

    input_spec = emnist_train.create_tf_dataset_for_client(
        emnist_train.client_ids[0]).element_spec

    model_builder = emnist_ae_models.create_autoencoder_model

    loss_builder = functools.partial(tf.keras.losses.MeanSquaredError,
                                     reduction=tf.keras.losses.Reduction.SUM)

    metrics_builder = lambda: [tf.keras.metrics.MeanSquaredError()]

    training_process = iterative_process_builder.from_flags(
        input_spec=input_spec,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    client_datasets_fn = training_utils.build_client_datasets_fn(
        train_dataset=emnist_train,
        train_clients_per_round=FLAGS.clients_per_round,
        random_seed=FLAGS.client_datasets_random_seed)

    assign_weights_fn = fed_avg_schedule.ServerState.assign_weights_to_keras_model

    evaluate_fn = training_utils.build_evaluate_fn(
        eval_dataset=emnist_test,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder,
        assign_weights_to_keras_model=assign_weights_fn)

    logging.info('Training model:')
    logging.info(model_builder().summary())

    training_loop.run(
        iterative_process=training_process,
        client_datasets_fn=client_datasets_fn,
        evaluate_fn=evaluate_fn,
    )
Beispiel #15
0
    def test_iterative_process_with_custom_client_weight_fn_decreases_loss(
            self):
        FLAGS.client_lr_schedule = 'constant'
        FLAGS.server_lr_schedule = 'constant'
        federated_data = [[_batch_fn()]]
        input_spec = _get_input_spec()

        def client_weight_fn(local_outputs):
            return 1.0 / (1.0 + local_outputs['loss'][-1])

        iterproc = iterative_process_builder.from_flags(
            input_spec,
            model_builder,
            loss_builder,
            metrics_builder,
            client_weight_fn=client_weight_fn)
        _, train_outputs = self._run_rounds(iterproc, federated_data, 4)
        self.assertLess(train_outputs[-1]['loss'], train_outputs[0]['loss'])
Beispiel #16
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))
    tf.compat.v1.enable_v2_behavior()

    model_builder = functools.partial(models.create_recurrent_model,
                                      vocab_size=FLAGS.vocab_size,
                                      embedding_size=FLAGS.embedding_size,
                                      latent_size=FLAGS.latent_size,
                                      num_layers=FLAGS.num_layers,
                                      shared_embedding=FLAGS.shared_embedding)

    loss_builder = functools.partial(
        tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True)

    pad_token, oov_token, _, eos_token = dataset.get_special_tokens(
        FLAGS.vocab_size)

    def metrics_builder():
        return [
            keras_metrics.MaskedCategoricalAccuracy(name='accuracy_with_oov',
                                                    masked_tokens=[pad_token]),
            keras_metrics.MaskedCategoricalAccuracy(
                name='accuracy_no_oov', masked_tokens=[pad_token, oov_token]),
            # Notice BOS never appears in ground truth.
            keras_metrics.MaskedCategoricalAccuracy(
                name='accuracy_no_oov_or_eos',
                masked_tokens=[pad_token, oov_token, eos_token]),
            keras_metrics.NumBatchesCounter(),
            keras_metrics.NumTokensCounter(masked_tokens=[pad_token])
        ]

    train_set, validation_set, test_set = dataset.construct_word_level_datasets(
        FLAGS.vocab_size, FLAGS.client_batch_size,
        FLAGS.client_epochs_per_round, FLAGS.sequence_length,
        FLAGS.max_elements_per_user, FLAGS.num_validation_examples)

    input_spec = validation_set.element_spec

    def client_weight_fn(local_outputs):
        # Num_tokens is a tensor with type int64[1], to use as a weight need
        # a float32 scalar.
        return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32)

    training_process = iterative_process_builder.from_flags(
        input_spec=input_spec,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder,
        client_weight_fn=client_weight_fn)

    client_datasets_fn = training_utils.build_client_datasets_fn(
        train_set, FLAGS.clients_per_round)

    assign_weights_fn = fed_avg_schedule.ServerState.assign_weights_to_keras_model

    evaluate_fn = training_utils.build_evaluate_fn(
        model_builder=model_builder,
        eval_dataset=validation_set,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder,
        assign_weights_to_keras_model=assign_weights_fn)

    test_fn = training_utils.build_evaluate_fn(
        model_builder=model_builder,
        # Use both val and test for symmetry with other experiments, which
        # evaluate on the entire test set.
        eval_dataset=validation_set.concatenate(test_set),
        loss_builder=loss_builder,
        metrics_builder=metrics_builder,
        assign_weights_to_keras_model=assign_weights_fn)

    logging.info('Training model:')
    logging.info(model_builder().summary())

    training_loop.run(training_process,
                      client_datasets_fn,
                      evaluate_fn,
                      test_fn=test_fn)
Beispiel #17
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))
    tf.compat.v1.enable_v2_behavior()
    tff.framework.set_default_executor(
        tff.framework.local_executor_factory(max_fanout=10))
    if FLAGS.lstm:

        def _layer_fn(x):
            return tf.keras.layers.LSTM(x, return_sequences=True)
    else:

        def _layer_fn(x):
            return tf.keras.layers.GRU(x, return_sequences=True)

    model_builder = functools.partial(models.create_recurrent_model,
                                      vocab_size=FLAGS.vocab_size,
                                      recurrent_layer_fn=_layer_fn,
                                      shared_embedding=FLAGS.shared_embedding)

    loss_builder = functools.partial(
        tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True)

    pad_token, oov_token, _, eos_token = dataset.get_special_tokens(
        FLAGS.vocab_size)

    def metrics_builder():
        return [
            keras_metrics.FlattenedCategoricalAccuracy(
                # Plus 4 for PAD, OOV, BOS and EOS.
                vocab_size=FLAGS.vocab_size + 4,
                name='accuracy_with_oov',
                masked_tokens=pad_token),
            keras_metrics.FlattenedCategoricalAccuracy(
                vocab_size=FLAGS.vocab_size + 4,
                name='accuracy_no_oov',
                masked_tokens=[pad_token, oov_token]),
            # Notice BOS never appears in ground truth.
            keras_metrics.FlattenedCategoricalAccuracy(
                vocab_size=FLAGS.vocab_size + 4,
                name='accuracy_no_oov_or_eos',
                masked_tokens=[pad_token, oov_token, eos_token]),
            keras_metrics.NumBatchesCounter(),
            keras_metrics.FlattenedNumExamplesCounter(name='num_tokens',
                                                      mask_zero=True),
        ]

    (stackoverflow_train, stackoverflow_validation,
     stackoverflow_test) = dataset.construct_word_level_datasets(
         FLAGS.vocab_size, FLAGS.client_batch_size,
         FLAGS.client_epochs_per_round, FLAGS.sequence_length,
         FLAGS.max_elements_per_user, FLAGS.num_validation_examples)

    sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                         next(iter(stackoverflow_validation)))

    def client_weight_fn(local_outputs):
        # Num_tokens is a tensor with type int64[1], to use as a weight need
        # a float32 scalar.
        return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32)

    training_process = iterative_process_builder.from_flags(
        dummy_batch=sample_batch,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder,
        client_weight_fn=client_weight_fn)

    client_datasets_fn = training_utils.build_client_datasets_fn(
        stackoverflow_train, FLAGS.clients_per_round)

    eval_fn = training_utils.build_evaluate_fn(
        model_builder=model_builder,
        eval_dataset=stackoverflow_validation,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder,
        # Use both val and test for symmetry with other experiments, which
        # evaluate on the entire test set.
        test_dataset=stackoverflow_validation.concatenate(stackoverflow_test))

    logging.info('Training model:')
    logging.info(model_builder().summary())

    training_loop.run(training_process, client_datasets_fn, eval_fn)
Beispiel #18
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Expected no command-line arguments, '
                         'got: {}'.format(argv))

  model_builder = functools.partial(
      stackoverflow_models.create_recurrent_model,
      vocab_size=FLAGS.vocab_size,
      embedding_size=FLAGS.embedding_size,
      latent_size=FLAGS.latent_size,
      num_layers=FLAGS.num_layers,
      shared_embedding=FLAGS.shared_embedding)

  loss_builder = functools.partial(
      tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True)

  pad_token, oov_token, _, eos_token = stackoverflow_dataset.get_special_tokens(
      FLAGS.vocab_size)

  def metrics_builder():
    return [
        keras_metrics.MaskedCategoricalAccuracy(
            name='accuracy_with_oov', masked_tokens=[pad_token]),
        keras_metrics.MaskedCategoricalAccuracy(
            name='accuracy_no_oov', masked_tokens=[pad_token, oov_token]),
        # Notice BOS never appears in ground truth.
        keras_metrics.MaskedCategoricalAccuracy(
            name='accuracy_no_oov_or_eos',
            masked_tokens=[pad_token, oov_token, eos_token]),
        keras_metrics.NumBatchesCounter(),
        keras_metrics.NumTokensCounter(masked_tokens=[pad_token])
    ]

  dataset_vocab = stackoverflow_dataset.create_vocab(FLAGS.vocab_size)

  train_clientdata, _, test_clientdata = (
      tff.simulation.datasets.stackoverflow.load_data())

  # Split the test data into test and validation sets.
  # TODO(b/161914546): consider moving evaluation to use
  # `tff.learning.build_federated_evaluation` to get metrics over client
  # distributions, as well as the example weight means from this centralized
  # evaluation.
  base_test_dataset = test_clientdata.create_tf_dataset_from_all_clients()
  preprocess_val_and_test = stackoverflow_dataset.create_test_dataset_preprocess_fn(
      dataset_vocab, FLAGS.sequence_length)
  test_set = preprocess_val_and_test(
      base_test_dataset.skip(FLAGS.num_validation_examples))
  validation_set = preprocess_val_and_test(
      base_test_dataset.take(FLAGS.num_validation_examples))

  train_dataset_preprocess_comp = stackoverflow_dataset.create_train_dataset_preprocess_fn(
      vocab=stackoverflow_dataset.create_vocab(FLAGS.vocab_size),
      client_batch_size=FLAGS.client_batch_size,
      client_epochs_per_round=FLAGS.client_epochs_per_round,
      max_seq_len=FLAGS.sequence_length,
      max_training_elements_per_user=FLAGS.max_elements_per_user)

  def client_weight_fn(local_outputs):
    # Num_tokens is a tensor with type int64[1], to use as a weight need
    # a float32 scalar.
    return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32)

  training_process = iterative_process_builder.from_flags(
      input_spec=None,  # type pulled from train_dataset_preproces_comp.
      model_builder=model_builder,
      loss_builder=loss_builder,
      metrics_builder=metrics_builder,
      client_weight_fn=client_weight_fn,
      dataset_preprocess_comp=train_dataset_preprocess_comp)

  client_datasets_fn = training_utils.build_client_datasets_fn(
      train_dataset=train_clientdata,
      train_clients_per_round=FLAGS.clients_per_round,
      random_seed=FLAGS.client_datasets_random_seed)

  assign_weights_fn = fed_avg_schedule.ServerState.assign_weights_to_keras_model

  evaluate_fn = training_utils.build_evaluate_fn(
      model_builder=model_builder,
      eval_dataset=validation_set,
      loss_builder=loss_builder,
      metrics_builder=metrics_builder,
      assign_weights_to_keras_model=assign_weights_fn)

  test_fn = training_utils.build_evaluate_fn(
      model_builder=model_builder,
      # Use both val and test for symmetry with other experiments, which
      # evaluate on the entire test set.
      eval_dataset=validation_set.concatenate(test_set),
      loss_builder=loss_builder,
      metrics_builder=metrics_builder,
      assign_weights_to_keras_model=assign_weights_fn)

  logging.info('Training model:')
  logging.info(model_builder().summary())

  training_loop.run(
      training_process, client_datasets_fn, evaluate_fn, test_fn=test_fn)