コード例 #1
0
def bind_single_selection_as_argument_to_lower_level_lambda(comp, index):
  r"""Binds selection from the param of `comp` as param to lower-level lambda.

  The returned pattern is quite important here; given an input lambda `comp`,
  we will return an equivalent structure of the form:


                                    Lambda(x)
                                       |
                                      Call
                                    /      \
                              Lambda        Selection from x

  Args:
    comp: Instance of `tff_framework.Lambda`, whose parameters we wish to rebind
      to a different lambda. This lambda must have unique names.
    index: `int` representing the index to bind as an argument to the
      lower-level lambda.

  Returns:
    An instance of `tff_framework.Lambda`, equivalent to `comp`, satisfying the
    pattern above.
  """
  py_typecheck.check_type(comp, tff_framework.Lambda)
  py_typecheck.check_type(index, int)
  tff_framework.check_has_unique_names(comp)
  comp = _prepare_for_rebinding(comp)
  name_generator = tff_framework.unique_name_generator(comp)
  parameter_name = comp.parameter_name
  new_name = six.next(name_generator)
  new_ref = tff_framework.Reference(new_name,
                                    comp.type_signature.parameter[index])

  def _remove_selection_from_ref(inner_comp):
    if isinstance(inner_comp, tff_framework.Selection) and isinstance(
        inner_comp.source, tff_framework.Reference
    ) and inner_comp.index == index and inner_comp.source.name == parameter_name:
      return new_ref, True
    return inner_comp, False

  references_rebound_in_result, _ = tff_framework.transform_postorder(
      comp.result, _remove_selection_from_ref)
  newly_bound_lambda = tff_framework.Lambda(new_ref.name,
                                            new_ref.type_signature,
                                            references_rebound_in_result)
  _check_for_missed_binding(comp, newly_bound_lambda)
  original_ref = tff_framework.Reference(comp.parameter_name,
                                         comp.parameter_type)
  selection = tff_framework.Selection(original_ref, index=index)
  called_rebound = tff_framework.Call(newly_bound_lambda, selection)
  return tff_framework.Lambda(comp.parameter_name, comp.parameter_type,
                              called_rebound)
コード例 #2
0
def _construct_selection_from_federated_tuple(federated_tuple, selected_index,
                                              name_generator):
  """Selects the index `selected_index` from `federated_tuple`.

  Args:
    federated_tuple: Instance of `tff_framework.ComputationBuildingBlock` of
      federated named tuple type from which we wish to select one of the tuple's
      elements.
    selected_index: Integer index we wish to select from `federated_tuple`.
    name_generator: `generator` to generate unique names in the construction.

  Returns:
    An instance of `tff_framework.ComputationBuildingBlock` representing index
    `selected_index` from `federated_tuple`, still federated at the same
    placement.
  """
  py_typecheck.check_type(federated_tuple,
                          tff_framework.ComputationBuildingBlock)
  py_typecheck.check_type(selected_index, int)
  py_typecheck.check_type(federated_tuple.type_signature, tff.FederatedType)
  py_typecheck.check_type(federated_tuple.type_signature.member,
                          tff.NamedTupleType)
  unique_reference_name = six.next(name_generator)
  selection_function_ref = tff_framework.Reference(
      unique_reference_name, federated_tuple.type_signature.member)
  selected_building_block = tff_framework.Selection(
      selection_function_ref, index=selected_index)
  constructed_selection_function = tff_framework.Lambda(
      unique_reference_name, federated_tuple.type_signature.member,
      selected_building_block)
  return tff_framework.create_federated_map_or_apply(
      constructed_selection_function, federated_tuple)
コード例 #3
0
def _insert_comp_in_top_level_lambda(comp, name, comp_to_insert):
  """Inserts a computation into `comp` with the given `name`.

  Args:
    comp: The `tff_framework.Lambda` to transform. The names of lambda
      parameters and block variables in `comp` must be unique.
    name: The name to use.
    comp_to_insert: The `tff_framework.ComputationBuildingBlock` to insert.

  Returns:
    A new computation with the transformation applied or the original `comp`.
  """
  py_typecheck.check_type(comp, tff_framework.Lambda)
  tff_framework.check_has_unique_names(comp)
  py_typecheck.check_type(name, six.string_types)
  py_typecheck.check_type(comp_to_insert,
                          tff_framework.ComputationBuildingBlock)
  result = comp.result
  if isinstance(result, tff_framework.Block):
    variables = result.locals
    result = result.result
  else:
    variables = []
  variables.insert(0, (name, comp_to_insert))
  block = tff_framework.Block(variables, result)
  return tff_framework.Lambda(comp.parameter_name, comp.parameter_type, block)
コード例 #4
0
 def _normalize_lambda_bit(comp):
   if not isinstance(comp.parameter_type, tff.FederatedType):
     return comp, False
   return tff_framework.Lambda(
       comp.parameter_name,
       tff.FederatedType(comp.parameter_type.member,
                         comp.parameter_type.placement), comp.result), True
コード例 #5
0
def extract_work(before_aggregate, after_aggregate, canonical_form_types):
    """Converts `before_aggregate` and `after_aggregate` to `work`.

  Args:
    before_aggregate: The first result of splitting `after_broadcast` on
      `tff_framework.FEDERATED_AGGREGATE`.
    after_aggregate: The second result of splitting `after_broadcast` on
      `tff_framework.FEDERATED_AGGREGATE`.
    canonical_form_types: `dict` holding the `canonical_form.CanonicalForm` type
      signatures specified by the `tff.utils.IterativeProcess` we are compiling.

  Returns:
    `work` as specified by `canonical_form.CanonicalForm`, an instance of
    `tff_framework.CompiledComputation`.

  Raises:
    transformations.CanonicalFormCompilationError: If we fail to extract a
    `tff_framework.CompiledComputation`, or we extract one of the wrong type.
  """
    # See `get_iterative_process_for_canonical_form()` above for the meaning of
    # variable names used in the code below.
    c3_elements_in_before_aggregate_parameter = [[0, 1], [1]]
    c3_to_before_aggregate_computation = (
        transformations.zip_selection_as_argument_to_lower_level_lambda(
            before_aggregate,
            c3_elements_in_before_aggregate_parameter).result.function)
    c5_index_in_before_aggregate_result = 0
    c3_to_c5_computation = transformations.select_output_from_lambda(
        c3_to_before_aggregate_computation,
        c5_index_in_before_aggregate_result)
    c6_index_in_after_aggregate_result = 2
    after_aggregate_to_c6_computation = transformations.select_output_from_lambda(
        after_aggregate, c6_index_in_after_aggregate_result)
    c3_elements_in_after_aggregate_parameter = [[0, 0, 1], [0, 1]]
    c3_to_c6_computation = (
        transformations.zip_selection_as_argument_to_lower_level_lambda(
            after_aggregate_to_c6_computation,
            c3_elements_in_after_aggregate_parameter).result.function)
    c3_to_unzipped_c4_computation = transformations.concatenate_function_outputs(
        c3_to_c5_computation, c3_to_c6_computation)
    c3_to_c4_computation = tff_framework.Lambda(
        c3_to_unzipped_c4_computation.parameter_name,
        c3_to_unzipped_c4_computation.parameter_type,
        tff_framework.create_federated_zip(
            c3_to_unzipped_c4_computation.result))

    work = transformations.consolidate_and_extract_local_processing(
        c3_to_c4_computation)
    if not isinstance(work, tff_framework.CompiledComputation):
        raise transformations.CanonicalFormCompilationError(
            'Failed to extract a `tff_framework.CompiledComputation` from '
            'work, instead received a {} (of type {}).'.format(
                type(work), work.type_signature))
    if work.type_signature != canonical_form_types['work_type']:
        raise transformations.CanonicalFormCompilationError(
            'Extracted a TF block of the wrong type. Expected a function with type '
            '{}, but the type signature of the TF block was {}'.format(
                canonical_form_types['work_type'], work.type_signature))
    return work
コード例 #6
0
def concatenate_function_outputs(first_function, second_function):
  """Constructs a new function concatenating the outputs of its arguments.

  Assumes that `first_function` and `second_function` already have unique
  names, and have declared parameters of the same type. The constructed
  function will bind its parameter to each of the parameters of
  `first_function` and `second_function`, and return the result of executing
  these functions in parallel and concatenating the outputs in a tuple.

  Args:
    first_function: Instance of `tff_framework.Lambda` whose result we wish to
      concatenate with the result of `second_function`.
    second_function: Instance of `tff_framework.Lambda` whose result we wish to
      concatenate with the result of `first_function`.

  Returns:
    A new instance of `tff_framework.Lambda` with unique names representing the
    computation described above.

  Raises:
    TypeError: If the arguments are not instances of `tff_framework.Lambda`,
    or declare parameters of different types.
  """

  py_typecheck.check_type(first_function, tff_framework.Lambda)
  py_typecheck.check_type(second_function, tff_framework.Lambda)
  tff_framework.check_has_unique_names(first_function)
  tff_framework.check_has_unique_names(second_function)

  if first_function.parameter_type != second_function.parameter_type:
    raise TypeError('Must pass two functions which declare the same parameter '
                    'type to `concatenate_function_outputs`; you have passed '
                    'one function which declared a parameter of type {}, and '
                    'another which declares a parameter of type {}'.format(
                        first_function.type_signature,
                        second_function.type_signature))

  def _rename_first_function_arg(comp):
    if isinstance(
        comp,
        tff_framework.Reference) and comp.name == first_function.parameter_name:
      if comp.type_signature != second_function.parameter_type:
        raise AssertionError('{}, {}'.format(comp.type_signature,
                                             second_function.parameter_type))
      return tff_framework.Reference(second_function.parameter_name,
                                     comp.type_signature), True
    return comp, False

  first_function, _ = tff_framework.transform_postorder(
      first_function, _rename_first_function_arg)

  concatenated_function = tff_framework.Lambda(
      second_function.parameter_name, second_function.parameter_type,
      tff_framework.Tuple([first_function.result, second_function.result]))

  renamed, _ = tff_framework.uniquify_reference_names(concatenated_function)

  return renamed
コード例 #7
0
def extract_update(after_aggregate, canonical_form_types):
    """Converts `after_aggregate` to `update`.

  Args:
    after_aggregate: The second result of splitting `after_broadcast` on
      `tff_framework.FEDERATED_AGGREGATE`.
    canonical_form_types: `dict` holding the `canonical_form.CanonicalForm` type
      signatures specified by the `tff.utils.IterativeProcess` we are compiling.

  Returns:
    `update` as specified by `canonical_form.CanonicalForm`, an instance of
    `tff_framework.CompiledComputation`.

  Raises:
    transformations.CanonicalFormCompilationError: If we fail to extract a
    `tff_framework.CompiledComputation`, or we extract one of the wrong type.
  """
    # See `get_iterative_process_for_canonical_form()` above for the meaning of
    # variable names used in the code below.
    s5_elements_in_after_aggregate_result = [0, 1]
    s5_output_extracted = transformations.select_output_from_lambda(
        after_aggregate, s5_elements_in_after_aggregate_result)
    s5_output_zipped = tff_framework.Lambda(
        s5_output_extracted.parameter_name, s5_output_extracted.parameter_type,
        tff_framework.create_federated_zip(s5_output_extracted.result))
    s4_elements_in_after_aggregate_parameter = [[0, 0, 0], [1]]
    s4_to_s5_computation = (
        transformations.zip_selection_as_argument_to_lower_level_lambda(
            s5_output_zipped,
            s4_elements_in_after_aggregate_parameter).result.function)

    update = transformations.consolidate_and_extract_local_processing(
        s4_to_s5_computation)
    if not isinstance(update, tff_framework.CompiledComputation):
        raise transformations.CanonicalFormCompilationError(
            'Failed to extract a `tff_framework.CompiledComputation` from '
            'update, instead received a {} (of type {}).'.format(
                type(update), update.type_signature))
    if update.type_signature != canonical_form_types['update_type']:
        raise transformations.CanonicalFormCompilationError(
            'Extracted a TF block of the wrong type. Expected a function with type '
            '{}, but the type signature of the TF block was {}'.format(
                canonical_form_types['update_type'], update.type_signature))
    return update
コード例 #8
0
def select_output_from_lambda(comp, indices):
  """Constructs a new function with result of selecting `indices` from `comp`.

  Args:
    comp: Instance of `tff_framework.Lambda` of result type `tff.NamedTupleType`
      from which we wish to select `indices`. Notice that this named tuple type
      must have elements of federated type.
    indices: Instance of `int`, `list`, or `tuple`, specifying the indices we
      wish to select from the result of `comp`. If `indices` is an `int`, the
      result of the returned `comp` will be of type at index `indices` in
      `comp.type_signature.result`. If `indices` is a `list` or `tuple`, the
      result type will be a `tff.NamedTupleType` wrapping the specified
      selections.

  Returns:
    A transformed version of `comp` with result value the selection from the
    result of `comp` specified by `indices`.
  """
  py_typecheck.check_type(comp, tff_framework.Lambda)
  py_typecheck.check_type(comp.type_signature.result, tff.NamedTupleType)
  py_typecheck.check_type(indices, (int, tuple, list))
  result_tuple = comp.result
  name_generator = tff_framework.unique_name_generator(comp)
  new_name = six.next(name_generator)
  ref_to_result_tuple = tff_framework.Reference(new_name,
                                                result_tuple.type_signature)
  if isinstance(indices, (tuple, list)):
    if not all(isinstance(x, int) for x in indices):
      raise TypeError('Must select by index in `select_output_from_lambda`.')
    selected_output = [
        tff_framework.Selection(ref_to_result_tuple, index=x) for x in indices
    ]
    tuple_of_selected_output = tff_framework.Tuple(selected_output)
    result = tff_framework.Block([(new_name, result_tuple)],
                                 tuple_of_selected_output)
  else:
    selected_output = tff_framework.Selection(
        ref_to_result_tuple, index=indices)
    result = tff_framework.Block([(new_name, result_tuple)], selected_output)
  return tff_framework.Lambda(comp.parameter_name, comp.parameter_type, result)
コード例 #9
0
def get_canonical_form_for_iterative_process(iterative_process):
    """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process.

  This function transforms computations from the input `iterative_process` into
  an instance of `tff.backends.mapreduce.CanonicalForm`.

  Args:
    iterative_process: An instance of `tff.utils.IterativeProcess`.

  Returns:
    An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this
    process.

  Raises:
    TypeError: If the arguments are of the wrong types.
    transformations.CanonicalFormCompilationError: If the compilation
      process fails.
  """
    py_typecheck.check_type(iterative_process,
                            computation_utils.IterativeProcess)

    initialize_comp = tff_framework.ComputationBuildingBlock.from_proto(
        iterative_process.initialize._computation_proto)  # pylint: disable=protected-access

    next_comp = tff_framework.ComputationBuildingBlock.from_proto(
        iterative_process.next._computation_proto)  # pylint: disable=protected-access

    if not (isinstance(next_comp.type_signature.parameter, tff.NamedTupleType)
            and isinstance(next_comp.type_signature.result,
                           tff.NamedTupleType)):
        raise TypeError(
            'Any IterativeProcess compatible with CanonicalForm must '
            'have a `next` function which takes and returns instances '
            'of `tff.NamedTupleType`; your next function takes '
            'parameters of type {} and returns results of type {}'.format(
                next_comp.type_signature.parameter,
                next_comp.type_signature.result))

    if len(next_comp.type_signature.result) == 2:
        next_result = next_comp.result
        dummy_clients_metrics_appended = tff_framework.Tuple([
            next_result[0],
            next_result[1],
            tff.federated_value([], tff.CLIENTS)._comp  # pylint: disable=protected-access
        ])
        next_comp = tff_framework.Lambda(next_comp.parameter_name,
                                         next_comp.parameter_type,
                                         dummy_clients_metrics_appended)

    initialize_comp = tff_framework.replace_intrinsics_with_bodies(
        initialize_comp)
    next_comp = tff_framework.replace_intrinsics_with_bodies(next_comp)

    tff_framework.check_intrinsics_whitelisted_for_reduction(initialize_comp)
    tff_framework.check_intrinsics_whitelisted_for_reduction(next_comp)
    tff_framework.check_broadcast_not_dependent_on_aggregate(next_comp)

    before_broadcast, after_broadcast = (
        transformations.force_align_and_split_by_intrinsic(
            next_comp, tff_framework.FEDERATED_BROADCAST.uri))

    before_aggregate, after_aggregate = (
        transformations.force_align_and_split_by_intrinsic(
            after_broadcast, tff_framework.FEDERATED_AGGREGATE.uri))

    init_info_packed = pack_initialize_comp_type_signature(
        initialize_comp.type_signature)

    next_info_packed = pack_next_comp_type_signature(next_comp.type_signature,
                                                     init_info_packed)

    before_broadcast_info_packed = (
        check_and_pack_before_broadcast_type_signature(
            before_broadcast.type_signature, next_info_packed))

    before_aggregate_info_packed = (
        check_and_pack_before_aggregate_type_signature(
            before_aggregate.type_signature, before_broadcast_info_packed))

    canonical_form_types = check_and_pack_after_aggregate_type_signature(
        after_aggregate.type_signature, before_aggregate_info_packed)

    initialize = transformations.consolidate_and_extract_local_processing(
        initialize_comp)

    if not (isinstance(initialize, tff_framework.CompiledComputation)
            and initialize.type_signature.result
            == canonical_form_types['initialize_type'].member):
        raise transformations.CanonicalFormCompilationError(
            'Compilation of initialize has failed. Expected to extract a '
            '`tff_framework.CompiledComputation` of type {}, instead we extracted '
            'a {} of type {}.'.format(next_comp.type_signature.parameter[0],
                                      type(initialize),
                                      initialize.type_signature.result))

    prepare = extract_prepare(before_broadcast, canonical_form_types)

    work = extract_work(before_aggregate, after_aggregate,
                        canonical_form_types)

    zero_noarg_function, accumulate, merge, report = extract_aggregate_functions(
        before_aggregate, canonical_form_types)

    update = extract_update(after_aggregate, canonical_form_types)

    cf = canonical_form.CanonicalForm(
        tff_framework.building_block_to_computation(initialize),
        tff_framework.building_block_to_computation(prepare),
        tff_framework.building_block_to_computation(work),
        tff_framework.building_block_to_computation(zero_noarg_function),
        tff_framework.building_block_to_computation(accumulate),
        tff_framework.building_block_to_computation(merge),
        tff_framework.building_block_to_computation(report),
        tff_framework.building_block_to_computation(update))
    return cf
コード例 #10
0
def zip_selection_as_argument_to_lower_level_lambda(comp, selected_index_lists):
  r"""Binds selections from the param of `comp` as params to lower-level lambda.

  Notice that `comp` must be a `tff_framework.Lambda`.

  The returned pattern is quite important here; given an input lambda `Comp`,
  we will return an equivalent structure of the form:


                                    Lambda(x)
                                       |
                                      Call
                                    /      \
                              Lambda        <Selections from x>

  Where <Selections from x> represents a tuple of selections from the parameter
  `x`, as specified by `selected_index_lists`. This transform is necessary in
  order to isolate spurious dependence on arguments that are not in fact used,
  for example after we have separated processing on the server from that which
  happens on the clients, but the server-processing still declares some
  parameters placed at the clients.

  `selected_index_lists` must be a list of lists. Each list represents
  a sequence of selections to the parameter of `comp`. For example, if `var`
  is the parameter of `comp`, the list `[0, 1, 0]` would represent the
  selection `x[0][1][0]`. The elements of these inner lists must be integers;
  that is, the selections must be positional. Notice we do not allow for tuples
  due to automatic unwrapping.

  Args:
    comp: Instance of `tff_framework.Lambda`, whose parameters we wish to rebind
      to a different lambda.
    selected_index_lists: 2-d list of `int`s, specifying the parameters of
      `comp` which we wish to rebind as the parameter to a lower-level lambda.

  Returns:
    An instance of `tff_framework.Lambda`, equivalent to `comp`, satisfying the
    pattern above.
  """
  py_typecheck.check_type(comp, tff_framework.Lambda)
  py_typecheck.check_type(selected_index_lists, list)
  for selection_list in selected_index_lists:
    py_typecheck.check_type(selection_list, list)
    for selected_element in selection_list:
      py_typecheck.check_type(selected_element, int)
  original_comp = comp
  comp = _prepare_for_rebinding(comp)

  top_level_parameter_type = comp.type_signature.parameter
  name_generator = tff_framework.unique_name_generator(comp)
  top_level_parameter_name = comp.parameter_name
  top_level_parameter_reference = tff_framework.Reference(
      top_level_parameter_name, comp.parameter_type)

  type_list = []
  for selection_list in selected_index_lists:
    try:
      selected_type = top_level_parameter_type
      for selection in selection_list:
        selected_type = selected_type[selection]
      type_list.append(selected_type)
    except TypeError:
      six.reraise(
          TypeError,
          TypeError(
              'You have tried to bind a variable to a nonexistent index in your '
              'lambda parameter type; the selection defined by {} is '
              'inadmissible for the lambda parameter type {}, in the comp {}.'
              .format(selection_list, top_level_parameter_type, original_comp)),
          sys.exc_info()[2])

  if not all(isinstance(x, tff.FederatedType) for x in type_list):
    raise TypeError(
        'All selected arguments should be of federated type; your selections '
        'have resulted in the list of types {}'.format(type_list))
  placement = type_list[0].placement
  if not all(x.placement is placement for x in type_list):
    raise ValueError(
        'In order to zip the argument to the lower-level lambda together, all '
        'selected arguments should be at the same placement. Your selections '
        'have resulted in the list of types {}'.format(type_list))

  arg_to_lower_level_lambda_list = []
  for selection_tuple in selected_index_lists:
    selected_comp = top_level_parameter_reference
    for selection in selection_tuple:
      selected_comp = tff_framework.Selection(selected_comp, index=selection)
    arg_to_lower_level_lambda_list.append(selected_comp)
  zip_arg = tff_framework.create_federated_zip(
      tff_framework.Tuple(arg_to_lower_level_lambda_list))

  zip_type = tff.FederatedType([x.member for x in type_list],
                               placement=placement)
  ref_to_zip = tff_framework.Reference(six.next(name_generator), zip_type)

  selections_from_zip = [
      _construct_selection_from_federated_tuple(ref_to_zip, x, name_generator)
      for x in range(len(selected_index_lists))
  ]

  def _replace_selections_with_new_bindings(inner_comp):
    """Identifies selection pattern and replaces with new binding.

    Detecting this pattern is the most brittle part of this rebinding function.
    It relies on pattern-matching, and right now we cannot guarantee that this
    pattern is present in every situation we wish to replace with a new
    binding.

    Args:
      inner_comp: Instance of `tff_framework.ComputationBuildingBlock` in which
        we wish to replace the selections specified by `selected_index_lists`
        with the parallel new bindings from `selections_from_zip`.

    Returns:
      A possibly transformed version of `inner_comp` with nodes matching the
      selection patterns replaced by their new bindings.
    """
    # TODO(b/135541729): Either come up with a preprocessing way to enforce
    # this is sufficient, or rework the should_transform predicate.
    for idx, tup in enumerate(selected_index_lists):
      selection = inner_comp  # Empty selection
      tuple_pattern_matched = True
      for selected_index in tup[::-1]:
        if isinstance(
            selection,
            tff_framework.Selection) and selection.index == selected_index:
          selection = selection.source
        else:
          tuple_pattern_matched = False
          break
      if tuple_pattern_matched:
        if isinstance(selection, tff_framework.Reference
                     ) and selection.name == top_level_parameter_name:
          return selections_from_zip[idx], True
    return inner_comp, False

  variables_rebound_in_result, _ = tff_framework.transform_postorder(
      comp.result, _replace_selections_with_new_bindings)
  lambda_with_zipped_param = tff_framework.Lambda(ref_to_zip.name,
                                                  ref_to_zip.type_signature,
                                                  variables_rebound_in_result)
  _check_for_missed_binding(comp, lambda_with_zipped_param)

  zipped_lambda_called = tff_framework.Call(lambda_with_zipped_param, zip_arg)
  constructed_lambda = tff_framework.Lambda(comp.parameter_name,
                                            comp.parameter_type,
                                            zipped_lambda_called)
  names_uniquified, _ = tff_framework.uniquify_reference_names(
      constructed_lambda)
  return names_uniquified
コード例 #11
0
def _split_by_intrinsic(comp, uri):
  """Splits `comp` into `before` and `after` the intrinsic for the given `uri`.

  This function finds the intrinsic for the given `uri` in `comp`; splits `comp`
  into two computations `before` and `after` the intrinsic; and returns a Python
  tuple representing the pair of `before` and `after` computations.

  NOTE: This function is generally safe to call on computations that do not fit
  into canonical form. It is left to the caller to determine if the resulting
  computations are expected.

  Args:
    comp: The `tff_framework.Lambda` to transform.
    uri: A URI of an intrinsic.

  Returns:
    A pair of `tff_framework.ComputationBuildingBlock`s representing the
    computations `before` and `after` the intrinsic.

  Raises:
    ValueError: If `comp` is not a `tff_framework.Lambda` referencing a
      `tff_framework.Block` referencing a collections of variables containing an
      intrinsic with the given `uri` or if there is more than one intrinsic with
      the given `uri` in `comp`.
  """
  py_typecheck.check_type(comp, tff_framework.Lambda)
  py_typecheck.check_type(uri, six.string_types)
  py_typecheck.check_type(comp.result, tff_framework.Block)

  def _get_called_intrinsic_from_block_variables(variables, uri):
    for index, (name, variable) in enumerate(variables):
      if tff_framework.is_called_intrinsic(variable, uri):
        return index, name, variable
    raise ValueError(
        'Expected a lambda referencing a block referencing a collection of '
        'variables containing an intrinsic with the uri: {}, found None.'
        .format(uri))

  index, name, variable = _get_called_intrinsic_from_block_variables(
      comp.result.locals, uri)
  intrinsics = _get_called_intrinsics(comp, uri)
  length = len(intrinsics)
  if length != 1:
    raise ValueError(
        'Expected a computation with exactly one intrinsic with the uri: {}, '
        'found: {}.'.format(uri, length))
  name_generator = tff_framework.unique_name_generator(comp)
  before = tff_framework.Lambda(comp.parameter_name, comp.parameter_type,
                                variable.argument)
  parameter_type = tff.NamedTupleType(
      (comp.parameter_type, variable.type_signature))
  ref_name = six.next(name_generator)
  ref = tff_framework.Reference(ref_name, parameter_type)
  sel_0 = tff_framework.Selection(ref, index=0)
  sel_1 = tff_framework.Selection(ref, index=1)
  variables = comp.result.locals
  variables[index] = (name, sel_1)
  variables.insert(0, (comp.parameter_name, sel_0))
  block = tff_framework.Block(variables, comp.result.result)
  after = tff_framework.Lambda(ref.name, ref.type_signature, block)
  return before, after