def test_quantile_aggregation_for_num_examples(self):
    client_ids = [0, 1, 2, 3, 4]
    quantiles = [0, 0.25, 0.5, 0.75, 1.0]

    def create_single_value_ds(client_id):
      client_value = [[0.0]] * (client_id + 1)
      return tf.data.Dataset.from_tensor_slices(
          collections.OrderedDict(x=client_value, y=client_value)).batch(1)

    client_data = [create_single_value_ds(id) for id in client_ids]

    metrics_builder = lambda: [tf.keras.metrics.MeanSquaredError()]
    eval_fn = training_utils.build_federated_evaluate_fn(
        model_builder,
        metrics_builder,
        quantiles=quantiles)
    model_weights = tff.learning.ModelWeights.from_model(tff_model_fn())
    eval_metrics = eval_fn(model_weights, client_data)
    logging.info('Eval metrics: %s', eval_metrics)

    num_examples_quantiles = eval_metrics['num_examples']['quantiles']
    expected_quantile_values = [1.0, 2.0, 3.0, 4.0, 5.0]
    expected_quantiles = collections.OrderedDict(
        zip(quantiles, expected_quantile_values))
    self.assertAllClose(num_examples_quantiles, expected_quantiles)
  def test_eval_metrics_for_unbalanced_client_data(self):
    client_data = self._create_client_data(balanced=False)
    metrics_builder = lambda: [tf.keras.metrics.MeanSquaredError()]
    eval_fn = training_utils.build_federated_evaluate_fn(
        model_builder, metrics_builder)
    model_weights = tff.learning.ModelWeights.from_model(tff_model_fn())
    eval_metrics = eval_fn(model_weights, client_data)
    logging.info('Eval metrics: %s', eval_metrics)

    self.assertCountEqual(eval_metrics.keys(),
                          ['mean_squared_error', 'num_examples'])

    expected_keys = ['example_weighted', 'uniform_weighted', 'quantiles']
    self.assertCountEqual(eval_metrics['num_examples'].keys(), expected_keys)
    self.assertCountEqual(eval_metrics['mean_squared_error'].keys(),
                          expected_keys)

    self.assertNear(
        eval_metrics['num_examples']['uniform_weighted'], 2.5, err=1e-6)
    self.assertNear(
        eval_metrics['num_examples']['example_weighted'], 2.6, err=1e-6)

    expected_uniform_mse = 1.5
    expected_example_mse = 1.8
    self.assertNear(
        eval_metrics['mean_squared_error']['uniform_weighted'],
        expected_uniform_mse,
        err=1e-6)
    self.assertNear(
        eval_metrics['mean_squared_error']['example_weighted'],
        expected_example_mse,
        err=1e-6)
示例#3
0
    def test_quantile_aggregation_for_mse(self):
        client_ids = [0, 1, 2, 3, 4]
        quantiles = [0, 0.25, 0.5, 0.75, 1.0]

        def create_single_value_ds(client_id):
            client_value = [[float(client_id)]]
            return tf.data.Dataset.from_tensor_slices(
                collections.OrderedDict(x=client_value,
                                        y=client_value)).batch(1)

        client_data = tff.simulation.client_data.ConcreteClientData(
            client_ids=client_ids,
            create_tf_dataset_for_client_fn=create_single_value_ds)

        metrics_builder = lambda: [tf.keras.metrics.MeanSquaredError()]
        eval_fn = training_utils.build_federated_evaluate_fn(
            client_data,
            model_builder,
            metrics_builder,
            clients_per_round=5,
            quantiles=quantiles)
        model_weights = tff.learning.ModelWeights.from_model(tff_model_fn())
        eval_metrics = eval_fn(model_weights, round_num=1)
        logging.info('Eval metrics: %s', eval_metrics)

        mse_quantiles = eval_metrics['mean_squared_error']['quantiles']
        expected_quantile_values = [0.0, 1.0, 4.0, 9.0, 16.0]
        expected_quantiles = collections.OrderedDict(
            zip(quantiles, expected_quantile_values))
        self.assertEqual(mse_quantiles, expected_quantiles)
示例#4
0
    def test_eval_metrics_for_balanced_client_data(self):
        client_data = self._create_client_data(balanced=True)
        metrics_builder = lambda: [tf.keras.metrics.MeanSquaredError()]
        eval_fn = training_utils.build_federated_evaluate_fn(
            client_data, model_builder, metrics_builder, clients_per_round=2)
        model_weights = tff.learning.ModelWeights.from_model(tff_model_fn())
        eval_metrics = eval_fn(model_weights, round_num=1)
        logging.info('Eval metrics: %s', eval_metrics)

        self.assertCountEqual(eval_metrics.keys(),
                              ['mean_squared_error', 'num_examples'])

        # Testing correctness of sum-based metrics
        expected_sum_keys = ['summed', 'uniform_weighted', 'quantiles']
        self.assertCountEqual(eval_metrics['num_examples'].keys(),
                              expected_sum_keys)
        self.assertEqual(eval_metrics['num_examples']['summed'], 6)
        self.assertEqual(eval_metrics['num_examples']['uniform_weighted'], 3.0)

        # Testing correctness of mean-based metrics
        expected_keys = ['example_weighted', 'uniform_weighted', 'quantiles']
        self.assertCountEqual(eval_metrics['mean_squared_error'].keys(),
                              expected_keys)

        expected_mse = 0.875
        self.assertNear(eval_metrics['mean_squared_error']['uniform_weighted'],
                        expected_mse,
                        err=1e-6)
        self.assertNear(eval_metrics['mean_squared_error']['example_weighted'],
                        expected_mse,
                        err=1e-6)
示例#5
0
def _evaluation_fn():
    metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()]
    fed_eval_fn = training_utils.build_federated_evaluate_fn(
        model_builder=_keras_model_builder, metrics_builder=metrics_builder)
    fed_data = _federated_data()

    def eval_fn(model, round_num):
        del round_num
        return fed_eval_fn(model, fed_data)

    return eval_fn
示例#6
0
def configure_training(task_spec: training_specs.TaskSpec,
                       eval_spec: Optional[training_specs.EvalSpec] = None,
                       model: str = 'cnn') -> training_specs.RunnerSpec:
  """Configures training for the EMNIST character recognition task.

  This method will load and pre-process datasets and construct a model used for
  the task. It then uses `iterative_process_builder` to create an iterative
  process compatible with `federated_research.utils.training_loop`.

  Args:
    task_spec: A `TaskSpec` class for creating federated training tasks.
    eval_spec: An `EvalSpec` class for configuring federated evaluation. If set
      to None, centralized evaluation is used for validation and testing
      instead.
    model: A string specifying the model used for character recognition. Can be
      one of `cnn` and `2nn`, corresponding to a CNN model and a densely
      connected 2-layer model (respectively).

  Returns:
    A `RunnerSpec` containing attributes used for running the newly created
    federated task.
  """
  emnist_task = 'digit_recognition'

  emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(
      only_digits=False)

  train_preprocess_fn = emnist_dataset.create_preprocess_fn(
      num_epochs=task_spec.client_epochs_per_round,
      batch_size=task_spec.client_batch_size,
      emnist_task=emnist_task)

  input_spec = train_preprocess_fn.type_signature.result.element

  if model == 'cnn':
    model_builder = functools.partial(
        emnist_models.create_conv_dropout_model, only_digits=False)
  elif model == '2nn':
    model_builder = functools.partial(
        emnist_models.create_two_hidden_layer_model, only_digits=False)
  else:
    raise ValueError(
        'Cannot handle model flag [{!s}], must be one of {!s}.'.format(
            model, EMNIST_MODELS))

  loss_builder = tf.keras.losses.SparseCategoricalCrossentropy
  metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()]

  def tff_model_fn() -> tff.learning.Model:
    return tff.learning.from_keras_model(
        keras_model=model_builder(),
        input_spec=input_spec,
        loss=loss_builder(),
        metrics=metrics_builder())

  iterative_process = task_spec.iterative_process_builder(tff_model_fn)

  clients_per_train_round = min(task_spec.clients_per_round,
                                TOTAL_NUM_TRAIN_CLIENTS)

  if hasattr(emnist_train, 'dataset_computation'):

    @tff.tf_computation(tf.string)
    def build_train_dataset_from_client_id(client_id):
      client_dataset = emnist_train.dataset_computation(client_id)
      return train_preprocess_fn(client_dataset)

    training_process = tff.simulation.compose_dataset_computation_with_iterative_process(
        build_train_dataset_from_client_id, iterative_process)
    client_ids_fn = training_utils.build_sample_fn(
        emnist_train.client_ids,
        size=clients_per_train_round,
        replace=False,
        random_seed=task_spec.sampling_random_seed)
    # We convert the output to a list (instead of an np.ndarray) so that it can
    # be used as input to the iterative process.
    client_sampling_fn = lambda x: list(client_ids_fn(x))

  else:
    training_process = tff.simulation.compose_dataset_computation_with_iterative_process(
        train_preprocess_fn, iterative_process)
    client_sampling_fn = training_utils.build_client_datasets_fn(
        dataset=emnist_train,
        clients_per_round=clients_per_train_round,
        random_seed=task_spec.sampling_random_seed)

  training_process.get_model_weights = iterative_process.get_model_weights

  if eval_spec:

    if eval_spec.clients_per_validation_round is None:
      clients_per_validation_round = TOTAL_NUM_TEST_CLIENTS
    else:
      clients_per_validation_round = min(eval_spec.clients_per_validation_round,
                                         TOTAL_NUM_TEST_CLIENTS)

    if eval_spec.clients_per_test_round is None:
      clients_per_test_round = TOTAL_NUM_TEST_CLIENTS
    else:
      clients_per_test_round = min(eval_spec.clients_per_test_round,
                                   TOTAL_NUM_TEST_CLIENTS)

    test_preprocess_fn = emnist_dataset.create_preprocess_fn(
        num_epochs=1,
        batch_size=eval_spec.client_batch_size,
        shuffle_buffer_size=1,
        emnist_task=emnist_task)
    emnist_test = emnist_test.preprocess(test_preprocess_fn)

    def eval_metrics_builder():
      return [
          tf.keras.metrics.SparseCategoricalCrossentropy(),
          tf.keras.metrics.SparseCategoricalAccuracy()
      ]

    federated_eval_fn = training_utils.build_federated_evaluate_fn(
        model_builder=model_builder, metrics_builder=eval_metrics_builder)

    validation_client_sampling_fn = training_utils.build_client_datasets_fn(
        emnist_test,
        clients_per_validation_round,
        random_seed=eval_spec.sampling_random_seed)
    test_client_sampling_fn = training_utils.build_client_datasets_fn(
        emnist_test,
        clients_per_test_round,
        random_seed=eval_spec.sampling_random_seed)

    def validation_fn(model_weights, round_num):
      validation_clients = validation_client_sampling_fn(round_num)
      return federated_eval_fn(model_weights, validation_clients)

    def test_fn(model_weights):
      # We fix the round number to get deterministic behavior
      test_round_num = 0
      test_clients = test_client_sampling_fn(test_round_num)
      return federated_eval_fn(model_weights, test_clients)

  else:
    _, central_emnist_test = emnist_dataset.get_centralized_datasets(
        only_digits=False, emnist_task=emnist_task)

    test_fn = training_utils.build_centralized_evaluate_fn(
        eval_dataset=central_emnist_test,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    validation_fn = lambda model_weights, round_num: test_fn(model_weights)

  return training_specs.RunnerSpec(
      iterative_process=training_process,
      client_datasets_fn=client_sampling_fn,
      validation_fn=validation_fn,
      test_fn=test_fn)