コード例 #1
0
    def encoded_sum_fn(state, values, weight=None):
        """Encoded sum federated_computation."""
        del weight  # Unused.
        encode_params, decode_before_sum_params, decode_after_sum_params = (
            tff.federated_map(nest_encoder.get_params_fn, state))
        encode_params = tff.federated_broadcast(encode_params)
        decode_before_sum_params = tff.federated_broadcast(
            decode_before_sum_params)

        encoded_values = tff.federated_map(
            nest_encoder.encode_fn,
            [values, encode_params, decode_before_sum_params])

        aggregated_values = tff.federated_aggregate(encoded_values,
                                                    nest_encoder.zero_fn(),
                                                    nest_encoder.accumulate_fn,
                                                    nest_encoder.merge_fn,
                                                    nest_encoder.report_fn)

        decoded_values = tff.federated_map(
            nest_encoder.decode_after_sum_fn,
            [aggregated_values.values, decode_after_sum_params])

        updated_state = tff.federated_map(
            nest_encoder.update_state_fn,
            [state, aggregated_values.state_update_tensors])
        return updated_state, decoded_values
コード例 #2
0
 def foo(temperatures, threshold):
   return tff.federated_sum(
       tff.federated_map(
           tff.tf_computation(
               lambda x, y: tf.cast(tf.greater(x, y), tf.int32),
               [tf.float32, tf.float32]),
           [temperatures, tff.federated_broadcast(threshold)]))
コード例 #3
0
 def encoded_broadcast_fn(state, value):
     """Encoded broadcast federated_computation."""
     new_state, encoded_value = tff.federated_apply(encode_fn,
                                                    (state, value))
     client_encoded_value = tff.federated_broadcast(encoded_value)
     client_value = tff.federated_map(decode_fn, client_encoded_value)
     return new_state, client_value
コード例 #4
0
def broadcast_next_fn(state, value):
    @tff.tf_computation(tf.int32)
    def add_one(value):
        return value + 1

    return {
        'call_count': tff.federated_apply(add_one, state.call_count),
    }, tff.federated_broadcast(value)
コード例 #5
0
        def next_fn(server_state, client_data):
            broadcast_state = tff.federated_broadcast(server_state)

            @tff.tf_computation(tf.int32, tff.SequenceType(tf.float32))
            @tf.function
            def some_transform(x, y):
                del y  # Unused
                return x + 1

            client_update = tff.federated_map(some_transform,
                                              (broadcast_state, client_data))
            aggregate_update = tff.federated_sum(client_update)
            server_output = tff.federated_value(1234, tff.SERVER)
            return aggregate_update, server_output
コード例 #6
0
    def encoded_broadcast_fn(state, value):
        """Broadcast function, to be wrapped as federated_computation."""

        state_type = state.type_signature.member
        value_type = value.type_signature.member

        encode_fn, decode_fn = _build_encode_decode_tf_computations_for_broadcast(
            state_type, value_type, encoders)

        new_state, encoded_value = tff.federated_apply(encode_fn,
                                                       (state, value))
        client_encoded_value = tff.federated_broadcast(encoded_value)
        client_value = tff.federated_map(decode_fn, client_encoded_value)
        return new_state, client_value
コード例 #7
0
 def next_computation(arg):
     """The logic of a single MapReduce sprocessing round."""
     s1 = arg[0]
     c1 = arg[1]
     s2 = tff.federated_apply(cf.prepare, s1)
     c2 = tff.federated_broadcast(s2)
     c3 = tff.federated_zip([c1, c2])
     c4 = tff.federated_map(cf.work, c3)
     c5 = c4[0]
     c6 = c4[1]
     s3 = tff.federated_aggregate(c5, cf.zero(), cf.accumulate, cf.merge,
                                  cf.report)
     s4 = tff.federated_zip([s1, s3])
     s5 = tff.federated_apply(cf.update, s4)
     s6 = s5[0]
     s7 = s5[1]
     return s6, s7, c6
コード例 #8
0
    def next_fn(global_state, value, weight=None):
        """Defines next_fn for StatefulAggregateFn."""
        # Weighted aggregation is not supported.
        # TODO(b/140236959): Add an assertion that weight is None here, so the
        # contract of this method is better established. Will likely cause some
        # downstream breaks.
        del weight

        #######################################
        # Define local tf_computations

        # TODO(b/129567727): Make most of these tf_computations polymorphic
        # so type manipulation isn't needed.

        global_state_type = initialize_fn.type_signature.result

        @tff.tf_computation(global_state_type)
        def derive_sample_params(global_state):
            return query.derive_sample_params(global_state)

        @tff.tf_computation(derive_sample_params.type_signature.result,
                            value.type_signature.member)
        def preprocess_record(params, record):
            # TODO(b/123092620): Once TFF passes the expected container type (instead
            # of AnonymousTuple), we shouldn't need this.
            record = from_anon_tuple_fn(record)

            return query.preprocess_record(params, record)

        # TODO(b/123092620): We should have the expected container type here.
        value_type = value_type_fn(value)

        tensor_specs = tff_framework.type_to_tf_tensor_specs(value_type)

        @tff.tf_computation
        def zero():
            return query.initial_sample_state(tensor_specs)

        sample_state_type = zero.type_signature.result

        @tff.tf_computation(sample_state_type,
                            preprocess_record.type_signature.result)
        def accumulate(sample_state, preprocessed_record):
            return query.accumulate_preprocessed_record(
                sample_state, preprocessed_record)

        @tff.tf_computation(sample_state_type, sample_state_type)
        def merge(sample_state_1, sample_state_2):
            return query.merge_sample_states(sample_state_1, sample_state_2)

        @tff.tf_computation(merge.type_signature.result)
        def report(sample_state):
            return sample_state

        @tff.tf_computation(sample_state_type, global_state_type)
        def post_process(sample_state, global_state):
            result, new_global_state = query.get_noised_result(
                sample_state, global_state)
            return new_global_state, result

        #######################################
        # Orchestration logic

        sample_params = tff.federated_apply(derive_sample_params, global_state)
        client_sample_params = tff.federated_broadcast(sample_params)
        preprocessed_record = tff.federated_map(preprocess_record,
                                                (client_sample_params, value))
        agg_result = tff.federated_aggregate(preprocessed_record, zero(),
                                             accumulate, merge, report)

        return tff.federated_apply(post_process, (agg_result, global_state))
コード例 #9
0
 def _(x):
   return tff.federated_broadcast(x)