def robust_aggregation_fn(value, weight): aggregate = tff.federated_mean(value, weight=weight) for _ in range(num_communication_passes - 1): aggregate_at_client = tff.federated_broadcast(aggregate) updated_weight = tff.federated_map( update_weight_fn, (weight, aggregate_at_client, value)) aggregate = tff.federated_mean(value, weight=updated_weight) return aggregate
def next_comp(state, value, weight): return collections.OrderedDict( state=tff.federated_map(_add_one, state), result=tff.federated_mean(value, weight), measurements=tff.federated_zip( collections.OrderedDict(num_clients=tff.federated_sum( tff.federated_value(1, tff.CLIENTS)))))
def fed_output(local_outputs): # TODO(b/124070381): Remove need for using num_examples_float here. return { 'num_examples': tff.federated_sum(local_outputs.num_examples), 'loss': tff.federated_mean(local_outputs.loss, weight=local_outputs.num_examples_float), }
def test_fails_stateful_aggregate_and_process(self): model_weights_type = model_utils.weights_type_from_model( model_examples.LinearRegression) with self.assertRaises(optimizer_utils.DisjointArgumentError): federated_averaging.build_federated_averaging_process( model_fn=model_examples.LinearRegression, client_optimizer_fn=tf.keras.optimizers.SGD, stateful_delta_aggregate_fn=tff.utils.StatefulAggregateFn( initialize_fn=lambda: (), next_fn=lambda state, value, weight=None: # pylint: disable=g-long-lambda (state, tff.federated_mean(value, weight))), aggregation_process=optimizer_utils.build_stateless_mean( model_delta_type=model_weights_type.trainable))
def test_fails_stateful_aggregate_and_process(self): with tf.Graph().as_default(): model_weights_type = tff.framework.type_from_tensors( model_utils.ModelWeights.from_model( model_examples.LinearRegression())) with self.assertRaises(optimizer_utils.DisjointArgumentError): optimizer_utils.build_model_delta_optimizer_process( model_fn=model_examples.LinearRegression, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=tf.keras.optimizers.SGD, stateful_delta_aggregate_fn=tff.utils.StatefulAggregateFn( initialize_fn=lambda: (), next_fn=lambda state, value, weight=None: # pylint: disable=g-long-lambda (state, tff.federated_mean(value, weight))), aggregation_process=optimizer_utils.build_stateless_mean( model_delta_type=model_weights_type.trainable))
def _state_incrementing_mean_next(server_state, client_value, weight=None): add_one = tff.tf_computation(lambda x: x + 1, tf.int32) new_state = tff.federated_map(add_one, server_state) return (new_state, tff.federated_mean(client_value, weight=weight))
def build_stateless_mean(): """Just tff.federated_mean with empty state, to use as a default.""" return tff.utils.StatefulAggregateFn( initialize_fn=lambda: (), next_fn=lambda state, value, weight=None: ( # pylint: disable=g-long-lambda state, tff.federated_mean(value, weight=weight)))
def federated_train(model, learning_rate, data): return tff.federated_mean( tff.federated_map(local_train, [ tff.federated_broadcast(model), tff.federated_broadcast(learning_rate), data ]))
def _state_incrementing_mean_next(server_state, client_value, weight=None): new_state = tff.federated_map(_add_one, server_state) return (new_state, tff.federated_mean(client_value, weight=weight))
def run_one_round_tff(server_state, federated_dataset): """Orchestration logic for one round of optimization. Args: server_state: a `tff.learning.framework.ServerState` named tuple. federated_dataset: a federated `tf.Dataset` with placement tff.CLIENTS. Returns: A tuple of updated `tff.learning.framework.ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ model_weights_type = federated_server_state_type.member.model @tff.tf_computation(tf_dataset_type, model_weights_type) def client_delta_tf(tf_dataset, initial_model_weights): """Performs client local model optimization. Args: tf_dataset: a `tf.data.Dataset` that provides training examples. initial_model_weights: a `model_utils.ModelWeights` containing the starting weights. Returns: A `ClientOutput` structure. """ client_delta_fn = model_to_client_delta_fn(model_fn) # TODO(b/123092620): this can be removed once AnonymousTuple works with # tf.contrib.framework.nest, or the following behavior is moved to # anonymous_tuple module. if isinstance(initial_model_weights, anonymous_tuple.AnonymousTuple): initial_model_weights = model_utils.ModelWeights.from_tff_value( initial_model_weights) client_output = client_delta_fn(tf_dataset, initial_model_weights) return client_output client_outputs = tff.federated_map( client_delta_tf, (federated_dataset, tff.federated_broadcast(server_state.model))) @tff.tf_computation(server_state_type, model_weights_type.trainable) def server_update_model_tf(server_state, model_delta): """Converts args to correct python types and calls server_update_model.""" # We need to convert TFF types to the types server_update_model expects. # TODO(b/123092620): Mixing AnonymousTuple with other nested types is not # pretty, fold this into anonymous_tuple module or get working with # tf.contrib.framework.nest. py_typecheck.check_type(model_delta, anonymous_tuple.AnonymousTuple) model_delta = anonymous_tuple.to_odict(model_delta) py_typecheck.check_type(server_state, anonymous_tuple.AnonymousTuple) server_state = ServerState( model=model_utils.ModelWeights.from_tff_value(server_state.model), optimizer_state=list(server_state.optimizer_state)) return server_update_model( server_state, model_delta, model_fn=model_fn, optimizer_fn=server_optimizer_fn) # TODO(b/124070381): We hope to remove this explicit cast once we have a # full solution for type analysis in multiplications and divisions # inside TFF fed_weight_type = client_outputs.weights_delta_weight.type_signature.member py_typecheck.check_type(fed_weight_type, tff.TensorType) if fed_weight_type.dtype.is_integer: @tff.tf_computation(fed_weight_type) def _cast_to_float(x): return tf.cast(x, tf.float32) weight_denom = tff.federated_map(_cast_to_float, client_outputs.weights_delta_weight) else: weight_denom = client_outputs.weights_delta_weight round_model_delta = tff.federated_mean( client_outputs.weights_delta, weight=weight_denom) # TODO(b/123408447): remove tff.federated_apply and call # server_update_model_tf directly once T <-> T@SERVER isomorphism is # supported. server_state = tff.federated_apply(server_update_model_tf, (server_state, round_model_delta)) # Re-use graph used to construct `model`, since it has the variables, which # need to be read in federated_output_computation to get the correct shapes # and types for the federated aggregation. with g.as_default(): aggregated_outputs = dummy_model_for_metadata.federated_output_computation( client_outputs.model_output) # Promote the FederatedType outside the NamedTupleType aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs
def stateless_mean(state, value, weight): empty_metrics = tff.federated_value((), tff.SERVER) return collections.OrderedDict( state=state, result=tff.federated_mean(value, weight=weight), measurements=empty_metrics)
def cast_to_float_mean(state, value, weight): return state, tff.federated_mean(value, weight=tff.federated_map( _cast_weight_to_float, weight))