示例#1
0
  def test_constructor_with_type_mismatch(self):
    with self.assertRaises(
        iterative_process.NextMustAcceptStateFromInitializeError):

      @computations.federated_computation(tf.float32, tf.float32)
      def add_float32(current, val):
        return current + val

      iterative_process.IterativeProcess(
          initialize_fn=initialize, next_fn=add_float32)

    with self.assertRaises(iterative_process.NextMustReturnStateError):

      @computations.federated_computation(tf.int32)
      def add_bad_result(_):
        return 0.0

      iterative_process.IterativeProcess(
          initialize_fn=initialize, next_fn=add_bad_result)

    with self.assertRaises(iterative_process.NextMustReturnStateError):

      @computations.federated_computation(tf.int32)
      def add_bad_multi_result(_):
        return 0.0, 0

      iterative_process.IterativeProcess(
          initialize_fn=initialize, next_fn=add_bad_multi_result)
示例#2
0
    def test_constructor_with_type_mismatch(self):
        with self.assertRaisesRegex(
                TypeError, r'The return type of initialize_fn should match.*'):

            @computations.federated_computation([tf.float32, tf.float32])
            def add_float32(current, val):
                return current + val

            iterative_process.IterativeProcess(initialize_fn=initialize,
                                               next_fn=add_float32)

        with self.assertRaisesRegex(
                TypeError,
                'The return type of next_fn should match the first parameter'):

            @computations.federated_computation(tf.int32)
            def add_bad_result(_):
                return 0.0

            iterative_process.IterativeProcess(initialize_fn=initialize,
                                               next_fn=add_bad_result)

        with self.assertRaisesRegex(
                TypeError,
                'The return type of next_fn should match the first parameter'):

            @computations.federated_computation(tf.int32)
            def add_bad_multi_result(_):
                return 0.0, 0

            iterative_process.IterativeProcess(initialize_fn=initialize,
                                               next_fn=add_bad_multi_result)
示例#3
0
  def test_constructor_with_initialize_bad_type(self):
    with self.assertRaisesRegex(TypeError, r'Expected .*\.Computation, .*'):
      iterative_process.IterativeProcess(initialize_fn=None, next_fn=add_int32)

    with self.assertRaises(iterative_process.InitializeFnHasArgsError):

      @computations.federated_computation(tf.int32)
      def one_arg_initialize(one_arg):
        del one_arg  # Unused.
        return values.to_value(0)

      iterative_process.IterativeProcess(
          initialize_fn=one_arg_initialize, next_fn=add_int32)
示例#4
0
def get_iterative_process_for_sum_example_with_no_aggregation():
    """Returns an iterative process for a sum example.

  This iterative process does not have a call to `federated_aggregate` or
  `federated_secure_sum` and as a result it should fail to compile to
  `canonical_form.CanonicalForm`.
  """
    @computations.federated_computation
    def init_fn():
        """The `init` function for `tff.templates.IterativeProcess`."""
        return intrinsics.federated_value([0, 0], placements.SERVER)

    @computations.tf_computation([tf.int32, tf.int32], [tf.int32, tf.int32])
    def update(server_state, global_update):
        del server_state  # Unused
        return global_update, []

    @computations.federated_computation([
        computation_types.FederatedType([tf.int32, tf.int32],
                                        placements.SERVER),
        computation_types.FederatedType(tf.int32, placements.CLIENTS),
    ])
    def next_fn(server_state, client_data):
        """The `next` function for `tff.templates.IterativeProcess`."""
        del client_data
        # No call to `federated_aggregate`.
        unsecure_update = intrinsics.federated_value(1, placements.SERVER)
        # No call to `federated_secure_sum`.
        secure_update = intrinsics.federated_value(1, placements.SERVER)
        s6 = intrinsics.federated_zip(
            [server_state, [unsecure_update, secure_update]])
        new_server_state, server_output = intrinsics.federated_map(update, s6)
        return new_server_state, server_output

    return iterative_process.IterativeProcess(init_fn, next_fn)
示例#5
0
    def test_returns_iterproc_accepting_dataset_in_third_index_of_next(self):
        iterproc = _create_stateless_int_dataset_reduction_iterative_process()

        old_param_type = iterproc.next.type_signature.parameter

        new_param_elements = [old_param_type[0], tf.int32, old_param_type[1]]

        @computations.federated_computation(
            computation_types.StructType(new_param_elements))
        def new_next(param):
            return iterproc.next([param[0], param[2]])

        iterproc_with_dataset_as_third_elem = iterative_process.IterativeProcess(
            iterproc.initialize, new_next)
        expected_new_next_type_signature = computation_types.FunctionType([
            computation_types.FederatedType(tf.int64, placements.SERVER),
            tf.int32,
            computation_types.FederatedType(tf.string, placements.CLIENTS)
        ], computation_types.FederatedType(tf.int64, placements.SERVER))

        new_iterproc = iterative_process_compositions.compose_dataset_computation_with_iterative_process(
            int_dataset_computation, iterproc_with_dataset_as_third_elem)

        self.assertTrue(
            expected_new_next_type_signature.is_equivalent_to(
                new_iterproc.next.type_signature))
 def test_construction_with_empty_state_does_not_raise(self):
   initialize_fn = computations.tf_computation()(lambda: ())
   next_fn = computations.tf_computation(())(lambda x: (x, 1.0))
   try:
     iterative_process.IterativeProcess(initialize_fn, next_fn)
   except:  # pylint: disable=bare-except
     self.fail('Could not construct an IterativeProcess with empty state.')
示例#7
0
def _create_stateless_int_vector_unknown_dim_dataset_reduction_iterative_process(
):
    # Tests handling client data of unknown shape and summing to fixed shape.

    @computations.tf_computation()
    def make_zero():
        return tf.reshape(tf.cast(0, tf.int64), shape=[1])

    @computations.federated_computation()
    def init():
        return intrinsics.federated_eval(make_zero, placements.SERVER)

    @computations.tf_computation(
        computation_types.SequenceType(
            computation_types.TensorType(tf.int64, shape=[None])))
    def reduce_dataset(x):
        return x.reduce(tf.cast(tf.constant([0]), tf.int64),
                        lambda x, y: x + tf.reduce_sum(y))

    @computations.federated_computation(
        computation_types.FederatedType(
            computation_types.TensorType(tf.int64, shape=[None]),
            placements.SERVER),
        computation_types.FederatedType(
            computation_types.SequenceType(
                computation_types.TensorType(tf.int64, shape=[None])),
            placements.CLIENTS))
    def next_fn(server_state, client_data):
        del server_state  # Unused
        return intrinsics.federated_sum(
            intrinsics.federated_map(reduce_dataset, client_data))

    return iterative_process.IterativeProcess(initialize_fn=init,
                                              next_fn=next_fn)
 def test_federated_init_state_not_assignable(self):
     initialize_fn = federated_computation.federated_computation()(
         lambda: intrinsics.federated_value(0, placements.SERVER))
     next_fn = federated_computation.federated_computation(
         FederatedType(tf.int32, placements.CLIENTS))(lambda state: state)
     with self.assertRaises(errors.TemplateStateNotAssignableError):
         iterative_process.IterativeProcess(initialize_fn, next_fn)
def _create_stateless_int_dataset_reduction_iterative_process():
    @computations.tf_computation()
    def make_zero():
        return tf.cast(0, tf.int64)

    @computations.federated_computation()
    def init():
        return intrinsics.federated_eval(make_zero, placement_literals.SERVER)

    @computations.tf_computation(computation_types.SequenceType(tf.int64))
    def reduce_dataset(x):
        return x.reduce(tf.cast(0, tf.int64), lambda x, y: x + y)

    @computations.federated_computation(
        (init.type_signature.result,
         computation_types.FederatedType(
             computation_types.SequenceType(tf.int64),
             placement_literals.CLIENTS)))
    def next_fn(empty_tup, x):
        del empty_tup  # Unused
        return intrinsics.federated_sum(
            intrinsics.federated_map(reduce_dataset, x))

    return iterative_process.IterativeProcess(initialize_fn=init,
                                              next_fn=next_fn)
示例#10
0
    def test_disallows_broadcast_dependent_on_aggregate(self):
        @federated_computation.federated_computation
        def init_comp():
            return intrinsics.federated_value(0, placements.SERVER)

        @federated_computation.federated_computation(
            computation_types.at_server(tf.int32),
            computation_types.at_clients(()))
        def next_comp(server_state, client_data):
            del server_state, client_data
            client_val = intrinsics.federated_value(0, placements.CLIENTS)
            server_agg = intrinsics.federated_sum(client_val)
            # This broadcast is dependent on the result of the above aggregation,
            # which is not supported by MapReduce form.
            broadcasted = intrinsics.federated_broadcast(server_agg)
            server_agg_again = intrinsics.federated_sum(broadcasted)
            # `next` must return two values.
            return server_agg_again, intrinsics.federated_value(
                (), placements.SERVER)

        ip = iterative_process.IterativeProcess(init_comp, next_comp)

        with self.assertRaises(ValueError):
            form_utils.check_iterative_process_compatible_with_map_reduce_form(
                ip)
 def test_next_state_not_assignable_tuple_result(self):
     float_next_fn = computations.tf_computation(
         tf.float32,
         tf.float32)(lambda state, x: (tf.cast(state, tf.float32), x))
     with self.assertRaises(errors.TemplateStateNotAssignableError):
         iterative_process.IterativeProcess(test_initialize_fn,
                                            float_next_fn)
示例#12
0
    def get_map_reduce_form_for_client_to_server_fn(
            self, client_to_server_fn) -> forms.MapReduceForm:
        """Produces a `MapReduceForm` for the provided `client_to_server_fn`.

    Creates an `iterative_process.IterativeProcess` which uses
    `client_to_server_fn` to map from `client_data` to `server_output`, then
    passes this value through `get_map_reduce_form_for_iterative_process`.

    Args:
      client_to_server_fn: A function from client-placed data to server-placed
        output.

    Returns:
      A `forms.MapReduceForm` which uses the embedded `client_to_server_fn`.
    """
        @federated_computation.federated_computation
        def init_fn():
            return intrinsics.federated_value((), placements.SERVER)

        @federated_computation.federated_computation([
            computation_types.at_server(()),
            computation_types.at_clients(tf.int32),
        ])
        def next_fn(server_state, client_data):
            server_output = client_to_server_fn(client_data)
            return server_state, server_output

        ip = iterative_process.IterativeProcess(init_fn, next_fn)
        return form_utils.get_map_reduce_form_for_iterative_process(ip)
 def test_federated_next_state_not_assignable(self):
     initialize_fn = computations.federated_computation()(
         lambda: intrinsics.federated_value(0, placements.SERVER))
     next_fn = computations.federated_computation(
         initialize_fn.type_signature.result)(
             intrinsics.federated_broadcast)
     with self.assertRaises(errors.TemplateStateNotAssignableError):
         iterative_process.IterativeProcess(initialize_fn, next_fn)
示例#14
0
    def test_constructor_with_state_tuple_arg(self):
        ip = iterative_process.IterativeProcess(initialize, add_int32)

        state = ip.initialize()
        iterations = 10
        for val in range(iterations):
            state = ip.next(state, val)
        self.assertEqual(state, sum(range(iterations)))
示例#15
0
  def test_constructor_with_empty_tuple(self):
    ip = iterative_process.IterativeProcess(initialize_empty_tuple,
                                            next_empty_tuple)

    state = ip.initialize()
    iterations = 2
    for _ in range(iterations):
      state = ip.next(state)
    self.assertEqual(state, [])
示例#16
0
  def test_constructor_with_state_multiple_return_values(self):
    ip = iterative_process.IterativeProcess(initialize, add_mul_int32)

    state = ip.initialize()
    iterations = 10
    for val in range(iterations):
      state, product = ip.next(state, val)
    self.assertEqual(state, sum(range(iterations)))
    self.assertEqual(product, sum(range(iterations - 1)) * (iterations - 1))
示例#17
0
def get_iterative_process_with_nested_broadcasts():
    """Returns an iterative process with nested federated broadcasts.

  This iterative process contains all the components required to compile to
  `forms.MapReduceForm`.
  """
    @federated_computation.federated_computation
    def init_fn():
        """The `init` function for `tff.templates.IterativeProcess`."""
        return intrinsics.federated_value([0, 0], placements.SERVER)

    @tensorflow_computation.tf_computation([tf.int32, tf.int32])
    def prepare(server_state):
        return server_state

    @tensorflow_computation.tf_computation(tf.int32, [tf.int32, tf.int32])
    def work(client_data, client_input):
        del client_data  # Unused
        del client_input  # Unused
        return 1, 1

    @tensorflow_computation.tf_computation([tf.int32, tf.int32],
                                           [tf.int32, tf.int32])
    def update(server_state, global_update):
        del server_state  # Unused
        return global_update, []

    @federated_computation.federated_computation(
        computation_types.FederatedType([tf.int32, tf.int32],
                                        placements.SERVER))
    def broadcast_and_return_arg_and_result(x):
        broadcasted = intrinsics.federated_broadcast(x)
        return [broadcasted, x]

    @federated_computation.federated_computation([
        computation_types.FederatedType([tf.int32, tf.int32],
                                        placements.SERVER),
        computation_types.FederatedType(tf.int32, placements.CLIENTS),
    ])
    def next_fn(server_state, client_data):
        """The `next` function for `tff.templates.IterativeProcess`."""
        s2 = intrinsics.federated_map(prepare, server_state)
        unused_client_input, to_broadcast = broadcast_and_return_arg_and_result(
            s2)
        client_input = intrinsics.federated_broadcast(to_broadcast)
        c3 = intrinsics.federated_zip([client_data, client_input])
        client_updates = intrinsics.federated_map(work, c3)
        unsecure_update = intrinsics.federated_sum(client_updates[0])
        secure_update = intrinsics.federated_secure_sum_bitwidth(
            client_updates[1], 8)
        s6 = intrinsics.federated_zip(
            [server_state, [unsecure_update, secure_update]])
        new_server_state, server_output = intrinsics.federated_map(update, s6)
        return new_server_state, server_output

    return iterative_process.IterativeProcess(init_fn, next_fn)
def _create_test_iterative_process(state_type, state_init):

  @tensorflow_computation.tf_computation(state_type)
  def next_fn(state):
    return state

  return iterative_process.IterativeProcess(
      initialize_fn=tensorflow_computation.tf_computation(
          lambda: tf.constant(state_init)),
      next_fn=next_fn)
示例#19
0
    def test_constructor_with_state_only(self):
        ip = iterative_process.IterativeProcess(initialize, count_int32)

        state = ip.initialize()
        iterations = 10
        for _ in range(iterations):
            # TODO(b/122321354): remove the .item() call on `state` once numpy.int32
            # type is supported.
            state = ip.next(state.item())
        self.assertEqual(state, iterations)
示例#20
0
 def test_next_computation_returning_tensor_fails_well(self):
     mrf = mapreduce_test_utils.get_temperature_sensor_example()
     it = form_utils.get_iterative_process_for_map_reduce_form(mrf)
     init_result = it.initialize.type_signature.result
     lam = building_blocks.Lambda(
         'x', init_result, building_blocks.Reference('x', init_result))
     bad_it = iterative_process.IterativeProcess(
         it.initialize,
         computation_impl.ConcreteComputation.from_building_block(lam))
     with self.assertRaises(TypeError):
         form_utils.get_map_reduce_form_for_iterative_process(bad_it)
 def test_next_computation_returning_tensor_fails_well(self):
   cf = test_utils.get_temperature_sensor_example()
   it = canonical_form_utils.get_iterative_process_for_canonical_form(cf)
   init_result = it.initialize.type_signature.result
   lam = building_blocks.Lambda('x', init_result,
                                building_blocks.Reference('x', init_result))
   bad_it = iterative_process.IterativeProcess(
       it.initialize,
       computation_wrapper_instances.building_block_to_computation(lam))
   with self.assertRaises(TypeError):
     canonical_form_utils.get_canonical_form_for_iterative_process(bad_it)
示例#22
0
def _create_whimsy_iterative_process():
    @computations.tf_computation()
    def init():
        return []

    @computations.tf_computation(init.type_signature.result)
    def next_fn(x):
        return x

    return iterative_process.IterativeProcess(initialize_fn=init,
                                              next_fn=next_fn)
示例#23
0
 def test_raises_on_invalid_distributor(self):
     model_weights_type = type_conversions.type_from_tensors(
         model_utils.ModelWeights.from_model(
             model_examples.LinearRegression()))
     distributor = distributors.build_broadcast_process(model_weights_type)
     invalid_distributor = iterative_process.IterativeProcess(
         distributor.initialize, distributor.next)
     with self.assertRaises(TypeError):
         fed_avg.build_weighted_fed_avg(
             model_fn=model_examples.LinearRegression,
             client_optimizer_fn=sgdm.build_sgdm(1.0),
             model_distributor=invalid_distributor)
示例#24
0
def get_iterative_process_for_map_reduce_form(
        mrf: forms.MapReduceForm) -> iterative_process.IterativeProcess:
    """Creates `tff.templates.IterativeProcess` from a MapReduce form.

  Args:
    mrf: An instance of `tff.backends.mapreduce.MapReduceForm`.

  Returns:
    An instance of `tff.templates.IterativeProcess` that corresponds to `mrf`.

  Raises:
    TypeError: If the arguments are of the wrong types.
  """
    py_typecheck.check_type(mrf, forms.MapReduceForm)

    @federated_computation.federated_computation
    def init_computation():
        return intrinsics.federated_value(mrf.initialize(), placements.SERVER)

    next_parameter_type = computation_types.StructType([
        (mrf.server_state_label, init_computation.type_signature.result),
        (mrf.client_data_label,
         computation_types.FederatedType(mrf.work.type_signature.parameter[0],
                                         placements.CLIENTS)),
    ])

    @federated_computation.federated_computation(next_parameter_type)
    def next_computation(arg):
        """The logic of a single MapReduce processing round."""
        server_state, client_data = arg
        broadcast_input = intrinsics.federated_map(mrf.prepare, server_state)
        broadcast_result = intrinsics.federated_broadcast(broadcast_input)
        work_arg = intrinsics.federated_zip([client_data, broadcast_result])
        (aggregate_input, secure_sum_bitwidth_input, secure_sum_input,
         secure_modular_sum_input) = intrinsics.federated_map(
             mrf.work, work_arg)
        aggregate_result = intrinsics.federated_aggregate(
            aggregate_input, mrf.zero(), mrf.accumulate, mrf.merge, mrf.report)
        secure_sum_bitwidth_result = intrinsics.federated_secure_sum_bitwidth(
            secure_sum_bitwidth_input, mrf.secure_sum_bitwidth())
        secure_sum_result = intrinsics.federated_secure_sum(
            secure_sum_input, mrf.secure_sum_max_input())
        secure_modular_sum_result = intrinsics.federated_secure_modular_sum(
            secure_modular_sum_input, mrf.secure_modular_sum_modulus())
        update_arg = intrinsics.federated_zip(
            (server_state, (aggregate_result, secure_sum_bitwidth_result,
                            secure_sum_result, secure_modular_sum_result)))
        updated_server_state, server_output = intrinsics.federated_map(
            mrf.update, update_arg)
        return updated_server_state, server_output

    return iterative_process.IterativeProcess(init_computation,
                                              next_computation)
示例#25
0
 def test_raises_on_invalid_distributor(self):
     model_weights_type = type_conversions.type_from_tensors(
         model_utils.ModelWeights.from_model(
             model_examples.LinearRegression()))
     distributor = distributors.build_broadcast_process(model_weights_type)
     invalid_distributor = iterative_process.IterativeProcess(
         distributor.initialize, distributor.next)
     with self.assertRaises(TypeError):
         mime.build_weighted_mime_lite(
             model_fn=model_examples.LinearRegression,
             base_optimizer=sgdm.build_sgdm(learning_rate=0.01,
                                            momentum=0.9),
             model_distributor=invalid_distributor)
  def test_construction_with_unknown_dimension_does_not_raise(self):
    initialize_fn = computations.tf_computation()(
        lambda: tf.constant([], dtype=tf.string))

    @computations.tf_computation(
        computation_types.TensorType(shape=[None], dtype=tf.string))
    def next_fn(strings):
      return tf.concat([strings, tf.constant(['abc'])], axis=0)

    try:
      iterative_process.IterativeProcess(initialize_fn, next_fn)
    except:  # pylint: disable=bare-except
      self.fail('Could not construct an IterativeProcess with parameter types '
                'with statically unknown shape.')
示例#27
0
def get_iterative_process_for_map_reduce_form(
        mrf: forms.MapReduceForm) -> iterative_process.IterativeProcess:
    """Creates `tff.templates.IterativeProcess` from a MapReduce form.

  Args:
    mrf: An instance of `tff.backends.mapreduce.MapReduceForm`.

  Returns:
    An instance of `tff.templates.IterativeProcess` that corresponds to `mrf`.

  Raises:
    TypeError: If the arguments are of the wrong types.
  """
    py_typecheck.check_type(mrf, forms.MapReduceForm)

    @computations.federated_computation
    def init_computation():
        return intrinsics.federated_value(mrf.initialize(), placements.SERVER)

    next_parameter_type = computation_types.StructType([
        (mrf.server_state_label, init_computation.type_signature.result),
        (mrf.client_data_label,
         computation_types.FederatedType(mrf.work.type_signature.parameter[0],
                                         placements.CLIENTS)),
    ])

    @computations.federated_computation(next_parameter_type)
    def next_computation(arg):
        """The logic of a single MapReduce processing round."""
        s1 = arg[0]
        c1 = arg[1]
        s2 = intrinsics.federated_map(mrf.prepare, s1)
        c2 = intrinsics.federated_broadcast(s2)
        c3 = intrinsics.federated_zip([c1, c2])
        c4 = intrinsics.federated_map(mrf.work, c3)
        c5 = c4[0]
        c6 = c4[1]
        s3 = intrinsics.federated_aggregate(c5, mrf.zero(), mrf.accumulate,
                                            mrf.merge, mrf.report)
        s4 = intrinsics.federated_secure_sum_bitwidth(c6, mrf.bitwidth())
        s5 = intrinsics.federated_zip([s3, s4])
        s6 = intrinsics.federated_zip([s1, s5])
        s7 = intrinsics.federated_map(mrf.update, s6)
        s8 = s7[0]
        s9 = s7[1]
        return s8, s9

    return iterative_process.IterativeProcess(init_computation,
                                              next_computation)
示例#28
0
def _create_federated_int_dataset_identity_iterative_process():
    @computations.tf_computation()
    def create_dataset():
        return tf.data.Dataset.range(5)

    @computations.federated_computation()
    def init():
        return intrinsics.federated_eval(create_dataset, placements.CLIENTS)

    @computations.federated_computation(init.type_signature.result)
    def next_fn(x):
        return x

    return iterative_process.IterativeProcess(initialize_fn=init,
                                              next_fn=next_fn)
示例#29
0
    def test_constructor_with_tensors_unknown_dimensions(self):
        @computations.tf_computation
        def init():
            return tf.constant([], dtype=tf.string)

        @computations.tf_computation(
            computation_types.TensorType(shape=[None], dtype=tf.string))
        def next_fn(strings):
            return tf.concat([strings, tf.constant(['abc'])], axis=0)

        try:
            iterative_process.IterativeProcess(init, next_fn)
        except:  # pylint: disable=bare-except
            self.fail(
                'Could not construct an IterativeProcess with parameter types '
                'including unknown dimension tennsors.')
示例#30
0
def get_iterative_process_for_canonical_form(cf):
    """Creates `tff.utils.IterativeProcess` from a canonical form.

  Args:
    cf: An instance of `tff.backends.mapreduce.CanonicalForm`.

  Returns:
    An instance of `tff.utils.IterativeProcess` that corresponds to `cf`.

  Raises:
    TypeError: If the arguments are of the wrong types.
  """
    py_typecheck.check_type(cf, canonical_form.CanonicalForm)

    @computations.federated_computation
    def init_computation():
        return intrinsics.federated_value(cf.initialize(), placements.SERVER)

    @computations.federated_computation(
        init_computation.type_signature.result,
        computation_types.FederatedType(cf.work.type_signature.parameter[0],
                                        placements.CLIENTS))
    def next_computation(arg):
        """The logic of a single MapReduce processing round."""
        s1 = arg[0]
        c1 = arg[1]
        s2 = intrinsics.federated_map(cf.prepare, s1)
        c2 = intrinsics.federated_broadcast(s2)
        c3 = intrinsics.federated_zip([c1, c2])
        c4 = intrinsics.federated_map(cf.work, c3)
        c5 = c4[0]
        c6 = c5[0]
        c7 = c5[1]
        c8 = c4[1]
        s3 = intrinsics.federated_aggregate(c6, cf.zero(), cf.accumulate,
                                            cf.merge, cf.report)
        s4 = intrinsics.federated_secure_sum(c7, cf.bitwidth())
        s5 = intrinsics.federated_zip([s3, s4])
        s6 = intrinsics.federated_zip([s1, s5])
        s7 = intrinsics.federated_map(cf.update, s6)
        s8 = s7[0]
        s9 = s7[1]
        return s8, s9, c8

    return iterative_process.IterativeProcess(init_computation,
                                              next_computation)