Пример #1
0
def trivial_aggregate():
    empty_at_clients = intrinsics.federated_value((), placements.CLIENTS)
    zero = ()
    accumulate = computations.tf_computation(lambda _a, _b: ())
    merge = computations.tf_computation(lambda _a, _b: ())
    report = computations.tf_computation(lambda _: ())

    return intrinsics.federated_aggregate(empty_at_clients, zero, accumulate,
                                          merge, report)
Пример #2
0
def simple_aggregate():
    one_at_clients = intrinsics.federated_value(1, placement_literals.CLIENTS)
    zero = 0
    accumulate = computations.tf_computation(lambda a, b: a + b)
    merge = computations.tf_computation(lambda a, b: a + b)
    report = computations.tf_computation(lambda a: a)

    return intrinsics.federated_aggregate(one_at_clients, zero, accumulate,
                                          merge, report)
 def test_construction_with_empty_state_does_not_raise(self):
     initialize_fn = computations.tf_computation()(lambda: ())
     next_fn = computations.tf_computation(())(lambda x: (x, 1.0))
     report_fn = computations.tf_computation(())(lambda x: x)
     try:
         estimation_process.EstimationProcess(initialize_fn, next_fn,
                                              report_fn)
     except:  # pylint: disable=bare-except
         self.fail(
             'Could not construct an EstimationProcess with empty state.')
Пример #4
0
 def next_fn(state, value):
     state = intrinsics.federated_map(
         computations.tf_computation(lambda x: x + 1), state)
     result = intrinsics.federated_map(
         computations.tf_computation(
             lambda x: tf.nest.map_structure(lambda y: y + 1, x)),
         intrinsics.federated_sum(value))
     measurements = intrinsics.federated_value(MEASUREMENT_CONSTANT,
                                               placements.SERVER)
     return measured_process.MeasuredProcessOutput(
         state, result, measurements)
Пример #5
0
 def test_invoke_returns_value_with_correct_type(self):
     context = federated_computation_context.FederatedComputationContext(
         context_stack_impl.context_stack)
     comp = computations.tf_computation(lambda: tf.constant(10))
     result = context.invoke(comp, None)
     self.assertIsInstance(result, value_base.Value)
     self.assertEqual(str(result.type_signature), 'int32')
Пример #6
0
  def test_roundtrip(self):
    add = computations.tf_computation(lambda x, y: x + y)
    server_data_type = computation_types.at_server(tf.int32)
    client_data_type = computation_types.at_clients(tf.int32)

    @computations.federated_computation(server_data_type, client_data_type)
    def add_server_number_plus_one(server_number, client_numbers):
      one = intrinsics.federated_value(1, placements.SERVER)
      server_context = intrinsics.federated_map(add, (one, server_number))
      client_context = intrinsics.federated_broadcast(server_context)
      return intrinsics.federated_map(add, (client_context, client_numbers))

    bf = form_utils.get_broadcast_form_for_computation(
        add_server_number_plus_one)
    self.assertEqual(bf.server_data_label, 'server_number')
    self.assertEqual(bf.client_data_label, 'client_numbers')
    self.assert_types_equivalent(
        bf.compute_server_context.type_signature,
        computation_types.FunctionType(tf.int32, (tf.int32,)))
    self.assertEqual(2, bf.compute_server_context(1)[0])
    self.assert_types_equivalent(
        bf.client_processing.type_signature,
        computation_types.FunctionType(((tf.int32,), tf.int32), tf.int32))
    self.assertEqual(3, bf.client_processing((1,), 2))

    round_trip_comp = form_utils.get_computation_for_broadcast_form(bf)
    self.assert_types_equivalent(round_trip_comp.type_signature,
                                 add_server_number_plus_one.type_signature)
    # 2 (server data) + 1 (constant in comp) + 2 (client data) = 5 (output)
    self.assertEqual([5, 6, 7], round_trip_comp(2, [2, 3, 4]))
Пример #7
0
 def test_next_return_namedtuple_raises(self):
   measured_process_output = collections.namedtuple(
       'MeasuredProcessOutput', ['state', 'result', 'measurements'])
   namedtuple_next_fn = computations.tf_computation(
       tf.int32)(lambda state: measured_process_output(state, (), ()))
   with self.assertRaises(errors.TemplateNotMeasuredProcessOutputError):
     measured_process.MeasuredProcess(test_initialize_fn, namedtuple_next_fn)
Пример #8
0
    def test_sequence_reduce(self):
        add_numbers = computations.tf_computation(
            lambda a, b: tf.add(a, b),  # pylint: disable=unnecessary-lambda
            [tf.int32, tf.int32])

        @computations.federated_computation(
            computation_types.SequenceType(tf.int32))
        def foo1(x):
            val = intrinsics.sequence_reduce(x, 0, add_numbers)
            self.assertIsInstance(val, value_base.Value)
            return val

        self.assert_type(foo1, '(int32* -> int32)')

        @computations.federated_computation(
            computation_types.FederatedType(
                computation_types.SequenceType(tf.int32), placements.SERVER))
        def foo2(x):
            zero = intrinsics.federated_value(0, placements.SERVER)
            value = intrinsics.sequence_reduce(x, zero, add_numbers)
            self.assertIsInstance(value, value_base.Value)
            return value

        self.assert_type(foo2, '(int32*@SERVER -> int32@SERVER)')

        @computations.federated_computation(
            computation_types.FederatedType(
                computation_types.SequenceType(tf.int32), placements.CLIENTS))
        def foo3(x):
            zero = intrinsics.federated_value(0, placements.CLIENTS)
            value = intrinsics.sequence_reduce(x, zero, add_numbers)
            self.assertIsInstance(value, value_base.Value)
            return value

        self.assert_type(foo3, '({int32*}@CLIENTS -> {int32}@CLIENTS)')
Пример #9
0
 def test_raises_zeroing_norm_fn_bad_arg(self):
     zeroing_norm_fn = computations.tf_computation(lambda x: x + 3,
                                                   tf.int32)
     with self.assertRaisesRegex(TypeError,
                                 'Argument of `zeroing_norm_fn`'):
         clipping_factory.ZeroingClippingFactory(2.0, zeroing_norm_fn,
                                                 mean_factory.MeanFactory())
Пример #10
0
 def initialize_comp():
     if not isinstance(stateful_fn.initialize,
                       computation_base.Computation):
         initialize = computations.tf_computation(stateful_fn.initialize)
     else:
         initialize = stateful_fn.initialize
     return intrinsics.federated_eval(initialize, placements.SERVER)
Пример #11
0
    def test_sequence_reduce(self):
        add_numbers = computations.tf_computation(tf.add, [tf.int32, tf.int32])

        @computations.federated_computation(
            computation_types.SequenceType(tf.int32))
        def foo1(x):
            val = intrinsics.sequence_reduce(x, 0, add_numbers)
            self.assertIsInstance(val, value_base.Value)
            return val

        self.assert_type(foo1, '(int32* -> int32)')

        @computations.federated_computation(
            computation_types.FederatedType(
                computation_types.SequenceType(tf.int32), placements.SERVER))
        def foo2(x):
            val = intrinsics.sequence_reduce(x, 0, add_numbers)
            self.assertIsInstance(val, value_base.Value)
            return val

        self.assert_type(foo2, '(int32*@SERVER -> int32@SERVER)')

        @computations.federated_computation(
            computation_types.FederatedType(
                computation_types.SequenceType(tf.int32), placements.CLIENTS))
        def foo3(x):
            val = intrinsics.sequence_reduce(x, 0, add_numbers)
            self.assertIsInstance(val, value_base.Value)
            return val

        self.assert_type(foo3, '({int32*}@CLIENTS -> {int32}@CLIENTS)')
Пример #12
0
    def test_sequence_reduce(self):
        add_numbers = computations.tf_computation(tf.add, [tf.int32, tf.int32])

        @computations.federated_computation(
            computation_types.SequenceType(tf.int32))
        def foo1(x):
            return intrinsics.sequence_reduce(x, 0, add_numbers)

        self.assertEqual(str(foo1.type_signature), '(int32* -> int32)')

        @computations.federated_computation(
            computation_types.FederatedType(
                computation_types.SequenceType(tf.int32), placements.SERVER,
                True))
        def foo2(x):
            return intrinsics.sequence_reduce(x, 0, add_numbers)

        self.assertEqual(str(foo2.type_signature),
                         '(int32*@SERVER -> int32@SERVER)')

        @computations.federated_computation(
            computation_types.FederatedType(
                computation_types.SequenceType(tf.int32), placements.CLIENTS))
        def foo3(x):
            return intrinsics.sequence_reduce(x, 0, add_numbers)

        self.assertEqual(str(foo3.type_signature),
                         '({int32*}@CLIENTS -> {int32}@CLIENTS)')
Пример #13
0
    def test_one_param(self):
        @tf.function
        def foo(x):
            return x + 1

        tf_comp = computations.tf_computation(foo, tf.int32)
        self.assertEqual(tf_comp(1), 2)
Пример #14
0
 def test_next_state_not_assignable_tuple_result(self):
     float_next_fn = computations.tf_computation(
         tf.float32,
         tf.float32)(lambda state, x: (tf.cast(state, tf.float32), x))
     with self.assertRaises(errors.TemplateStateNotAssignableError):
         iterative_process.IterativeProcess(test_initialize_fn,
                                            float_next_fn)
Пример #15
0
def apply(transform_fn: computation_base.Computation,
          arg_process: EstimationProcess):
    """Builds an `EstimationProcess` by applying `transform_fn` to `arg_process`.

  Args:
    transform_fn: A `computation_base.Computation` to apply to the estimate of
      the arg_process.
    arg_process: An `EstimationProcess` to which the transformation will be
      applied.

  Returns:
    An estimation process that applies `transform_fn` to the result of calling
      `arg_process.get_estimate`.
  """
    py_typecheck.check_type(transform_fn, computation_base.Computation)
    py_typecheck.check_type(arg_process, EstimationProcess)

    arg_process_estimate_type = arg_process.get_estimate.type_signature.result
    transform_fn_arg_type = transform_fn.type_signature.parameter

    if not transform_fn_arg_type.is_assignable_from(arg_process_estimate_type):
        raise errors.TemplateStateNotAssignableError(
            f'The return type of `get_estimate` of `arg_process` must be '
            f'assignable to the input argument of `transform_fn`, but '
            f'`get_estimate` returns type:\n{arg_process_estimate_type}\n'
            f'and the argument of `transform_fn` is:\n'
            f'{transform_fn_arg_type}')

    transformed_estimate_fn = computations.tf_computation(
        lambda state: transform_fn(arg_process.get_estimate(state)),
        arg_process.state_type)

    return EstimationProcess(initialize_fn=arg_process.initialize,
                             next_fn=arg_process.next,
                             get_estimate_fn=transformed_estimate_fn)
Пример #16
0
 def next_fn(strings, val):
     new_state_fn = computations.tf_computation()(
         lambda s: tf.concat([s, tf.constant(['abc'])], axis=0))
     return MeasuredProcessOutput(
         intrinsics.federated_map(new_state_fn, strings),
         intrinsics.federated_sum(val),
         intrinsics.federated_value(1, placements.SERVER))
Пример #17
0
    def test_namedtuple_param(self):

        MyType = collections.namedtuple('MyType', ['x', 'y'])  # pylint: disable=invalid-name

        @tf.function
        def foo(t):
            self.assertIsInstance(t, MyType)
            return t.x + t.y

        # Explicit type
        tf_comp = computations.tf_computation(foo, MyType(tf.int32, tf.int32))
        self.assertEqual(tf_comp(MyType(1, 2)), 3)

        # Polymorphic
        tf_comp = computations.tf_computation(foo)
        self.assertEqual(tf_comp(MyType(1, 2)), 3)
Пример #18
0
class ContainsAggregationShared(parameterized.TestCase):
    @parameterized.named_parameters([
        ('trivial_tf', computations.tf_computation(lambda: ())),
        ('trivial_tff', computations.federated_computation(lambda: ())),
        ('non_aggregation_intrinsics', non_aggregation_intrinsics),
        ('unused_aggregation', unused_aggregation),
        ('trivial_aggregate', trivial_aggregate),
        ('trivial_collect', trivial_collect),
        ('trivial_mean', trivial_mean),
        ('trivial_reduce', trivial_reduce),
        ('trivial_sum', trivial_sum),
        # TODO(b/120439632) Enable once federated_mean accepts structured weight.
        # ('trivial_weighted_mean', trivial_weighted_mean),
        ('trivial_secure_sum', trivial_secure_sum),
    ])
    def test_returns_none(self, comp):
        self.assertEmpty(
            tree_analysis.find_unsecure_aggregation_in_tree(
                comp.to_building_block()))
        self.assertEmpty(
            tree_analysis.find_secure_aggregation_in_tree(
                comp.to_building_block()))

    def test_throws_on_unresolvable_function_call(self):
        input_ty = ()
        output_ty = computation_types.FederatedType(tf.int32,
                                                    placement_literals.CLIENTS)

        @computations.federated_computation(
            computation_types.FunctionType(input_ty, output_ty))
        def comp(unknown_func):
            return unknown_func(())

        with self.assertRaises(ValueError):
            tree_analysis.find_unsecure_aggregation_in_tree(
                comp.to_building_block())
        with self.assertRaises(ValueError):
            tree_analysis.find_secure_aggregation_in_tree(
                comp.to_building_block())

    # functions without a federated output can't aggregate
    def test_returns_none_on_unresolvable_function_call_with_non_federated_output(
            self):
        input_ty = computation_types.FederatedType(tf.int32,
                                                   placement_literals.CLIENTS)
        output_ty = tf.int32

        @computations.federated_computation(
            computation_types.FunctionType(input_ty, output_ty))
        def comp(unknown_func):
            return unknown_func(
                intrinsics.federated_value(1, placement_literals.CLIENTS))

        self.assertEmpty(
            tree_analysis.find_unsecure_aggregation_in_tree(
                comp.to_building_block()))
        self.assertEmpty(
            tree_analysis.find_secure_aggregation_in_tree(
                comp.to_building_block()))
Пример #19
0
 def next_fn(state, weights, updates):
     new_weights = intrinsics.federated_map(
         computations.tf_computation(lambda x, y: x + y),
         (weights.trainable, updates))
     new_weights = intrinsics.federated_zip(
         model_utils.ModelWeights(new_weights, ()))
     return measured_process.MeasuredProcessOutput(state, new_weights,
                                                   empty_at_server())
Пример #20
0
    def test_py_and_tf_args(self):
        @tf.function(autograph=False)
        def foo(x, y, add=True):
            return x + y if add else x - y

            # Note: tf.Functions support mixing tensorflow and Python arguments,
            # usually with the semantics you would expect. Currently, TFF does not
            # support this kind of mixing, even for Polymorphic TFF functions.
            # However, you can work around this by explicitly binding any Python
            # arguments on a tf.Function:

        tf_poly_add = computations.tf_computation(lambda x, y: foo(x, y, True))
        tf_poly_sub = computations.tf_computation(
            lambda x, y: foo(x, y, False))
        self.assertEqual(tf_poly_add(2, 1), 3)
        self.assertEqual(tf_poly_add(2., 1.), 3.)
        self.assertEqual(tf_poly_sub(2, 1), 1)
Пример #21
0
    def test_explicit_tuple_param(self):
        # See also test_polymorphic_tuple_input
        @tf.function
        def foo(t):
            return t[0] + t[1]

        tf_comp = computations.tf_computation(foo, (tf.int32, tf.int32))
        self.assertEqual(tf_comp((1, 2)), 3)
Пример #22
0
 def foo(temperatures, threshold):
     return intrinsics.federated_sum(
         intrinsics.federated_map(
             computations.tf_computation(
                 lambda x, y: tf.to_int32(tf.greater(x, y)),
                 [tf.float32, tf.float32]),
             [temperatures,
              intrinsics.federated_broadcast(threshold)]))
 def test_map_estimate_not_assignable(self):
     map_fn = computations.tf_computation(
         tf.int32)(lambda estimate: estimate)
     process = estimation_process.EstimationProcess(test_initialize_fn,
                                                    test_next_fn,
                                                    test_report_fn)
     with self.assertRaises(estimation_process.EstimateNotAssignableError):
         process.map(map_fn)
Пример #24
0
 def next_comp(state, value):
     return measured_process.MeasuredProcessOutput(
         state=intrinsics.federated_map(_add_one, state),
         result=intrinsics.federated_broadcast(value),
         # Arbitrary metrics for testing.
         measurements=intrinsics.federated_map(
             computations.tf_computation(
                 lambda v: tf.linalg.global_norm(tf.nest.flatten(v)) + 3.0),
             value))
Пример #25
0
 def foo(temperatures, threshold):
     val = intrinsics.federated_sum(
         intrinsics.federated_map(
             computations.tf_computation(
                 lambda x, y: tf.cast(tf.greater(x, y), tf.int32)),
             [temperatures,
              intrinsics.federated_broadcast(threshold)]))
     self.assertIsInstance(val, value_base.Value)
     return val
Пример #26
0
    def test_non_federated_init_next_raises(self):
        initialize_fn = computations.tf_computation(lambda: 0)

        @computations.tf_computation(tf.int32, tf.float32)
        def next_fn(state, val):
            return MeasuredProcessOutput(state, val, ())

        with self.assertRaises(errors.TemplateNotFederatedError):
            distributors.DistributionProcess(initialize_fn, next_fn)
Пример #27
0
  def test_non_federated_init_next_raises(self):
    initialize_fn = computations.tf_computation(lambda: 0)

    @computations.tf_computation(tf.int32, tf.float32)
    def next_fn(state, val):
      return MeasuredProcessOutput(state, val, ())

    with self.assertRaises(aggregation_process.AggregationNotFederatedError):
      aggregation_process.AggregationProcess(initialize_fn, next_fn)
Пример #28
0
 def test_federated_map_with_client_dataset_reduce(self):
     ds = _mock_data_of_type(
         computation_types.at_clients(computation_types.SequenceType(
             tf.int32),
                                      all_equal=True))
     val = intrinsics.federated_map(
         computations.tf_computation(
             lambda ds: ds.reduce(np.int32(0), lambda x, y: x + y)), ds)
     self.assert_value(val, '{int32}@CLIENTS')
Пример #29
0
    def test_nested_tuple_input_explicit_types(self):
        @tf.function(autograph=False)
        def foo(tuple1, tuple2):
            return tuple1[0] + tuple1[1][0] + tuple1[1][1] + tuple2[
                0] + tuple2[1]

        tff_type = [(tf.int32, (tf.int32, tf.int32)), (tf.int32, tf.int32)]
        tf_comp = computations.tf_computation(foo, tff_type)
        self.assertEqual(tf_comp((1, (2, 3)), (0, 0)), 6)
    def test_tensorflow_computation_with_lambda_and_selection(self):
        @computations.federated_computation(tf.int32,
                                            computation_types.FunctionType(
                                                tf.int32, tf.int32))
        def apply_twice(x, f):
            return f(f(x))

        add_one = computations.tf_computation(lambda x: x + 1, tf.int32)

        self.assertEqual(apply_twice(5, add_one), 7)