def foo(temperatures, threshold): return tff.federated_sum( tff.federated_map( tff.tf_computation( lambda x, y: tf.cast(tf.greater(x, y), tf.int32), [tf.float32, tf.float32]), [temperatures, tff.federated_broadcast(threshold)]))
def encoded_mean_fn(state, values, weight): weighted_values = tff.federated_map(multiply_fn, [values, weight]) updated_state, summed_decoded_values = encoded_sum_fn( state, weighted_values) summed_weights = tff.federated_sum(weight) decoded_values = tff.federated_map( divide_fn, [summed_decoded_values, summed_weights]) return updated_state, decoded_values
def next_fn(server_state, client_data): broadcast_state = tff.federated_broadcast(server_state) @tff.tf_computation(tf.int32, tff.SequenceType(tf.float32)) @tf.function def some_transform(x, y): del y # Unused return x + 1 client_update = tff.federated_map(some_transform, (broadcast_state, client_data)) aggregate_update = tff.federated_sum(client_update) server_output = tff.federated_value(1234, tff.SERVER) return aggregate_update, server_output
def foo(x): return tff.federated_sum(x)
def _(x): return tff.federated_sum(x)