def validator( model_fn: MODEL_FN, client_state_fn: CLIENT_STATE_FN ): model = model_fn() client_state = client_state_fn() dataset_type = tff.SequenceType(model.input_spec) client_state_type = tff.framework.type_from_tensors(client_state) validate_client_tf = tff.tf_computation( lambda dataset, state: __validate_client( dataset, state, model_fn, tf.function(client.validate) ), (dataset_type, client_state_type) ) federated_dataset_type = tff.type_at_clients(dataset_type) federated_client_state_type = tff.type_at_clients(client_state_type) def validate(datasets, client_states): outputs = tff.federated_map(validate_client_tf, (datasets, client_states)) metrics = model.federated_output_computation(outputs.metrics) return metrics return tff.federated_computation( validate, (federated_dataset_type, federated_client_state_type) )
def evaluator(coefficient_fn: COEFFICIENT_FN, model_fn: MODEL_FN, client_state_fn: CLIENT_STATE_FN): model = model_fn() client_state = client_state_fn() dataset_type = tff.SequenceType(model.input_spec) client_state_type = tff.framework.type_from_tensors(client_state) weights_type = tff.framework.type_from_tensors( tff.learning.ModelWeights.from_model(model)) evaluate_client_tf = tff.tf_computation( lambda dataset, state, weights: __evaluate_client( dataset, state, weights, coefficient_fn, model_fn, tf.function(client.evaluate)), (dataset_type, client_state_type, weights_type)) federated_weights_type = tff.type_at_server(weights_type) federated_dataset_type = tff.type_at_clients(dataset_type) federated_client_state_type = tff.type_at_clients(client_state_type) def evaluate(weights, datasets, client_states): broadcast = tff.federated_broadcast(weights) outputs = tff.federated_map(evaluate_client_tf, (datasets, client_states, broadcast)) confusion_matrix = tff.federated_sum(outputs.confusion_matrix) aggregated_metrics = model.federated_output_computation( outputs.metrics) collected_metrics = tff.federated_collect(outputs.metrics) return confusion_matrix, aggregated_metrics, collected_metrics return tff.federated_computation( evaluate, (federated_weights_type, federated_dataset_type, federated_client_state_type))
def validator(coefficient_fn: COEFFICIENT_FN, model_fn: MODEL_FN, client_state_fn: CLIENT_STATE_FN): model = model_fn() client_state = client_state_fn() dataset_type = tff.SequenceType(model.input_spec) client_state_type = tff.framework.type_from_tensors(client_state) weights_type = tff.learning.framework.weights_type_from_model(model) validate_client_tf = tff.tf_computation( lambda dataset, state, weights: __validate_client( dataset, state, weights, coefficient_fn, model_fn, tf.function(client.validate)), (dataset_type, client_state_type, weights_type)) federated_weights_type = tff.type_at_server(weights_type) federated_dataset_type = tff.type_at_clients(dataset_type) federated_client_state_type = tff.type_at_clients(client_state_type) def validate(weights, datasets, client_states): broadcast = tff.federated_broadcast(weights) outputs = tff.federated_map(validate_client_tf, (datasets, client_states, broadcast)) metrics = model.federated_output_computation(outputs.metrics) return metrics return tff.federated_computation( validate, (federated_weights_type, federated_dataset_type, federated_client_state_type))
def iterator(coefficient_fn: COEFFICIENT_FN, model_fn: MODEL_FN, client_state_fn: CLIENT_STATE_FN, server_optimizer_fn: OPTIMIZER_FN, client_optimizer_fn: OPTIMIZER_FN): model = model_fn() client_state = client_state_fn() init_tf = tff.tf_computation( lambda: __initialize_server(model_fn, server_optimizer_fn)) server_state_type = init_tf.type_signature.result client_state_type = tff.framework.type_from_tensors(client_state) update_server_tf = tff.tf_computation( lambda state, weights_delta: __update_server( state, weights_delta, model_fn, server_optimizer_fn, tf.function(server.update)), (server_state_type, server_state_type.model.trainable)) state_to_message_tf = tff.tf_computation( lambda state: __state_to_message(state, tf.function(server.to_message) ), server_state_type) dataset_type = tff.SequenceType(model.input_spec) server_message_type = state_to_message_tf.type_signature.result update_client_tf = tff.tf_computation( lambda dataset, state, message: __update_client( dataset, state, message, coefficient_fn, model_fn, client_optimizer_fn, tf.function(client.update)), (dataset_type, client_state_type, server_message_type)) federated_server_state_type = tff.type_at_server(server_state_type) federated_dataset_type = tff.type_at_clients(dataset_type) federated_client_state_type = tff.type_at_clients(client_state_type) def init_tff(): return tff.federated_value(init_tf(), tff.SERVER) def next_tff(server_state, datasets, client_states): message = tff.federated_map(state_to_message_tf, server_state) broadcast = tff.federated_broadcast(message) outputs = tff.federated_map(update_client_tf, (datasets, client_states, broadcast)) weights_delta = tff.federated_mean(outputs.weights_delta, weight=outputs.client_weight) metrics = model.federated_output_computation(outputs.metrics) next_state = tff.federated_map(update_server_tf, (server_state, weights_delta)) return next_state, metrics, outputs.client_state return tff.templates.IterativeProcess( initialize_fn=tff.federated_computation(init_tff), next_fn=tff.federated_computation( next_tff, (federated_server_state_type, federated_dataset_type, federated_client_state_type)))
def create(self, value_type, weight_type): @tff.federated_computation() def init_fn(): return tff.federated_value((), tff.SERVER) @tff.tf_computation(tf.float32, value_type, value_type) def update_weight_fn(weight, server_model, client_model): sqnorms = tf.nest.map_structure(lambda a, b: tf.norm(a - b)**2, server_model, client_model) sqnorm = tf.reduce_sum(sqnorms) return tf.math.divide_no_nan( weight, tf.math.maximum(self._tolerance, tf.math.sqrt(sqnorm))) @tff.federated_computation(init_fn.type_signature.result, tff.type_at_clients(value_type), tff.type_at_clients(weight_type)) def next_fn(state, value, weight): aggregate = tff.federated_mean(value, weight=weight) for _ in range(self._num_communication_passes - 1): aggregate_at_client = tff.federated_broadcast(aggregate) updated_weight = tff.federated_map( update_weight_fn, (weight, aggregate_at_client, value)) aggregate = tff.federated_mean(value, weight=updated_weight) no_metrics = tff.federated_value((), tff.SERVER) return tff.templates.MeasuredProcessOutput(state, aggregate, no_metrics) return tff.templates.AggregationProcess(init_fn, next_fn)
def __attrs_post_init__(self): self.gen_input_type = tensor_spec_for_batch(self.dummy_gen_input) self.real_data_type = tensor_spec_for_batch(self.dummy_real_data) # Model-weights based types self._generator = self.generator_model_fn() _ = self._generator(self.dummy_gen_input) if not isinstance(self._generator, tf.keras.models.Model): raise TypeError( 'Expected `tf.keras.models.Model`, found {}.'.format( type(self._generator))) self._discriminator = self.discriminator_model_fn() _ = self._discriminator(self.dummy_real_data) if not isinstance(self._discriminator, tf.keras.models.Model): raise TypeError( 'Expected `tf.keras.models.Model`, found {}.'.format( type(self._discriminator))) def vars_to_type(var_struct): # TODO(b/131681951): read_value() shouldn't be needed return tf.nest.map_structure( lambda v: tf.TensorSpec.from_tensor(v.read_value()), var_struct) self.discriminator_weights_type = vars_to_type( self._discriminator.weights) self.generator_weights_type = vars_to_type(self._generator.weights) self.from_server_type = gan_training_tf_fns.FromServer( generator_weights=self.generator_weights_type, discriminator_weights=self.discriminator_weights_type, meta_gen=self.generator_weights_type, meta_disc=self.discriminator_weights_type) self.client_gen_input_type = tff.type_at_clients( tff.SequenceType(self.gen_input_type)) self.client_real_data_type = tff.type_at_clients( tff.SequenceType(self.real_data_type)) self.server_gen_input_type = tff.type_at_server( tff.SequenceType(self.gen_input_type)) if self.train_discriminator_dp_average_query is not None: self.aggregation_process = tff.aggregators.DifferentiallyPrivateFactory( query=self.train_discriminator_dp_average_query).create( value_type=tff.to_type(self.discriminator_weights_type)) else: self.aggregation_process = tff.aggregators.MeanFactory().create( value_type=tff.to_type(self.discriminator_weights_type), weight_type=tff.to_type(tf.float32))
def __attrs_post_init__(self): self.gen_input_type = tensor_spec_for_batch(self.dummy_gen_input) self.real_data_type = tensor_spec_for_batch(self.dummy_real_data) # Model-weights based types self._generator = self.generator_model_fn() _ = self._generator(self.dummy_gen_input) if not isinstance(self._generator, tf.keras.models.Model): raise TypeError( 'Expected `tf.keras.models.Model`, found {}.'.format( type(self._generator))) self._discriminator = self.discriminator_model_fn() _ = self._discriminator(self.dummy_real_data) if not isinstance(self._discriminator, tf.keras.models.Model): raise TypeError( 'Expected `tf.keras.models.Model`, found {}.'.format( type(self._discriminator))) def vars_to_type(var_struct): # TODO(b/131681951): read_value() shouldn't be needed return tf.nest.map_structure( lambda v: tf.TensorSpec.from_tensor(v.read_value()), var_struct) self.discriminator_weights_type = vars_to_type( self._discriminator.weights) self.generator_weights_type = vars_to_type(self._generator.weights) self.from_server_type = gan_training_tf_fns.FromServer( generator_weights=self.generator_weights_type, discriminator_weights=self.discriminator_weights_type) self.client_gen_input_type = tff.type_at_clients( tff.SequenceType(self.gen_input_type)) self.client_real_data_type = tff.type_at_clients( tff.SequenceType(self.real_data_type)) self.server_gen_input_type = tff.type_at_server( tff.SequenceType(self.gen_input_type)) # Right now, the logic in this library is effectively "if DP use stateful # aggregator, else don't use stateful aggregator". An alternative # formulation would be to always use a stateful aggregator, but when not # using DP default the aggregator to be a stateless mean, e.g., # https://github.com/tensorflow/federated/blob/master/tensorflow_federated/python/learning/framework/optimizer_utils.py#L283. if self.train_discriminator_dp_average_query is not None: self.dp_averaging_fn = tff.utils.build_dp_aggregate_process( value_type=tff.to_type(self.discriminator_weights_type), query=self.train_discriminator_dp_average_query)
def test_executes_dataset_concat_aggregation(self): tensor_spec = tf.TensorSpec(shape=[2], dtype=tf.float32) @tff.tf_computation def create_empty_ds(): empty_tensor = tf.zeros(shape=[0] + tensor_spec.shape, dtype=tensor_spec.dtype) return tf.data.Dataset.from_tensor_slices(empty_tensor) @tff.tf_computation def concat_datasets(ds1, ds2): return ds1.concatenate(ds2) @tff.tf_computation def identity(ds): return ds @tff.federated_computation( tff.type_at_clients(tff.SequenceType(tensor_spec))) def do_a_federated_aggregate(client_ds): return tff.federated_aggregate(value=client_ds, zero=create_empty_ds(), accumulate=concat_datasets, merge=concat_datasets, report=identity) input_data = tf.data.Dataset.from_tensor_slices([[0.1, 0.2]]) ds = do_a_federated_aggregate([input_data]) self.assertIsInstance(ds, tf.data.Dataset)
def test_bad_type_coercion_raises(self): tensor_type = tff.TensorType(shape=[None], dtype=tf.float32) @tff.tf_computation(tensor_type) def foo(x): # We will pass in a tensor which passes the TFF type check, but fails the # reshape. return tf.reshape(x, []) @tff.federated_computation(tff.type_at_clients(tensor_type)) def map_foo_at_clients(x): return tff.federated_map(foo, x) @tff.federated_computation(tff.type_at_server(tensor_type)) def map_foo_at_server(x): return tff.federated_map(foo, x) bad_tensor = tf.constant([1.] * 10, dtype=tf.float32) good_tensor = tf.constant([1.], dtype=tf.float32) # Ensure running this computation at both placements, or unplaced, still # raises. with self.assertRaises(Exception): foo(bad_tensor) with self.assertRaises(Exception): map_foo_at_server(bad_tensor) with self.assertRaises(Exception): map_foo_at_clients([bad_tensor] * 10) # We give the distributed runtime a chance to clean itself up, otherwise # workers may be getting SIGABRT while they are handling another exception, # causing the test infra to crash. Making a successful call ensures that # cleanup happens after failures have been handled. map_foo_at_clients([good_tensor] * 10)
def test_clip_type_properties_simple(self, value_type): factory = _clipped_sum() value_type = tff.to_type(value_type) process = factory.create(value_type) self.assertIsInstance(process, tff.templates.AggregationProcess) server_state_type = tff.type_at_server( ()) # Inner SumFactory has no state expected_initialize_type = tff.FunctionType(parameter=None, result=server_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_measurements_type = tff.type_at_server( collections.OrderedDict(agg_process=())) expected_next_type = tff.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=tff.type_at_clients(value_type)), result=tff.templates.MeasuredProcessOutput( state=server_state_type, result=tff.type_at_server(value_type), measurements=expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def test_computations_run_with_worker_restarts(self, context, first_contexts, second_contexts): @tff.tf_computation(tf.int32) def add_one(x): return x + 1 @tff.federated_computation(tff.type_at_clients(tf.int32)) def map_add_one(federated_arg): return tff.federated_map(add_one, federated_arg) context_stack = tff.framework.get_context_stack() with context_stack.install(context): with contextlib.ExitStack() as stack: for server_context in first_contexts: stack.enter_context(server_context) result = map_add_one([0, 1]) self.assertEqual(result, [1, 2]) # Closing and re-entering the server contexts serves to simulate failures # and restarts at the workers. Restarts leave the workers in a state that # needs initialization again; entering the second context ensures that the # servers need to be reinitialized by the controller. with contextlib.ExitStack() as stack: for server_context in second_contexts: stack.enter_context(server_context) result = map_add_one([0, 1]) self.assertEqual(result, [1, 2])
def test_computations_run_with_worker_restarts_and_aggregation( self, context, aggregation_contexts, first_worker_contexts, second_worker_contexts): @tff.tf_computation(tf.int32) def add_one(x): return x + 1 @tff.federated_computation(tff.type_at_clients(tf.int32)) def map_add_one(federated_arg): return tff.federated_map(add_one, federated_arg) context_stack = tff.framework.get_context_stack() with context_stack.install(context): with contextlib.ExitStack() as aggregation_stack: for server_context in aggregation_contexts: aggregation_stack.enter_context(server_context) with contextlib.ExitStack() as first_worker_stack: for server_context in first_worker_contexts: first_worker_stack.enter_context(server_context) result = map_add_one([0, 1]) self.assertEqual(result, [1, 2]) # Reinitializing the workers without leaving the aggregation context # simulates a worker failure, while the aggregator keeps running. with contextlib.ExitStack() as second_worker_stack: for server_context in second_worker_contexts: second_worker_stack.enter_context(server_context) result = map_add_one([0, 1]) self.assertEqual(result, [1, 2])
def test_federated_zip(self): @tff.federated_computation([tff.type_at_clients(tf.int32)] * 2) def foo(x): return tff.federated_zip(x) result = foo([[1, 2], [3, 4]]) self.assertIsNotNone(result)
def create(self, value_type: tff.Type) -> tff.templates.AggregationProcess: self._dp_sum_process = self._dp_sum.create(value_type) @tff.federated_computation() def init(): # Invoke here to instantiate anything we need return self._dp_sum_process.initialize() @tff.tf_computation(value_type, tf.int32) def div(x, y): # Opaque shape manipulations return [tf.squeeze(tf.math.divide_no_nan(x, tf.cast(y, tf.float32)), 0)] @tff.federated_computation(init.type_signature.result, tff.type_at_clients(value_type)) def next_fn(state, value): one_at_clients = tff.federated_value(1, tff.CLIENTS) dp_sum = self._dp_sum_process.next(state, value) summed_one = tff.federated_sum(one_at_clients) return tff.templates.MeasuredProcessOutput( state=dp_sum.state, result=tff.federated_map(div, (dp_sum.result, summed_one)), measurements=dp_sum.measurements) return tff.templates.AggregationProcess(initialize_fn=init, next_fn=next_fn)
def test_executes_empty_sum(self): @tff.federated_computation(tff.type_at_clients(tf.int32)) def fed_sum(x): return tff.federated_sum(x) result = fed_sum([]) self.assertEqual(result, 0)
def test_computations_run_with_changing_clients(self, context, server_contexts): self.skipTest('b/175155128') @tff.tf_computation(tf.int32) @tf.function def add_one(x): return x + 1 @tff.federated_computation(tff.type_at_clients(tf.int32)) def map_add_one(federated_arg): return tff.federated_map(add_one, federated_arg) context_stack = tff.framework.get_context_stack() with context_stack.install(context): with contextlib.ExitStack() as stack: for server_context in server_contexts: stack.enter_context(server_context) result_two_clients = map_add_one([0, 1]) self.assertEqual(result_two_clients, [1, 2]) # Moving to three clients should be fine result_three_clients = map_add_one([0, 1, 2]) # Running a 0-client function should also be OK self.assertEqual(add_one(0), 1) self.assertEqual(result_three_clients, [1, 2, 3]) # Changing back to 2 clients should still succeed. second_result_two_clients = map_add_one([0, 1]) self.assertEqual(second_result_two_clients, [1, 2]) # Similarly, 3 clients again should be fine. second_result_three_clients = map_add_one([0, 1, 2]) self.assertEqual(second_result_three_clients, [1, 2, 3])
def federated_output_computation_from_metrics( metrics: List[tf.keras.metrics.Metric]) -> tff.federated_computation: """Produces a federated computation for aggregating Keras metrics. This can be used to evaluate both Keras and non-Keras models using Keras metrics. Aggregates metrics across clients by summing their internal variables, producing new metrics with summed internal variables, and calling metric.result() on each. See `federated_aggregate_keras_metric` for details. Args: metrics: A List of `tf.keras.metrics.Metric` to aggregate. Returns: A `tff.federated_computation` aggregating metrics across clients by summing their internal variables, producing new metrics with summed internal variables, and calling metric.result() on each. """ # Get a sample of metric variables to use to determine its type. sample_metric_variables = read_metric_variables(metrics) metric_variable_type_dict = tf.nest.map_structure( tf.TensorSpec.from_tensor, sample_metric_variables) federated_local_outputs_type = tff.type_at_clients( metric_variable_type_dict) def federated_output(local_outputs): return federated_aggregate_keras_metric(metrics, local_outputs) federated_output_computation = tff.federated_computation( federated_output, federated_local_outputs_type) return federated_output_computation
def iterator( model_fn: MODEL_FN, client_state_fn: CLIENT_STATE_FN, client_optimizer_fn: OPTIMIZER_FN ): model = model_fn() client_state = client_state_fn() init_tf = tff.tf_computation( lambda: () ) server_state_type = init_tf.type_signature.result client_state_type = tff.framework.type_from_tensors(client_state) dataset_type = tff.SequenceType(model.input_spec) update_client_tf = tff.tf_computation( lambda dataset, state: __update_client( dataset, state, model_fn, client_optimizer_fn, tf.function(client.update) ), (dataset_type, client_state_type) ) federated_server_state_type = tff.type_at_server(server_state_type) federated_dataset_type = tff.type_at_clients(dataset_type) federated_client_state_type = tff.type_at_clients(client_state_type) def init_tff(): return tff.federated_value(init_tf(), tff.SERVER) def next_tff(server_state, datasets, client_states): outputs = tff.federated_map(update_client_tf, (datasets, client_states)) metrics = model.federated_output_computation(outputs.metrics) return server_state, metrics, outputs.client_state return tff.templates.IterativeProcess( initialize_fn=tff.federated_computation(init_tff), next_fn=tff.federated_computation( next_tff, (federated_server_state_type, federated_dataset_type, federated_client_state_type) ) )
def build_federated_evaluation( model_fn: Callable[[], tff.learning.Model], metrics_builder: Callable[[], Sequence[tf.keras.metrics.Metric]] ) -> tff.federated_computation: """Builds a federated evaluation `tff.federated_computation`. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. metrics_builder: A no-arg function that returns a sequence of `tf.keras.metrics.Metric` objects. These metrics must have a callable `update_state` accepting `y_true` and `y_pred` arguments, corresponding to the true and predicted label, respectively. Returns: A `tff.federated_computation` that accepts model weights and federated data, and returns the evaluation metrics, aggregated in both uniform- and example-weighted manners. """ # Wrap model construction in a graph to avoid polluting the global context # with variables created for this model. with tf.Graph().as_default(): placeholder_model = model_fn() model_weights_type = tff.learning.framework.weights_type_from_model( placeholder_model) model_input_type = tff.SequenceType(placeholder_model.input_spec) @tff.tf_computation(model_weights_type, model_input_type) def compute_client_metrics(model_weights, federated_dataset): model = model_fn() metrics = metrics_builder() return compute_metrics(model, model_weights, metrics, federated_dataset) @tff.federated_computation(tff.type_at_server(model_weights_type), tff.type_at_clients(model_input_type)) def federated_evaluate(model_weights, federated_dataset): client_model = tff.federated_broadcast(model_weights) client_metrics = tff.federated_map(compute_client_metrics, (client_model, federated_dataset)) # Extract the number of examples in order to compute client weights num_examples = client_metrics.num_examples uniform_weighted_metrics = tff.federated_mean(client_metrics, weight=None) example_weighted_metrics = tff.federated_mean(client_metrics, weight=num_examples) # Aggregate the metrics in a single nested dictionary aggregate_metrics = collections.OrderedDict() aggregate_metrics[AggregationMethods.EXAMPLE_WEIGHTED. value] = example_weighted_metrics aggregate_metrics[AggregationMethods.UNIFORM_WEIGHTED. value] = uniform_weighted_metrics return aggregate_metrics return federated_evaluate
def build_federated_averaging_process_attacked( model_fn, client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.1), server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0), aggregation_process=None, client_update_tf=ClientExplicitBoosting(boost_factor=1.0)): """Builds the TFF computations for optimization using federated averaging with potentially malicious clients. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. client_optimizer_fn: A no-arg function that returns a `tf.keras.optimizers.Optimizer`, use during local client training. server_optimizer_fn: A no-arg function that returns a `tf.keras.optimizers.Optimizer`, use to apply updates to the global model. aggregation_process: A 'tff.templates.MeasuredProcess' that aggregates model deltas placed@CLIENTS to an aggregated model delta placed@SERVER. client_update_tf: a 'tf.function' computes the ClientOutput. Returns: A `tff.templates.IterativeProcess`. """ with tf.Graph().as_default(): dummy_model_for_metadata = model_fn() weights_type = tff.learning.framework.weights_type_from_model( dummy_model_for_metadata) if aggregation_process is None: aggregation_process = tff.learning.framework.build_stateless_mean( model_delta_type=weights_type.trainable) server_init = build_server_init_fn(model_fn, server_optimizer_fn, aggregation_process.initialize) server_state_type = server_init.type_signature.result.member server_update_fn = build_server_update_fn(model_fn, server_optimizer_fn, server_state_type, server_state_type.model) tf_dataset_type = tff.SequenceType(dummy_model_for_metadata.input_spec) client_update_fn = build_client_update_fn(model_fn, client_optimizer_fn, client_update_tf, tf_dataset_type, server_state_type.model) federated_server_state_type = tff.type_at_server(server_state_type) federated_dataset_type = tff.type_at_clients(tf_dataset_type) run_one_round_tff = build_run_one_round_fn_attacked( server_update_fn, client_update_fn, aggregation_process, dummy_model_for_metadata, federated_server_state_type, federated_dataset_type) return tff.templates.IterativeProcess(initialize_fn=server_init, next_fn=run_one_round_tff)
def test_empty_mean_returns_nan(self): self.skipTest('b/200970992') # TODO(b/200970992): Standardize handling of this case. We currently have a # ZeroDivisionError, a RuntimeError, and a context that returns nan. @tff.federated_computation(tff.type_at_clients(tf.float32)) def fed_mean(x): return tff.federated_mean(x) with self.assertRaises(RuntimeError): fed_mean([])
def create(self, value_type, weight_type): @tff.federated_computation() def initialize_fn(): # state = AggregationState(self._num_participants) return tff.federated_value(self._num_participants, tff.SERVER) @tff.federated_computation(initialize_fn.type_signature.result, tff.type_at_clients(value_type), tff.type_at_clients(weight_type)) def next_fn(state, value, weight): weighted_values = tff.federated_map(_mul, (value, weight)) summed_value = tff.federated_sum(weighted_values) normalized_value = tff.federated_map(_div, (summed_value, state)) measurements = tff.federated_value((), tff.SERVER) return tff.templates.MeasuredProcessOutput( state=state, result=normalized_value, measurements=measurements) return tff.templates.AggregationProcess(initialize_fn, next_fn)
def test_federated_zip_with_twenty_elements(self): # This test will fail if execution scales factorially with number of # elements zipped. num_element = 20 num_clients = 2 @tff.federated_computation([tff.type_at_clients(tf.int32)] * num_element) def foo(x): return tff.federated_zip(x) value = [list(range(num_clients))] * num_element result = foo(value) self.assertIsNotNone(result)
def test_repeated_invocations_of_map(self): @tff.tf_computation(tf.int32) def add_one(x): return x + 1 @tff.federated_computation(tff.type_at_clients(tf.int32)) def map_add_one(federated_arg): return tff.federated_map(add_one, federated_arg) result1 = map_add_one([0, 1]) result2 = map_add_one([0, 1]) self.assertIsNotNone(result1) self.assertEqual(result1, result2)
def evaluator( model_fn: MODEL_FN, client_state_fn: CLIENT_STATE_FN ): model = model_fn() client_state = client_state_fn() dataset_type = tff.SequenceType(model.input_spec) client_state_type = tff.framework.type_from_tensors(client_state) evaluate_client_tf = tff.tf_computation( lambda dataset, state: __evaluate_client( dataset, state, model_fn, tf.function(client.evaluate) ), (dataset_type, client_state_type) ) federated_dataset_type = tff.type_at_clients(dataset_type) federated_client_state_type = tff.type_at_clients(client_state_type) def evaluate(datasets, client_states): outputs = tff.federated_map(evaluate_client_tf, (datasets, client_states)) confusion_matrix = tff.federated_sum(outputs.confusion_matrix) aggregated_metrics = model.federated_output_computation(outputs.metrics) collected_metrics = tff.federated_collect(outputs.metrics) return confusion_matrix, aggregated_metrics, collected_metrics return tff.federated_computation( evaluate, (federated_dataset_type, federated_client_state_type) )
def build_federated_averaging_process( model_fn, client_optimizer_fn, server_optimizer_fn=lambda: flars_optimizer.FLARSOptimizer(learning_rate= 1.0)): """Builds the TFF computations for optimization using federated averaging. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. client_optimizer_fn: A no-arg function that returns a `tf.keras.optimizers.Optimizer` for the local client training. server_optimizer_fn: A no-arg function that returns a `tf.keras.optimizers.Optimizer` for applying updates on the server. Returns: A `tff.templates.IterativeProcess`. """ with tf.Graph().as_default(): dummy_model_for_metadata = model_fn() type_signature_grads_norm = tuple( weight.dtype for weight in tf.nest.flatten( dummy_model_for_metadata.trainable_variables)) server_init_tf = build_server_init_fn(model_fn, server_optimizer_fn) server_state_type = server_init_tf.type_signature.result server_update_fn = build_server_update_fn(model_fn, server_optimizer_fn, server_state_type, server_state_type.model, type_signature_grads_norm) tf_dataset_type = tff.SequenceType(dummy_model_for_metadata.input_spec) client_update_fn = build_client_update_fn(model_fn, client_optimizer_fn, tf_dataset_type, server_state_type.model) federated_server_state_type = tff.type_at_server(server_state_type) federated_dataset_type = tff.type_at_clients(tf_dataset_type) run_one_round_tff = build_run_one_round_fn(server_update_fn, client_update_fn, dummy_model_for_metadata, federated_server_state_type, federated_dataset_type) return tff.templates.IterativeProcess( initialize_fn=tff.federated_computation( lambda: tff.federated_eval(server_init_tf, tff.SERVER)), next_fn=run_one_round_tff)
def test_build_with_preprocess_function(self): test_dataset = tf.data.Dataset.range(5) client_datasets_type = tff.type_at_clients( tff.SequenceType(test_dataset.element_spec)) @tff.tf_computation(tff.SequenceType(test_dataset.element_spec)) def preprocess_dataset(ds): def to_batch(x): return _Batch( tf.fill(dims=(784, ), value=float(x) * 2.0), tf.expand_dims(tf.cast(x + 1, dtype=tf.int64), axis=0)) return ds.map(to_batch).batch(2) iterproc = fed_avg_schedule.build_fed_avg_process( _uncompiled_model_builder, TAU, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD) iterproc = tff.simulation.compose_dataset_computation_with_iterative_process( preprocess_dataset, iterproc) with tf.Graph().as_default(): test_model_for_types = _uncompiled_model_builder() server_state_type = tff.FederatedType( fed_avg_schedule.ServerState(model=tff.framework.type_from_tensors( tff.learning.ModelWeights( test_model_for_types.trainable_variables, test_model_for_types.non_trainable_variables)), optimizer_state=(tf.int64, ), round_num=tf.float32, client_drift=tf.float32), tff.SERVER) metrics_type = test_model_for_types.federated_output_computation.type_signature.result expected_parameter_type = collections.OrderedDict( server_state=server_state_type, federated_dataset=client_datasets_type, ) expected_result_type = (server_state_type, metrics_type) expected_type = tff.FunctionType(parameter=expected_parameter_type, result=expected_result_type) self.assertTrue( iterproc.next.type_signature.is_equivalent_to(expected_type), msg='{s}\n!={t}'.format(s=iterproc.next.type_signature, t=expected_type))
def get_federated_tokenize_fn(dataset_name, dataset_element_type_structure): """Get a federated tokenizer function.""" @tff.tf_computation(tff.SequenceType(dataset_element_type_structure)) def tokenize_dataset(dataset): """The TF computation to tokenize a dataset.""" dataset = tokenize(dataset, dataset_name) return dataset @tff.federated_computation( tff.type_at_clients(tff.SequenceType(dataset_element_type_structure))) def tokenize_datasets(datasets): """The TFF computation to compute tokenized datasets.""" tokenized_datasets = tff.federated_map(tokenize_dataset, datasets) return tokenized_datasets return tokenize_datasets
def test_polymorphism(self): @tff.tf_computation(tf.int32) def add_one(x): return x + 1 @tff.federated_computation(tff.type_at_clients(tf.int32)) def map_add_one(federated_arg): return tff.federated_map(add_one, federated_arg) result1 = map_add_one([0, 1]) result2 = map_add_one([0, 1, 2]) self.assertIsNotNone(result1) self.assertIsNotNone(result2) self.assertLen(result1, 2) self.assertLen(result2, 3)
def test_federated_collect_large_numbers_of_parameters(self): num_clients = 10 model_size = 10**6 client_models = [tf.ones([model_size]) for _ in range(num_clients)] client_data_type = tff.type_at_clients((tf.float32, [model_size])) @tff.federated_computation(client_data_type) def comp(client_data): return tff.federated_collect(client_data) start_time_seconds = time.time() result = comp(client_models) end_time_seconds = time.time() runtime = end_time_seconds - start_time_seconds if runtime > 10: raise RuntimeError( 'comp should take much less than a second, but took ' + str(runtime)) del result