Example #1
0
 def test_removes_federated_types_under_function(self):
     int_type = tf.int32
     server_int_type = computation_types.at_server(int_type)
     int_ref = building_blocks.Reference('x', int_type)
     int_id = building_blocks.Lambda('x', int_type, int_ref)
     fed_ref = building_blocks.Reference('x', server_int_type)
     applied_id = building_block_factory.create_federated_map_or_apply(
         int_id, fed_ref)
     before = building_block_factory.create_federated_map_or_apply(
         int_id, applied_id)
     after, modified = tree_transformations.strip_placement(before)
     self.assertTrue(modified)
     self.assert_has_no_intrinsics_nor_federated_types(after)
Example #2
0
 def test_strip_placement_removes_federated_maps(self):
     int_type = computation_types.TensorType(tf.int32)
     clients_int_type = computation_types.at_clients(int_type)
     int_ref = building_blocks.Reference('x', int_type)
     int_id = building_blocks.Lambda('x', int_type, int_ref)
     fed_ref = building_blocks.Reference('x', clients_int_type)
     applied_id = building_block_factory.create_federated_map_or_apply(
         int_id, fed_ref)
     before = building_block_factory.create_federated_map_or_apply(
         int_id, applied_id)
     after, modified = tree_transformations.strip_placement(before)
     self.assertTrue(modified)
     self.assert_has_no_intrinsics_nor_federated_types(after)
     type_test_utils.assert_types_identical(before.type_signature,
                                            clients_int_type)
     type_test_utils.assert_types_identical(after.type_signature, int_type)
     self.assertEqual(
         before.compact_representation(),
         'federated_map(<(x -> x),federated_map(<(x -> x),x>)>)')
     self.assertEqual(after.compact_representation(),
                      '(x -> x)((x -> x)(x))')
Example #3
0
 def test_reduces_federated_apply_to_equivalent_function(self):
   lam = building_blocks.Lambda('x', tf.int32,
                                building_blocks.Reference('x', tf.int32))
   arg = building_blocks.Reference(
       'arg', computation_types.FederatedType(tf.int32, placements.CLIENTS))
   mapped_fn = building_block_factory.create_federated_map_or_apply(lam, arg)
   extracted_tf = mapreduce_transformations.consolidate_and_extract_local_processing(
       mapped_fn)
   self.assertIsInstance(extracted_tf, building_blocks.CompiledComputation)
   executable_tf = computation_wrapper_instances.building_block_to_computation(
       extracted_tf)
   executable_lam = computation_wrapper_instances.building_block_to_computation(
       lam)
   for k in range(10):
     self.assertEqual(executable_tf(k), executable_lam(k))
Example #4
0
def _construct_selection_from_federated_tuple(
        federated_tuple: building_blocks.ComputationBuildingBlock, index: int,
        name_generator) -> building_blocks.ComputationBuildingBlock:
    """Selects the index `selected_index` from `federated_tuple`."""
    federated_tuple.type_signature.check_federated()
    member_type = federated_tuple.type_signature.member
    member_type.check_struct()
    param_name = next(name_generator)
    selecting_function = building_blocks.Lambda(
        param_name, member_type,
        building_blocks.Selection(
            building_blocks.Reference(param_name, member_type),
            index=index,
        ))
    return building_block_factory.create_federated_map_or_apply(
        selecting_function, federated_tuple)
Example #5
0
 def test_reduces_federated_apply_to_equivalent_function(self):
     lam = building_blocks.Lambda('x', tf.int32,
                                  building_blocks.Reference('x', tf.int32))
     arg_type = computation_types.FederatedType(tf.int32,
                                                placements.CLIENTS)
     arg = building_blocks.Reference('arg', arg_type)
     map_block = building_block_factory.create_federated_map_or_apply(
         lam, arg)
     mapping_fn = building_blocks.Lambda('arg', arg_type, map_block)
     extracted_tf = compiler.consolidate_and_extract_local_processing(
         mapping_fn, DEFAULT_GRAPPLER_CONFIG)
     self.assertIsInstance(extracted_tf,
                           building_blocks.CompiledComputation)
     executable_tf = computation_impl.ConcreteComputation.from_building_block(
         extracted_tf)
     executable_lam = computation_impl.ConcreteComputation.from_building_block(
         lam)
     for k in range(10):
         self.assertEqual(executable_tf(k), executable_lam(k))
def _extract_update(after_aggregate):
  """Extracts `update` from `after_aggregate`.

  This function is intended to be used by
  `get_canonical_form_for_iterative_process` only. As a result, this function
  does not assert that `after_aggregate` has the expected structure, the
  caller is expected to perform these checks before calling this function.

  Args:
    after_aggregate: The second result of splitting `after_broadcast` on
      aggregate intrinsics.

  Returns:
    `update` as specified by `canonical_form.CanonicalForm`, an instance of
    `building_blocks.CompiledComputation`.

  Raises:
    transformations.CanonicalFormCompilationError: If we extract an AST of the
      wrong type.
  """
  s7_elements_in_after_aggregate_result = [0, 1]
  s7_output_extracted = transformations.select_output_from_lambda(
      after_aggregate, s7_elements_in_after_aggregate_result)
  s7_output_zipped = building_blocks.Lambda(
      s7_output_extracted.parameter_name, s7_output_extracted.parameter_type,
      building_block_factory.create_federated_zip(s7_output_extracted.result))
  s6_elements_in_after_aggregate_parameter = [[0, 0, 0], [1, 0], [1, 1]]
  s6_to_s7_computation = (
      transformations.zip_selection_as_argument_to_lower_level_lambda(
          s7_output_zipped,
          s6_elements_in_after_aggregate_parameter).result.function)

  # TODO(b/148942011): The transformation
  # `zip_selection_as_argument_to_lower_level_lambda` does not support selecting
  # from nested structures, therefore we need to pack the type signature
  # `<s1, s3, s4>` as `<s1, <s3, s4>>`.
  name_generator = building_block_factory.unique_name_generator(
      s6_to_s7_computation)

  pack_ref_name = next(name_generator)
  pack_ref_type = computation_types.StructType([
      s6_to_s7_computation.parameter_type.member[0],
      computation_types.StructType([
          s6_to_s7_computation.parameter_type.member[1],
          s6_to_s7_computation.parameter_type.member[2],
      ]),
  ])
  pack_ref = building_blocks.Reference(pack_ref_name, pack_ref_type)
  sel_s1 = building_blocks.Selection(pack_ref, index=0)
  sel = building_blocks.Selection(pack_ref, index=1)
  sel_s3 = building_blocks.Selection(sel, index=0)
  sel_s4 = building_blocks.Selection(sel, index=1)
  result = building_blocks.Struct([sel_s1, sel_s3, sel_s4])
  pack_fn = building_blocks.Lambda(pack_ref.name, pack_ref.type_signature,
                                   result)
  ref_name = next(name_generator)
  ref_type = computation_types.FederatedType(pack_ref_type, placements.SERVER)
  ref = building_blocks.Reference(ref_name, ref_type)
  unpacked_args = building_block_factory.create_federated_map_or_apply(
      pack_fn, ref)
  call = building_blocks.Call(s6_to_s7_computation, unpacked_args)
  fn = building_blocks.Lambda(ref.name, ref.type_signature, call)
  return transformations.consolidate_and_extract_local_processing(fn)
Example #7
0
def _extract_update(after_aggregate, grappler_config):
    """Extracts `update` from `after_aggregate`.

  This function is intended to be used by
  `get_map_reduce_form_for_iterative_process` only. As a result, this function
  does not assert that `after_aggregate` has the expected structure, the
  caller is expected to perform these checks before calling this function.

  Args:
    after_aggregate: The second result of splitting `after_broadcast` on
      aggregate intrinsics.
    grappler_config: An instance of `tf.compat.v1.ConfigProto` to configure
      Grappler graph optimization.

  Returns:
    `update` as specified by `forms.MapReduceForm`, an instance of
    `building_blocks.CompiledComputation`.

  Raises:
    compiler.MapReduceFormCompilationError: If we extract an AST of the wrong
      type.
  """
    after_aggregate_zipped = building_blocks.Lambda(
        after_aggregate.parameter_name, after_aggregate.parameter_type,
        building_block_factory.create_federated_zip(after_aggregate.result))
    # `create_federated_zip` doesn't have unique reference names, but we need
    # them for `as_function_of_some_federated_subparameters`.
    after_aggregate_zipped, _ = tree_transformations.uniquify_reference_names(
        after_aggregate_zipped)
    server_state_index = ('original_arg', 'original_arg', 0)
    aggregate_result_index = ('intrinsic_results',
                              'federated_aggregate_result')
    secure_sum_bitwidth_result_index = ('intrinsic_results',
                                        'federated_secure_sum_bitwidth_result')
    secure_sum_result_index = ('intrinsic_results',
                               'federated_secure_sum_result')
    secure_modular_sum_result_index = ('intrinsic_results',
                                       'federated_secure_modular_sum_result')
    update_with_flat_inputs = _as_function_of_some_federated_subparameters(
        after_aggregate_zipped, (
            server_state_index,
            aggregate_result_index,
            secure_sum_bitwidth_result_index,
            secure_sum_result_index,
            secure_modular_sum_result_index,
        ))

    # TODO(b/148942011): The transformation
    # `zip_selection_as_argument_to_lower_level_lambda` does not support selecting
    # from nested structures, therefore we need to transform the input from
    # <server_state, <aggregation_results...>> into
    # <server_state, aggregation_results...>
    # unpack = <v, <...>> -> <v, ...>
    name_generator = building_block_factory.unique_name_generator(
        update_with_flat_inputs)
    unpack_param_name = next(name_generator)
    original_param_type = update_with_flat_inputs.parameter_type.member
    unpack_param_type = computation_types.StructType([
        original_param_type[0],
        computation_types.StructType(original_param_type[1:]),
    ])
    unpack_param_ref = building_blocks.Reference(unpack_param_name,
                                                 unpack_param_type)
    select = lambda bb, i: building_blocks.Selection(bb, index=i)
    unpack = building_blocks.Lambda(
        unpack_param_name, unpack_param_type,
        building_blocks.Struct([select(unpack_param_ref, 0)] + [
            select(select(unpack_param_ref, 1), i)
            for i in range(len(original_param_type) - 1)
        ]))

    # update = v -> update_with_flat_inputs(federated_map(unpack, v))
    param_name = next(name_generator)
    param_type = computation_types.at_server(unpack_param_type)
    param_ref = building_blocks.Reference(param_name, param_type)
    update = building_blocks.Lambda(
        param_name, param_type,
        building_blocks.Call(
            update_with_flat_inputs,
            building_block_factory.create_federated_map_or_apply(
                unpack, param_ref)))
    return compiler.consolidate_and_extract_local_processing(
        update, grappler_config)
Example #8
0
def force_align_and_split_by_intrinsics(
    comp: building_blocks.Lambda,
    intrinsic_defaults: List[building_blocks.Call],
) -> Tuple[building_blocks.Lambda, building_blocks.Lambda]:
    """Divides `comp` into before-and-after of calls to one ore more intrinsics.

  The input computation `comp` must have the following properties:

  1. The computation `comp` is completely self-contained, i.e., there are no
     references to arguments introduced in a scope external to `comp`.

  2. `comp`'s return value must not contain uncalled lambdas.

  3. None of the calls to intrinsics in `intrinsic_defaults` may be
     within a lambda passed to another external function (intrinsic or
     compiled computation).

  4. No argument passed to an intrinsic in `intrinsic_defaults` may be
     dependent on the result of a call to an intrinsic in
     `intrinsic_uris_and_defaults`.

  5. All intrinsics in `intrinsic_defaults` must have "merge-able" arguments.
     Structs will be merged element-wise, federated values will be zipped, and
     functions will be composed:
       `f = lambda f1_arg, f2_arg: (f1(f1_arg), f2(f2_arg))`

  6. All intrinsics in `intrinsic_defaults` must return a single federated value
     whose member is the merged result of any merged calls, i.e.:
       `f(merged_arg).member = (f1(f1_arg).member, f2(f2_arg).member)`

  Under these conditions, (and assuming `comp` is a computation with non-`None`
  argument), this function will return two `building_blocks.Lambda`s `before`
  and `after` such that `comp` is semantically equivalent to the following
  expression*:

  ```
  (arg -> (let
    x=before(arg),
    y=intrinsic1(x[0]),
    z=intrinsic2(x[1]),
    ...
   in after(<arg, <y,z,...>>)))
  ```

  If `comp` is a no-arg computation, the returned computations will be
  equivalent (in the same sense as above) to:
  ```
  ( -> (let
    x=before(),
    y=intrinsic1(x[0]),
    z=intrinsic2(x[1]),
    ...
   in after(<y,z,...>)))
  ```

  *Note that these expressions may not be entirely equivalent under
  nondeterminism since there is no way in this case to handle computations in
  which `before` creates a random variable that is then used in `after`, since
  the only way for state to pass from `before` to `after` is for it to travel
  through one of the intrinsics.

  In this expression, there is only a single call to `intrinsic` that results
  from consolidating all occurrences of this intrinsic in the original `comp`.
  All logic in `comp` that produced inputs to any these intrinsic calls is now
  consolidated and jointly encapsulated in `before`, which produces a combined
  argument to all the original calls. All the remaining logic in `comp`,
  including that which consumed the outputs of the intrinsic calls, must have
  been encapsulated into `after`.

  If the original computation `comp` had type `(T -> U)`, then `before` and
  `after` would be `(T -> X)` and `(<T,Y> -> U)`, respectively, where `X` is
  the type of the argument to the single combined intrinsic call above. Note
  that `after` takes the output of the call to the intrinsic as well as the
  original argument to `comp`, as it may be dependent on both.

  Args:
    comp: The instance of `building_blocks.Lambda` that serves as the input to
      this transformation, as described above.
    intrinsic_defaults: A list of intrinsics with which to split the
      computation, provided as a list of `Call`s to insert if no intrinsic with
      a matching URI is found. Intrinsics in this list will be merged, and
      `comp` will be split across them.

  Returns:
    A pair of the form `(before, after)`, where each of `before` and `after`
    is a `building_blocks.ComputationBuildingBlock` instance that represents a
    part of the result as specified above.
  """
    py_typecheck.check_type(comp, building_blocks.Lambda)
    py_typecheck.check_type(intrinsic_defaults, list)
    comp_repr = comp.compact_representation()

    # Flatten `comp` to call-dominant form so that we're working with just a
    # linear list of intrinsic calls with no indirection via tupling, selection,
    # blocks, called lambdas, or references.
    comp = to_call_dominant(comp)

    # CDF can potentially return blocks if there are variables not dependent on
    # the top-level parameter. We normalize these away.
    if not comp.is_lambda():
        comp.check_block()
        comp.result.check_lambda()
        if comp.result.result.is_block():
            additional_locals = comp.result.result.locals
            result = comp.result.result.result
        else:
            additional_locals = []
            result = comp.result.result
        # Note: without uniqueness, a local in `comp.locals` could potentially
        # shadow `comp.result.parameter_name`. However, `to_call_dominant`
        # above ensure that names are unique, as it ends in a call to
        # `uniquify_reference_names`.
        comp = building_blocks.Lambda(
            comp.result.parameter_name, comp.result.parameter_type,
            building_blocks.Block(comp.locals + additional_locals, result))
    comp.check_lambda()

    # Simple computations with no intrinsic calls won't have a block.
    # Normalize these as well.
    if not comp.result.is_block():
        comp = building_blocks.Lambda(comp.parameter_name, comp.parameter_type,
                                      building_blocks.Block([], comp.result))
    comp.result.check_block()

    name_generator = building_block_factory.unique_name_generator(comp)

    intrinsic_uris = set(call.function.uri for call in intrinsic_defaults)
    deps = _compute_intrinsic_dependencies(intrinsic_uris, comp.parameter_name,
                                           comp.result.locals, comp_repr)
    merged_intrinsics = _compute_merged_intrinsics(intrinsic_defaults,
                                                   deps.uri_to_locals,
                                                   name_generator)

    # Note: the outputs are labeled as `{uri}_param for convenience, e.g.
    # `federated_secure_sum_param: ...`.
    before = building_blocks.Lambda(
        comp.parameter_name, comp.parameter_type,
        building_blocks.Block(
            deps.locals_not_dependent_on_intrinsics,
            building_blocks.Struct([(f'{merged.uri}_param', merged.args)
                                    for merged in merged_intrinsics])))

    after_param_name = next(name_generator)
    if comp.parameter_type is not None:
        # TODO(b/147499373): If None-arguments were uniformly represented as empty
        # tuples, we would be able to avoid this (and related) ugly casing.
        after_param_type = computation_types.StructType([
            ('original_arg', comp.parameter_type),
            ('intrinsic_results',
             computation_types.StructType([(f'{merged.uri}_result',
                                            merged.return_type)
                                           for merged in merged_intrinsics])),
        ])
    else:
        after_param_type = computation_types.StructType([
            ('intrinsic_results',
             computation_types.StructType([(f'{merged.uri}_result',
                                            merged.return_type)
                                           for merged in merged_intrinsics])),
        ])
    after_param_ref = building_blocks.Reference(after_param_name,
                                                after_param_type)
    if comp.parameter_type is not None:
        original_arg_bindings = [
            (comp.parameter_name,
             building_blocks.Selection(after_param_ref, name='original_arg'))
        ]
    else:
        original_arg_bindings = []

    unzip_bindings = []
    for merged in merged_intrinsics:
        if merged.unpack_to_locals:
            intrinsic_result = building_blocks.Selection(
                building_blocks.Selection(after_param_ref,
                                          name='intrinsic_results'),
                name=f'{merged.uri}_result')
            select_param_type = intrinsic_result.type_signature.member
            for i, binding_name in enumerate(merged.unpack_to_locals):
                select_param_name = next(name_generator)
                select_param_ref = building_blocks.Reference(
                    select_param_name, select_param_type)
                selected = building_block_factory.create_federated_map_or_apply(
                    building_blocks.Lambda(
                        select_param_name, select_param_type,
                        building_blocks.Selection(select_param_ref, index=i)),
                    intrinsic_result)
                unzip_bindings.append((binding_name, selected))
    after = building_blocks.Lambda(
        after_param_name,
        after_param_type,
        building_blocks.Block(
            original_arg_bindings +
            # Note that we must duplicate `locals_not_dependent_on_intrinsics`
            # across both the `before` and `after` computations since both can
            # rely on them, and there's no way to plumb results from `before`
            # through to `after` except via one of the intrinsics being split
            # upon. In MapReduceForm, this limitation is caused by the fact that
            # `prepare` has no output which serves as an input to `report`.
            deps.locals_not_dependent_on_intrinsics + unzip_bindings +
            deps.locals_dependent_on_intrinsics,
            comp.result.result))
    try:
        tree_analysis.check_has_unique_names(before)
        tree_analysis.check_has_unique_names(after)
    except tree_analysis.NonuniqueNameError as e:
        raise ValueError(
            f'nonunique names in result of splitting\n{comp}') from e
    return before, after