Esempio n. 1
0
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """
        client_model = tff.federated_broadcast(server_state.model)

        client_outputs = tff.federated_map(client_update_fn,
                                           (federated_dataset, client_model))

        weight_denom = client_outputs.weights_delta_weight
        round_model_delta = tff.federated_mean(client_outputs.weights_delta,
                                               weight=weight_denom)

        server_state = tff.federated_apply(server_update_fn,
                                           (server_state, round_model_delta))

        aggregated_outputs = dummy_model_for_metadata.federated_output_computation(
            client_outputs.model_output)
        aggregated_outputs = tff.federated_zip(aggregated_outputs)

        return server_state, aggregated_outputs
    def run_one_round(server_state, client_states):
        """Orchestration logic for one round of federated training computation."""
        # performing the federated averaging of the clients' weights
        mean_client_weights = tff.federated_mean(client_states)
        print(str(mean_client_weights))
        ## SERVER UPDATING STEP
        server_state = tff.federated_apply(server_update_fn,
                                           (server_state, mean_client_weights))

        # returning the new server state
        return server_state
Esempio n. 3
0
    def encoded_broadcast_fn(state, value):
        """Broadcast function, to be wrapped as federated_computation."""

        state_type = state.type_signature.member
        value_type = value.type_signature.member

        encode_fn, decode_fn = _build_encode_decode_tf_computations_for_broadcast(
            state_type, value_type, encoders)

        new_state, encoded_value = tff.federated_apply(encode_fn,
                                                       (state, value))
        client_encoded_value = tff.federated_broadcast(encoded_value)
        client_value = tff.federated_map(decode_fn, client_encoded_value)
        return new_state, client_value