def test_iterative_process_type_mismatch(self):
        with self.assertRaisesRegex(
                TypeError, r'The return type of initialize_fn should match.*'):

            @computations.federated_computation([tf.float32, tf.float32])
            def add_float32(current, val):
                return current + val

            _ = computation_utils.IterativeProcess(initialize_fn=initialize,
                                                   next_fn=add_float32)

        with self.assertRaisesRegex(
                TypeError,
                'The return type of next_fn should match the first parameter'):

            @computations.federated_computation(tf.int32)
            def add_bad_result(_):
                return 0.0

            _ = computation_utils.IterativeProcess(initialize_fn=initialize,
                                                   next_fn=add_bad_result)

        with self.assertRaisesRegex(
                TypeError,
                'The return type of next_fn should match the first parameter'):

            @computations.federated_computation(tf.int32)
            def add_bad_multi_result(_):
                return 0.0, 0

            _ = computation_utils.IterativeProcess(
                initialize_fn=initialize, next_fn=add_bad_multi_result)
  def test_iterative_process_initialize_bad_type(self):
    with self.assertRaisesRegex(TypeError, r'Expected .*\.Computation, .*'):
      _ = computation_utils.IterativeProcess(
          initialize_fn=None, next_fn=add_int32)

    with self.assertRaisesRegex(
        TypeError, r'initialize_fn must be a no-arg tff.Computation'):

      @tff.federated_computation(tf.int32)
      def one_arg_initialize(one_arg):
        del one_arg  # unused
        return tff.to_value(0)

      _ = computation_utils.IterativeProcess(
          initialize_fn=one_arg_initialize, next_fn=add_int32)
Exemple #3
0
def get_iterative_process_for_sum_example_with_no_server_state():
    """Returns an iterative process for a sum example."""
    @computations.federated_computation
    def init_fn():
        """The `init` function for `computation_utils.IterativeProcess`."""
        return intrinsics.federated_value([], placements.SERVER)

    @computations.tf_computation(tf.int32)
    def work(client_data):
        del client_data  # Unused
        return [1, 1], []

    @computations.tf_computation([tf.int32, tf.int32])
    def update(global_update):
        return [], global_update

    @computations.federated_computation([
        computation_types.FederatedType([], placements.SERVER),
        computation_types.FederatedType(tf.int32, placements.CLIENTS),
    ])
    def next_fn(server_state, client_data):
        """The `next` function for `computation_utils.IterativeProcess`."""
        del server_state  # Unused
        client_updates, client_output = intrinsics.federated_map(
            work, client_data)
        unsecure_update = intrinsics.federated_sum(client_updates[0])
        secure_update = intrinsics.federated_secure_sum(client_updates[1], 8)
        s5 = intrinsics.federated_zip([unsecure_update, secure_update])
        new_server_state, server_output = intrinsics.federated_map(update, s5)
        return new_server_state, server_output, client_output

    return computation_utils.IterativeProcess(init_fn, next_fn)
Exemple #4
0
def get_iterative_process_for_concise_sum_example():
    """Returns an iterative process for a sum example."""
    @computations.federated_computation
    def init_fn():
        """The `init` function for `computation_utils.IterativeProcess`."""
        return intrinsics.federated_value([0, 0], placements.SERVER)

    @computations.tf_computation(tf.int32, [tf.int32, tf.int32])
    def work(client_data, client_input):
        del client_data  # Unused
        del client_input  # Unused
        return [1, 1]

    @computations.federated_computation([
        computation_types.FederatedType([tf.int32, tf.int32],
                                        placements.SERVER),
        computation_types.FederatedType(tf.int32, placements.CLIENTS),
    ])
    def next_fn(server_state, client_data):
        """The `next` function for `computation_utils.IterativeProcess`."""
        client_input = intrinsics.federated_broadcast(server_state)
        c3 = intrinsics.federated_zip([client_data, client_input])
        client_updates = intrinsics.federated_map(work, c3)
        unsecure_update = intrinsics.federated_sum(client_updates[0])
        secure_update = intrinsics.federated_secure_sum(client_updates[1], 8)
        new_server_state = intrinsics.federated_zip(
            [unsecure_update, secure_update])
        server_output = intrinsics.federated_value([], placements.SERVER)
        return new_server_state, server_output

    return computation_utils.IterativeProcess(init_fn, next_fn)
Exemple #5
0
def get_iterative_process_for_canonical_form_example():
    """Construct a simple `IterativeProcess` compatible with `CanonicalForm`.

  The computation itself is non-sensical; but demonstrates the required type
  signatures for `CanonicalForm```.

  Returns:
    An `IterativeProcess` compatible with `CanonicalForm`.
  """
    @computations.tf_computation(tf.int32, tf.float32)
    def add_two(x_int, y_float):
        return tf.cast(x_int, tf.float32) + y_float

    @computations.federated_computation
    def init_fn():
        return intrinsics.federated_value(1.234, placements.SERVER)

    @computations.federated_computation([
        computation_types.FederatedType(tf.float32, placements.SERVER),
        computation_types.FederatedType(tf.int32, placements.CLIENTS)
    ])
    def next_fn(server_val, client_val):
        """Defines a series of federated computations compatible with CanonicalForm."""
        broadcast_val = intrinsics.federated_broadcast(server_val)
        values_on_clients = intrinsics.federated_zip(
            (client_val, broadcast_val))
        result_on_clients = intrinsics.federated_map(add_two,
                                                     values_on_clients)
        aggregated_result = intrinsics.federated_mean(result_on_clients)
        side_output = intrinsics.federated_value([1, 2, 3, 4, 5],
                                                 placements.SERVER)
        return aggregated_result, side_output

    return computation_utils.IterativeProcess(init_fn, next_fn)
Exemple #6
0
  def test_returns_canonical_form_with_no_broadcast(self):

    @computations.tf_computation(tf.int32)
    @tf.function
    def map_fn(client_val):
      del client_val  # unused
      return 1

    @computations.federated_computation
    def init_fn():
      return intrinsics.federated_value(False, placements.SERVER)

    @computations.federated_computation(
        computation_types.FederatedType(tf.bool, placements.SERVER),
        computation_types.FederatedType(tf.int32, placements.CLIENTS))
    def next_fn(server_val, client_val):
      del server_val  # Unused
      result_on_clients = intrinsics.federated_map(map_fn, client_val)
      aggregated_result = intrinsics.federated_sum(result_on_clients)
      side_output = intrinsics.federated_value(False, placements.SERVER)
      return side_output, aggregated_result

    ip = computation_utils.IterativeProcess(init_fn, next_fn)
    cf = canonical_form_utils.get_canonical_form_for_iterative_process(ip)
    self.assertIsInstance(cf, canonical_form.CanonicalForm)
    def test_iterative_process_state_tuple_arg(self):
        iterative_process = computation_utils.IterativeProcess(
            initialize, add_int32)

        state = iterative_process.initialize()
        iterations = 10
        for val in range(iterations):
            state = iterative_process.next(state, val)
        self.assertEqual(state, sum(range(iterations)))
  def test_iterative_process_state_multiple_return_values(self):
    iterative_process = computation_utils.IterativeProcess(
        initialize, add_mul_int32)

    state = iterative_process.initialize()
    iterations = 10
    for val in range(iterations):
      state, product = iterative_process.next(state, val)
    self.assertEqual(state, sum(range(iterations)))
    self.assertEqual(product, sum(range(iterations - 1)) * (iterations - 1))
    def test_iterative_process_state_only(self):
        iterative_process = computation_utils.IterativeProcess(
            initialize, count_int32)

        state = iterative_process.initialize()
        iterations = 10
        for _ in range(iterations):
            # TODO(b/122321354): remove the .item() call on `state` once numpy.int32
            # type is supported.
            state = iterative_process.next(state.item())
        self.assertEqual(state, iterations)
 def test_tensor_computation_fails_well(self):
   cf = test_utils.get_temperature_sensor_example()
   it = canonical_form_utils.get_iterative_process_for_canonical_form(cf)
   init_result = it.initialize.type_signature.result
   lam = building_blocks.Lambda('x', init_result,
                                building_blocks.Reference('x', init_result))
   bad_it = computation_utils.IterativeProcess(
       it.initialize,
       computation_wrapper_instances.building_block_to_computation(lam))
   with self.assertRaisesRegex(TypeError,
                               'instances of `tff.NamedTupleType`.'):
     canonical_form_utils.get_canonical_form_for_iterative_process(bad_it)
def get_iterative_process_for_canonical_form(cf):
    """Creates `tff.utils.IterativeProcess` from a canonical form.

  Args:
    cf: An instance of `tff.backends.mapreduce.CanonicalForm`.

  Returns:
    An instance of `tff.utils.IterativeProcess` that corresponds to `cf`.

  Raises:
    TypeError: If the arguments are of the wrong types.
  """
    py_typecheck.check_type(cf, canonical_form.CanonicalForm)

    @computations.federated_computation
    def init_computation():
        return intrinsics.federated_value(cf.initialize(), placements.SERVER)

    @computations.federated_computation(
        init_computation.type_signature.result,
        computation_types.FederatedType(cf.work.type_signature.parameter[0],
                                        placements.CLIENTS))
    def next_computation(arg):
        """The logic of a single MapReduce processing round."""
        s1 = arg[0]
        c1 = arg[1]
        s2 = intrinsics.federated_map(cf.prepare, s1)
        c2 = intrinsics.federated_broadcast(s2)
        c3 = intrinsics.federated_zip([c1, c2])
        c4 = intrinsics.federated_map(cf.work, c3)
        c5 = c4[0]
        c6 = c5[0]
        c7 = c5[1]
        c8 = c4[1]
        s3 = intrinsics.federated_aggregate(c6, cf.zero(), cf.accumulate,
                                            cf.merge, cf.report)
        s4 = intrinsics.federated_secure_sum(c7, cf.bitwidth())
        s5 = intrinsics.federated_zip([s3, s4])
        s6 = intrinsics.federated_zip([s1, s5])
        s7 = intrinsics.federated_map(cf.update, s6)
        s8 = s7[0]
        s9 = s7[1]
        return s8, s9, c8

    return computation_utils.IterativeProcess(init_computation,
                                              next_computation)
Exemple #12
0
def get_unused_tf_computation_arg_iterative_process():
  """Returns an iterative process with a @tf.function with an unused arg."""
  server_state_type = computation_types.NamedTupleType([('num_clients',
                                                         tf.int32)])

  def _bind_tf_function(unused_input, tf_func):
    tf_wrapper = tf.function(lambda _: tf_func())
    input_federated_type = unused_input.type_signature
    wrapper = computations.tf_computation(tf_wrapper,
                                          input_federated_type.member)
    return intrinsics.federated_map(wrapper, unused_input)

  def count_clients_federated(client_data):

    @tf.function
    def client_ones_fn():
      return tf.ones(shape=[], dtype=tf.int32)

    client_ones = _bind_tf_function(client_data, client_ones_fn)
    return intrinsics.federated_sum(client_ones)

  @computations.federated_computation
  def init_fn():
    return intrinsics.federated_value(
        collections.OrderedDict([('num_clients', 0)]), placements.SERVER)

  @computations.federated_computation([
      computation_types.FederatedType(server_state_type, placements.SERVER),
      computation_types.FederatedType(
          computation_types.SequenceType(tf.string), placements.CLIENTS)
  ])
  def next_fn(server_state, client_val):
    """`next` function for `computation_utils.IterativeProcess`."""
    server_update = intrinsics.federated_zip(
        collections.OrderedDict([('num_clients',
                                  count_clients_federated(client_val))]))

    server_output = intrinsics.federated_value((), placements.SERVER)
    server_output = intrinsics.federated_sum(
        _bind_tf_function(
            intrinsics.federated_broadcast(server_state), tf.timestamp))

    return server_update, server_output

  return computation_utils.IterativeProcess(init_fn, next_fn)
Exemple #13
0
def get_unused_lambda_arg_iterative_process():
    """Returns an iterative process having a Lambda not referencing its arg."""
    server_state_type = computation_types.NamedTupleType([('num_clients',
                                                           tf.int32)])

    def _bind_federated_value(unused_input, input_type,
                              federated_output_value):
        federated_input_type = computation_types.FederatedType(
            input_type, placements.CLIENTS)
        wrapper = computations.federated_computation(
            lambda _: federated_output_value, federated_input_type)
        return wrapper(unused_input)

    def count_clients_federated(client_data):
        client_ones = intrinsics.federated_value(1, placements.CLIENTS)

        client_ones = _bind_federated_value(
            client_data, computation_types.SequenceType(tf.string),
            client_ones)
        return intrinsics.federated_sum(client_ones)

    @computations.federated_computation
    def init_fn():
        return intrinsics.federated_value(
            collections.OrderedDict([('num_clients', 0)]), placements.SERVER)

    @computations.federated_computation([
        computation_types.FederatedType(server_state_type, placements.SERVER),
        computation_types.FederatedType(
            computation_types.SequenceType(tf.string), placements.CLIENTS)
    ])
    def next_fn(server_state, client_val):
        """`next` function for `computation_utils.IterativeProcess`."""
        server_update = intrinsics.federated_zip(
            collections.OrderedDict([('num_clients',
                                      count_clients_federated(client_val))]))

        server_output = intrinsics.federated_value((), placements.SERVER)
        server_output = _bind_federated_value(
            intrinsics.federated_broadcast(server_state), server_state_type,
            server_output)

        return server_update, server_output

    return computation_utils.IterativeProcess(init_fn, next_fn)
Exemple #14
0
def get_iterative_process_for_canonical_form(cf):
    """Creates `tff.utils.IterativeProcess` from a canonical form.

  Args:
    cf: An instance of `tff.backends.mapreduce.CanonicalForm`.

  Returns:
    An instance of `tff.utils.IterativeProcess` that corresponds to `cf`.

  Raises:
    TypeError: If the arguments are of the wrong types.
  """
    py_typecheck.check_type(cf, canonical_form.CanonicalForm)

    @tff.federated_computation
    def init_computation():
        return tff.federated_value(cf.initialize(), tff.SERVER)

    @tff.federated_computation(init_computation.type_signature.result,
                               tff.FederatedType(
                                   cf.work.type_signature.parameter[0],
                                   tff.CLIENTS))
    def next_computation(arg):
        """The logic of a single MapReduce sprocessing round."""
        s1 = arg[0]
        c1 = arg[1]
        s2 = tff.federated_apply(cf.prepare, s1)
        c2 = tff.federated_broadcast(s2)
        c3 = tff.federated_zip([c1, c2])
        c4 = tff.federated_map(cf.work, c3)
        c5 = c4[0]
        c6 = c4[1]
        s3 = tff.federated_aggregate(c5, cf.zero(), cf.accumulate, cf.merge,
                                     cf.report)
        s4 = tff.federated_zip([s1, s3])
        s5 = tff.federated_apply(cf.update, s4)
        s6 = s5[0]
        s7 = s5[1]
        return s6, s7, c6

    return computation_utils.IterativeProcess(init_computation,
                                              next_computation)
Exemple #15
0
def get_iterative_process_for_sum_example_with_no_aggregation():
    """Returns an iterative process for a sum example."""
    @computations.federated_computation
    def init_fn():
        """The `init` function for `computation_utils.IterativeProcess`."""
        return intrinsics.federated_value([0, 0], placements.SERVER)

    @computations.tf_computation([tf.int32, tf.int32])
    def prepare(server_state):
        return server_state

    @computations.tf_computation(tf.int32, [tf.int32, tf.int32])
    def work(client_data, client_input):
        del client_data  # Unused
        del client_input  # Unused
        return [1, 1], []

    @computations.tf_computation([tf.int32, tf.int32], [tf.int32, tf.int32])
    def update(server_state, global_update):
        del server_state  # Unused
        return global_update, []

    @computations.federated_computation([
        computation_types.FederatedType([tf.int32, tf.int32],
                                        placements.SERVER),
        computation_types.FederatedType(tf.int32, placements.CLIENTS),
    ])
    def next_fn(server_state, client_data):
        """The `next` function for `computation_utils.IterativeProcess`."""
        s2 = intrinsics.federated_map(prepare, server_state)
        client_input = intrinsics.federated_broadcast(s2)
        c3 = intrinsics.federated_zip([client_data, client_input])
        _, client_output = intrinsics.federated_map(work, c3)
        unsecure_update = intrinsics.federated_value(1, placements.SERVER)
        secure_update = intrinsics.federated_value(1, placements.SERVER)
        s6 = intrinsics.federated_zip(
            [server_state, [unsecure_update, secure_update]])
        new_server_state, server_output = intrinsics.federated_map(update, s6)
        return new_server_state, server_output, client_output

    return computation_utils.IterativeProcess(init_fn, next_fn)
Exemple #16
0
    def test_returns_canonical_form_with_next_fn_returning_call_directly(self):
        @computations.federated_computation
        def init_fn():
            return intrinsics.federated_value(42, placements.SERVER)

        @computations.federated_computation(
            computation_types.FederatedType(tf.int32, placements.SERVER),
            computation_types.FederatedType(
                computation_types.SequenceType(tf.float32),
                placements.CLIENTS))
        def next_fn(server_state, client_data):
            broadcast_state = intrinsics.federated_broadcast(server_state)

            @computations.tf_computation(tf.int32,
                                         computation_types.SequenceType(
                                             tf.float32))
            @tf.function
            def some_transform(x, y):
                del y  # Unused
                return x + 1

            client_update = intrinsics.federated_map(
                some_transform, (broadcast_state, client_data))
            aggregate_update = intrinsics.federated_sum(client_update)
            server_output = intrinsics.federated_value(1234, placements.SERVER)
            return aggregate_update, server_output

        @computations.federated_computation(
            computation_types.FederatedType(tf.int32, placements.SERVER),
            computation_types.FederatedType(
                computation_types.SequenceType(tf.float32),
                placements.CLIENTS))
        def nested_next_fn(server_state, client_data):
            return next_fn(server_state, client_data)

        iterative_process = computation_utils.IterativeProcess(
            init_fn, nested_next_fn)
        cf = canonical_form_utils.get_canonical_form_for_iterative_process(
            iterative_process)
        self.assertIsInstance(cf, canonical_form.CanonicalForm)
  def test_broadcast_dependent_on_aggregate_fails_well(self):
    cf = test_utils.get_temperature_sensor_example()
    it = canonical_form_utils.get_iterative_process_for_canonical_form(cf)
    next_comp = test_utils.computation_to_building_block(it.next)
    top_level_param = building_blocks.Reference(next_comp.parameter_name,
                                                next_comp.parameter_type)
    first_result = building_blocks.Call(next_comp, top_level_param)
    middle_param = building_blocks.Tuple([
        building_blocks.Selection(first_result, index=0),
        building_blocks.Selection(top_level_param, index=1)
    ])
    second_result = building_blocks.Call(next_comp, middle_param)
    not_reducible = building_blocks.Lambda(next_comp.parameter_name,
                                           next_comp.parameter_type,
                                           second_result)
    not_reducible_it = computation_utils.IterativeProcess(
        it.initialize,
        computation_wrapper_instances.building_block_to_computation(
            not_reducible))

    with self.assertRaisesRegex(ValueError, 'broadcast dependent on aggregate'):
      canonical_form_utils.get_canonical_form_for_iterative_process(
          not_reducible_it)
 def test_iterative_process_next_bad_type(self):
     with self.assertRaisesRegex(TypeError,
                                 r'Expected .*\.Computation, .*'):
         _ = computation_utils.IterativeProcess(initialize_fn=initialize,
                                                next_fn=None)