Beispiel #1
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tf.compat.v1.enable_v2_behavior()

    experiment_output_dir = FLAGS.root_output_dir
    tensorboard_dir = os.path.join(experiment_output_dir, 'logdir',
                                   FLAGS.experiment_name)
    results_dir = os.path.join(experiment_output_dir, 'results',
                               FLAGS.experiment_name)

    for path in [experiment_output_dir, tensorboard_dir, results_dir]:
        try:
            tf.io.gfile.makedirs(path)
        except tf.errors.OpError:
            pass  # Directory already exists.

    hparam_dict = collections.OrderedDict([(name, FLAGS[name].value)
                                           for name in hparam_flags])
    hparam_dict['results_file'] = results_dir
    hparams_file = os.path.join(results_dir, 'hparams.csv')
    logging.info('Saving hyper parameters to: [%s]', hparams_file)
    utils_impl.atomic_write_to_csv(pd.Series(hparam_dict), hparams_file)

    train_client_data, test_client_data = (
        tff.simulation.datasets.shakespeare.load_data())

    def preprocess(ds):
        return dataset.convert_snippets_to_character_sequence_examples(
            ds, FLAGS.batch_size, epochs=1).cache()

    train_dataset = train_client_data.create_tf_dataset_from_all_clients()
    if FLAGS.shuffle_train_data:
        train_dataset = train_dataset.shuffle(buffer_size=10000)
    train_dataset = preprocess(train_dataset)

    eval_dataset = preprocess(
        test_client_data.create_tf_dataset_from_all_clients())

    optimizer = optimizer_utils.create_optimizer_fn_from_flags('centralized')()

    pad_token, _, _, _ = dataset.get_special_tokens()

    # Vocabulary with one OOV ID and zero for the mask.
    vocab_size = len(dataset.CHAR_VOCAB) + 2
    model = models.create_recurrent_model(vocab_size=vocab_size,
                                          batch_size=FLAGS.batch_size)
    model.compile(
        optimizer=optimizer,
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[
            keras_metrics.MaskedCategoricalAccuracy(masked_tokens=[pad_token])
        ])

    logging.info('Training model:')
    logging.info(model.summary())

    csv_logger_callback = keras_callbacks.AtomicCSVLogger(results_dir)
    tensorboard_callback = tf.keras.callbacks.TensorBoard(
        log_dir=tensorboard_dir)

    # Reduce the learning rate every 20 epochs.
    def decay_lr(epoch, lr):
        if (epoch + 1) % 20 == 0:
            return lr * 0.1
        else:
            return lr

    lr_callback = tf.keras.callbacks.LearningRateScheduler(decay_lr, verbose=1)

    history = model.fit(
        train_dataset,
        validation_data=eval_dataset,
        epochs=FLAGS.num_epochs,
        callbacks=[lr_callback, tensorboard_callback, csv_logger_callback])

    logging.info('Final metrics:')
    for name in ['loss', 'accuracy']:
        metric = history.history['val_{}'.format(name)][-1]
        logging.info('\t%s: %.4f', name, metric)
Beispiel #2
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tf.enable_v2_behavior()

    experiment_output_dir = FLAGS.root_output_dir
    tensorboard_dir = os.path.join(experiment_output_dir, 'logdir',
                                   FLAGS.experiment_name)
    results_dir = os.path.join(experiment_output_dir, 'results',
                               FLAGS.experiment_name)

    for path in [experiment_output_dir, tensorboard_dir, results_dir]:
        try:
            tf.io.gfile.makedirs(path)
        except tf.errors.OpError:
            pass  # Directory already exists.

    hparam_dict = collections.OrderedDict([(name, FLAGS[name].value)
                                           for name in hparam_flags])
    hparam_dict['results_file'] = results_dir
    hparams_file = os.path.join(results_dir, 'hparams.csv')

    logging.info('Saving hyper parameters to: [%s]', hparams_file)
    utils_impl.atomic_write_to_csv(pd.Series(hparam_dict), hparams_file)

    train_dataset, eval_dataset = dataset.get_centralized_stackoverflow_datasets(
        batch_size=FLAGS.batch_size,
        vocab_tokens_size=FLAGS.vocab_tokens_size,
        vocab_tags_size=FLAGS.vocab_tags_size)

    optimizer = optimizer_utils.create_optimizer_fn_from_flags('centralized')()

    model = models.create_logistic_model(
        vocab_tokens_size=FLAGS.vocab_tokens_size,
        vocab_tags_size=FLAGS.vocab_tags_size)

    model.compile(loss=tf.keras.losses.BinaryCrossentropy(
        from_logits=False, reduction=tf.keras.losses.Reduction.SUM),
                  optimizer=optimizer,
                  metrics=[
                      tf.keras.metrics.Precision(),
                      tf.keras.metrics.Recall(top_k=5)
                  ])

    logging.info('Training model:')
    logging.info(model.summary())

    csv_logger_callback = keras_callbacks.AtomicCSVLogger(results_dir)
    tensorboard_callback = tf.keras.callbacks.TensorBoard(
        log_dir=tensorboard_dir)

    # Reduce the learning rate after a fixed number of epochs.
    def decay_lr(epoch, learning_rate):
        if (epoch + 1) % FLAGS.decay_epochs == 0:
            return learning_rate * FLAGS.lr_decay
        else:
            return learning_rate

    lr_callback = tf.keras.callbacks.LearningRateScheduler(decay_lr, verbose=1)

    history = model.fit(
        train_dataset,
        validation_data=eval_dataset,
        epochs=FLAGS.num_epochs,
        callbacks=[lr_callback, tensorboard_callback, csv_logger_callback])

    logging.info('Final metrics:')
    for name in ['loss', 'precision', 'recall']:
        metric = history.history['val_{}'.format(name)][-1]
        logging.info('\t%s: %.4f', name, metric)
Beispiel #3
0
def run_experiment():
  """Runs the training experiment."""
  _, validation_dataset, test_dataset = dataset.construct_word_level_datasets(
      FLAGS.vocab_size, FLAGS.batch_size, 1, FLAGS.sequence_length, -1,
      FLAGS.num_validation_examples)
  train_dataset = dataset.get_centralized_train_dataset(
      FLAGS.vocab_size, FLAGS.batch_size, FLAGS.sequence_length,
      FLAGS.shuffle_buffer_size)

  model = models.create_recurrent_model(
      vocab_size=FLAGS.vocab_size,
      name='stackoverflow-lstm',
      embedding_size=FLAGS.embedding_size,
      latent_size=FLAGS.latent_size,
      num_layers=FLAGS.num_layers,
      shared_embedding=FLAGS.shared_embedding)

  logging.info('Training model: %s', model.summary())
  optimizer = optimizer_utils.create_optimizer_fn_from_flags('centralized')()
  pad_token, oov_token, _, eos_token = dataset.get_special_tokens(
      FLAGS.vocab_size)
  model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      optimizer=optimizer,
      metrics=[
          # Plus 4 for pad, oov, bos, eos
          keras_metrics.MaskedCategoricalAccuracy(
              name='accuracy_with_oov', masked_tokens=[pad_token]),
          keras_metrics.MaskedCategoricalAccuracy(
              name='accuracy_no_oov', masked_tokens=[pad_token, oov_token]),
          keras_metrics.MaskedCategoricalAccuracy(
              name='accuracy_no_oov_or_eos',
              masked_tokens=[pad_token, oov_token, eos_token]),
      ])

  train_results_path = os.path.join(FLAGS.root_output_dir, 'train_results',
                                    FLAGS.experiment_name)
  test_results_path = os.path.join(FLAGS.root_output_dir, 'test_results',
                                   FLAGS.experiment_name)

  train_csv_logger = keras_callbacks.AtomicCSVLogger(train_results_path)
  test_csv_logger = keras_callbacks.AtomicCSVLogger(test_results_path)

  log_dir = os.path.join(FLAGS.root_output_dir, 'logdir', FLAGS.experiment_name)
  try:
    tf.io.gfile.makedirs(log_dir)
    tf.io.gfile.makedirs(train_results_path)
    tf.io.gfile.makedirs(test_results_path)
  except tf.errors.OpError:
    pass  # log_dir already exists.

  train_tensorboard_callback = tf.keras.callbacks.TensorBoard(
      log_dir=log_dir,
      write_graph=True,
      update_freq=FLAGS.tensorboard_update_frequency)

  test_tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)

  # Write the hyperparameters to a CSV:
  hparam_dict = collections.OrderedDict([
      (name, FLAGS[name].value) for name in hparam_flags
  ])
  hparams_file = os.path.join(FLAGS.root_output_dir, FLAGS.experiment_name,
                              'hparams.csv')
  utils_impl.atomic_write_to_csv(pd.Series(hparam_dict), hparams_file)

  model.fit(
      train_dataset,
      epochs=FLAGS.epochs,
      verbose=0,
      validation_data=validation_dataset,
      callbacks=[train_csv_logger, train_tensorboard_callback])
  score = model.evaluate(
      test_dataset,
      verbose=0,
      callbacks=[test_csv_logger, test_tensorboard_callback])
  logging.info('Final test loss: %.4f', score[0])
  logging.info('Final test accuracy: %.4f', score[1])
Beispiel #4
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))
    tff.backends.native.set_local_execution_context(max_fanout=10)

    model_builder = functools.partial(
        stackoverflow_models.create_recurrent_model,
        vocab_size=FLAGS.vocab_size,
        embedding_size=FLAGS.embedding_size,
        latent_size=FLAGS.latent_size,
        num_layers=FLAGS.num_layers,
        shared_embedding=FLAGS.shared_embedding)

    loss_builder = functools.partial(
        tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True)

    special_tokens = stackoverflow_dataset.get_special_tokens(FLAGS.vocab_size)
    pad_token = special_tokens.pad
    oov_tokens = special_tokens.oov
    eos_token = special_tokens.eos

    def metrics_builder():
        return [
            keras_metrics.MaskedCategoricalAccuracy(name='accuracy_with_oov',
                                                    masked_tokens=[pad_token]),
            keras_metrics.MaskedCategoricalAccuracy(name='accuracy_no_oov',
                                                    masked_tokens=[pad_token] +
                                                    oov_tokens),
            # Notice BOS never appears in ground truth.
            keras_metrics.MaskedCategoricalAccuracy(
                name='accuracy_no_oov_or_eos',
                masked_tokens=[pad_token, eos_token] + oov_tokens),
            keras_metrics.NumBatchesCounter(),
            keras_metrics.NumTokensCounter(masked_tokens=[pad_token]),
        ]

    datasets = stackoverflow_dataset.construct_word_level_datasets(
        FLAGS.vocab_size, FLAGS.client_batch_size,
        FLAGS.client_epochs_per_round, FLAGS.sequence_length,
        FLAGS.max_elements_per_user, FLAGS.num_validation_examples)
    train_dataset, validation_dataset, test_dataset = datasets

    if FLAGS.uniform_weighting:

        def client_weight_fn(local_outputs):
            del local_outputs
            return 1.0
    else:

        def client_weight_fn(local_outputs):
            return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32)

    def model_fn():
        return tff.learning.from_keras_model(
            model_builder(),
            loss_builder(),
            input_spec=validation_dataset.element_spec,
            metrics=metrics_builder())

    if FLAGS.noise_multiplier is not None:
        if not FLAGS.uniform_weighting:
            raise ValueError(
                'Differential privacy is only implemented for uniform weighting.'
            )

        dp_query = tff.utils.build_dp_query(
            clip=FLAGS.clip,
            noise_multiplier=FLAGS.noise_multiplier,
            expected_total_weight=FLAGS.clients_per_round,
            adaptive_clip_learning_rate=FLAGS.adaptive_clip_learning_rate,
            target_unclipped_quantile=FLAGS.target_unclipped_quantile,
            clipped_count_budget_allocation=FLAGS.
            clipped_count_budget_allocation,
            expected_num_clients=FLAGS.clients_per_round,
            per_vector_clipping=FLAGS.per_vector_clipping,
            model=model_fn())

        weights_type = tff.learning.framework.weights_type_from_model(model_fn)
        aggregation_process = tff.utils.build_dp_aggregate_process(
            weights_type.trainable, dp_query)
    else:
        aggregation_process = None

    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')
    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')

    iterative_process = tff.learning.build_federated_averaging_process(
        model_fn=model_fn,
        server_optimizer_fn=server_optimizer_fn,
        client_weight_fn=client_weight_fn,
        client_optimizer_fn=client_optimizer_fn,
        aggregation_process=aggregation_process)

    client_datasets_fn = training_utils.build_client_datasets_fn(
        train_dataset, FLAGS.clients_per_round)

    evaluate_fn = training_utils.build_evaluate_fn(
        model_builder=model_builder,
        eval_dataset=validation_dataset,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder,
        assign_weights_to_keras_model=dp_utils.assign_weights_to_keras_model)

    test_fn = training_utils.build_evaluate_fn(
        model_builder=model_builder,
        # Use both val and test for symmetry with other experiments, which
        # evaluate on the entire test set.
        eval_dataset=validation_dataset.concatenate(test_dataset),
        loss_builder=loss_builder,
        metrics_builder=metrics_builder,
        assign_weights_to_keras_model=dp_utils.assign_weights_to_keras_model)

    logging.info('Training model:')
    logging.info(model_builder().summary())

    training_loop.run(iterative_process,
                      client_datasets_fn,
                      evaluate_fn,
                      test_fn=test_fn)
Beispiel #5
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    experiment_output_dir = FLAGS.root_output_dir
    tensorboard_dir = os.path.join(experiment_output_dir, 'logdir',
                                   FLAGS.experiment_name)
    results_dir = os.path.join(experiment_output_dir, 'results',
                               FLAGS.experiment_name)

    for path in [experiment_output_dir, tensorboard_dir, results_dir]:
        try:
            tf.io.gfile.makedirs(path)
        except tf.errors.OpError:
            pass  # Directory already exists.

    hparam_dict = collections.OrderedDict([(name, FLAGS[name].value)
                                           for name in hparam_flags])
    hparam_dict['results_file'] = results_dir
    hparams_file = os.path.join(results_dir, 'hparams.csv')

    logging.info('Saving hyper parameters to: [%s]', hparams_file)
    utils_impl.atomic_write_to_csv(pd.Series(hparam_dict), hparams_file)

    train_dataset, eval_dataset = emnist_dataset.get_centralized_emnist_datasets(
        batch_size=FLAGS.batch_size, only_digits=False)

    optimizer = optimizer_utils.create_optimizer_fn_from_flags('centralized')()

    if FLAGS.model == 'cnn':
        model = emnist_models.create_conv_dropout_model(only_digits=False)
    elif FLAGS.model == '2nn':
        model = emnist_models.create_two_hidden_layer_model(only_digits=False)
    else:
        raise ValueError('Cannot handle model flag [{!s}].'.format(
            FLAGS.model))

    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(),
                  optimizer=optimizer,
                  metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

    logging.info('Training model:')
    logging.info(model.summary())

    csv_logger_callback = keras_callbacks.AtomicCSVLogger(results_dir)
    tensorboard_callback = tf.keras.callbacks.TensorBoard(
        log_dir=tensorboard_dir)

    # Reduce the learning rate after a fixed number of epochs.
    def decay_lr(epoch, learning_rate):
        if (epoch + 1) % FLAGS.decay_epochs == 0:
            return learning_rate * FLAGS.lr_decay
        else:
            return learning_rate

    lr_callback = tf.keras.callbacks.LearningRateScheduler(decay_lr, verbose=1)

    history = model.fit(
        train_dataset,
        validation_data=eval_dataset,
        epochs=FLAGS.num_epochs,
        callbacks=[lr_callback, tensorboard_callback, csv_logger_callback])

    logging.info('Final metrics:')
    for name in ['loss', 'sparse_categorical_accuracy']:
        metric = history.history['val_{}'.format(name)][-1]
        logging.info('\t%s: %.4f', name, metric)
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Expected no command-line arguments, '
                         'got: {}'.format(argv))

  client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('client')
  server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('server')

  client_lr_callback = callbacks.create_reduce_lr_on_plateau(
      learning_rate=FLAGS.client_learning_rate,
      decay_factor=FLAGS.client_decay_factor,
      min_delta=FLAGS.min_delta,
      min_lr=FLAGS.min_lr,
      window_size=FLAGS.window_size,
      patience=FLAGS.patience)

  server_lr_callback = callbacks.create_reduce_lr_on_plateau(
      learning_rate=FLAGS.server_learning_rate,
      decay_factor=FLAGS.server_decay_factor,
      min_delta=FLAGS.min_delta,
      min_lr=FLAGS.min_lr,
      window_size=FLAGS.window_size,
      patience=FLAGS.patience)

  def iterative_process_builder(
      model_fn: Callable[[], tff.learning.Model],
      client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
  ) -> tff.templates.IterativeProcess:
    """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.
      client_weight_fn: Optional function that takes the output of
        `model.report_local_outputs` and returns a tensor providing the weight
        in the federated average of model deltas. If not provided, the default
        is the total number of examples processed on device.

    Returns:
      A `tff.templates.IterativeProcess`.
    """

    return adaptive_fed_avg.build_fed_avg_process(
        model_fn,
        client_lr_callback,
        server_lr_callback,
        client_optimizer_fn=client_optimizer_fn,
        server_optimizer_fn=server_optimizer_fn,
        client_weight_fn=client_weight_fn)

  hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags())

  shared_args = utils_impl.lookup_flag_values(shared_flags)
  shared_args['iterative_process_builder'] = iterative_process_builder

  if FLAGS.task == 'cifar100':
    hparam_dict['cifar100_crop_size'] = FLAGS.cifar100_crop_size
    federated_cifar100.run_federated(
        **shared_args,
        crop_size=FLAGS.cifar100_crop_size,
        hparam_dict=hparam_dict)

  elif FLAGS.task == 'emnist_cr':
    federated_emnist.run_federated(
        **shared_args, model=FLAGS.emnist_cr_model, hparam_dict=hparam_dict)

  elif FLAGS.task == 'emnist_ae':
    federated_emnist_ae.run_federated(**shared_args, hparam_dict=hparam_dict)

  elif FLAGS.task == 'shakespeare':
    federated_shakespeare.run_federated(
        **shared_args,
        sequence_length=FLAGS.shakespeare_sequence_length,
        hparam_dict=hparam_dict)

  elif FLAGS.task == 'stackoverflow_nwp':
    so_nwp_flags = collections.OrderedDict()
    for flag_name in task_flags:
      if flag_name.startswith('so_nwp_'):
        so_nwp_flags[flag_name[7:]] = FLAGS[flag_name].value
    federated_stackoverflow.run_federated(
        **shared_args, **so_nwp_flags, hparam_dict=hparam_dict)

  elif FLAGS.task == 'stackoverflow_lr':
    so_lr_flags = collections.OrderedDict()
    for flag_name in task_flags:
      if flag_name.startswith('so_lr_'):
        so_lr_flags[flag_name[6:]] = FLAGS[flag_name].value
    federated_stackoverflow_lr.run_federated(
        **shared_args, **so_lr_flags, hparam_dict=hparam_dict)
 def test_create_optimizer_fn_from_flags_invalid_optimizer(self):
     FLAGS['{}_optimizer'.format(TEST_CLIENT_FLAG_PREFIX)].value = 'foo'
     with self.assertRaisesRegex(ValueError, 'not a valid optimizer'):
         optimizer_utils.create_optimizer_fn_from_flags(
             TEST_CLIENT_FLAG_PREFIX)
Beispiel #8
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    emnist_train, emnist_test = emnist_dataset.get_emnist_datasets(
        FLAGS.client_batch_size,
        FLAGS.client_epochs_per_round,
        only_digits=False)

    if FLAGS.model == 'cnn':
        model_builder = functools.partial(
            emnist_models.create_conv_dropout_model, only_digits=False)
    elif FLAGS.model == '2nn':
        model_builder = functools.partial(
            emnist_models.create_two_hidden_layer_model, only_digits=False)
    else:
        raise ValueError('Cannot handle model flag [{!s}].'.format(
            FLAGS.model))

    loss_builder = tf.keras.losses.SparseCategoricalCrossentropy
    metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()]

    if FLAGS.uniform_weighting:

        def client_weight_fn(local_outputs):
            del local_outputs
            return 1.0

    else:
        client_weight_fn = None  #  Defaults to the number of examples per client.

    def model_fn():
        return tff.learning.from_keras_model(
            model_builder(),
            loss_builder(),
            input_spec=emnist_test.element_spec,
            metrics=metrics_builder())

    if FLAGS.noise_multiplier is not None:
        if not FLAGS.uniform_weighting:
            raise ValueError(
                'Differential privacy is only implemented for uniform weighting.'
            )

        dp_query = tff.utils.build_dp_query(
            clip=FLAGS.clip,
            noise_multiplier=FLAGS.noise_multiplier,
            expected_total_weight=FLAGS.clients_per_round,
            adaptive_clip_learning_rate=FLAGS.adaptive_clip_learning_rate,
            target_unclipped_quantile=FLAGS.target_unclipped_quantile,
            clipped_count_budget_allocation=FLAGS.
            clipped_count_budget_allocation,
            expected_num_clients=FLAGS.clients_per_round,
            per_vector_clipping=FLAGS.per_vector_clipping,
            model=model_fn())

        dp_aggregate_fn, _ = tff.utils.build_dp_aggregate(dp_query)
    else:
        dp_aggregate_fn = None

    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')
    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    training_process = (
        tff.learning.federated_averaging.build_federated_averaging_process(
            model_fn=model_fn,
            server_optimizer_fn=server_optimizer_fn,
            client_weight_fn=client_weight_fn,
            client_optimizer_fn=client_optimizer_fn,
            stateful_delta_aggregate_fn=dp_aggregate_fn))

    adaptive_clipping = (FLAGS.adaptive_clip_learning_rate > 0)
    training_process = dp_utils.DPFedAvgProcessAdapter(
        training_process, FLAGS.per_vector_clipping, adaptive_clipping)

    client_datasets_fn = training_utils.build_client_datasets_fn(
        emnist_train, FLAGS.clients_per_round)

    evaluate_fn = training_utils.build_evaluate_fn(
        eval_dataset=emnist_test,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder,
        assign_weights_to_keras_model=dp_utils.assign_weights_to_keras_model)

    logging.info('Training model:')
    logging.info(model_builder().summary())

    training_loop.run(
        iterative_process=training_process,
        client_datasets_fn=client_datasets_fn,
        evaluate_fn=evaluate_fn,
    )
def from_flags(
    input_spec,
    model_builder: ModelBuilder,
    loss_builder: LossBuilder,
    metrics_builder: MetricsBuilder,
    client_weight_fn: Optional[ClientWeightFn] = None,
    *,
    dataset_preprocess_comp: Optional[tff.Computation] = None,
) -> fed_avg_schedule.FederatedAveragingProcessAdapter:
  """Builds a `tff.templates.IterativeProcess` instance from flags.

  The iterative process is designed to incorporate learning rate schedules,
  which are configured via flags.

  Args:
    input_spec: A value convertible to a `tff.Type`, representing the data which
      will be fed into the `tff.templates.IterativeProcess.next` function over
      the course of training. Generally, this can be found by accessing the
      `element_spec` attribute of a client `tf.data.Dataset`.
    model_builder: A no-arg function that returns an uncompiled `tf.keras.Model`
      object.
    loss_builder: A no-arg function returning a `tf.keras.losses.Loss` object.
    metrics_builder: A no-arg function that returns a list of
      `tf.keras.metrics.Metric` objects.
    client_weight_fn: An optional callable that takes the result of
      `tff.learning.Model.report_local_outputs` from the model returned by
      `model_builder`, and returns a scalar client weight. If `None`, defaults
      to the number of examples processed over all batches.
    dataset_preprocess_comp: Optional `tff.Computation` that sets up a data
      pipeline on the clients. The computation must take a squence of values
      and return a sequence of values, or in TFF type shorthand `(U* -> V*)`. If
      `None`, no dataset preprocessing is applied. If specified, `input_spec` is
      optinal, as the necessary type signatures will taken from the computation.

  Returns:
    A `fed_avg_schedule.FederatedAveragingProcessAdapter`.
  """
  # TODO(b/147808007): Assert that model_builder() returns an uncompiled keras
  # model.
  client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('client')
  server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('server')

  client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags('client')
  server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags('server')

  if dataset_preprocess_comp is not None:
    if input_spec is not None:
      print('Specified both `dataset_preprocess_comp` and `input_spec` when '
            'only one is necessary. Ignoring `input_spec` and using type '
            'signature of `dataset_preprocess_comp`.')
    model_input_spec = dataset_preprocess_comp.type_signature.result.element
  else:
    model_input_spec = input_spec

  def tff_model_fn() -> tff.learning.Model:
    return tff.learning.from_keras_model(
        keras_model=model_builder(),
        input_spec=model_input_spec,
        loss=loss_builder(),
        metrics=metrics_builder())

  return fed_avg_schedule.build_fed_avg_process(
      model_fn=tff_model_fn,
      client_optimizer_fn=client_optimizer_fn,
      client_lr=client_lr_schedule,
      server_optimizer_fn=server_optimizer_fn,
      server_lr=server_lr_schedule,
      client_weight_fn=client_weight_fn,
      dataset_preprocess_comp=dataset_preprocess_comp)
Beispiel #10
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Expected no command-line arguments, '
                         'got: {}'.format(argv))

  client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('client')
  server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('server')

  client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags('client')
  server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags('server')

  def iterative_process_builder(
      model_fn: Callable[[], tff.learning.Model],
      client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
  ) -> tff.templates.IterativeProcess:
    """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.
      client_weight_fn: Optional function that takes the output of
        `model.report_local_outputs` and returns a tensor providing the weight
        in the federated average of model deltas. If not provided, the default
        is the total number of examples processed on device.

    Returns:
      A `tff.templates.IterativeProcess`.
    """

    return fed_avg_schedule.build_fed_avg_process(
        model_fn=model_fn,
        client_optimizer_fn=client_optimizer_fn,
        client_lr=client_lr_schedule,
        server_optimizer_fn=server_optimizer_fn,
        server_lr=server_lr_schedule,
        client_weight_fn=client_weight_fn)

  assign_weights_fn = fed_avg_schedule.ServerState.assign_weights_to_keras_model
  hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags())

  shared_args = utils_impl.lookup_flag_values(shared_flags)
  shared_args['iterative_process_builder'] = iterative_process_builder
  shared_args['assign_weights_fn'] = assign_weights_fn

  if FLAGS.task == 'cifar100':
    hparam_dict['cifar100_crop_size'] = FLAGS.cifar100_crop_size
    federated_cifar100.run_federated(
        **shared_args,
        crop_size=FLAGS.cifar100_crop_size,
        hparam_dict=hparam_dict)

  elif FLAGS.task == 'emnist_cr':
    federated_emnist.run_federated(
        **shared_args,
        emnist_model=FLAGS.emnist_cr_model,
        hparam_dict=hparam_dict)

  elif FLAGS.task == 'emnist_ae':
    federated_emnist_ae.run_federated(**shared_args, hparam_dict=hparam_dict)

  elif FLAGS.task == 'shakespeare':
    federated_shakespeare.run_federated(
        **shared_args,
        sequence_length=FLAGS.shakespeare_sequence_length,
        hparam_dict=hparam_dict)

  elif FLAGS.task == 'stackoverflow_nwp':
    so_nwp_flags = collections.OrderedDict()
    for flag_name in task_flags:
      if flag_name.startswith('so_nwp_'):
        so_nwp_flags[flag_name[7:]] = FLAGS[flag_name].value
    federated_stackoverflow.run_federated(
        **shared_args, **so_nwp_flags, hparam_dict=hparam_dict)

  elif FLAGS.task == 'stackoverflow_lr':
    so_lr_flags = collections.OrderedDict()
    for flag_name in task_flags:
      if flag_name.startswith('so_lr_'):
        so_lr_flags[flag_name[6:]] = FLAGS[flag_name].value
    federated_stackoverflow_lr.run_federated(
        **shared_args, **so_lr_flags, hparam_dict=hparam_dict)

  else:
    raise ValueError(
        '--task flag {} is not supported, must be one of {}.'.format(
            FLAGS.task, _SUPPORTED_TASKS))
Beispiel #11
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tf.compat.v1.enable_v2_behavior()

    experiment_output_dir = FLAGS.root_output_dir
    tensorboard_dir = os.path.join(experiment_output_dir, 'logdir',
                                   FLAGS.experiment_name)
    results_dir = os.path.join(experiment_output_dir, 'results',
                               FLAGS.experiment_name)

    for path in [experiment_output_dir, tensorboard_dir, results_dir]:
        try:
            tf.io.gfile.makedirs(path)
        except tf.errors.OpError:
            pass  # Directory already exists.

    hparam_dict = collections.OrderedDict([(name, FLAGS[name].value)
                                           for name in hparam_flags])
    hparam_dict['results_file'] = results_dir
    hparams_file = os.path.join(results_dir, 'hparams.csv')

    logging.info('Saving hyper parameters to: [%s]', hparams_file)
    utils_impl.atomic_write_to_csv(pd.Series(hparam_dict), hparams_file)

    cifar_train, cifar_test = dataset.get_centralized_cifar100(
        train_batch_size=FLAGS.batch_size, crop_shape=CROP_SHAPE)

    optimizer = optimizer_utils.create_optimizer_fn_from_flags('centralized')()
    model = resnet_models.create_resnet18(input_shape=CROP_SHAPE,
                                          num_classes=NUM_CLASSES)
    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(),
                  optimizer=optimizer,
                  metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

    logging.info('Training model:')
    logging.info(model.summary())

    csv_logger_callback = keras_callbacks.AtomicCSVLogger(results_dir)
    tensorboard_callback = tf.keras.callbacks.TensorBoard(
        log_dir=tensorboard_dir)

    # Reduce the learning rate after a fixed number of epochs.
    def decay_lr(epoch, learning_rate):
        if (epoch + 1) % FLAGS.decay_epochs == 0:
            return learning_rate * FLAGS.lr_decay
        else:
            return learning_rate

    lr_callback = tf.keras.callbacks.LearningRateScheduler(decay_lr, verbose=1)

    history = model.fit(
        cifar_train,
        validation_data=cifar_test,
        epochs=FLAGS.num_epochs,
        callbacks=[lr_callback, tensorboard_callback, csv_logger_callback])

    logging.info('Final metrics:')
    for name in ['loss', 'sparse_categorical_accuracy']:
        metric = history.history['val_{}'.format(name)][-1]
        logging.info('\t%s: %.4f', name, metric)
Beispiel #12
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')

    client_lr_callback = callbacks.create_reduce_lr_on_plateau(
        learning_rate=FLAGS.client_learning_rate,
        decay_factor=FLAGS.client_decay_factor,
        min_delta=FLAGS.min_delta,
        min_lr=FLAGS.min_lr,
        window_size=FLAGS.window_size,
        patience=FLAGS.patience)

    server_lr_callback = callbacks.create_reduce_lr_on_plateau(
        learning_rate=FLAGS.server_learning_rate,
        decay_factor=FLAGS.server_decay_factor,
        min_delta=FLAGS.min_delta,
        min_lr=FLAGS.min_lr,
        window_size=FLAGS.window_size,
        patience=FLAGS.patience)

    def iterative_process_builder(
        model_fn: Callable[[], tff.learning.Model],
        client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
        dataset_preprocess_comp: Optional[tff.Computation] = None,
    ) -> tff.templates.IterativeProcess:
        """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.
      client_weight_fn: Optional function that takes the output of
        `model.report_local_outputs` and returns a tensor providing the weight
        in the federated average of model deltas. If not provided, the default
        is the total number of examples processed on device.
      dataset_preprocess_comp: Optional `tff.Computation` that sets up a data
        pipeline on the clients. The computation must take a squence of values
        and return a sequence of values, or in TFF type shorthand `(U* -> V*)`.
        If `None`, no dataset preprocessing is applied.

    Returns:
      A `tff.templates.IterativeProcess`.
    """

        return adaptive_fed_avg.build_fed_avg_process(
            model_fn,
            client_lr_callback,
            server_lr_callback,
            client_optimizer_fn=client_optimizer_fn,
            server_optimizer_fn=server_optimizer_fn,
            client_weight_fn=client_weight_fn,
            dataset_preprocess_comp=dataset_preprocess_comp)

    assign_weights_fn = adaptive_fed_avg.ServerState.assign_weights_to_keras_model

    common_args = collections.OrderedDict([
        ('iterative_process_builder', iterative_process_builder),
        ('assign_weights_fn', assign_weights_fn),
        ('client_epochs_per_round', FLAGS.client_epochs_per_round),
        ('client_batch_size', FLAGS.client_batch_size),
        ('clients_per_round', FLAGS.clients_per_round),
        ('max_batches_per_client', FLAGS.max_batches_per_client),
        ('client_datasets_random_seed', FLAGS.client_datasets_random_seed)
    ])

    if FLAGS.task == 'cifar100':
        federated_cifar100.run_federated(**common_args,
                                         crop_size=FLAGS.cifar100_crop_size)

    elif FLAGS.task == 'emnist_cr':
        federated_emnist.run_federated(**common_args,
                                       emnist_model=FLAGS.emnist_cr_model)

    elif FLAGS.task == 'emnist_ae':
        federated_emnist_ae.run_federated(**common_args)

    elif FLAGS.task == 'shakespeare':
        federated_shakespeare.run_federated(
            **common_args, sequence_length=FLAGS.shakespeare_sequence_length)

    elif FLAGS.task == 'stackoverflow_nwp':
        so_nwp_flags = collections.OrderedDict()
        for flag_name in FLAGS:
            if flag_name.startswith('so_nwp_'):
                so_nwp_flags[flag_name[7:]] = FLAGS[flag_name].value
        federated_stackoverflow.run_federated(**common_args, **so_nwp_flags)

    elif FLAGS.task == 'stackoverflow_lr':
        so_lr_flags = collections.OrderedDict()
        for flag_name in FLAGS:
            if flag_name.startswith('so_lr_'):
                so_lr_flags[flag_name[6:]] = FLAGS[flag_name].value
        federated_stackoverflow_lr.run_federated(**common_args, **so_lr_flags)

    else:
        raise ValueError(
            '--task flag {} is not supported, must be one of {}.'.format(
                FLAGS.task, _SUPPORTED_TASKS))