예제 #1
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tf.enable_v2_behavior()
    # TODO(b/139129100): Remove this once the local executor is the default.
    tff.framework.set_default_executor(
        tff.framework.local_executor_factory(max_fanout=25))

    stackoverflow_train, stackoverflow_validation, stackoverflow_test = dataset.get_stackoverflow_datasets(
        vocab_tokens_size=FLAGS.vocab_tokens_size,
        vocab_tags_size=FLAGS.vocab_tags_size,
        client_batch_size=FLAGS.client_batch_size,
        client_epochs_per_round=FLAGS.client_epochs_per_round,
        max_training_elements_per_user=FLAGS.max_elements_per_user,
        num_validation_examples=FLAGS.num_validation_examples)

    sample_client_dataset = stackoverflow_train.create_tf_dataset_for_client(
        stackoverflow_train.client_ids[0])
    # TODO(b/144382142): Sample batches cannot be eager tensors, since they are
    # passed (implicitly) to tff.learning.build_federated_averaging_process.
    sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                         next(iter(sample_client_dataset)))

    model_builder = functools.partial(
        models.create_logistic_model,
        vocab_tokens_size=FLAGS.vocab_tokens_size,
        vocab_tags_size=FLAGS.vocab_tags_size)

    loss_builder = functools.partial(tf.keras.losses.BinaryCrossentropy,
                                     from_logits=False,
                                     reduction=tf.keras.losses.Reduction.SUM)

    training_process = iterative_process_builder.from_flags(
        dummy_batch=sample_batch,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    client_datasets_fn = training_utils.build_client_datasets_fn(
        stackoverflow_train, FLAGS.clients_per_round)

    evaluate_fn = training_utils.build_evaluate_fn(
        model_builder=model_builder,
        eval_dataset=stackoverflow_validation,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder,
        # Use both val and test for symmetry with other experiments, which
        # evaluate on the entire test set.
        test_dataset=stackoverflow_validation.concatenate(stackoverflow_test))

    logging.info('Training model:')
    logging.info(model_builder().summary())

    training_loop.run(
        iterative_process=training_process,
        client_datasets_fn=client_datasets_fn,
        evaluate_fn=evaluate_fn,
    )
예제 #2
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    tf.compat.v1.enable_v2_behavior()
    # TODO(b/139129100): Remove this once the local executor is the default.
    tff.framework.set_default_executor(
        tff.framework.local_executor_factory(max_fanout=25))

    emnist_train, emnist_test = dataset.get_emnist_datasets(
        FLAGS.client_batch_size,
        FLAGS.client_epochs_per_round,
        only_digits=False)

    sample_client_dataset = emnist_train.create_tf_dataset_for_client(
        emnist_train.client_ids[0])
    # TODO(b/144382142): Sample batches cannot be eager tensors, since they are
    # passed (implicitly) to tff.learning.build_federated_averaging_process.
    sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                         next(iter(sample_client_dataset)))

    if FLAGS.model == 'cnn':
        model_builder = functools.partial(models.create_conv_dropout_model,
                                          only_digits=False)
    elif FLAGS.model == '2nn':
        model_builder = functools.partial(models.create_two_hidden_layer_model,
                                          only_digits=False)
    else:
        raise ValueError('Cannot handle model flag [{!s}].'.format(
            FLAGS.model))

    loss_builder = tf.keras.losses.SparseCategoricalCrossentropy
    metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()]

    training_process = iterative_process_builder.from_flags(
        dummy_batch=sample_batch,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    client_datasets_fn = training_utils.build_client_datasets_fn(
        emnist_train, FLAGS.clients_per_round)

    evaluate_fn = training_utils.build_evaluate_fn(
        eval_dataset=emnist_test,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    logging.info('Training model:')
    logging.info(model_builder().summary())

    training_loop.run(
        iterative_process=training_process,
        client_datasets_fn=client_datasets_fn,
        evaluate_fn=evaluate_fn,
    )
 def test_build_client_datasets_fn(self):
     tff_dataset = tff.simulation.client_data.ConcreteClientData(
         [2], create_tf_dataset_for_client)
     client_datasets_fn = training_utils.build_client_datasets_fn(
         tff_dataset, 1)
     client_datasets, client_ids = client_datasets_fn(7)
     sample_batch = next(iter(client_datasets[0]))
     reference_batch = next(iter(create_tf_dataset_for_client(2)))
     self.assertAllClose(sample_batch, reference_batch)
     self.assertEqual(client_ids, [2])
예제 #4
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tf.compat.v1.enable_v2_behavior()
    tff.framework.set_default_executor(
        tff.framework.local_executor_factory(max_fanout=25))

    train_clientdata, test_dataset = dataset.construct_character_level_datasets(
        FLAGS.client_batch_size, FLAGS.client_epochs_per_round,
        FLAGS.sequence_length)
    test_dataset = test_dataset.cache()

    loss_fn_builder = functools.partial(
        tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True)

    # Need to iterate until we find a client with data.
    for client_id in train_clientdata.client_ids:
        try:
            sample_batch = next(
                iter(train_clientdata.create_tf_dataset_for_client(client_id)))
            break
        except StopIteration:
            pass  # Client had no batches.
    sample_batch = tf.nest.map_structure(lambda t: t.numpy(), sample_batch)

    def client_weight_fn(local_outputs):
        # Num_tokens is a tensor with type int64[1], to use as a weight need
        # a float32 scalar.
        return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32)

    training_process = iterative_process_builder.from_flags(
        dummy_batch=sample_batch,
        model_builder=model_builder,
        loss_builder=loss_fn_builder,
        metrics_builder=metrics_builder,
        client_weight_fn=client_weight_fn)

    logging.info('Training model:')
    logging.info(model_builder().summary())

    training_loop.run(
        iterative_process=training_process,
        client_datasets_fn=training_utils.build_client_datasets_fn(
            train_clientdata, FLAGS.clients_per_round),
        evaluate_fn=training_utils.build_evaluate_fn(
            eval_dataset=test_dataset,
            model_builder=model_builder,
            loss_builder=loss_fn_builder,
            metrics_builder=metrics_builder),
    )
예제 #5
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Expected no command-line arguments, '
                         'got: {}'.format(argv))

  tf.compat.v1.enable_v2_behavior()
  # TODO(b/139129100): Remove this once the local executor is the default.
  tff.framework.set_default_executor(
      tff.framework.local_executor_factory(max_fanout=25))

  cifar_train, cifar_test = dataset.get_federated_cifar100(
      client_epochs_per_round=FLAGS.client_epochs_per_round,
      train_batch_size=FLAGS.client_batch_size,
      crop_shape=CROP_SHAPE)

  sample_client_dataset = cifar_train.create_tf_dataset_for_client(
      cifar_train.client_ids[0])

  sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                       next(iter(sample_client_dataset)))

  model_builder = functools.partial(
      resnet_models.create_resnet18,
      input_shape=CROP_SHAPE,
      num_classes=NUM_CLASSES)

  logging.info('Training model:')
  logging.info(model_builder().summary())

  loss_builder = tf.keras.losses.SparseCategoricalCrossentropy
  metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()]

  training_process = iterative_process_builder.from_flags(
      dummy_batch=sample_batch,
      model_builder=model_builder,
      loss_builder=loss_builder,
      metrics_builder=metrics_builder)

  training_loop.run(
      iterative_process=training_process,
      client_datasets_fn=training_utils.build_client_datasets_fn(
          cifar_train, FLAGS.clients_per_round),
      evaluate_fn=training_utils.build_evaluate_fn(
          eval_dataset=cifar_test,
          model_builder=model_builder,
          loss_builder=loss_builder,
          metrics_builder=metrics_builder),
  )
예제 #6
0
def main(_):

    tf.enable_v2_behavior()
    # TODO(b/139129100): Remove this once the local executor is the default.
    tff.framework.set_default_executor(
        tff.framework.local_executor_factory(max_fanout=25))

    emnist_train, emnist_test = dataset.get_emnist_datasets(
        FLAGS.client_batch_size,
        FLAGS.client_epochs_per_round,
        only_digits=False)

    sample_client_dataset = emnist_train.create_tf_dataset_for_client(
        emnist_train.client_ids[0])
    # TODO(b/144382142): Sample batches cannot be eager tensors, since they are
    # passed (implicitly) to tff.learning.build_federated_averaging_process.
    sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                         next(iter(sample_client_dataset)))

    model_builder = models.create_autoencoder_model

    loss_builder = tf.keras.losses.MeanSquaredError
    metrics_builder = lambda: [tf.keras.metrics.MeanSquaredError()]

    training_process = iterative_process_builder.from_flags(
        dummy_batch=sample_batch,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    client_datasets_fn = training_utils.build_client_datasets_fn(
        emnist_train, FLAGS.clients_per_round)

    evaluate_fn = training_utils.build_evaluate_fn(
        eval_dataset=emnist_test,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    logging.info('Training model:')
    logging.info(model_builder().summary())

    training_loop.run(
        iterative_process=training_process,
        client_datasets_fn=client_datasets_fn,
        evaluate_fn=evaluate_fn,
    )
예제 #7
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))
    tf.compat.v1.enable_v2_behavior()
    tff.framework.set_default_executor(
        tff.framework.local_executor_factory(max_fanout=10))
    if FLAGS.lstm:

        def _layer_fn(x):
            return tf.keras.layers.LSTM(x, return_sequences=True)
    else:

        def _layer_fn(x):
            return tf.keras.layers.GRU(x, return_sequences=True)

    model_builder = functools.partial(models.create_recurrent_model,
                                      vocab_size=FLAGS.vocab_size,
                                      recurrent_layer_fn=_layer_fn,
                                      shared_embedding=FLAGS.shared_embedding)

    loss_builder = functools.partial(
        tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True)

    pad_token, oov_token, _, eos_token = dataset.get_special_tokens(
        FLAGS.vocab_size)

    def metrics_builder():
        return [
            keras_metrics.FlattenedCategoricalAccuracy(
                # Plus 4 for PAD, OOV, BOS and EOS.
                vocab_size=FLAGS.vocab_size + 4,
                name='accuracy_with_oov',
                masked_tokens=pad_token),
            keras_metrics.FlattenedCategoricalAccuracy(
                vocab_size=FLAGS.vocab_size + 4,
                name='accuracy_no_oov',
                masked_tokens=[pad_token, oov_token]),
            # Notice BOS never appears in ground truth.
            keras_metrics.FlattenedCategoricalAccuracy(
                vocab_size=FLAGS.vocab_size + 4,
                name='accuracy_no_oov_or_eos',
                masked_tokens=[pad_token, oov_token, eos_token]),
            keras_metrics.NumBatchesCounter(),
            keras_metrics.FlattenedNumExamplesCounter(name='num_tokens',
                                                      mask_zero=True),
        ]

    (stackoverflow_train, stackoverflow_validation,
     stackoverflow_test) = dataset.construct_word_level_datasets(
         FLAGS.vocab_size, FLAGS.client_batch_size,
         FLAGS.client_epochs_per_round, FLAGS.sequence_length,
         FLAGS.max_elements_per_user, FLAGS.num_validation_examples)

    sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                         next(iter(stackoverflow_validation)))

    def client_weight_fn(local_outputs):
        # Num_tokens is a tensor with type int64[1], to use as a weight need
        # a float32 scalar.
        return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32)

    training_process = iterative_process_builder.from_flags(
        dummy_batch=sample_batch,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder,
        client_weight_fn=client_weight_fn)

    client_datasets_fn = training_utils.build_client_datasets_fn(
        stackoverflow_train, FLAGS.clients_per_round)

    eval_fn = training_utils.build_evaluate_fn(
        model_builder=model_builder,
        eval_dataset=stackoverflow_validation,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder,
        # Use both val and test for symmetry with other experiments, which
        # evaluate on the entire test set.
        test_dataset=stackoverflow_validation.concatenate(stackoverflow_test))

    logging.info('Training model:')
    logging.info(model_builder().summary())

    training_loop.run(training_process, client_datasets_fn, eval_fn)