Exemplo n.º 1
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    client_optimizer_fn = lambda: utils_impl.create_optimizer_from_flags(
        'client')
    server_optimizer_fn = lambda: utils_impl.create_optimizer_from_flags(
        'server')

    compression_dict = utils_impl.lookup_flag_values(compression_flags)
    dp_dict = utils_impl.lookup_flag_values(dp_flags)

    def iterative_process_builder(
        model_fn: Callable[[], tff.learning.Model],
    ) -> tff.templates.IterativeProcess:
        """Creates an iterative process using a given TFF `model_fn`."""

        model_trainable_variables = model_fn().trainable_variables

        # Most logic for deciding what to run is here.
        aggregation_factory = fl_utils.build_aggregator(
            compression_flags=compression_dict,
            dp_flags=dp_dict,
            num_clients=get_total_num_clients(FLAGS.task),
            num_clients_per_round=FLAGS.clients_per_round,
            num_rounds=FLAGS.total_rounds,
            client_template=model_trainable_variables)

        return tff.learning.build_federated_averaging_process(
            model_fn=model_fn,
            server_optimizer_fn=server_optimizer_fn,
            client_weighting=tff.learning.ClientWeighting.UNIFORM,
            client_optimizer_fn=client_optimizer_fn,
            model_update_aggregation_factory=aggregation_factory)

    task_spec = training_specs.TaskSpec(
        iterative_process_builder=iterative_process_builder,
        client_epochs_per_round=FLAGS.client_epochs_per_round,
        client_batch_size=FLAGS.client_batch_size,
        clients_per_round=FLAGS.clients_per_round,
        client_datasets_random_seed=FLAGS.client_datasets_random_seed)

    if FLAGS.task == 'stackoverflow_lr':
        runner_spec = federated_stackoverflow_lr.configure_training(task_spec)
    else:
        raise ValueError(
            '--task flag {} is not supported, must be one of {}.'.format(
                FLAGS.task, _SUPPORTED_TASKS))

    training_loop.run(iterative_process=runner_spec.iterative_process,
                      client_datasets_fn=runner_spec.client_datasets_fn,
                      validation_fn=runner_spec.validation_fn,
                      test_fn=runner_spec.test_fn,
                      total_rounds=FLAGS.total_rounds,
                      experiment_name=FLAGS.experiment_name,
                      root_output_dir=FLAGS.root_output_dir,
                      rounds_per_eval=FLAGS.rounds_per_eval,
                      rounds_per_checkpoint=FLAGS.rounds_per_checkpoint)
Exemplo n.º 2
0
def _write_hparam_flags():
  """Returns an ordered dictionary of pertinent hyperparameter flags."""
  hparam_dict = utils_impl.lookup_flag_values(shared_flags)

  # Update with optimizer flags corresponding to the chosen optimizers.
  opt_flag_dict = utils_impl.lookup_flag_values(optimizer_flags)
  opt_flag_dict = optimizer_utils.remove_unused_flags('client', opt_flag_dict)
  opt_flag_dict = optimizer_utils.remove_unused_flags('server', opt_flag_dict)
  hparam_dict.update(opt_flag_dict)

  results_dir = os.path.join(FLAGS.root_output_dir, 'results',
                             FLAGS.experiment_name)
  utils_impl.create_directory_if_not_exists(results_dir)
  hparam_file = os.path.join(results_dir, 'hparams.csv')
  utils_impl.atomic_write_series_to_csv(hparam_dict, hparam_file)
Exemplo n.º 3
0
def get_hparam_flags():
    """Returns an ordered dictionary of pertinent hyperparameter flags."""
    hparam_dict = utils_impl.lookup_flag_values(shared_flags)

    # Update with optimizer flags corresponding to the chosen optimizers.
    opt_flag_dict = utils_impl.lookup_flag_values(optimizer_flags)
    opt_flag_dict = optimizer_utils.remove_unused_flags(
        'client', opt_flag_dict)
    opt_flag_dict = optimizer_utils.remove_unused_flags(
        'server', opt_flag_dict)
    hparam_dict.update(opt_flag_dict)

    # Update with task-specific flags.
    task_hparam_dict = utils_impl.lookup_flag_values(gld_flags)
    hparam_dict.update(task_hparam_dict)

    return hparam_dict
    def test_convert_flag_names_to_odict(self):
        with utils_impl.record_new_flags() as hparam_flags:
            flags.DEFINE_integer('flag1', 1, 'This is the first flag.')
            flags.DEFINE_float('flag2', 2.0, 'This is the second flag.')

        hparam_odict = utils_impl.lookup_flag_values(hparam_flags)
        expected_odict = collections.OrderedDict(flag1=1, flag2=2.0)

        self.assertEqual(hparam_odict, expected_odict)
def _write_hparam_flags():
  """Creates an ordered dictionary of hyperparameter flags and writes to CSV."""
  hparam_dict = utils_impl.lookup_flag_values(shared_flags)

  # Update with optimizer flags corresponding to the chosen optimizers.
  opt_flag_dict = utils_impl.lookup_flag_values(optimizer_flags)
  opt_flag_dict = optimizer_utils.remove_unused_flags('client', opt_flag_dict)
  opt_flag_dict = optimizer_utils.remove_unused_flags('server', opt_flag_dict)
  hparam_dict.update(opt_flag_dict)

  # Update with task-specific flags.
  task_name = FLAGS.task
  if task_name in TASK_FLAGS:
    task_hparam_dict = utils_impl.lookup_flag_values(TASK_FLAGS[task_name])
    hparam_dict.update(task_hparam_dict)

  results_dir = os.path.join(FLAGS.root_output_dir, 'results',
                             FLAGS.experiment_name)
  utils_impl.create_directory_if_not_exists(results_dir)
  hparam_file = os.path.join(results_dir, 'hparams.csv')
  utils_impl.atomic_write_series_to_csv(hparam_dict, hparam_file)
Exemplo n.º 6
0
def _get_task_args():
    """Returns an ordered dictionary of task-specific arguments.

  This method returns a dict of (arg_name, arg_value) pairs, where the
  arg_name has had the task name removed as a prefix (if it exists), as well
  as any leading `-` or `_` characters.

  Returns:
    An ordered dictionary of (arg_name, arg_value) pairs.
  """
    task_name = FLAGS.task
    task_args = collections.OrderedDict()

    if task_name in TASK_FLAGS:
        task_flag_list = TASK_FLAGS[task_name]
        task_flag_dict = utils_impl.lookup_flag_values(task_flag_list)
        task_flag_prefix = TASK_FLAG_PREFIXES[task_name]
        for (key, value) in task_flag_dict.items():
            if key.startswith(task_flag_prefix):
                key = key[len(task_flag_prefix):].lstrip('_-')
            task_args[key] = value
    return task_args
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    hparams_dict = utils_impl.lookup_flag_values(hparam_flags)
    hparams_dict = optimizer_utils.remove_unused_flags('centralized',
                                                       hparams_dict)

    centralized_main.run_centralized(
        optimizer_utils.create_optimizer_fn_from_flags('centralized')(),
        FLAGS.num_epochs,
        FLAGS.batch_size,
        vocab_size=FLAGS.vocab_size,
        d_embed=FLAGS.d_embed,
        d_model=FLAGS.d_model,
        d_hidden=FLAGS.d_hidden,
        num_heads=FLAGS.num_heads,
        num_layers=FLAGS.num_layers,
        dropout=FLAGS.dropout,
        experiment_name=FLAGS.experiment_name,
        root_output_dir=FLAGS.root_output_dir,
        max_batches=FLAGS.max_batches,
        hparams_dict=hparams_dict)
Exemplo n.º 8
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    hparams_dict = utils_impl.lookup_flag_values(hparam_flags)
    hparams_dict = optimizer_utils.remove_unused_flags('centralized',
                                                       hparams_dict)

    dataset_type = dataset.DatasetType.GLD23K
    if FLAGS.dataset_type == 'gld160k':
        dataset_type = dataset.DatasetType.GLD160K

    centralized_main.run_centralized(
        optimizer=optimizer_utils.create_optimizer_fn_from_flags(
            'centralized')(),
        image_size=FLAGS.image_size,
        num_epochs=FLAGS.num_epochs,
        batch_size=FLAGS.batch_size,
        num_groups=FLAGS.num_groups,
        dataset_type=dataset_type,
        experiment_name=FLAGS.experiment_name,
        root_output_dir=FLAGS.root_output_dir,
        dropout_prob=FLAGS.dropout_prob,
        hparams_dict=hparams_dict)
 def test_convert_undefined_flag_names(self):
     with self.assertRaisesRegex(ValueError,
                                 '"bad_flag" is not a defined flag'):
         utils_impl.lookup_flag_values(['bad_flag'])
Exemplo n.º 10
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')

    client_lr_callback = callbacks.create_reduce_lr_on_plateau(
        learning_rate=FLAGS.client_learning_rate,
        decay_factor=FLAGS.client_decay_factor,
        min_delta=FLAGS.min_delta,
        min_lr=FLAGS.min_lr,
        window_size=FLAGS.window_size,
        patience=FLAGS.patience)

    server_lr_callback = callbacks.create_reduce_lr_on_plateau(
        learning_rate=FLAGS.server_learning_rate,
        decay_factor=FLAGS.server_decay_factor,
        min_delta=FLAGS.min_delta,
        min_lr=FLAGS.min_lr,
        window_size=FLAGS.window_size,
        patience=FLAGS.patience)

    def iterative_process_builder(
        model_fn: Callable[[], tff.learning.Model],
        client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
    ) -> tff.templates.IterativeProcess:
        """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.
      client_weight_fn: Optional function that takes the output of
        `model.report_local_outputs` and returns a tensor providing the weight
        in the federated average of model deltas. If not provided, the default
        is the total number of examples processed on device.

    Returns:
      A `tff.templates.IterativeProcess`.
    """

        return adaptive_fed_avg.build_fed_avg_process(
            model_fn,
            client_lr_callback,
            server_lr_callback,
            client_optimizer_fn=client_optimizer_fn,
            server_optimizer_fn=server_optimizer_fn,
            client_weight_fn=client_weight_fn)

    hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags())

    task_spec = training_specs.TaskSpec(
        iterative_process_builder=iterative_process_builder,
        client_epochs_per_round=FLAGS.client_epochs_per_round,
        client_batch_size=FLAGS.client_batch_size,
        clients_per_round=FLAGS.clients_per_round,
        sampling_random_seed=FLAGS.client_datasets_random_seed)

    if FLAGS.task == 'cifar100':
        runner_spec = federated_cifar100.configure_training(
            task_spec, crop_size=FLAGS.cifar100_crop_size)
    elif FLAGS.task == 'emnist_cr':
        runner_spec = federated_emnist.configure_training(
            task_spec, model=FLAGS.emnist_cr_model)
    elif FLAGS.task == 'emnist_ae':
        runner_spec = federated_emnist_ae.configure_training(task_spec)
    elif FLAGS.task == 'shakespeare':
        runner_spec = federated_shakespeare.configure_training(
            task_spec, sequence_length=FLAGS.shakespeare_sequence_length)
    elif FLAGS.task == 'stackoverflow_nwp':
        runner_spec = federated_stackoverflow.configure_training(
            task_spec,
            vocab_size=FLAGS.so_nwp_vocab_size,
            num_oov_buckets=FLAGS.so_nwp_num_oov_buckets,
            sequence_length=FLAGS.so_nwp_sequence_length,
            max_elements_per_user=FLAGS.so_nwp_max_elements_per_user,
            num_validation_examples=FLAGS.so_nwp_num_validation_examples)
    elif FLAGS.task == 'stackoverflow_lr':
        runner_spec = federated_stackoverflow_lr.configure_training(
            task_spec,
            vocab_tokens_size=FLAGS.so_lr_vocab_tokens_size,
            vocab_tags_size=FLAGS.so_lr_vocab_tags_size,
            max_elements_per_user=FLAGS.so_lr_max_elements_per_user,
            num_validation_examples=FLAGS.so_lr_num_validation_examples)
    else:
        raise ValueError(
            '--task flag {} is not supported, must be one of {}.'.format(
                FLAGS.task, _SUPPORTED_TASKS))

    training_loop.run(iterative_process=runner_spec.iterative_process,
                      client_datasets_fn=runner_spec.client_datasets_fn,
                      validation_fn=runner_spec.validation_fn,
                      test_fn=runner_spec.test_fn,
                      total_rounds=FLAGS.total_rounds,
                      experiment_name=FLAGS.experiment_name,
                      root_output_dir=FLAGS.root_output_dir,
                      rounds_per_eval=FLAGS.rounds_per_eval,
                      rounds_per_checkpoint=FLAGS.rounds_per_checkpoint,
                      rounds_per_profile=FLAGS.rounds_per_profile,
                      hparam_dict=hparam_dict)
Exemplo n.º 11
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Expected no command-line arguments, '
                         'got: {}'.format(argv))

  emnist_task = 'digit_recognition'
  emnist_train, _ = tff.simulation.datasets.emnist.load_data(only_digits=False)
  _, emnist_test = emnist_dataset.get_centralized_datasets(
      only_digits=False, emnist_task=emnist_task)

  train_preprocess_fn = emnist_dataset.create_preprocess_fn(
      num_epochs=FLAGS.client_epochs_per_round,
      batch_size=FLAGS.client_batch_size,
      emnist_task=emnist_task)

  input_spec = train_preprocess_fn.type_signature.result.element

  if FLAGS.model == 'cnn':
    model_builder = functools.partial(
        emnist_models.create_conv_dropout_model, only_digits=FLAGS.only_digits)
  elif FLAGS.model == '2nn':
    model_builder = functools.partial(
        emnist_models.create_two_hidden_layer_model,
        only_digits=FLAGS.only_digits)
  elif FLAGS.model == '1m_cnn':
    model_builder = functools.partial(
        create_1m_cnn_model, only_digits=FLAGS.only_digits)
  else:
    raise ValueError('Cannot handle model flag [{!s}].'.format(FLAGS.model))

  logging.info('Training model:')
  logging.info(model_builder().summary())

  loss_builder = tf.keras.losses.SparseCategoricalCrossentropy
  metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()]

  compression_dict = utils_impl.lookup_flag_values(compression_flags)
  dp_dict = utils_impl.lookup_flag_values(dp_flags)

  # Most logic for deciding what baseline to run is here.
  aggregation_factory = fl_utils.build_aggregator(
      compression_flags=compression_dict,
      dp_flags=dp_dict,
      num_clients=len(emnist_train.client_ids),
      num_clients_per_round=FLAGS.clients_per_round,
      num_rounds=FLAGS.total_rounds,
      client_template=model_builder().trainable_variables)

  def tff_model_fn():
    return tff.learning.from_keras_model(
        keras_model=model_builder(),
        loss=loss_builder(),
        input_spec=input_spec,
        metrics=metrics_builder())

  server_optimizer_fn = lambda: utils_impl.create_optimizer_from_flags('server')
  client_optimizer_fn = lambda: utils_impl.create_optimizer_from_flags('client')

  iterative_process = tff.learning.build_federated_averaging_process(
      model_fn=tff_model_fn,
      server_optimizer_fn=server_optimizer_fn,
      client_weighting=tff.learning.ClientWeighting.UNIFORM,
      client_optimizer_fn=client_optimizer_fn,
      model_update_aggregation_factory=aggregation_factory)

  @tff.tf_computation(tf.string)
  def build_train_dataset_from_client_id(client_id):
    client_dataset = emnist_train.dataset_computation(client_id)
    return train_preprocess_fn(client_dataset)

  training_process = tff.simulation.compose_dataset_computation_with_iterative_process(
      build_train_dataset_from_client_id, iterative_process)
  training_process.get_model_weights = iterative_process.get_model_weights

  client_ids_fn = functools.partial(
      tff.simulation.build_uniform_sampling_fn(
          emnist_train.client_ids,
          replace=False,
          random_seed=FLAGS.client_datasets_random_seed),
      size=FLAGS.clients_per_round)

  # We convert the output to a list (instead of an np.ndarray) so that it can
  # be used as input to the iterative process.
  client_sampling_fn = lambda x: list(client_ids_fn(x))

  evaluate_fn = tff.learning.build_federated_evaluation(tff_model_fn)

  def test_fn(state):
    return evaluate_fn(
        iterative_process.get_model_weights(state), [emnist_test])

  def validation_fn(state, round_num):
    del round_num
    return evaluate_fn(
        iterative_process.get_model_weights(state), [emnist_test])

  training_loop.run(
      iterative_process=training_process,
      client_datasets_fn=client_sampling_fn,
      validation_fn=validation_fn,
      test_fn=test_fn,
      total_rounds=FLAGS.total_rounds,
      experiment_name=FLAGS.experiment_name,
      root_output_dir=FLAGS.root_output_dir,
      rounds_per_eval=FLAGS.rounds_per_eval,
      rounds_per_checkpoint=FLAGS.rounds_per_checkpoint)
Exemplo n.º 12
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Expected no command-line arguments, '
                         'got: {}'.format(argv))

  client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('client')
  server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('server')

  client_lr_callback = callbacks.create_reduce_lr_on_plateau(
      learning_rate=FLAGS.client_learning_rate,
      decay_factor=FLAGS.client_decay_factor,
      min_delta=FLAGS.min_delta,
      min_lr=FLAGS.min_lr,
      window_size=FLAGS.window_size,
      patience=FLAGS.patience)

  server_lr_callback = callbacks.create_reduce_lr_on_plateau(
      learning_rate=FLAGS.server_learning_rate,
      decay_factor=FLAGS.server_decay_factor,
      min_delta=FLAGS.min_delta,
      min_lr=FLAGS.min_lr,
      window_size=FLAGS.window_size,
      patience=FLAGS.patience)

  def iterative_process_builder(
      model_fn: Callable[[], tff.learning.Model],
      client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
  ) -> tff.templates.IterativeProcess:
    """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.
      client_weight_fn: Optional function that takes the output of
        `model.report_local_outputs` and returns a tensor providing the weight
        in the federated average of model deltas. If not provided, the default
        is the total number of examples processed on device.

    Returns:
      A `tff.templates.IterativeProcess`.
    """

    return adaptive_fed_avg.build_fed_avg_process(
        model_fn,
        client_lr_callback,
        server_lr_callback,
        client_optimizer_fn=client_optimizer_fn,
        server_optimizer_fn=server_optimizer_fn,
        client_weight_fn=client_weight_fn)

  hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags())

  shared_args = utils_impl.lookup_flag_values(shared_flags)
  shared_args['iterative_process_builder'] = iterative_process_builder

  if FLAGS.task == 'cifar100':
    hparam_dict['cifar100_crop_size'] = FLAGS.cifar100_crop_size
    federated_cifar100.run_federated(
        **shared_args,
        crop_size=FLAGS.cifar100_crop_size,
        hparam_dict=hparam_dict)

  elif FLAGS.task == 'emnist_cr':
    federated_emnist.run_federated(
        **shared_args, model=FLAGS.emnist_cr_model, hparam_dict=hparam_dict)

  elif FLAGS.task == 'emnist_ae':
    federated_emnist_ae.run_federated(**shared_args, hparam_dict=hparam_dict)

  elif FLAGS.task == 'shakespeare':
    federated_shakespeare.run_federated(
        **shared_args,
        sequence_length=FLAGS.shakespeare_sequence_length,
        hparam_dict=hparam_dict)

  elif FLAGS.task == 'stackoverflow_nwp':
    so_nwp_flags = collections.OrderedDict()
    for flag_name in task_flags:
      if flag_name.startswith('so_nwp_'):
        so_nwp_flags[flag_name[7:]] = FLAGS[flag_name].value
    federated_stackoverflow.run_federated(
        **shared_args, **so_nwp_flags, hparam_dict=hparam_dict)

  elif FLAGS.task == 'stackoverflow_lr':
    so_lr_flags = collections.OrderedDict()
    for flag_name in task_flags:
      if flag_name.startswith('so_lr_'):
        so_lr_flags[flag_name[6:]] = FLAGS[flag_name].value
    federated_stackoverflow_lr.run_federated(
        **shared_args, **so_lr_flags, hparam_dict=hparam_dict)
Exemplo n.º 13
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')

    client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'client')
    server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'server')

    def iterative_process_builder(
        model_fn: Callable[[], tff.learning.Model],
        client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
    ) -> tff.templates.IterativeProcess:
        """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.
      client_weight_fn: Optional function that takes the output of
        `model.report_local_outputs` and returns a tensor providing the weight
        in the federated average of model deltas. If not provided, the default
        is the total number of examples processed on device.

    Returns:
      A `tff.templates.IterativeProcess`.
    """

        return fed_avg_schedule.build_fed_avg_process(
            model_fn=model_fn,
            client_optimizer_fn=client_optimizer_fn,
            client_lr=client_lr_schedule,
            server_optimizer_fn=server_optimizer_fn,
            server_lr=server_lr_schedule,
            client_weight_fn=client_weight_fn)

    shared_args = utils_impl.lookup_flag_values(shared_flags)
    shared_args['iterative_process_builder'] = iterative_process_builder
    task_args = _get_task_args()
    hparam_dict = _get_hparam_flags()

    if FLAGS.task == 'cifar100':
        run_federated_fn = federated_cifar100.run_federated
    elif FLAGS.task == 'emnist_cr':
        run_federated_fn = federated_emnist.run_federated
    elif FLAGS.task == 'emnist_ae':
        run_federated_fn = federated_emnist_ae.run_federated
    elif FLAGS.task == 'shakespeare':
        run_federated_fn = federated_shakespeare.run_federated
    elif FLAGS.task == 'stackoverflow_nwp':
        run_federated_fn = federated_stackoverflow.run_federated
    elif FLAGS.task == 'stackoverflow_lr':
        run_federated_fn = federated_stackoverflow_lr.run_federated
    else:
        raise ValueError(
            '--task flag {} is not supported, must be one of {}.'.format(
                FLAGS.task, _SUPPORTED_TASKS))

    run_federated_fn(**shared_args, **task_args, hparam_dict=hparam_dict)
Exemplo n.º 14
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))
    tff.backends.native.set_local_execution_context(max_fanout=10)

    model_builder = functools.partial(
        stackoverflow_models.create_recurrent_model,
        vocab_size=FLAGS.vocab_size,
        embedding_size=FLAGS.embedding_size,
        latent_size=FLAGS.latent_size,
        num_layers=FLAGS.num_layers,
        shared_embedding=FLAGS.shared_embedding)

    loss_builder = functools.partial(
        tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True)

    special_tokens = stackoverflow_dataset.get_special_tokens(FLAGS.vocab_size)
    pad_token = special_tokens.pad
    oov_tokens = special_tokens.oov
    eos_token = special_tokens.eos

    def metrics_builder():
        return [
            keras_metrics.MaskedCategoricalAccuracy(name='accuracy_with_oov',
                                                    masked_tokens=[pad_token]),
            keras_metrics.MaskedCategoricalAccuracy(name='accuracy_no_oov',
                                                    masked_tokens=[pad_token] +
                                                    oov_tokens),
            # Notice BOS never appears in ground truth.
            keras_metrics.MaskedCategoricalAccuracy(
                name='accuracy_no_oov_or_eos',
                masked_tokens=[pad_token, eos_token] + oov_tokens),
            keras_metrics.NumBatchesCounter(),
            keras_metrics.NumTokensCounter(masked_tokens=[pad_token]),
        ]

    datasets = stackoverflow_dataset.construct_word_level_datasets(
        FLAGS.vocab_size, FLAGS.client_batch_size,
        FLAGS.client_epochs_per_round, FLAGS.sequence_length,
        FLAGS.max_elements_per_user, FLAGS.num_validation_examples)
    train_dataset, validation_dataset, test_dataset = datasets

    if FLAGS.uniform_weighting:

        def client_weight_fn(local_outputs):
            del local_outputs
            return 1.0
    else:

        def client_weight_fn(local_outputs):
            return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32)

    def model_fn():
        return tff.learning.from_keras_model(
            model_builder(),
            loss_builder(),
            input_spec=validation_dataset.element_spec,
            metrics=metrics_builder())

    if FLAGS.noise_multiplier is not None:
        if not FLAGS.uniform_weighting:
            raise ValueError(
                'Differential privacy is only implemented for uniform weighting.'
            )

        dp_query = tff.utils.build_dp_query(
            clip=FLAGS.clip,
            noise_multiplier=FLAGS.noise_multiplier,
            expected_total_weight=FLAGS.clients_per_round,
            adaptive_clip_learning_rate=FLAGS.adaptive_clip_learning_rate,
            target_unclipped_quantile=FLAGS.target_unclipped_quantile,
            clipped_count_budget_allocation=FLAGS.
            clipped_count_budget_allocation,
            expected_clients_per_round=FLAGS.clients_per_round)

        weights_type = tff.learning.framework.weights_type_from_model(model_fn)
        aggregation_process = tff.utils.build_dp_aggregate_process(
            weights_type.trainable, dp_query)
    else:
        aggregation_process = None

    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')
    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')

    iterative_process = tff.learning.build_federated_averaging_process(
        model_fn=model_fn,
        server_optimizer_fn=server_optimizer_fn,
        client_weight_fn=client_weight_fn,
        client_optimizer_fn=client_optimizer_fn,
        aggregation_process=aggregation_process)

    client_datasets_fn = training_utils.build_client_datasets_fn(
        train_dataset, FLAGS.clients_per_round)

    evaluate_fn = training_utils.build_centralized_evaluate_fn(
        model_builder=model_builder,
        eval_dataset=validation_dataset,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    test_fn = training_utils.build_centralized_evaluate_fn(
        model_builder=model_builder,
        # Use both val and test for symmetry with other experiments, which
        # evaluate on the entire test set.
        eval_dataset=validation_dataset.concatenate(test_dataset),
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    logging.info('Training model:')
    logging.info(model_builder().summary())

    hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags())
    training_loop_dict = utils_impl.lookup_flag_values(training_loop_flags)

    training_loop.run(iterative_process=iterative_process,
                      client_datasets_fn=client_datasets_fn,
                      validation_fn=evaluate_fn,
                      test_fn=test_fn,
                      hparam_dict=hparam_dict,
                      **training_loop_dict)
Exemplo n.º 15
0
def run_experiment():
  """Data preprocessing and experiment execution."""
  emnist_train, _ = emnist_dataset.get_federated_datasets(
      train_client_batch_size=FLAGS.client_batch_size,
      train_client_epochs_per_round=FLAGS.client_epochs_per_round,
      only_digits=FLAGS.only_digits)
  _, emnist_test = emnist_dataset.get_centralized_datasets(
      only_digits=FLAGS.only_digits)

  example_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[0])
  input_spec = example_dataset.element_spec

  # Build optimizer functions from flags
  client_optimizer_fn = functools.partial(
      utils_impl.create_optimizer_from_flags, 'client')
  server_optimizer_fn = functools.partial(
      utils_impl.create_optimizer_from_flags, 'server')

  def tff_model_fn():
    return tff.learning.from_keras_model(
        keras_model=model_builder(),
        input_spec=input_spec,
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

  if FLAGS.use_compression:
    # We create a `tff.templates.MeasuredProcess` for broadcast process and a
    # `tff.aggregators.WeightedAggregationFactory` for aggregation by providing
    # the `_broadcast_encoder_fn` and `_mean_encoder_fn` to corresponding
    # utilities. The fns are called once for each of the model weights created
    # by tff_model_fn, and return instances of appropriate encoders.
    encoded_broadcast_process = (
        tff.learning.framework.build_encoded_broadcast_process_from_model(
            tff_model_fn, _broadcast_encoder_fn))
    aggregator = tff.aggregators.MeanFactory(
        tff.aggregators.EncodedSumFactory(_mean_encoder_fn))
  else:
    encoded_broadcast_process = None
    aggregator = None

  # Construct the iterative process
  iterative_process = tff.learning.build_federated_averaging_process(
      model_fn=tff_model_fn,
      client_optimizer_fn=client_optimizer_fn,
      server_optimizer_fn=server_optimizer_fn,
      broadcast_process=encoded_broadcast_process,
      model_update_aggregation_factory=aggregator)

  iterative_process = (
      tff.simulation.compose_dataset_computation_with_iterative_process(
          emnist_train.dataset_computation, iterative_process))

  # Create a client sampling function, mapping integer round numbers to lists
  # of client ids.
  client_selection_fn = functools.partial(
      tff.simulation.build_uniform_sampling_fn(emnist_train.client_ids),
      size=FLAGS.clients_per_round)

  # Create a validation function
  evaluate_fn = tff.learning.build_federated_evaluation(tff_model_fn)

  def validation_fn(state, round_num):
    if round_num % FLAGS.rounds_per_eval == 0:
      return evaluate_fn(state.model, [emnist_test])
    else:
      return {}

  # Log hyperparameters to CSV
  hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags())
  results_dir = os.path.join(FLAGS.root_output_dir, 'results',
                             FLAGS.experiment_name)
  utils_impl.create_directory_if_not_exists(results_dir)
  hparam_file = os.path.join(results_dir, 'hparams.csv')
  utils_impl.atomic_write_series_to_csv(hparam_dict, hparam_file)

  checkpoint_manager, metrics_managers = _configure_managers()

  tff.simulation.run_simulation(
      process=iterative_process,
      client_selection_fn=client_selection_fn,
      validation_fn=validation_fn,
      total_rounds=FLAGS.total_rounds,
      file_checkpoint_manager=checkpoint_manager,
      metrics_managers=metrics_managers)
 def test_convert_nonstr_flag(self):
     with self.assertRaisesRegex(ValueError,
                                 'All flag names must be strings'):
         utils_impl.lookup_flag_values([300])
Exemplo n.º 17
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))
    tff.backends.native.set_local_execution_context(max_fanout=10)

    model_builder = functools.partial(
        stackoverflow_models.create_recurrent_model,
        vocab_size=FLAGS.vocab_size,
        embedding_size=FLAGS.embedding_size,
        latent_size=FLAGS.latent_size,
        num_layers=FLAGS.num_layers,
        shared_embedding=FLAGS.shared_embedding)

    loss_builder = functools.partial(
        tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True)

    special_tokens = stackoverflow_word_prediction.get_special_tokens(
        FLAGS.vocab_size)
    pad_token = special_tokens.pad
    oov_tokens = special_tokens.oov
    eos_token = special_tokens.eos

    def metrics_builder():
        return [
            keras_metrics.MaskedCategoricalAccuracy(name='accuracy_with_oov',
                                                    masked_tokens=[pad_token]),
            keras_metrics.MaskedCategoricalAccuracy(name='accuracy_no_oov',
                                                    masked_tokens=[pad_token] +
                                                    oov_tokens),
            # Notice BOS never appears in ground truth.
            keras_metrics.MaskedCategoricalAccuracy(
                name='accuracy_no_oov_or_eos',
                masked_tokens=[pad_token, eos_token] + oov_tokens),
            keras_metrics.NumBatchesCounter(),
            keras_metrics.NumTokensCounter(masked_tokens=[pad_token]),
        ]

    train_dataset, _ = stackoverflow_word_prediction.get_federated_datasets(
        vocab_size=FLAGS.vocab_size,
        train_client_batch_size=FLAGS.client_batch_size,
        train_client_epochs_per_round=FLAGS.client_epochs_per_round,
        max_sequence_length=FLAGS.sequence_length,
        max_elements_per_train_client=FLAGS.max_elements_per_user)
    _, validation_dataset, test_dataset = stackoverflow_word_prediction.get_centralized_datasets(
        vocab_size=FLAGS.vocab_size,
        max_sequence_length=FLAGS.sequence_length,
        num_validation_examples=FLAGS.num_validation_examples)

    if FLAGS.uniform_weighting:
        client_weighting = tff.learning.ClientWeighting.UNIFORM
    else:
        client_weighting = tff.learning.ClientWeighting.NUM_EXAMPLES

    def model_fn():
        return tff.learning.from_keras_model(
            model_builder(),
            loss_builder(),
            input_spec=validation_dataset.element_spec,
            metrics=metrics_builder())

    if FLAGS.noise_multiplier is not None:
        if not FLAGS.uniform_weighting:
            raise ValueError(
                'Differential privacy is only implemented for uniform weighting.'
            )
        if FLAGS.noise_multiplier <= 0:
            raise ValueError(
                'noise_multiplier must be positive if DP is enabled.')
        if FLAGS.clip is None or FLAGS.clip <= 0:
            raise ValueError('clip must be positive if DP is enabled.')

        if not FLAGS.adaptive_clip_learning_rate:
            aggregation_factory = tff.aggregators.DifferentiallyPrivateFactory.gaussian_fixed(
                noise_multiplier=FLAGS.noise_multiplier,
                clients_per_round=FLAGS.clients_per_round,
                clip=FLAGS.clip)
        else:
            if FLAGS.adaptive_clip_learning_rate <= 0:
                raise ValueError(
                    'adaptive_clip_learning_rate must be positive if '
                    'adaptive clipping is enabled.')
            aggregation_factory = tff.aggregators.DifferentiallyPrivateFactory.gaussian_adaptive(
                noise_multiplier=FLAGS.noise_multiplier,
                clients_per_round=FLAGS.clients_per_round,
                initial_l2_norm_clip=FLAGS.clip,
                target_unclipped_quantile=FLAGS.target_unclipped_quantile,
                learning_rate=FLAGS.adaptive_clip_learning_rate)
    else:
        if FLAGS.uniform_weighting:
            aggregation_factory = tff.aggregators.UnweightedMeanFactory()
        else:
            aggregation_factory = tff.aggregators.MeanFactory()

    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')
    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')

    iterative_process = tff.learning.build_federated_averaging_process(
        model_fn=model_fn,
        server_optimizer_fn=server_optimizer_fn,
        client_weighting=client_weighting,
        client_optimizer_fn=client_optimizer_fn,
        model_update_aggregation_factory=aggregation_factory)

    client_datasets_fn = training_utils.build_client_datasets_fn(
        train_dataset, FLAGS.clients_per_round)

    evaluate_fn = training_utils.build_centralized_evaluate_fn(
        model_builder=model_builder,
        eval_dataset=validation_dataset,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)
    validation_fn = lambda state, round_num: evaluate_fn(state.model)

    evaluate_test_fn = training_utils.build_centralized_evaluate_fn(
        model_builder=model_builder,
        # Use both val and test for symmetry with other experiments, which
        # evaluate on the entire test set.
        eval_dataset=validation_dataset.concatenate(test_dataset),
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)
    test_fn = lambda state: evaluate_test_fn(state.model)

    logging.info('Training model:')
    logging.info(model_builder().summary())

    # Log hyperparameters to CSV
    hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags())
    results_dir = os.path.join(FLAGS.root_output_dir, 'results',
                               FLAGS.experiment_name)
    utils_impl.create_directory_if_not_exists(results_dir)
    hparam_file = os.path.join(results_dir, 'hparams.csv')
    utils_impl.atomic_write_series_to_csv(hparam_dict, hparam_file)

    training_loop.run(iterative_process=iterative_process,
                      client_datasets_fn=client_datasets_fn,
                      validation_fn=validation_fn,
                      test_fn=test_fn,
                      total_rounds=FLAGS.total_rounds,
                      experiment_name=FLAGS.experiment_name,
                      root_output_dir=FLAGS.root_output_dir,
                      rounds_per_eval=FLAGS.rounds_per_eval,
                      rounds_per_checkpoint=FLAGS.rounds_per_checkpoint)
Exemplo n.º 18
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Expected no command-line arguments, '
                         'got: {}'.format(argv))

  emnist_train, _ = emnist_dataset.get_federated_datasets(
      train_client_batch_size=FLAGS.client_batch_size,
      train_client_epochs_per_round=FLAGS.client_epochs_per_round,
      only_digits=False)

  _, emnist_test = emnist_dataset.get_centralized_datasets()

  if FLAGS.model == 'cnn':
    model_builder = functools.partial(
        emnist_models.create_conv_dropout_model, only_digits=False)
  elif FLAGS.model == '2nn':
    model_builder = functools.partial(
        emnist_models.create_two_hidden_layer_model, only_digits=False)
  else:
    raise ValueError('Cannot handle model flag [{!s}].'.format(FLAGS.model))

  loss_builder = tf.keras.losses.SparseCategoricalCrossentropy
  metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()]

  if FLAGS.uniform_weighting:
    client_weighting = tff.learning.ClientWeighting.UNIFORM
  else:
    client_weighting = tff.learning.ClientWeighting.NUM_EXAMPLES

  def model_fn():
    return tff.learning.from_keras_model(
        model_builder(),
        loss_builder(),
        input_spec=emnist_test.element_spec,
        metrics=metrics_builder())

  if FLAGS.noise_multiplier is not None:
    if not FLAGS.uniform_weighting:
      raise ValueError(
          'Differential privacy is only implemented for uniform weighting.')
    if FLAGS.noise_multiplier <= 0:
      raise ValueError('noise_multiplier must be positive if DP is enabled.')
    if FLAGS.clip is None or FLAGS.clip <= 0:
      raise ValueError('clip must be positive if DP is enabled.')

    if not FLAGS.adaptive_clip_learning_rate:
      aggregation_factory = tff.aggregators.DifferentiallyPrivateFactory.gaussian_fixed(
          noise_multiplier=FLAGS.noise_multiplier,
          clients_per_round=FLAGS.clients_per_round,
          clip=FLAGS.clip)
    else:
      if FLAGS.adaptive_clip_learning_rate <= 0:
        raise ValueError('adaptive_clip_learning_rate must be positive if '
                         'adaptive clipping is enabled.')
      aggregation_factory = tff.aggregators.DifferentiallyPrivateFactory.gaussian_adaptive(
          noise_multiplier=FLAGS.noise_multiplier,
          clients_per_round=FLAGS.clients_per_round,
          initial_l2_norm_clip=FLAGS.clip,
          target_unclipped_quantile=FLAGS.target_unclipped_quantile,
          learning_rate=FLAGS.adaptive_clip_learning_rate)
  else:
    if FLAGS.uniform_weighting:
      aggregation_factory = tff.aggregators.UnweightedMeanFactory()
    else:
      aggregation_factory = tff.aggregators.MeanFactory()

  server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('server')
  client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('client')
  iterative_process = tff.learning.build_federated_averaging_process(
      model_fn=model_fn,
      server_optimizer_fn=server_optimizer_fn,
      client_weighting=client_weighting,
      client_optimizer_fn=client_optimizer_fn,
      model_update_aggregation_factory=aggregation_factory)

  client_datasets_fn = training_utils.build_client_datasets_fn(
      emnist_train, FLAGS.clients_per_round)

  evaluate_fn = training_utils.build_centralized_evaluate_fn(
      eval_dataset=emnist_test,
      model_builder=model_builder,
      loss_builder=loss_builder,
      metrics_builder=metrics_builder)
  validation_fn = lambda model_weights, round_num: evaluate_fn(model_weights)

  logging.info('Training model:')
  logging.info(model_builder().summary())

  # Log hyperparameters to CSV
  hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags())
  results_dir = os.path.join(FLAGS.root_output_dir, 'results',
                             FLAGS.experiment_name)
  utils_impl.create_directory_if_not_exists(results_dir)
  hparam_file = os.path.join(results_dir, 'hparams.csv')
  utils_impl.atomic_write_series_to_csv(hparam_dict, hparam_file)

  training_loop.run(
      iterative_process=iterative_process,
      client_datasets_fn=client_datasets_fn,
      validation_fn=validation_fn,
      total_rounds=FLAGS.total_rounds,
      experiment_name=FLAGS.experiment_name,
      root_output_dir=FLAGS.root_output_dir,
      rounds_per_eval=FLAGS.rounds_per_eval,
      rounds_per_checkpoint=FLAGS.rounds_per_checkpoint,
      rounds_per_profile=FLAGS.rounds_per_profile)
Exemplo n.º 19
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')

    client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'client')
    server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'server')

    if FLAGS.schedule == 'importance':
        fed_avg_schedule = importance_schedule
    elif FLAGS.schedule == 'loss':
        fed_avg_schedule = fed_loss
    else:
        fed_avg_schedule = fed_avg

    if FLAGS.schedule == 'importance':

        def iterative_process_builder(
            model_fn: Callable[[], tff.learning.Model],
            client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
        ) -> tff.templates.IterativeProcess:

            factory = importance_aggregation_factory.ImportanceSamplingFactory(
                FLAGS.clients_per_round)
            weights_type = importance_aggregation_factory.weights_type_from_model_fn(
                model_fn)
            importance_aggregation_process = factory.create(
                value_type=weights_type,
                weight_type=tff.TensorType(tf.float32))

            return importance_schedule.build_fed_avg_process(
                model_fn=model_fn,
                client_optimizer_fn=client_optimizer_fn,
                client_lr=client_lr_schedule,
                server_optimizer_fn=server_optimizer_fn,
                server_lr=server_lr_schedule,
                aggregation_process=importance_aggregation_process)
    elif FLAGS.schedule == 'loss':

        def iterative_process_builder(
            model_fn: Callable[[], tff.learning.Model],
            client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
        ) -> tff.templates.IterativeProcess:
            """Creates an iterative process using a given TFF `model_fn`.

      Args:
        model_fn: A no-arg function returning a `tff.learning.Model`.
        client_weight_fn: Optional function that takes the output of
          `model.report_local_outputs` and returns a tensor providing the weight
          in the federated average of model deltas. If not provided, the default
          is the total number of examples processed on device.

      Returns:
        A `tff.templates.IterativeProcess`.
      """
            return fed_avg_schedule.build_fed_avg_process(
                total_clients=FLAGS.loss_pool_size,
                effective_num_clients=FLAGS.clients_per_round,
                model_fn=model_fn,
                client_optimizer_fn=client_optimizer_fn,
                client_lr=client_lr_schedule,
                server_optimizer_fn=server_optimizer_fn,
                server_lr=server_lr_schedule,
                client_weight_fn=client_weight_fn,
                aggregation_process=None)
    else:

        def iterative_process_builder(
            model_fn: Callable[[], tff.learning.Model],
            client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
        ) -> tff.templates.IterativeProcess:
            """Creates an iterative process using a given TFF `model_fn`.

      Args:
        model_fn: A no-arg function returning a `tff.learning.Model`.
        client_weight_fn: Optional function that takes the output of
          `model.report_local_outputs` and returns a tensor providing the weight
          in the federated average of model deltas. If not provided, the default
          is the total number of examples processed on device.

      Returns:
        A `tff.templates.IterativeProcess`.
      """
            return fed_avg_schedule.build_fed_avg_process(
                model_fn=model_fn,
                client_optimizer_fn=client_optimizer_fn,
                client_lr=client_lr_schedule,
                server_optimizer_fn=server_optimizer_fn,
                server_lr=server_lr_schedule,
                client_weight_fn=client_weight_fn)

    shared_args = utils_impl.lookup_flag_values(shared_flags)
    shared_args['iterative_process_builder'] = iterative_process_builder
    task_args = _get_task_args()
    hparam_dict = _get_hparam_flags()
    # shared_args['prob_transmit'] = FLAGS.prob_transmit

    if FLAGS.task == 'cifar100':
        run_federated_fn = federated_cifar100.run_federated

    elif FLAGS.task == 'emnist_cr':
        run_federated_fn = federated_emnist.run_federated
    elif FLAGS.task == 'emnist_ae':
        run_federated_fn = federated_emnist_ae.run_federated
    elif FLAGS.task == 'shakespeare':
        run_federated_fn = federated_shakespeare.run_federated
    elif FLAGS.task == 'stackoverflow_nwp':
        run_federated_fn = federated_stackoverflow.run_federated
    elif FLAGS.task == 'stackoverflow_lr':
        run_federated_fn = federated_stackoverflow_lr.run_federated
    elif FLAGS.task == 'synthetic':
        run_federated_fn = federated_synthetic.run_federated
    else:
        raise ValueError(
            '--task flag {} is not supported, must be one of {}.'.format(
                FLAGS.task, _SUPPORTED_TASKS))
    run_federated_fn(**shared_args,
                     **task_args,
                     beta=FLAGS.beta,
                     hparam_dict=hparam_dict,
                     schedule=FLAGS.schedule)
Exemplo n.º 20
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    emnist_train, emnist_test = emnist_dataset.get_emnist_datasets(
        FLAGS.client_batch_size,
        FLAGS.client_epochs_per_round,
        only_digits=False)

    if FLAGS.model == 'cnn':
        model_builder = functools.partial(
            emnist_models.create_conv_dropout_model, only_digits=False)
    elif FLAGS.model == '2nn':
        model_builder = functools.partial(
            emnist_models.create_two_hidden_layer_model, only_digits=False)
    else:
        raise ValueError('Cannot handle model flag [{!s}].'.format(
            FLAGS.model))

    loss_builder = tf.keras.losses.SparseCategoricalCrossentropy
    metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()]

    if FLAGS.uniform_weighting:

        def client_weight_fn(local_outputs):
            del local_outputs
            return 1.0

    else:
        client_weight_fn = None  #  Defaults to the number of examples per client.

    def model_fn():
        return tff.learning.from_keras_model(
            model_builder(),
            loss_builder(),
            input_spec=emnist_test.element_spec,
            metrics=metrics_builder())

    if FLAGS.noise_multiplier is not None:
        if not FLAGS.uniform_weighting:
            raise ValueError(
                'Differential privacy is only implemented for uniform weighting.'
            )

        dp_query = tff.utils.build_dp_query(
            clip=FLAGS.clip,
            noise_multiplier=FLAGS.noise_multiplier,
            expected_total_weight=FLAGS.clients_per_round,
            adaptive_clip_learning_rate=FLAGS.adaptive_clip_learning_rate,
            target_unclipped_quantile=FLAGS.target_unclipped_quantile,
            clipped_count_budget_allocation=FLAGS.
            clipped_count_budget_allocation,
            expected_clients_per_round=FLAGS.clients_per_round,
            per_vector_clipping=FLAGS.per_vector_clipping,
            model=model_fn())

        weights_type = tff.learning.framework.weights_type_from_model(model_fn)
        aggregation_process = tff.utils.build_dp_aggregate_process(
            weights_type.trainable, dp_query)
    else:
        aggregation_process = None

    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')
    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    iterative_process = tff.learning.build_federated_averaging_process(
        model_fn=model_fn,
        server_optimizer_fn=server_optimizer_fn,
        client_weight_fn=client_weight_fn,
        client_optimizer_fn=client_optimizer_fn,
        aggregation_process=aggregation_process)

    client_datasets_fn = training_utils.build_client_datasets_fn(
        emnist_train, FLAGS.clients_per_round)

    evaluate_fn = training_utils.build_evaluate_fn(
        eval_dataset=emnist_test,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    logging.info('Training model:')
    logging.info(model_builder().summary())

    hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags())
    training_loop_dict = utils_impl.lookup_flag_values(training_loop_flags)

    training_loop.run(iterative_process=iterative_process,
                      client_datasets_fn=client_datasets_fn,
                      validation_fn=evaluate_fn,
                      hparam_dict=hparam_dict,
                      **training_loop_dict)
Exemplo n.º 21
0
def run_experiment():
    """Data preprocessing and experiment execution."""
    emnist_train, emnist_test = emnist_dataset.get_emnist_datasets(
        FLAGS.client_batch_size,
        FLAGS.client_epochs_per_round,
        only_digits=FLAGS.only_digits)

    example_dataset = emnist_train.create_tf_dataset_for_client(
        emnist_train.client_ids[0])
    input_spec = example_dataset.element_spec

    client_datasets_fn = training_utils.build_client_datasets_fn(
        emnist_train, FLAGS.clients_per_round)

    evaluate_fn = training_utils.build_centralized_evaluate_fn(
        eval_dataset=emnist_test,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    client_optimizer_fn = functools.partial(
        utils_impl.create_optimizer_from_flags, 'client')
    server_optimizer_fn = functools.partial(
        utils_impl.create_optimizer_from_flags, 'server')

    def tff_model_fn():
        keras_model = model_builder()
        return tff.learning.from_keras_model(keras_model,
                                             input_spec=input_spec,
                                             loss=loss_builder(),
                                             metrics=metrics_builder())

    if FLAGS.use_compression:
        # We create a `MeasuredProcess` for broadcast process and a
        # `MeasuredProcess` for aggregate process by providing the
        # `_broadcast_encoder_fn` and `_mean_encoder_fn` to corresponding utilities.
        # The fns are called once for each of the model weights created by
        # tff_model_fn, and return instances of appropriate encoders.
        encoded_broadcast_process = (
            tff.learning.framework.build_encoded_broadcast_process_from_model(
                tff_model_fn, _broadcast_encoder_fn))
        encoded_mean_process = (
            tff.learning.framework.build_encoded_mean_process_from_model(
                tff_model_fn, _mean_encoder_fn))
    else:
        encoded_broadcast_process = None
        encoded_mean_process = None

    iterative_process = tff.learning.build_federated_averaging_process(
        model_fn=tff_model_fn,
        client_optimizer_fn=client_optimizer_fn,
        server_optimizer_fn=server_optimizer_fn,
        aggregation_process=encoded_mean_process,
        broadcast_process=encoded_broadcast_process)

    hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags())
    training_loop_dict = utils_impl.lookup_flag_values(training_loop_flags)

    training_loop.run(iterative_process=iterative_process,
                      client_datasets_fn=client_datasets_fn,
                      validation_fn=evaluate_fn,
                      hparam_dict=hparam_dict,
                      **training_loop_dict)
Exemplo n.º 22
0
def run_experiment():
    """Data preprocessing and experiment execution."""
    emnist_train, _ = emnist_dataset.get_federated_datasets(
        train_client_batch_size=FLAGS.client_batch_size,
        train_client_epochs_per_round=FLAGS.client_epochs_per_round,
        only_digits=False)

    _, emnist_test = emnist_dataset.get_centralized_datasets()

    example_dataset = emnist_train.create_tf_dataset_for_client(
        emnist_train.client_ids[0])
    input_spec = example_dataset.element_spec

    client_datasets_fn = training_utils.build_client_datasets_fn(
        emnist_train, FLAGS.clients_per_round)

    evaluate_fn = training_utils.build_centralized_evaluate_fn(
        eval_dataset=emnist_test,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)
    validation_fn = lambda model_weights, round_num: evaluate_fn(model_weights)

    client_optimizer_fn = functools.partial(
        utils_impl.create_optimizer_from_flags, 'client')
    server_optimizer_fn = functools.partial(
        utils_impl.create_optimizer_from_flags, 'server')

    def tff_model_fn():
        keras_model = model_builder()
        return tff.learning.from_keras_model(keras_model,
                                             input_spec=input_spec,
                                             loss=loss_builder(),
                                             metrics=metrics_builder())

    if FLAGS.use_compression:
        # We create a `MeasuredProcess` for broadcast process and a
        # `MeasuredProcess` for aggregate process by providing the
        # `_broadcast_encoder_fn` and `_mean_encoder_fn` to corresponding utilities.
        # The fns are called once for each of the model weights created by
        # tff_model_fn, and return instances of appropriate encoders.
        encoded_broadcast_process = (
            tff.learning.framework.build_encoded_broadcast_process_from_model(
                tff_model_fn, _broadcast_encoder_fn))
        encoded_mean_process = (
            tff.learning.framework.build_encoded_mean_process_from_model(
                tff_model_fn, _mean_encoder_fn))
    else:
        encoded_broadcast_process = None
        encoded_mean_process = None

    iterative_process = tff.learning.build_federated_averaging_process(
        model_fn=tff_model_fn,
        client_optimizer_fn=client_optimizer_fn,
        server_optimizer_fn=server_optimizer_fn,
        aggregation_process=encoded_mean_process,
        broadcast_process=encoded_broadcast_process)

    # Log hyperparameters to CSV
    hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags())
    results_dir = os.path.join(FLAGS.root_output_dir, 'results',
                               FLAGS.experiment_name)
    utils_impl.create_directory_if_not_exists(results_dir)
    hparam_file = os.path.join(results_dir, 'hparams.csv')
    utils_impl.atomic_write_series_to_csv(hparam_dict, hparam_file)

    training_loop.run(iterative_process=iterative_process,
                      client_datasets_fn=client_datasets_fn,
                      validation_fn=validation_fn,
                      total_rounds=FLAGS.total_rounds,
                      experiment_name=FLAGS.experiment_name,
                      root_output_dir=FLAGS.root_output_dir,
                      rounds_per_eval=FLAGS.rounds_per_eval,
                      rounds_per_checkpoint=FLAGS.rounds_per_checkpoint,
                      rounds_per_profile=FLAGS.rounds_per_profile)
Exemplo n.º 23
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')
    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    if FLAGS.task == 'stackoverflow_nwp_finetune':
        if not FLAGS.global_variables_only:
            raise ValueError('`FLAGS.global_variables_only` must be True for '
                             'a `stackoverflow_nwp_finetune` task.')
        if not FLAGS.client_epochs_per_round:
            raise ValueError('`FLAGS.client_epochs_per_round` must be set for '
                             'a `stackoverflow_nwp_finetune` task.')
        finetune_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
            'finetune')
    else:
        reconstruction_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
            'reconstruction')

    def iterative_process_builder(
        model_fn: Callable[[], reconstruction_model.ReconstructionModel],
        loss_fn: Callable[[], List[tf.keras.losses.Loss]],
        metrics_fn: Optional[Callable[[],
                                      List[tf.keras.metrics.Metric]]] = None,
        client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
        dataset_split_fn_builder: Callable[
            ..., reconstruction_utils.DatasetSplitFn] = reconstruction_utils.
        build_dataset_split_fn,
    ) -> tff.templates.IterativeProcess:
        """Creates an iterative process using a given TFF `model_fn`.

    For a `stackoverflow_nwp_finetune` task, the `model_fn` must return a model
    that has only global variables, and the argument `dataset_split_fn_builder`
    is ignored. The returned iterative process is basically the same as the one
    created by the standard `tff.learning.build_federated_averaging_process`.

    For other tasks, the returned iterative process performs the federated
    reconstruction algorithm defined by
    `training_process.build_federated_reconstruction_process`.

    Args:
      model_fn: A no-arg function returning a
        `reconstruction_model.ReconstructionModel`. The returned model must have
        only global variables for a `stackoverflow_nwp_finetune` task.
      loss_fn: A no-arg function returning a list of `tf.keras.losses.Loss`.
      metrics_fn: A no-arg function returning a list of
        `tf.keras.metrics.Metric`.
      client_weight_fn: Optional function that takes the local model's output,
        and returns a tensor that provides the weight in the federated average
        of model deltas. If not provided, the default is the total number of
        examples processed on device. If DP is used, this argument is ignored,
        and uniform client weighting is used.
      dataset_split_fn_builder: `DatasetSplitFn` builder. Returns a method used
        to split the examples into a reconstruction, and post-reconstruction
        set. Ignored for a `stackoverflow_nwp_finetune` task.

    Raises:
      ValueError: if `model_fn` returns a model with local variables for a
        `stackoverflow_nwp_finetune` task.

    Returns:
      A `tff.templates.IterativeProcess`.
    """

        # Get aggregation factory for DP, if needed.
        aggregation_factory = None
        client_weighting = client_weight_fn
        if FLAGS.dp_noise_multiplier is not None:
            aggregation_factory = tff.learning.dp_aggregator(
                noise_multiplier=FLAGS.dp_noise_multiplier,
                clients_per_round=float(FLAGS.clients_per_round),
                zeroing=FLAGS.dp_zeroing)
            # DP is only implemented for uniform weighting.
            client_weighting = lambda _: 1.0

        if FLAGS.task == 'stackoverflow_nwp_finetune':

            if not reconstruction_utils.has_only_global_variables(model_fn()):
                raise ValueError(
                    '`model_fn` should return a model with only global variables. '
                )

            def fake_dataset_split_fn(
                client_dataset: tf.data.Dataset, round_num: tf.Tensor
            ) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
                del round_num
                return client_dataset.repeat(0), client_dataset.repeat(
                    FLAGS.client_epochs_per_round)

            return training_process.build_federated_reconstruction_process(
                model_fn=model_fn,
                loss_fn=loss_fn,
                metrics_fn=metrics_fn,
                server_optimizer_fn=lambda: server_optimizer_fn(
                    FLAGS.server_learning_rate),
                client_optimizer_fn=lambda: client_optimizer_fn(
                    FLAGS.client_learning_rate),
                dataset_split_fn=fake_dataset_split_fn,
                client_weight_fn=client_weighting,
                aggregation_factory=aggregation_factory)

        return training_process.build_federated_reconstruction_process(
            model_fn=model_fn,
            loss_fn=loss_fn,
            metrics_fn=metrics_fn,
            server_optimizer_fn=lambda: server_optimizer_fn(
                FLAGS.server_learning_rate),
            client_optimizer_fn=lambda: client_optimizer_fn(
                FLAGS.client_learning_rate),
            reconstruction_optimizer_fn=functools.partial(
                reconstruction_optimizer_fn,
                FLAGS.reconstruction_learning_rate),
            dataset_split_fn=dataset_split_fn_builder(
                recon_epochs_max=FLAGS.recon_epochs_max,
                recon_epochs_constant=FLAGS.recon_epochs_constant,
                recon_steps_max=FLAGS.recon_steps_max,
                post_recon_epochs=FLAGS.post_recon_epochs,
                post_recon_steps_max=FLAGS.post_recon_steps_max,
                split_dataset=FLAGS.split_dataset),
            evaluate_reconstruction=FLAGS.evaluate_reconstruction,
            jointly_train_variables=FLAGS.jointly_train_variables,
            client_weight_fn=client_weighting,
            aggregation_factory=aggregation_factory)

    def evaluation_computation_builder(
        model_fn: Callable[[], reconstruction_model.ReconstructionModel],
        loss_fn: Callable[[], tf.losses.Loss],
        metrics_fn: Callable[[], List[tf.metrics.Metric]],
        dataset_split_fn_builder: Callable[
            ..., reconstruction_utils.DatasetSplitFn] = reconstruction_utils.
        build_dataset_split_fn,
    ) -> tff.Computation:
        """Creates a `tff.Computation` for federated evaluation.

    For a `stackoverflow_nwp_finetune` task, the returned `tff.Computation` is
    created by `federated_evaluation.build_federated_finetune_evaluation`. For
    other tasks, the returned `tff.Computation` is given by
    `evaluation_computation.build_federated_reconstruction_evaluation`.

    Args:
      model_fn: A no-arg function that returns a `ReconstructionModel`. The
        returned model must have only global variables for a
        `stackoverflow_nwp_finetune` task. This method must *not* capture
        Tensorflow tensors or variables and use them. Must be constructed
        entirely from scratch on each invocation, returning the same model each
        call will result in an error.
      loss_fn: A no-arg function returning a `tf.keras.losses.Loss` to use to
        evaluate the model. The final loss metric is the example-weighted mean
        loss across batches (and across clients).
      metrics_fn: A no-arg function returning a list of
        `tf.keras.metrics.Metric`s to use to evaluate the model. The final
        metrics are the example-weighted mean metrics across batches (and across
        clients).
      dataset_split_fn_builder: `DatasetSplitFn` builder. Returns a method used
        to split the examples into a reconstruction set (which is used as a
        fine-tuning set for a `stackoverflow_nwp_finetune` task), and an
        evaluation set.

    Returns:
      A `tff.Computation` for federated evaluation.
    """

        # For a `stackoverflow_nwp_finetune` task, the first dataset returned by
        # `dataset_split_fn` is used for fine-tuning global variables. For other
        # tasks, the first dataset is used for reconstructing local variables.
        dataset_split_fn = dataset_split_fn_builder(
            recon_epochs_max=FLAGS.recon_epochs_max,
            recon_epochs_constant=FLAGS.recon_epochs_constant,
            recon_steps_max=FLAGS.recon_steps_max,
            post_recon_epochs=FLAGS.post_recon_epochs,
            post_recon_steps_max=FLAGS.post_recon_steps_max,
            # Getting meaningful evaluation metrics requires splitting the data.
            split_dataset=True)

        if FLAGS.task == 'stackoverflow_nwp_finetune':
            return federated_evaluation.build_federated_finetune_evaluation(
                model_fn=model_fn,
                loss_fn=loss_fn,
                metrics_fn=metrics_fn,
                finetune_optimizer_fn=functools.partial(
                    finetune_optimizer_fn, FLAGS.finetune_learning_rate),
                dataset_split_fn=dataset_split_fn)

        return evaluation_computation.build_federated_reconstruction_evaluation(
            model_fn=model_fn,
            loss_fn=loss_fn,
            metrics_fn=metrics_fn,
            reconstruction_optimizer_fn=functools.partial(
                reconstruction_optimizer_fn,
                FLAGS.reconstruction_learning_rate),
            dataset_split_fn=dataset_split_fn)

    # Shared args, useful to support more tasks.
    shared_args = utils_impl.lookup_flag_values(shared_flags)
    shared_args['iterative_process_builder'] = iterative_process_builder
    shared_args[
        'evaluation_computation_builder'] = evaluation_computation_builder

    task_args = _get_task_args()
    _write_hparam_flags()

    if FLAGS.task in ['stackoverflow_nwp', 'stackoverflow_nwp_finetune']:
        run_federated_fn = federated_stackoverflow.run_federated
    elif FLAGS.task == 'movielens_mf':
        run_federated_fn = federated_movielens.run_federated
    else:
        raise ValueError(
            '--task flag {} is not supported, must be one of {}.'.format(
                FLAGS.task, _SUPPORTED_TASKS))

    run_federated_fn(**shared_args, **task_args)