Example #1
0
 def foo(x):
     val = intrinsics.federated_map(
         computations.tf_computation(lambda x: x > 10), x)
     self.assertIsInstance(val, value_base.Value)
     return val
Example #2
0
 def _(x):
     return intrinsics.federated_map(
         computations.tf_computation(lambda x: x > 10, tf.int32), x)
Example #3
0
 def foo(ds):
     val = intrinsics.federated_map(
         computations.tf_computation(
             lambda ds: ds.reduce(np.int32(0), lambda x, y: x + y)), ds)
     self.assertIsInstance(val, value_base.Value)
     return val
Example #4
0
 def comp(x):
     return intrinsics.federated_map(_identity, x)
Example #5
0
 def baz(x):
     value = intrinsics.federated_broadcast(x)
     return intrinsics.federated_map(add_one, value)
Example #6
0
 def comp():
   return intrinsics.federated_map(
       add_one, intrinsics.federated_value(10, placements.CLIENTS))
Example #7
0
 def new_computation(param):
     datasets_on_clients = intrinsics.federated_map(
         dataset_computation, param)
     return computation_body(datasets_on_clients)
Example #8
0
 def update_state(state, value_min, value_max):
     abs_max_fn = computations.tf_computation(
         lambda x, y: tf.maximum(tf.abs(x), tf.abs(y)))
     abs_value_max = intrinsics.federated_map(abs_max_fn,
                                              (value_min, value_max))
     return process.next(state, abs_value_max)
Example #9
0
 def get_bounds(state):
     upper_bound = process.report(state)
     lower_bound = intrinsics.federated_map(
         computations.tf_computation(lambda x: x * -1.0), upper_bound)
     return upper_bound, lower_bound
    def next_fn(global_state, value, weight=None):
        """Defines next_fn for StatefulAggregateFn."""
        # Weighted aggregation is not supported.
        # TODO(b/140236959): Add an assertion that weight is None here, so the
        # contract of this method is better established. Will likely cause some
        # downstream breaks.
        del weight

        #######################################
        # Define local tf_computations

        # TODO(b/129567727): Make most of these tf_computations polymorphic
        # so type manipulation isn't needed.

        global_state_type = initialize_fn.type_signature.result

        @computations.tf_computation(global_state_type)
        def derive_sample_params(global_state):
            return query.derive_sample_params(global_state)

        @computations.tf_computation(
            derive_sample_params.type_signature.result,
            value.type_signature.member)
        def preprocess_record(params, record):
            # TODO(b/123092620): Once TFF passes the expected container type (instead
            # of AnonymousTuple), we shouldn't need this.
            record = from_tff_result_fn(record)

            return query.preprocess_record(params, record)

        # TODO(b/123092620): We should have the expected container type here.
        value_type = value_type_fn(value)

        tensor_specs = type_utils.type_to_tf_tensor_specs(value_type)

        @computations.tf_computation
        def zero():
            return query.initial_sample_state(tensor_specs)

        sample_state_type = zero.type_signature.result

        @computations.tf_computation(sample_state_type,
                                     preprocess_record.type_signature.result)
        def accumulate(sample_state, preprocessed_record):
            return query.accumulate_preprocessed_record(
                sample_state, preprocessed_record)

        @computations.tf_computation(sample_state_type, sample_state_type)
        def merge(sample_state_1, sample_state_2):
            return query.merge_sample_states(sample_state_1, sample_state_2)

        @computations.tf_computation(merge.type_signature.result)
        def report(sample_state):
            return sample_state

        @computations.tf_computation(sample_state_type, global_state_type)
        def post_process(sample_state, global_state):
            result, new_global_state = query.get_noised_result(
                sample_state, global_state)
            return new_global_state, result

        #######################################
        # Orchestration logic

        sample_params = intrinsics.federated_map(derive_sample_params,
                                                 global_state)
        client_sample_params = intrinsics.federated_broadcast(sample_params)
        preprocessed_record = intrinsics.federated_map(
            preprocess_record, (client_sample_params, value))
        agg_result = intrinsics.federated_aggregate(preprocessed_record,
                                                    zero(), accumulate, merge,
                                                    report)

        return intrinsics.federated_map(post_process,
                                        (agg_result, global_state))