Esempio n. 1
0
def bind_single_selection_as_argument_to_lower_level_lambda(comp, index):
  r"""Binds selection from the param of `comp` as param to lower-level lambda.

  The returned pattern is quite important here; given an input lambda `comp`,
  we will return an equivalent structure of the form:


                                    Lambda(x)
                                       |
                                      Call
                                    /      \
                              Lambda        Selection from x

  Args:
    comp: Instance of `tff_framework.Lambda`, whose parameters we wish to rebind
      to a different lambda. This lambda must have unique names.
    index: `int` representing the index to bind as an argument to the
      lower-level lambda.

  Returns:
    An instance of `tff_framework.Lambda`, equivalent to `comp`, satisfying the
    pattern above.
  """
  py_typecheck.check_type(comp, tff_framework.Lambda)
  py_typecheck.check_type(index, int)
  tff_framework.check_has_unique_names(comp)
  comp = _prepare_for_rebinding(comp)
  name_generator = tff_framework.unique_name_generator(comp)
  parameter_name = comp.parameter_name
  new_name = six.next(name_generator)
  new_ref = tff_framework.Reference(new_name,
                                    comp.type_signature.parameter[index])

  def _remove_selection_from_ref(inner_comp):
    if isinstance(inner_comp, tff_framework.Selection) and isinstance(
        inner_comp.source, tff_framework.Reference
    ) and inner_comp.index == index and inner_comp.source.name == parameter_name:
      return new_ref, True
    return inner_comp, False

  references_rebound_in_result, _ = tff_framework.transform_postorder(
      comp.result, _remove_selection_from_ref)
  newly_bound_lambda = tff_framework.Lambda(new_ref.name,
                                            new_ref.type_signature,
                                            references_rebound_in_result)
  _check_for_missed_binding(comp, newly_bound_lambda)
  original_ref = tff_framework.Reference(comp.parameter_name,
                                         comp.parameter_type)
  selection = tff_framework.Selection(original_ref, index=index)
  called_rebound = tff_framework.Call(newly_bound_lambda, selection)
  return tff_framework.Lambda(comp.parameter_name, comp.parameter_type,
                              called_rebound)
Esempio n. 2
0
def parse_tff_to_tf(comp):
  """Parses TFF construct `comp` into TensorFlow construct.

  Does not change the type signature of `comp`. Therefore may return either
  a `tff.fframework.CompiledComputation` or a `tff_framework.Call` with no
  argument and function `tff_framework.CompiledComputation`.

  Args:
    comp: Instance of `tff_framework.ComputationBuildingBlock` to parse down to
      a single TF block.

  Returns:
    The result of parsing TFF to TF. If successful, this is either a single
    `tff_framework.CompiledComputation`, or a call to one. If unseccesful, there
    may be more TFF constructs still remaining. Notice it is not the job of this
    function, but rather its callers, to check that the result of this parse is
    as expected.
  """
  parser_callable = tff_framework.TFParser()
  comp, _ = tff_framework.remove_lambdas_and_blocks(comp)
  # Parsing all the way up from the leaves can be expensive, so we check whether
  # inserting called identities at the leaves is necessary first.
  new_comp, _ = tff_framework.transform_postorder(comp, parser_callable)
  if isinstance(new_comp, tff_framework.CompiledComputation) or isinstance(
      new_comp, tff_framework.Call) and isinstance(
          new_comp.function, tff_framework.CompiledComputation):
    return new_comp
  if isinstance(new_comp, tff_framework.Lambda):
    leaves_decorated, _ = tff_framework.insert_called_tf_identity_at_leaves(
        new_comp)
    comp, _ = tff_framework.remove_lambdas_and_blocks(leaves_decorated)
    parsed_comp, _ = tff_framework.transform_postorder(leaves_decorated,
                                                       parser_callable)
    return parsed_comp
  elif isinstance(new_comp, tff_framework.Call):
    leaves_decorated, _ = tff_framework.insert_called_tf_identity_at_leaves(
        new_comp.function)
    comp, _ = tff_framework.remove_lambdas_and_blocks(leaves_decorated)
    parsed_comp, _ = tff_framework.transform_postorder(leaves_decorated,
                                                       parser_callable)
    return tff_framework.Call(parsed_comp, None)
  else:
    parsed_comp, _ = tff_framework.transform_postorder(new_comp,
                                                       parser_callable)
    return parsed_comp
Esempio n. 3
0
def zip_selection_as_argument_to_lower_level_lambda(comp, selected_index_lists):
  r"""Binds selections from the param of `comp` as params to lower-level lambda.

  Notice that `comp` must be a `tff_framework.Lambda`.

  The returned pattern is quite important here; given an input lambda `Comp`,
  we will return an equivalent structure of the form:


                                    Lambda(x)
                                       |
                                      Call
                                    /      \
                              Lambda        <Selections from x>

  Where <Selections from x> represents a tuple of selections from the parameter
  `x`, as specified by `selected_index_lists`. This transform is necessary in
  order to isolate spurious dependence on arguments that are not in fact used,
  for example after we have separated processing on the server from that which
  happens on the clients, but the server-processing still declares some
  parameters placed at the clients.

  `selected_index_lists` must be a list of lists. Each list represents
  a sequence of selections to the parameter of `comp`. For example, if `var`
  is the parameter of `comp`, the list `[0, 1, 0]` would represent the
  selection `x[0][1][0]`. The elements of these inner lists must be integers;
  that is, the selections must be positional. Notice we do not allow for tuples
  due to automatic unwrapping.

  Args:
    comp: Instance of `tff_framework.Lambda`, whose parameters we wish to rebind
      to a different lambda.
    selected_index_lists: 2-d list of `int`s, specifying the parameters of
      `comp` which we wish to rebind as the parameter to a lower-level lambda.

  Returns:
    An instance of `tff_framework.Lambda`, equivalent to `comp`, satisfying the
    pattern above.
  """
  py_typecheck.check_type(comp, tff_framework.Lambda)
  py_typecheck.check_type(selected_index_lists, list)
  for selection_list in selected_index_lists:
    py_typecheck.check_type(selection_list, list)
    for selected_element in selection_list:
      py_typecheck.check_type(selected_element, int)
  original_comp = comp
  comp = _prepare_for_rebinding(comp)

  top_level_parameter_type = comp.type_signature.parameter
  name_generator = tff_framework.unique_name_generator(comp)
  top_level_parameter_name = comp.parameter_name
  top_level_parameter_reference = tff_framework.Reference(
      top_level_parameter_name, comp.parameter_type)

  type_list = []
  for selection_list in selected_index_lists:
    try:
      selected_type = top_level_parameter_type
      for selection in selection_list:
        selected_type = selected_type[selection]
      type_list.append(selected_type)
    except TypeError:
      six.reraise(
          TypeError,
          TypeError(
              'You have tried to bind a variable to a nonexistent index in your '
              'lambda parameter type; the selection defined by {} is '
              'inadmissible for the lambda parameter type {}, in the comp {}.'
              .format(selection_list, top_level_parameter_type, original_comp)),
          sys.exc_info()[2])

  if not all(isinstance(x, tff.FederatedType) for x in type_list):
    raise TypeError(
        'All selected arguments should be of federated type; your selections '
        'have resulted in the list of types {}'.format(type_list))
  placement = type_list[0].placement
  if not all(x.placement is placement for x in type_list):
    raise ValueError(
        'In order to zip the argument to the lower-level lambda together, all '
        'selected arguments should be at the same placement. Your selections '
        'have resulted in the list of types {}'.format(type_list))

  arg_to_lower_level_lambda_list = []
  for selection_tuple in selected_index_lists:
    selected_comp = top_level_parameter_reference
    for selection in selection_tuple:
      selected_comp = tff_framework.Selection(selected_comp, index=selection)
    arg_to_lower_level_lambda_list.append(selected_comp)
  zip_arg = tff_framework.create_federated_zip(
      tff_framework.Tuple(arg_to_lower_level_lambda_list))

  zip_type = tff.FederatedType([x.member for x in type_list],
                               placement=placement)
  ref_to_zip = tff_framework.Reference(six.next(name_generator), zip_type)

  selections_from_zip = [
      _construct_selection_from_federated_tuple(ref_to_zip, x, name_generator)
      for x in range(len(selected_index_lists))
  ]

  def _replace_selections_with_new_bindings(inner_comp):
    """Identifies selection pattern and replaces with new binding.

    Detecting this pattern is the most brittle part of this rebinding function.
    It relies on pattern-matching, and right now we cannot guarantee that this
    pattern is present in every situation we wish to replace with a new
    binding.

    Args:
      inner_comp: Instance of `tff_framework.ComputationBuildingBlock` in which
        we wish to replace the selections specified by `selected_index_lists`
        with the parallel new bindings from `selections_from_zip`.

    Returns:
      A possibly transformed version of `inner_comp` with nodes matching the
      selection patterns replaced by their new bindings.
    """
    # TODO(b/135541729): Either come up with a preprocessing way to enforce
    # this is sufficient, or rework the should_transform predicate.
    for idx, tup in enumerate(selected_index_lists):
      selection = inner_comp  # Empty selection
      tuple_pattern_matched = True
      for selected_index in tup[::-1]:
        if isinstance(
            selection,
            tff_framework.Selection) and selection.index == selected_index:
          selection = selection.source
        else:
          tuple_pattern_matched = False
          break
      if tuple_pattern_matched:
        if isinstance(selection, tff_framework.Reference
                     ) and selection.name == top_level_parameter_name:
          return selections_from_zip[idx], True
    return inner_comp, False

  variables_rebound_in_result, _ = tff_framework.transform_postorder(
      comp.result, _replace_selections_with_new_bindings)
  lambda_with_zipped_param = tff_framework.Lambda(ref_to_zip.name,
                                                  ref_to_zip.type_signature,
                                                  variables_rebound_in_result)
  _check_for_missed_binding(comp, lambda_with_zipped_param)

  zipped_lambda_called = tff_framework.Call(lambda_with_zipped_param, zip_arg)
  constructed_lambda = tff_framework.Lambda(comp.parameter_name,
                                            comp.parameter_type,
                                            zipped_lambda_called)
  names_uniquified, _ = tff_framework.uniquify_reference_names(
      constructed_lambda)
  return names_uniquified
Esempio n. 4
0
def consolidate_and_extract_local_processing(comp):
  """Consolidates all the local processing in `comp`.

  The input computation `comp` must have the following properties:

  1. The output of `comp` may be of a federated type or unplaced. We refer to
     the placement `p` of that type as the placement of `comp`. There is no
     placement anywhere in the body of `comp` different than `p`. If `comp`
     is of a functional type, and has a parameter, the type of that parameter
     is a federated type placed at `p` as well, or unplaced if the result of
     the function is unplaced.

  2. The only intrinsics that may appear in the body of `comp` are those that
     manipulate data locally within the same placement. The exact set of these
     intrinsics will be gradually updated. At the moment, we support only the
     following:

     * Either `federated_apply` or `federated_map`, depending on whether `comp`
       is `SERVER`- or `CLIENTS`-placed. `federated_map_all_equal` is also
       allowed in the `CLIENTS`-placed case.

     * Either `federated_value_at_server` or `federated_value_at_clients`,
       likewise placement-dependent.

     * Either `federated_zip_at_server` or `federated_zip_at_clients`, again
       placement-dependent.

     Anything else, including `sequence_*` operators, should have been reduced
     already prior to calling this function.

  3. There are no lambdas in the body of `comp` except for `comp` itself being
     possibly a (top-level) lambda. All other lambdas must have been reduced.
     This requirement may eventually be relaxed by embedding lambda reducer into
     this helper method.

  4. If `comp` is of a functional type, it is either an instance of
     `tff_framework.CompiledComputation`, in which case there is nothing for us
     to do here, or a `tff_framework.Lambda`.

  5. There is at most one unbound reference under `comp`, and this is only
     allowed in the case that `comp` is not of a functional type.

  Aside from the intrinsics whitelisted above, and the possibility of allowing
  lambdas, blocks, and references given the constraints above, the remaining
  constructs in `comp` include a combination of tuples, selections, calls, and
  sections of TensorFlow (as `CompiledComputation`s). This helper function does
  contain the logic to consolidate these constructs.

  The output of this transformation is always a single section of TensorFlow,
  which we henceforth refer to as `result`, the exact form of which depends on
  the placement of `comp` and the presence or absence of an argument.

  a. If there is no argument in `comp`, and `comp` is `SERVER`-placed, then
     the `result` is such that `comp` can be equivalently represented as:

     ```
     federated_value_at_server(result())
     ```

  b. If there is no argument in `comp`, and `comp` is `CLIENTS`-placed, then
     the `result` is such that `comp` can be equivalently represented as:

     ```
     federated_value_at_clients(result())
     ```

  c. If there is an argument in `comp`, and `comp` is `SERVER`-placed, then
     the `result` is such that `comp` can be equivalently represented as:

     ```
     (arg -> federated_apply(<result, arg>))
     ```

  d. If there is an argument in `comp`, and `comp` is `CLIENTS`-placed, then
     the `result` is such that `comp` can be equivalently represented as:

     ```
     (arg -> federated_map(<result, arg>))
     ```

  If the type of `comp` is `T@p` (thus `comp` is non-functional), the type of
  `result` is `T`, where `p` is the specific (concrete) placement of `comp`.

  If the type of `comp` is `(T@p -> U@p)`, then the type of `result` must be
  `(T -> U)`, where `p` is again a specific placement.

  Args:
    comp: An instance of `tff_framework.ComputationBuildingBlock` that serves as
      the input to this transformation, as described above.

  Returns:
    An instance of `tff.CompiledComputation` that holds the TensorFlow section
    produced by this extraction step, as described above.
  """
  py_typecheck.check_type(comp, tff_framework.ComputationBuildingBlock)
  comp, _ = tff_framework.remove_lambdas_and_blocks(comp)
  if isinstance(comp.type_signature, tff.FunctionType):
    if isinstance(comp, tff_framework.CompiledComputation):
      return comp
    elif not isinstance(comp, tff_framework.Lambda):
      raise ValueError('Any `tff_framework.ComputationBuildingBlock` of '
                       'functional type passed to '
                       '`consolidate_and_extract_local_processing`  should be '
                       'either a `tff_framework.CompiledComputation` or a '
                       '`tff_framework.Lambda`; you have passed a {} of type '
                       '{}.'.format(type(comp), comp.type_signature))
    if isinstance(comp.result.type_signature, tff.FederatedType):
      unwrapped, _ = tff_framework.unwrap_placement(comp.result)
      func = unwrapped.argument[0]
    else:
      func = comp
    extracted = parse_tff_to_tf(func)
    check_extraction_result(func, extracted)
    return extracted
  elif isinstance(comp.type_signature, tff.FederatedType):
    unwrapped, _ = tff_framework.unwrap_placement(comp)
    if unwrapped.function.uri in (tff_framework.FEDERATED_APPLY.uri,
                                  tff_framework.FEDERATED_MAP.uri):
      extracted = parse_tff_to_tf(unwrapped.argument[0])
      check_extraction_result(unwrapped.argument[0], extracted)
      return extracted
    else:
      decorated_func, _ = tff_framework.insert_called_tf_identity_at_leaves(
          unwrapped.argument.function)
      decorated = tff_framework.Call(decorated_func,
                                     unwrapped.argument.argument)
      extracted = parse_tff_to_tf(decorated)
      check_extraction_result(decorated, extracted)
      return extracted.function
  else:
    called_tf = parse_tff_to_tf(comp)
    check_extraction_result(comp, called_tf)
    return called_tf.function