예제 #1
0
def _insert_comp_in_top_level_lambda(comp, name, comp_to_insert):
  """Inserts a computation into `comp` with the given `name`.

  Args:
    comp: The `tff_framework.Lambda` to transform. The names of lambda
      parameters and block variables in `comp` must be unique.
    name: The name to use.
    comp_to_insert: The `tff_framework.ComputationBuildingBlock` to insert.

  Returns:
    A new computation with the transformation applied or the original `comp`.
  """
  py_typecheck.check_type(comp, tff_framework.Lambda)
  tff_framework.check_has_unique_names(comp)
  py_typecheck.check_type(name, six.string_types)
  py_typecheck.check_type(comp_to_insert,
                          tff_framework.ComputationBuildingBlock)
  result = comp.result
  if isinstance(result, tff_framework.Block):
    variables = result.locals
    result = result.result
  else:
    variables = []
  variables.insert(0, (name, comp_to_insert))
  block = tff_framework.Block(variables, result)
  return tff_framework.Lambda(comp.parameter_name, comp.parameter_type, block)
예제 #2
0
def select_output_from_lambda(comp, indices):
  """Constructs a new function with result of selecting `indices` from `comp`.

  Args:
    comp: Instance of `tff_framework.Lambda` of result type `tff.NamedTupleType`
      from which we wish to select `indices`. Notice that this named tuple type
      must have elements of federated type.
    indices: Instance of `int`, `list`, or `tuple`, specifying the indices we
      wish to select from the result of `comp`. If `indices` is an `int`, the
      result of the returned `comp` will be of type at index `indices` in
      `comp.type_signature.result`. If `indices` is a `list` or `tuple`, the
      result type will be a `tff.NamedTupleType` wrapping the specified
      selections.

  Returns:
    A transformed version of `comp` with result value the selection from the
    result of `comp` specified by `indices`.
  """
  py_typecheck.check_type(comp, tff_framework.Lambda)
  py_typecheck.check_type(comp.type_signature.result, tff.NamedTupleType)
  py_typecheck.check_type(indices, (int, tuple, list))
  result_tuple = comp.result
  name_generator = tff_framework.unique_name_generator(comp)
  new_name = six.next(name_generator)
  ref_to_result_tuple = tff_framework.Reference(new_name,
                                                result_tuple.type_signature)
  if isinstance(indices, (tuple, list)):
    if not all(isinstance(x, int) for x in indices):
      raise TypeError('Must select by index in `select_output_from_lambda`.')
    selected_output = [
        tff_framework.Selection(ref_to_result_tuple, index=x) for x in indices
    ]
    tuple_of_selected_output = tff_framework.Tuple(selected_output)
    result = tff_framework.Block([(new_name, result_tuple)],
                                 tuple_of_selected_output)
  else:
    selected_output = tff_framework.Selection(
        ref_to_result_tuple, index=indices)
    result = tff_framework.Block([(new_name, result_tuple)], selected_output)
  return tff_framework.Lambda(comp.parameter_name, comp.parameter_type, result)
예제 #3
0
def _split_by_intrinsic(comp, uri):
  """Splits `comp` into `before` and `after` the intrinsic for the given `uri`.

  This function finds the intrinsic for the given `uri` in `comp`; splits `comp`
  into two computations `before` and `after` the intrinsic; and returns a Python
  tuple representing the pair of `before` and `after` computations.

  NOTE: This function is generally safe to call on computations that do not fit
  into canonical form. It is left to the caller to determine if the resulting
  computations are expected.

  Args:
    comp: The `tff_framework.Lambda` to transform.
    uri: A URI of an intrinsic.

  Returns:
    A pair of `tff_framework.ComputationBuildingBlock`s representing the
    computations `before` and `after` the intrinsic.

  Raises:
    ValueError: If `comp` is not a `tff_framework.Lambda` referencing a
      `tff_framework.Block` referencing a collections of variables containing an
      intrinsic with the given `uri` or if there is more than one intrinsic with
      the given `uri` in `comp`.
  """
  py_typecheck.check_type(comp, tff_framework.Lambda)
  py_typecheck.check_type(uri, six.string_types)
  py_typecheck.check_type(comp.result, tff_framework.Block)

  def _get_called_intrinsic_from_block_variables(variables, uri):
    for index, (name, variable) in enumerate(variables):
      if tff_framework.is_called_intrinsic(variable, uri):
        return index, name, variable
    raise ValueError(
        'Expected a lambda referencing a block referencing a collection of '
        'variables containing an intrinsic with the uri: {}, found None.'
        .format(uri))

  index, name, variable = _get_called_intrinsic_from_block_variables(
      comp.result.locals, uri)
  intrinsics = _get_called_intrinsics(comp, uri)
  length = len(intrinsics)
  if length != 1:
    raise ValueError(
        'Expected a computation with exactly one intrinsic with the uri: {}, '
        'found: {}.'.format(uri, length))
  name_generator = tff_framework.unique_name_generator(comp)
  before = tff_framework.Lambda(comp.parameter_name, comp.parameter_type,
                                variable.argument)
  parameter_type = tff.NamedTupleType(
      (comp.parameter_type, variable.type_signature))
  ref_name = six.next(name_generator)
  ref = tff_framework.Reference(ref_name, parameter_type)
  sel_0 = tff_framework.Selection(ref, index=0)
  sel_1 = tff_framework.Selection(ref, index=1)
  variables = comp.result.locals
  variables[index] = (name, sel_1)
  variables.insert(0, (comp.parameter_name, sel_0))
  block = tff_framework.Block(variables, comp.result.result)
  after = tff_framework.Lambda(ref.name, ref.type_signature, block)
  return before, after