Ejemplo n.º 1
0
def _create_before_and_after_aggregate_for_no_federated_secure_sum(tree):
    r"""Creates a before and after aggregate computations for the given `tree`.

  Lambda
  |
  Tuple
  |
  [Comp, Tuple]
         |
         [Tuple, []]
          |
          []

       Lambda(x)
       |
       Call
      /    \
  Comp      Tuple
            |
            [Sel(0),      Sel(0)]
            /            /
         Ref(x)    Sel(1)
                  /
            Ref(x)

  In the first AST, the first element returned by `Lambda`, `Comp`, is the
  result of the before aggregate returned by force aligning and splitting `tree`
  by `intrinsic_defs.FEDERATED_AGGREGATE.uri` and the second element returned by
  `Lambda` is an empty structure that represents the argument to the secure sum
  intrinsic. Therefore, the first AST has a type signature satisfying the
  requirements of before aggregate.

  In the second AST, `Comp` is the after aggregate returned by force aligning
  and splitting `tree` by intrinsic_defs.FEDERATED_AGGREGATE.uri; `Lambda` has a
  type signature satisfying the requirements of after aggregate; and the
  argument passed to `Comp` is a selection from the parameter of `Lambda` which
  intentionally drops `s4` on the floor.

  This function is intended to be used by
  `get_canonical_form_for_iterative_process` to create before and after
  broadcast computations for the given `tree` when there is no
  `intrinsic_defs.FEDERATED_SECURE_SUM` in `tree`. As a result, this function
  does not assert that there is no `intrinsic_defs.FEDERATED_SECURE_SUM` in
  `tree` and it does not assert that `tree` has the expected structure, the
  caller is expected to perform these checks before calling this function.

  Args:
    tree: An instance of `building_blocks.ComputationBuildingBlock`.

  Returns:
    A pair of the form `(before, after)`, where each of `before` and `after`
    is a `tff_framework.ComputationBuildingBlock` that represents a part of the
    result as specified by
    `transformations.force_align_and_split_by_intrinsics`.
  """
    name_generator = building_block_factory.unique_name_generator(tree)

    before_aggregate, after_aggregate = (
        transformations.force_align_and_split_by_intrinsics(
            tree, [intrinsic_defs.FEDERATED_AGGREGATE.uri]))

    empty_tuple = building_blocks.Tuple([])
    value = building_block_factory.create_federated_value(
        empty_tuple, placements.CLIENTS)
    bitwidth = empty_tuple
    args = building_blocks.Tuple([value, bitwidth])
    result = building_blocks.Tuple([before_aggregate.result, args])
    before_aggregate = building_blocks.Lambda(before_aggregate.parameter_name,
                                              before_aggregate.parameter_type,
                                              result)

    ref_name = next(name_generator)
    s4_type = computation_types.FederatedType([], placements.SERVER)
    ref_type = computation_types.NamedTupleType([
        after_aggregate.parameter_type[0],
        computation_types.NamedTupleType([
            after_aggregate.parameter_type[1],
            s4_type,
        ]),
    ])
    ref = building_blocks.Reference(ref_name, ref_type)
    sel_arg = building_blocks.Selection(ref, index=0)
    sel = building_blocks.Selection(ref, index=1)
    sel_s3 = building_blocks.Selection(sel, index=0)
    arg = building_blocks.Tuple([sel_arg, sel_s3])
    call = building_blocks.Call(after_aggregate, arg)
    after_aggregate = building_blocks.Lambda(ref.name, ref.type_signature,
                                             call)

    return before_aggregate, after_aggregate
Ejemplo n.º 2
0
 def test_returns_true_for_lambdas(self):
   ref_1 = building_blocks.Reference('a', tf.int32)
   fn_1 = building_blocks.Lambda(ref_1.name, ref_1.type_signature, ref_1)
   ref_2 = building_blocks.Reference('a', tf.int32)
   fn_2 = building_blocks.Lambda(ref_2.name, ref_2.type_signature, ref_2)
   self.assertTrue(tree_analysis.trees_equal(fn_1, fn_2))
Ejemplo n.º 3
0
 def test_returns_true_for_references(self):
   reference_1 = building_blocks.Reference('a', tf.int32)
   reference_2 = building_blocks.Reference('a', tf.int32)
   self.assertTrue(tree_analysis.trees_equal(reference_1, reference_2))
Ejemplo n.º 4
0
 def test_returns_false_for_comps_with_different_types(self):
   data = building_blocks.Data('data', tf.int32)
   ref = building_blocks.Reference('a', tf.int32)
   self.assertFalse(tree_analysis.trees_equal(data, ref))
   self.assertFalse(tree_analysis.trees_equal(ref, data))
Ejemplo n.º 5
0
 def test_returns_false_for_lambdas_with_different_parameter_types(self):
   ref_1 = building_blocks.Reference('a', tf.int32)
   fn_1 = building_blocks.Lambda(ref_1.name, ref_1.type_signature, ref_1)
   ref_2 = building_blocks.Reference('a', tf.float32)
   fn_2 = building_blocks.Lambda(ref_2.name, ref_2.type_signature, ref_2)
   self.assertFalse(tree_analysis.trees_equal(fn_1, fn_2))
Ejemplo n.º 6
0
 def test_raises_type_error_with_int_excluding(self):
   ref = building_blocks.Reference('a', tf.int32)
   fn = building_blocks.Lambda(ref.name, ref.type_signature, ref)
   with self.assertRaises(TypeError):
     tree_analysis.contains_no_unbound_references(fn, 1)
Ejemplo n.º 7
0
 def test_returns_true_with_excluded_reference(self):
   ref = building_blocks.Reference('a', tf.int32)
   fn = building_blocks.Lambda('b', tf.int32, ref)
   self.assertTrue(
       tree_analysis.contains_no_unbound_references(fn, excluding='a'))
Ejemplo n.º 8
0
 def test_raises_reference_to_functional_type(self):
     function_type = computation_types.FunctionType(tf.int32, tf.int32)
     ref = building_blocks.Reference('x', function_type)
     with self.assertRaisesRegex(ValueError, 'of functional type passed'):
         mapreduce_transformations.consolidate_and_extract_local_processing(
             ref)
Ejemplo n.º 9
0
 def test_raises_on_non_lambda_comp(self):
     ref = building_blocks.Reference('x', [tf.int32])
     with self.assertRaises(TypeError):
         mapreduce_transformations.zip_selection_as_argument_to_lower_level_lambda(
             ref, [0])
 def test_returns_true_for_selections_with_indexes(self):
     ref_1 = building_blocks.Reference('a', [tf.int32, tf.int32])
     selection_1 = building_blocks.Selection(ref_1, index=0)
     ref_2 = building_blocks.Reference('a', [tf.int32, tf.int32])
     selection_2 = building_blocks.Selection(ref_2, index=0)
     self.assertTrue(tree_analysis.trees_equal(selection_1, selection_2))
 def test_raises_with_reference(self):
     ref = building_blocks.Reference('x', tf.int32)
     with self.assertRaisesRegex(ValueError, 'tensorflow'):
         building_block_analysis.get_device_placement_in(ref)
Ejemplo n.º 12
0
 def _construct_reference_representing(comp_to_represent):
     """Helper closing over `name_generator` for name safety."""
     arg_type = comp_to_represent.type_signature
     arg_name = next(name_generator)
     return building_blocks.Reference(arg_name, arg_type)
Ejemplo n.º 13
0
    def test_returns_string_for_reference(self):
        comp = building_blocks.Reference('a', tf.int32)

        self.assertEqual(comp.compact_representation(), 'a')
        self.assertEqual(comp.formatted_representation(), 'a')
        self.assertEqual(comp.structural_representation(), 'Ref(a)')
Ejemplo n.º 14
0
def _extract_update(after_aggregate):
    """Extracts `update` from `after_aggregate`.

  This function is intended to be used by
  `get_canonical_form_for_iterative_process` only. As a result, this function
  does not assert that `after_aggregate` has the expected structure, the
  caller is expected to perform these checks before calling this function.

  Args:
    after_aggregate: The second result of splitting `after_broadcast` on
      aggregate intrinsics.

  Returns:
    `update` as specified by `canonical_form.CanonicalForm`, an instance of
    `building_blocks.CompiledComputation`.

  Raises:
    transformations.CanonicalFormCompilationError: If we extract an AST of the
      wrong type.
  """
    s7_elements_in_after_aggregate_result = [0, 1]
    s7_output_extracted = transformations.select_output_from_lambda(
        after_aggregate, s7_elements_in_after_aggregate_result)
    s7_output_zipped = building_blocks.Lambda(
        s7_output_extracted.parameter_name, s7_output_extracted.parameter_type,
        building_block_factory.create_federated_zip(
            s7_output_extracted.result))
    s6_elements_in_after_aggregate_parameter = [[0, 0, 0], [1, 0], [1, 1]]
    s6_to_s7_computation = (
        transformations.zip_selection_as_argument_to_lower_level_lambda(
            s7_output_zipped,
            s6_elements_in_after_aggregate_parameter).result.function)

    # TODO(b/148942011): The transformation
    # `zip_selection_as_argument_to_lower_level_lambda` does not support selecting
    # from nested structures, therefore we need to pack the type signature
    # `<s1, s3, s4>` as `<s1, <s3, s4>>`.
    name_generator = building_block_factory.unique_name_generator(
        s6_to_s7_computation)

    pack_ref_name = next(name_generator)
    pack_ref_type = computation_types.NamedTupleType([
        s6_to_s7_computation.parameter_type.member[0],
        computation_types.NamedTupleType([
            s6_to_s7_computation.parameter_type.member[1],
            s6_to_s7_computation.parameter_type.member[2],
        ]),
    ])
    pack_ref = building_blocks.Reference(pack_ref_name, pack_ref_type)
    sel_s1 = building_blocks.Selection(pack_ref, index=0)
    sel = building_blocks.Selection(pack_ref, index=1)
    sel_s3 = building_blocks.Selection(sel, index=0)
    sel_s4 = building_blocks.Selection(sel, index=1)
    result = building_blocks.Tuple([sel_s1, sel_s3, sel_s4])
    pack_fn = building_blocks.Lambda(pack_ref.name, pack_ref.type_signature,
                                     result)
    ref_name = next(name_generator)
    ref_type = computation_types.FederatedType(pack_ref_type,
                                               placements.SERVER)
    ref = building_blocks.Reference(ref_name, ref_type)
    unpacked_args = building_block_factory.create_federated_map_or_apply(
        pack_fn, ref)
    call = building_blocks.Call(s6_to_s7_computation, unpacked_args)
    fn = building_blocks.Lambda(ref.name, ref.type_signature, call)
    return transformations.consolidate_and_extract_local_processing(fn)
 def test_should_not_transform_reference(self):
     reference = building_blocks.Reference('x', tf.int32)
     config = tf.compat.v1.ConfigProto()
     tf_optimizer = compiled_computation_transformations.TensorFlowOptimizer(
         config)
     self.assertFalse(tf_optimizer.should_transform(reference))
Ejemplo n.º 16
0
 def test_raises_on_non_tuple_parameter(self):
     lam = building_blocks.Lambda('x', tf.int32,
                                  building_blocks.Reference('x', tf.int32))
     with self.assertRaises(TypeError):
         mapreduce_transformations.zip_selection_as_argument_to_lower_level_lambda(
             lam, [[0]])
Ejemplo n.º 17
0
def zero_or_one_arg_fn_to_building_block(
    fn,
    parameter_name: Optional[str],
    parameter_type: Optional[Any],
    context_stack: context_stack_base.ContextStack,
    suggested_name: Optional[str] = None,
) -> Tuple[building_blocks.ComputationBuildingBlock, computation_types.Type]:
  """Converts a zero- or one-argument `fn` into a computation building block.

  Args:
    fn: A function with 0 or 1 arguments that contains orchestration logic,
      i.e., that expects zero or one `values_base.Value` and returns a result
      convertible to the same.
    parameter_name: The name of the parameter, or `None` if there is't any.
    parameter_type: The TFF type of the parameter, or `None` if there's none.
    context_stack: The context stack to use.
    suggested_name: The optional suggested name to use for the federated context
      that will be used to serialize this function's body (ideally the name of
      the underlying Python function). It might be modified to avoid conflicts.

  Returns:
    A tuple of `(building_blocks.ComputationBuildingBlock,
    computation_types.Type)`, where the first element contains the logic from
    `fn`, and the second element contains potentially annotated type information
    for the result of `fn`.

  Raises:
    ValueError: if `fn` is incompatible with `parameter_type`.
  """
  py_typecheck.check_callable(fn)
  py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
  if suggested_name is not None:
    py_typecheck.check_type(suggested_name, str)
  parameter_type = computation_types.to_type(parameter_type)
  if isinstance(context_stack.current,
                federated_computation_context.FederatedComputationContext):
    parent_context = context_stack.current
  else:
    parent_context = None
  context = federated_computation_context.FederatedComputationContext(
      context_stack, suggested_name=suggested_name, parent=parent_context)
  if parameter_name is not None:
    py_typecheck.check_type(parameter_name, str)
    parameter_name = '{}_{}'.format(context.name, str(parameter_name))
  with context_stack.install(context):
    if parameter_type is not None:
      result = fn(
          value_impl.ValueImpl(
              building_blocks.Reference(parameter_name, parameter_type),
              context_stack))
    else:
      result = fn()
    if result is None:
      raise ValueError(
          'The function defined on line {} of file {} has returned a '
          '`NoneType`, but all TFF functions must return some non-`None` '
          'value.'.format(fn.__code__.co_firstlineno, fn.__code__.co_filename))
    annotated_result_type = type_conversions.infer_type(result)
    result = value_impl.to_value(result, annotated_result_type, context_stack)
    result_comp = value_impl.ValueImpl.get_comp(result)
    annotated_type = computation_types.FunctionType(parameter_type,
                                                    annotated_result_type)
    return building_blocks.Lambda(parameter_name, parameter_type,
                                  result_comp), annotated_type
Ejemplo n.º 18
0
 def test_raises_on_selection_from_non_tuple(self):
     lam = building_blocks.Lambda(
         'x', [tf.int32], building_blocks.Reference('x', [tf.int32]))
     with self.assertRaisesRegex(TypeError, 'nonexistent index'):
         mapreduce_transformations.zip_selection_as_argument_to_lower_level_lambda(
             lam, [[0, 0]])
Ejemplo n.º 19
0
 def test_returns_true(self):
   ref = building_blocks.Reference('a', tf.int32)
   fn = building_blocks.Lambda(ref.name, ref.type_signature, ref)
   self.assertTrue(tree_analysis.contains_no_unbound_references(fn))
Ejemplo n.º 20
0
 def test_raises_on_non_federated_selection(self):
     lam = building_blocks.Lambda(
         'x', [tf.int32], building_blocks.Reference('x', [tf.int32]))
     with self.assertRaises(TypeError):
         mapreduce_transformations.zip_selection_as_argument_to_lower_level_lambda(
             lam, [[0]])
Ejemplo n.º 21
0
 def test_returns_false(self):
   ref = building_blocks.Reference('a', tf.int32)
   fn = building_blocks.Lambda('b', tf.int32, ref)
   self.assertFalse(tree_analysis.contains_no_unbound_references(fn))
Ejemplo n.º 22
0
 def test_raises_on_non_lambda(self):
     fed_type = computation_types.FederatedType(tf.int32,
                                                placements.CLIENTS)
     ref = building_blocks.Reference('x', [fed_type])
     with self.assertRaises(TypeError):
         mapreduce_transformations.select_output_from_lambda(ref, 0)
Ejemplo n.º 23
0
 def test_returns_true_for_lambdas_representing_identical_functions(self):
   ref_1 = building_blocks.Reference('a', tf.int32)
   fn_1 = building_blocks.Lambda('a', ref_1.type_signature, ref_1)
   ref_2 = building_blocks.Reference('b', tf.int32)
   fn_2 = building_blocks.Lambda('b', ref_2.type_signature, ref_2)
   self.assertTrue(tree_analysis.trees_equal(fn_1, fn_2))
Ejemplo n.º 24
0
 def test_raises_on_reference(self):
     ref = building_blocks.Reference('x', tf.int32)
     with self.assertRaises(ValueError):
         building_block_analysis.count_tensorflow_ops_in(ref)
Ejemplo n.º 25
0
 def test_returns_true_for_lambdas_referring_to_same_unbound_variables(self):
   ref_to_x = building_blocks.Reference('x', tf.int32)
   fn_1 = building_blocks.Lambda('a', tf.int32, ref_to_x)
   fn_2 = building_blocks.Lambda('a', tf.int32, ref_to_x)
   self.assertTrue(tree_analysis.trees_equal(fn_1, fn_2))
 def test_should_not_transform_non_compiled_computations(self):
     reference = building_blocks.Reference('x', tf.int32)
     self.assertFalse(compiled_computation_transformations.AddUniqueIDs().
                      should_transform(reference))
Ejemplo n.º 27
0
 def test_returns_false_for_references_with_different_names(self):
   reference_1 = building_blocks.Reference('a', tf.int32)
   reference_2 = building_blocks.Reference('b', tf.int32)
   self.assertFalse(tree_analysis.trees_equal(reference_1, reference_2))
 def test_should_not_transform_non_compiled_computations(self):
     reference = building_blocks.Reference('x', tf.int32)
     self.assertFalse(
         compiled_computation_transformations.RaiseOnDisallowedOp(
             frozenset()).should_transform(reference))
Ejemplo n.º 29
0
 def test_returns_false_for_selections_with_differet_sources(self):
   ref_1 = building_blocks.Reference('a', [tf.int32, tf.int32])
   selection_1 = building_blocks.Selection(ref_1, index=0)
   ref_2 = building_blocks.Reference('b', [tf.int32, tf.int32])
   selection_2 = building_blocks.Selection(ref_2, index=1)
   self.assertFalse(tree_analysis.trees_equal(selection_1, selection_2))
Ejemplo n.º 30
0
def _create_before_and_after_broadcast_for_no_broadcast(tree):
    r"""Creates a before and after broadcast computations for the given `tree`.

  This function returns the two ASTs:

  Lambda
  |
  Tuple
  |
  []

       Lambda(x)
       |
       Call
      /    \
  Comp      Sel(0)
           /
     Ref(x)

  The first AST is an empty structure that has a type signature satisfying the
  requirements of before broadcast.

  In the second AST, `Comp` is `tree`; `Lambda` has a type signature satisfying
  the requirements of after broadcast; and the argument passed to `Comp` is a
  selection from the parameter of `Lambda` which intentionally drops `c2` on the
  floor.

  This function is intended to be used by
  `get_canonical_form_for_iterative_process` to create before and after
  broadcast computations for the given `tree` when there is no
  `intrinsic_defs.FEDERATED_BROADCAST` in `tree`. As a result, this function
  does not assert that there is no `intrinsic_defs.FEDERATED_BROADCAST` in
  `tree` and it does not assert that `tree` has the expected structure, the
  caller is expected to perform these checks before calling this function.

  Args:
    tree: An instance of `building_blocks.ComputationBuildingBlock`.

  Returns:
    A pair of the form `(before, after)`, where each of `before` and `after`
    is a `tff_framework.ComputationBuildingBlock` that represents a part of the
    result as specified by
    `transformations.force_align_and_split_by_intrinsics`.
  """
    name_generator = building_block_factory.unique_name_generator(tree)

    parameter_name = next(name_generator)
    empty_tuple = building_blocks.Tuple([])
    value = building_block_factory.create_federated_value(
        empty_tuple, placements.SERVER)
    before_broadcast = building_blocks.Lambda(parameter_name,
                                              tree.type_signature.parameter,
                                              value)

    parameter_name = next(name_generator)
    type_signature = computation_types.FederatedType(
        before_broadcast.type_signature.result.member, placements.CLIENTS)
    parameter_type = computation_types.NamedTupleType(
        [tree.type_signature.parameter, type_signature])
    ref = building_blocks.Reference(parameter_name, parameter_type)
    arg = building_blocks.Selection(ref, index=0)
    call = building_blocks.Call(tree, arg)
    after_broadcast = building_blocks.Lambda(ref.name, ref.type_signature,
                                             call)

    return before_broadcast, after_broadcast