def one_round_computation(server_state, federated_dataset): """Orchestration logic for one round of optimization. Args: server_state: a `tff.learning.framework.ServerState` named tuple. federated_dataset: a federated `tf.Dataset` with placement tff.CLIENTS. Returns: A tuple of updated `tff.learning.framework.ServerState` and the result of `tff.learning.Model.federated_output_computation`, both having `tff.SERVER` placement. """ broadcast_output = broadcast_process.next( server_state.model_broadcast_state, server_state.model) client_outputs = tff.federated_map( _compute_local_training_and_client_delta, (federated_dataset, broadcast_output.result)) aggregation_output = aggregation_process.next( server_state.delta_aggregate_state, client_outputs.weights_delta, client_outputs.weights_delta_weight) new_global_model, new_optimizer_state = tff.federated_map( server_update, (server_state.model, aggregation_output.result, server_state.optimizer_state)) new_server_state = tff.federated_zip( ServerState(new_global_model, new_optimizer_state, aggregation_output.state, broadcast_output.state)) aggregated_outputs = dummy_model_for_metadata.federated_output_computation( client_outputs.model_output) measurements = tff.federated_zip( collections.OrderedDict( broadcast=broadcast_output.measurements, aggregation=aggregation_output.measurements, train=aggregated_outputs)) return new_server_state, measurements
def next_comp(state, value): return collections.OrderedDict( state=tff.federated_map(_add_one, state), result=tff.federated_broadcast(value), # Arbitrary metrics for testing. measurements=tff.federated_map( tff.tf_computation( lambda v: tf.linalg.global_norm(tf.nest.flatten(v)) + 3.0), value))
def next_comp(state, value, weight): return collections.OrderedDict( state=tff.federated_map(_add_one, state), result=tff.federated_mean(value, weight), measurements=tff.federated_zip( collections.OrderedDict(num_clients=tff.federated_sum( tff.federated_value(1, tff.CLIENTS)))))
def robust_aggregation_fn(value, weight): aggregate = tff.federated_mean(value, weight=weight) for _ in range(num_communication_passes - 1): aggregate_at_client = tff.federated_broadcast(aggregate) updated_weight = tff.federated_map( update_weight_fn, (weight, aggregate_at_client, value)) aggregate = tff.federated_mean(value, weight=updated_weight) return aggregate
def run_one_round_tff(server_state, federated_dataset): """Orchestration logic for one round of optimization. Args: server_state: a `tff.learning.framework.ServerState` named tuple. federated_dataset: a federated `tf.Dataset` with placement tff.CLIENTS. Returns: A tuple of updated `tff.learning.framework.ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ new_broadcaster_state, client_model = stateful_model_broadcast_fn( server_state.model_broadcast_state, server_state.model) client_outputs = tff.federated_map(tf_client_delta, (federated_dataset, client_model)) # TODO(b/124070381): We hope to remove this explicit cast once we have a # full solution for type analysis in multiplications and divisions # inside TFF weight_denom = tff.federated_map(_cast_weight_to_float, client_outputs.weights_delta_weight) new_delta_aggregate_state, round_model_delta = stateful_delta_aggregate_fn( server_state.delta_aggregate_state, client_outputs.weights_delta, weight=weight_denom) # TODO(b/123408447): remove tff.federated_map and call # tf_server_update directly once T <-> T@SERVER isomorphism is # supported. server_state = tff.federated_map( tf_server_update, (server_state, round_model_delta, new_delta_aggregate_state, new_broadcaster_state)) aggregated_outputs = dummy_model_for_metadata.federated_output_computation( client_outputs.model_output) # TODO(b/131429028): Ideally this federated_zip shouldn't ever be needed. if isinstance(aggregated_outputs.type_signature, tff.NamedTupleType): # Promote the FederatedType outside the NamedTupleType. aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs
def next_fn_taking_client_ids(param): datasets_on_clients = tff.federated_map(dataset_computation, param[dataset_index]) original_param = [] for idx, elem in enumerate(param): if idx != dataset_index: original_param.append(elem) else: original_param.append(datasets_on_clients) return process.next(original_param)
def personalization_eval(server_model_weights, federated_client_input): """TFF orchestration logic.""" client_init_weights = tff.federated_broadcast(server_model_weights) client_final_metrics = tff.federated_map( _client_computation, (client_init_weights, federated_client_input)) # WARNING: Collecting information from clients can be risky. Users have to # make sure that it is proper to collect those metrics from clients. # TODO(b/147889283): Add a link to the TFF doc once it exists. results = tff.utils.federated_sample(client_final_metrics, max_num_samples) return results
def one_round_computation(server_state, federated_dataset): """Orchestration logic for one round of optimization. Args: server_state: a `tff.learning.framework.ServerState` named tuple. federated_dataset: a federated `tf.Dataset` with placement tff.CLIENTS. Returns: A tuple of updated `tff.learning.framework.ServerState` and the result of `tff.learning.Model.federated_output_computation`, both having `tff.SERVER` placement. """ new_broadcaster_state, client_model = model_broadcast_fn( server_state.model_broadcast_state, server_state.model) client_outputs = tff.federated_map( _compute_local_training_and_client_delta, (federated_dataset, client_model)) new_delta_aggregate_state, round_model_delta = delta_aggregate_fn( server_state.delta_aggregate_state, client_outputs.weights_delta, weight=client_outputs.weights_delta_weight) new_global_model, new_optimizer_state = tff.federated_map( server_update, (server_state.model, round_model_delta, server_state.optimizer_state)) new_server_state = tff.federated_zip( ServerState(new_global_model, new_optimizer_state, new_delta_aggregate_state, new_broadcaster_state)) aggregated_outputs = dummy_model_for_metadata.federated_output_computation( client_outputs.model_output) if isinstance(aggregated_outputs.type_signature, tff.NamedTupleType): # Promote the FederatedType outside the NamedTupleType. aggregated_outputs = tff.federated_zip(aggregated_outputs) return new_server_state, aggregated_outputs
def run_one_round_tff(server_state, federated_dataset): """Orchestration logic for one round of optimization. Args: server_state: a `tff.learning.framework.ServerState` named tuple. federated_dataset: a federated `tf.Dataset` with placement tff.CLIENTS. Returns: A tuple of updated `tff.learning.framework.ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ model_weights_type = server_state_type.model @tff.tf_computation(tf_dataset_type, model_weights_type) def client_delta_tf(tf_dataset, initial_model_weights): """Performs client local model optimization. Args: tf_dataset: a `tf.data.Dataset` that provides training examples. initial_model_weights: a `model_utils.ModelWeights` containing the starting weights. Returns: A `ClientOutput` structure. """ client_delta_fn = model_to_client_delta_fn(model_fn) client_output = client_delta_fn(tf_dataset, initial_model_weights) return client_output new_broadcaster_state, client_model = stateful_model_broadcast_fn( server_state.model_broadcast_state, server_state.model) client_outputs = tff.federated_map(client_delta_tf, (federated_dataset, client_model)) @tff.tf_computation( server_state_type, model_weights_type.trainable, server_state.delta_aggregate_state.type_signature.member, server_state.model_broadcast_state.type_signature.member) def server_update_tf(server_state, model_delta, new_delta_aggregate_state, new_broadcaster_state): """Converts args to correct python types and calls server_update_model.""" py_typecheck.check_type(server_state, ServerState) server_state = ServerState( model=server_state.model, optimizer_state=list(server_state.optimizer_state), delta_aggregate_state=new_delta_aggregate_state, model_broadcast_state=new_broadcaster_state) return server_update_model(server_state, model_delta, model_fn=model_fn, optimizer_fn=server_optimizer_fn) # TODO(b/124070381): We hope to remove this explicit cast once we have a # full solution for type analysis in multiplications and divisions # inside TFF fed_weight_type = client_outputs.weights_delta_weight.type_signature.member py_typecheck.check_type(fed_weight_type, tff.TensorType) if fed_weight_type.dtype.is_integer: @tff.tf_computation(fed_weight_type) def _cast_to_float(x): return tf.cast(x, tf.float32) weight_denom = tff.federated_map( _cast_to_float, client_outputs.weights_delta_weight) else: weight_denom = client_outputs.weights_delta_weight new_delta_aggregate_state, round_model_delta = stateful_delta_aggregate_fn( server_state.delta_aggregate_state, client_outputs.weights_delta, weight=weight_denom) # TODO(b/123408447): remove tff.federated_apply and call # server_update_tf directly once T <-> T@SERVER isomorphism is # supported. server_state = tff.federated_apply( server_update_tf, (server_state, round_model_delta, new_delta_aggregate_state, new_broadcaster_state)) aggregated_outputs = dummy_model_for_metadata.federated_output_computation( client_outputs.model_output) # Promote the FederatedType outside the NamedTupleType aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs
def federated_train(model, learning_rate, data): return tff.federated_average( tff.federated_map(local_train, [ tff.federated_broadcast(model), tff.federated_broadcast(learning_rate), data ]))
def _state_incrementing_broadcast_next(server_state, server_value): new_state = tff.federated_map(_add_one, server_state) return (new_state, tff.federated_broadcast(server_value))
def _state_incrementing_mean_next(server_state, client_value, weight=None): new_state = tff.federated_map(_add_one, server_state) return (new_state, tff.federated_mean(client_value, weight=weight))
def server_eval(server_model_weights, federated_dataset): client_outputs = tff.federated_map( client_eval, [tff.federated_broadcast(server_model_weights), federated_dataset]) return model.federated_output_computation(client_outputs.local_outputs)
def next_fn(empty_tup, x): del empty_tup # Unused return tff.federated_sum(tff.federated_map(reduce_dataset, x))
def cast_to_float_mean(state, value, weight): return state, tff.federated_mean(value, weight=tff.federated_map( _cast_weight_to_float, weight))
def _state_incrementing_mean_next(server_state, client_value, weight=None): add_one = tff.tf_computation(lambda x: x + 1, tf.int32) new_state = tff.federated_map(add_one, server_state) return (new_state, tff.federated_mean(client_value, weight=weight))
def _state_incrementing_broadcast_next(server_state, server_value): add_one = tff.tf_computation(lambda x: x + 1, tf.int32) new_state = tff.federated_map(add_one, server_state) return (new_state, tff.federated_broadcast(server_value))
def run_one_round_tff(server_state, federated_dataset): """Orchestration logic for one round of optimization. Args: server_state: a `tff.learning.framework.ServerState` named tuple. federated_dataset: a federated `tf.Dataset` with placement tff.CLIENTS. Returns: A tuple of updated `tff.learning.framework.ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ model_weights_type = federated_server_state_type.member.model @tff.tf_computation(tf_dataset_type, model_weights_type) def client_delta_tf(tf_dataset, initial_model_weights): """Performs client local model optimization. Args: tf_dataset: a `tf.data.Dataset` that provides training examples. initial_model_weights: a `model_utils.ModelWeights` containing the starting weights. Returns: A `ClientOutput` structure. """ client_delta_fn = model_to_client_delta_fn(model_fn) # TODO(b/123092620): this can be removed once AnonymousTuple works with # tf.contrib.framework.nest, or the following behavior is moved to # anonymous_tuple module. if isinstance(initial_model_weights, anonymous_tuple.AnonymousTuple): initial_model_weights = model_utils.ModelWeights.from_tff_value( initial_model_weights) client_output = client_delta_fn(tf_dataset, initial_model_weights) return client_output client_outputs = tff.federated_map( client_delta_tf, (federated_dataset, tff.federated_broadcast(server_state.model))) @tff.tf_computation(server_state_type, model_weights_type.trainable) def server_update_model_tf(server_state, model_delta): """Converts args to correct python types and calls server_update_model.""" # We need to convert TFF types to the types server_update_model expects. # TODO(b/123092620): Mixing AnonymousTuple with other nested types is not # pretty, fold this into anonymous_tuple module or get working with # tf.contrib.framework.nest. py_typecheck.check_type(model_delta, anonymous_tuple.AnonymousTuple) model_delta = anonymous_tuple.to_odict(model_delta) py_typecheck.check_type(server_state, anonymous_tuple.AnonymousTuple) server_state = ServerState( model=model_utils.ModelWeights.from_tff_value(server_state.model), optimizer_state=list(server_state.optimizer_state)) return server_update_model( server_state, model_delta, model_fn=model_fn, optimizer_fn=server_optimizer_fn) # TODO(b/124070381): We hope to remove this explicit cast once we have a # full solution for type analysis in multiplications and divisions # inside TFF fed_weight_type = client_outputs.weights_delta_weight.type_signature.member py_typecheck.check_type(fed_weight_type, tff.TensorType) if fed_weight_type.dtype.is_integer: @tff.tf_computation(fed_weight_type) def _cast_to_float(x): return tf.cast(x, tf.float32) weight_denom = tff.federated_map(_cast_to_float, client_outputs.weights_delta_weight) else: weight_denom = client_outputs.weights_delta_weight round_model_delta = tff.federated_mean( client_outputs.weights_delta, weight=weight_denom) # TODO(b/123408447): remove tff.federated_apply and call # server_update_model_tf directly once T <-> T@SERVER isomorphism is # supported. server_state = tff.federated_apply(server_update_model_tf, (server_state, round_model_delta)) # Re-use graph used to construct `model`, since it has the variables, which # need to be read in federated_output_computation to get the correct shapes # and types for the federated aggregation. with g.as_default(): aggregated_outputs = dummy_model_for_metadata.federated_output_computation( client_outputs.model_output) # Promote the FederatedType outside the NamedTupleType aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs