def next_comp(state, value):
     return collections.OrderedDict(
         state=tff.federated_map(_add_one, state),
         result=tff.federated_broadcast(value),
         # Arbitrary metrics for testing.
         measurements=tff.federated_map(
             tff.tf_computation(
                 lambda v: tf.linalg.global_norm(tf.nest.flatten(v)) + 3.0),
             value))
Esempio n. 2
0
def _state_incrementing_broadcast_next(server_state, server_value):
    add_one = tff.tf_computation(lambda x: x + 1, tf.int32)
    new_state = tff.federated_map(add_one, server_state)
    return (new_state, tff.federated_broadcast(server_value))
Esempio n. 3
0
def _state_incrementing_mean_next(server_state, client_value, weight=None):
    add_one = tff.tf_computation(lambda x: x + 1, tf.int32)
    new_state = tff.federated_map(add_one, server_state)
    return (new_state, tff.federated_mean(client_value, weight=weight))
Esempio n. 4
0
 def server_init_tff():
   """Orchestration logic for server model initialization."""
   no_arg_server_init_fn = lambda: server_init(model_fn, server_optimizer_fn)
   server_init_tf = tff.tf_computation(no_arg_server_init_fn)
   return tff.federated_value(server_init_tf(), tff.SERVER)
Esempio n. 5
0
 def initialize_comp():
   if not isinstance(stateful_fn.initialize, tff.Computation):
     initialize = tff.tf_computation(stateful_fn.initialize)
   else:
     initialize = stateful_fn.initialize
   return tff.federated_eval(initialize, tff.SERVER)