def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ client_model = tff.federated_broadcast(server_state.model) client_outputs = tff.federated_map(client_update_fn, (federated_dataset, client_model)) weight_denom = client_outputs.weights_delta_weight round_model_delta = tff.federated_mean(client_outputs.weights_delta, weight=weight_denom) server_state = tff.federated_apply(server_update_fn, (server_state, round_model_delta)) aggregated_outputs = dummy_model_for_metadata.federated_output_computation( client_outputs.model_output) aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs
def run_one_round(server_state, client_states): """Orchestration logic for one round of federated training computation.""" # performing the federated averaging of the clients' weights mean_client_weights = tff.federated_mean(client_states) print(str(mean_client_weights)) ## SERVER UPDATING STEP server_state = tff.federated_apply(server_update_fn, (server_state, mean_client_weights)) # returning the new server state return server_state
def encoded_broadcast_fn(state, value): """Broadcast function, to be wrapped as federated_computation.""" state_type = state.type_signature.member value_type = value.type_signature.member encode_fn, decode_fn = _build_encode_decode_tf_computations_for_broadcast( state_type, value_type, encoders) new_state, encoded_value = tff.federated_apply(encode_fn, (state, value)) client_encoded_value = tff.federated_broadcast(encoded_value) client_value = tff.federated_map(decode_fn, client_encoded_value) return new_state, client_value