Exemplo n.º 1
0
    def test_handles_federated_broadcasts_nested_in_tuple(self):
        first_broadcast = compiler_test_utils.create_whimsy_called_federated_broadcast(
        )
        packed_broadcast = building_blocks.Struct([
            building_blocks.Data(
                'a',
                computation_types.FederatedType(
                    computation_types.TensorType(tf.int32),
                    placements.SERVER)), first_broadcast
        ])
        sel = building_blocks.Selection(packed_broadcast, index=0)
        second_broadcast = building_block_factory.create_federated_broadcast(
            sel)
        result, _ = compiler_transformations.transform_to_call_dominant(
            second_broadcast)
        comp = building_blocks.Lambda('a', tf.int32, result)
        uri = [intrinsic_defs.FEDERATED_BROADCAST.uri]

        before, after = transformations.force_align_and_split_by_intrinsics(
            comp, uri)

        self.assertIsInstance(before, building_blocks.Lambda)
        self.assertFalse(tree_analysis.contains_called_intrinsic(before, uri))
        self.assertIsInstance(after, building_blocks.Lambda)
        self.assertFalse(tree_analysis.contains_called_intrinsic(after, uri))
Exemplo n.º 2
0
    def test_returns_trees_with_one_federated_secure_sum_and_two_federated_aggregates(
            self):
        federated_aggregate = compiler_test_utils.create_whimsy_called_federated_aggregate(
            accumulate_parameter_name='a',
            merge_parameter_name='b',
            report_parameter_name='c')
        federated_secure_sum = compiler_test_utils.create_whimsy_called_federated_secure_sum(
        )
        called_intrinsics = building_blocks.Struct([
            federated_secure_sum,
            federated_aggregate,
            federated_aggregate,
        ])
        comp = building_blocks.Lambda('d', tf.int32, called_intrinsics)
        uri = [
            intrinsic_defs.FEDERATED_AGGREGATE.uri,
            intrinsic_defs.FEDERATED_SECURE_SUM.uri,
        ]

        before, after = transformations.force_align_and_split_by_intrinsics(
            comp, uri)

        self.assertIsInstance(before, building_blocks.Lambda)
        self.assertFalse(tree_analysis.contains_called_intrinsic(before, uri))
        self.assertIsInstance(after, building_blocks.Lambda)
        self.assertFalse(tree_analysis.contains_called_intrinsic(after, uri))
Exemplo n.º 3
0
    def test_returns_trees_with_one_federated_secure_sum(self):
        federated_secure_sum = compiler_test_utils.create_whimsy_called_federated_secure_sum(
        )
        called_intrinsics = building_blocks.Struct([federated_secure_sum])
        comp = building_blocks.Lambda('a', tf.int32, called_intrinsics)
        uri = [intrinsic_defs.FEDERATED_SECURE_SUM.uri]

        before, after = transformations.force_align_and_split_by_intrinsics(
            comp, uri)

        self.assertIsInstance(before, building_blocks.Lambda)
        self.assertFalse(tree_analysis.contains_called_intrinsic(before, uri))
        self.assertIsInstance(after, building_blocks.Lambda)
        self.assertFalse(tree_analysis.contains_called_intrinsic(after, uri))
Exemplo n.º 4
0
  def test_returns_trees_with_one_federated_broadcast(self):
    federated_broadcast = compiler_test_utils.create_dummy_called_federated_broadcast(
    )
    called_intrinsics = building_blocks.Tuple([federated_broadcast])
    comp = building_blocks.Lambda('a', tf.int32, called_intrinsics)
    uri = [intrinsic_defs.FEDERATED_BROADCAST.uri]

    before, after = transformations.force_align_and_split_by_intrinsics(
        comp, uri)

    self.assertIsInstance(before, building_blocks.Lambda)
    self.assertFalse(tree_analysis.contains_called_intrinsic(before, uri))
    self.assertIsInstance(after, building_blocks.Lambda)
    self.assertFalse(tree_analysis.contains_called_intrinsic(after, uri))
Exemplo n.º 5
0
    def assert_splits_on(self, comp, calls):
        """Asserts that `force_align_and_split_by_intrinsics` removes intrinsics."""
        if not isinstance(calls, list):
            calls = [calls]
        uris = [call.function.uri for call in calls]
        before, after = transformations.force_align_and_split_by_intrinsics(
            comp, calls)

        # Ensure that the resulting computations no longer contain the split
        # intrinsics.
        self.assertFalse(tree_analysis.contains_called_intrinsic(before, uris))
        self.assertFalse(tree_analysis.contains_called_intrinsic(after, uris))
        # Removal isn't interesting to test for if it wasn't there to begin with.
        self.assertTrue(tree_analysis.contains_called_intrinsic(comp, uris))

        if comp.parameter_type is not None:
            type_test_utils.assert_types_equivalent(comp.parameter_type,
                                                    before.parameter_type)
        else:
            self.assertIsNone(before.parameter_type)
        # THere must be one parameter for each intrinsic in `calls`.
        before.type_signature.result.check_struct()
        self.assertLen(before.type_signature.result, len(calls))

        # Check that `after`'s parameter is a structure like:
        # {
        #   'original_arg': comp.parameter_type, (if present)
        #   'intrinsic_results': [...],
        # }
        after.parameter_type.check_struct()
        if comp.parameter_type is not None:
            self.assertLen(after.parameter_type, 2)
            type_test_utils.assert_types_equivalent(
                comp.parameter_type, after.parameter_type.original_arg)
        else:
            self.assertLen(after.parameter_type, 1)
        # There must be one result for each intrinsic in `calls`.
        self.assertLen(after.parameter_type.intrinsic_results, len(calls))

        # Check that each pair of (param, result) is a valid type substitution
        # for the intrinsic in question.
        for i in range(len(calls)):
            concrete_signature = computation_types.FunctionType(
                before.type_signature.result[i],
                after.parameter_type.intrinsic_results[i])
            abstract_signature = calls[i].function.intrinsic_def(
            ).type_signature
            type_analysis.check_concrete_instance_of(concrete_signature,
                                                     abstract_signature)
  def assert_splits_on(self, comp, calls):
    """Asserts that `force_align_and_split_by_intrinsics` removes intrinsics."""
    if not isinstance(calls, list):
      calls = [calls]
    uris = [call.function.uri for call in calls]
    before, after = transformations.force_align_and_split_by_intrinsics(
        comp, calls)

    # Ensure that the resulting computations no longer contain the split
    # intrinsics.
    self.assertFalse(tree_analysis.contains_called_intrinsic(before, uris))
    self.assertFalse(tree_analysis.contains_called_intrinsic(after, uris))
    # Removal isn't interesting to test for if it wasn't there to begin with.
    self.assertTrue(tree_analysis.contains_called_intrinsic(comp, uris))

    self.assert_types_equivalent(comp.parameter_type, before.parameter_type)
    # THere must be one parameter for each intrinsic in `calls`.
    before.type_signature.result.check_struct()
    self.assertLen(before.type_signature.result, len(calls))

    # Check that `after`'s parameter is a structure like:
    # {
    #   'original_arg': comp.parameter_type,
    #   'intrinsic_results': [...],
    # }
    after.parameter_type.check_struct()
    self.assertLen(after.parameter_type, 2)
    self.assert_types_equivalent(comp.parameter_type,
                                 after.parameter_type.original_arg)
    # There must be one result for each intrinsic in `calls`.
    self.assertLen(after.parameter_type.intrinsic_results, len(calls))

    # Check that each pair of (param, result) is a valid type substitution
    # for the intrinsic in question.
    for i in range(len(calls)):
      concrete_signature = computation_types.FunctionType(
          before.type_signature.result[i],
          after.parameter_type.intrinsic_results[i])
      abstract_signature = calls[i].function.intrinsic_def().type_signature
      # `force_align_and_split_by_intrinsics` loses all-equal data due to
      # zipping and unzipping. This is okay because the resulting computations
      # are not used together directly, but are compiled into unplaced TF code.
      abstract_signature = _remove_client_all_equals_from_type(
          abstract_signature)
      concrete_signature = _remove_client_all_equals_from_type(
          concrete_signature)
      type_analysis.check_concrete_instance_of(concrete_signature,
                                               abstract_signature)
Exemplo n.º 7
0
 def test_removes_selected_intrinsic_leaving_remaining_intrinsic(self):
     federated_aggregate = building_block_test_utils.create_whimsy_called_federated_aggregate(
         accumulate_parameter_name='a',
         merge_parameter_name='b',
         report_parameter_name='c')
     federated_secure_sum_bitwidth = building_block_test_utils.create_whimsy_called_federated_secure_sum_bitwidth(
     )
     called_intrinsics = building_blocks.Struct([
         federated_aggregate,
         federated_secure_sum_bitwidth,
     ])
     comp = building_blocks.Lambda('d', tf.int32, called_intrinsics)
     null_aggregate = building_block_factory.create_null_federated_aggregate(
     )
     secure_sum_bitwidth_uri = federated_secure_sum_bitwidth.function.uri
     aggregate_uri = null_aggregate.function.uri
     before, after = transformations.force_align_and_split_by_intrinsics(
         comp, [null_aggregate])
     self.assertTrue(
         tree_analysis.contains_called_intrinsic(comp,
                                                 secure_sum_bitwidth_uri))
     self.assertTrue(
         tree_analysis.contains_called_intrinsic(comp, aggregate_uri))
     self.assertFalse(
         tree_analysis.contains_called_intrinsic(before, aggregate_uri))
     self.assertFalse(
         tree_analysis.contains_called_intrinsic(after, aggregate_uri))
     self.assertTrue(
         tree_analysis.contains_called_intrinsic(before,
                                                 secure_sum_bitwidth_uri)
         or tree_analysis.contains_called_intrinsic(
             after, secure_sum_bitwidth_uri))
Exemplo n.º 8
0
 def test_returns_false_with_unmatched_called_intrinsic(self):
     comp = building_block_test_utils.create_whimsy_called_federated_broadcast(
     )
     uri = intrinsic_defs.FEDERATED_MAP.uri
     self.assertFalse(tree_analysis.contains_called_intrinsic(comp, uri))
Exemplo n.º 9
0
 def test_returns_false_with_no_called_intrinsic(self):
     comp = building_block_test_utils.create_identity_function('a')
     self.assertFalse(tree_analysis.contains_called_intrinsic(comp))
Exemplo n.º 10
0
 def test_returns_true_with_matching_uri(self):
     comp = building_block_test_utils.create_whimsy_called_federated_broadcast(
     )
     uri = intrinsic_defs.FEDERATED_BROADCAST.uri
     self.assertTrue(tree_analysis.contains_called_intrinsic(comp, uri))
Exemplo n.º 11
0
 def test_returns_true_with_none_uri(self):
     comp = building_block_test_utils.create_whimsy_called_federated_broadcast(
     )
     self.assertTrue(tree_analysis.contains_called_intrinsic(comp))
Exemplo n.º 12
0
 def test_raises_type_error_with_none_tree(self):
     with self.assertRaises(TypeError):
         tree_analysis.contains_called_intrinsic(None)
def get_canonical_form_for_iterative_process(ip):
  """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process.

  This function transforms computations from the input `ip` into
  an instance of `tff.backends.mapreduce.CanonicalForm`.

  Args:
    ip: An instance of `tff.templates.IterativeProcess`.

  Returns:
    An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this
    process.

  Raises:
    TypeError: If the arguments are of the wrong types.
    transformations.CanonicalFormCompilationError: If the compilation
      process fails.
  """
  py_typecheck.check_type(ip, iterative_process.IterativeProcess)

  initialize_comp = building_blocks.ComputationBuildingBlock.from_proto(
      ip.initialize._computation_proto)  # pylint: disable=protected-access
  next_comp = building_blocks.ComputationBuildingBlock.from_proto(
      ip.next._computation_proto)  # pylint: disable=protected-access
  _check_iterative_process_compatible_with_canonical_form(
      initialize_comp, next_comp)

  initialize_comp = _replace_intrinsics_with_bodies(initialize_comp)
  next_comp = _replace_intrinsics_with_bodies(next_comp)
  tree_analysis.check_contains_only_reducible_intrinsics(initialize_comp)
  tree_analysis.check_contains_only_reducible_intrinsics(next_comp)
  tree_analysis.check_broadcast_not_dependent_on_aggregate(next_comp)

  if tree_analysis.contains_called_intrinsic(
      next_comp, intrinsic_defs.FEDERATED_BROADCAST.uri):
    before_broadcast, after_broadcast = (
        transformations.force_align_and_split_by_intrinsics(
            next_comp, [intrinsic_defs.FEDERATED_BROADCAST.uri]))
  else:
    before_broadcast, after_broadcast = (
        _create_before_and_after_broadcast_for_no_broadcast(next_comp))

  contains_federated_aggregate = tree_analysis.contains_called_intrinsic(
      next_comp, intrinsic_defs.FEDERATED_AGGREGATE.uri)
  contains_federated_secure_sum = tree_analysis.contains_called_intrinsic(
      next_comp, intrinsic_defs.FEDERATED_SECURE_SUM.uri)
  if not (contains_federated_aggregate or contains_federated_secure_sum):
    raise ValueError(
        'Expected an `tff.templates.IterativeProcess` containing at least one '
        '`federated_aggregate` or `federated_secure_sum`, found none.')

  if contains_federated_aggregate and contains_federated_secure_sum:
    before_aggregate, after_aggregate = (
        transformations.force_align_and_split_by_intrinsics(
            after_broadcast, [
                intrinsic_defs.FEDERATED_AGGREGATE.uri,
                intrinsic_defs.FEDERATED_SECURE_SUM.uri,
            ]))
  elif contains_federated_secure_sum:
    assert not contains_federated_aggregate
    before_aggregate, after_aggregate = (
        _create_before_and_after_aggregate_for_no_federated_aggregate(
            after_broadcast))
  else:
    assert contains_federated_aggregate and not contains_federated_secure_sum
    before_aggregate, after_aggregate = (
        _create_before_and_after_aggregate_for_no_federated_secure_sum(
            after_broadcast))

  type_info = _get_type_info(initialize_comp, before_broadcast, after_broadcast,
                             before_aggregate, after_aggregate)

  initialize = transformations.consolidate_and_extract_local_processing(
      initialize_comp)
  _check_type_equal(initialize.type_signature, type_info['initialize_type'])

  prepare = _extract_prepare(before_broadcast)
  _check_type_equal(prepare.type_signature, type_info['prepare_type'])

  work = _extract_work(before_aggregate)
  _check_type_equal(work.type_signature, type_info['work_type'])

  zero, accumulate, merge, report = _extract_federated_aggregate_functions(
      before_aggregate)
  _check_type_equal(zero.type_signature, type_info['zero_type'])
  _check_type_equal(accumulate.type_signature, type_info['accumulate_type'])
  _check_type_equal(merge.type_signature, type_info['merge_type'])
  _check_type_equal(report.type_signature, type_info['report_type'])

  bitwidth = _extract_federated_secure_sum_functions(before_aggregate)
  _check_type_equal(bitwidth.type_signature, type_info['bitwidth_type'])

  update = _extract_update(after_aggregate)
  _check_type_equal(update.type_signature, type_info['update_type'])

  return canonical_form.CanonicalForm(
      computation_wrapper_instances.building_block_to_computation(initialize),
      computation_wrapper_instances.building_block_to_computation(prepare),
      computation_wrapper_instances.building_block_to_computation(work),
      computation_wrapper_instances.building_block_to_computation(zero),
      computation_wrapper_instances.building_block_to_computation(accumulate),
      computation_wrapper_instances.building_block_to_computation(merge),
      computation_wrapper_instances.building_block_to_computation(report),
      computation_wrapper_instances.building_block_to_computation(bitwidth),
      computation_wrapper_instances.building_block_to_computation(update))
Exemplo n.º 14
0
def get_canonical_form_for_iterative_process(iterative_process):
    """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process.

  This function transforms computations from the input `iterative_process` into
  an instance of `tff.backends.mapreduce.CanonicalForm`.

  Args:
    iterative_process: An instance of `tff.utils.IterativeProcess`.

  Returns:
    An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this
    process.

  Raises:
    TypeError: If the arguments are of the wrong types.
    transformations.CanonicalFormCompilationError: If the compilation
      process fails.
  """
    py_typecheck.check_type(iterative_process,
                            computation_utils.IterativeProcess)

    initialize_comp = building_blocks.ComputationBuildingBlock.from_proto(
        iterative_process.initialize._computation_proto)  # pylint: disable=protected-access
    next_comp = building_blocks.ComputationBuildingBlock.from_proto(
        iterative_process.next._computation_proto)  # pylint: disable=protected-access
    _check_iterative_process_compatible_with_canonical_form(
        initialize_comp, next_comp)

    if len(next_comp.type_signature.result) == 2:
        next_comp = _create_next_with_fake_client_output(next_comp)

    initialize_comp = _replace_intrinsics_with_bodies(initialize_comp)
    next_comp = _replace_intrinsics_with_bodies(next_comp)
    tree_analysis.check_intrinsics_whitelisted_for_reduction(initialize_comp)
    tree_analysis.check_intrinsics_whitelisted_for_reduction(next_comp)
    tree_analysis.check_broadcast_not_dependent_on_aggregate(next_comp)

    if tree_analysis.contains_called_intrinsic(
            next_comp, intrinsic_defs.FEDERATED_BROADCAST.uri):
        before_broadcast, after_broadcast = (
            transformations.force_align_and_split_by_intrinsics(
                next_comp, [intrinsic_defs.FEDERATED_BROADCAST.uri]))
    else:
        before_broadcast, after_broadcast = (
            _create_before_and_after_broadcast_for_no_broadcast(next_comp))

    before_aggregate, after_aggregate = (
        transformations.force_align_and_split_by_intrinsics(
            after_broadcast, [intrinsic_defs.FEDERATED_AGGREGATE.uri]))

    type_info = _get_type_info(initialize_comp, before_broadcast,
                               after_broadcast, before_aggregate,
                               after_aggregate)

    initialize = transformations.consolidate_and_extract_local_processing(
        initialize_comp)
    _check_type_equal(initialize.type_signature, type_info['initialize_type'])

    prepare = _extract_prepare(before_broadcast)
    _check_type_equal(prepare.type_signature, type_info['prepare_type'])

    work = _extract_work(before_aggregate, after_aggregate)
    _check_type_equal(work.type_signature, type_info['work_type'])

    zero, accumulate, merge, report = _extract_aggregate_functions(
        before_aggregate)
    _check_type_equal(zero.type_signature, type_info['zero_type'])
    _check_type_equal(accumulate.type_signature, type_info['accumulate_type'])
    _check_type_equal(merge.type_signature, type_info['merge_type'])
    _check_type_equal(report.type_signature, type_info['report_type'])

    update = _extract_update(after_aggregate)
    _check_type_equal(update.type_signature, type_info['update_type'])

    return canonical_form.CanonicalForm(
        computation_wrapper_instances.building_block_to_computation(
            initialize),
        computation_wrapper_instances.building_block_to_computation(prepare),
        computation_wrapper_instances.building_block_to_computation(work),
        computation_wrapper_instances.building_block_to_computation(zero),
        computation_wrapper_instances.building_block_to_computation(
            accumulate),
        computation_wrapper_instances.building_block_to_computation(merge),
        computation_wrapper_instances.building_block_to_computation(report),
        computation_wrapper_instances.building_block_to_computation(update))
Exemplo n.º 15
0
def get_canonical_form_for_iterative_process(iterative_process):
    """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process.

  This function transforms computations from the input `iterative_process` into
  an instance of `tff.backends.mapreduce.CanonicalForm`.

  Args:
    iterative_process: An instance of `tff.utils.IterativeProcess`.

  Returns:
    An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this
    process.

  Raises:
    TypeError: If the arguments are of the wrong types.
    transformations.CanonicalFormCompilationError: If the compilation
      process fails.
  """
    py_typecheck.check_type(iterative_process,
                            computation_utils.IterativeProcess)

    initialize_comp = building_blocks.ComputationBuildingBlock.from_proto(
        iterative_process.initialize._computation_proto)  # pylint: disable=protected-access

    next_comp = building_blocks.ComputationBuildingBlock.from_proto(
        iterative_process.next._computation_proto)  # pylint: disable=protected-access

    if not (isinstance(next_comp.type_signature.parameter,
                       computation_types.NamedTupleType)
            and isinstance(next_comp.type_signature.result,
                           computation_types.NamedTupleType)):
        raise TypeError(
            'Any IterativeProcess compatible with CanonicalForm must '
            'have a `next` function which takes and returns instances '
            'of `tff.NamedTupleType`; your next function takes '
            'parameters of type {} and returns results of type {}'.format(
                next_comp.type_signature.parameter,
                next_comp.type_signature.result))

    if len(next_comp.type_signature.result) == 2:
        next_comp = _create_next_with_fake_client_output(next_comp)

    initialize_comp = replace_intrinsics_with_bodies(initialize_comp)
    next_comp = replace_intrinsics_with_bodies(next_comp)

    tree_analysis.check_intrinsics_whitelisted_for_reduction(initialize_comp)
    tree_analysis.check_intrinsics_whitelisted_for_reduction(next_comp)
    tree_analysis.check_broadcast_not_dependent_on_aggregate(next_comp)

    if tree_analysis.contains_called_intrinsic(
            next_comp, intrinsic_defs.FEDERATED_BROADCAST.uri):
        before_broadcast, after_broadcast = (
            transformations.force_align_and_split_by_intrinsics(
                next_comp, [intrinsic_defs.FEDERATED_BROADCAST.uri]))
    else:
        before_broadcast, after_broadcast = (
            _create_before_and_after_broadcast_for_no_broadcast(next_comp))

    before_aggregate, after_aggregate = (
        transformations.force_align_and_split_by_intrinsics(
            after_broadcast, [intrinsic_defs.FEDERATED_AGGREGATE.uri]))

    init_info_packed = pack_initialize_comp_type_signature(
        initialize_comp.type_signature)

    next_info_packed = pack_next_comp_type_signature(next_comp.type_signature,
                                                     init_info_packed)

    before_broadcast_info_packed = (
        check_and_pack_before_broadcast_type_signature(
            before_broadcast.type_signature, next_info_packed))

    before_aggregate_info_packed = (
        check_and_pack_before_aggregate_type_signature(
            before_aggregate.type_signature, before_broadcast_info_packed))

    canonical_form_types = check_and_pack_after_aggregate_type_signature(
        after_aggregate.type_signature, before_aggregate_info_packed)

    initialize = transformations.consolidate_and_extract_local_processing(
        initialize_comp)

    if initialize.type_signature.result != canonical_form_types[
            'initialize_type'].member:
        raise transformations.CanonicalFormCompilationError(
            'Compilation of initialize has failed. Expected to extract a '
            '`building_blocks.CompiledComputation` of type {}, instead we extracted '
            'a {} of type {}.'.format(next_comp.type_signature.parameter[0],
                                      type(initialize),
                                      initialize.type_signature.result))

    prepare = extract_prepare(before_broadcast, canonical_form_types)

    work = extract_work(before_aggregate, after_aggregate,
                        canonical_form_types)

    zero_noarg_function, accumulate, merge, report = extract_aggregate_functions(
        before_aggregate, canonical_form_types)

    update = extract_update(after_aggregate, canonical_form_types)

    cf = canonical_form.CanonicalForm(
        computation_wrapper_instances.building_block_to_computation(
            initialize),
        computation_wrapper_instances.building_block_to_computation(prepare),
        computation_wrapper_instances.building_block_to_computation(work),
        computation_wrapper_instances.building_block_to_computation(
            zero_noarg_function),
        computation_wrapper_instances.building_block_to_computation(
            accumulate),
        computation_wrapper_instances.building_block_to_computation(merge),
        computation_wrapper_instances.building_block_to_computation(report),
        computation_wrapper_instances.building_block_to_computation(update))
    return cf
Exemplo n.º 16
0
def get_canonical_form_for_iterative_process(
    ip: iterative_process.IterativeProcess,
    grappler_config: Optional[
        tf.compat.v1.ConfigProto] = _GRAPPLER_DEFAULT_CONFIG):
    """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process.

  This function transforms computations from the input `ip` into
  an instance of `tff.backends.mapreduce.CanonicalForm`.

  Args:
    ip: An instance of `tff.templates.IterativeProcess` that is compatible
      with canonical form. Iterative processes are only compatible if: -
        `initialize_fn` returns a single federated value placed at `SERVER`. -
        `next` takes exactly two arguments. The first must be the state value
        placed at `SERVER`. - `next` returns exactly two values.
    grappler_config: An optional instance of `tf.compat.v1.ConfigProto` to
      configure Grappler graph optimization of the TensorFlow graphs backing the
      resulting `tff.backends.mapreduce.CanonicalForm`. These options are
      combined with a set of defaults that aggressively configure Grappler. If
      `None`, Grappler is bypassed.

  Returns:
    An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this
    process.

  Raises:
    TypeError: If the arguments are of the wrong types.
    transformations.CanonicalFormCompilationError: If the compilation
      process fails.
  """
    py_typecheck.check_type(ip, iterative_process.IterativeProcess)
    if grappler_config is not None:
        py_typecheck.check_type(grappler_config, tf.compat.v1.ConfigProto)
        overridden_grappler_config = tf.compat.v1.ConfigProto()
        overridden_grappler_config.CopyFrom(_GRAPPLER_DEFAULT_CONFIG)
        overridden_grappler_config.MergeFrom(grappler_config)
        grappler_config = overridden_grappler_config

    initialize_comp = building_blocks.ComputationBuildingBlock.from_proto(
        ip.initialize._computation_proto)  # pylint: disable=protected-access
    next_comp = building_blocks.ComputationBuildingBlock.from_proto(
        ip.next._computation_proto)  # pylint: disable=protected-access
    _check_iterative_process_compatible_with_canonical_form(
        initialize_comp, next_comp)

    initialize_comp = _replace_intrinsics_with_bodies(initialize_comp)
    next_comp = _replace_intrinsics_with_bodies(next_comp)
    next_comp = _replace_lambda_body_with_call_dominant_form(next_comp)

    tree_analysis.check_contains_only_reducible_intrinsics(initialize_comp)
    tree_analysis.check_contains_only_reducible_intrinsics(next_comp)
    tree_analysis.check_broadcast_not_dependent_on_aggregate(next_comp)

    if tree_analysis.contains_called_intrinsic(
            next_comp, intrinsic_defs.FEDERATED_BROADCAST.uri):

        before_broadcast, after_broadcast = (
            transformations.force_align_and_split_by_intrinsics(
                next_comp, [intrinsic_defs.FEDERATED_BROADCAST.uri]))
    else:
        before_broadcast, after_broadcast = (
            _create_before_and_after_broadcast_for_no_broadcast(next_comp))

    contains_federated_aggregate = tree_analysis.contains_called_intrinsic(
        next_comp, intrinsic_defs.FEDERATED_AGGREGATE.uri)
    contains_federated_secure_sum = tree_analysis.contains_called_intrinsic(
        next_comp, intrinsic_defs.FEDERATED_SECURE_SUM.uri)
    if not (contains_federated_aggregate or contains_federated_secure_sum):
        raise ValueError(
            'Expected an `tff.templates.IterativeProcess` containing at least one '
            '`federated_aggregate` or `federated_secure_sum`, found none.')

    if contains_federated_aggregate and contains_federated_secure_sum:
        before_aggregate, after_aggregate = (
            transformations.force_align_and_split_by_intrinsics(
                after_broadcast, [
                    intrinsic_defs.FEDERATED_AGGREGATE.uri,
                    intrinsic_defs.FEDERATED_SECURE_SUM.uri,
                ]))
    elif contains_federated_secure_sum:
        assert not contains_federated_aggregate
        before_aggregate, after_aggregate = (
            _create_before_and_after_aggregate_for_no_federated_aggregate(
                after_broadcast))
    else:
        assert contains_federated_aggregate and not contains_federated_secure_sum
        before_aggregate, after_aggregate = (
            _create_before_and_after_aggregate_for_no_federated_secure_sum(
                after_broadcast))

    type_info = _get_type_info(initialize_comp, before_broadcast,
                               after_broadcast, before_aggregate,
                               after_aggregate)

    initialize = transformations.consolidate_and_extract_local_processing(
        initialize_comp, grappler_config)
    _check_type_equal(initialize.type_signature, type_info['initialize_type'])

    prepare = _extract_prepare(before_broadcast, grappler_config)
    _check_type_equal(prepare.type_signature, type_info['prepare_type'])

    work = _extract_work(before_aggregate, grappler_config)
    _check_type_equal(work.type_signature, type_info['work_type'])

    zero, accumulate, merge, report = _extract_federated_aggregate_functions(
        before_aggregate, grappler_config)
    _check_type_equal(zero.type_signature, type_info['zero_type'])
    _check_type_equal(accumulate.type_signature, type_info['accumulate_type'])
    _check_type_equal(merge.type_signature, type_info['merge_type'])
    _check_type_equal(report.type_signature, type_info['report_type'])

    bitwidth = _extract_federated_secure_sum_functions(before_aggregate,
                                                       grappler_config)
    _check_type_equal(bitwidth.type_signature, type_info['bitwidth_type'])

    update = _extract_update(after_aggregate, grappler_config)
    _check_type_equal(update.type_signature, type_info['update_type'])

    next_parameter_names = (name for (
        name, _) in structure.iter_elements(ip.next.type_signature.parameter))
    server_state_label, client_data_label = next_parameter_names
    return canonical_form.CanonicalForm(
        computation_wrapper_instances.building_block_to_computation(
            initialize),
        computation_wrapper_instances.building_block_to_computation(prepare),
        computation_wrapper_instances.building_block_to_computation(work),
        computation_wrapper_instances.building_block_to_computation(zero),
        computation_wrapper_instances.building_block_to_computation(
            accumulate),
        computation_wrapper_instances.building_block_to_computation(merge),
        computation_wrapper_instances.building_block_to_computation(report),
        computation_wrapper_instances.building_block_to_computation(bitwidth),
        computation_wrapper_instances.building_block_to_computation(update),
        server_state_label=server_state_label,
        client_data_label=client_data_label)