示例#1
0
def validator(
  model_fn: MODEL_FN,
  client_state_fn: CLIENT_STATE_FN
):
  model = model_fn()
  client_state = client_state_fn()

  dataset_type = tff.SequenceType(model.input_spec)
  client_state_type = tff.framework.type_from_tensors(client_state)

  validate_client_tf = tff.tf_computation(
    lambda dataset, state: __validate_client(
      dataset,
      state,
      model_fn,
      tf.function(client.validate)
    ),
    (dataset_type, client_state_type)
  )

  federated_dataset_type = tff.type_at_clients(dataset_type)
  federated_client_state_type = tff.type_at_clients(client_state_type)    

  def validate(datasets, client_states):
    outputs = tff.federated_map(validate_client_tf, (datasets, client_states))
    metrics = model.federated_output_computation(outputs.metrics)

    return metrics

  return tff.federated_computation(
    validate,
    (federated_dataset_type, federated_client_state_type)
  )
示例#2
0
def evaluator(coefficient_fn: COEFFICIENT_FN, model_fn: MODEL_FN,
              client_state_fn: CLIENT_STATE_FN):
    model = model_fn()
    client_state = client_state_fn()

    dataset_type = tff.SequenceType(model.input_spec)
    client_state_type = tff.framework.type_from_tensors(client_state)
    weights_type = tff.framework.type_from_tensors(
        tff.learning.ModelWeights.from_model(model))

    evaluate_client_tf = tff.tf_computation(
        lambda dataset, state, weights: __evaluate_client(
            dataset, state, weights, coefficient_fn, model_fn,
            tf.function(client.evaluate)),
        (dataset_type, client_state_type, weights_type))

    federated_weights_type = tff.type_at_server(weights_type)
    federated_dataset_type = tff.type_at_clients(dataset_type)
    federated_client_state_type = tff.type_at_clients(client_state_type)

    def evaluate(weights, datasets, client_states):
        broadcast = tff.federated_broadcast(weights)
        outputs = tff.federated_map(evaluate_client_tf,
                                    (datasets, client_states, broadcast))

        confusion_matrix = tff.federated_sum(outputs.confusion_matrix)
        aggregated_metrics = model.federated_output_computation(
            outputs.metrics)
        collected_metrics = tff.federated_collect(outputs.metrics)

        return confusion_matrix, aggregated_metrics, collected_metrics

    return tff.federated_computation(
        evaluate, (federated_weights_type, federated_dataset_type,
                   federated_client_state_type))
示例#3
0
def validator(coefficient_fn: COEFFICIENT_FN, model_fn: MODEL_FN,
              client_state_fn: CLIENT_STATE_FN):
    model = model_fn()
    client_state = client_state_fn()

    dataset_type = tff.SequenceType(model.input_spec)
    client_state_type = tff.framework.type_from_tensors(client_state)
    weights_type = tff.learning.framework.weights_type_from_model(model)

    validate_client_tf = tff.tf_computation(
        lambda dataset, state, weights: __validate_client(
            dataset, state, weights, coefficient_fn, model_fn,
            tf.function(client.validate)),
        (dataset_type, client_state_type, weights_type))

    federated_weights_type = tff.type_at_server(weights_type)
    federated_dataset_type = tff.type_at_clients(dataset_type)
    federated_client_state_type = tff.type_at_clients(client_state_type)

    def validate(weights, datasets, client_states):
        broadcast = tff.federated_broadcast(weights)
        outputs = tff.federated_map(validate_client_tf,
                                    (datasets, client_states, broadcast))
        metrics = model.federated_output_computation(outputs.metrics)

        return metrics

    return tff.federated_computation(
        validate, (federated_weights_type, federated_dataset_type,
                   federated_client_state_type))
示例#4
0
def iterator(coefficient_fn: COEFFICIENT_FN, model_fn: MODEL_FN,
             client_state_fn: CLIENT_STATE_FN,
             server_optimizer_fn: OPTIMIZER_FN,
             client_optimizer_fn: OPTIMIZER_FN):
    model = model_fn()
    client_state = client_state_fn()

    init_tf = tff.tf_computation(
        lambda: __initialize_server(model_fn, server_optimizer_fn))

    server_state_type = init_tf.type_signature.result
    client_state_type = tff.framework.type_from_tensors(client_state)

    update_server_tf = tff.tf_computation(
        lambda state, weights_delta: __update_server(
            state, weights_delta, model_fn, server_optimizer_fn,
            tf.function(server.update)),
        (server_state_type, server_state_type.model.trainable))

    state_to_message_tf = tff.tf_computation(
        lambda state: __state_to_message(state, tf.function(server.to_message)
                                         ), server_state_type)

    dataset_type = tff.SequenceType(model.input_spec)
    server_message_type = state_to_message_tf.type_signature.result

    update_client_tf = tff.tf_computation(
        lambda dataset, state, message: __update_client(
            dataset, state, message, coefficient_fn, model_fn,
            client_optimizer_fn, tf.function(client.update)),
        (dataset_type, client_state_type, server_message_type))

    federated_server_state_type = tff.type_at_server(server_state_type)
    federated_dataset_type = tff.type_at_clients(dataset_type)
    federated_client_state_type = tff.type_at_clients(client_state_type)

    def init_tff():
        return tff.federated_value(init_tf(), tff.SERVER)

    def next_tff(server_state, datasets, client_states):
        message = tff.federated_map(state_to_message_tf, server_state)
        broadcast = tff.federated_broadcast(message)

        outputs = tff.federated_map(update_client_tf,
                                    (datasets, client_states, broadcast))
        weights_delta = tff.federated_mean(outputs.weights_delta,
                                           weight=outputs.client_weight)

        metrics = model.federated_output_computation(outputs.metrics)

        next_state = tff.federated_map(update_server_tf,
                                       (server_state, weights_delta))

        return next_state, metrics, outputs.client_state

    return tff.templates.IterativeProcess(
        initialize_fn=tff.federated_computation(init_tff),
        next_fn=tff.federated_computation(
            next_tff, (federated_server_state_type, federated_dataset_type,
                       federated_client_state_type)))
示例#5
0
    def create(self, value_type, weight_type):
        @tff.federated_computation()
        def init_fn():
            return tff.federated_value((), tff.SERVER)

        @tff.tf_computation(tf.float32, value_type, value_type)
        def update_weight_fn(weight, server_model, client_model):
            sqnorms = tf.nest.map_structure(lambda a, b: tf.norm(a - b)**2,
                                            server_model, client_model)
            sqnorm = tf.reduce_sum(sqnorms)
            return tf.math.divide_no_nan(
                weight, tf.math.maximum(self._tolerance, tf.math.sqrt(sqnorm)))

        @tff.federated_computation(init_fn.type_signature.result,
                                   tff.type_at_clients(value_type),
                                   tff.type_at_clients(weight_type))
        def next_fn(state, value, weight):
            aggregate = tff.federated_mean(value, weight=weight)
            for _ in range(self._num_communication_passes - 1):
                aggregate_at_client = tff.federated_broadcast(aggregate)
                updated_weight = tff.federated_map(
                    update_weight_fn, (weight, aggregate_at_client, value))
                aggregate = tff.federated_mean(value, weight=updated_weight)
            no_metrics = tff.federated_value((), tff.SERVER)
            return tff.templates.MeasuredProcessOutput(state, aggregate,
                                                       no_metrics)

        return tff.templates.AggregationProcess(init_fn, next_fn)
示例#6
0
    def __attrs_post_init__(self):
        self.gen_input_type = tensor_spec_for_batch(self.dummy_gen_input)
        self.real_data_type = tensor_spec_for_batch(self.dummy_real_data)

        # Model-weights based types
        self._generator = self.generator_model_fn()
        _ = self._generator(self.dummy_gen_input)
        if not isinstance(self._generator, tf.keras.models.Model):
            raise TypeError(
                'Expected `tf.keras.models.Model`, found {}.'.format(
                    type(self._generator)))
        self._discriminator = self.discriminator_model_fn()
        _ = self._discriminator(self.dummy_real_data)
        if not isinstance(self._discriminator, tf.keras.models.Model):
            raise TypeError(
                'Expected `tf.keras.models.Model`, found {}.'.format(
                    type(self._discriminator)))

        def vars_to_type(var_struct):
            # TODO(b/131681951): read_value() shouldn't be needed
            return tf.nest.map_structure(
                lambda v: tf.TensorSpec.from_tensor(v.read_value()),
                var_struct)

        self.discriminator_weights_type = vars_to_type(
            self._discriminator.weights)
        self.generator_weights_type = vars_to_type(self._generator.weights)

        self.from_server_type = gan_training_tf_fns.FromServer(
            generator_weights=self.generator_weights_type,
            discriminator_weights=self.discriminator_weights_type,
            meta_gen=self.generator_weights_type,
            meta_disc=self.discriminator_weights_type)

        self.client_gen_input_type = tff.type_at_clients(
            tff.SequenceType(self.gen_input_type))
        self.client_real_data_type = tff.type_at_clients(
            tff.SequenceType(self.real_data_type))
        self.server_gen_input_type = tff.type_at_server(
            tff.SequenceType(self.gen_input_type))

        if self.train_discriminator_dp_average_query is not None:
            self.aggregation_process = tff.aggregators.DifferentiallyPrivateFactory(
                query=self.train_discriminator_dp_average_query).create(
                    value_type=tff.to_type(self.discriminator_weights_type))
        else:
            self.aggregation_process = tff.aggregators.MeanFactory().create(
                value_type=tff.to_type(self.discriminator_weights_type),
                weight_type=tff.to_type(tf.float32))
示例#7
0
    def __attrs_post_init__(self):
        self.gen_input_type = tensor_spec_for_batch(self.dummy_gen_input)
        self.real_data_type = tensor_spec_for_batch(self.dummy_real_data)

        # Model-weights based types
        self._generator = self.generator_model_fn()
        _ = self._generator(self.dummy_gen_input)
        if not isinstance(self._generator, tf.keras.models.Model):
            raise TypeError(
                'Expected `tf.keras.models.Model`, found {}.'.format(
                    type(self._generator)))
        self._discriminator = self.discriminator_model_fn()
        _ = self._discriminator(self.dummy_real_data)
        if not isinstance(self._discriminator, tf.keras.models.Model):
            raise TypeError(
                'Expected `tf.keras.models.Model`, found {}.'.format(
                    type(self._discriminator)))

        def vars_to_type(var_struct):
            # TODO(b/131681951): read_value() shouldn't be needed
            return tf.nest.map_structure(
                lambda v: tf.TensorSpec.from_tensor(v.read_value()),
                var_struct)

        self.discriminator_weights_type = vars_to_type(
            self._discriminator.weights)
        self.generator_weights_type = vars_to_type(self._generator.weights)

        self.from_server_type = gan_training_tf_fns.FromServer(
            generator_weights=self.generator_weights_type,
            discriminator_weights=self.discriminator_weights_type)

        self.client_gen_input_type = tff.type_at_clients(
            tff.SequenceType(self.gen_input_type))
        self.client_real_data_type = tff.type_at_clients(
            tff.SequenceType(self.real_data_type))
        self.server_gen_input_type = tff.type_at_server(
            tff.SequenceType(self.gen_input_type))

        # Right now, the logic in this library is effectively "if DP use stateful
        # aggregator, else don't use stateful aggregator". An alternative
        # formulation would be to always use a stateful aggregator, but when not
        # using DP default the aggregator to be a stateless mean, e.g.,
        # https://github.com/tensorflow/federated/blob/master/tensorflow_federated/python/learning/framework/optimizer_utils.py#L283.
        if self.train_discriminator_dp_average_query is not None:
            self.dp_averaging_fn = tff.utils.build_dp_aggregate_process(
                value_type=tff.to_type(self.discriminator_weights_type),
                query=self.train_discriminator_dp_average_query)
示例#8
0
    def test_executes_dataset_concat_aggregation(self):

        tensor_spec = tf.TensorSpec(shape=[2], dtype=tf.float32)

        @tff.tf_computation
        def create_empty_ds():
            empty_tensor = tf.zeros(shape=[0] + tensor_spec.shape,
                                    dtype=tensor_spec.dtype)
            return tf.data.Dataset.from_tensor_slices(empty_tensor)

        @tff.tf_computation
        def concat_datasets(ds1, ds2):
            return ds1.concatenate(ds2)

        @tff.tf_computation
        def identity(ds):
            return ds

        @tff.federated_computation(
            tff.type_at_clients(tff.SequenceType(tensor_spec)))
        def do_a_federated_aggregate(client_ds):
            return tff.federated_aggregate(value=client_ds,
                                           zero=create_empty_ds(),
                                           accumulate=concat_datasets,
                                           merge=concat_datasets,
                                           report=identity)

        input_data = tf.data.Dataset.from_tensor_slices([[0.1, 0.2]])
        ds = do_a_federated_aggregate([input_data])
        self.assertIsInstance(ds, tf.data.Dataset)
示例#9
0
    def test_bad_type_coercion_raises(self):
        tensor_type = tff.TensorType(shape=[None], dtype=tf.float32)

        @tff.tf_computation(tensor_type)
        def foo(x):
            # We will pass in a tensor which passes the TFF type check, but fails the
            # reshape.
            return tf.reshape(x, [])

        @tff.federated_computation(tff.type_at_clients(tensor_type))
        def map_foo_at_clients(x):
            return tff.federated_map(foo, x)

        @tff.federated_computation(tff.type_at_server(tensor_type))
        def map_foo_at_server(x):
            return tff.federated_map(foo, x)

        bad_tensor = tf.constant([1.] * 10, dtype=tf.float32)
        good_tensor = tf.constant([1.], dtype=tf.float32)
        # Ensure running this computation at both placements, or unplaced, still
        # raises.
        with self.assertRaises(Exception):
            foo(bad_tensor)
        with self.assertRaises(Exception):
            map_foo_at_server(bad_tensor)
        with self.assertRaises(Exception):
            map_foo_at_clients([bad_tensor] * 10)
        # We give the distributed runtime a chance to clean itself up, otherwise
        # workers may be getting SIGABRT while they are handling another exception,
        # causing the test infra to crash. Making a successful call ensures that
        # cleanup happens after failures have been handled.
        map_foo_at_clients([good_tensor] * 10)
    def test_clip_type_properties_simple(self, value_type):
        factory = _clipped_sum()
        value_type = tff.to_type(value_type)
        process = factory.create(value_type)

        self.assertIsInstance(process, tff.templates.AggregationProcess)

        server_state_type = tff.type_at_server(
            ())  # Inner SumFactory has no state

        expected_initialize_type = tff.FunctionType(parameter=None,
                                                    result=server_state_type)
        self.assertTrue(
            process.initialize.type_signature.is_equivalent_to(
                expected_initialize_type))

        expected_measurements_type = tff.type_at_server(
            collections.OrderedDict(agg_process=()))

        expected_next_type = tff.FunctionType(
            parameter=collections.OrderedDict(
                state=server_state_type,
                value=tff.type_at_clients(value_type)),
            result=tff.templates.MeasuredProcessOutput(
                state=server_state_type,
                result=tff.type_at_server(value_type),
                measurements=expected_measurements_type))
        self.assertTrue(
            process.next.type_signature.is_equivalent_to(expected_next_type))
示例#11
0
    def test_computations_run_with_worker_restarts(self, context,
                                                   first_contexts,
                                                   second_contexts):
        @tff.tf_computation(tf.int32)
        def add_one(x):
            return x + 1

        @tff.federated_computation(tff.type_at_clients(tf.int32))
        def map_add_one(federated_arg):
            return tff.federated_map(add_one, federated_arg)

        context_stack = tff.framework.get_context_stack()
        with context_stack.install(context):

            with contextlib.ExitStack() as stack:
                for server_context in first_contexts:
                    stack.enter_context(server_context)
                result = map_add_one([0, 1])
                self.assertEqual(result, [1, 2])

            # Closing and re-entering the server contexts serves to simulate failures
            # and restarts at the workers. Restarts leave the workers in a state that
            # needs initialization again; entering the second context ensures that the
            # servers need to be reinitialized by the controller.
            with contextlib.ExitStack() as stack:
                for server_context in second_contexts:
                    stack.enter_context(server_context)
                result = map_add_one([0, 1])
                self.assertEqual(result, [1, 2])
示例#12
0
  def test_computations_run_with_worker_restarts_and_aggregation(
      self, context, aggregation_contexts, first_worker_contexts,
      second_worker_contexts):

    @tff.tf_computation(tf.int32)
    def add_one(x):
      return x + 1

    @tff.federated_computation(tff.type_at_clients(tf.int32))
    def map_add_one(federated_arg):
      return tff.federated_map(add_one, federated_arg)

    context_stack = tff.framework.get_context_stack()
    with context_stack.install(context):

      with contextlib.ExitStack() as aggregation_stack:
        for server_context in aggregation_contexts:
          aggregation_stack.enter_context(server_context)
        with contextlib.ExitStack() as first_worker_stack:
          for server_context in first_worker_contexts:
            first_worker_stack.enter_context(server_context)

          result = map_add_one([0, 1])
          self.assertEqual(result, [1, 2])

        # Reinitializing the workers without leaving the aggregation context
        # simulates a worker failure, while the aggregator keeps running.
        with contextlib.ExitStack() as second_worker_stack:
          for server_context in second_worker_contexts:
            second_worker_stack.enter_context(server_context)
          result = map_add_one([0, 1])
          self.assertEqual(result, [1, 2])
示例#13
0
    def test_federated_zip(self):
        @tff.federated_computation([tff.type_at_clients(tf.int32)] * 2)
        def foo(x):
            return tff.federated_zip(x)

        result = foo([[1, 2], [3, 4]])
        self.assertIsNotNone(result)
  def create(self, value_type: tff.Type) -> tff.templates.AggregationProcess:
    self._dp_sum_process = self._dp_sum.create(value_type)

    @tff.federated_computation()
    def init():
      # Invoke here to instantiate anything we need
      return self._dp_sum_process.initialize()

    @tff.tf_computation(value_type, tf.int32)
    def div(x, y):
      # Opaque shape manipulations
      return [tf.squeeze(tf.math.divide_no_nan(x, tf.cast(y, tf.float32)), 0)]

    @tff.federated_computation(init.type_signature.result,
                               tff.type_at_clients(value_type))
    def next_fn(state, value):
      one_at_clients = tff.federated_value(1, tff.CLIENTS)
      dp_sum = self._dp_sum_process.next(state, value)
      summed_one = tff.federated_sum(one_at_clients)
      return tff.templates.MeasuredProcessOutput(
          state=dp_sum.state,
          result=tff.federated_map(div, (dp_sum.result, summed_one)),
          measurements=dp_sum.measurements)

    return tff.templates.AggregationProcess(initialize_fn=init, next_fn=next_fn)
示例#15
0
    def test_executes_empty_sum(self):
        @tff.federated_computation(tff.type_at_clients(tf.int32))
        def fed_sum(x):
            return tff.federated_sum(x)

        result = fed_sum([])
        self.assertEqual(result, 0)
示例#16
0
    def test_computations_run_with_changing_clients(self, context,
                                                    server_contexts):
        self.skipTest('b/175155128')

        @tff.tf_computation(tf.int32)
        @tf.function
        def add_one(x):
            return x + 1

        @tff.federated_computation(tff.type_at_clients(tf.int32))
        def map_add_one(federated_arg):
            return tff.federated_map(add_one, federated_arg)

        context_stack = tff.framework.get_context_stack()
        with context_stack.install(context):

            with contextlib.ExitStack() as stack:
                for server_context in server_contexts:
                    stack.enter_context(server_context)
                result_two_clients = map_add_one([0, 1])
                self.assertEqual(result_two_clients, [1, 2])
                # Moving to three clients should be fine
                result_three_clients = map_add_one([0, 1, 2])
                # Running a 0-client function should also be OK
                self.assertEqual(add_one(0), 1)
                self.assertEqual(result_three_clients, [1, 2, 3])
                # Changing back to 2 clients should still succeed.
                second_result_two_clients = map_add_one([0, 1])
                self.assertEqual(second_result_two_clients, [1, 2])
                # Similarly, 3 clients again should be fine.
                second_result_three_clients = map_add_one([0, 1, 2])
                self.assertEqual(second_result_three_clients, [1, 2, 3])
def federated_output_computation_from_metrics(
        metrics: List[tf.keras.metrics.Metric]) -> tff.federated_computation:
    """Produces a federated computation for aggregating Keras metrics.

  This can be used to evaluate both Keras and non-Keras models using Keras
  metrics. Aggregates metrics across clients by summing their internal
  variables, producing new metrics with summed internal variables, and calling
  metric.result() on each. See `federated_aggregate_keras_metric` for details.

  Args:
    metrics: A List of `tf.keras.metrics.Metric` to aggregate.

  Returns:
    A `tff.federated_computation` aggregating metrics across clients by summing
    their internal variables, producing new metrics with summed internal
    variables, and calling metric.result() on each.
  """
    # Get a sample of metric variables to use to determine its type.
    sample_metric_variables = read_metric_variables(metrics)

    metric_variable_type_dict = tf.nest.map_structure(
        tf.TensorSpec.from_tensor, sample_metric_variables)
    federated_local_outputs_type = tff.type_at_clients(
        metric_variable_type_dict)

    def federated_output(local_outputs):
        return federated_aggregate_keras_metric(metrics, local_outputs)

    federated_output_computation = tff.federated_computation(
        federated_output, federated_local_outputs_type)
    return federated_output_computation
示例#18
0
def iterator(
  model_fn: MODEL_FN,
  client_state_fn: CLIENT_STATE_FN,
  client_optimizer_fn: OPTIMIZER_FN
):
  model = model_fn()
  client_state = client_state_fn()

  init_tf = tff.tf_computation(
    lambda: ()
  )
  
  server_state_type = init_tf.type_signature.result
  client_state_type = tff.framework.type_from_tensors(client_state)
  dataset_type = tff.SequenceType(model.input_spec)
  
  update_client_tf = tff.tf_computation(
    lambda dataset, state: __update_client(
      dataset,
      state,
      model_fn,
      client_optimizer_fn,
      tf.function(client.update)
    ),
    (dataset_type, client_state_type)
  )
  
  federated_server_state_type = tff.type_at_server(server_state_type)
  federated_dataset_type = tff.type_at_clients(dataset_type)
  federated_client_state_type = tff.type_at_clients(client_state_type)

  def init_tff():
    return tff.federated_value(init_tf(), tff.SERVER)
  
  def next_tff(server_state, datasets, client_states):
    outputs = tff.federated_map(update_client_tf, (datasets, client_states))
    metrics = model.federated_output_computation(outputs.metrics)

    return server_state, metrics, outputs.client_state

  return tff.templates.IterativeProcess(
    initialize_fn=tff.federated_computation(init_tff),
    next_fn=tff.federated_computation(
      next_tff,
      (federated_server_state_type, federated_dataset_type, federated_client_state_type)
    )
  )
示例#19
0
def build_federated_evaluation(
    model_fn: Callable[[], tff.learning.Model],
    metrics_builder: Callable[[], Sequence[tf.keras.metrics.Metric]]
) -> tff.federated_computation:
    """Builds a federated evaluation `tff.federated_computation`.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    metrics_builder: A no-arg function that returns a sequence of
      `tf.keras.metrics.Metric` objects. These metrics must have a callable
      `update_state` accepting `y_true` and `y_pred` arguments, corresponding to
      the true and predicted label, respectively.

  Returns:
    A `tff.federated_computation` that accepts model weights and federated data,
    and returns the evaluation metrics, aggregated in both uniform- and
    example-weighted manners.
  """
    # Wrap model construction in a graph to avoid polluting the global context
    # with variables created for this model.
    with tf.Graph().as_default():
        placeholder_model = model_fn()
        model_weights_type = tff.learning.framework.weights_type_from_model(
            placeholder_model)
        model_input_type = tff.SequenceType(placeholder_model.input_spec)

    @tff.tf_computation(model_weights_type, model_input_type)
    def compute_client_metrics(model_weights, federated_dataset):
        model = model_fn()
        metrics = metrics_builder()
        return compute_metrics(model, model_weights, metrics,
                               federated_dataset)

    @tff.federated_computation(tff.type_at_server(model_weights_type),
                               tff.type_at_clients(model_input_type))
    def federated_evaluate(model_weights, federated_dataset):
        client_model = tff.federated_broadcast(model_weights)
        client_metrics = tff.federated_map(compute_client_metrics,
                                           (client_model, federated_dataset))
        # Extract the number of examples in order to compute client weights
        num_examples = client_metrics.num_examples
        uniform_weighted_metrics = tff.federated_mean(client_metrics,
                                                      weight=None)
        example_weighted_metrics = tff.federated_mean(client_metrics,
                                                      weight=num_examples)
        # Aggregate the metrics in a single nested dictionary
        aggregate_metrics = collections.OrderedDict()
        aggregate_metrics[AggregationMethods.EXAMPLE_WEIGHTED.
                          value] = example_weighted_metrics
        aggregate_metrics[AggregationMethods.UNIFORM_WEIGHTED.
                          value] = uniform_weighted_metrics

        return aggregate_metrics

    return federated_evaluate
示例#20
0
def build_federated_averaging_process_attacked(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.1),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),
    aggregation_process=None,
    client_update_tf=ClientExplicitBoosting(boost_factor=1.0)):
    """Builds the TFF computations for optimization using federated averaging with potentially malicious clients.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    client_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer`, use during local client training.
    server_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer`, use to apply updates to the global model.
    aggregation_process: A 'tff.templates.MeasuredProcess' that aggregates model
      deltas placed@CLIENTS to an aggregated model delta placed@SERVER.
    client_update_tf: a 'tf.function' computes the ClientOutput.

  Returns:
    A `tff.templates.IterativeProcess`.
  """
    with tf.Graph().as_default():
        dummy_model_for_metadata = model_fn()
        weights_type = tff.learning.framework.weights_type_from_model(
            dummy_model_for_metadata)

    if aggregation_process is None:
        aggregation_process = tff.learning.framework.build_stateless_mean(
            model_delta_type=weights_type.trainable)

    server_init = build_server_init_fn(model_fn, server_optimizer_fn,
                                       aggregation_process.initialize)
    server_state_type = server_init.type_signature.result.member
    server_update_fn = build_server_update_fn(model_fn, server_optimizer_fn,
                                              server_state_type,
                                              server_state_type.model)
    tf_dataset_type = tff.SequenceType(dummy_model_for_metadata.input_spec)

    client_update_fn = build_client_update_fn(model_fn, client_optimizer_fn,
                                              client_update_tf,
                                              tf_dataset_type,
                                              server_state_type.model)

    federated_server_state_type = tff.type_at_server(server_state_type)

    federated_dataset_type = tff.type_at_clients(tf_dataset_type)

    run_one_round_tff = build_run_one_round_fn_attacked(
        server_update_fn, client_update_fn, aggregation_process,
        dummy_model_for_metadata, federated_server_state_type,
        federated_dataset_type)

    return tff.templates.IterativeProcess(initialize_fn=server_init,
                                          next_fn=run_one_round_tff)
示例#21
0
    def test_empty_mean_returns_nan(self):
        self.skipTest('b/200970992')
        # TODO(b/200970992): Standardize handling of this case. We currently have a
        # ZeroDivisionError, a RuntimeError, and a context that returns nan.

        @tff.federated_computation(tff.type_at_clients(tf.float32))
        def fed_mean(x):
            return tff.federated_mean(x)

        with self.assertRaises(RuntimeError):
            fed_mean([])
示例#22
0
    def create(self, value_type, weight_type):
        @tff.federated_computation()
        def initialize_fn():
            # state = AggregationState(self._num_participants)
            return tff.federated_value(self._num_participants, tff.SERVER)

        @tff.federated_computation(initialize_fn.type_signature.result,
                                   tff.type_at_clients(value_type),
                                   tff.type_at_clients(weight_type))
        def next_fn(state, value, weight):
            weighted_values = tff.federated_map(_mul, (value, weight))
            summed_value = tff.federated_sum(weighted_values)
            normalized_value = tff.federated_map(_div, (summed_value, state))
            measurements = tff.federated_value((), tff.SERVER)
            return tff.templates.MeasuredProcessOutput(
                state=state,
                result=normalized_value,
                measurements=measurements)

        return tff.templates.AggregationProcess(initialize_fn, next_fn)
示例#23
0
  def test_federated_zip_with_twenty_elements(self):
    # This test will fail if execution scales factorially with number of
    # elements zipped.
    num_element = 20
    num_clients = 2

    @tff.federated_computation([tff.type_at_clients(tf.int32)] * num_element)
    def foo(x):
      return tff.federated_zip(x)

    value = [list(range(num_clients))] * num_element
    result = foo(value)
    self.assertIsNotNone(result)
示例#24
0
    def test_repeated_invocations_of_map(self):
        @tff.tf_computation(tf.int32)
        def add_one(x):
            return x + 1

        @tff.federated_computation(tff.type_at_clients(tf.int32))
        def map_add_one(federated_arg):
            return tff.federated_map(add_one, federated_arg)

        result1 = map_add_one([0, 1])
        result2 = map_add_one([0, 1])

        self.assertIsNotNone(result1)
        self.assertEqual(result1, result2)
示例#25
0
def evaluator(
  model_fn: MODEL_FN,
  client_state_fn: CLIENT_STATE_FN
):
  model = model_fn()
  client_state = client_state_fn()

  dataset_type = tff.SequenceType(model.input_spec)
  client_state_type = tff.framework.type_from_tensors(client_state)

  evaluate_client_tf = tff.tf_computation(
    lambda dataset, state: __evaluate_client(
      dataset,
      state,
      model_fn,
      tf.function(client.evaluate)
    ),
    (dataset_type, client_state_type)
  )

  federated_dataset_type = tff.type_at_clients(dataset_type)
  federated_client_state_type = tff.type_at_clients(client_state_type)    

  def evaluate(datasets, client_states):
    outputs = tff.federated_map(evaluate_client_tf, (datasets, client_states))
    
    confusion_matrix = tff.federated_sum(outputs.confusion_matrix)
    aggregated_metrics = model.federated_output_computation(outputs.metrics)
    collected_metrics = tff.federated_collect(outputs.metrics)

    return confusion_matrix, aggregated_metrics, collected_metrics

  return tff.federated_computation(
    evaluate,
    (federated_dataset_type, federated_client_state_type)
  )
示例#26
0
def build_federated_averaging_process(
    model_fn,
    client_optimizer_fn,
    server_optimizer_fn=lambda: flars_optimizer.FLARSOptimizer(learning_rate=
                                                               1.0)):
    """Builds the TFF computations for optimization using federated averaging.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    client_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer` for the local client training.
    server_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer` for applying updates on the server.

  Returns:
    A `tff.templates.IterativeProcess`.
  """
    with tf.Graph().as_default():
        dummy_model_for_metadata = model_fn()
    type_signature_grads_norm = tuple(
        weight.dtype for weight in tf.nest.flatten(
            dummy_model_for_metadata.trainable_variables))

    server_init_tf = build_server_init_fn(model_fn, server_optimizer_fn)

    server_state_type = server_init_tf.type_signature.result
    server_update_fn = build_server_update_fn(model_fn, server_optimizer_fn,
                                              server_state_type,
                                              server_state_type.model,
                                              type_signature_grads_norm)

    tf_dataset_type = tff.SequenceType(dummy_model_for_metadata.input_spec)
    client_update_fn = build_client_update_fn(model_fn, client_optimizer_fn,
                                              tf_dataset_type,
                                              server_state_type.model)

    federated_server_state_type = tff.type_at_server(server_state_type)
    federated_dataset_type = tff.type_at_clients(tf_dataset_type)
    run_one_round_tff = build_run_one_round_fn(server_update_fn,
                                               client_update_fn,
                                               dummy_model_for_metadata,
                                               federated_server_state_type,
                                               federated_dataset_type)

    return tff.templates.IterativeProcess(
        initialize_fn=tff.federated_computation(
            lambda: tff.federated_eval(server_init_tf, tff.SERVER)),
        next_fn=run_one_round_tff)
    def test_build_with_preprocess_function(self):
        test_dataset = tf.data.Dataset.range(5)
        client_datasets_type = tff.type_at_clients(
            tff.SequenceType(test_dataset.element_spec))

        @tff.tf_computation(tff.SequenceType(test_dataset.element_spec))
        def preprocess_dataset(ds):
            def to_batch(x):
                return _Batch(
                    tf.fill(dims=(784, ), value=float(x) * 2.0),
                    tf.expand_dims(tf.cast(x + 1, dtype=tf.int64), axis=0))

            return ds.map(to_batch).batch(2)

        iterproc = fed_avg_schedule.build_fed_avg_process(
            _uncompiled_model_builder,
            TAU,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD)

        iterproc = tff.simulation.compose_dataset_computation_with_iterative_process(
            preprocess_dataset, iterproc)

        with tf.Graph().as_default():
            test_model_for_types = _uncompiled_model_builder()

        server_state_type = tff.FederatedType(
            fed_avg_schedule.ServerState(model=tff.framework.type_from_tensors(
                tff.learning.ModelWeights(
                    test_model_for_types.trainable_variables,
                    test_model_for_types.non_trainable_variables)),
                                         optimizer_state=(tf.int64, ),
                                         round_num=tf.float32,
                                         client_drift=tf.float32), tff.SERVER)
        metrics_type = test_model_for_types.federated_output_computation.type_signature.result

        expected_parameter_type = collections.OrderedDict(
            server_state=server_state_type,
            federated_dataset=client_datasets_type,
        )
        expected_result_type = (server_state_type, metrics_type)

        expected_type = tff.FunctionType(parameter=expected_parameter_type,
                                         result=expected_result_type)
        self.assertTrue(
            iterproc.next.type_signature.is_equivalent_to(expected_type),
            msg='{s}\n!={t}'.format(s=iterproc.next.type_signature,
                                    t=expected_type))
def get_federated_tokenize_fn(dataset_name, dataset_element_type_structure):
    """Get a federated tokenizer function."""
    @tff.tf_computation(tff.SequenceType(dataset_element_type_structure))
    def tokenize_dataset(dataset):
        """The TF computation to tokenize a dataset."""
        dataset = tokenize(dataset, dataset_name)
        return dataset

    @tff.federated_computation(
        tff.type_at_clients(tff.SequenceType(dataset_element_type_structure)))
    def tokenize_datasets(datasets):
        """The TFF computation to compute tokenized datasets."""
        tokenized_datasets = tff.federated_map(tokenize_dataset, datasets)
        return tokenized_datasets

    return tokenize_datasets
示例#29
0
    def test_polymorphism(self):
        @tff.tf_computation(tf.int32)
        def add_one(x):
            return x + 1

        @tff.federated_computation(tff.type_at_clients(tf.int32))
        def map_add_one(federated_arg):
            return tff.federated_map(add_one, federated_arg)

        result1 = map_add_one([0, 1])
        result2 = map_add_one([0, 1, 2])

        self.assertIsNotNone(result1)
        self.assertIsNotNone(result2)

        self.assertLen(result1, 2)
        self.assertLen(result2, 3)
示例#30
0
    def test_federated_collect_large_numbers_of_parameters(self):
        num_clients = 10
        model_size = 10**6
        client_models = [tf.ones([model_size]) for _ in range(num_clients)]
        client_data_type = tff.type_at_clients((tf.float32, [model_size]))

        @tff.federated_computation(client_data_type)
        def comp(client_data):
            return tff.federated_collect(client_data)

        start_time_seconds = time.time()
        result = comp(client_models)
        end_time_seconds = time.time()
        runtime = end_time_seconds - start_time_seconds
        if runtime > 10:
            raise RuntimeError(
                'comp should take much less than a second, but took ' +
                str(runtime))
        del result