def test_ok_on_nested_lambdas_with_different_variable_name(self):
     ref_to_x = building_blocks.Reference('x', tf.int32)
     lambda_1 = building_blocks.Lambda('x', tf.int32, ref_to_x)
     lambda_2 = building_blocks.Lambda('y', tf.int32, lambda_1)
     tree_analysis.check_has_unique_names(lambda_2)
def _create_before_and_after_aggregate_for_no_federated_secure_sum(tree):
  r"""Creates a before and after aggregate computations for the given `tree`.

  Lambda
  |
  Tuple
  |
  [Comp, Tuple]
         |
         [Tuple, []]
          |
          []

       Lambda(x)
       |
       Call
      /    \
  Comp      Tuple
            |
            [Sel(0),      Sel(0)]
            /            /
         Ref(x)    Sel(1)
                  /
            Ref(x)

  In the first AST, the first element returned by `Lambda`, `Comp`, is the
  result of the before aggregate returned by force aligning and splitting `tree`
  by `intrinsic_defs.FEDERATED_AGGREGATE.uri` and the second element returned by
  `Lambda` is an empty structure that represents the argument to the secure sum
  intrinsic. Therefore, the first AST has a type signature satisfying the
  requirements of before aggregate.

  In the second AST, `Comp` is the after aggregate returned by force aligning
  and splitting `tree` by intrinsic_defs.FEDERATED_AGGREGATE.uri; `Lambda` has a
  type signature satisfying the requirements of after aggregate; and the
  argument passed to `Comp` is a selection from the parameter of `Lambda` which
  intentionally drops `s4` on the floor.

  This function is intended to be used by
  `get_canonical_form_for_iterative_process` to create before and after
  broadcast computations for the given `tree` when there is no
  `intrinsic_defs.FEDERATED_SECURE_SUM` in `tree`. As a result, this function
  does not assert that there is no `intrinsic_defs.FEDERATED_SECURE_SUM` in
  `tree` and it does not assert that `tree` has the expected structure, the
  caller is expected to perform these checks before calling this function.

  Args:
    tree: An instance of `building_blocks.ComputationBuildingBlock`.

  Returns:
    A pair of the form `(before, after)`, where each of `before` and `after`
    is a `tff_framework.ComputationBuildingBlock` that represents a part of the
    result as specified by
    `transformations.force_align_and_split_by_intrinsics`.
  """
  name_generator = building_block_factory.unique_name_generator(tree)

  before_aggregate, after_aggregate = (
      transformations.force_align_and_split_by_intrinsics(
          tree, [intrinsic_defs.FEDERATED_AGGREGATE.uri]))

  empty_tuple = building_blocks.Struct([])
  value = building_block_factory.create_federated_value(empty_tuple,
                                                        placements.CLIENTS)
  bitwidth = empty_tuple
  args = building_blocks.Struct([value, bitwidth])
  result = building_blocks.Struct([before_aggregate.result, args])
  before_aggregate = building_blocks.Lambda(before_aggregate.parameter_name,
                                            before_aggregate.parameter_type,
                                            result)

  ref_name = next(name_generator)
  s4_type = computation_types.FederatedType([], placements.SERVER)
  ref_type = computation_types.StructType([
      after_aggregate.parameter_type[0],
      computation_types.StructType([
          after_aggregate.parameter_type[1],
          s4_type,
      ]),
  ])
  ref = building_blocks.Reference(ref_name, ref_type)
  sel_arg = building_blocks.Selection(ref, index=0)
  sel = building_blocks.Selection(ref, index=1)
  sel_s3 = building_blocks.Selection(sel, index=0)
  arg = building_blocks.Struct([sel_arg, sel_s3])
  call = building_blocks.Call(after_aggregate, arg)
  after_aggregate = building_blocks.Lambda(ref.name, ref.type_signature, call)

  return before_aggregate, after_aggregate
示例#3
0
 def test_raises_type_error_with_int_excluding(self):
     ref = building_blocks.Reference('a', tf.int32)
     fn = building_blocks.Lambda(ref.name, ref.type_signature, ref)
     with self.assertRaises(TypeError):
         tree_analysis.contains_no_unbound_references(fn, 1)
示例#4
0
def get_canonical_form_for_iterative_process(iterative_process):
  """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process.

  This function transforms computations from the input `iterative_process` into
  an instance of `tff.backends.mapreduce.CanonicalForm`.

  Args:
    iterative_process: An instance of `tff.utils.IterativeProcess`.

  Returns:
    An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this
    process.

  Raises:
    TypeError: If the arguments are of the wrong types.
    transformations.CanonicalFormCompilationError: If the compilation
      process fails.
  """
  py_typecheck.check_type(iterative_process, computation_utils.IterativeProcess)

  initialize_comp = building_blocks.ComputationBuildingBlock.from_proto(
      iterative_process.initialize._computation_proto)  # pylint: disable=protected-access

  next_comp = building_blocks.ComputationBuildingBlock.from_proto(
      iterative_process.next._computation_proto)  # pylint: disable=protected-access

  if not (isinstance(next_comp.type_signature.parameter,
                     computation_types.NamedTupleType) and
          isinstance(next_comp.type_signature.result,
                     computation_types.NamedTupleType)):
    raise TypeError(
        'Any IterativeProcess compatible with CanonicalForm must '
        'have a `next` function which takes and returns instances '
        'of `tff.NamedTupleType`; your next function takes '
        'parameters of type {} and returns results of type {}'.format(
            next_comp.type_signature.parameter,
            next_comp.type_signature.result))

  if len(next_comp.type_signature.result) == 2:
    next_result = next_comp.result
    if isinstance(next_result, building_blocks.Tuple):
      dummy_clients_metrics_appended = building_blocks.Tuple([
          next_result[0],
          next_result[1],
          intrinsics.federated_value([], placements.CLIENTS)._comp  # pylint: disable=protected-access
      ])
    else:
      dummy_clients_metrics_appended = building_blocks.Tuple([
          building_blocks.Selection(next_result, index=0),
          building_blocks.Selection(next_result, index=1),
          intrinsics.federated_value([], placements.CLIENTS)._comp  # pylint: disable=protected-access
      ])
    next_comp = building_blocks.Lambda(next_comp.parameter_name,
                                       next_comp.parameter_type,
                                       dummy_clients_metrics_appended)

  initialize_comp = replace_intrinsics_with_bodies(initialize_comp)
  next_comp = replace_intrinsics_with_bodies(next_comp)

  tree_analysis.check_intrinsics_whitelisted_for_reduction(initialize_comp)
  tree_analysis.check_intrinsics_whitelisted_for_reduction(next_comp)
  tree_analysis.check_broadcast_not_dependent_on_aggregate(next_comp)

  before_broadcast, after_broadcast = (
      transformations.force_align_and_split_by_intrinsic(
          next_comp, intrinsic_defs.FEDERATED_BROADCAST.uri))

  before_aggregate, after_aggregate = (
      transformations.force_align_and_split_by_intrinsic(
          after_broadcast, intrinsic_defs.FEDERATED_AGGREGATE.uri))

  init_info_packed = pack_initialize_comp_type_signature(
      initialize_comp.type_signature)

  next_info_packed = pack_next_comp_type_signature(next_comp.type_signature,
                                                   init_info_packed)

  before_broadcast_info_packed = (
      check_and_pack_before_broadcast_type_signature(
          before_broadcast.type_signature, next_info_packed))

  before_aggregate_info_packed = (
      check_and_pack_before_aggregate_type_signature(
          before_aggregate.type_signature, before_broadcast_info_packed))

  canonical_form_types = check_and_pack_after_aggregate_type_signature(
      after_aggregate.type_signature, before_aggregate_info_packed)

  initialize = transformations.consolidate_and_extract_local_processing(
      initialize_comp)

  if not (isinstance(initialize, building_blocks.CompiledComputation) and
          initialize.type_signature.result ==
          canonical_form_types['initialize_type'].member):
    raise transformations.CanonicalFormCompilationError(
        'Compilation of initialize has failed. Expected to extract a '
        '`building_blocks.CompiledComputation` of type {}, instead we extracted '
        'a {} of type {}.'.format(next_comp.type_signature.parameter[0],
                                  type(initialize),
                                  initialize.type_signature.result))

  prepare = extract_prepare(before_broadcast, canonical_form_types)

  work = extract_work(before_aggregate, after_aggregate, canonical_form_types)

  zero_noarg_function, accumulate, merge, report = extract_aggregate_functions(
      before_aggregate, canonical_form_types)

  update = extract_update(after_aggregate, canonical_form_types)

  cf = canonical_form.CanonicalForm(
      computation_wrapper_instances.building_block_to_computation(initialize),
      computation_wrapper_instances.building_block_to_computation(prepare),
      computation_wrapper_instances.building_block_to_computation(work),
      computation_wrapper_instances.building_block_to_computation(
          zero_noarg_function),
      computation_wrapper_instances.building_block_to_computation(accumulate),
      computation_wrapper_instances.building_block_to_computation(merge),
      computation_wrapper_instances.building_block_to_computation(report),
      computation_wrapper_instances.building_block_to_computation(update))
  return cf
def _create_before_and_after_broadcast_for_no_broadcast(tree):
  r"""Creates a before and after broadcast computations for the given `tree`.

  This function returns the two ASTs:

  Lambda
  |
  Tuple
  |
  []

       Lambda(x)
       |
       Call
      /    \
  Comp      Sel(0)
           /
     Ref(x)

  The first AST is an empty structure that has a type signature satisfying the
  requirements of before broadcast.

  In the second AST, `Comp` is `tree`; `Lambda` has a type signature satisfying
  the requirements of after broadcast; and the argument passed to `Comp` is a
  selection from the parameter of `Lambda` which intentionally drops `c2` on the
  floor.

  This function is intended to be used by
  `get_canonical_form_for_iterative_process` to create before and after
  broadcast computations for the given `tree` when there is no
  `intrinsic_defs.FEDERATED_BROADCAST` in `tree`. As a result, this function
  does not assert that there is no `intrinsic_defs.FEDERATED_BROADCAST` in
  `tree` and it does not assert that `tree` has the expected structure, the
  caller is expected to perform these checks before calling this function.

  Args:
    tree: An instance of `building_blocks.ComputationBuildingBlock`.

  Returns:
    A pair of the form `(before, after)`, where each of `before` and `after`
    is a `tff_framework.ComputationBuildingBlock` that represents a part of the
    result as specified by
    `transformations.force_align_and_split_by_intrinsics`.
  """
  name_generator = building_block_factory.unique_name_generator(tree)

  parameter_name = next(name_generator)
  empty_tuple = building_blocks.Struct([])
  value = building_block_factory.create_federated_value(empty_tuple,
                                                        placements.SERVER)
  before_broadcast = building_blocks.Lambda(parameter_name,
                                            tree.type_signature.parameter,
                                            value)

  parameter_name = next(name_generator)
  type_signature = computation_types.FederatedType(
      before_broadcast.type_signature.result.member, placements.CLIENTS)
  parameter_type = computation_types.StructType(
      [tree.type_signature.parameter, type_signature])
  ref = building_blocks.Reference(parameter_name, parameter_type)
  arg = building_blocks.Selection(ref, index=0)
  call = building_blocks.Call(tree, arg)
  after_broadcast = building_blocks.Lambda(ref.name, ref.type_signature, call)

  return before_broadcast, after_broadcast
示例#6
0
def _group_by_intrinsics_in_top_level_lambda(comp):
  """Groups the intrinsics in the frist block local in the result of `comp`.

  This transformation creates an AST by replacing the tuple of called intrinsics
  found as the first local in the `building_blocks.Block` returned by the top
  level lambda with two new computations. The first computation is a tuple of
  tuples of called intrinsics, representing the original tuple of called
  intrinscis grouped by URI. The second computation is a tuple of selection from
  the first computations, representing original tuple of called intrinsics.

  It is necessary to group intrinsics before it is possible to merge them.

  Args:
    comp: The `building_blocks.Lambda` to transform.

  Returns:
    A `building_blocks.Lamda` that returns a `building_blocks.Block`, the first
    local variables of the retunred `building_blocks.Block` will be a tuple of
    tuples of called intrinsics representing the original tuple of called
    intrinscis grouped by URI.

  Raises:
    ValueError: If the first local in the `building_blocks.Block` referenced by
      the top level lambda is not a `building_blocks.Struct` of called
      intrinsics.
  """
  py_typecheck.check_type(comp, building_blocks.Lambda)
  py_typecheck.check_type(comp.result, building_blocks.Block)
  tree_analysis.check_has_unique_names(comp)

  name_generator = building_block_factory.unique_name_generator(comp)

  name, first_local = comp.result.locals[0]
  py_typecheck.check_type(first_local, building_blocks.Struct)
  for element in first_local:
    if not building_block_analysis.is_called_intrinsic(element):
      raise ValueError(
          'Expected all the elements of the `building_blocks.Struct` to be '
          'called intrinsics, but found: \n{}'.format(element))

  # Create collections of data describing how to pack and unpack the intrinsics
  # into groups by their URI.
  #
  # packed_keys is a list of unique URI ordered by occurrence in the original
  #   tuple of called intrinsics.
  # packed_groups is a `collections.OrderedDict` where each key is a URI to
  #   group by and each value is a list of intrinsics with that URI.
  # packed_indexes is a list of tuples where each tuple contains two indexes:
  #   the first index in the tuple is the index of the group that the intrinsic
  #   was packed into; the second index in the tuple is the index of the
  #   intrinsic in that group that the intrinsic was packed into; the index of
  #   the tuple in packed_indexes corresponds to the index of the intrinsic in
  #   the list of intrinsics that are beging grouped. Therefore, packed_indexes
  #   represents an implicit mapping of packed indexes, keyed by unpacked index.
  packed_keys = []
  for called_intrinsic in first_local:
    uri = called_intrinsic.function.uri
    if uri not in packed_keys:
      packed_keys.append(uri)
  # If there are no duplicates, return early.
  if len(packed_keys) == len(first_local):
    return comp, False
  packed_groups = collections.OrderedDict([(x, []) for x in packed_keys])
  packed_indexes = []
  for called_intrinsic in first_local:
    packed_group = packed_groups[called_intrinsic.function.uri]
    packed_group.append(called_intrinsic)
    packed_indexes.append((
        packed_keys.index(called_intrinsic.function.uri),
        len(packed_group) - 1,
    ))

  packed_elements = []
  for called_intrinsics in packed_groups.values():
    if len(called_intrinsics) > 1:
      element = building_blocks.Struct(called_intrinsics)
    else:
      element = called_intrinsics[0]
    packed_elements.append(element)
  packed_comp = building_blocks.Struct(packed_elements)

  packed_ref_name = next(name_generator)
  packed_ref_type = computation_types.to_type(packed_comp.type_signature)
  packed_ref = building_blocks.Reference(packed_ref_name, packed_ref_type)

  unpacked_elements = []
  for indexes in packed_indexes:
    group_index = indexes[0]
    sel = building_blocks.Selection(packed_ref, index=group_index)
    uri = packed_keys[group_index]
    called_intrinsics = packed_groups[uri]
    if len(called_intrinsics) > 1:
      intrinsic_index = indexes[1]
      sel = building_blocks.Selection(sel, index=intrinsic_index)
    unpacked_elements.append(sel)
  unpacked_comp = building_blocks.Struct(unpacked_elements)

  variables = comp.result.locals
  variables[0] = (name, unpacked_comp)
  variables.insert(0, (packed_ref_name, packed_comp))
  block = building_blocks.Block(variables, comp.result.result)
  fn = building_blocks.Lambda(comp.parameter_name, comp.parameter_type, block)
  return fn, True
 def _call_function(arg_type):
   """Creates `lambda x: x()` argument type `arg_type`."""
   arg_name = next(name_generator)
   arg_ref = building_blocks.Reference(arg_name, arg_type)
   called_arg = building_blocks.Call(arg_ref, None)
   return building_blocks.Lambda(arg_name, arg_type, called_arg)
示例#8
0
 def test_returns_true_for_lambdas_representing_identical_functions(self):
   ref_1 = building_blocks.Reference('a', tf.int32)
   fn_1 = building_blocks.Lambda('a', ref_1.type_signature, ref_1)
   ref_2 = building_blocks.Reference('b', tf.int32)
   fn_2 = building_blocks.Lambda('b', ref_2.type_signature, ref_2)
   self.assertTrue(tree_analysis.trees_equal(fn_1, fn_2))
示例#9
0
 def test_returns_false_for_lambdas_with_different_parameter_types(self):
   ref_1 = building_blocks.Reference('a', tf.int32)
   fn_1 = building_blocks.Lambda(ref_1.name, ref_1.type_signature, ref_1)
   ref_2 = building_blocks.Reference('a', tf.float32)
   fn_2 = building_blocks.Lambda(ref_2.name, ref_2.type_signature, ref_2)
   self.assertFalse(tree_analysis.trees_equal(fn_1, fn_2))
示例#10
0
 def test_returns_true_with_excluded_reference(self):
   ref = building_blocks.Reference('a', tf.int32)
   fn = building_blocks.Lambda('b', tf.int32, ref)
   self.assertTrue(
       tree_analysis.contains_no_unbound_references(fn, excluding='a'))
示例#11
0
 def test_returns_false(self):
   ref = building_blocks.Reference('a', tf.int32)
   fn = building_blocks.Lambda('b', tf.int32, ref)
   self.assertFalse(tree_analysis.contains_no_unbound_references(fn))
 def test_ok_lambda_binding_of_new_variable(self):
     y_ref = building_blocks.Reference('y', tf.int32)
     lambda_1 = building_blocks.Lambda('y', tf.int32, y_ref)
     x_data = building_blocks.Data('x', tf.int32)
     single_block = building_blocks.Block([('x', x_data)], lambda_1)
     tree_analysis.check_has_unique_names(single_block)
 def test_ok_block_binding_of_new_variable(self):
     x_data = building_blocks.Data('x', tf.int32)
     single_block = building_blocks.Block([('x', x_data)], x_data)
     lambda_1 = building_blocks.Lambda('y', tf.int32, single_block)
     tree_analysis.check_has_unique_names(lambda_1)
 def test_raises_block_rebinding_of_lambda_variable(self):
     x_data = building_blocks.Data('x', tf.int32)
     single_block = building_blocks.Block([('x', x_data)], x_data)
     lambda_1 = building_blocks.Lambda('x', tf.int32, single_block)
     with self.assertRaises(tree_analysis.NonuniqueNameError):
         tree_analysis.check_has_unique_names(lambda_1)
def transform_postorder(comp, transform):
  """Traverses `comp` recursively postorder and replaces its constituents.

  For each element of `comp` viewed as an expression tree, the transformation
  `transform` is applied first to building blocks it is parameterized by, then
  the element itself. The transformation `transform` should act as an identity
  function on the kinds of elements (computation building blocks) it does not
  care to transform. This corresponds to a post-order traversal of the
  expression tree, i.e., parameters are always transformed left-to-right (in
  the order in which they are listed in building block constructors), then the
  parent is visited and transformed with the already-visited, and possibly
  transformed arguments in place.

  NOTE: In particular, in `Call(f,x)`, both `f` and `x` are arguments to `Call`.
  Therefore, `f` is transformed into `f'`, next `x` into `x'` and finally,
  `Call(f',x')` is transformed at the end.

  Args:
    comp: A `computation_building_block.ComputationBuildingBlock` to traverse
      and transform bottom-up.
    transform: The transformation to apply locally to each building block in
      `comp`. It is a Python function that accepts a building block at input,
      and should return a (building block, bool) tuple as output, where the
      building block is a `computation_building_block.ComputationBuildingBlock`
      representing either the original building block or a transformed building
      block and the bool is a flag indicating if the building block was modified
      as.

  Returns:
    The result of applying `transform` to parts of `comp` in a bottom-up
    fashion, along with a Boolean with the value `True` if `comp` was
    transformed and `False` if it was not.

  Raises:
    TypeError: If the arguments are of the wrong computation_types.
    NotImplementedError: If the argument is a kind of computation building block
      that is currently not recognized.
  """
  py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock)
  if isinstance(comp, (
      building_blocks.CompiledComputation,
      building_blocks.Data,
      building_blocks.Intrinsic,
      building_blocks.Placement,
      building_blocks.Reference,
  )):
    return transform(comp)
  elif isinstance(comp, building_blocks.Selection):
    source, source_modified = transform_postorder(comp.source, transform)
    if source_modified:
      comp = building_blocks.Selection(source, comp.name, comp.index)
    comp, comp_modified = transform(comp)
    return comp, comp_modified or source_modified
  elif isinstance(comp, building_blocks.Tuple):
    elements = []
    elements_modified = False
    for key, value in anonymous_tuple.iter_elements(comp):
      value, value_modified = transform_postorder(value, transform)
      elements.append((key, value))
      elements_modified = elements_modified or value_modified
    if elements_modified:
      comp = building_blocks.Tuple(elements)
    comp, comp_modified = transform(comp)
    return comp, comp_modified or elements_modified
  elif isinstance(comp, building_blocks.Call):
    fn, fn_modified = transform_postorder(comp.function, transform)
    if comp.argument is not None:
      arg, arg_modified = transform_postorder(comp.argument, transform)
    else:
      arg, arg_modified = (None, False)
    if fn_modified or arg_modified:
      comp = building_blocks.Call(fn, arg)
    comp, comp_modified = transform(comp)
    return comp, comp_modified or fn_modified or arg_modified
  elif isinstance(comp, building_blocks.Lambda):
    result, result_modified = transform_postorder(comp.result, transform)
    if result_modified:
      comp = building_blocks.Lambda(comp.parameter_name, comp.parameter_type,
                                    result)
    comp, comp_modified = transform(comp)
    return comp, comp_modified or result_modified
  elif isinstance(comp, building_blocks.Block):
    variables = []
    variables_modified = False
    for key, value in comp.locals:
      value, value_modified = transform_postorder(value, transform)
      variables.append((key, value))
      variables_modified = variables_modified or value_modified
    result, result_modified = transform_postorder(comp.result, transform)
    if variables_modified or result_modified:
      comp = building_blocks.Block(variables, result)
    comp, comp_modified = transform(comp)
    return comp, comp_modified or variables_modified or result_modified
  else:
    raise NotImplementedError(
        'Unrecognized computation building block: {}'.format(str(comp)))
示例#16
0
 def test_returns_true_for_lambdas_referring_to_same_unbound_variables(self):
   ref_to_x = building_blocks.Reference('x', tf.int32)
   fn_1 = building_blocks.Lambda('a', tf.int32, ref_to_x)
   fn_2 = building_blocks.Lambda('a', tf.int32, ref_to_x)
   self.assertTrue(tree_analysis.trees_equal(fn_1, fn_2))
示例#17
0
def consolidate_and_extract_local_processing(comp, grappler_config_proto):
  """Consolidates all the local processing in `comp`.

  The input computation `comp` must have the following properties:

  1. The output of `comp` may be of a federated type or unplaced. We refer to
     the placement `p` of that type as the placement of `comp`. There is no
     placement anywhere in the body of `comp` different than `p`. If `comp`
     is of a functional type, and has a parameter, the type of that parameter
     is a federated type placed at `p` as well, or unplaced if the result of
     the function is unplaced.

  2. The only intrinsics that may appear in the body of `comp` are those that
     manipulate data locally within the same placement. The exact set of these
     intrinsics will be gradually updated. At the moment, we support only the
     following:

     * Either `federated_apply` or `federated_map`, depending on whether `comp`
       is `SERVER`- or `CLIENTS`-placed. `federated_map_all_equal` is also
       allowed in the `CLIENTS`-placed case.

     * Either `federated_value_at_server` or `federated_value_at_clients`,
       likewise placement-dependent.

     * Either `federated_zip_at_server` or `federated_zip_at_clients`, again
       placement-dependent.

     Anything else, including `sequence_*` operators, should have been reduced
     already prior to calling this function.

  3. There are no lambdas in the body of `comp` except for `comp` itself being
     possibly a (top-level) lambda. All other lambdas must have been reduced.
     This requirement may eventually be relaxed by embedding lambda reducer into
     this helper method.

  4. If `comp` is of a functional type, it is either an instance of
     `building_blocks.CompiledComputation`, in which case there is nothing for
     us to do here, or a `building_blocks.Lambda`.

  5. There is at most one unbound reference under `comp`, and this is only
     allowed in the case that `comp` is not of a functional type.

  Aside from the intrinsics specified above, and the possibility of allowing
  lambdas, blocks, and references given the constraints above, the remaining
  constructs in `comp` include a combination of tuples, selections, calls, and
  sections of TensorFlow (as `CompiledComputation`s). This helper function does
  contain the logic to consolidate these constructs.

  The output of this transformation is always a single section of TensorFlow,
  which we henceforth refer to as `result`, the exact form of which depends on
  the placement of `comp` and the presence or absence of an argument.

  a. If there is no argument in `comp`, and `comp` is `SERVER`-placed, then
     the `result` is such that `comp` can be equivalently represented as:

     ```
     federated_value_at_server(result())
     ```

  b. If there is no argument in `comp`, and `comp` is `CLIENTS`-placed, then
     the `result` is such that `comp` can be equivalently represented as:

     ```
     federated_value_at_clients(result())
     ```

  c. If there is an argument in `comp`, and `comp` is `SERVER`-placed, then
     the `result` is such that `comp` can be equivalently represented as:

     ```
     (arg -> federated_apply(<result, arg>))
     ```

  d. If there is an argument in `comp`, and `comp` is `CLIENTS`-placed, then
     the `result` is such that `comp` can be equivalently represented as:

     ```
     (arg -> federated_map(<result, arg>))
     ```

  If the type of `comp` is `T@p` (thus `comp` is non-functional), the type of
  `result` is `T`, where `p` is the specific (concrete) placement of `comp`.

  If the type of `comp` is `(T@p -> U@p)`, then the type of `result` must be
  `(T -> U)`, where `p` is again a specific placement.

  Args:
    comp: An instance of `building_blocks.ComputationBuildingBlock` that serves
      as the input to this transformation, as described above.
    grappler_config_proto: An instance of `tf.compat.v1.ConfigProto` to
      configure Grappler graph optimization of the generated TensorFlow graph.
      If `None`, Grappler is bypassed.

  Returns:
    An instance of `building_blocks.CompiledComputation` that holds the
    TensorFlow section produced by this extraction step, as described above.
  """
  py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock)
  if comp.type_signature.is_function():
    if comp.is_compiled_computation():
      return comp
    elif not comp.is_lambda():
      # We normalize on lambdas for ease of calling unwrap_placement below.
      # The constructed lambda here simply forwards its argument to `comp`.
      arg = building_blocks.Reference(
          next(building_block_factory.unique_name_generator(comp)),
          comp.type_signature.parameter)
      called_fn = building_blocks.Call(comp, arg)
      comp = building_blocks.Lambda(arg.name, arg.type_signature, called_fn)
    if comp.type_signature.result.is_federated():
      unwrapped, _ = tree_transformations.unwrap_placement(comp.result)
      # Unwrapped can be a call to `federated_value_at_P`, or
      # `federated_apply/map`.
      if unwrapped.function.uri in (intrinsic_defs.FEDERATED_APPLY.uri,
                                    intrinsic_defs.FEDERATED_MAP.uri):
        extracted = parse_tff_to_tf(unwrapped.argument[0],
                                    grappler_config_proto)
        check_extraction_result(unwrapped.argument[0], extracted)
        return extracted
      else:
        member_type = None if comp.parameter_type is None else comp.parameter_type.member
        rebound = building_blocks.Lambda(comp.parameter_name, member_type,
                                         unwrapped.argument)
        extracted = parse_tff_to_tf(rebound, grappler_config_proto)
        check_extraction_result(rebound, extracted)
        return extracted
    else:
      extracted = parse_tff_to_tf(comp, grappler_config_proto)
      check_extraction_result(comp, extracted)
      return extracted
  elif comp.type_signature.is_federated():
    unwrapped, _ = tree_transformations.unwrap_placement(comp)
    # Unwrapped can be a call to `federated_value_at_P`, or
    # `federated_apply/map`.
    if unwrapped.function.uri in (intrinsic_defs.FEDERATED_APPLY.uri,
                                  intrinsic_defs.FEDERATED_MAP.uri):
      extracted = parse_tff_to_tf(unwrapped.argument[0], grappler_config_proto)
      check_extraction_result(unwrapped.argument[0], extracted)
      return extracted
    else:
      extracted = parse_tff_to_tf(unwrapped.argument, grappler_config_proto)
      check_extraction_result(unwrapped.argument, extracted)
      return extracted.function
  else:
    called_tf = parse_tff_to_tf(comp, grappler_config_proto)
    check_extraction_result(comp, called_tf)
    return called_tf.function
示例#18
0
 def test_returns_true_for_lambdas(self):
   ref_1 = building_blocks.Reference('a', tf.int32)
   fn_1 = building_blocks.Lambda(ref_1.name, ref_1.type_signature, ref_1)
   ref_2 = building_blocks.Reference('a', tf.int32)
   fn_2 = building_blocks.Lambda(ref_2.name, ref_2.type_signature, ref_2)
   self.assertTrue(tree_analysis.trees_equal(fn_1, fn_2))
 def _identity_function(arg_type):
   """Creates `lambda x: x` with argument type `arg_type`."""
   arg_name = next(name_generator)
   val = building_blocks.Reference(arg_name, arg_type)
   lam = building_blocks.Lambda(arg_name, arg_type, val)
   return lam
示例#20
0
 def test_propogates_dependence_up_through_lambda(self):
   dummy_intrinsic = building_blocks.Intrinsic('dummy_intrinsic', tf.int32)
   lam = building_blocks.Lambda('x', tf.int32, dummy_intrinsic)
   dependent_nodes = tree_analysis.extract_nodes_consuming(
       lam, dummy_intrinsic_predicate)
   self.assertIn(lam, dependent_nodes)
示例#21
0
def create_nested_syntax_tree():
    r"""Constructs computation with explicit ordering for testing traversals.

  The goal of this computation is to exercise each switch
  in transform_postorder_with_symbol_bindings, at least all those that recurse.

  The computation this function constructs can be represented as below.

  Notice that the body of the Lambda *does not depend on the Lambda's
  parameter*, so that if we were actually executing this call the argument will
  be thrown away.

  All leaf nodes are instances of `building_blocks.Data`.

                            Call
                           /    \
                 Lambda('arg')   Data('k')
                     |
                   Block('y','z')-------------
                  /                          |
  ['y'=Data('a'),'z'=Data('b')]              |
                                           Tuple
                                         /       \
                                   Block('v')     Block('x')-------
                                     / \              |            |
                       ['v'=Selection]   Data('g') ['x'=Data('h']  |
                             |                                     |
                             |                                     |
                             |                                 Block('w')
                             |                                   /   \
                           Tuple ------            ['w'=Data('i']     Data('j')
                         /              \
                 Block('t')             Block('u')
                  /     \              /          \
    ['t'=Data('c')]    Data('d') ['u'=Data('e')]  Data('f')


  Postorder traversals:
  If we are reading Data URIs, results of a postorder traversal should be:
  [a, b, c, d, e, f, g, h, i, j, k]

  If we are reading locals declarations, results of a postorder traversal should
  be:
  [t, u, v, w, x, y, z]

  And if we are reading both in an interleaved fashion, results of a postorder
  traversal should be:
  [a, b, c, d, t, e, f, u, g, v, h, i, j, w, x, y, z, k]

  Preorder traversals:
  If we are reading Data URIs, results of a preorder traversal should be:
  [a, b, c, d, e, f, g, h, i, j, k]

  If we are reading locals declarations, results of a preorder traversal should
  be:
  [y, z, v, t, u, x, w]

  And if we are reading both in an interleaved fashion, results of a preorder
  traversal should be:
  [y, z, a, b, v, t, c, d, u, e, f, g, x, h, w, i, j, k]

  Since we are also exposing the ability to hook into variable declarations,
  it is worthwhile considering the order in which variables are assigned in
  this tree. Notice that this order maps neither to preorder nor to postorder
  when purely considering the nodes of the tree above. This would be:
  [arg, y, z, t, u, v, x, w]

  Returns:
    An instance of `building_blocks.ComputationBuildingBlock`
    satisfying the description above.
  """
    data_c = building_blocks.Data('c', tf.float32)
    data_d = building_blocks.Data('d', tf.float32)
    left_most_leaf = building_blocks.Block([('t', data_c)], data_d)

    data_e = building_blocks.Data('e', tf.float32)
    data_f = building_blocks.Data('f', tf.float32)
    center_leaf = building_blocks.Block([('u', data_e)], data_f)
    inner_tuple = building_blocks.Struct([left_most_leaf, center_leaf])

    selected = building_blocks.Selection(inner_tuple, index=0)
    data_g = building_blocks.Data('g', tf.float32)
    middle_block = building_blocks.Block([('v', selected)], data_g)

    data_i = building_blocks.Data('i', tf.float32)
    data_j = building_blocks.Data('j', tf.float32)
    right_most_endpoint = building_blocks.Block([('w', data_i)], data_j)

    data_h = building_blocks.Data('h', tf.int32)
    right_child = building_blocks.Block([('x', data_h)], right_most_endpoint)

    result = building_blocks.Struct([middle_block, right_child])
    data_a = building_blocks.Data('a', tf.float32)
    data_b = building_blocks.Data('b', tf.float32)
    dummy_outer_block = building_blocks.Block([('y', data_a), ('z', data_b)],
                                              result)
    dummy_lambda = building_blocks.Lambda('arg', tf.float32, dummy_outer_block)
    dummy_arg = building_blocks.Data('k', tf.float32)
    called_lambda = building_blocks.Call(dummy_lambda, dummy_arg)

    return called_lambda
示例#22
0
 def test_raises_on_non_tuple_parameter(self):
     lam = building_blocks.Lambda('x', tf.int32,
                                  building_blocks.Reference('x', tf.int32))
     with self.assertRaises(TypeError):
         mapreduce_transformations.zip_selection_as_argument_to_lower_level_lambda(
             lam, [[0]])
示例#23
0
def remove_duplicate_called_graphs(comp):
    """Deduplicates called graphs for a subset of TFF AST constructs.

  Args:
    comp: Instance of `building_blocks.ComputationBuildingBlock` whose called
      graphs we wish to deduplicate, according to `tree_analysis.trees_equal`.
      For `comp` to be eligible here, it must be either a lambda itself whose
      body contains no lambdas or blocks, or another computation containing no
      lambdas or blocks. This restriction is necessary because
      `remove_duplicate_called_graphs` makes no effort to ensure that it is not
      pulling references out of their defining scope, except for the case where
      `comp` is a lambda itself. This function exits early and logs a warning if
      this assumption is violated. Additionally, `comp` must contain only
      computations which can be represented in TensorFlow, IE, satisfy the type
      restriction in `type_utils.is_tensorflow_compatible_type`.

  Returns:
    Either a called instance of `building_blocks.CompiledComputation` or a
    `building_blocks.CompiledComputation` itself, depending on whether `comp`
    is of non-functional or functional type respectively. Additionally, returns
    a boolean to match the `transformation_utils.TransformSpec` pattern.
  """
    py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock)
    tree_analysis.check_has_unique_names(comp)
    name_generator = building_block_factory.unique_name_generator(comp)
    if isinstance(comp, building_blocks.Lambda):
        comp_to_check = comp.result
    else:
        comp_to_check = comp
    if tree_analysis.count_types(
            comp_to_check,
        (building_blocks.Lambda, building_blocks.Block)) > 0:
        logging.warning(
            'The preprocessors have failed to remove called lambdas '
            'and blocks; falling back to less efficient, but '
            'guaranteed, TensorFlow generation with computation %s.', comp)
        return comp, False

    leaf_called_graphs = []

    def _pack_called_graphs_into_block(inner_comp):
        """Packs deduplicated bindings to called graphs in `leaf_called_graphs`."""
        if (isinstance(inner_comp, building_blocks.Call) and isinstance(
                inner_comp.function, building_blocks.CompiledComputation)):
            for (name, x) in leaf_called_graphs:
                if tree_analysis.trees_equal(x, inner_comp):
                    return building_blocks.Reference(
                        name, inner_comp.type_signature), True
            new_name = next(name_generator)
            leaf_called_graphs.append((new_name, inner_comp))
            return building_blocks.Reference(new_name,
                                             inner_comp.type_signature), True

        return inner_comp, False

    if isinstance(comp, building_blocks.Lambda):
        transformed_result, _ = transformation_utils.transform_postorder(
            comp.result, _pack_called_graphs_into_block)
        packed_into_block = building_blocks.Block(leaf_called_graphs,
                                                  transformed_result)
        parsed, _ = create_tensorflow_representing_block(packed_into_block)
        tff_func = building_blocks.Lambda(comp.parameter_name,
                                          comp.parameter_type, parsed)
        tf_parser_callable = tree_to_cc_transformations.TFParser()
        comp, _ = tree_transformations.insert_called_tf_identity_at_leaves(
            tff_func)
        tf_generated, _ = transformation_utils.transform_postorder(
            comp, tf_parser_callable)
    else:
        transformed_result, _ = transformation_utils.transform_postorder(
            comp, _pack_called_graphs_into_block)
        packed_into_block = building_blocks.Block(leaf_called_graphs,
                                                  transformed_result)
        tf_generated, _ = create_tensorflow_representing_block(
            packed_into_block)
    return tf_generated, True
示例#24
0
 def test_raises_on_selection_from_non_tuple(self):
     lam = building_blocks.Lambda(
         'x', [tf.int32], building_blocks.Reference('x', [tf.int32]))
     with self.assertRaisesRegex(TypeError, 'nonexistent index'):
         mapreduce_transformations.zip_selection_as_argument_to_lower_level_lambda(
             lam, [[0, 0]])
 def _create_empty_function(type_elements):
   ref_name = next(name_generator)
   ref_type = computation_types.StructType(type_elements)
   ref = building_blocks.Reference(ref_name, ref_type)
   empty_tuple = building_blocks.Struct([])
   return building_blocks.Lambda(ref.name, ref.type_signature, empty_tuple)
示例#26
0
 def test_raises_on_non_federated_selection(self):
     lam = building_blocks.Lambda(
         'x', [tf.int32], building_blocks.Reference('x', [tf.int32]))
     with self.assertRaises(TypeError):
         mapreduce_transformations.zip_selection_as_argument_to_lower_level_lambda(
             lam, [[0]])
def _extract_update(after_aggregate, grappler_config):
  """Extracts `update` from `after_aggregate`.

  This function is intended to be used by
  `get_canonical_form_for_iterative_process` only. As a result, this function
  does not assert that `after_aggregate` has the expected structure, the
  caller is expected to perform these checks before calling this function.

  Args:
    after_aggregate: The second result of splitting `after_broadcast` on
      aggregate intrinsics.
    grappler_config: An instance of `tf.compat.v1.ConfigProto` to configure
      Grappler graph optimization.

  Returns:
    `update` as specified by `canonical_form.CanonicalForm`, an instance of
    `building_blocks.CompiledComputation`.

  Raises:
    transformations.CanonicalFormCompilationError: If we extract an AST of the
      wrong type.
  """
  s7_elements_in_after_aggregate_result = [0, 1]
  s7_output_extracted = transformations.select_output_from_lambda(
      after_aggregate, s7_elements_in_after_aggregate_result)
  s7_output_zipped = building_blocks.Lambda(
      s7_output_extracted.parameter_name, s7_output_extracted.parameter_type,
      building_block_factory.create_federated_zip(s7_output_extracted.result))
  s6_elements_in_after_aggregate_parameter = [[0, 0, 0], [1, 0], [1, 1]]
  s6_to_s7_computation = (
      transformations.zip_selection_as_argument_to_lower_level_lambda(
          s7_output_zipped,
          s6_elements_in_after_aggregate_parameter).result.function)

  # TODO(b/148942011): The transformation
  # `zip_selection_as_argument_to_lower_level_lambda` does not support selecting
  # from nested structures, therefore we need to pack the type signature
  # `<s1, s3, s4>` as `<s1, <s3, s4>>`.
  name_generator = building_block_factory.unique_name_generator(
      s6_to_s7_computation)

  pack_ref_name = next(name_generator)
  pack_ref_type = computation_types.StructType([
      s6_to_s7_computation.parameter_type.member[0],
      computation_types.StructType([
          s6_to_s7_computation.parameter_type.member[1],
          s6_to_s7_computation.parameter_type.member[2],
      ]),
  ])
  pack_ref = building_blocks.Reference(pack_ref_name, pack_ref_type)
  sel_s1 = building_blocks.Selection(pack_ref, index=0)
  sel = building_blocks.Selection(pack_ref, index=1)
  sel_s3 = building_blocks.Selection(sel, index=0)
  sel_s4 = building_blocks.Selection(sel, index=1)
  result = building_blocks.Struct([sel_s1, sel_s3, sel_s4])
  pack_fn = building_blocks.Lambda(pack_ref.name, pack_ref.type_signature,
                                   result)
  ref_name = next(name_generator)
  ref_type = computation_types.FederatedType(pack_ref_type, placements.SERVER)
  ref = building_blocks.Reference(ref_name, ref_type)
  unpacked_args = building_block_factory.create_federated_map_or_apply(
      pack_fn, ref)
  call = building_blocks.Call(s6_to_s7_computation, unpacked_args)
  fn = building_blocks.Lambda(ref.name, ref.type_signature, call)
  return transformations.consolidate_and_extract_local_processing(
      fn, grappler_config)
示例#28
0
def sequence_reduce(value, zero, op):
    """Reduces a TFF sequence `value` given a `zero` and reduction operator `op`.

  This method reduces a set of elements of a TFF sequence `value`, using a given
  `zero` in the algebra (i.e., the result of reducing an empty sequence) of some
  type `U`, and a reduction operator `op` with type signature `(<U,T> -> U)`
  that incorporates a single `T`-typed element of `value` into the `U`-typed
  result of partial reduction. In the special case of `T` equal to `U`, this
  corresponds to the classical notion of reduction of a set using a commutative
  associative binary operator. The generalized reduction (with `T` not equal to
  `U`) requires that repeated application of `op` to reduce a set of `T` always
  yields the same `U`-typed result, regardless of the order in which elements
  of `T` are processed in the course of the reduction.

  One can also invoke `sequence_reduce` on a federated sequence, in which case
  the reductions are performed pointwise; under the hood, we construct an
  expression  of the form
  `federated_map(x -> sequence_reduce(x, zero, op), value)`. See also the
  discussion on `sequence_map`.

  Note: When applied to a federated value this function does the reduce
  point-wise.

  Args:
    value: A value that is either a TFF sequence, or a federated sequence.
    zero: The result of reducing a sequence with no elements.
    op: An operator with type signature `(<U,T> -> U)`, where `T` is the type of
      the elements of the sequence, and `U` is the type of `zero` to be used in
      performing the reduction.

  Returns:
    The `U`-typed result of reducing elements in the sequence, or if the `value`
    is federated, a federated `U` that represents the result of locally
    reducing each member constituent of `value`.

  Raises:
    TypeError: If the arguments are not of the types specified above.
  """
    value = value_impl.to_value(value, None)
    zero = value_impl.to_value(zero, None)
    op = value_impl.to_value(op, None)
    # Check if the value is a federated sequence that should be reduced
    # under a `federated_map`.
    if value.type_signature.is_federated():
        is_federated_sequence = True
        value_member_type = value.type_signature.member
        value_member_type.check_sequence()
        zero_member_type = zero.type_signature.member
    else:
        is_federated_sequence = False
        value.type_signature.check_sequence()
    if not is_federated_sequence:
        comp = building_block_factory.create_sequence_reduce(
            value.comp, zero.comp, op.comp)
        comp = _bind_comp_as_reference(comp)
        return value_impl.Value(comp)
    else:
        ref_type = computation_types.StructType(
            [value_member_type, zero_member_type])
        ref = building_blocks.Reference('arg', ref_type)
        arg1 = building_blocks.Selection(ref, index=0)
        arg2 = building_blocks.Selection(ref, index=1)
        call = building_block_factory.create_sequence_reduce(
            arg1, arg2, op.comp)
        fn = building_blocks.Lambda(ref.name, ref.type_signature, call)
        fn_value_impl = value_impl.Value(fn)
        args = building_blocks.Struct([value.comp, zero.comp])
        return federated_map(fn_value_impl, args)
示例#29
0
 def test_returns_false(self):
     ref = building_blocks.Reference('a', tf.int32)
     fn = building_blocks.Lambda(ref.name, ref.type_signature, ref)
     self.assertTrue(tree_analysis.contains_no_unbound_references(fn))
 def test_raises_on_nested_lambdas_with_same_variable_name(self):
     ref_to_x = building_blocks.Reference('x', tf.int32)
     lambda_1 = building_blocks.Lambda('x', tf.int32, ref_to_x)
     lambda_2 = building_blocks.Lambda('x', tf.int32, lambda_1)
     with self.assertRaises(tree_analysis.NonuniqueNameError):
         tree_analysis.check_has_unique_names(lambda_2)