def test_returns_trees_with_one_federated_aggregate_and_one_federated_secure_sum_for_federated_secure_sum_first(
            self):
        federated_aggregate = compiler_test_utils.create_dummy_called_federated_aggregate(
            accumulate_parameter_name='a',
            merge_parameter_name='b',
            report_parameter_name='c')
        federated_secure_sum = compiler_test_utils.create_dummy_called_federated_secure_sum(
        )
        called_intrinsics = building_blocks.Struct([
            federated_aggregate,
            federated_secure_sum,
        ])
        comp = building_blocks.Lambda('d', tf.int32, called_intrinsics)
        uri = [
            intrinsic_defs.FEDERATED_SECURE_SUM.uri,
            intrinsic_defs.FEDERATED_AGGREGATE.uri,
        ]

        before, after = transformations.force_align_and_split_by_intrinsics(
            comp, uri)

        self.assertIsInstance(before, building_blocks.Lambda)
        self.assertFalse(tree_analysis.contains_called_intrinsic(before, uri))
        self.assertIsInstance(after, building_blocks.Lambda)
        self.assertFalse(tree_analysis.contains_called_intrinsic(after, uri))
    def test_handles_federated_broadcasts_nested_in_tuple(self):
        first_broadcast = compiler_test_utils.create_dummy_called_federated_broadcast(
        )
        packed_broadcast = building_blocks.Struct([
            building_blocks.Data(
                'a',
                computation_types.FederatedType(
                    computation_types.TensorType(tf.int32),
                    placements.SERVER)), first_broadcast
        ])
        sel = building_blocks.Selection(packed_broadcast, index=0)
        second_broadcast = building_block_factory.create_federated_broadcast(
            sel)
        result, _ = compiler_transformations.transform_to_call_dominant(
            second_broadcast)
        comp = building_blocks.Lambda('a', tf.int32, result)
        uri = [intrinsic_defs.FEDERATED_BROADCAST.uri]

        before, after = transformations.force_align_and_split_by_intrinsics(
            comp, uri)

        self.assertIsInstance(before, building_blocks.Lambda)
        self.assertFalse(tree_analysis.contains_called_intrinsic(before, uri))
        self.assertIsInstance(after, building_blocks.Lambda)
        self.assertFalse(tree_analysis.contains_called_intrinsic(after, uri))
 def test_cannot_split_on_chained_intrinsic(self):
   int_type = computation_types.TensorType(tf.int32)
   client_int_type = computation_types.at_clients(int_type)
   int_ref = lambda name: building_blocks.Reference(name, int_type)
   client_int_ref = (
       lambda name: building_blocks.Reference(name, client_int_type))
   body = building_blocks.Block([
       ('a',
        building_block_factory.create_federated_map(
            building_blocks.Lambda('p1', int_type, int_ref('p1')),
            client_int_ref('param'))),
       ('b',
        building_block_factory.create_federated_map(
            building_blocks.Lambda('p2', int_type, int_ref('p2')),
            client_int_ref('a'))),
   ], client_int_ref('b'))
   comp = building_blocks.Lambda('param', int_type, body)
   with self.assertRaises(transformations._NonAlignableAlongIntrinsicError):
     transformations.force_align_and_split_by_intrinsics(
         comp, [building_block_factory.create_null_federated_map()])
    def test_returns_trees_with_one_federated_secure_sum(self):
        federated_secure_sum = compiler_test_utils.create_whimsy_called_federated_secure_sum(
        )
        called_intrinsics = building_blocks.Struct([federated_secure_sum])
        comp = building_blocks.Lambda('a', tf.int32, called_intrinsics)
        uri = [intrinsic_defs.FEDERATED_SECURE_SUM.uri]

        before, after = transformations.force_align_and_split_by_intrinsics(
            comp, uri)

        self.assertIsInstance(before, building_blocks.Lambda)
        self.assertFalse(tree_analysis.contains_called_intrinsic(before, uri))
        self.assertIsInstance(after, building_blocks.Lambda)
        self.assertFalse(tree_analysis.contains_called_intrinsic(after, uri))
  def test_returns_trees_with_one_federated_broadcast(self):
    federated_broadcast = compiler_test_utils.create_dummy_called_federated_broadcast(
    )
    called_intrinsics = building_blocks.Tuple([federated_broadcast])
    comp = building_blocks.Lambda('a', tf.int32, called_intrinsics)
    uri = [intrinsic_defs.FEDERATED_BROADCAST.uri]

    before, after = transformations.force_align_and_split_by_intrinsics(
        comp, uri)

    self.assertIsInstance(before, building_blocks.Lambda)
    self.assertFalse(tree_analysis.contains_called_intrinsic(before, uri))
    self.assertIsInstance(after, building_blocks.Lambda)
    self.assertFalse(tree_analysis.contains_called_intrinsic(after, uri))
def _split_ast_on_broadcast(bb):
    """Splits an AST on the `broadcast` intrinsic.

  Args:
    bb: An AST of arbitrary shape, potentially containing a broadcast.

  Returns:
    Two ASTs, the first of which maps comp's input to the
    argument of broadcast, and the second of which maps comp's input and
    broadcast's output to comp's output.
  """
    before, after = transformations.force_align_and_split_by_intrinsics(
        bb, [building_block_factory.create_null_federated_broadcast()])
    return _untuple_broadcast_only_before_after(before, after)
  def test_returns_tree(self):
    ip = get_iterative_process_for_sum_example_with_no_federated_aggregate()
    next_tree = building_blocks.ComputationBuildingBlock.from_proto(
        ip.next._computation_proto)

    before_aggregate, after_aggregate = canonical_form_utils._create_before_and_after_aggregate_for_no_federated_aggregate(
        next_tree)

    before_federated_secure_sum, after_federated_secure_sum = (
        transformations.force_align_and_split_by_intrinsics(
            next_tree, [intrinsic_defs.FEDERATED_SECURE_SUM.uri]))
    self.assertIsInstance(before_aggregate, building_blocks.Lambda)
    self.assertIsInstance(before_aggregate.result, building_blocks.Tuple)
    self.assertLen(before_aggregate.result, 2)

    # pyformat: disable
    self.assertEqual(
        before_aggregate.result[0].formatted_representation(),
        '<\n'
        '  federated_value_at_clients(<>),\n'
        '  <>,\n'
        '  (_var1 -> <>),\n'
        '  (_var2 -> <>),\n'
        '  (_var3 -> <>)\n'
        '>'
    )
    # pyformat: enable

    self.assertEqual(
        before_aggregate.result[1].formatted_representation(),
        before_federated_secure_sum.result.formatted_representation())

    self.assertIsInstance(after_aggregate, building_blocks.Lambda)
    self.assertIsInstance(after_aggregate.result, building_blocks.Call)
    actual_tree, _ = tree_transformations.uniquify_reference_names(
        after_aggregate.result.function)
    expected_tree, _ = tree_transformations.uniquify_reference_names(
        after_federated_secure_sum)
    self.assertEqual(actual_tree.formatted_representation(),
                     expected_tree.formatted_representation())

    # pyformat: disable
    self.assertEqual(
        after_aggregate.result.argument.formatted_representation(),
        '<\n'
        '  _var4[0],\n'
        '  _var4[1][1]\n'
        '>'
    )
  def assert_splits_on(self, comp, calls):
    """Asserts that `force_align_and_split_by_intrinsics` removes intrinsics."""
    if not isinstance(calls, list):
      calls = [calls]
    uris = [call.function.uri for call in calls]
    before, after = transformations.force_align_and_split_by_intrinsics(
        comp, calls)

    # Ensure that the resulting computations no longer contain the split
    # intrinsics.
    self.assertFalse(tree_analysis.contains_called_intrinsic(before, uris))
    self.assertFalse(tree_analysis.contains_called_intrinsic(after, uris))
    # Removal isn't interesting to test for if it wasn't there to begin with.
    self.assertTrue(tree_analysis.contains_called_intrinsic(comp, uris))

    self.assert_types_equivalent(comp.parameter_type, before.parameter_type)
    # THere must be one parameter for each intrinsic in `calls`.
    before.type_signature.result.check_struct()
    self.assertLen(before.type_signature.result, len(calls))

    # Check that `after`'s parameter is a structure like:
    # {
    #   'original_arg': comp.parameter_type,
    #   'intrinsic_results': [...],
    # }
    after.parameter_type.check_struct()
    self.assertLen(after.parameter_type, 2)
    self.assert_types_equivalent(comp.parameter_type,
                                 after.parameter_type.original_arg)
    # There must be one result for each intrinsic in `calls`.
    self.assertLen(after.parameter_type.intrinsic_results, len(calls))

    # Check that each pair of (param, result) is a valid type substitution
    # for the intrinsic in question.
    for i in range(len(calls)):
      concrete_signature = computation_types.FunctionType(
          before.type_signature.result[i],
          after.parameter_type.intrinsic_results[i])
      abstract_signature = calls[i].function.intrinsic_def().type_signature
      # `force_align_and_split_by_intrinsics` loses all-equal data due to
      # zipping and unzipping. This is okay because the resulting computations
      # are not used together directly, but are compiled into unplaced TF code.
      abstract_signature = _remove_client_all_equals_from_type(
          abstract_signature)
      concrete_signature = _remove_client_all_equals_from_type(
          concrete_signature)
      type_analysis.check_concrete_instance_of(concrete_signature,
                                               abstract_signature)
def _split_ast_on_aggregate(bb):
    """Splits an AST on reduced aggregation intrinsics.

  Args:
    bb: An AST containing `federated_aggregate` or
      `federated_secure_sum_bitwidth` aggregations.

  Returns:
    Two ASTs, the first of which maps comp's input to the arguments
    to `federated_aggregate` and `federated_secure_sum_bitwidth`, and the
    second of which maps comp's input and the output of `federated_aggregate`
    and `federated_secure_sum_bitwidth` to comp's output.
  """
    return transformations.force_align_and_split_by_intrinsics(
        bb, [
            building_block_factory.create_null_federated_aggregate(),
            building_block_factory.create_null_federated_secure_sum_bitwidth()
        ])
def get_canonical_form_for_iterative_process(ip):
  """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process.

  This function transforms computations from the input `ip` into
  an instance of `tff.backends.mapreduce.CanonicalForm`.

  Args:
    ip: An instance of `tff.templates.IterativeProcess`.

  Returns:
    An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this
    process.

  Raises:
    TypeError: If the arguments are of the wrong types.
    transformations.CanonicalFormCompilationError: If the compilation
      process fails.
  """
  py_typecheck.check_type(ip, iterative_process.IterativeProcess)

  initialize_comp = building_blocks.ComputationBuildingBlock.from_proto(
      ip.initialize._computation_proto)  # pylint: disable=protected-access
  next_comp = building_blocks.ComputationBuildingBlock.from_proto(
      ip.next._computation_proto)  # pylint: disable=protected-access
  _check_iterative_process_compatible_with_canonical_form(
      initialize_comp, next_comp)

  initialize_comp = _replace_intrinsics_with_bodies(initialize_comp)
  next_comp = _replace_intrinsics_with_bodies(next_comp)
  tree_analysis.check_contains_only_reducible_intrinsics(initialize_comp)
  tree_analysis.check_contains_only_reducible_intrinsics(next_comp)
  tree_analysis.check_broadcast_not_dependent_on_aggregate(next_comp)

  if tree_analysis.contains_called_intrinsic(
      next_comp, intrinsic_defs.FEDERATED_BROADCAST.uri):
    before_broadcast, after_broadcast = (
        transformations.force_align_and_split_by_intrinsics(
            next_comp, [intrinsic_defs.FEDERATED_BROADCAST.uri]))
  else:
    before_broadcast, after_broadcast = (
        _create_before_and_after_broadcast_for_no_broadcast(next_comp))

  contains_federated_aggregate = tree_analysis.contains_called_intrinsic(
      next_comp, intrinsic_defs.FEDERATED_AGGREGATE.uri)
  contains_federated_secure_sum = tree_analysis.contains_called_intrinsic(
      next_comp, intrinsic_defs.FEDERATED_SECURE_SUM.uri)
  if not (contains_federated_aggregate or contains_federated_secure_sum):
    raise ValueError(
        'Expected an `tff.templates.IterativeProcess` containing at least one '
        '`federated_aggregate` or `federated_secure_sum`, found none.')

  if contains_federated_aggregate and contains_federated_secure_sum:
    before_aggregate, after_aggregate = (
        transformations.force_align_and_split_by_intrinsics(
            after_broadcast, [
                intrinsic_defs.FEDERATED_AGGREGATE.uri,
                intrinsic_defs.FEDERATED_SECURE_SUM.uri,
            ]))
  elif contains_federated_secure_sum:
    assert not contains_federated_aggregate
    before_aggregate, after_aggregate = (
        _create_before_and_after_aggregate_for_no_federated_aggregate(
            after_broadcast))
  else:
    assert contains_federated_aggregate and not contains_federated_secure_sum
    before_aggregate, after_aggregate = (
        _create_before_and_after_aggregate_for_no_federated_secure_sum(
            after_broadcast))

  type_info = _get_type_info(initialize_comp, before_broadcast, after_broadcast,
                             before_aggregate, after_aggregate)

  initialize = transformations.consolidate_and_extract_local_processing(
      initialize_comp)
  _check_type_equal(initialize.type_signature, type_info['initialize_type'])

  prepare = _extract_prepare(before_broadcast)
  _check_type_equal(prepare.type_signature, type_info['prepare_type'])

  work = _extract_work(before_aggregate)
  _check_type_equal(work.type_signature, type_info['work_type'])

  zero, accumulate, merge, report = _extract_federated_aggregate_functions(
      before_aggregate)
  _check_type_equal(zero.type_signature, type_info['zero_type'])
  _check_type_equal(accumulate.type_signature, type_info['accumulate_type'])
  _check_type_equal(merge.type_signature, type_info['merge_type'])
  _check_type_equal(report.type_signature, type_info['report_type'])

  bitwidth = _extract_federated_secure_sum_functions(before_aggregate)
  _check_type_equal(bitwidth.type_signature, type_info['bitwidth_type'])

  update = _extract_update(after_aggregate)
  _check_type_equal(update.type_signature, type_info['update_type'])

  return canonical_form.CanonicalForm(
      computation_wrapper_instances.building_block_to_computation(initialize),
      computation_wrapper_instances.building_block_to_computation(prepare),
      computation_wrapper_instances.building_block_to_computation(work),
      computation_wrapper_instances.building_block_to_computation(zero),
      computation_wrapper_instances.building_block_to_computation(accumulate),
      computation_wrapper_instances.building_block_to_computation(merge),
      computation_wrapper_instances.building_block_to_computation(report),
      computation_wrapper_instances.building_block_to_computation(bitwidth),
      computation_wrapper_instances.building_block_to_computation(update))
def _create_before_and_after_aggregate_for_no_federated_secure_sum(tree):
  r"""Creates a before and after aggregate computations for the given `tree`.

  Lambda
  |
  Tuple
  |
  [Comp, Tuple]
         |
         [Tuple, []]
          |
          []

       Lambda(x)
       |
       Call
      /    \
  Comp      Tuple
            |
            [Sel(0),      Sel(0)]
            /            /
         Ref(x)    Sel(1)
                  /
            Ref(x)

  In the first AST, the first element returned by `Lambda`, `Comp`, is the
  result of the before aggregate returned by force aligning and splitting `tree`
  by `intrinsic_defs.FEDERATED_AGGREGATE.uri` and the second element returned by
  `Lambda` is an empty structure that represents the argument to the secure sum
  intrinsic. Therefore, the first AST has a type signature satisfying the
  requirements of before aggregate.

  In the second AST, `Comp` is the after aggregate returned by force aligning
  and splitting `tree` by intrinsic_defs.FEDERATED_AGGREGATE.uri; `Lambda` has a
  type signature satisfying the requirements of after aggregate; and the
  argument passed to `Comp` is a selection from the parameter of `Lambda` which
  intentionally drops `s4` on the floor.

  This function is intended to be used by
  `get_canonical_form_for_iterative_process` to create before and after
  broadcast computations for the given `tree` when there is no
  `intrinsic_defs.FEDERATED_SECURE_SUM` in `tree`. As a result, this function
  does not assert that there is no `intrinsic_defs.FEDERATED_SECURE_SUM` in
  `tree` and it does not assert that `tree` has the expected structure, the
  caller is expected to perform these checks before calling this function.

  Args:
    tree: An instance of `building_blocks.ComputationBuildingBlock`.

  Returns:
    A pair of the form `(before, after)`, where each of `before` and `after`
    is a `tff_framework.ComputationBuildingBlock` that represents a part of the
    result as specified by
    `transformations.force_align_and_split_by_intrinsics`.
  """
  name_generator = building_block_factory.unique_name_generator(tree)

  before_aggregate, after_aggregate = (
      transformations.force_align_and_split_by_intrinsics(
          tree, [intrinsic_defs.FEDERATED_AGGREGATE.uri]))

  empty_tuple = building_blocks.Struct([])
  value = building_block_factory.create_federated_value(empty_tuple,
                                                        placements.CLIENTS)
  bitwidth = empty_tuple
  args = building_blocks.Struct([value, bitwidth])
  result = building_blocks.Struct([before_aggregate.result, args])
  before_aggregate = building_blocks.Lambda(before_aggregate.parameter_name,
                                            before_aggregate.parameter_type,
                                            result)

  ref_name = next(name_generator)
  s4_type = computation_types.FederatedType([], placements.SERVER)
  ref_type = computation_types.StructType([
      after_aggregate.parameter_type[0],
      computation_types.StructType([
          after_aggregate.parameter_type[1],
          s4_type,
      ]),
  ])
  ref = building_blocks.Reference(ref_name, ref_type)
  sel_arg = building_blocks.Selection(ref, index=0)
  sel = building_blocks.Selection(ref, index=1)
  sel_s3 = building_blocks.Selection(sel, index=0)
  arg = building_blocks.Struct([sel_arg, sel_s3])
  call = building_blocks.Call(after_aggregate, arg)
  after_aggregate = building_blocks.Lambda(ref.name, ref.type_signature, call)

  return before_aggregate, after_aggregate
Exemple #12
0
def get_canonical_form_for_iterative_process(iterative_process):
    """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process.

  This function transforms computations from the input `iterative_process` into
  an instance of `tff.backends.mapreduce.CanonicalForm`.

  Args:
    iterative_process: An instance of `tff.utils.IterativeProcess`.

  Returns:
    An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this
    process.

  Raises:
    TypeError: If the arguments are of the wrong types.
    transformations.CanonicalFormCompilationError: If the compilation
      process fails.
  """
    py_typecheck.check_type(iterative_process,
                            computation_utils.IterativeProcess)

    initialize_comp = building_blocks.ComputationBuildingBlock.from_proto(
        iterative_process.initialize._computation_proto)  # pylint: disable=protected-access
    next_comp = building_blocks.ComputationBuildingBlock.from_proto(
        iterative_process.next._computation_proto)  # pylint: disable=protected-access
    _check_iterative_process_compatible_with_canonical_form(
        initialize_comp, next_comp)

    if len(next_comp.type_signature.result) == 2:
        next_comp = _create_next_with_fake_client_output(next_comp)

    initialize_comp = _replace_intrinsics_with_bodies(initialize_comp)
    next_comp = _replace_intrinsics_with_bodies(next_comp)
    tree_analysis.check_intrinsics_whitelisted_for_reduction(initialize_comp)
    tree_analysis.check_intrinsics_whitelisted_for_reduction(next_comp)
    tree_analysis.check_broadcast_not_dependent_on_aggregate(next_comp)

    if tree_analysis.contains_called_intrinsic(
            next_comp, intrinsic_defs.FEDERATED_BROADCAST.uri):
        before_broadcast, after_broadcast = (
            transformations.force_align_and_split_by_intrinsics(
                next_comp, [intrinsic_defs.FEDERATED_BROADCAST.uri]))
    else:
        before_broadcast, after_broadcast = (
            _create_before_and_after_broadcast_for_no_broadcast(next_comp))

    before_aggregate, after_aggregate = (
        transformations.force_align_and_split_by_intrinsics(
            after_broadcast, [intrinsic_defs.FEDERATED_AGGREGATE.uri]))

    type_info = _get_type_info(initialize_comp, before_broadcast,
                               after_broadcast, before_aggregate,
                               after_aggregate)

    initialize = transformations.consolidate_and_extract_local_processing(
        initialize_comp)
    _check_type_equal(initialize.type_signature, type_info['initialize_type'])

    prepare = _extract_prepare(before_broadcast)
    _check_type_equal(prepare.type_signature, type_info['prepare_type'])

    work = _extract_work(before_aggregate, after_aggregate)
    _check_type_equal(work.type_signature, type_info['work_type'])

    zero, accumulate, merge, report = _extract_aggregate_functions(
        before_aggregate)
    _check_type_equal(zero.type_signature, type_info['zero_type'])
    _check_type_equal(accumulate.type_signature, type_info['accumulate_type'])
    _check_type_equal(merge.type_signature, type_info['merge_type'])
    _check_type_equal(report.type_signature, type_info['report_type'])

    update = _extract_update(after_aggregate)
    _check_type_equal(update.type_signature, type_info['update_type'])

    return canonical_form.CanonicalForm(
        computation_wrapper_instances.building_block_to_computation(
            initialize),
        computation_wrapper_instances.building_block_to_computation(prepare),
        computation_wrapper_instances.building_block_to_computation(work),
        computation_wrapper_instances.building_block_to_computation(zero),
        computation_wrapper_instances.building_block_to_computation(
            accumulate),
        computation_wrapper_instances.building_block_to_computation(merge),
        computation_wrapper_instances.building_block_to_computation(report),
        computation_wrapper_instances.building_block_to_computation(update))
def get_canonical_form_for_iterative_process(iterative_process):
    """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process.

  This function transforms computations from the input `iterative_process` into
  an instance of `tff.backends.mapreduce.CanonicalForm`.

  Args:
    iterative_process: An instance of `tff.utils.IterativeProcess`.

  Returns:
    An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this
    process.

  Raises:
    TypeError: If the arguments are of the wrong types.
    transformations.CanonicalFormCompilationError: If the compilation
      process fails.
  """
    py_typecheck.check_type(iterative_process,
                            computation_utils.IterativeProcess)

    initialize_comp = building_blocks.ComputationBuildingBlock.from_proto(
        iterative_process.initialize._computation_proto)  # pylint: disable=protected-access

    next_comp = building_blocks.ComputationBuildingBlock.from_proto(
        iterative_process.next._computation_proto)  # pylint: disable=protected-access

    if not (isinstance(next_comp.type_signature.parameter,
                       computation_types.NamedTupleType)
            and isinstance(next_comp.type_signature.result,
                           computation_types.NamedTupleType)):
        raise TypeError(
            'Any IterativeProcess compatible with CanonicalForm must '
            'have a `next` function which takes and returns instances '
            'of `tff.NamedTupleType`; your next function takes '
            'parameters of type {} and returns results of type {}'.format(
                next_comp.type_signature.parameter,
                next_comp.type_signature.result))

    if len(next_comp.type_signature.result) == 2:
        next_comp = _create_next_with_fake_client_output(next_comp)

    initialize_comp = replace_intrinsics_with_bodies(initialize_comp)
    next_comp = replace_intrinsics_with_bodies(next_comp)

    tree_analysis.check_intrinsics_whitelisted_for_reduction(initialize_comp)
    tree_analysis.check_intrinsics_whitelisted_for_reduction(next_comp)
    tree_analysis.check_broadcast_not_dependent_on_aggregate(next_comp)

    if tree_analysis.contains_called_intrinsic(
            next_comp, intrinsic_defs.FEDERATED_BROADCAST.uri):
        before_broadcast, after_broadcast = (
            transformations.force_align_and_split_by_intrinsics(
                next_comp, [intrinsic_defs.FEDERATED_BROADCAST.uri]))
    else:
        before_broadcast, after_broadcast = (
            _create_before_and_after_broadcast_for_no_broadcast(next_comp))

    before_aggregate, after_aggregate = (
        transformations.force_align_and_split_by_intrinsics(
            after_broadcast, [intrinsic_defs.FEDERATED_AGGREGATE.uri]))

    init_info_packed = pack_initialize_comp_type_signature(
        initialize_comp.type_signature)

    next_info_packed = pack_next_comp_type_signature(next_comp.type_signature,
                                                     init_info_packed)

    before_broadcast_info_packed = (
        check_and_pack_before_broadcast_type_signature(
            before_broadcast.type_signature, next_info_packed))

    before_aggregate_info_packed = (
        check_and_pack_before_aggregate_type_signature(
            before_aggregate.type_signature, before_broadcast_info_packed))

    canonical_form_types = check_and_pack_after_aggregate_type_signature(
        after_aggregate.type_signature, before_aggregate_info_packed)

    initialize = transformations.consolidate_and_extract_local_processing(
        initialize_comp)

    if initialize.type_signature.result != canonical_form_types[
            'initialize_type'].member:
        raise transformations.CanonicalFormCompilationError(
            'Compilation of initialize has failed. Expected to extract a '
            '`building_blocks.CompiledComputation` of type {}, instead we extracted '
            'a {} of type {}.'.format(next_comp.type_signature.parameter[0],
                                      type(initialize),
                                      initialize.type_signature.result))

    prepare = extract_prepare(before_broadcast, canonical_form_types)

    work = extract_work(before_aggregate, after_aggregate,
                        canonical_form_types)

    zero_noarg_function, accumulate, merge, report = extract_aggregate_functions(
        before_aggregate, canonical_form_types)

    update = extract_update(after_aggregate, canonical_form_types)

    cf = canonical_form.CanonicalForm(
        computation_wrapper_instances.building_block_to_computation(
            initialize),
        computation_wrapper_instances.building_block_to_computation(prepare),
        computation_wrapper_instances.building_block_to_computation(work),
        computation_wrapper_instances.building_block_to_computation(
            zero_noarg_function),
        computation_wrapper_instances.building_block_to_computation(
            accumulate),
        computation_wrapper_instances.building_block_to_computation(merge),
        computation_wrapper_instances.building_block_to_computation(report),
        computation_wrapper_instances.building_block_to_computation(update))
    return cf
def get_canonical_form_for_iterative_process(
    ip: iterative_process.IterativeProcess,
    grappler_config: Optional[
        tf.compat.v1.ConfigProto] = _GRAPPLER_DEFAULT_CONFIG):
    """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process.

  This function transforms computations from the input `ip` into
  an instance of `tff.backends.mapreduce.CanonicalForm`.

  Args:
    ip: An instance of `tff.templates.IterativeProcess` that is compatible
      with canonical form. Iterative processes are only compatible if: -
        `initialize_fn` returns a single federated value placed at `SERVER`. -
        `next` takes exactly two arguments. The first must be the state value
        placed at `SERVER`. - `next` returns exactly two values.
    grappler_config: An optional instance of `tf.compat.v1.ConfigProto` to
      configure Grappler graph optimization of the TensorFlow graphs backing the
      resulting `tff.backends.mapreduce.CanonicalForm`. These options are
      combined with a set of defaults that aggressively configure Grappler. If
      `None`, Grappler is bypassed.

  Returns:
    An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this
    process.

  Raises:
    TypeError: If the arguments are of the wrong types.
    transformations.CanonicalFormCompilationError: If the compilation
      process fails.
  """
    py_typecheck.check_type(ip, iterative_process.IterativeProcess)
    if grappler_config is not None:
        py_typecheck.check_type(grappler_config, tf.compat.v1.ConfigProto)
        overridden_grappler_config = tf.compat.v1.ConfigProto()
        overridden_grappler_config.CopyFrom(_GRAPPLER_DEFAULT_CONFIG)
        overridden_grappler_config.MergeFrom(grappler_config)
        grappler_config = overridden_grappler_config

    initialize_comp = building_blocks.ComputationBuildingBlock.from_proto(
        ip.initialize._computation_proto)  # pylint: disable=protected-access
    next_comp = building_blocks.ComputationBuildingBlock.from_proto(
        ip.next._computation_proto)  # pylint: disable=protected-access
    _check_iterative_process_compatible_with_canonical_form(
        initialize_comp, next_comp)

    initialize_comp = _replace_intrinsics_with_bodies(initialize_comp)
    next_comp = _replace_intrinsics_with_bodies(next_comp)
    next_comp = _replace_lambda_body_with_call_dominant_form(next_comp)

    tree_analysis.check_contains_only_reducible_intrinsics(initialize_comp)
    tree_analysis.check_contains_only_reducible_intrinsics(next_comp)
    tree_analysis.check_broadcast_not_dependent_on_aggregate(next_comp)

    if tree_analysis.contains_called_intrinsic(
            next_comp, intrinsic_defs.FEDERATED_BROADCAST.uri):

        before_broadcast, after_broadcast = (
            transformations.force_align_and_split_by_intrinsics(
                next_comp, [intrinsic_defs.FEDERATED_BROADCAST.uri]))
    else:
        before_broadcast, after_broadcast = (
            _create_before_and_after_broadcast_for_no_broadcast(next_comp))

    contains_federated_aggregate = tree_analysis.contains_called_intrinsic(
        next_comp, intrinsic_defs.FEDERATED_AGGREGATE.uri)
    contains_federated_secure_sum = tree_analysis.contains_called_intrinsic(
        next_comp, intrinsic_defs.FEDERATED_SECURE_SUM.uri)
    if not (contains_federated_aggregate or contains_federated_secure_sum):
        raise ValueError(
            'Expected an `tff.templates.IterativeProcess` containing at least one '
            '`federated_aggregate` or `federated_secure_sum`, found none.')

    if contains_federated_aggregate and contains_federated_secure_sum:
        before_aggregate, after_aggregate = (
            transformations.force_align_and_split_by_intrinsics(
                after_broadcast, [
                    intrinsic_defs.FEDERATED_AGGREGATE.uri,
                    intrinsic_defs.FEDERATED_SECURE_SUM.uri,
                ]))
    elif contains_federated_secure_sum:
        assert not contains_federated_aggregate
        before_aggregate, after_aggregate = (
            _create_before_and_after_aggregate_for_no_federated_aggregate(
                after_broadcast))
    else:
        assert contains_federated_aggregate and not contains_federated_secure_sum
        before_aggregate, after_aggregate = (
            _create_before_and_after_aggregate_for_no_federated_secure_sum(
                after_broadcast))

    type_info = _get_type_info(initialize_comp, before_broadcast,
                               after_broadcast, before_aggregate,
                               after_aggregate)

    initialize = transformations.consolidate_and_extract_local_processing(
        initialize_comp, grappler_config)
    _check_type_equal(initialize.type_signature, type_info['initialize_type'])

    prepare = _extract_prepare(before_broadcast, grappler_config)
    _check_type_equal(prepare.type_signature, type_info['prepare_type'])

    work = _extract_work(before_aggregate, grappler_config)
    _check_type_equal(work.type_signature, type_info['work_type'])

    zero, accumulate, merge, report = _extract_federated_aggregate_functions(
        before_aggregate, grappler_config)
    _check_type_equal(zero.type_signature, type_info['zero_type'])
    _check_type_equal(accumulate.type_signature, type_info['accumulate_type'])
    _check_type_equal(merge.type_signature, type_info['merge_type'])
    _check_type_equal(report.type_signature, type_info['report_type'])

    bitwidth = _extract_federated_secure_sum_functions(before_aggregate,
                                                       grappler_config)
    _check_type_equal(bitwidth.type_signature, type_info['bitwidth_type'])

    update = _extract_update(after_aggregate, grappler_config)
    _check_type_equal(update.type_signature, type_info['update_type'])

    next_parameter_names = (name for (
        name, _) in structure.iter_elements(ip.next.type_signature.parameter))
    server_state_label, client_data_label = next_parameter_names
    return canonical_form.CanonicalForm(
        computation_wrapper_instances.building_block_to_computation(
            initialize),
        computation_wrapper_instances.building_block_to_computation(prepare),
        computation_wrapper_instances.building_block_to_computation(work),
        computation_wrapper_instances.building_block_to_computation(zero),
        computation_wrapper_instances.building_block_to_computation(
            accumulate),
        computation_wrapper_instances.building_block_to_computation(merge),
        computation_wrapper_instances.building_block_to_computation(report),
        computation_wrapper_instances.building_block_to_computation(bitwidth),
        computation_wrapper_instances.building_block_to_computation(update),
        server_state_label=server_state_label,
        client_data_label=client_data_label)
Exemple #15
0
    def test_returns_type_info(self):
        ip = get_iterative_process_for_sum_example()
        initialize_tree = building_blocks.ComputationBuildingBlock.from_proto(
            ip.initialize._computation_proto)
        next_tree = building_blocks.ComputationBuildingBlock.from_proto(
            ip.next._computation_proto)
        initialize_tree = canonical_form_utils._replace_intrinsics_with_bodies(
            initialize_tree)
        next_tree = canonical_form_utils._replace_intrinsics_with_bodies(
            next_tree)
        before_broadcast, after_broadcast = (
            mapreduce_transformations.force_align_and_split_by_intrinsics(
                next_tree, [intrinsic_defs.FEDERATED_BROADCAST.uri]))
        before_aggregate, after_aggregate = (
            mapreduce_transformations.force_align_and_split_by_intrinsics(
                after_broadcast, [intrinsic_defs.FEDERATED_AGGREGATE.uri]))

        type_info = canonical_form_utils._get_type_info(
            initialize_tree, next_tree, before_broadcast, after_broadcast,
            before_aggregate, after_aggregate)

        actual = {
            label: type_signature.compact_representation()
            for label, type_signature in type_info.items()
        }
        # pyformat: disable
        expected = {
            'accumulate_type':
            '(<<int32,int32,int32,int32,int32,int32>,<int32,int32,int32,int32,int32,int32>> -> <int32,int32,int32,int32,int32,int32>)',
            'c1_type': '{int32}@CLIENTS',
            'c2_type':
            '<<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>>@CLIENTS',
            'c3_type':
            '{<int32,<<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>>>}@CLIENTS',
            'c4_type': '{<<int32,int32,int32,int32,int32,int32>,<>>}@CLIENTS',
            'c5_type': '{<int32,int32,int32,int32,int32,int32>}@CLIENTS',
            'c6_type': '{<>}@CLIENTS',
            'initialize_type': '( -> <int32,int32>)',
            'merge_type':
            '(<<int32,int32,int32,int32,int32,int32>,<int32,int32,int32,int32,int32,int32>> -> <int32,int32,int32,int32,int32,int32>)',
            'prepare_type':
            '(<int32,int32> -> <<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>>)',
            'report_type':
            '(<int32,int32,int32,int32,int32,int32> -> <int32,int32,int32,int32,int32,int32>)',
            's1_type': '<int32,int32>@SERVER',
            's2_type':
            '<<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>>@SERVER',
            's3_type': '<int32,int32,int32,int32,int32,int32>@SERVER',
            's4_type':
            '<<int32,int32>,<int32,int32,int32,int32,int32,int32>>@SERVER',
            's5_type': '<<int32,int32>,<>>@SERVER',
            's6_type': '<int32,int32>@SERVER',
            's7_type': '<>@SERVER',
            'update_type':
            '(<<int32,int32>,<int32,int32,int32,int32,int32,int32>> -> <<int32,int32>,<>>)',
            'work_type':
            '(<int32,<<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>>> -> <<int32,int32,int32,int32,int32,int32>,<>>)',
            'zero_type': '( -> <int32,int32,int32,int32,int32,int32>)'
        }
        # pyformat: enable

        self.assertEqual(actual, expected)
    def test_returns_type_info(self):
        ip = get_iterative_process_for_sum_example()
        initialize_tree = building_blocks.ComputationBuildingBlock.from_proto(
            ip.initialize._computation_proto)
        next_tree = building_blocks.ComputationBuildingBlock.from_proto(
            ip.next._computation_proto)
        initialize_tree = canonical_form_utils._replace_intrinsics_with_bodies(
            initialize_tree)
        next_tree = canonical_form_utils._replace_intrinsics_with_bodies(
            next_tree)
        before_broadcast, after_broadcast = (
            mapreduce_transformations.force_align_and_split_by_intrinsics(
                next_tree, [intrinsic_defs.FEDERATED_BROADCAST.uri]))
        before_aggregate, after_aggregate = (
            mapreduce_transformations.force_align_and_split_by_intrinsics(
                after_broadcast, [intrinsic_defs.FEDERATED_AGGREGATE.uri]))

        type_info = canonical_form_utils._get_type_info(
            initialize_tree, before_broadcast, after_broadcast,
            before_aggregate, after_aggregate)

        actual = collections.OrderedDict([
            (label, type_signature.compact_representation())
            for label, type_signature in type_info.items()
        ])
        # Note: THE CONTENTS OF THIS DICTIONARY IS NOT IMPORTANT. The purpose of
        # this test is not to assert that this value returned by
        # `canonical_form_utils._get_type_info`, but instead to act as a signal when
        # refactoring the code involved in compiling an `tff.utils.IterativeProcess`
        # into a `tff.backends.mapreduce.CanonicalForm`.
        # pyformat: disable
        expected = collections.OrderedDict(
            initialize_type='( -> <int32,int32>)',
            s1_type='<int32,int32>@SERVER',
            c1_type='{int32}@CLIENTS',
            s2_type=
            '<<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>>@SERVER',
            prepare_type=
            '(<int32,int32> -> <<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>>)',
            c2_type=
            '<<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>>@CLIENTS',
            c3_type=
            '{<int32,<<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>>>}@CLIENTS',
            c4_type='{<<int32,int32,int32,int32,int32,int32>,<>>}@CLIENTS',
            work_type=
            '(<int32,<<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>>> -> <<int32,int32,int32,int32,int32,int32>,<>>)',
            c5_type='{<int32,int32,int32,int32,int32,int32>}@CLIENTS',
            c6_type='{<>}@CLIENTS',
            zero_type='( -> <int32,int32,int32,int32,int32,int32>)',
            accumulate_type=
            '(<<int32,int32,int32,int32,int32,int32>,<int32,int32,int32,int32,int32,int32>> -> <int32,int32,int32,int32,int32,int32>)',
            merge_type=
            '(<<int32,int32,int32,int32,int32,int32>,<int32,int32,int32,int32,int32,int32>> -> <int32,int32,int32,int32,int32,int32>)',
            report_type=
            '(<int32,int32,int32,int32,int32,int32> -> <int32,int32,int32,int32,int32,int32>)',
            s3_type='<int32,int32,int32,int32,int32,int32>@SERVER',
            s4_type=
            '<<int32,int32>,<int32,int32,int32,int32,int32,int32>>@SERVER',
            s5_type='<<int32,int32>,<>>@SERVER',
            update_type=
            '(<<int32,int32>,<int32,int32,int32,int32,int32,int32>> -> <<int32,int32>,<>>)',
            s6_type='<int32,int32>@SERVER',
            s7_type='<>@SERVER',
        )
        # pyformat: enable

        self.assertEqual(actual, expected)
  def test_returns_type_info_for_sum_example(self):
    ip = get_iterative_process_for_sum_example()
    initialize_tree = building_blocks.ComputationBuildingBlock.from_proto(
        ip.initialize._computation_proto)
    next_tree = building_blocks.ComputationBuildingBlock.from_proto(
        ip.next._computation_proto)
    initialize_tree = canonical_form_utils._replace_intrinsics_with_bodies(
        initialize_tree)
    next_tree = canonical_form_utils._replace_intrinsics_with_bodies(next_tree)
    before_broadcast, after_broadcast = (
        transformations.force_align_and_split_by_intrinsics(
            next_tree, [intrinsic_defs.FEDERATED_BROADCAST.uri]))
    before_aggregate, after_aggregate = (
        transformations.force_align_and_split_by_intrinsics(
            after_broadcast, [
                intrinsic_defs.FEDERATED_AGGREGATE.uri,
                intrinsic_defs.FEDERATED_SECURE_SUM.uri,
            ]))

    type_info = canonical_form_utils._get_type_info(initialize_tree,
                                                    before_broadcast,
                                                    after_broadcast,
                                                    before_aggregate,
                                                    after_aggregate)

    actual = collections.OrderedDict([
        (label, type_signature.compact_representation())
        for label, type_signature in type_info.items()
    ])
    # Note: THE CONTENTS OF THIS DICTIONARY IS NOT IMPORTANT. The purpose of
    # this test is not to assert that this value returned by
    # `canonical_form_utils._get_type_info`, but instead to act as a signal when
    # refactoring the code involved in compiling an
    # `tff.templates.IterativeProcess` into a
    # `tff.backends.mapreduce.CanonicalForm`. If you are sure this needs to be
    # updated, one recommendation is to print 'k=\'v\',' while iterating over
    # the k-v pairs of the ordereddict.
    # pyformat: disable
    expected = collections.OrderedDict(
        initialize_type='( -> <int32,int32>)',
        s1_type='<int32,int32>@SERVER',
        c1_type='{int32}@CLIENTS',
        prepare_type='(<int32,int32> -> <<int32,int32>>)',
        s2_type='<<int32,int32>>@SERVER',
        c2_type='<<int32,int32>>@CLIENTS',
        c3_type='{<int32,<<int32,int32>>>}@CLIENTS',
        work_type='(<int32,<<int32,int32>>> -> <<<int32>,<int32>>,<>>)',
        c4_type='{<<<int32>,<int32>>,<>>}@CLIENTS',
        c5_type='{<<int32>,<int32>>}@CLIENTS',
        c6_type='{<int32>}@CLIENTS',
        c7_type='{<int32>}@CLIENTS',
        c8_type='{<>}@CLIENTS',
        zero_type='( -> <int32>)',
        accumulate_type='(<<int32>,<int32>> -> <int32>)',
        merge_type='(<<int32>,<int32>> -> <int32>)',
        report_type='(<int32> -> <int32>)',
        s3_type='<int32>@SERVER',
        bitwidth_type='( -> <int32>)',
        s4_type='<int32>@SERVER',
        s5_type='<<int32>,<int32>>@SERVER',
        s6_type='<<int32,int32>,<<int32>,<int32>>>@SERVER',
        update_type='(<<int32,int32>,<<int32>,<int32>>> -> <<int32,int32>,<>>)',
        s7_type='<<int32,int32>,<>>@SERVER',
        s8_type='<int32,int32>@SERVER',
        s9_type='<>@SERVER',
    )
    # pyformat: enable

    items = zip(actual.items(), expected.items())
    for (actual_key, actual_value), (expected_key, expected_value) in items:
      self.assertEqual(actual_key, expected_key)
      self.assertEqual(
          actual_value, expected_value,
          'The value of \'{}\' is not equal to the expected value'.format(
              actual_key))
Exemple #18
0
    def test_returns_tree(self):
        ip = get_iterative_process_for_sum_example_with_no_federated_secure_sum(
        )
        next_tree = building_blocks.ComputationBuildingBlock.from_proto(
            ip.next._computation_proto)
        next_tree = canonical_form_utils._replace_intrinsics_with_bodies(
            next_tree)

        before_aggregate, after_aggregate = canonical_form_utils._create_before_and_after_aggregate_for_no_federated_secure_sum(
            next_tree)

        before_federated_aggregate, after_federated_aggregate = (
            transformations.force_align_and_split_by_intrinsics(
                next_tree, [intrinsic_defs.FEDERATED_AGGREGATE.uri]))
        self.assertIsInstance(before_aggregate, building_blocks.Lambda)
        self.assertIsInstance(before_aggregate.result, building_blocks.Struct)
        self.assertLen(before_aggregate.result, 2)

        # trees_equal will fail if computations refer to unbound references, so we
        # create a new dummy computation to bind them.
        unbound_refs_in_before_agg_result = transformation_utils.get_map_of_unbound_references(
            before_aggregate.result[0])[before_aggregate.result[0]]
        unbound_refs_in_before_fed_agg_result = transformation_utils.get_map_of_unbound_references(
            before_federated_aggregate.result)[
                before_federated_aggregate.result]

        dummy_data = building_blocks.Data('data',
                                          computation_types.AbstractType('T'))

        blk_binding_refs_in_before_agg = building_blocks.Block(
            [(name, dummy_data) for name in unbound_refs_in_before_agg_result],
            before_aggregate.result[0])
        blk_binding_refs_in_before_fed_agg = building_blocks.Block(
            [(name, dummy_data)
             for name in unbound_refs_in_before_fed_agg_result],
            before_federated_aggregate.result)

        self.assertTrue(
            tree_analysis.trees_equal(blk_binding_refs_in_before_agg,
                                      blk_binding_refs_in_before_fed_agg))

        # pyformat: disable
        self.assertEqual(
            before_aggregate.result[1].formatted_representation(), '<\n'
            '  federated_value_at_clients(<>),\n'
            '  <>\n'
            '>')
        # pyformat: enable

        self.assertIsInstance(after_aggregate, building_blocks.Lambda)
        self.assertIsInstance(after_aggregate.result, building_blocks.Call)

        self.assertTrue(
            tree_analysis.trees_equal(after_aggregate.result.function,
                                      after_federated_aggregate))

        # pyformat: disable
        self.assertEqual(
            after_aggregate.result.argument.formatted_representation(), '<\n'
            '  _var1[0],\n'
            '  _var1[1][0]\n'
            '>')