def build_federated_evaluation(model_fn): """Builds the TFF computation for federated evaluation of the given model. Args: model_fn: A no-argument function that returns a `tff.learning.Model`. Returns: A federated computation (an instance of `tff.Computation`) that accepts model parameters and federated data, and returns the evaluation metrics as aggregated by `tff.learning.Model.federated_output_computation`. """ # Construct the model first just to obtain the metadata and define all the # types needed to define the computations that follow. # TODO(b/124477628): Ideally replace the need for stamping throwaway models # with some other mechanism. with tf.Graph().as_default(): model = model_utils.enhance(model_fn()) model_weights_type = tff.to_type( tf.nest.map_structure( lambda v: tff.TensorType(v.dtype.base_dtype, v.shape), model.weights)) batch_type = tff.to_type(model.input_spec) @tff.tf_computation(model_weights_type, tff.SequenceType(batch_type)) def client_eval(incoming_model_weights, dataset): """Returns local outputs after evaluting `model_weights` on `dataset`.""" model = model_utils.enhance(model_fn()) # TODO(b/124477598): Remove dummy when b/121400757 has been fixed. @tf.function def reduce_fn(dummy, batch): model_output = model.forward_pass(batch, training=False) return dummy + tf.cast(model_output.loss, tf.float64) # TODO(b/123898430): The control dependencies below have been inserted as a # temporary workaround. These control dependencies need to be removed, and # defuns and datasets supported together fully. with tf.control_dependencies( [tff.utils.assign(model.weights, incoming_model_weights)]): dummy = dataset.reduce(tf.constant(0.0, dtype=tf.float64), reduce_fn) with tf.control_dependencies([dummy]): return collections.OrderedDict([ ('local_outputs', model.report_local_outputs()), ('workaround for b/121400757', dummy) ]) @tff.federated_computation( tff.FederatedType(model_weights_type, tff.SERVER), tff.FederatedType(tff.SequenceType(batch_type), tff.CLIENTS)) def server_eval(server_model_weights, federated_dataset): client_outputs = tff.federated_map( client_eval, [tff.federated_broadcast(server_model_weights), federated_dataset]) return model.federated_output_computation(client_outputs.local_outputs) return server_eval
def build_federated_evaluation(model_fn): """Builds the TFF computation for federated evaluation of the given model. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. This method must *not* capture TensorFlow tensors or variables and use them. The model must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error. Returns: A federated computation (an instance of `tff.Computation`) that accepts model parameters and federated data, and returns the evaluation metrics as aggregated by `tff.learning.Model.federated_output_computation`. """ # Construct the model first just to obtain the metadata and define all the # types needed to define the computations that follow. # TODO(b/124477628): Ideally replace the need for stamping throwaway models # with some other mechanism. with tf.Graph().as_default(): model = model_utils.enhance(model_fn()) model_weights_type = tff.framework.type_from_tensors(model.weights) batch_type = tff.to_type(model.input_spec) @tff.tf_computation(model_weights_type, tff.SequenceType(batch_type)) def client_eval(incoming_model_weights, dataset): """Returns local outputs after evaluting `model_weights` on `dataset`.""" model = model_utils.enhance(model_fn()) @tf.function def _tf_client_eval(incoming_model_weights, dataset): """Evaluation TF work.""" tff.utils.assign(model.weights, incoming_model_weights) def reduce_fn(prev_loss, batch): model_output = model.forward_pass(batch, training=False) return prev_loss + tf.cast(model_output.loss, tf.float64) dataset.reduce(tf.constant(0.0, dtype=tf.float64), reduce_fn) return collections.OrderedDict([('local_outputs', model.report_local_outputs())]) return _tf_client_eval(incoming_model_weights, dataset) @tff.federated_computation( tff.FederatedType(model_weights_type, tff.SERVER), tff.FederatedType(tff.SequenceType(batch_type), tff.CLIENTS)) def server_eval(server_model_weights, federated_dataset): client_outputs = tff.federated_map( client_eval, [tff.federated_broadcast(server_model_weights), federated_dataset]) return model.federated_output_computation(client_outputs.local_outputs) return server_eval
def _create_stateless_int_dataset_reduction_iterative_process(): @tff.tf_computation() def make_zero(): return tf.cast(0, tf.int64) @tff.federated_computation() def init(): return tff.federated_eval(make_zero, tff.SERVER) @tff.tf_computation(tff.SequenceType(tf.int64)) def reduce_dataset(x): return x.reduce(tf.cast(0, tf.int64), lambda x, y: x + y) @tff.federated_computation((init.type_signature.result, tff.FederatedType(tff.SequenceType(tf.int64), tff.CLIENTS))) def next_fn(empty_tup, x): del empty_tup # Unused return tff.federated_sum(tff.federated_map(reduce_dataset, x)) return tff.templates.IterativeProcess(initialize_fn=init, next_fn=next_fn)
def test_orchestration_type_signature(self): iterative_process = optimizer_utils.build_model_delta_optimizer_process( model_fn=model_examples.TrainableLinearRegression, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=lambda: gradient_descent.SGD(learning_rate=1.0 )) expected_model_weights_type = model_utils.ModelWeights( collections.OrderedDict([('a', tff.TensorType(tf.float32, [2, 1])), ('b', tf.float32)]), collections.OrderedDict([('c', tf.float32)])) # ServerState consists of a model and optimizer_state. The optimizer_state # is provided by TensorFlow, TFF doesn't care what the actual value is. expected_federated_server_state_type = tff.FederatedType( optimizer_utils.ServerState(expected_model_weights_type, test.AnyType(), test.AnyType(), test.AnyType()), placement=tff.SERVER, all_equal=True) expected_federated_dataset_type = tff.FederatedType(tff.SequenceType( model_examples.TrainableLinearRegression().input_spec), tff.CLIENTS, all_equal=False) expected_model_output_types = tff.FederatedType( collections.OrderedDict([ ('loss', tff.TensorType(tf.float32, [])), ('num_examples', tff.TensorType(tf.int32, [])), ]), tff.SERVER, all_equal=True) # `initialize` is expected to be a funcion of no arguments to a ServerState. self.assertEqual( tff.FunctionType(parameter=None, result=expected_federated_server_state_type), iterative_process.initialize.type_signature) # `next` is expected be a function of (ServerState, Datasets) to # ServerState. self.assertEqual( tff.FunctionType(parameter=[ expected_federated_server_state_type, expected_federated_dataset_type ], result=(expected_federated_server_state_type, expected_model_output_types)), iterative_process.next.type_signature)
def test_orchestration_typecheck(self): iterative_process = federated_sgd.build_federated_sgd_process( model_fn=model_examples.LinearRegression) expected_model_weights_type = model_utils.ModelWeights( collections.OrderedDict([('a', tff.TensorType(tf.float32, [2, 1])), ('b', tf.float32)]), collections.OrderedDict([('c', tf.float32)])) # ServerState consists of a model and optimizer_state. The optimizer_state # is provided by TensorFlow, TFF doesn't care what the actual value is. expected_federated_server_state_type = tff.FederatedType( optimizer_utils.ServerState(expected_model_weights_type, test.AnyType()), placement=tff.SERVER, all_equal=True) expected_federated_dataset_type = tff.FederatedType( tff.SequenceType( model_examples.LinearRegression.make_batch( tff.TensorType(tf.float32, [None, 2]), tff.TensorType(tf.float32, [None, 1]))), tff.CLIENTS, all_equal=False) expected_model_output_types = tff.FederatedType( collections.OrderedDict([ ('loss', tff.TensorType(tf.float32, [])), ('num_examples', tff.TensorType(tf.int32, [])), ]), tff.SERVER, all_equal=True) # `initialize` is expected to be a funcion of no arguments to a ServerState. self.assertEqual( tff.FunctionType( parameter=None, result=expected_federated_server_state_type), iterative_process.initialize.type_signature) # `next` is expected be a function of (ServerState, Datasets) to # ServerState. self.assertEqual( tff.FunctionType( parameter=[ expected_federated_server_state_type, expected_federated_dataset_type ], result=(expected_federated_server_state_type, expected_model_output_types)), iterative_process.next.type_signature)
def test_construction(self): iterative_process = optimizer_utils.build_model_delta_optimizer_process( model_fn=model_examples.LinearRegression, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=tf.keras.optimizers.SGD) server_state_type = tff.FederatedType( optimizer_utils.ServerState( model=model_utils.ModelWeights( trainable=[ tff.TensorType(tf.float32, [2, 1]), tff.TensorType(tf.float32) ], non_trainable=[tff.TensorType(tf.float32)]), optimizer_state=[tf.int64], delta_aggregate_state=(), model_broadcast_state=()), tff.SERVER) self.assertEqual( str(iterative_process.initialize.type_signature), str(tff.FunctionType(parameter=None, result=server_state_type))) dataset_type = tff.FederatedType( tff.SequenceType( collections.OrderedDict( x=tff.TensorType(tf.float32, [None, 2]), y=tff.TensorType(tf.float32, [None, 1]))), tff.CLIENTS) metrics_type = tff.FederatedType( collections.OrderedDict( broadcast=(), aggregation=(), train=collections.OrderedDict( loss=tff.TensorType(tf.float32), num_examples=tff.TensorType(tf.int32))), tff.SERVER) self.assertEqual( str(iterative_process.next.type_signature), str( tff.FunctionType( parameter=(server_state_type, dataset_type), result=(server_state_type, metrics_type))))
def build_model_delta_optimizer_process( model_fn, model_to_client_delta_fn, server_optimizer_fn, stateful_delta_aggregate_fn=build_stateless_mean(), stateful_model_broadcast_fn=build_stateless_broadcaster()): """Constructs `tff.utils.IterativeProcess` for Federated Averaging or SGD. This provides the TFF orchestration logic connecting the common server logic which applies aggregated model deltas to the server model with a `ClientDeltaFn` that specifies how `weight_deltas` are computed on device. Note: We pass in functions rather than constructed objects so we can ensure any variables or ops created in constructors are placed in the correct graph. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. model_to_client_delta_fn: A function from a `model_fn` to a `ClientDeltaFn`. server_optimizer_fn: A no-arg function that returns a `tf.Optimizer`. The `apply_gradients` method of this optimizer is used to apply client updates to the server model. stateful_delta_aggregate_fn: A `tff.utils.StatefulAggregateFn` where the `next_fn` performs a federated aggregation and upates state. That is, it has TFF type `(state@SERVER, value@CLIENTS, weights@CLIENTS) -> (state@SERVER, aggregate@SERVER)`, where the `value` type is `tff.learning.framework.ModelWeights.trainable` corresponding to the object returned by `model_fn`. stateful_model_broadcast_fn: A `tff.utils.StatefulBroadcastFn` where the `next_fn` performs a federated broadcast and upates state. That is, it has TFF type `(state@SERVER, value@SERVER) -> (state@SERVER, value@CLIENTS)`, where the `value` type is `tff.learning.framework.ModelWeights` corresponding to the object returned by `model_fn`. Returns: A `tff.utils.IterativeProcess`. """ py_typecheck.check_callable(model_fn) py_typecheck.check_callable(model_to_client_delta_fn) py_typecheck.check_callable(server_optimizer_fn) py_typecheck.check_type(stateful_delta_aggregate_fn, tff.utils.StatefulAggregateFn) py_typecheck.check_type(stateful_model_broadcast_fn, tff.utils.StatefulBroadcastFn) # TODO(b/122081673): would be nice not to have the construct a throwaway model # here just to get the types. After fully moving to TF2.0 and eager-mode, we # should re-evaluate what happens here. with tf.Graph().as_default(): dummy_model_for_metadata = model_utils.enhance(model_fn()) # =========================================================================== # TensorFlow Computations @tff.tf_computation def tf_init_fn(): return server_init(model_fn, server_optimizer_fn, stateful_delta_aggregate_fn.initialize(), stateful_model_broadcast_fn.initialize()) tf_dataset_type = tff.SequenceType(dummy_model_for_metadata.input_spec) server_state_type = tf_init_fn.type_signature.result @tff.tf_computation(tf_dataset_type, server_state_type.model) def tf_client_delta(tf_dataset, initial_model_weights): """Performs client local model optimization. Args: tf_dataset: a `tf.data.Dataset` that provides training examples. initial_model_weights: a `model_utils.ModelWeights` containing the starting weights. Returns: A `ClientOutput` structure. """ client_delta_fn = model_to_client_delta_fn(model_fn) client_output = client_delta_fn(tf_dataset, initial_model_weights) return client_output @tff.tf_computation(server_state_type, server_state_type.model.trainable, server_state_type.delta_aggregate_state, server_state_type.model_broadcast_state) def tf_server_update(server_state, model_delta, new_delta_aggregate_state, new_broadcaster_state): """Converts args to correct python types and calls server_update_model.""" py_typecheck.check_type(server_state, ServerState) server_state = ServerState( model=server_state.model, optimizer_state=list(server_state.optimizer_state), delta_aggregate_state=new_delta_aggregate_state, model_broadcast_state=new_broadcaster_state) return server_update_model(server_state, model_delta, model_fn=model_fn, optimizer_fn=server_optimizer_fn) weight_type = tf_client_delta.type_signature.result.weights_delta_weight @tff.tf_computation(weight_type) def _cast_weight_to_float(x): return tf.cast(x, tf.float32) # =========================================================================== # Federated Computations @tff.federated_computation def server_init_tff(): """Orchestration logic for server model initialization.""" return tff.federated_value(tf_init_fn(), tff.SERVER) federated_server_state_type = server_init_tff.type_signature.result federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS) @tff.federated_computation(federated_server_state_type, federated_dataset_type) def run_one_round_tff(server_state, federated_dataset): """Orchestration logic for one round of optimization. Args: server_state: a `tff.learning.framework.ServerState` named tuple. federated_dataset: a federated `tf.Dataset` with placement tff.CLIENTS. Returns: A tuple of updated `tff.learning.framework.ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ new_broadcaster_state, client_model = stateful_model_broadcast_fn( server_state.model_broadcast_state, server_state.model) client_outputs = tff.federated_map(tf_client_delta, (federated_dataset, client_model)) # TODO(b/124070381): We hope to remove this explicit cast once we have a # full solution for type analysis in multiplications and divisions # inside TFF weight_denom = tff.federated_map(_cast_weight_to_float, client_outputs.weights_delta_weight) new_delta_aggregate_state, round_model_delta = stateful_delta_aggregate_fn( server_state.delta_aggregate_state, client_outputs.weights_delta, weight=weight_denom) # TODO(b/123408447): remove tff.federated_apply and call # tf_server_update directly once T <-> T@SERVER isomorphism is # supported. server_state = tff.federated_apply( tf_server_update, (server_state, round_model_delta, new_delta_aggregate_state, new_broadcaster_state)) aggregated_outputs = dummy_model_for_metadata.federated_output_computation( client_outputs.model_output) # Promote the FederatedType outside the NamedTupleType aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs return tff.utils.IterativeProcess(initialize_fn=server_init_tff, next_fn=run_one_round_tff)
def benchmark_fc_api_mnist(self): """Code adapted from FC API tutorial ipynb.""" n_rounds = 10 batch_type = tff.NamedTupleType([ ("x", tff.TensorType(tf.float32, [None, 784])), ("y", tff.TensorType(tf.int32, [None])) ]) model_type = tff.NamedTupleType([ ("weights", tff.TensorType(tf.float32, [784, 10])), ("bias", tff.TensorType(tf.float32, [10])) ]) local_data_type = tff.SequenceType(batch_type) server_model_type = tff.FederatedType(model_type, tff.SERVER, all_equal=True) client_data_type = tff.FederatedType(local_data_type, tff.CLIENTS) server_float_type = tff.FederatedType(tf.float32, tff.SERVER, all_equal=True) computation_building_start = time.time() # pylint: disable=missing-docstring @tff.tf_computation(model_type, batch_type) def batch_loss(model, batch): predicted_y = tf.nn.softmax( tf.matmul(batch.x, model.weights) + model.bias) return -tf.reduce_mean( tf.reduce_sum(tf.one_hot(batch.y, 10) * tf.log(predicted_y), reduction_indices=[1])) initial_model = { "weights": np.zeros([784, 10], dtype=np.float32), "bias": np.zeros([10], dtype=np.float32) } @tff.tf_computation(model_type, batch_type, tf.float32) def batch_train(initial_model, batch, learning_rate): model_vars = tff.utils.get_variables("v", model_type) init_model = tff.utils.assign(model_vars, initial_model) optimizer = tf.train.GradientDescentOptimizer(learning_rate) with tf.control_dependencies([init_model]): train_model = optimizer.minimize(batch_loss(model_vars, batch)) with tf.control_dependencies([train_model]): return tff.utils.identity(model_vars) @tff.federated_computation(model_type, tf.float32, local_data_type) def local_train(initial_model, learning_rate, all_batches): @tff.federated_computation(model_type, batch_type) def batch_fn(model, batch): return batch_train(model, batch, learning_rate) return tff.sequence_reduce(all_batches, initial_model, batch_fn) @tff.federated_computation(server_model_type, server_float_type, client_data_type) def federated_train(model, learning_rate, data): return tff.federated_average( tff.federated_map(local_train, [ tff.federated_broadcast(model), tff.federated_broadcast(learning_rate), data ])) computation_building_stop = time.time() building_time = computation_building_stop - computation_building_start self.report_benchmark(name="computation_building_time, FC API", wall_time=building_time, iters=1) model = initial_model learning_rate = 0.1 federated_data = generate_fake_mnist_data() execution_array = [] for _ in range(n_rounds): execution_start = time.time() model = federated_train(model, learning_rate, federated_data) execution_stop = time.time() execution_array.append(execution_stop - execution_start) self.report_benchmark(name="Average per round execution time, FC API", wall_time=np.mean(execution_array), iters=n_rounds, extras={"std_dev": np.std(execution_array)})
def build_model_delta_optimizer_process(model_fn, model_to_client_delta_fn, server_optimizer_fn): """Constructs `tff.utils.IterativeProcess` for Federated Averaging or SGD. This provides the TFF orchestration logic connecting the common server logic which applies aggregated model deltas to the server model with a ClientDeltaFn that specifies how weight_deltas are computed on device. Note: We pass in functions rather than constructed objects so we can ensure any variables or ops created in constructors are placed in the correct graph. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. model_to_client_delta_fn: A function from a model_fn to a `ClientDeltaFn`. server_optimizer_fn: A no-arg function that returns a `tf.Optimizer`. The `apply_gradients` method of this optimizer is used to apply client updates to the server model. Returns: A `tff.utils.IterativeProcess`. """ py_typecheck.check_callable(model_fn) py_typecheck.check_callable(model_to_client_delta_fn) py_typecheck.check_callable(server_optimizer_fn) # TODO(b/122081673): would be nice not to have the construct a throwaway model # here just to get the types. After fully moving to TF2.0 and eager-mode, we # should re-evaluate what happens here and where `g` is used below. with tf.Graph().as_default() as g: dummy_model_for_metadata = model_utils.enhance(model_fn()) @tff.federated_computation def server_init_tff(): """Orchestration logic for server model initialization.""" no_arg_server_init_fn = lambda: server_init(model_fn, server_optimizer_fn) server_init_tf = tff.tf_computation(no_arg_server_init_fn) return tff.federated_value(server_init_tf(), tff.SERVER) federated_server_state_type = server_init_tff.type_signature.result server_state_type = federated_server_state_type.member tf_dataset_type = tff.SequenceType(dummy_model_for_metadata.input_spec) federated_dataset_type = tff.FederatedType( tf_dataset_type, tff.CLIENTS, all_equal=False) @tff.federated_computation(federated_server_state_type, federated_dataset_type) def run_one_round_tff(server_state, federated_dataset): """Orchestration logic for one round of optimization. Args: server_state: a `tff.learning.framework.ServerState` named tuple. federated_dataset: a federated `tf.Dataset` with placement tff.CLIENTS. Returns: A tuple of updated `tff.learning.framework.ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ model_weights_type = federated_server_state_type.member.model @tff.tf_computation(tf_dataset_type, model_weights_type) def client_delta_tf(tf_dataset, initial_model_weights): """Performs client local model optimization. Args: tf_dataset: a `tf.data.Dataset` that provides training examples. initial_model_weights: a `model_utils.ModelWeights` containing the starting weights. Returns: A `ClientOutput` structure. """ client_delta_fn = model_to_client_delta_fn(model_fn) # TODO(b/123092620): this can be removed once AnonymousTuple works with # tf.contrib.framework.nest, or the following behavior is moved to # anonymous_tuple module. if isinstance(initial_model_weights, anonymous_tuple.AnonymousTuple): initial_model_weights = model_utils.ModelWeights.from_tff_value( initial_model_weights) client_output = client_delta_fn(tf_dataset, initial_model_weights) return client_output client_outputs = tff.federated_map( client_delta_tf, (federated_dataset, tff.federated_broadcast(server_state.model))) @tff.tf_computation(server_state_type, model_weights_type.trainable) def server_update_model_tf(server_state, model_delta): """Converts args to correct python types and calls server_update_model.""" # We need to convert TFF types to the types server_update_model expects. # TODO(b/123092620): Mixing AnonymousTuple with other nested types is not # pretty, fold this into anonymous_tuple module or get working with # tf.contrib.framework.nest. py_typecheck.check_type(model_delta, anonymous_tuple.AnonymousTuple) model_delta = anonymous_tuple.to_odict(model_delta) py_typecheck.check_type(server_state, anonymous_tuple.AnonymousTuple) server_state = ServerState( model=model_utils.ModelWeights.from_tff_value(server_state.model), optimizer_state=list(server_state.optimizer_state)) return server_update_model( server_state, model_delta, model_fn=model_fn, optimizer_fn=server_optimizer_fn) # TODO(b/124070381): We hope to remove this explicit cast once we have a # full solution for type analysis in multiplications and divisions # inside TFF fed_weight_type = client_outputs.weights_delta_weight.type_signature.member py_typecheck.check_type(fed_weight_type, tff.TensorType) if fed_weight_type.dtype.is_integer: @tff.tf_computation(fed_weight_type) def _cast_to_float(x): return tf.cast(x, tf.float32) weight_denom = tff.federated_map(_cast_to_float, client_outputs.weights_delta_weight) else: weight_denom = client_outputs.weights_delta_weight round_model_delta = tff.federated_mean( client_outputs.weights_delta, weight=weight_denom) # TODO(b/123408447): remove tff.federated_apply and call # server_update_model_tf directly once T <-> T@SERVER isomorphism is # supported. server_state = tff.federated_apply(server_update_model_tf, (server_state, round_model_delta)) # Re-use graph used to construct `model`, since it has the variables, which # need to be read in federated_output_computation to get the correct shapes # and types for the federated aggregation. with g.as_default(): aggregated_outputs = dummy_model_for_metadata.federated_output_computation( client_outputs.model_output) # Promote the FederatedType outside the NamedTupleType aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs return tff.utils.IterativeProcess( initialize_fn=server_init_tff, next_fn=run_one_round_tff)
def _build_one_round_computation( *, model_fn: _ModelConstructor, server_optimizer_fn: _OptimizerConstructor, model_to_client_delta_fn: Callable[[Callable[[], model_lib.Model]], ClientDeltaFn], broadcast_process: tff.templates.MeasuredProcess, aggregation_process: tff.templates.MeasuredProcess, ) -> tff.Computation: """Builds the `next` computation for a model delta averaging process. Args: model_fn: a no-argument callable that constructs and returns a `tff.learning.Model`. *Must* construct and return a new model when called. Returning captured models from other scopes will raise errors. server_optimizer_fn: a no-argument callable that constructs and returns a `tf.keras.optimizers.Optimizer`. *Must* construct and return a new optimizer when called. Returning captured optimizers from other scopes will raise errors. model_to_client_delta_fn: a callable that takes a single no-arg callable that returns `tff.learning.Model` as an argument and returns a `ClientDeltaFn` which performs the local training loop and model delta computation. broadcast_process: a `tff.templates.MeasuredProcess` to broadcast the global model to the clients. aggregation_process: a `tff.templates.MeasuredProcess` to aggregate client model deltas. Returns: A `tff.Computation` that initializes the process. The computation takes a tuple of `(ServerState@SERVER, tf.data.Dataset@CLIENTS)` argument, and returns a tuple of `(ServerState@SERVER, metrics@SERVER)`. """ # TODO(b/124477628): would be nice not to have the construct a throwaway model # here just to get the types. After fully moving to TF2.0 and eager-mode, we # should re-evaluate what happens here. # TODO(b/144382142): Keras name uniquification is probably the main reason we # still need this. with tf.Graph().as_default(): dummy_model_for_metadata = model_fn() model_weights_type = tff.framework.type_from_tensors( model_utils.ModelWeights.from_model(dummy_model_for_metadata)) dummy_optimizer = server_optimizer_fn() # We must force variable creation for momentum and adaptive optimizers. _eagerly_create_optimizer_variables( model=dummy_model_for_metadata, optimizer=dummy_optimizer) optimizer_variable_type = tff.framework.type_from_tensors( dummy_optimizer.variables()) @tff.tf_computation(model_weights_type, model_weights_type.trainable, optimizer_variable_type) def server_update(global_model, model_delta, optimizer_state): """Converts args to correct python types and calls server_update_model.""" # Construct variables first. model = model_fn() optimizer = server_optimizer_fn() # We must force variable creation for momentum and adaptive optimizers. _eagerly_create_optimizer_variables(model=model, optimizer=optimizer) @tf.function def update_model_inner(weights_delta): """Applies the update to the global model.""" model_variables = model_utils.ModelWeights.from_model(model) optimizer_variables = optimizer.variables() # We might have a NaN value e.g. if all of the clients processed # had no data, so the denominator in the federated_mean is zero. # If we see any NaNs, zero out the whole update. no_nan_weights_delta, _ = tensor_utils.zero_all_if_any_non_finite( weights_delta) # TODO(b/124538167): We should increment a server counter to # track the fact a non-finite weights_delta was encountered. # Set the variables to the current global model (before update). tf.nest.map_structure(lambda a, b: a.assign(b), (model_variables, optimizer_variables), (global_model, optimizer_state)) # Update the variables with the delta, and return the new global model. _apply_delta(optimizer=optimizer, model=model, delta=no_nan_weights_delta) return model_variables, optimizer_variables return update_model_inner(model_delta) dataset_type = tff.SequenceType(dummy_model_for_metadata.input_spec) @tff.tf_computation(dataset_type, model_weights_type) def _compute_local_training_and_client_delta(dataset, initial_model_weights): """Performs client local model optimization. Args: dataset: a `tf.data.Dataset` that provides training examples. initial_model_weights: a `model_utils.ModelWeights` containing the starting weights. Returns: A `ClientOutput` structure. """ client_delta_fn = model_to_client_delta_fn(model_fn) client_output = client_delta_fn(dataset, initial_model_weights) return client_output broadcast_state = broadcast_process.initialize.type_signature.result.member aggregation_state = aggregation_process.initialize.type_signature.result.member server_state_type = ServerState( model=model_weights_type, optimizer_state=optimizer_variable_type, delta_aggregate_state=aggregation_state, model_broadcast_state=broadcast_state) @tff.federated_computation( tff.FederatedType(server_state_type, tff.SERVER), tff.FederatedType(dataset_type, tff.CLIENTS)) def one_round_computation(server_state, federated_dataset): """Orchestration logic for one round of optimization. Args: server_state: a `tff.learning.framework.ServerState` named tuple. federated_dataset: a federated `tf.Dataset` with placement tff.CLIENTS. Returns: A tuple of updated `tff.learning.framework.ServerState` and the result of `tff.learning.Model.federated_output_computation`, both having `tff.SERVER` placement. """ broadcast_output = broadcast_process.next( server_state.model_broadcast_state, server_state.model) client_outputs = tff.federated_map( _compute_local_training_and_client_delta, (federated_dataset, broadcast_output.result)) aggregation_output = aggregation_process.next( server_state.delta_aggregate_state, client_outputs.weights_delta, client_outputs.weights_delta_weight) new_global_model, new_optimizer_state = tff.federated_map( server_update, (server_state.model, aggregation_output.result, server_state.optimizer_state)) new_server_state = tff.federated_zip( ServerState(new_global_model, new_optimizer_state, aggregation_output.state, broadcast_output.state)) aggregated_outputs = dummy_model_for_metadata.federated_output_computation( client_outputs.model_output) measurements = tff.federated_zip( collections.OrderedDict( broadcast=broadcast_output.measurements, aggregation=aggregation_output.measurements, train=aggregated_outputs)) return new_server_state, measurements return one_round_computation
def build_personalization_eval(model_fn, personalize_fn_dict, baseline_evaluate_fn, max_num_samples=100, context_tff_type=None): """Builds the TFF computation for evaluating personalization strategies. The returned TFF computation broadcasts model weights from SERVER to CLIENTS. Each client evaluates the personalization strategies given in `personalize_fn_dict`. Evaluation metrics from at most `max_num_samples` participating clients are collected to the SERVER. Args: model_fn: A no-argument function that returns a `tff.learning.Model`. personalize_fn_dict: An `OrderedDict` that maps a `string` (representing a strategy name) to a no-argument function that returns a `tf.function`. Each `tf.function` represents a personalization strategy: it accepts a `tff.learning.Model` (with weights already initialized to the provided model weights when users invoke the returned TFF computation), a training `tf.dataset.Dataset`, a test `tf.dataset.Dataset`, and an arbitrary context object (which is used to hold any extra information that a personalization strategy may use), trains a personalized model, and returns the evaluation metrics. The evaluation metrics are usually represented as an `OrderedDict` (or a nested `OrderedDict`) of `string` metric names to scalar `tf.Tensor`s. baseline_evaluate_fn: A `tf.function` that accepts a `tff.learning.Model` (with weights already initialized to the provided model weights when users invoke the returned TFF computation), and a `tf.dataset.Dataset`, evaluates the model on the dataset, and returns the evaluation metrics. The evaluation metrics are usually represented as an `OrderedDict` (or a nested `OrderedDict`) of `string` metric names to scalar `tf.Tensor`s. This function is *only* used to compute the baseline metrics of the initial model. max_num_samples: A positive `int` specifying the maximum number of metric samples to collect in a round. Each sample contains the personalization metrics from a single client. If the number of participating clients in a round is smaller than this value, all clients' metrics are collected. context_tff_type: A `tff.Type` of the optional context object used by the personalization strategies defined in `personalization_fn_dict`. We use a context object to hold any extra information (in addition to the training dataset) that personalization may use. If context is used in `personalization_fn_dict`, its `tff.Type` must be provided here. Returns: A federated `tff.Computation` that maps < model_weights@SERVER, input@CLIENTS > -> personalization_metrics@SERVER, where: - model_weights is a `tff.learning.framework.ModelWeights`. - each client's input is an `OrderedDict` of at least two keys `train_data` and `test_data`, and each key is mapped to a `tf.dataset.Dataset`. If context is used in `personalize_fn_dict`, then client input has a third key `context` that is mapped to a object whose `tff.Type` is provided by the `context_tff_type` argument. - personazliation_metrics is an `OrderedDict` that maps a key 'baseline_metrics' to the evaluation metrics of the initial model (computed by `baseline_evaluate_fn`), and maps keys (strategy names) in `personalize_fn_dict` to the evaluation metrics of the corresponding personalization strategies. - Note: only metrics from at most `max_num_samples` participating clients are collected to the SERVER. All collected metrics are stored in a single `OrderedDict` (the personalization_metrics shown above), where each metric is mapped to a list of scalars (each scalar comes from one client). Metric values at the same position, e.g., metric_1[i], metric_2[i]..., all come from the same client. Raises: TypeError: If arguments are of the wrong types. ValueError: If `baseline_metrics` is used as a key in `personalize_fn_dict`. ValueError: If `max_num_samples` is not positive. """ # Obtain the types by constructing the model first. # TODO(b/124477628): Replace it with other ways of handling metadata. with tf.Graph().as_default(): py_typecheck.check_callable(model_fn) model = model_utils.enhance(model_fn()) model_weights_type = tff.framework.type_from_tensors(model.weights) batch_type = tff.to_type(model.input_spec) # Define the `tff.Type` of each client's input. client_input_type = collections.OrderedDict([ ('train_data', tff.SequenceType(batch_type)), ('test_data', tff.SequenceType(batch_type)) ]) if context_tff_type is not None: py_typecheck.check_type(context_tff_type, tff.Type) client_input_type['context'] = context_tff_type client_input_type = tff.to_type(client_input_type) @tff.tf_computation(model_weights_type, client_input_type) def _client_computation(initial_model_weights, client_input): """TFF computation that runs on each client.""" model = model_fn() train_data = client_input['train_data'] test_data = client_input['test_data'] context = client_input.get('context', None) return _client_fn(model, initial_model_weights, train_data, test_data, personalize_fn_dict, baseline_evaluate_fn, context) py_typecheck.check_type(max_num_samples, int) if max_num_samples <= 0: raise ValueError('max_num_samples must be a positive integer.') @tff.federated_computation( tff.FederatedType(model_weights_type, tff.SERVER), tff.FederatedType(client_input_type, tff.CLIENTS)) def personalization_eval(server_model_weights, federated_client_input): """TFF orchestration logic.""" client_init_weights = tff.federated_broadcast(server_model_weights) client_final_metrics = tff.federated_map( _client_computation, (client_init_weights, federated_client_input)) # WARNING: Collecting information from clients can be risky. Users have to # make sure that it is proper to collect those metrics from clients. # TODO(b/147889283): Add a link to the TFF doc once it exists. results = tff.utils.federated_sample(client_final_metrics, max_num_samples) return results return personalization_eval
def build_personalization_eval(model_fn, personalize_fn_dict, baseline_evaluate_fn, max_num_samples=100, context_tff_type=None): """Builds the TFF computation for evaluating personalization strategies. The returned TFF computation broadcasts model weights from `tff.SERVER` to `tff.CLIENTS`. Each client evaluates the personalization strategies given in `personalize_fn_dict`. Evaluation metrics from at most `max_num_samples` participating clients are collected to the server. NOTE: The functions in `personalize_fn_dict` and `baseline_evaluate_fn` are expected to take as input *unbatched* datasets, and are responsible for applying batching, if any, to the provided input datasets. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. This method must *not* capture TensorFlow tensors or variables and use them. The model must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error. personalize_fn_dict: An `OrderedDict` that maps a `string` (representing a strategy name) to a no-argument function that returns a `tf.function`. Each `tf.function` represents a personalization strategy: it accepts a `tff.learning.Model` (with weights already initialized to the given model weights when users invoke the returned TFF computation), an unbatched `tf.data.Dataset` for train, an unbatched `tf.data.Dataset` for test, and an arbitrary context object (which is used to hold any extra information that a personalization strategy may use), trains a personalized model, and returns the evaluation metrics. The evaluation metrics are represented as an `OrderedDict` (or a nested `OrderedDict`) of `string` metric names to scalar `tf.Tensor`s. baseline_evaluate_fn: A `tf.function` that accepts a `tff.learning.Model` (with weights already initialized to the provided model weights when users invoke the returned TFF computation), and an unbatched `tf.data.Dataset`, evaluates the model on the dataset, and returns the evaluation metrics. The evaluation metrics are represented as an `OrderedDict` (or a nested `OrderedDict`) of `string` metric names to scalar `tf.Tensor`s. This function is *only* used to compute the baseline metrics of the initial model. max_num_samples: A positive `int` specifying the maximum number of metric samples to collect in a round. Each sample contains the personalization metrics from a single client. If the number of participating clients in a round is smaller than this value, all clients' metrics are collected. context_tff_type: A `tff.Type` of the optional context object used by the personalization strategies defined in `personalization_fn_dict`. We use a context object to hold any extra information (in addition to the training dataset) that personalization may use. If context is used in `personalization_fn_dict`, its `tff.Type` must be provided here. Returns: A federated `tff.Computation` with the functional type signature `(<model_weights@SERVER, input@CLIENTS> -> personalization_metrics@SERVER)`: * `model_weights` is a `tff.learning.ModelWeights`. * Each client's input is an `OrderedDict` of two required keys `train_data` and `test_data`; each key is mapped to an unbatched `tf.data.Dataset`. If extra context (e.g., extra datasets) is used in `personalize_fn_dict`, then client input has a third key `context` that is mapped to a object whose `tff.Type` is provided by the `context_tff_type` argument. * `personazliation_metrics` is an `OrderedDict` that maps a key 'baseline_metrics' to the evaluation metrics of the initial model (computed by `baseline_evaluate_fn`), and maps keys (strategy names) in `personalize_fn_dict` to the evaluation metrics of the corresponding personalization strategies. * Note: only metrics from at most `max_num_samples` participating clients (sampled without replacement) are collected to the SERVER. All collected metrics are stored in a single `OrderedDict` (`personalization_metrics` shown above), where each metric is mapped to a list of scalars (each scalar comes from one client). Metric values at the same position, e.g., metric_1[i], metric_2[i]..., all come from the same client. Raises: TypeError: If arguments are of the wrong types. ValueError: If `baseline_metrics` is used as a key in `personalize_fn_dict`. ValueError: If `max_num_samples` is not positive. """ # Obtain the types by constructing the model first. # TODO(b/124477628): Replace it with other ways of handling metadata. with tf.Graph().as_default(): py_typecheck.check_callable(model_fn) model = model_utils.enhance(model_fn()) model_weights_type = tff.framework.type_from_tensors(model.weights) batch_type = model.input_spec # Define the `tff.Type` of each client's input. Since batching (as well as # other preprocessing of datasets) is done within each personalization # strategy (i.e., by functions in `personalize_fn_dict`), the client-side # input should contain unbatched elements. element_type = _remove_batch_dim(batch_type) client_input_type = collections.OrderedDict([ ('train_data', tff.SequenceType(element_type)), ('test_data', tff.SequenceType(element_type)) ]) if context_tff_type is not None: py_typecheck.check_type(context_tff_type, tff.Type) client_input_type['context'] = context_tff_type client_input_type = tff.to_type(client_input_type) @tff.tf_computation(model_weights_type, client_input_type) def _client_computation(initial_model_weights, client_input): """TFF computation that runs on each client.""" train_data = client_input['train_data'] test_data = client_input['test_data'] context = client_input.get('context', None) final_metrics = collections.OrderedDict() # Compute the evaluation metrics of the initial model. final_metrics['baseline_metrics'] = _compute_baseline_metrics( model_fn, initial_model_weights, test_data, baseline_evaluate_fn) py_typecheck.check_type(personalize_fn_dict, collections.OrderedDict) if 'baseline_metrics' in personalize_fn_dict: raise ValueError('baseline_metrics should not be used as a key in ' 'personalize_fn_dict.') # Compute the evaluation metrics of the personalized models. The returned # `p13n_metrics` is an `OrderedDict` that maps keys (strategy names) in # `personalize_fn_dict` to the evaluation metrics of the corresponding # personalization strategies. p13n_metrics = _compute_p13n_metrics(model_fn, initial_model_weights, train_data, test_data, personalize_fn_dict, context) final_metrics.update(p13n_metrics) return final_metrics py_typecheck.check_type(max_num_samples, int) if max_num_samples <= 0: raise ValueError('max_num_samples must be a positive integer.') @tff.federated_computation( tff.FederatedType(model_weights_type, tff.SERVER), tff.FederatedType(client_input_type, tff.CLIENTS)) def personalization_eval(server_model_weights, federated_client_input): """TFF orchestration logic.""" client_init_weights = tff.federated_broadcast(server_model_weights) client_final_metrics = tff.federated_map( _client_computation, (client_init_weights, federated_client_input)) # WARNING: Collecting information from clients can be risky. Users have to # make sure that it is proper to collect those metrics from clients. # TODO(b/147889283): Add a link to the TFF doc once it exists. results = tff.utils.federated_sample(client_final_metrics, max_num_samples) return results return personalization_eval