def encoded_sum_fn(state, values, weight=None): """Encoded sum federated_computation.""" del weight # Unused. encode_params, decode_before_sum_params, decode_after_sum_params = ( tff.federated_map(nest_encoder.get_params_fn, state)) encode_params = tff.federated_broadcast(encode_params) decode_before_sum_params = tff.federated_broadcast( decode_before_sum_params) encoded_values = tff.federated_map( nest_encoder.encode_fn, [values, encode_params, decode_before_sum_params]) aggregated_values = tff.federated_aggregate(encoded_values, nest_encoder.zero_fn(), nest_encoder.accumulate_fn, nest_encoder.merge_fn, nest_encoder.report_fn) decoded_values = tff.federated_map( nest_encoder.decode_after_sum_fn, [aggregated_values.values, decode_after_sum_params]) updated_state = tff.federated_map( nest_encoder.update_state_fn, [state, aggregated_values.state_update_tensors]) return updated_state, decoded_values
def foo(temperatures, threshold): return tff.federated_sum( tff.federated_map( tff.tf_computation( lambda x, y: tf.cast(tf.greater(x, y), tf.int32), [tf.float32, tf.float32]), [temperatures, tff.federated_broadcast(threshold)]))
def encoded_broadcast_fn(state, value): """Encoded broadcast federated_computation.""" new_state, encoded_value = tff.federated_apply(encode_fn, (state, value)) client_encoded_value = tff.federated_broadcast(encoded_value) client_value = tff.federated_map(decode_fn, client_encoded_value) return new_state, client_value
def broadcast_next_fn(state, value): @tff.tf_computation(tf.int32) def add_one(value): return value + 1 return { 'call_count': tff.federated_apply(add_one, state.call_count), }, tff.federated_broadcast(value)
def next_fn(server_state, client_data): broadcast_state = tff.federated_broadcast(server_state) @tff.tf_computation(tf.int32, tff.SequenceType(tf.float32)) @tf.function def some_transform(x, y): del y # Unused return x + 1 client_update = tff.federated_map(some_transform, (broadcast_state, client_data)) aggregate_update = tff.federated_sum(client_update) server_output = tff.federated_value(1234, tff.SERVER) return aggregate_update, server_output
def encoded_broadcast_fn(state, value): """Broadcast function, to be wrapped as federated_computation.""" state_type = state.type_signature.member value_type = value.type_signature.member encode_fn, decode_fn = _build_encode_decode_tf_computations_for_broadcast( state_type, value_type, encoders) new_state, encoded_value = tff.federated_apply(encode_fn, (state, value)) client_encoded_value = tff.federated_broadcast(encoded_value) client_value = tff.federated_map(decode_fn, client_encoded_value) return new_state, client_value
def next_computation(arg): """The logic of a single MapReduce sprocessing round.""" s1 = arg[0] c1 = arg[1] s2 = tff.federated_apply(cf.prepare, s1) c2 = tff.federated_broadcast(s2) c3 = tff.federated_zip([c1, c2]) c4 = tff.federated_map(cf.work, c3) c5 = c4[0] c6 = c4[1] s3 = tff.federated_aggregate(c5, cf.zero(), cf.accumulate, cf.merge, cf.report) s4 = tff.federated_zip([s1, s3]) s5 = tff.federated_apply(cf.update, s4) s6 = s5[0] s7 = s5[1] return s6, s7, c6
def next_fn(global_state, value, weight=None): """Defines next_fn for StatefulAggregateFn.""" # Weighted aggregation is not supported. # TODO(b/140236959): Add an assertion that weight is None here, so the # contract of this method is better established. Will likely cause some # downstream breaks. del weight ####################################### # Define local tf_computations # TODO(b/129567727): Make most of these tf_computations polymorphic # so type manipulation isn't needed. global_state_type = initialize_fn.type_signature.result @tff.tf_computation(global_state_type) def derive_sample_params(global_state): return query.derive_sample_params(global_state) @tff.tf_computation(derive_sample_params.type_signature.result, value.type_signature.member) def preprocess_record(params, record): # TODO(b/123092620): Once TFF passes the expected container type (instead # of AnonymousTuple), we shouldn't need this. record = from_anon_tuple_fn(record) return query.preprocess_record(params, record) # TODO(b/123092620): We should have the expected container type here. value_type = value_type_fn(value) tensor_specs = tff_framework.type_to_tf_tensor_specs(value_type) @tff.tf_computation def zero(): return query.initial_sample_state(tensor_specs) sample_state_type = zero.type_signature.result @tff.tf_computation(sample_state_type, preprocess_record.type_signature.result) def accumulate(sample_state, preprocessed_record): return query.accumulate_preprocessed_record( sample_state, preprocessed_record) @tff.tf_computation(sample_state_type, sample_state_type) def merge(sample_state_1, sample_state_2): return query.merge_sample_states(sample_state_1, sample_state_2) @tff.tf_computation(merge.type_signature.result) def report(sample_state): return sample_state @tff.tf_computation(sample_state_type, global_state_type) def post_process(sample_state, global_state): result, new_global_state = query.get_noised_result( sample_state, global_state) return new_global_state, result ####################################### # Orchestration logic sample_params = tff.federated_apply(derive_sample_params, global_state) client_sample_params = tff.federated_broadcast(sample_params) preprocessed_record = tff.federated_map(preprocess_record, (client_sample_params, value)) agg_result = tff.federated_aggregate(preprocessed_record, zero(), accumulate, merge, report) return tff.federated_apply(post_process, (agg_result, global_state))
def _(x): return tff.federated_broadcast(x)