def foo(x): return tff.federated_aggregate(x, Accumulator(0, 0), accumulate, merge, report)
def foo(x): return tff.federated_aggregate(x, build_federated_zero(), accumulate, merge, report)
def foo(x): return tff.federated_aggregate(x, build_empty_accumulator(), accumulate, merge, report)
def next_fn(global_state, value, weight=None): """Defines next_fn for StatefulAggregateFn.""" # Weighted aggregation is not supported. # TODO(b/140236959): Add an assertion that weight is None here, so the # contract of this method is better established. Will likely cause some # downstream breaks. del weight ####################################### # Define local tf_computations # TODO(b/129567727): Make most of these tf_computations polymorphic # so type manipulation isn't needed. global_state_type = initialize_fn.type_signature.result @tff.tf_computation(global_state_type) def derive_sample_params(global_state): return query.derive_sample_params(global_state) @tff.tf_computation(derive_sample_params.type_signature.result, value.type_signature.member) def preprocess_record(params, record): # TODO(b/123092620): Once TFF passes the expected container type (instead # of AnonymousTuple), we shouldn't need this. record = from_anon_tuple_fn(record) return query.preprocess_record(params, record) # TODO(b/123092620): We should have the expected container type here. value_type = value_type_fn(value) tensor_specs = tff_framework.type_to_tf_tensor_specs(value_type) @tff.tf_computation def zero(): return query.initial_sample_state(tensor_specs) sample_state_type = zero.type_signature.result @tff.tf_computation(sample_state_type, preprocess_record.type_signature.result) def accumulate(sample_state, preprocessed_record): return query.accumulate_preprocessed_record( sample_state, preprocessed_record) @tff.tf_computation(sample_state_type, sample_state_type) def merge(sample_state_1, sample_state_2): return query.merge_sample_states(sample_state_1, sample_state_2) @tff.tf_computation(merge.type_signature.result) def report(sample_state): return sample_state @tff.tf_computation(sample_state_type, global_state_type) def post_process(sample_state, global_state): result, new_global_state = query.get_noised_result( sample_state, global_state) return new_global_state, result ####################################### # Orchestration logic sample_params = tff.federated_map(derive_sample_params, global_state) client_sample_params = tff.federated_broadcast(sample_params) preprocessed_record = tff.federated_map(preprocess_record, (client_sample_params, value)) agg_result = tff.federated_aggregate(preprocessed_record, zero(), accumulate, merge, report) return tff.federated_map(post_process, (agg_result, global_state))