Example #1
0
def _insert_comp_in_top_level_lambda(comp, name, comp_to_insert):
  """Inserts a computation into `comp` with the given `name`.

  Args:
    comp: The `tff_framework.Lambda` to transform. The names of lambda
      parameters and block variables in `comp` must be unique.
    name: The name to use.
    comp_to_insert: The `tff_framework.ComputationBuildingBlock` to insert.

  Returns:
    A new computation with the transformation applied or the original `comp`.
  """
  py_typecheck.check_type(comp, tff_framework.Lambda)
  tff_framework.check_has_unique_names(comp)
  py_typecheck.check_type(name, six.string_types)
  py_typecheck.check_type(comp_to_insert,
                          tff_framework.ComputationBuildingBlock)
  result = comp.result
  if isinstance(result, tff_framework.Block):
    variables = result.locals
    result = result.result
  else:
    variables = []
  variables.insert(0, (name, comp_to_insert))
  block = tff_framework.Block(variables, result)
  return tff_framework.Lambda(comp.parameter_name, comp.parameter_type, block)
Example #2
0
def concatenate_function_outputs(first_function, second_function):
  """Constructs a new function concatenating the outputs of its arguments.

  Assumes that `first_function` and `second_function` already have unique
  names, and have declared parameters of the same type. The constructed
  function will bind its parameter to each of the parameters of
  `first_function` and `second_function`, and return the result of executing
  these functions in parallel and concatenating the outputs in a tuple.

  Args:
    first_function: Instance of `tff_framework.Lambda` whose result we wish to
      concatenate with the result of `second_function`.
    second_function: Instance of `tff_framework.Lambda` whose result we wish to
      concatenate with the result of `first_function`.

  Returns:
    A new instance of `tff_framework.Lambda` with unique names representing the
    computation described above.

  Raises:
    TypeError: If the arguments are not instances of `tff_framework.Lambda`,
    or declare parameters of different types.
  """

  py_typecheck.check_type(first_function, tff_framework.Lambda)
  py_typecheck.check_type(second_function, tff_framework.Lambda)
  tff_framework.check_has_unique_names(first_function)
  tff_framework.check_has_unique_names(second_function)

  if first_function.parameter_type != second_function.parameter_type:
    raise TypeError('Must pass two functions which declare the same parameter '
                    'type to `concatenate_function_outputs`; you have passed '
                    'one function which declared a parameter of type {}, and '
                    'another which declares a parameter of type {}'.format(
                        first_function.type_signature,
                        second_function.type_signature))

  def _rename_first_function_arg(comp):
    if isinstance(
        comp,
        tff_framework.Reference) and comp.name == first_function.parameter_name:
      if comp.type_signature != second_function.parameter_type:
        raise AssertionError('{}, {}'.format(comp.type_signature,
                                             second_function.parameter_type))
      return tff_framework.Reference(second_function.parameter_name,
                                     comp.type_signature), True
    return comp, False

  first_function, _ = tff_framework.transform_postorder(
      first_function, _rename_first_function_arg)

  concatenated_function = tff_framework.Lambda(
      second_function.parameter_name, second_function.parameter_type,
      tff_framework.Tuple([first_function.result, second_function.result]))

  renamed, _ = tff_framework.uniquify_reference_names(concatenated_function)

  return renamed
Example #3
0
def bind_single_selection_as_argument_to_lower_level_lambda(comp, index):
  r"""Binds selection from the param of `comp` as param to lower-level lambda.

  The returned pattern is quite important here; given an input lambda `comp`,
  we will return an equivalent structure of the form:


                                    Lambda(x)
                                       |
                                      Call
                                    /      \
                              Lambda        Selection from x

  Args:
    comp: Instance of `tff_framework.Lambda`, whose parameters we wish to rebind
      to a different lambda. This lambda must have unique names.
    index: `int` representing the index to bind as an argument to the
      lower-level lambda.

  Returns:
    An instance of `tff_framework.Lambda`, equivalent to `comp`, satisfying the
    pattern above.
  """
  py_typecheck.check_type(comp, tff_framework.Lambda)
  py_typecheck.check_type(index, int)
  tff_framework.check_has_unique_names(comp)
  comp = _prepare_for_rebinding(comp)
  name_generator = tff_framework.unique_name_generator(comp)
  parameter_name = comp.parameter_name
  new_name = six.next(name_generator)
  new_ref = tff_framework.Reference(new_name,
                                    comp.type_signature.parameter[index])

  def _remove_selection_from_ref(inner_comp):
    if isinstance(inner_comp, tff_framework.Selection) and isinstance(
        inner_comp.source, tff_framework.Reference
    ) and inner_comp.index == index and inner_comp.source.name == parameter_name:
      return new_ref, True
    return inner_comp, False

  references_rebound_in_result, _ = tff_framework.transform_postorder(
      comp.result, _remove_selection_from_ref)
  newly_bound_lambda = tff_framework.Lambda(new_ref.name,
                                            new_ref.type_signature,
                                            references_rebound_in_result)
  _check_for_missed_binding(comp, newly_bound_lambda)
  original_ref = tff_framework.Reference(comp.parameter_name,
                                         comp.parameter_type)
  selection = tff_framework.Selection(original_ref, index=index)
  called_rebound = tff_framework.Call(newly_bound_lambda, selection)
  return tff_framework.Lambda(comp.parameter_name, comp.parameter_type,
                              called_rebound)
Example #4
0
def _can_extract_intrinsic_to_top_level_lambda(comp, uri):
  """Tests if the intrinsic for the given `uri` can be extracted.

  Args:
    comp: The `tff_framework.Lambda` to test. The names of lambda parameters and
      block variables in `comp` must be unique.
    uri: A URI of an intrinsic.

  Returns:
    `True` if the intrinsic can be extracted, otherwise `False`.
  """
  py_typecheck.check_type(comp, tff_framework.Lambda)
  tff_framework.check_has_unique_names(comp)
  py_typecheck.check_type(uri, six.string_types)
  intrinsics = _get_called_intrinsics(comp, uri)
  return _are_comps_bound_exclusively_by_top_level_lambda(comp, intrinsics)
Example #5
0
def _extract_intrinsic_as_reference_to_top_level_lambda(comp, uri):
  """Extracts an intrinsic from `comp` as a reference for the given `uri`.

  Args:
    comp: The `tff_framework.Lambda` to transform. The names of lambda
      parameters and block variables in `comp` must be unique.
    uri: A URI of an intrinsic.

  Returns:
    A new computation with the transformation applied or the original `comp`.

  Raises:
    ValueError: If there is more than one intrinsic for the give `uri` or if the
      intrinsic is not exclusively bound by `comp`.
  """
  py_typecheck.check_type(comp, tff_framework.Lambda)
  tff_framework.check_has_unique_names(comp)
  py_typecheck.check_type(uri, six.string_types)
  intrinsics = _get_called_intrinsics(comp, uri)
  length = len(intrinsics)
  if length != 1:
    raise ValueError(
        'Expected a computation with exactly one intrinsic with the uri: {}, '
        'found: {}.'.format(uri, length))
  if not _are_comps_bound_exclusively_by_top_level_lambda(comp, intrinsics):
    raise ValueError(
        'Expected a computation which binds all the references in the '
        'intrinsic with the uri: {}.'.format(uri))
  name_generator = tff_framework.unique_name_generator(comp)
  extracted_intrinsic = intrinsics[0]
  ref_name = six.next(name_generator)
  ref_type = tff.to_type(extracted_intrinsic.type_signature)
  ref = tff_framework.Reference(ref_name, ref_type)

  def _should_transform(comp):
    return tff_framework.is_called_intrinsic(comp, uri)

  def _transform(comp):
    if not _should_transform(comp):
      return comp, False
    return ref, True

  comp, _ = tff_framework.transform_postorder(comp, _transform)
  comp = _insert_comp_in_top_level_lambda(
      comp, name=ref.name, comp_to_insert=extracted_intrinsic)
  return comp, True
Example #6
0
def _are_comps_bound_exclusively_by_top_level_lambda(comp, comps):
  """Tests if all computations in `comps` are bound exclusively by `comp`.

  Args:
    comp: The `tff_framework.Lambda` to test. The names of lambda parameters and
      block variables in `comp` must be unique.
    comps: A Python `list` of computations to test.

  Returns:
    `True` if the unbound references in each computation in `comps` are bound by
    exclusively the parameter of `comp`, otherwise `False`.
  """
  py_typecheck.check_type(comp, tff_framework.Lambda)
  tff_framework.check_has_unique_names(comp)
  py_typecheck.check_type(comps, (list, tuple, set))
  unbound_references = tff_framework.get_map_of_unbound_references(comp)
  names = set((comp.parameter_name,))
  return all(names == unbound_references[e] for e in comps)