Пример #1
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    # TODO(b/139129100): Remove this once the local executor is the default.
    tff.framework.set_default_executor(
        tff.framework.local_executor_factory(max_fanout=25))

    emnist_train, _ = dataset.get_emnist_datasets(only_digits=False)
    emnist_train = emnist_train.preprocess(lambda x: x.batch(20))
    input_spec = emnist_train.create_tf_dataset_for_client(
        emnist_train.client_ids[0]).element_spec
    if FLAGS.model == 'cnn':
        model_builder = functools.partial(models.create_conv_dropout_model,
                                          only_digits=False)
    elif FLAGS.model == '2nn':
        model_builder = functools.partial(models.create_two_hidden_layer_model,
                                          only_digits=False)
    else:
        raise ValueError('Cannot handle model flag [{!s}].'.format(
            FLAGS.model))

    tff_model = tff.learning.from_keras_model(
        keras_model=model_builder(),
        input_spec=input_spec,
        loss=tf.keras.losses.SparseCategoricalCrossentropy())

    yogi_init_accum_estimate = optimizer_utils.compute_yogi_init(
        emnist_train, tff_model, num_clients=FLAGS.num_clients)
    logging.info('Yogi initializer: {:s}'.format(
        format(yogi_init_accum_estimate, '10.6E')))
Пример #2
0
def run_experiment():
    """Data preprocessing and experiment execution."""
    emnist_train, emnist_test = dataset.get_emnist_datasets(
        FLAGS.client_batch_size,
        FLAGS.client_epochs_per_round,
        only_digits=FLAGS.only_digits)

    example_dataset = emnist_train.create_tf_dataset_for_client(
        emnist_train.client_ids[0])
    sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                         next(iter(example_dataset)))

    client_datasets_fn = training_utils.build_client_datasets_fn(
        emnist_train, FLAGS.clients_per_round)

    evaluate_fn = training_utils.build_evaluate_fn(
        eval_dataset=emnist_test,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    client_optimizer_fn = functools.partial(
        utils_impl.create_optimizer_from_flags, 'client')
    server_optimizer_fn = functools.partial(
        utils_impl.create_optimizer_from_flags, 'server')

    def tff_model_fn():
        keras_model = model_builder()
        return tff.learning.from_keras_model(keras_model,
                                             dummy_batch=sample_batch,
                                             loss=loss_builder(),
                                             metrics=metrics_builder())

    if FLAGS.use_compression:
        # We create a `StatefulBroadcastFn` and `StatefulAggregateFn` by providing
        # the `_broadcast_encoder_fn` and `_mean_encoder_fn` to corresponding
        # utilities. The fns are called once for each of the model weights created
        # by tff_model_fn, and return instances of appropriate encoders.
        encoded_broadcast_fn = (
            tff.learning.framework.build_encoded_broadcast_from_model(
                tff_model_fn, _broadcast_encoder_fn))
        encoded_mean_fn = tff.learning.framework.build_encoded_mean_from_model(
            tff_model_fn, _mean_encoder_fn)
    else:
        encoded_broadcast_fn = None
        encoded_mean_fn = None

    iterative_process = tff.learning.build_federated_averaging_process(
        model_fn=tff_model_fn,
        client_optimizer_fn=client_optimizer_fn,
        server_optimizer_fn=server_optimizer_fn,
        stateful_delta_aggregate_fn=encoded_mean_fn,
        stateful_model_broadcast_fn=encoded_broadcast_fn)
    iterative_process = compression_process_adapter.CompressionProcessAdapter(
        iterative_process)

    training_loop.run(iterative_process=iterative_process,
                      client_datasets_fn=client_datasets_fn,
                      evaluate_fn=evaluate_fn)
Пример #3
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    tf.compat.v1.enable_v2_behavior()
    # TODO(b/139129100): Remove this once the local executor is the default.
    tff.framework.set_default_executor(
        tff.framework.local_executor_factory(max_fanout=25))

    emnist_train, emnist_test = dataset.get_emnist_datasets(
        FLAGS.client_batch_size,
        FLAGS.client_epochs_per_round,
        only_digits=False)

    sample_client_dataset = emnist_train.create_tf_dataset_for_client(
        emnist_train.client_ids[0])
    # TODO(b/144382142): Sample batches cannot be eager tensors, since they are
    # passed (implicitly) to tff.learning.build_federated_averaging_process.
    sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                         next(iter(sample_client_dataset)))

    if FLAGS.model == 'cnn':
        model_builder = functools.partial(models.create_conv_dropout_model,
                                          only_digits=False)
    elif FLAGS.model == '2nn':
        model_builder = functools.partial(models.create_two_hidden_layer_model,
                                          only_digits=False)
    else:
        raise ValueError('Cannot handle model flag [{!s}].'.format(
            FLAGS.model))

    loss_builder = tf.keras.losses.SparseCategoricalCrossentropy
    metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()]

    training_process = iterative_process_builder.from_flags(
        dummy_batch=sample_batch,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    client_datasets_fn = training_utils.build_client_datasets_fn(
        emnist_train, FLAGS.clients_per_round)

    evaluate_fn = training_utils.build_evaluate_fn(
        eval_dataset=emnist_test,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    logging.info('Training model:')
    logging.info(model_builder().summary())

    training_loop.run(
        iterative_process=training_process,
        client_datasets_fn=client_datasets_fn,
        evaluate_fn=evaluate_fn,
    )
Пример #4
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    tf.compat.v1.enable_v2_behavior()

    emnist_train, emnist_test = dataset.get_emnist_datasets(
        FLAGS.client_batch_size,
        FLAGS.client_epochs_per_round,
        only_digits=False)

    input_spec = emnist_train.create_tf_dataset_for_client(
        emnist_train.client_ids[0]).element_spec

    if FLAGS.model == 'cnn':
        model_builder = functools.partial(models.create_conv_dropout_model,
                                          only_digits=False)
    elif FLAGS.model == '2nn':
        model_builder = functools.partial(models.create_two_hidden_layer_model,
                                          only_digits=False)
    else:
        raise ValueError('Cannot handle model flag [{!s}].'.format(
            FLAGS.model))

    loss_builder = tf.keras.losses.SparseCategoricalCrossentropy
    metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()]

    training_process = iterative_process_builder.from_flags(
        input_spec=input_spec,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    client_datasets_fn = training_utils.build_client_datasets_fn(
        train_dataset=emnist_train,
        train_clients_per_round=FLAGS.clients_per_round,
        random_seed=FLAGS.client_datasets_random_seed)

    assign_weights_fn = fed_avg_schedule.ServerState.assign_weights_to_keras_model

    evaluate_fn = training_utils.build_evaluate_fn(
        eval_dataset=emnist_test,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder,
        assign_weights_to_keras_model=assign_weights_fn)

    logging.info('Training model:')
    logging.info(model_builder().summary())

    training_loop.run(
        iterative_process=training_process,
        client_datasets_fn=client_datasets_fn,
        evaluate_fn=evaluate_fn,
    )
Пример #5
0
 def test_take_with_repeat(self):
   emnist_train, _ = dataset.get_emnist_datasets(
       client_batch_size=10,
       client_epochs_per_round=-1,
       max_batches_per_client=10,
       only_digits=True)
   self.assertEqual(len(emnist_train.client_ids), 3383)
   for i in range(10):
     client_ds = emnist_train.create_tf_dataset_for_client(
         emnist_train.client_ids[i])
     self.assertEqual(_compute_length_of_dataset(client_ds), 10)
Пример #6
0
  def test_emnist_dataset_structure(self):
    emnist_train, emnist_test = dataset.get_emnist_datasets(
        client_batch_size=10, client_epochs_per_round=1, only_digits=True)
    self.assertEqual(len(emnist_train.client_ids), 3383)
    sample_train_ds = emnist_train.create_tf_dataset_for_client(
        emnist_train.client_ids[0])

    train_batch = next(iter(sample_train_ds))
    train_batch_shape = train_batch[0].shape
    test_batch = next(iter(emnist_test))
    test_batch_shape = test_batch[0].shape
    self.assertEqual(train_batch_shape.as_list(), [10, 28, 28, 1])
    self.assertEqual(test_batch_shape.as_list(), [TEST_BATCH_SIZE, 28, 28, 1])
Пример #7
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    tf.compat.v1.enable_v2_behavior()
    # TODO(b/139129100): Remove this once the local executor is the default.
    tff.framework.set_default_executor(
        tff.framework.local_executor_factory(max_fanout=25))

    emnist_train, _ = dataset.get_emnist_datasets(only_digits=False)
    emnist_train = emnist_train.preprocess(lambda x: x.batch(20))
    sample_client_dataset = emnist_train.create_tf_dataset_for_client(
        emnist_train.client_ids[0])
    # TODO(b/144382142): Sample batches cannot be eager tensors, since they are
    # passed (implicitly) to tff.learning.build_federated_averaging_process.
    sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                         next(iter(sample_client_dataset)))
    if FLAGS.model == 'cnn':
        model_builder = functools.partial(models.create_conv_dropout_model,
                                          only_digits=False)
    elif FLAGS.model == '2nn':
        model_builder = functools.partial(models.create_two_hidden_layer_model,
                                          only_digits=False)
    else:
        raise ValueError('Cannot handle model flag [{!s}].'.format(
            FLAGS.model))

    tff_model = tff.learning.from_keras_model(
        keras_model=model_builder(),
        dummy_batch=sample_batch,
        loss=tf.keras.losses.SparseCategoricalCrossentropy())

    yogi_init_accum_estimate = optimizer_utils.compute_yogi_init(
        emnist_train, tff_model, num_clients=FLAGS.num_clients)
    logging.info('Yogi initializer: {:s}'.format(
        format(yogi_init_accum_estimate, '10.6E')))
Пример #8
0
 def test_raises_no_repeat_and_no_take(self):
   with self.assertRaises(ValueError):
     dataset.get_emnist_datasets(
         client_batch_size=10,
         client_epochs_per_round=-1,
         max_batches_per_client=-1)
Пример #9
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Expected no command-line arguments, '
                         'got: {}'.format(argv))

  tf.compat.v1.enable_v2_behavior()

  emnist_train, emnist_test = dataset.get_emnist_datasets(
      FLAGS.client_batch_size, FLAGS.client_epochs_per_round, only_digits=False)

  if FLAGS.model == 'cnn':
    model_builder = functools.partial(
        models.create_conv_dropout_model, only_digits=False)
  elif FLAGS.model == '2nn':
    model_builder = functools.partial(
        models.create_two_hidden_layer_model, only_digits=False)
  else:
    raise ValueError('Cannot handle model flag [{!s}].'.format(FLAGS.model))

  loss_builder = tf.keras.losses.SparseCategoricalCrossentropy
  metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()]

  if FLAGS.uniform_weighting:

    def client_weight_fn(local_outputs):
      del local_outputs
      return 1.0

  else:
    client_weight_fn = None  #  Defaults to the number of examples per client.

  def model_fn():
    return tff.learning.from_keras_model(
        model_builder(),
        loss_builder(),
        input_spec=emnist_test.element_spec,
        metrics=metrics_builder())

  if FLAGS.noise_multiplier is not None:
    if not FLAGS.uniform_weighting:
      raise ValueError(
          'Differential privacy is only implemented for uniform weighting.')

    dp_query = tff.utils.build_dp_query(
        clip=FLAGS.clip,
        noise_multiplier=FLAGS.noise_multiplier,
        expected_total_weight=FLAGS.clients_per_round,
        adaptive_clip_learning_rate=FLAGS.adaptive_clip_learning_rate,
        target_unclipped_quantile=FLAGS.target_unclipped_quantile,
        clipped_count_budget_allocation=FLAGS.clipped_count_budget_allocation,
        expected_num_clients=FLAGS.clients_per_round,
        per_vector_clipping=FLAGS.per_vector_clipping,
        model=model_fn())

    dp_aggregate_fn, _ = tff.utils.build_dp_aggregate(dp_query)
  else:
    dp_aggregate_fn = None

  server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('server')
  client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('client')
  training_process = (
      tff.learning.federated_averaging.build_federated_averaging_process(
          model_fn=model_fn,
          server_optimizer_fn=server_optimizer_fn,
          client_weight_fn=client_weight_fn,
          client_optimizer_fn=client_optimizer_fn,
          stateful_delta_aggregate_fn=dp_aggregate_fn))

  adaptive_clipping = (FLAGS.adaptive_clip_learning_rate > 0)
  training_process = dp_utils.DPFedAvgProcessAdapter(training_process,
                                                     FLAGS.per_vector_clipping,
                                                     adaptive_clipping)

  client_datasets_fn = training_utils.build_client_datasets_fn(
      emnist_train, FLAGS.clients_per_round)

  evaluate_fn = training_utils.build_evaluate_fn(
      eval_dataset=emnist_test,
      model_builder=model_builder,
      loss_builder=loss_builder,
      metrics_builder=metrics_builder,
      assign_weights_to_keras_model=dp_utils.assign_weights_to_keras_model)

  logging.info('Training model:')
  logging.info(model_builder().summary())

  training_loop.run(
      iterative_process=training_process,
      client_datasets_fn=client_datasets_fn,
      evaluate_fn=evaluate_fn,
  )