def test_with_structure_replacing_federated_zip(self):
     fed_tuple = building_blocks.Reference(
         'tup',
         computation_types.FederatedType([tf.int32] * 3,
                                         placements.CLIENTS))
     unzipped = building_block_factory.create_federated_unzip(fed_tuple)
     zipped = building_block_factory.create_federated_zip(unzipped)
     placement_unwrapped, _ = tree_transformations.unwrap_placement(zipped)
     placement_gone = placement_unwrapped.argument
     lambdas_and_blocks_removed, modified = transformations.remove_lambdas_and_blocks(
         placement_gone)
     self.assertTrue(modified)
     self.assertNoLambdasOrBlocks(lambdas_and_blocks_removed)
Пример #2
0
def consolidate_and_extract_local_processing(comp, grappler_config_proto):
    """Consolidates all the local processing in `comp`.

  The input computation `comp` must have the following properties:

  1. The output of `comp` may be of a federated type or unplaced. We refer to
     the placement `p` of that type as the placement of `comp`. There is no
     placement anywhere in the body of `comp` different than `p`. If `comp`
     is of a functional type, and has a parameter, the type of that parameter
     is a federated type placed at `p` as well, or unplaced if the result of
     the function is unplaced.

  2. The only intrinsics that may appear in the body of `comp` are those that
     manipulate data locally within the same placement. The exact set of these
     intrinsics will be gradually updated. At the moment, we support only the
     following:

     * Either `federated_apply` or `federated_map`, depending on whether `comp`
       is `SERVER`- or `CLIENTS`-placed. `federated_map_all_equal` is also
       allowed in the `CLIENTS`-placed case.

     * Either `federated_value_at_server` or `federated_value_at_clients`,
       likewise placement-dependent.

     * Either `federated_zip_at_server` or `federated_zip_at_clients`, again
       placement-dependent.

     Anything else, including `sequence_*` operators, should have been reduced
     already prior to calling this function.

  3. There are no lambdas in the body of `comp` except for `comp` itself being
     possibly a (top-level) lambda. All other lambdas must have been reduced.
     This requirement may eventually be relaxed by embedding lambda reducer into
     this helper method.

  4. If `comp` is of a functional type, it is either an instance of
     `building_blocks.CompiledComputation`, in which case there is nothing for
     us to do here, or a `building_blocks.Lambda`.

  5. There is at most one unbound reference under `comp`, and this is only
     allowed in the case that `comp` is not of a functional type.

  Aside from the intrinsics specified above, and the possibility of allowing
  lambdas, blocks, and references given the constraints above, the remaining
  constructs in `comp` include a combination of tuples, selections, calls, and
  sections of TensorFlow (as `CompiledComputation`s). This helper function does
  contain the logic to consolidate these constructs.

  The output of this transformation is always a single section of TensorFlow,
  which we henceforth refer to as `result`, the exact form of which depends on
  the placement of `comp` and the presence or absence of an argument.

  a. If there is no argument in `comp`, and `comp` is `SERVER`-placed, then
     the `result` is such that `comp` can be equivalently represented as:

     ```
     federated_value_at_server(result())
     ```

  b. If there is no argument in `comp`, and `comp` is `CLIENTS`-placed, then
     the `result` is such that `comp` can be equivalently represented as:

     ```
     federated_value_at_clients(result())
     ```

  c. If there is an argument in `comp`, and `comp` is `SERVER`-placed, then
     the `result` is such that `comp` can be equivalently represented as:

     ```
     (arg -> federated_apply(<result, arg>))
     ```

  d. If there is an argument in `comp`, and `comp` is `CLIENTS`-placed, then
     the `result` is such that `comp` can be equivalently represented as:

     ```
     (arg -> federated_map(<result, arg>))
     ```

  If the type of `comp` is `T@p` (thus `comp` is non-functional), the type of
  `result` is `T`, where `p` is the specific (concrete) placement of `comp`.

  If the type of `comp` is `(T@p -> U@p)`, then the type of `result` must be
  `(T -> U)`, where `p` is again a specific placement.

  Args:
    comp: An instance of `building_blocks.ComputationBuildingBlock` that serves
      as the input to this transformation, as described above.
    grappler_config_proto: An instance of `tf.compat.v1.ConfigProto` to
      configure Grappler graph optimization of the generated TensorFlow graph.
      If `grappler_config_proto` has
      `graph_options.rewrite_options.disable_meta_optimizer=True`, Grappler is
      bypassed.

  Returns:
    An instance of `building_blocks.CompiledComputation` that holds the
    TensorFlow section produced by this extraction step, as described above.
  """
    py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock)
    if comp.type_signature.is_function():
        if comp.is_compiled_computation():
            return comp
        elif not comp.is_lambda():
            # We normalize on lambdas for ease of calling unwrap_placement below.
            # The constructed lambda here simply forwards its argument to `comp`.
            arg = building_blocks.Reference(
                next(building_block_factory.unique_name_generator(comp)),
                comp.type_signature.parameter)
            called_fn = building_blocks.Call(comp, arg)
            comp = building_blocks.Lambda(arg.name, arg.type_signature,
                                          called_fn)
        if comp.type_signature.result.is_federated():
            unwrapped, _ = tree_transformations.unwrap_placement(comp.result)
            # Unwrapped can be a call to `federated_value_at_P`, or
            # `federated_apply/map`.
            if unwrapped.function.uri in (intrinsic_defs.FEDERATED_APPLY.uri,
                                          intrinsic_defs.FEDERATED_MAP.uri):
                extracted = parse_tff_to_tf(unwrapped.argument[0],
                                            grappler_config_proto)
                check_extraction_result(unwrapped.argument[0], extracted)
                return extracted
            else:
                member_type = None if comp.parameter_type is None else comp.parameter_type.member
                rebound = building_blocks.Lambda(comp.parameter_name,
                                                 member_type,
                                                 unwrapped.argument)
                extracted = parse_tff_to_tf(rebound, grappler_config_proto)
                check_extraction_result(rebound, extracted)
                return extracted
        else:
            extracted = parse_tff_to_tf(comp, grappler_config_proto)
            check_extraction_result(comp, extracted)
            return extracted
    elif comp.type_signature.is_federated():
        unwrapped, _ = tree_transformations.unwrap_placement(comp)
        # Unwrapped can be a call to `federated_value_at_P`, or
        # `federated_apply/map`.
        if unwrapped.function.uri in (intrinsic_defs.FEDERATED_APPLY.uri,
                                      intrinsic_defs.FEDERATED_MAP.uri):
            extracted = parse_tff_to_tf(unwrapped.argument[0],
                                        grappler_config_proto)
            check_extraction_result(unwrapped.argument[0], extracted)
            return extracted
        else:
            extracted = parse_tff_to_tf(unwrapped.argument,
                                        grappler_config_proto)
            check_extraction_result(unwrapped.argument, extracted)
            return extracted.function
    else:
        called_tf = parse_tff_to_tf(comp, grappler_config_proto)
        check_extraction_result(comp, called_tf)
        return called_tf.function