Esempio n. 1
0
 def foo(x):
     return tff.federated_aggregate(x, Accumulator(0, 0), accumulate,
                                    merge, report)
Esempio n. 2
0
 def foo(x):
     return tff.federated_aggregate(x, build_federated_zero(),
                                    accumulate, merge, report)
Esempio n. 3
0
 def foo(x):
     return tff.federated_aggregate(x, build_empty_accumulator(),
                                    accumulate, merge, report)
Esempio n. 4
0
    def next_fn(global_state, value, weight=None):
        """Defines next_fn for StatefulAggregateFn."""
        # Weighted aggregation is not supported.
        # TODO(b/140236959): Add an assertion that weight is None here, so the
        # contract of this method is better established. Will likely cause some
        # downstream breaks.
        del weight

        #######################################
        # Define local tf_computations

        # TODO(b/129567727): Make most of these tf_computations polymorphic
        # so type manipulation isn't needed.

        global_state_type = initialize_fn.type_signature.result

        @tff.tf_computation(global_state_type)
        def derive_sample_params(global_state):
            return query.derive_sample_params(global_state)

        @tff.tf_computation(derive_sample_params.type_signature.result,
                            value.type_signature.member)
        def preprocess_record(params, record):
            # TODO(b/123092620): Once TFF passes the expected container type (instead
            # of AnonymousTuple), we shouldn't need this.
            record = from_anon_tuple_fn(record)

            return query.preprocess_record(params, record)

        # TODO(b/123092620): We should have the expected container type here.
        value_type = value_type_fn(value)

        tensor_specs = tff_framework.type_to_tf_tensor_specs(value_type)

        @tff.tf_computation
        def zero():
            return query.initial_sample_state(tensor_specs)

        sample_state_type = zero.type_signature.result

        @tff.tf_computation(sample_state_type,
                            preprocess_record.type_signature.result)
        def accumulate(sample_state, preprocessed_record):
            return query.accumulate_preprocessed_record(
                sample_state, preprocessed_record)

        @tff.tf_computation(sample_state_type, sample_state_type)
        def merge(sample_state_1, sample_state_2):
            return query.merge_sample_states(sample_state_1, sample_state_2)

        @tff.tf_computation(merge.type_signature.result)
        def report(sample_state):
            return sample_state

        @tff.tf_computation(sample_state_type, global_state_type)
        def post_process(sample_state, global_state):
            result, new_global_state = query.get_noised_result(
                sample_state, global_state)
            return new_global_state, result

        #######################################
        # Orchestration logic

        sample_params = tff.federated_map(derive_sample_params, global_state)
        client_sample_params = tff.federated_broadcast(sample_params)
        preprocessed_record = tff.federated_map(preprocess_record,
                                                (client_sample_params, value))
        agg_result = tff.federated_aggregate(preprocessed_record, zero(),
                                             accumulate, merge, report)

        return tff.federated_map(post_process, (agg_result, global_state))