Esempio n. 1
0
 def test_returns_false_for_tuples_with_different_names(self):
     data_1 = building_blocks.Data('data', tf.int32)
     tuple_1 = building_blocks.Tuple([('a', data_1), ('b', data_1)])
     data_2 = building_blocks.Data('data', tf.float32)
     tuple_2 = building_blocks.Tuple([('c', data_2), ('d', data_2)])
     self.assertFalse(tree_analysis._trees_equal(tuple_1, tuple_2))
Esempio n. 2
0
 def test_returns_false_for_tuples_with_different_elements(self):
     data_1 = building_blocks.Data('data', tf.int32)
     tuple_1 = building_blocks.Tuple([data_1, data_1])
     data_2 = building_blocks.Data('data', tf.float32)
     tuple_2 = building_blocks.Tuple([data_2, data_2])
     self.assertFalse(tree_analysis._trees_equal(tuple_1, tuple_2))
Esempio n. 3
0
def _dictlike_items_to_value(items, context_stack,
                             container_type) -> ValueImpl:
    value = building_blocks.Tuple(
        [(k, ValueImpl.get_comp(to_value(v, None, context_stack)))
         for k, v in items], container_type)
    return ValueImpl(value, context_stack)
Esempio n. 4
0
def create_nested_syntax_tree():
    r"""Constructs computation with explicit ordering for testing traversals.

  The goal of this computation is to exercise each switch
  in transform_postorder_with_symbol_bindings, at least all those that recurse.

  The computation this function constructs can be represented as below.

  Notice that the body of the Lambda *does not depend on the Lambda's
  parameter*, so that if we were actually executing this call the argument will
  be thrown away.

  All leaf nodes are instances of `building_blocks.Data`.

                            Call
                           /    \
                 Lambda('arg')   Data('k')
                     |
                   Block('y','z')-------------
                  /                          |
  ['y'=Data('a'),'z'=Data('b')]              |
                                           Tuple
                                         /       \
                                   Block('v')     Block('x')-------
                                     / \              |            |
                       ['v'=Selection]   Data('g') ['x'=Data('h']  |
                             |                                     |
                             |                                     |
                             |                                 Block('w')
                             |                                   /   \
                           Tuple ------            ['w'=Data('i']     Data('j')
                         /              \
                 Block('t')             Block('u')
                  /     \              /          \
    ['t'=Data('c')]    Data('d') ['u'=Data('e')]  Data('f')


  Postorder traversals:
  If we are reading Data URIs, results of a postorder traversal should be:
  [a, b, c, d, e, f, g, h, i, j, k]

  If we are reading locals declarations, results of a postorder traversal should
  be:
  [t, u, v, w, x, y, z]

  And if we are reading both in an interleaved fashion, results of a postorder
  traversal should be:
  [a, b, c, d, t, e, f, u, g, v, h, i, j, w, x, y, z, k]

  Preorder traversals:
  If we are reading Data URIs, results of a preorder traversal should be:
  [a, b, c, d, e, f, g, h, i, j, k]

  If we are reading locals declarations, results of a preorder traversal should
  be:
  [y, z, v, t, u, x, w]

  And if we are reading both in an interleaved fashion, results of a preorder
  traversal should be:
  [y, z, a, b, v, t, c, d, u, e, f, g, x, h, w, i, j, k]

  Since we are also exposing the ability to hook into variable declarations,
  it is worthwhile considering the order in which variables are assigned in
  this tree. Notice that this order maps neither to preorder nor to postorder
  when purely considering the nodes of the tree above. This would be:
  [arg, y, z, t, u, v, x, w]

  Returns:
    An instance of `building_blocks.ComputationBuildingBlock`
    satisfying the description above.
  """
    data_c = building_blocks.Data('c', tf.float32)
    data_d = building_blocks.Data('d', tf.float32)
    left_most_leaf = building_blocks.Block([('t', data_c)], data_d)

    data_e = building_blocks.Data('e', tf.float32)
    data_f = building_blocks.Data('f', tf.float32)
    center_leaf = building_blocks.Block([('u', data_e)], data_f)
    inner_tuple = building_blocks.Tuple([left_most_leaf, center_leaf])

    selected = building_blocks.Selection(inner_tuple, index=0)
    data_g = building_blocks.Data('g', tf.float32)
    middle_block = building_blocks.Block([('v', selected)], data_g)

    data_i = building_blocks.Data('i', tf.float32)
    data_j = building_blocks.Data('j', tf.float32)
    right_most_endpoint = building_blocks.Block([('w', data_i)], data_j)

    data_h = building_blocks.Data('h', tf.int32)
    right_child = building_blocks.Block([('x', data_h)], right_most_endpoint)

    result = building_blocks.Tuple([middle_block, right_child])
    data_a = building_blocks.Data('a', tf.float32)
    data_b = building_blocks.Data('b', tf.float32)
    dummy_outer_block = building_blocks.Block([('y', data_a), ('z', data_b)],
                                              result)
    dummy_lambda = building_blocks.Lambda('arg', tf.float32, dummy_outer_block)
    dummy_arg = building_blocks.Data('k', tf.float32)
    called_lambda = building_blocks.Call(dummy_lambda, dummy_arg)

    return called_lambda
Esempio n. 5
0
def transform_preorder(
    comp: building_blocks.ComputationBuildingBlock,
    transform: Callable[[building_blocks.ComputationBuildingBlock],
                        TransformReturnType]
) -> TransformReturnType:
  """Walks the AST of `comp` preorder, calling `transform` on the way down.

  Notice that this function will stop walking the tree when its transform
  function modifies a node; this is to prevent the caller from unexpectedly
  kicking off an infinite recursion. For this purpose the transform function
  must identify when it has transformed the structure of a building block; if
  the structure of the building block is modified but `False` is returned as
  the second element of the tuple returned by `transform`, `transform_preorder`
  may result in an infinite recursion.

  Args:
    comp: Instance of `building_blocks.ComputationBuildingBlock` to be
      transformed in a preorder fashion.
    transform: Transform function to be applied to the nodes of `comp`. Must
      return a two-tuple whose first element is a
      `building_blocks.ComputationBuildingBlock` and whose second element is a
      Boolean. If the computation which is passed to `comp` is returned in a
      modified state, must return `True` for the second element.

  Returns:
    A two-tuple, whose first element is modified version of `comp`, and
    whose second element is a Boolean indicating whether `comp` was transformed
    during the walk.

  Raises:
    TypeError: If the argument types don't match those specified above.
  """

  py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock)
  py_typecheck.check_callable(transform)
  inner_comp, modified = transform(comp)
  if modified:
    return inner_comp, modified
  if isinstance(inner_comp, (
      building_blocks.CompiledComputation,
      building_blocks.Data,
      building_blocks.Intrinsic,
      building_blocks.Placement,
      building_blocks.Reference,
  )):
    return inner_comp, modified
  elif isinstance(inner_comp, building_blocks.Lambda):
    transformed_result, result_modified = transform_preorder(
        inner_comp.result, transform)
    if not (modified or result_modified):
      return inner_comp, False
    return building_blocks.Lambda(inner_comp.parameter_name,
                                  inner_comp.parameter_type,
                                  transformed_result), True
  elif isinstance(inner_comp, building_blocks.Tuple):
    elements_modified = False
    elements = []
    for name, val in anonymous_tuple.iter_elements(inner_comp):
      result, result_modified = transform_preorder(val, transform)
      elements_modified = modified or result_modified
      elements.append((name, result))
    if not (modified or elements_modified):
      return inner_comp, False
    return building_blocks.Tuple(elements), True
  elif isinstance(inner_comp, building_blocks.Selection):
    transformed_source, source_modified = transform_preorder(
        inner_comp.source, transform)
    if not (modified or source_modified):
      return inner_comp, False
    return building_blocks.Selection(transformed_source, inner_comp.name,
                                     inner_comp.index), True
  elif isinstance(inner_comp, building_blocks.Call):
    transformed_fn, fn_modified = transform_preorder(inner_comp.function,
                                                     transform)
    if inner_comp.argument is not None:
      transformed_arg, arg_modified = transform_preorder(
          inner_comp.argument, transform)
    else:
      transformed_arg = None
      arg_modified = False
    if not (modified or fn_modified or arg_modified):
      return inner_comp, False
    return building_blocks.Call(transformed_fn, transformed_arg), True
  elif isinstance(inner_comp, building_blocks.Block):
    transformed_variables = []
    values_modified = False
    for key, value in inner_comp.locals:
      transformed_value, value_modified = transform_preorder(value, transform)
      transformed_variables.append((key, transformed_value))
      values_modified = values_modified or value_modified
    transformed_result, result_modified = transform_preorder(
        comp.result, transform)
    if not (modified or values_modified or result_modified):
      return inner_comp, False
    return building_blocks.Block(transformed_variables,
                                 transformed_result), True
  else:
    raise NotImplementedError(
        'Unrecognized computation building block: {}'.format(str(inner_comp)))
Esempio n. 6
0
def transform_postorder(comp, transform):
  """Traverses `comp` recursively postorder and replaces its constituents.

  For each element of `comp` viewed as an expression tree, the transformation
  `transform` is applied first to building blocks it is parameterized by, then
  the element itself. The transformation `transform` should act as an identity
  function on the kinds of elements (computation building blocks) it does not
  care to transform. This corresponds to a post-order traversal of the
  expression tree, i.e., parameters are always transformed left-to-right (in
  the order in which they are listed in building block constructors), then the
  parent is visited and transformed with the already-visited, and possibly
  transformed arguments in place.

  NOTE: In particular, in `Call(f,x)`, both `f` and `x` are arguments to `Call`.
  Therefore, `f` is transformed into `f'`, next `x` into `x'` and finally,
  `Call(f',x')` is transformed at the end.

  Args:
    comp: A `computation_building_block.ComputationBuildingBlock` to traverse
      and transform bottom-up.
    transform: The transformation to apply locally to each building block in
      `comp`. It is a Python function that accepts a building block at input,
      and should return a (building block, bool) tuple as output, where the
      building block is a `computation_building_block.ComputationBuildingBlock`
      representing either the original building block or a transformed building
      block and the bool is a flag indicating if the building block was modified
      as.

  Returns:
    The result of applying `transform` to parts of `comp` in a bottom-up
    fashion, along with a Boolean with the value `True` if `comp` was
    transformed and `False` if it was not.

  Raises:
    TypeError: If the arguments are of the wrong computation_types.
    NotImplementedError: If the argument is a kind of computation building block
      that is currently not recognized.
  """
  py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock)
  if isinstance(comp, (
      building_blocks.CompiledComputation,
      building_blocks.Data,
      building_blocks.Intrinsic,
      building_blocks.Placement,
      building_blocks.Reference,
  )):
    return transform(comp)
  elif isinstance(comp, building_blocks.Selection):
    source, source_modified = transform_postorder(comp.source, transform)
    if source_modified:
      comp = building_blocks.Selection(source, comp.name, comp.index)
    comp, comp_modified = transform(comp)
    return comp, comp_modified or source_modified
  elif isinstance(comp, building_blocks.Tuple):
    elements = []
    elements_modified = False
    for key, value in anonymous_tuple.iter_elements(comp):
      value, value_modified = transform_postorder(value, transform)
      elements.append((key, value))
      elements_modified = elements_modified or value_modified
    if elements_modified:
      comp = building_blocks.Tuple(elements)
    comp, comp_modified = transform(comp)
    return comp, comp_modified or elements_modified
  elif isinstance(comp, building_blocks.Call):
    fn, fn_modified = transform_postorder(comp.function, transform)
    if comp.argument is not None:
      arg, arg_modified = transform_postorder(comp.argument, transform)
    else:
      arg, arg_modified = (None, False)
    if fn_modified or arg_modified:
      comp = building_blocks.Call(fn, arg)
    comp, comp_modified = transform(comp)
    return comp, comp_modified or fn_modified or arg_modified
  elif isinstance(comp, building_blocks.Lambda):
    result, result_modified = transform_postorder(comp.result, transform)
    if result_modified:
      comp = building_blocks.Lambda(comp.parameter_name, comp.parameter_type,
                                    result)
    comp, comp_modified = transform(comp)
    return comp, comp_modified or result_modified
  elif isinstance(comp, building_blocks.Block):
    variables = []
    variables_modified = False
    for key, value in comp.locals:
      value, value_modified = transform_postorder(value, transform)
      variables.append((key, value))
      variables_modified = variables_modified or value_modified
    result, result_modified = transform_postorder(comp.result, transform)
    if variables_modified or result_modified:
      comp = building_blocks.Block(variables, result)
    comp, comp_modified = transform(comp)
    return comp, comp_modified or variables_modified or result_modified
  else:
    raise NotImplementedError(
        'Unrecognized computation building block: {}'.format(str(comp)))
Esempio n. 7
0
def _extract_update(after_aggregate):
    """Extracts `update` from `after_aggregate`.

  This function is intended to be used by
  `get_canonical_form_for_iterative_process` only. As a result, this function
  does not assert that `after_aggregate` has the expected structure, the
  caller is expected to perform these checks before calling this function.

  Args:
    after_aggregate: The second result of splitting `after_broadcast` on
      aggregate intrinsics.

  Returns:
    `update` as specified by `canonical_form.CanonicalForm`, an instance of
    `building_blocks.CompiledComputation`.

  Raises:
    transformations.CanonicalFormCompilationError: If we extract an AST of the
      wrong type.
  """
    s7_elements_in_after_aggregate_result = [0, 1]
    s7_output_extracted = transformations.select_output_from_lambda(
        after_aggregate, s7_elements_in_after_aggregate_result)
    s7_output_zipped = building_blocks.Lambda(
        s7_output_extracted.parameter_name, s7_output_extracted.parameter_type,
        building_block_factory.create_federated_zip(
            s7_output_extracted.result))
    s6_elements_in_after_aggregate_parameter = [[0, 0, 0], [1, 0], [1, 1]]
    s6_to_s7_computation = (
        transformations.zip_selection_as_argument_to_lower_level_lambda(
            s7_output_zipped,
            s6_elements_in_after_aggregate_parameter).result.function)

    # TODO(b/148942011): The transformation
    # `zip_selection_as_argument_to_lower_level_lambda` does not support selecting
    # from nested structures, therefore we need to pack the type signature
    # `<s1, s3, s4>` as `<s1, <s3, s4>>`.
    name_generator = building_block_factory.unique_name_generator(
        s6_to_s7_computation)

    pack_ref_name = next(name_generator)
    pack_ref_type = computation_types.NamedTupleType([
        s6_to_s7_computation.parameter_type.member[0],
        computation_types.NamedTupleType([
            s6_to_s7_computation.parameter_type.member[1],
            s6_to_s7_computation.parameter_type.member[2],
        ]),
    ])
    pack_ref = building_blocks.Reference(pack_ref_name, pack_ref_type)
    sel_s1 = building_blocks.Selection(pack_ref, index=0)
    sel = building_blocks.Selection(pack_ref, index=1)
    sel_s3 = building_blocks.Selection(sel, index=0)
    sel_s4 = building_blocks.Selection(sel, index=1)
    result = building_blocks.Tuple([sel_s1, sel_s3, sel_s4])
    pack_fn = building_blocks.Lambda(pack_ref.name, pack_ref.type_signature,
                                     result)
    ref_name = next(name_generator)
    ref_type = computation_types.FederatedType(pack_ref_type,
                                               placements.SERVER)
    ref = building_blocks.Reference(ref_name, ref_type)
    unpacked_args = building_block_factory.create_federated_map_or_apply(
        pack_fn, ref)
    call = building_blocks.Call(s6_to_s7_computation, unpacked_args)
    fn = building_blocks.Lambda(ref.name, ref.type_signature, call)
    return transformations.consolidate_and_extract_local_processing(fn)
Esempio n. 8
0
    def test_constructs_aggregate_of_tuple_with_one_element(self):
        called_intrinsic = test_utils.create_dummy_called_federated_aggregate(
            accumulate_parameter_name='a',
            merge_parameter_name='b',
            report_parameter_name='c')
        calls = building_blocks.Tuple((called_intrinsic, called_intrinsic))
        comp = calls

        transformed_comp, modified = compiler_transformations.dedupe_and_merge_tuple_intrinsics(
            comp, intrinsic_defs.FEDERATED_AGGREGATE.uri)

        federated_agg = []

        def _find_federated_aggregate(comp):
            if building_block_analysis.is_called_intrinsic(
                    comp, intrinsic_defs.FEDERATED_AGGREGATE.uri):
                federated_agg.append(comp)
            return comp, False

        transformation_utils.transform_postorder(transformed_comp,
                                                 _find_federated_aggregate)
        self.assertTrue(modified)
        self.assertLen(federated_agg, 1)
        self.assertLen(federated_agg[0].type_signature.member, 1)
        self.assertEqual(
            transformed_comp.formatted_representation(), '(_var1 -> <\n'
            '  _var1[0],\n'
            '  _var1[0]\n'
            '>)((x -> <\n'
            '  x[0]\n'
            '>)((let\n'
            '  value=federated_aggregate(<\n'
            '    federated_map(<\n'
            '      (arg -> <\n'
            '        arg\n'
            '      >),\n'
            '      <\n'
            '        data\n'
            '      >[0]\n'
            '    >),\n'
            '    <\n'
            '      data\n'
            '    >,\n'
            '    (let\n'
            '      _var1=<\n'
            '        (a -> data)\n'
            '      >\n'
            '     in (_var2 -> <\n'
            '      _var1[0](<\n'
            '        <\n'
            '          _var2[0][0],\n'
            '          _var2[1][0]\n'
            '        >\n'
            '      >[0])\n'
            '    >)),\n'
            '    (let\n'
            '      _var3=<\n'
            '        (b -> data)\n'
            '      >\n'
            '     in (_var4 -> <\n'
            '      _var3[0](<\n'
            '        <\n'
            '          _var4[0][0],\n'
            '          _var4[1][0]\n'
            '        >\n'
            '      >[0])\n'
            '    >)),\n'
            '    (let\n'
            '      _var5=<\n'
            '        (c -> data)\n'
            '      >\n'
            '     in (_var6 -> <\n'
            '      _var5[0](_var6[0])\n'
            '    >))\n'
            '  >)\n'
            ' in <\n'
            '  federated_apply(<\n'
            '    (arg -> arg[0]),\n'
            '    value\n'
            '  >)\n'
            '>)))')
Esempio n. 9
0
def _create_before_and_after_broadcast_for_no_broadcast(tree):
    r"""Creates a before and after broadcast computations for the given `tree`.

  This function returns the two ASTs:

  Lambda
  |
  Tuple
  |
  []

       Lambda(x)
       |
       Call
      /    \
  Comp      Sel(0)
           /
     Ref(x)

  The first AST is an empty structure that has a type signature satisfying the
  requirements of before broadcast.

  In the second AST, `Comp` is `tree`; `Lambda` has a type signature satisfying
  the requirements of after broadcast; and the argument passed to `Comp` is a
  selection from the parameter of `Lambda` which intentionally drops `c2` on the
  floor.

  This function is intended to be used by
  `get_canonical_form_for_iterative_process` to create before and after
  broadcast computations for the given `tree` when there is no
  `intrinsic_defs.FEDERATED_BROADCAST` in `tree`. As a result, this function
  does not assert that there is no `intrinsic_defs.FEDERATED_BROADCAST` in
  `tree` and it does not assert that `tree` has the expected structure, the
  caller is expected to perform these checks before calling this function.

  Args:
    tree: An instance of `building_blocks.ComputationBuildingBlock`.

  Returns:
    A pair of the form `(before, after)`, where each of `before` and `after`
    is a `tff_framework.ComputationBuildingBlock` that represents a part of the
    result as specified by
    `transformations.force_align_and_split_by_intrinsics`.
  """
    name_generator = building_block_factory.unique_name_generator(tree)

    parameter_name = next(name_generator)
    empty_tuple = building_blocks.Tuple([])
    value = building_block_factory.create_federated_value(
        empty_tuple, placements.SERVER)
    before_broadcast = building_blocks.Lambda(parameter_name,
                                              tree.type_signature.parameter,
                                              value)

    parameter_name = next(name_generator)
    type_signature = computation_types.FederatedType(
        before_broadcast.type_signature.result.member, placements.CLIENTS)
    parameter_type = computation_types.NamedTupleType(
        [tree.type_signature.parameter, type_signature])
    ref = building_blocks.Reference(parameter_name, parameter_type)
    arg = building_blocks.Selection(ref, index=0)
    call = building_blocks.Call(tree, arg)
    after_broadcast = building_blocks.Lambda(ref.name, ref.type_signature,
                                             call)

    return before_broadcast, after_broadcast
Esempio n. 10
0
def _create_before_and_after_aggregate_for_no_federated_secure_sum(tree):
    r"""Creates a before and after aggregate computations for the given `tree`.

  Lambda
  |
  Tuple
  |
  [Comp, Tuple]
         |
         [Tuple, []]
          |
          []

       Lambda(x)
       |
       Call
      /    \
  Comp      Tuple
            |
            [Sel(0),      Sel(0)]
            /            /
         Ref(x)    Sel(1)
                  /
            Ref(x)

  In the first AST, the first element returned by `Lambda`, `Comp`, is the
  result of the before aggregate returned by force aligning and splitting `tree`
  by `intrinsic_defs.FEDERATED_AGGREGATE.uri` and the second element returned by
  `Lambda` is an empty structure that represents the argument to the secure sum
  intrinsic. Therefore, the first AST has a type signature satisfying the
  requirements of before aggregate.

  In the second AST, `Comp` is the after aggregate returned by force aligning
  and splitting `tree` by intrinsic_defs.FEDERATED_AGGREGATE.uri; `Lambda` has a
  type signature satisfying the requirements of after aggregate; and the
  argument passed to `Comp` is a selection from the parameter of `Lambda` which
  intentionally drops `s4` on the floor.

  This function is intended to be used by
  `get_canonical_form_for_iterative_process` to create before and after
  broadcast computations for the given `tree` when there is no
  `intrinsic_defs.FEDERATED_SECURE_SUM` in `tree`. As a result, this function
  does not assert that there is no `intrinsic_defs.FEDERATED_SECURE_SUM` in
  `tree` and it does not assert that `tree` has the expected structure, the
  caller is expected to perform these checks before calling this function.

  Args:
    tree: An instance of `building_blocks.ComputationBuildingBlock`.

  Returns:
    A pair of the form `(before, after)`, where each of `before` and `after`
    is a `tff_framework.ComputationBuildingBlock` that represents a part of the
    result as specified by
    `transformations.force_align_and_split_by_intrinsics`.
  """
    name_generator = building_block_factory.unique_name_generator(tree)

    before_aggregate, after_aggregate = (
        transformations.force_align_and_split_by_intrinsics(
            tree, [intrinsic_defs.FEDERATED_AGGREGATE.uri]))

    empty_tuple = building_blocks.Tuple([])
    value = building_block_factory.create_federated_value(
        empty_tuple, placements.CLIENTS)
    bitwidth = empty_tuple
    args = building_blocks.Tuple([value, bitwidth])
    result = building_blocks.Tuple([before_aggregate.result, args])
    before_aggregate = building_blocks.Lambda(before_aggregate.parameter_name,
                                              before_aggregate.parameter_type,
                                              result)

    ref_name = next(name_generator)
    s4_type = computation_types.FederatedType([], placements.SERVER)
    ref_type = computation_types.NamedTupleType([
        after_aggregate.parameter_type[0],
        computation_types.NamedTupleType([
            after_aggregate.parameter_type[1],
            s4_type,
        ]),
    ])
    ref = building_blocks.Reference(ref_name, ref_type)
    sel_arg = building_blocks.Selection(ref, index=0)
    sel = building_blocks.Selection(ref, index=1)
    sel_s3 = building_blocks.Selection(sel, index=0)
    arg = building_blocks.Tuple([sel_arg, sel_s3])
    call = building_blocks.Call(after_aggregate, arg)
    after_aggregate = building_blocks.Lambda(ref.name, ref.type_signature,
                                             call)

    return before_aggregate, after_aggregate
Esempio n. 11
0
 def test_reduces_lambda_returning_empty_tuple_to_tf(self):
   self.skipTest('Depends on a lower level fix, currently in review.')
   empty_tuple = building_blocks.Tuple([])
   lam = building_blocks.Lambda('x', tf.int32, empty_tuple)
   extracted_tf = transformations.consolidate_and_extract_local_processing(lam)
   self.assertIsInstance(extracted_tf, building_blocks.CompiledComputation)
Esempio n. 12
0
 def _create_empty_function(type_elements):
   ref_name = next(name_generator)
   ref_type = computation_types.NamedTupleType(type_elements)
   ref = building_blocks.Reference(ref_name, ref_type)
   empty_tuple = building_blocks.Tuple([])
   return building_blocks.Lambda(ref.name, ref.type_signature, empty_tuple)
Esempio n. 13
0
def get_canonical_form_for_iterative_process(iterative_process):
  """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process.

  This function transforms computations from the input `iterative_process` into
  an instance of `tff.backends.mapreduce.CanonicalForm`.

  Args:
    iterative_process: An instance of `tff.utils.IterativeProcess`.

  Returns:
    An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this
    process.

  Raises:
    TypeError: If the arguments are of the wrong types.
    transformations.CanonicalFormCompilationError: If the compilation
      process fails.
  """
  py_typecheck.check_type(iterative_process, computation_utils.IterativeProcess)

  initialize_comp = building_blocks.ComputationBuildingBlock.from_proto(
      iterative_process.initialize._computation_proto)  # pylint: disable=protected-access

  next_comp = building_blocks.ComputationBuildingBlock.from_proto(
      iterative_process.next._computation_proto)  # pylint: disable=protected-access

  if not (isinstance(next_comp.type_signature.parameter,
                     computation_types.NamedTupleType) and
          isinstance(next_comp.type_signature.result,
                     computation_types.NamedTupleType)):
    raise TypeError(
        'Any IterativeProcess compatible with CanonicalForm must '
        'have a `next` function which takes and returns instances '
        'of `tff.NamedTupleType`; your next function takes '
        'parameters of type {} and returns results of type {}'.format(
            next_comp.type_signature.parameter,
            next_comp.type_signature.result))

  if len(next_comp.type_signature.result) == 2:
    next_result = next_comp.result
    if isinstance(next_result, building_blocks.Tuple):
      dummy_clients_metrics_appended = building_blocks.Tuple([
          next_result[0],
          next_result[1],
          intrinsics.federated_value([], placements.CLIENTS)._comp  # pylint: disable=protected-access
      ])
    else:
      dummy_clients_metrics_appended = building_blocks.Tuple([
          building_blocks.Selection(next_result, index=0),
          building_blocks.Selection(next_result, index=1),
          intrinsics.federated_value([], placements.CLIENTS)._comp  # pylint: disable=protected-access
      ])
    next_comp = building_blocks.Lambda(next_comp.parameter_name,
                                       next_comp.parameter_type,
                                       dummy_clients_metrics_appended)

  initialize_comp = replace_intrinsics_with_bodies(initialize_comp)
  next_comp = replace_intrinsics_with_bodies(next_comp)

  tree_analysis.check_intrinsics_whitelisted_for_reduction(initialize_comp)
  tree_analysis.check_intrinsics_whitelisted_for_reduction(next_comp)
  tree_analysis.check_broadcast_not_dependent_on_aggregate(next_comp)

  before_broadcast, after_broadcast = (
      transformations.force_align_and_split_by_intrinsic(
          next_comp, intrinsic_defs.FEDERATED_BROADCAST.uri))

  before_aggregate, after_aggregate = (
      transformations.force_align_and_split_by_intrinsic(
          after_broadcast, intrinsic_defs.FEDERATED_AGGREGATE.uri))

  init_info_packed = pack_initialize_comp_type_signature(
      initialize_comp.type_signature)

  next_info_packed = pack_next_comp_type_signature(next_comp.type_signature,
                                                   init_info_packed)

  before_broadcast_info_packed = (
      check_and_pack_before_broadcast_type_signature(
          before_broadcast.type_signature, next_info_packed))

  before_aggregate_info_packed = (
      check_and_pack_before_aggregate_type_signature(
          before_aggregate.type_signature, before_broadcast_info_packed))

  canonical_form_types = check_and_pack_after_aggregate_type_signature(
      after_aggregate.type_signature, before_aggregate_info_packed)

  initialize = transformations.consolidate_and_extract_local_processing(
      initialize_comp)

  if not (isinstance(initialize, building_blocks.CompiledComputation) and
          initialize.type_signature.result ==
          canonical_form_types['initialize_type'].member):
    raise transformations.CanonicalFormCompilationError(
        'Compilation of initialize has failed. Expected to extract a '
        '`building_blocks.CompiledComputation` of type {}, instead we extracted '
        'a {} of type {}.'.format(next_comp.type_signature.parameter[0],
                                  type(initialize),
                                  initialize.type_signature.result))

  prepare = extract_prepare(before_broadcast, canonical_form_types)

  work = extract_work(before_aggregate, after_aggregate, canonical_form_types)

  zero_noarg_function, accumulate, merge, report = extract_aggregate_functions(
      before_aggregate, canonical_form_types)

  update = extract_update(after_aggregate, canonical_form_types)

  cf = canonical_form.CanonicalForm(
      computation_wrapper_instances.building_block_to_computation(initialize),
      computation_wrapper_instances.building_block_to_computation(prepare),
      computation_wrapper_instances.building_block_to_computation(work),
      computation_wrapper_instances.building_block_to_computation(
          zero_noarg_function),
      computation_wrapper_instances.building_block_to_computation(accumulate),
      computation_wrapper_instances.building_block_to_computation(merge),
      computation_wrapper_instances.building_block_to_computation(report),
      computation_wrapper_instances.building_block_to_computation(update))
  return cf
Esempio n. 14
0
 def test_returns_true_for_tuples(self):
     data_1 = building_blocks.Data('data', tf.int32)
     tuple_1 = building_blocks.Tuple([data_1, data_1])
     data_2 = building_blocks.Data('data', tf.int32)
     tuple_2 = building_blocks.Tuple([data_2, data_2])
     self.assertTrue(tree_analysis._trees_equal(tuple_1, tuple_2))
Esempio n. 15
0
def create_tensorflow_representing_block(block):
  """Generates non-duplicated TensorFlow for Block locals binding called graphs.

  Assuming that the argument `block` satisfies the following conditions:

  1. The local variables in `block` are all called graphs, with arbitrary
      arguments.
  2. The result of the Block contains tuples, selections and references,
     but nothing else.

  Then `create_tensorflow_representing_block` will generate a structure, which
  may contain tensorflow functions, calls to tensorflow functions, and
  references, but which have generated this TensorFlow code without duplicating
  work done by referencing the block locals.

  Args:
    block: Instance of `building_blocks.Block`, whose local variables are all
      called instances of `building_blocks.CompiledComputation`, and whose
      result contains only instances of `building_blocks.Reference`,
      `building_blocks.Selection` or `building_blocks.Tuple`.

  Returns:
    A transformed version of `block`, which has pushed references to the called
    graphs in the locals of `block` into TensorFlow.

  Raises:
    TypeError: If `block` is not an instance of `building_blocks.Block`.
    ValueError: If the locals of `block` are anything other than called graphs,
      or if the result of `block` contains anything other than selections,
      references and tuples.
  """
  _check_parameters_for_tf_block_generation(block)

  name_generator = building_block_factory.unique_name_generator(block)

  def _construct_reference_representing(comp_to_represent):
    """Helper closing over `name_generator` for name safety."""
    arg_type = comp_to_represent.type_signature
    arg_name = next(name_generator)
    return building_blocks.Reference(arg_name, arg_type)

  top_level_ref = _get_unbound_ref(block)
  named_comp_classes = transformations.group_block_locals_by_namespace(block)

  if top_level_ref:
    first_comps = [x[1] for x in named_comp_classes[0]]
    tup = building_blocks.Tuple([top_level_ref] + first_comps)
    output_comp = construct_tensorflow_calling_lambda_on_concrete_arg(
        top_level_ref, tup, top_level_ref)
    name_to_output_index = {top_level_ref.name: 0}
  else:
    output_comp = building_block_factory.create_compiled_empty_tuple()
    name_to_output_index = {}

  block_local_names = [x[0] for x in block.locals]

  def _update_name_to_output_index(name_class):
    """Helper closing over `name_to_output_index` and `block_local_names`."""
    offset = len(name_to_output_index.keys())
    for idx, comp_name in enumerate(name_class):
      for var_name in block_local_names:
        if var_name == comp_name:
          name_to_output_index[var_name] = idx + offset

  if top_level_ref:
    first_names = [x[0] for x in named_comp_classes[0]]
    _update_name_to_output_index(first_names)
    remaining_comp_classes = named_comp_classes[1:]
  else:
    remaining_comp_classes = named_comp_classes[:]

  for named_comp_class in remaining_comp_classes:
    if named_comp_class:
      comp_class = [x[1] for x in named_comp_class]
      name_class = [x[0] for x in named_comp_class]
      arg_ref = _construct_reference_representing(output_comp)
      output_comp = _construct_tensorflow_representing_single_local_assignment(
          arg_ref, comp_class, output_comp, name_to_output_index)
      _update_name_to_output_index(name_class)

  arg_ref = _construct_reference_representing(output_comp)
  result_replaced = _replace_references_in_comp_with_selections_from_arg(
      block.result, arg_ref, name_to_output_index)
  comp_called = construct_tensorflow_calling_lambda_on_concrete_arg(
      arg_ref, result_replaced, output_comp)

  return comp_called, True
Esempio n. 16
0
def to_value(
    arg: Any,
    type_spec,
    context_stack: context_stack_base.ContextStack,
) -> ValueImpl:
    """Converts the argument into an instance of `tff.Value`.

  The types of non-`tff.Value` arguments that are currently convertible to
  `tff.Value` include the following:

  * Lists, tuples, anonymous tuples, named tuples, and dictionaries, all
    of which are converted into instances of `tff.Tuple`.
  * Placement literals, converted into instances of `tff.Placement`.
  * Computations.
  * Python constants of type `str`, `int`, `float`, `bool`
  * Numpy objects inherting from `np.ndarray` or `np.generic` (the parent
    of numpy scalar types)

  Args:
    arg: Either an instance of `tff.Value`, or an argument convertible to
      `tff.Value`. The argument must not be `None`.
    type_spec: An optional `computation_types.Type` or value convertible to it
      by `computation_types.to_type` which specifies the desired type signature
      of the resulting value. This allows for disambiguating the target type
      (e.g., when two TFF types can be mapped to the same Python
      representations), or `None` if none available, in which case TFF tries to
      determine the type of the TFF value automatically.
    context_stack: The context stack to use.

  Returns:
    An instance of `tff.Value` corresponding to the given `arg`, and of TFF type
    matching the `type_spec` if specified (not `None`).

  Raises:
    TypeError: if `arg` is of an unsupported type, or of a type that does not
      match `type_spec`. Raises explicit error message if TensorFlow constructs
      are encountered, as TensorFlow code should be sealed away from TFF
      federated context.
  """
    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
    if type_spec is not None:
        type_spec = computation_types.to_type(type_spec)
        type_utils.check_well_formed(type_spec)
    if isinstance(arg, ValueImpl):
        result = arg
    elif isinstance(arg, building_blocks.ComputationBuildingBlock):
        result = ValueImpl(arg, context_stack)
    elif isinstance(arg, placement_literals.PlacementLiteral):
        result = ValueImpl(building_blocks.Placement(arg), context_stack)
    elif isinstance(arg, computation_base.Computation):
        result = ValueImpl(
            building_blocks.CompiledComputation(
                computation_impl.ComputationImpl.get_proto(arg)),
            context_stack)
    elif type_spec is not None and isinstance(type_spec,
                                              computation_types.SequenceType):
        result = _wrap_sequence_as_value(arg, type_spec.element, context_stack)
    elif isinstance(arg, anonymous_tuple.AnonymousTuple):
        result = ValueImpl(
            building_blocks.Tuple([
                (k, ValueImpl.get_comp(to_value(v, None, context_stack)))
                for k, v in anonymous_tuple.iter_elements(arg)
            ]), context_stack)
    elif py_typecheck.is_named_tuple(arg):
        result = to_value(arg._asdict(), None, context_stack)  # pytype: disable=attribute-error
    elif py_typecheck.is_attrs(arg):
        result = to_value(
            attr.asdict(arg,
                        dict_factory=collections.OrderedDict,
                        recurse=False), None, context_stack)
    elif isinstance(arg, dict):
        if isinstance(arg, collections.OrderedDict):
            items = arg.items()
        else:
            items = sorted(arg.items())
        value = building_blocks.Tuple([
            (k, ValueImpl.get_comp(to_value(v, None, context_stack)))
            for k, v in items
        ])
        result = ValueImpl(value, context_stack)
    elif isinstance(arg, (tuple, list)):
        result = ValueImpl(
            building_blocks.Tuple([
                ValueImpl.get_comp(to_value(x, None, context_stack))
                for x in arg
            ]), context_stack)
    elif isinstance(arg, tensorflow_utils.TENSOR_REPRESENTATION_TYPES):
        result = _wrap_constant_as_value(arg, context_stack)
    elif isinstance(arg, (tf.Tensor, tf.Variable)):
        raise TypeError(
            'TensorFlow construct {} has been encountered in a federated '
            'context. TFF does not support mixing TF and federated orchestration '
            'code. Please wrap any TensorFlow constructs with '
            '`tff.tf_computation`.'.format(arg))
    elif isinstance(arg, function_utils.PolymorphicFunction):
        # TODO(b/129567727) remove this case when this is no longer an error
        raise TypeError(
            'Polymorphic computations cannot be converted to a TFF value. Consider '
            'explicitly specifying the argument types of a computation before '
            'passing it to a function that requires a TFF value (such as a TFF '
            'intrinsic like federated_map).')
    else:
        raise TypeError(
            'Unable to interpret an argument of type {} as a TFF value.'.
            format(py_typecheck.type_string(type(arg))))
    py_typecheck.check_type(result, ValueImpl)
    if (type_spec is not None and not type_utils.is_assignable_from(
            type_spec, result.type_signature)):
        raise TypeError(
            'The supplied argument maps to TFF type {}, which is incompatible with '
            'the requested type {}.'.format(result.type_signature, type_spec))
    return result
Esempio n. 17
0
 def test_reduces_lambda_returning_empty_tuple_to_tf(self):
   empty_tuple = building_blocks.Tuple([])
   lam = building_blocks.Lambda('x', tf.int32, empty_tuple)
   extracted_tf = transformations.consolidate_and_extract_local_processing(lam)
   self.assertIsInstance(extracted_tf, building_blocks.CompiledComputation)