def run_one_round_tff(server_state, federated_dataset): """Orchestration logic for one round of optimization. Args: server_state: a `tff.learning.framework.ServerState` named tuple. federated_dataset: a federated `tf.Dataset` with placement tff.CLIENTS. Returns: A tuple of updated `tff.learning.framework.ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ new_broadcaster_state, client_model = stateful_model_broadcast_fn( server_state.model_broadcast_state, server_state.model) client_outputs = tff.federated_map(tf_client_delta, (federated_dataset, client_model)) # TODO(b/124070381): We hope to remove this explicit cast once we have a # full solution for type analysis in multiplications and divisions # inside TFF weight_denom = tff.federated_map(_cast_weight_to_float, client_outputs.weights_delta_weight) new_delta_aggregate_state, round_model_delta = stateful_delta_aggregate_fn( server_state.delta_aggregate_state, client_outputs.weights_delta, weight=weight_denom) # TODO(b/123408447): remove tff.federated_apply and call # tf_server_update directly once T <-> T@SERVER isomorphism is # supported. server_state = tff.federated_apply( tf_server_update, (server_state, round_model_delta, new_delta_aggregate_state, new_broadcaster_state)) aggregated_outputs = dummy_model_for_metadata.federated_output_computation( client_outputs.model_output) # TODO(b/131429028): Ideally this federated_zip shouldn't ever be needed. if isinstance(aggregated_outputs.type_signature, tff.NamedTupleType): # Promote the FederatedType outside the NamedTupleType. aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs
def run_one_round_tff(server_state, federated_dataset): """Orchestration logic for one round of optimization. Args: server_state: a `tff.learning.framework.ServerState` named tuple. federated_dataset: a federated `tf.Dataset` with placement tff.CLIENTS. Returns: A tuple of updated `tff.learning.framework.ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ model_weights_type = server_state_type.model @tff.tf_computation(tf_dataset_type, model_weights_type) def client_delta_tf(tf_dataset, initial_model_weights): """Performs client local model optimization. Args: tf_dataset: a `tf.data.Dataset` that provides training examples. initial_model_weights: a `model_utils.ModelWeights` containing the starting weights. Returns: A `ClientOutput` structure. """ client_delta_fn = model_to_client_delta_fn(model_fn) client_output = client_delta_fn(tf_dataset, initial_model_weights) return client_output new_broadcaster_state, client_model = stateful_model_broadcast_fn( server_state.model_broadcast_state, server_state.model) client_outputs = tff.federated_map(client_delta_tf, (federated_dataset, client_model)) @tff.tf_computation( server_state_type, model_weights_type.trainable, server_state.delta_aggregate_state.type_signature.member, server_state.model_broadcast_state.type_signature.member) def server_update_tf(server_state, model_delta, new_delta_aggregate_state, new_broadcaster_state): """Converts args to correct python types and calls server_update_model.""" py_typecheck.check_type(server_state, ServerState) server_state = ServerState( model=server_state.model, optimizer_state=list(server_state.optimizer_state), delta_aggregate_state=new_delta_aggregate_state, model_broadcast_state=new_broadcaster_state) return server_update_model(server_state, model_delta, model_fn=model_fn, optimizer_fn=server_optimizer_fn) # TODO(b/124070381): We hope to remove this explicit cast once we have a # full solution for type analysis in multiplications and divisions # inside TFF fed_weight_type = client_outputs.weights_delta_weight.type_signature.member py_typecheck.check_type(fed_weight_type, tff.TensorType) if fed_weight_type.dtype.is_integer: @tff.tf_computation(fed_weight_type) def _cast_to_float(x): return tf.cast(x, tf.float32) weight_denom = tff.federated_map( _cast_to_float, client_outputs.weights_delta_weight) else: weight_denom = client_outputs.weights_delta_weight new_delta_aggregate_state, round_model_delta = stateful_delta_aggregate_fn( server_state.delta_aggregate_state, client_outputs.weights_delta, weight=weight_denom) # TODO(b/123408447): remove tff.federated_apply and call # server_update_tf directly once T <-> T@SERVER isomorphism is # supported. server_state = tff.federated_apply( server_update_tf, (server_state, round_model_delta, new_delta_aggregate_state, new_broadcaster_state)) aggregated_outputs = dummy_model_for_metadata.federated_output_computation( client_outputs.model_output) # Promote the FederatedType outside the NamedTupleType aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs
def _state_incrementing_broadcast_next(server_state, server_value): add_one = tff.tf_computation(lambda x: x + 1, tf.int32) new_state = tff.federated_apply(add_one, server_state) return (new_state, tff.federated_broadcast(server_value))
def _state_incrementing_mean_next(server_state, client_value, weight=None): add_one = tff.tf_computation(lambda x: x + 1, tf.int32) new_state = tff.federated_apply(add_one, server_state) return (new_state, tff.federated_mean(client_value, weight=weight))
def run_one_round_tff(server_state, federated_dataset): """Orchestration logic for one round of optimization. Args: server_state: a `tff.learning.framework.ServerState` named tuple. federated_dataset: a federated `tf.Dataset` with placement tff.CLIENTS. Returns: A tuple of updated `tff.learning.framework.ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ model_weights_type = federated_server_state_type.member.model @tff.tf_computation(tf_dataset_type, model_weights_type) def client_delta_tf(tf_dataset, initial_model_weights): """Performs client local model optimization. Args: tf_dataset: a `tf.data.Dataset` that provides training examples. initial_model_weights: a `model_utils.ModelWeights` containing the starting weights. Returns: A `ClientOutput` structure. """ client_delta_fn = model_to_client_delta_fn(model_fn) # TODO(b/123092620): this can be removed once AnonymousTuple works with # tf.contrib.framework.nest, or the following behavior is moved to # anonymous_tuple module. if isinstance(initial_model_weights, anonymous_tuple.AnonymousTuple): initial_model_weights = model_utils.ModelWeights.from_tff_value( initial_model_weights) client_output = client_delta_fn(tf_dataset, initial_model_weights) return client_output client_outputs = tff.federated_map( client_delta_tf, (federated_dataset, tff.federated_broadcast(server_state.model))) @tff.tf_computation(server_state_type, model_weights_type.trainable) def server_update_model_tf(server_state, model_delta): """Converts args to correct python types and calls server_update_model.""" # We need to convert TFF types to the types server_update_model expects. # TODO(b/123092620): Mixing AnonymousTuple with other nested types is not # pretty, fold this into anonymous_tuple module or get working with # tf.contrib.framework.nest. py_typecheck.check_type(model_delta, anonymous_tuple.AnonymousTuple) model_delta = anonymous_tuple.to_odict(model_delta) py_typecheck.check_type(server_state, anonymous_tuple.AnonymousTuple) server_state = ServerState( model=model_utils.ModelWeights.from_tff_value(server_state.model), optimizer_state=list(server_state.optimizer_state)) return server_update_model( server_state, model_delta, model_fn=model_fn, optimizer_fn=server_optimizer_fn) # TODO(b/124070381): We hope to remove this explicit cast once we have a # full solution for type analysis in multiplications and divisions # inside TFF fed_weight_type = client_outputs.weights_delta_weight.type_signature.member py_typecheck.check_type(fed_weight_type, tff.TensorType) if fed_weight_type.dtype.is_integer: @tff.tf_computation(fed_weight_type) def _cast_to_float(x): return tf.cast(x, tf.float32) weight_denom = tff.federated_map(_cast_to_float, client_outputs.weights_delta_weight) else: weight_denom = client_outputs.weights_delta_weight round_model_delta = tff.federated_mean( client_outputs.weights_delta, weight=weight_denom) # TODO(b/123408447): remove tff.federated_apply and call # server_update_model_tf directly once T <-> T@SERVER isomorphism is # supported. server_state = tff.federated_apply(server_update_model_tf, (server_state, round_model_delta)) # Re-use graph used to construct `model`, since it has the variables, which # need to be read in federated_output_computation to get the correct shapes # and types for the federated aggregation. with g.as_default(): aggregated_outputs = dummy_model_for_metadata.federated_output_computation( client_outputs.model_output) # Promote the FederatedType outside the NamedTupleType aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs