def _create_before_and_after_aggregate_for_no_federated_secure_sum(tree): r"""Creates a before and after aggregate computations for the given `tree`. Lambda | Tuple | [Comp, Tuple] | [Tuple, []] | [] Lambda(x) | Call / \ Comp Tuple | [Sel(0), Sel(0)] / / Ref(x) Sel(1) / Ref(x) In the first AST, the first element returned by `Lambda`, `Comp`, is the result of the before aggregate returned by force aligning and splitting `tree` by `intrinsic_defs.FEDERATED_AGGREGATE.uri` and the second element returned by `Lambda` is an empty structure that represents the argument to the secure sum intrinsic. Therefore, the first AST has a type signature satisfying the requirements of before aggregate. In the second AST, `Comp` is the after aggregate returned by force aligning and splitting `tree` by intrinsic_defs.FEDERATED_AGGREGATE.uri; `Lambda` has a type signature satisfying the requirements of after aggregate; and the argument passed to `Comp` is a selection from the parameter of `Lambda` which intentionally drops `s4` on the floor. This function is intended to be used by `get_canonical_form_for_iterative_process` to create before and after broadcast computations for the given `tree` when there is no `intrinsic_defs.FEDERATED_SECURE_SUM` in `tree`. As a result, this function does not assert that there is no `intrinsic_defs.FEDERATED_SECURE_SUM` in `tree` and it does not assert that `tree` has the expected structure, the caller is expected to perform these checks before calling this function. Args: tree: An instance of `building_blocks.ComputationBuildingBlock`. Returns: A pair of the form `(before, after)`, where each of `before` and `after` is a `tff_framework.ComputationBuildingBlock` that represents a part of the result as specified by `transformations.force_align_and_split_by_intrinsics`. """ name_generator = building_block_factory.unique_name_generator(tree) before_aggregate, after_aggregate = ( transformations.force_align_and_split_by_intrinsics( tree, [intrinsic_defs.FEDERATED_AGGREGATE.uri])) empty_tuple = building_blocks.Tuple([]) value = building_block_factory.create_federated_value( empty_tuple, placements.CLIENTS) bitwidth = empty_tuple args = building_blocks.Tuple([value, bitwidth]) result = building_blocks.Tuple([before_aggregate.result, args]) before_aggregate = building_blocks.Lambda(before_aggregate.parameter_name, before_aggregate.parameter_type, result) ref_name = next(name_generator) s4_type = computation_types.FederatedType([], placements.SERVER) ref_type = computation_types.NamedTupleType([ after_aggregate.parameter_type[0], computation_types.NamedTupleType([ after_aggregate.parameter_type[1], s4_type, ]), ]) ref = building_blocks.Reference(ref_name, ref_type) sel_arg = building_blocks.Selection(ref, index=0) sel = building_blocks.Selection(ref, index=1) sel_s3 = building_blocks.Selection(sel, index=0) arg = building_blocks.Tuple([sel_arg, sel_s3]) call = building_blocks.Call(after_aggregate, arg) after_aggregate = building_blocks.Lambda(ref.name, ref.type_signature, call) return before_aggregate, after_aggregate
def test_returns_true_for_lambdas(self): ref_1 = building_blocks.Reference('a', tf.int32) fn_1 = building_blocks.Lambda(ref_1.name, ref_1.type_signature, ref_1) ref_2 = building_blocks.Reference('a', tf.int32) fn_2 = building_blocks.Lambda(ref_2.name, ref_2.type_signature, ref_2) self.assertTrue(tree_analysis.trees_equal(fn_1, fn_2))
def test_returns_true_for_references(self): reference_1 = building_blocks.Reference('a', tf.int32) reference_2 = building_blocks.Reference('a', tf.int32) self.assertTrue(tree_analysis.trees_equal(reference_1, reference_2))
def test_returns_false_for_comps_with_different_types(self): data = building_blocks.Data('data', tf.int32) ref = building_blocks.Reference('a', tf.int32) self.assertFalse(tree_analysis.trees_equal(data, ref)) self.assertFalse(tree_analysis.trees_equal(ref, data))
def test_returns_false_for_lambdas_with_different_parameter_types(self): ref_1 = building_blocks.Reference('a', tf.int32) fn_1 = building_blocks.Lambda(ref_1.name, ref_1.type_signature, ref_1) ref_2 = building_blocks.Reference('a', tf.float32) fn_2 = building_blocks.Lambda(ref_2.name, ref_2.type_signature, ref_2) self.assertFalse(tree_analysis.trees_equal(fn_1, fn_2))
def test_raises_type_error_with_int_excluding(self): ref = building_blocks.Reference('a', tf.int32) fn = building_blocks.Lambda(ref.name, ref.type_signature, ref) with self.assertRaises(TypeError): tree_analysis.contains_no_unbound_references(fn, 1)
def test_returns_true_with_excluded_reference(self): ref = building_blocks.Reference('a', tf.int32) fn = building_blocks.Lambda('b', tf.int32, ref) self.assertTrue( tree_analysis.contains_no_unbound_references(fn, excluding='a'))
def test_raises_reference_to_functional_type(self): function_type = computation_types.FunctionType(tf.int32, tf.int32) ref = building_blocks.Reference('x', function_type) with self.assertRaisesRegex(ValueError, 'of functional type passed'): mapreduce_transformations.consolidate_and_extract_local_processing( ref)
def test_raises_on_non_lambda_comp(self): ref = building_blocks.Reference('x', [tf.int32]) with self.assertRaises(TypeError): mapreduce_transformations.zip_selection_as_argument_to_lower_level_lambda( ref, [0])
def test_returns_true_for_selections_with_indexes(self): ref_1 = building_blocks.Reference('a', [tf.int32, tf.int32]) selection_1 = building_blocks.Selection(ref_1, index=0) ref_2 = building_blocks.Reference('a', [tf.int32, tf.int32]) selection_2 = building_blocks.Selection(ref_2, index=0) self.assertTrue(tree_analysis.trees_equal(selection_1, selection_2))
def test_raises_with_reference(self): ref = building_blocks.Reference('x', tf.int32) with self.assertRaisesRegex(ValueError, 'tensorflow'): building_block_analysis.get_device_placement_in(ref)
def _construct_reference_representing(comp_to_represent): """Helper closing over `name_generator` for name safety.""" arg_type = comp_to_represent.type_signature arg_name = next(name_generator) return building_blocks.Reference(arg_name, arg_type)
def test_returns_string_for_reference(self): comp = building_blocks.Reference('a', tf.int32) self.assertEqual(comp.compact_representation(), 'a') self.assertEqual(comp.formatted_representation(), 'a') self.assertEqual(comp.structural_representation(), 'Ref(a)')
def _extract_update(after_aggregate): """Extracts `update` from `after_aggregate`. This function is intended to be used by `get_canonical_form_for_iterative_process` only. As a result, this function does not assert that `after_aggregate` has the expected structure, the caller is expected to perform these checks before calling this function. Args: after_aggregate: The second result of splitting `after_broadcast` on aggregate intrinsics. Returns: `update` as specified by `canonical_form.CanonicalForm`, an instance of `building_blocks.CompiledComputation`. Raises: transformations.CanonicalFormCompilationError: If we extract an AST of the wrong type. """ s7_elements_in_after_aggregate_result = [0, 1] s7_output_extracted = transformations.select_output_from_lambda( after_aggregate, s7_elements_in_after_aggregate_result) s7_output_zipped = building_blocks.Lambda( s7_output_extracted.parameter_name, s7_output_extracted.parameter_type, building_block_factory.create_federated_zip( s7_output_extracted.result)) s6_elements_in_after_aggregate_parameter = [[0, 0, 0], [1, 0], [1, 1]] s6_to_s7_computation = ( transformations.zip_selection_as_argument_to_lower_level_lambda( s7_output_zipped, s6_elements_in_after_aggregate_parameter).result.function) # TODO(b/148942011): The transformation # `zip_selection_as_argument_to_lower_level_lambda` does not support selecting # from nested structures, therefore we need to pack the type signature # `<s1, s3, s4>` as `<s1, <s3, s4>>`. name_generator = building_block_factory.unique_name_generator( s6_to_s7_computation) pack_ref_name = next(name_generator) pack_ref_type = computation_types.NamedTupleType([ s6_to_s7_computation.parameter_type.member[0], computation_types.NamedTupleType([ s6_to_s7_computation.parameter_type.member[1], s6_to_s7_computation.parameter_type.member[2], ]), ]) pack_ref = building_blocks.Reference(pack_ref_name, pack_ref_type) sel_s1 = building_blocks.Selection(pack_ref, index=0) sel = building_blocks.Selection(pack_ref, index=1) sel_s3 = building_blocks.Selection(sel, index=0) sel_s4 = building_blocks.Selection(sel, index=1) result = building_blocks.Tuple([sel_s1, sel_s3, sel_s4]) pack_fn = building_blocks.Lambda(pack_ref.name, pack_ref.type_signature, result) ref_name = next(name_generator) ref_type = computation_types.FederatedType(pack_ref_type, placements.SERVER) ref = building_blocks.Reference(ref_name, ref_type) unpacked_args = building_block_factory.create_federated_map_or_apply( pack_fn, ref) call = building_blocks.Call(s6_to_s7_computation, unpacked_args) fn = building_blocks.Lambda(ref.name, ref.type_signature, call) return transformations.consolidate_and_extract_local_processing(fn)
def test_should_not_transform_reference(self): reference = building_blocks.Reference('x', tf.int32) config = tf.compat.v1.ConfigProto() tf_optimizer = compiled_computation_transformations.TensorFlowOptimizer( config) self.assertFalse(tf_optimizer.should_transform(reference))
def test_raises_on_non_tuple_parameter(self): lam = building_blocks.Lambda('x', tf.int32, building_blocks.Reference('x', tf.int32)) with self.assertRaises(TypeError): mapreduce_transformations.zip_selection_as_argument_to_lower_level_lambda( lam, [[0]])
def zero_or_one_arg_fn_to_building_block( fn, parameter_name: Optional[str], parameter_type: Optional[Any], context_stack: context_stack_base.ContextStack, suggested_name: Optional[str] = None, ) -> Tuple[building_blocks.ComputationBuildingBlock, computation_types.Type]: """Converts a zero- or one-argument `fn` into a computation building block. Args: fn: A function with 0 or 1 arguments that contains orchestration logic, i.e., that expects zero or one `values_base.Value` and returns a result convertible to the same. parameter_name: The name of the parameter, or `None` if there is't any. parameter_type: The TFF type of the parameter, or `None` if there's none. context_stack: The context stack to use. suggested_name: The optional suggested name to use for the federated context that will be used to serialize this function's body (ideally the name of the underlying Python function). It might be modified to avoid conflicts. Returns: A tuple of `(building_blocks.ComputationBuildingBlock, computation_types.Type)`, where the first element contains the logic from `fn`, and the second element contains potentially annotated type information for the result of `fn`. Raises: ValueError: if `fn` is incompatible with `parameter_type`. """ py_typecheck.check_callable(fn) py_typecheck.check_type(context_stack, context_stack_base.ContextStack) if suggested_name is not None: py_typecheck.check_type(suggested_name, str) parameter_type = computation_types.to_type(parameter_type) if isinstance(context_stack.current, federated_computation_context.FederatedComputationContext): parent_context = context_stack.current else: parent_context = None context = federated_computation_context.FederatedComputationContext( context_stack, suggested_name=suggested_name, parent=parent_context) if parameter_name is not None: py_typecheck.check_type(parameter_name, str) parameter_name = '{}_{}'.format(context.name, str(parameter_name)) with context_stack.install(context): if parameter_type is not None: result = fn( value_impl.ValueImpl( building_blocks.Reference(parameter_name, parameter_type), context_stack)) else: result = fn() if result is None: raise ValueError( 'The function defined on line {} of file {} has returned a ' '`NoneType`, but all TFF functions must return some non-`None` ' 'value.'.format(fn.__code__.co_firstlineno, fn.__code__.co_filename)) annotated_result_type = type_conversions.infer_type(result) result = value_impl.to_value(result, annotated_result_type, context_stack) result_comp = value_impl.ValueImpl.get_comp(result) annotated_type = computation_types.FunctionType(parameter_type, annotated_result_type) return building_blocks.Lambda(parameter_name, parameter_type, result_comp), annotated_type
def test_raises_on_selection_from_non_tuple(self): lam = building_blocks.Lambda( 'x', [tf.int32], building_blocks.Reference('x', [tf.int32])) with self.assertRaisesRegex(TypeError, 'nonexistent index'): mapreduce_transformations.zip_selection_as_argument_to_lower_level_lambda( lam, [[0, 0]])
def test_returns_true(self): ref = building_blocks.Reference('a', tf.int32) fn = building_blocks.Lambda(ref.name, ref.type_signature, ref) self.assertTrue(tree_analysis.contains_no_unbound_references(fn))
def test_raises_on_non_federated_selection(self): lam = building_blocks.Lambda( 'x', [tf.int32], building_blocks.Reference('x', [tf.int32])) with self.assertRaises(TypeError): mapreduce_transformations.zip_selection_as_argument_to_lower_level_lambda( lam, [[0]])
def test_returns_false(self): ref = building_blocks.Reference('a', tf.int32) fn = building_blocks.Lambda('b', tf.int32, ref) self.assertFalse(tree_analysis.contains_no_unbound_references(fn))
def test_raises_on_non_lambda(self): fed_type = computation_types.FederatedType(tf.int32, placements.CLIENTS) ref = building_blocks.Reference('x', [fed_type]) with self.assertRaises(TypeError): mapreduce_transformations.select_output_from_lambda(ref, 0)
def test_returns_true_for_lambdas_representing_identical_functions(self): ref_1 = building_blocks.Reference('a', tf.int32) fn_1 = building_blocks.Lambda('a', ref_1.type_signature, ref_1) ref_2 = building_blocks.Reference('b', tf.int32) fn_2 = building_blocks.Lambda('b', ref_2.type_signature, ref_2) self.assertTrue(tree_analysis.trees_equal(fn_1, fn_2))
def test_raises_on_reference(self): ref = building_blocks.Reference('x', tf.int32) with self.assertRaises(ValueError): building_block_analysis.count_tensorflow_ops_in(ref)
def test_returns_true_for_lambdas_referring_to_same_unbound_variables(self): ref_to_x = building_blocks.Reference('x', tf.int32) fn_1 = building_blocks.Lambda('a', tf.int32, ref_to_x) fn_2 = building_blocks.Lambda('a', tf.int32, ref_to_x) self.assertTrue(tree_analysis.trees_equal(fn_1, fn_2))
def test_should_not_transform_non_compiled_computations(self): reference = building_blocks.Reference('x', tf.int32) self.assertFalse(compiled_computation_transformations.AddUniqueIDs(). should_transform(reference))
def test_returns_false_for_references_with_different_names(self): reference_1 = building_blocks.Reference('a', tf.int32) reference_2 = building_blocks.Reference('b', tf.int32) self.assertFalse(tree_analysis.trees_equal(reference_1, reference_2))
def test_should_not_transform_non_compiled_computations(self): reference = building_blocks.Reference('x', tf.int32) self.assertFalse( compiled_computation_transformations.RaiseOnDisallowedOp( frozenset()).should_transform(reference))
def test_returns_false_for_selections_with_differet_sources(self): ref_1 = building_blocks.Reference('a', [tf.int32, tf.int32]) selection_1 = building_blocks.Selection(ref_1, index=0) ref_2 = building_blocks.Reference('b', [tf.int32, tf.int32]) selection_2 = building_blocks.Selection(ref_2, index=1) self.assertFalse(tree_analysis.trees_equal(selection_1, selection_2))
def _create_before_and_after_broadcast_for_no_broadcast(tree): r"""Creates a before and after broadcast computations for the given `tree`. This function returns the two ASTs: Lambda | Tuple | [] Lambda(x) | Call / \ Comp Sel(0) / Ref(x) The first AST is an empty structure that has a type signature satisfying the requirements of before broadcast. In the second AST, `Comp` is `tree`; `Lambda` has a type signature satisfying the requirements of after broadcast; and the argument passed to `Comp` is a selection from the parameter of `Lambda` which intentionally drops `c2` on the floor. This function is intended to be used by `get_canonical_form_for_iterative_process` to create before and after broadcast computations for the given `tree` when there is no `intrinsic_defs.FEDERATED_BROADCAST` in `tree`. As a result, this function does not assert that there is no `intrinsic_defs.FEDERATED_BROADCAST` in `tree` and it does not assert that `tree` has the expected structure, the caller is expected to perform these checks before calling this function. Args: tree: An instance of `building_blocks.ComputationBuildingBlock`. Returns: A pair of the form `(before, after)`, where each of `before` and `after` is a `tff_framework.ComputationBuildingBlock` that represents a part of the result as specified by `transformations.force_align_and_split_by_intrinsics`. """ name_generator = building_block_factory.unique_name_generator(tree) parameter_name = next(name_generator) empty_tuple = building_blocks.Tuple([]) value = building_block_factory.create_federated_value( empty_tuple, placements.SERVER) before_broadcast = building_blocks.Lambda(parameter_name, tree.type_signature.parameter, value) parameter_name = next(name_generator) type_signature = computation_types.FederatedType( before_broadcast.type_signature.result.member, placements.CLIENTS) parameter_type = computation_types.NamedTupleType( [tree.type_signature.parameter, type_signature]) ref = building_blocks.Reference(parameter_name, parameter_type) arg = building_blocks.Selection(ref, index=0) call = building_blocks.Call(tree, arg) after_broadcast = building_blocks.Lambda(ref.name, ref.type_signature, call) return before_broadcast, after_broadcast