def fed_output(local_outputs): # TODO(b/124070381): Remove need for using num_examples_float here. return { 'num_examples': tff.federated_sum(local_outputs.num_examples), 'loss': tff.federated_average(local_outputs.loss, weight=local_outputs.num_examples_float), }
def federated_train(model, learning_rate, data): return tff.federated_average( tff.federated_map(local_train, [ tff.federated_broadcast(model), tff.federated_broadcast(learning_rate), data ]))
def run_one_round_tff(server_state, federated_dataset): """Orchestration logic for one round of optimization. Args: server_state: a `tff.learning.framework.ServerState` named tuple. federated_dataset: a federated `tf.Dataset` with placement tff.CLIENTS. Returns: A tuple of updated `tff.learning.framework.ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ model_weights_type = federated_server_state_type.member.model @tff.tf_computation(tf_dataset_type, model_weights_type) def client_delta_tf(tf_dataset, initial_model_weights): """Performs client local model optimization. Args: tf_dataset: a `tf.data.Dataset` that provides training examples. initial_model_weights: a `model_utils.ModelWeights` containing the starting weights. Returns: A `ClientOutput` structure. """ client_delta_fn = model_to_client_delta_fn(model_fn) # TODO(b/123092620): this can be removed once AnonymousTuple works with # tf.contrib.framework.nest, or the following behavior is moved to # anonymous_tuple module. if isinstance(initial_model_weights, anonymous_tuple.AnonymousTuple): initial_model_weights = model_utils.ModelWeights.from_tff_value( initial_model_weights) client_output = client_delta_fn(tf_dataset, initial_model_weights) return client_output client_outputs = tff.federated_map( client_delta_tf, (federated_dataset, tff.federated_broadcast(server_state.model))) @tff.tf_computation(server_state_type, model_weights_type.trainable) def server_update_model_tf(server_state, model_delta): """Converts args to correct python types and calls server_update_model.""" # We need to convert TFF types to the types server_update_model expects. # TODO(b/123092620): Mixing AnonymousTuple with other nested types is not # pretty, fold this into anonymous_tuple module or get working with # tf.contrib.framework.nest. py_typecheck.check_type(model_delta, anonymous_tuple.AnonymousTuple) model_delta = anonymous_tuple.to_odict(model_delta) py_typecheck.check_type(server_state, anonymous_tuple.AnonymousTuple) server_state = ServerState( model=model_utils.ModelWeights.from_tff_value( server_state.model), optimizer_state=list(server_state.optimizer_state)) return server_update_model(server_state, model_delta, model_fn=model_fn, optimizer_fn=server_optimizer_fn) # TODO(b/124070381): We hope to remove this explicit cast once we have a # full solution for type analysis in multiplications and divisions # inside TFF fed_weight_type = client_outputs.weights_delta_weight.type_signature.member py_typecheck.check_type(fed_weight_type, tff.TensorType) if fed_weight_type.dtype.is_integer: @tff.tf_computation(fed_weight_type) def _cast_to_float(x): return tf.cast(x, tf.float32) weight_denom = tff.federated_map( _cast_to_float, client_outputs.weights_delta_weight) else: weight_denom = client_outputs.weights_delta_weight round_model_delta = tff.federated_average(client_outputs.weights_delta, weight=weight_denom) # TODO(b/123408447): remove tff.federated_apply and call # server_update_model_tf directly once T <-> T@SERVER isomorphism is # supported. server_state = tff.federated_apply(server_update_model_tf, (server_state, round_model_delta)) # Re-use graph used to construct `model`, since it has the variables, which # need to be read in federated_output_computation to get the correct shapes # and types for the federated aggregation. with g.as_default(): aggregated_outputs = dummy_model_for_metadata.federated_output_computation( client_outputs.model_output) # Promote the FederatedType outside the NamedTupleType aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs