def test_returns_false_for_tuples_with_different_names(self): data_1 = building_blocks.Data('data', tf.int32) tuple_1 = building_blocks.Tuple([('a', data_1), ('b', data_1)]) data_2 = building_blocks.Data('data', tf.float32) tuple_2 = building_blocks.Tuple([('c', data_2), ('d', data_2)]) self.assertFalse(tree_analysis._trees_equal(tuple_1, tuple_2))
def test_returns_false_for_tuples_with_different_elements(self): data_1 = building_blocks.Data('data', tf.int32) tuple_1 = building_blocks.Tuple([data_1, data_1]) data_2 = building_blocks.Data('data', tf.float32) tuple_2 = building_blocks.Tuple([data_2, data_2]) self.assertFalse(tree_analysis._trees_equal(tuple_1, tuple_2))
def _dictlike_items_to_value(items, context_stack, container_type) -> ValueImpl: value = building_blocks.Tuple( [(k, ValueImpl.get_comp(to_value(v, None, context_stack))) for k, v in items], container_type) return ValueImpl(value, context_stack)
def create_nested_syntax_tree(): r"""Constructs computation with explicit ordering for testing traversals. The goal of this computation is to exercise each switch in transform_postorder_with_symbol_bindings, at least all those that recurse. The computation this function constructs can be represented as below. Notice that the body of the Lambda *does not depend on the Lambda's parameter*, so that if we were actually executing this call the argument will be thrown away. All leaf nodes are instances of `building_blocks.Data`. Call / \ Lambda('arg') Data('k') | Block('y','z')------------- / | ['y'=Data('a'),'z'=Data('b')] | Tuple / \ Block('v') Block('x')------- / \ | | ['v'=Selection] Data('g') ['x'=Data('h'] | | | | | | Block('w') | / \ Tuple ------ ['w'=Data('i'] Data('j') / \ Block('t') Block('u') / \ / \ ['t'=Data('c')] Data('d') ['u'=Data('e')] Data('f') Postorder traversals: If we are reading Data URIs, results of a postorder traversal should be: [a, b, c, d, e, f, g, h, i, j, k] If we are reading locals declarations, results of a postorder traversal should be: [t, u, v, w, x, y, z] And if we are reading both in an interleaved fashion, results of a postorder traversal should be: [a, b, c, d, t, e, f, u, g, v, h, i, j, w, x, y, z, k] Preorder traversals: If we are reading Data URIs, results of a preorder traversal should be: [a, b, c, d, e, f, g, h, i, j, k] If we are reading locals declarations, results of a preorder traversal should be: [y, z, v, t, u, x, w] And if we are reading both in an interleaved fashion, results of a preorder traversal should be: [y, z, a, b, v, t, c, d, u, e, f, g, x, h, w, i, j, k] Since we are also exposing the ability to hook into variable declarations, it is worthwhile considering the order in which variables are assigned in this tree. Notice that this order maps neither to preorder nor to postorder when purely considering the nodes of the tree above. This would be: [arg, y, z, t, u, v, x, w] Returns: An instance of `building_blocks.ComputationBuildingBlock` satisfying the description above. """ data_c = building_blocks.Data('c', tf.float32) data_d = building_blocks.Data('d', tf.float32) left_most_leaf = building_blocks.Block([('t', data_c)], data_d) data_e = building_blocks.Data('e', tf.float32) data_f = building_blocks.Data('f', tf.float32) center_leaf = building_blocks.Block([('u', data_e)], data_f) inner_tuple = building_blocks.Tuple([left_most_leaf, center_leaf]) selected = building_blocks.Selection(inner_tuple, index=0) data_g = building_blocks.Data('g', tf.float32) middle_block = building_blocks.Block([('v', selected)], data_g) data_i = building_blocks.Data('i', tf.float32) data_j = building_blocks.Data('j', tf.float32) right_most_endpoint = building_blocks.Block([('w', data_i)], data_j) data_h = building_blocks.Data('h', tf.int32) right_child = building_blocks.Block([('x', data_h)], right_most_endpoint) result = building_blocks.Tuple([middle_block, right_child]) data_a = building_blocks.Data('a', tf.float32) data_b = building_blocks.Data('b', tf.float32) dummy_outer_block = building_blocks.Block([('y', data_a), ('z', data_b)], result) dummy_lambda = building_blocks.Lambda('arg', tf.float32, dummy_outer_block) dummy_arg = building_blocks.Data('k', tf.float32) called_lambda = building_blocks.Call(dummy_lambda, dummy_arg) return called_lambda
def transform_preorder( comp: building_blocks.ComputationBuildingBlock, transform: Callable[[building_blocks.ComputationBuildingBlock], TransformReturnType] ) -> TransformReturnType: """Walks the AST of `comp` preorder, calling `transform` on the way down. Notice that this function will stop walking the tree when its transform function modifies a node; this is to prevent the caller from unexpectedly kicking off an infinite recursion. For this purpose the transform function must identify when it has transformed the structure of a building block; if the structure of the building block is modified but `False` is returned as the second element of the tuple returned by `transform`, `transform_preorder` may result in an infinite recursion. Args: comp: Instance of `building_blocks.ComputationBuildingBlock` to be transformed in a preorder fashion. transform: Transform function to be applied to the nodes of `comp`. Must return a two-tuple whose first element is a `building_blocks.ComputationBuildingBlock` and whose second element is a Boolean. If the computation which is passed to `comp` is returned in a modified state, must return `True` for the second element. Returns: A two-tuple, whose first element is modified version of `comp`, and whose second element is a Boolean indicating whether `comp` was transformed during the walk. Raises: TypeError: If the argument types don't match those specified above. """ py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) py_typecheck.check_callable(transform) inner_comp, modified = transform(comp) if modified: return inner_comp, modified if isinstance(inner_comp, ( building_blocks.CompiledComputation, building_blocks.Data, building_blocks.Intrinsic, building_blocks.Placement, building_blocks.Reference, )): return inner_comp, modified elif isinstance(inner_comp, building_blocks.Lambda): transformed_result, result_modified = transform_preorder( inner_comp.result, transform) if not (modified or result_modified): return inner_comp, False return building_blocks.Lambda(inner_comp.parameter_name, inner_comp.parameter_type, transformed_result), True elif isinstance(inner_comp, building_blocks.Tuple): elements_modified = False elements = [] for name, val in anonymous_tuple.iter_elements(inner_comp): result, result_modified = transform_preorder(val, transform) elements_modified = modified or result_modified elements.append((name, result)) if not (modified or elements_modified): return inner_comp, False return building_blocks.Tuple(elements), True elif isinstance(inner_comp, building_blocks.Selection): transformed_source, source_modified = transform_preorder( inner_comp.source, transform) if not (modified or source_modified): return inner_comp, False return building_blocks.Selection(transformed_source, inner_comp.name, inner_comp.index), True elif isinstance(inner_comp, building_blocks.Call): transformed_fn, fn_modified = transform_preorder(inner_comp.function, transform) if inner_comp.argument is not None: transformed_arg, arg_modified = transform_preorder( inner_comp.argument, transform) else: transformed_arg = None arg_modified = False if not (modified or fn_modified or arg_modified): return inner_comp, False return building_blocks.Call(transformed_fn, transformed_arg), True elif isinstance(inner_comp, building_blocks.Block): transformed_variables = [] values_modified = False for key, value in inner_comp.locals: transformed_value, value_modified = transform_preorder(value, transform) transformed_variables.append((key, transformed_value)) values_modified = values_modified or value_modified transformed_result, result_modified = transform_preorder( comp.result, transform) if not (modified or values_modified or result_modified): return inner_comp, False return building_blocks.Block(transformed_variables, transformed_result), True else: raise NotImplementedError( 'Unrecognized computation building block: {}'.format(str(inner_comp)))
def transform_postorder(comp, transform): """Traverses `comp` recursively postorder and replaces its constituents. For each element of `comp` viewed as an expression tree, the transformation `transform` is applied first to building blocks it is parameterized by, then the element itself. The transformation `transform` should act as an identity function on the kinds of elements (computation building blocks) it does not care to transform. This corresponds to a post-order traversal of the expression tree, i.e., parameters are always transformed left-to-right (in the order in which they are listed in building block constructors), then the parent is visited and transformed with the already-visited, and possibly transformed arguments in place. NOTE: In particular, in `Call(f,x)`, both `f` and `x` are arguments to `Call`. Therefore, `f` is transformed into `f'`, next `x` into `x'` and finally, `Call(f',x')` is transformed at the end. Args: comp: A `computation_building_block.ComputationBuildingBlock` to traverse and transform bottom-up. transform: The transformation to apply locally to each building block in `comp`. It is a Python function that accepts a building block at input, and should return a (building block, bool) tuple as output, where the building block is a `computation_building_block.ComputationBuildingBlock` representing either the original building block or a transformed building block and the bool is a flag indicating if the building block was modified as. Returns: The result of applying `transform` to parts of `comp` in a bottom-up fashion, along with a Boolean with the value `True` if `comp` was transformed and `False` if it was not. Raises: TypeError: If the arguments are of the wrong computation_types. NotImplementedError: If the argument is a kind of computation building block that is currently not recognized. """ py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) if isinstance(comp, ( building_blocks.CompiledComputation, building_blocks.Data, building_blocks.Intrinsic, building_blocks.Placement, building_blocks.Reference, )): return transform(comp) elif isinstance(comp, building_blocks.Selection): source, source_modified = transform_postorder(comp.source, transform) if source_modified: comp = building_blocks.Selection(source, comp.name, comp.index) comp, comp_modified = transform(comp) return comp, comp_modified or source_modified elif isinstance(comp, building_blocks.Tuple): elements = [] elements_modified = False for key, value in anonymous_tuple.iter_elements(comp): value, value_modified = transform_postorder(value, transform) elements.append((key, value)) elements_modified = elements_modified or value_modified if elements_modified: comp = building_blocks.Tuple(elements) comp, comp_modified = transform(comp) return comp, comp_modified or elements_modified elif isinstance(comp, building_blocks.Call): fn, fn_modified = transform_postorder(comp.function, transform) if comp.argument is not None: arg, arg_modified = transform_postorder(comp.argument, transform) else: arg, arg_modified = (None, False) if fn_modified or arg_modified: comp = building_blocks.Call(fn, arg) comp, comp_modified = transform(comp) return comp, comp_modified or fn_modified or arg_modified elif isinstance(comp, building_blocks.Lambda): result, result_modified = transform_postorder(comp.result, transform) if result_modified: comp = building_blocks.Lambda(comp.parameter_name, comp.parameter_type, result) comp, comp_modified = transform(comp) return comp, comp_modified or result_modified elif isinstance(comp, building_blocks.Block): variables = [] variables_modified = False for key, value in comp.locals: value, value_modified = transform_postorder(value, transform) variables.append((key, value)) variables_modified = variables_modified or value_modified result, result_modified = transform_postorder(comp.result, transform) if variables_modified or result_modified: comp = building_blocks.Block(variables, result) comp, comp_modified = transform(comp) return comp, comp_modified or variables_modified or result_modified else: raise NotImplementedError( 'Unrecognized computation building block: {}'.format(str(comp)))
def _extract_update(after_aggregate): """Extracts `update` from `after_aggregate`. This function is intended to be used by `get_canonical_form_for_iterative_process` only. As a result, this function does not assert that `after_aggregate` has the expected structure, the caller is expected to perform these checks before calling this function. Args: after_aggregate: The second result of splitting `after_broadcast` on aggregate intrinsics. Returns: `update` as specified by `canonical_form.CanonicalForm`, an instance of `building_blocks.CompiledComputation`. Raises: transformations.CanonicalFormCompilationError: If we extract an AST of the wrong type. """ s7_elements_in_after_aggregate_result = [0, 1] s7_output_extracted = transformations.select_output_from_lambda( after_aggregate, s7_elements_in_after_aggregate_result) s7_output_zipped = building_blocks.Lambda( s7_output_extracted.parameter_name, s7_output_extracted.parameter_type, building_block_factory.create_federated_zip( s7_output_extracted.result)) s6_elements_in_after_aggregate_parameter = [[0, 0, 0], [1, 0], [1, 1]] s6_to_s7_computation = ( transformations.zip_selection_as_argument_to_lower_level_lambda( s7_output_zipped, s6_elements_in_after_aggregate_parameter).result.function) # TODO(b/148942011): The transformation # `zip_selection_as_argument_to_lower_level_lambda` does not support selecting # from nested structures, therefore we need to pack the type signature # `<s1, s3, s4>` as `<s1, <s3, s4>>`. name_generator = building_block_factory.unique_name_generator( s6_to_s7_computation) pack_ref_name = next(name_generator) pack_ref_type = computation_types.NamedTupleType([ s6_to_s7_computation.parameter_type.member[0], computation_types.NamedTupleType([ s6_to_s7_computation.parameter_type.member[1], s6_to_s7_computation.parameter_type.member[2], ]), ]) pack_ref = building_blocks.Reference(pack_ref_name, pack_ref_type) sel_s1 = building_blocks.Selection(pack_ref, index=0) sel = building_blocks.Selection(pack_ref, index=1) sel_s3 = building_blocks.Selection(sel, index=0) sel_s4 = building_blocks.Selection(sel, index=1) result = building_blocks.Tuple([sel_s1, sel_s3, sel_s4]) pack_fn = building_blocks.Lambda(pack_ref.name, pack_ref.type_signature, result) ref_name = next(name_generator) ref_type = computation_types.FederatedType(pack_ref_type, placements.SERVER) ref = building_blocks.Reference(ref_name, ref_type) unpacked_args = building_block_factory.create_federated_map_or_apply( pack_fn, ref) call = building_blocks.Call(s6_to_s7_computation, unpacked_args) fn = building_blocks.Lambda(ref.name, ref.type_signature, call) return transformations.consolidate_and_extract_local_processing(fn)
def test_constructs_aggregate_of_tuple_with_one_element(self): called_intrinsic = test_utils.create_dummy_called_federated_aggregate( accumulate_parameter_name='a', merge_parameter_name='b', report_parameter_name='c') calls = building_blocks.Tuple((called_intrinsic, called_intrinsic)) comp = calls transformed_comp, modified = compiler_transformations.dedupe_and_merge_tuple_intrinsics( comp, intrinsic_defs.FEDERATED_AGGREGATE.uri) federated_agg = [] def _find_federated_aggregate(comp): if building_block_analysis.is_called_intrinsic( comp, intrinsic_defs.FEDERATED_AGGREGATE.uri): federated_agg.append(comp) return comp, False transformation_utils.transform_postorder(transformed_comp, _find_federated_aggregate) self.assertTrue(modified) self.assertLen(federated_agg, 1) self.assertLen(federated_agg[0].type_signature.member, 1) self.assertEqual( transformed_comp.formatted_representation(), '(_var1 -> <\n' ' _var1[0],\n' ' _var1[0]\n' '>)((x -> <\n' ' x[0]\n' '>)((let\n' ' value=federated_aggregate(<\n' ' federated_map(<\n' ' (arg -> <\n' ' arg\n' ' >),\n' ' <\n' ' data\n' ' >[0]\n' ' >),\n' ' <\n' ' data\n' ' >,\n' ' (let\n' ' _var1=<\n' ' (a -> data)\n' ' >\n' ' in (_var2 -> <\n' ' _var1[0](<\n' ' <\n' ' _var2[0][0],\n' ' _var2[1][0]\n' ' >\n' ' >[0])\n' ' >)),\n' ' (let\n' ' _var3=<\n' ' (b -> data)\n' ' >\n' ' in (_var4 -> <\n' ' _var3[0](<\n' ' <\n' ' _var4[0][0],\n' ' _var4[1][0]\n' ' >\n' ' >[0])\n' ' >)),\n' ' (let\n' ' _var5=<\n' ' (c -> data)\n' ' >\n' ' in (_var6 -> <\n' ' _var5[0](_var6[0])\n' ' >))\n' ' >)\n' ' in <\n' ' federated_apply(<\n' ' (arg -> arg[0]),\n' ' value\n' ' >)\n' '>)))')
def _create_before_and_after_broadcast_for_no_broadcast(tree): r"""Creates a before and after broadcast computations for the given `tree`. This function returns the two ASTs: Lambda | Tuple | [] Lambda(x) | Call / \ Comp Sel(0) / Ref(x) The first AST is an empty structure that has a type signature satisfying the requirements of before broadcast. In the second AST, `Comp` is `tree`; `Lambda` has a type signature satisfying the requirements of after broadcast; and the argument passed to `Comp` is a selection from the parameter of `Lambda` which intentionally drops `c2` on the floor. This function is intended to be used by `get_canonical_form_for_iterative_process` to create before and after broadcast computations for the given `tree` when there is no `intrinsic_defs.FEDERATED_BROADCAST` in `tree`. As a result, this function does not assert that there is no `intrinsic_defs.FEDERATED_BROADCAST` in `tree` and it does not assert that `tree` has the expected structure, the caller is expected to perform these checks before calling this function. Args: tree: An instance of `building_blocks.ComputationBuildingBlock`. Returns: A pair of the form `(before, after)`, where each of `before` and `after` is a `tff_framework.ComputationBuildingBlock` that represents a part of the result as specified by `transformations.force_align_and_split_by_intrinsics`. """ name_generator = building_block_factory.unique_name_generator(tree) parameter_name = next(name_generator) empty_tuple = building_blocks.Tuple([]) value = building_block_factory.create_federated_value( empty_tuple, placements.SERVER) before_broadcast = building_blocks.Lambda(parameter_name, tree.type_signature.parameter, value) parameter_name = next(name_generator) type_signature = computation_types.FederatedType( before_broadcast.type_signature.result.member, placements.CLIENTS) parameter_type = computation_types.NamedTupleType( [tree.type_signature.parameter, type_signature]) ref = building_blocks.Reference(parameter_name, parameter_type) arg = building_blocks.Selection(ref, index=0) call = building_blocks.Call(tree, arg) after_broadcast = building_blocks.Lambda(ref.name, ref.type_signature, call) return before_broadcast, after_broadcast
def _create_before_and_after_aggregate_for_no_federated_secure_sum(tree): r"""Creates a before and after aggregate computations for the given `tree`. Lambda | Tuple | [Comp, Tuple] | [Tuple, []] | [] Lambda(x) | Call / \ Comp Tuple | [Sel(0), Sel(0)] / / Ref(x) Sel(1) / Ref(x) In the first AST, the first element returned by `Lambda`, `Comp`, is the result of the before aggregate returned by force aligning and splitting `tree` by `intrinsic_defs.FEDERATED_AGGREGATE.uri` and the second element returned by `Lambda` is an empty structure that represents the argument to the secure sum intrinsic. Therefore, the first AST has a type signature satisfying the requirements of before aggregate. In the second AST, `Comp` is the after aggregate returned by force aligning and splitting `tree` by intrinsic_defs.FEDERATED_AGGREGATE.uri; `Lambda` has a type signature satisfying the requirements of after aggregate; and the argument passed to `Comp` is a selection from the parameter of `Lambda` which intentionally drops `s4` on the floor. This function is intended to be used by `get_canonical_form_for_iterative_process` to create before and after broadcast computations for the given `tree` when there is no `intrinsic_defs.FEDERATED_SECURE_SUM` in `tree`. As a result, this function does not assert that there is no `intrinsic_defs.FEDERATED_SECURE_SUM` in `tree` and it does not assert that `tree` has the expected structure, the caller is expected to perform these checks before calling this function. Args: tree: An instance of `building_blocks.ComputationBuildingBlock`. Returns: A pair of the form `(before, after)`, where each of `before` and `after` is a `tff_framework.ComputationBuildingBlock` that represents a part of the result as specified by `transformations.force_align_and_split_by_intrinsics`. """ name_generator = building_block_factory.unique_name_generator(tree) before_aggregate, after_aggregate = ( transformations.force_align_and_split_by_intrinsics( tree, [intrinsic_defs.FEDERATED_AGGREGATE.uri])) empty_tuple = building_blocks.Tuple([]) value = building_block_factory.create_federated_value( empty_tuple, placements.CLIENTS) bitwidth = empty_tuple args = building_blocks.Tuple([value, bitwidth]) result = building_blocks.Tuple([before_aggregate.result, args]) before_aggregate = building_blocks.Lambda(before_aggregate.parameter_name, before_aggregate.parameter_type, result) ref_name = next(name_generator) s4_type = computation_types.FederatedType([], placements.SERVER) ref_type = computation_types.NamedTupleType([ after_aggregate.parameter_type[0], computation_types.NamedTupleType([ after_aggregate.parameter_type[1], s4_type, ]), ]) ref = building_blocks.Reference(ref_name, ref_type) sel_arg = building_blocks.Selection(ref, index=0) sel = building_blocks.Selection(ref, index=1) sel_s3 = building_blocks.Selection(sel, index=0) arg = building_blocks.Tuple([sel_arg, sel_s3]) call = building_blocks.Call(after_aggregate, arg) after_aggregate = building_blocks.Lambda(ref.name, ref.type_signature, call) return before_aggregate, after_aggregate
def test_reduces_lambda_returning_empty_tuple_to_tf(self): self.skipTest('Depends on a lower level fix, currently in review.') empty_tuple = building_blocks.Tuple([]) lam = building_blocks.Lambda('x', tf.int32, empty_tuple) extracted_tf = transformations.consolidate_and_extract_local_processing(lam) self.assertIsInstance(extracted_tf, building_blocks.CompiledComputation)
def _create_empty_function(type_elements): ref_name = next(name_generator) ref_type = computation_types.NamedTupleType(type_elements) ref = building_blocks.Reference(ref_name, ref_type) empty_tuple = building_blocks.Tuple([]) return building_blocks.Lambda(ref.name, ref.type_signature, empty_tuple)
def get_canonical_form_for_iterative_process(iterative_process): """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process. This function transforms computations from the input `iterative_process` into an instance of `tff.backends.mapreduce.CanonicalForm`. Args: iterative_process: An instance of `tff.utils.IterativeProcess`. Returns: An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this process. Raises: TypeError: If the arguments are of the wrong types. transformations.CanonicalFormCompilationError: If the compilation process fails. """ py_typecheck.check_type(iterative_process, computation_utils.IterativeProcess) initialize_comp = building_blocks.ComputationBuildingBlock.from_proto( iterative_process.initialize._computation_proto) # pylint: disable=protected-access next_comp = building_blocks.ComputationBuildingBlock.from_proto( iterative_process.next._computation_proto) # pylint: disable=protected-access if not (isinstance(next_comp.type_signature.parameter, computation_types.NamedTupleType) and isinstance(next_comp.type_signature.result, computation_types.NamedTupleType)): raise TypeError( 'Any IterativeProcess compatible with CanonicalForm must ' 'have a `next` function which takes and returns instances ' 'of `tff.NamedTupleType`; your next function takes ' 'parameters of type {} and returns results of type {}'.format( next_comp.type_signature.parameter, next_comp.type_signature.result)) if len(next_comp.type_signature.result) == 2: next_result = next_comp.result if isinstance(next_result, building_blocks.Tuple): dummy_clients_metrics_appended = building_blocks.Tuple([ next_result[0], next_result[1], intrinsics.federated_value([], placements.CLIENTS)._comp # pylint: disable=protected-access ]) else: dummy_clients_metrics_appended = building_blocks.Tuple([ building_blocks.Selection(next_result, index=0), building_blocks.Selection(next_result, index=1), intrinsics.federated_value([], placements.CLIENTS)._comp # pylint: disable=protected-access ]) next_comp = building_blocks.Lambda(next_comp.parameter_name, next_comp.parameter_type, dummy_clients_metrics_appended) initialize_comp = replace_intrinsics_with_bodies(initialize_comp) next_comp = replace_intrinsics_with_bodies(next_comp) tree_analysis.check_intrinsics_whitelisted_for_reduction(initialize_comp) tree_analysis.check_intrinsics_whitelisted_for_reduction(next_comp) tree_analysis.check_broadcast_not_dependent_on_aggregate(next_comp) before_broadcast, after_broadcast = ( transformations.force_align_and_split_by_intrinsic( next_comp, intrinsic_defs.FEDERATED_BROADCAST.uri)) before_aggregate, after_aggregate = ( transformations.force_align_and_split_by_intrinsic( after_broadcast, intrinsic_defs.FEDERATED_AGGREGATE.uri)) init_info_packed = pack_initialize_comp_type_signature( initialize_comp.type_signature) next_info_packed = pack_next_comp_type_signature(next_comp.type_signature, init_info_packed) before_broadcast_info_packed = ( check_and_pack_before_broadcast_type_signature( before_broadcast.type_signature, next_info_packed)) before_aggregate_info_packed = ( check_and_pack_before_aggregate_type_signature( before_aggregate.type_signature, before_broadcast_info_packed)) canonical_form_types = check_and_pack_after_aggregate_type_signature( after_aggregate.type_signature, before_aggregate_info_packed) initialize = transformations.consolidate_and_extract_local_processing( initialize_comp) if not (isinstance(initialize, building_blocks.CompiledComputation) and initialize.type_signature.result == canonical_form_types['initialize_type'].member): raise transformations.CanonicalFormCompilationError( 'Compilation of initialize has failed. Expected to extract a ' '`building_blocks.CompiledComputation` of type {}, instead we extracted ' 'a {} of type {}.'.format(next_comp.type_signature.parameter[0], type(initialize), initialize.type_signature.result)) prepare = extract_prepare(before_broadcast, canonical_form_types) work = extract_work(before_aggregate, after_aggregate, canonical_form_types) zero_noarg_function, accumulate, merge, report = extract_aggregate_functions( before_aggregate, canonical_form_types) update = extract_update(after_aggregate, canonical_form_types) cf = canonical_form.CanonicalForm( computation_wrapper_instances.building_block_to_computation(initialize), computation_wrapper_instances.building_block_to_computation(prepare), computation_wrapper_instances.building_block_to_computation(work), computation_wrapper_instances.building_block_to_computation( zero_noarg_function), computation_wrapper_instances.building_block_to_computation(accumulate), computation_wrapper_instances.building_block_to_computation(merge), computation_wrapper_instances.building_block_to_computation(report), computation_wrapper_instances.building_block_to_computation(update)) return cf
def test_returns_true_for_tuples(self): data_1 = building_blocks.Data('data', tf.int32) tuple_1 = building_blocks.Tuple([data_1, data_1]) data_2 = building_blocks.Data('data', tf.int32) tuple_2 = building_blocks.Tuple([data_2, data_2]) self.assertTrue(tree_analysis._trees_equal(tuple_1, tuple_2))
def create_tensorflow_representing_block(block): """Generates non-duplicated TensorFlow for Block locals binding called graphs. Assuming that the argument `block` satisfies the following conditions: 1. The local variables in `block` are all called graphs, with arbitrary arguments. 2. The result of the Block contains tuples, selections and references, but nothing else. Then `create_tensorflow_representing_block` will generate a structure, which may contain tensorflow functions, calls to tensorflow functions, and references, but which have generated this TensorFlow code without duplicating work done by referencing the block locals. Args: block: Instance of `building_blocks.Block`, whose local variables are all called instances of `building_blocks.CompiledComputation`, and whose result contains only instances of `building_blocks.Reference`, `building_blocks.Selection` or `building_blocks.Tuple`. Returns: A transformed version of `block`, which has pushed references to the called graphs in the locals of `block` into TensorFlow. Raises: TypeError: If `block` is not an instance of `building_blocks.Block`. ValueError: If the locals of `block` are anything other than called graphs, or if the result of `block` contains anything other than selections, references and tuples. """ _check_parameters_for_tf_block_generation(block) name_generator = building_block_factory.unique_name_generator(block) def _construct_reference_representing(comp_to_represent): """Helper closing over `name_generator` for name safety.""" arg_type = comp_to_represent.type_signature arg_name = next(name_generator) return building_blocks.Reference(arg_name, arg_type) top_level_ref = _get_unbound_ref(block) named_comp_classes = transformations.group_block_locals_by_namespace(block) if top_level_ref: first_comps = [x[1] for x in named_comp_classes[0]] tup = building_blocks.Tuple([top_level_ref] + first_comps) output_comp = construct_tensorflow_calling_lambda_on_concrete_arg( top_level_ref, tup, top_level_ref) name_to_output_index = {top_level_ref.name: 0} else: output_comp = building_block_factory.create_compiled_empty_tuple() name_to_output_index = {} block_local_names = [x[0] for x in block.locals] def _update_name_to_output_index(name_class): """Helper closing over `name_to_output_index` and `block_local_names`.""" offset = len(name_to_output_index.keys()) for idx, comp_name in enumerate(name_class): for var_name in block_local_names: if var_name == comp_name: name_to_output_index[var_name] = idx + offset if top_level_ref: first_names = [x[0] for x in named_comp_classes[0]] _update_name_to_output_index(first_names) remaining_comp_classes = named_comp_classes[1:] else: remaining_comp_classes = named_comp_classes[:] for named_comp_class in remaining_comp_classes: if named_comp_class: comp_class = [x[1] for x in named_comp_class] name_class = [x[0] for x in named_comp_class] arg_ref = _construct_reference_representing(output_comp) output_comp = _construct_tensorflow_representing_single_local_assignment( arg_ref, comp_class, output_comp, name_to_output_index) _update_name_to_output_index(name_class) arg_ref = _construct_reference_representing(output_comp) result_replaced = _replace_references_in_comp_with_selections_from_arg( block.result, arg_ref, name_to_output_index) comp_called = construct_tensorflow_calling_lambda_on_concrete_arg( arg_ref, result_replaced, output_comp) return comp_called, True
def to_value( arg: Any, type_spec, context_stack: context_stack_base.ContextStack, ) -> ValueImpl: """Converts the argument into an instance of `tff.Value`. The types of non-`tff.Value` arguments that are currently convertible to `tff.Value` include the following: * Lists, tuples, anonymous tuples, named tuples, and dictionaries, all of which are converted into instances of `tff.Tuple`. * Placement literals, converted into instances of `tff.Placement`. * Computations. * Python constants of type `str`, `int`, `float`, `bool` * Numpy objects inherting from `np.ndarray` or `np.generic` (the parent of numpy scalar types) Args: arg: Either an instance of `tff.Value`, or an argument convertible to `tff.Value`. The argument must not be `None`. type_spec: An optional `computation_types.Type` or value convertible to it by `computation_types.to_type` which specifies the desired type signature of the resulting value. This allows for disambiguating the target type (e.g., when two TFF types can be mapped to the same Python representations), or `None` if none available, in which case TFF tries to determine the type of the TFF value automatically. context_stack: The context stack to use. Returns: An instance of `tff.Value` corresponding to the given `arg`, and of TFF type matching the `type_spec` if specified (not `None`). Raises: TypeError: if `arg` is of an unsupported type, or of a type that does not match `type_spec`. Raises explicit error message if TensorFlow constructs are encountered, as TensorFlow code should be sealed away from TFF federated context. """ py_typecheck.check_type(context_stack, context_stack_base.ContextStack) if type_spec is not None: type_spec = computation_types.to_type(type_spec) type_utils.check_well_formed(type_spec) if isinstance(arg, ValueImpl): result = arg elif isinstance(arg, building_blocks.ComputationBuildingBlock): result = ValueImpl(arg, context_stack) elif isinstance(arg, placement_literals.PlacementLiteral): result = ValueImpl(building_blocks.Placement(arg), context_stack) elif isinstance(arg, computation_base.Computation): result = ValueImpl( building_blocks.CompiledComputation( computation_impl.ComputationImpl.get_proto(arg)), context_stack) elif type_spec is not None and isinstance(type_spec, computation_types.SequenceType): result = _wrap_sequence_as_value(arg, type_spec.element, context_stack) elif isinstance(arg, anonymous_tuple.AnonymousTuple): result = ValueImpl( building_blocks.Tuple([ (k, ValueImpl.get_comp(to_value(v, None, context_stack))) for k, v in anonymous_tuple.iter_elements(arg) ]), context_stack) elif py_typecheck.is_named_tuple(arg): result = to_value(arg._asdict(), None, context_stack) # pytype: disable=attribute-error elif py_typecheck.is_attrs(arg): result = to_value( attr.asdict(arg, dict_factory=collections.OrderedDict, recurse=False), None, context_stack) elif isinstance(arg, dict): if isinstance(arg, collections.OrderedDict): items = arg.items() else: items = sorted(arg.items()) value = building_blocks.Tuple([ (k, ValueImpl.get_comp(to_value(v, None, context_stack))) for k, v in items ]) result = ValueImpl(value, context_stack) elif isinstance(arg, (tuple, list)): result = ValueImpl( building_blocks.Tuple([ ValueImpl.get_comp(to_value(x, None, context_stack)) for x in arg ]), context_stack) elif isinstance(arg, tensorflow_utils.TENSOR_REPRESENTATION_TYPES): result = _wrap_constant_as_value(arg, context_stack) elif isinstance(arg, (tf.Tensor, tf.Variable)): raise TypeError( 'TensorFlow construct {} has been encountered in a federated ' 'context. TFF does not support mixing TF and federated orchestration ' 'code. Please wrap any TensorFlow constructs with ' '`tff.tf_computation`.'.format(arg)) elif isinstance(arg, function_utils.PolymorphicFunction): # TODO(b/129567727) remove this case when this is no longer an error raise TypeError( 'Polymorphic computations cannot be converted to a TFF value. Consider ' 'explicitly specifying the argument types of a computation before ' 'passing it to a function that requires a TFF value (such as a TFF ' 'intrinsic like federated_map).') else: raise TypeError( 'Unable to interpret an argument of type {} as a TFF value.'. format(py_typecheck.type_string(type(arg)))) py_typecheck.check_type(result, ValueImpl) if (type_spec is not None and not type_utils.is_assignable_from( type_spec, result.type_signature)): raise TypeError( 'The supplied argument maps to TFF type {}, which is incompatible with ' 'the requested type {}.'.format(result.type_signature, type_spec)) return result
def test_reduces_lambda_returning_empty_tuple_to_tf(self): empty_tuple = building_blocks.Tuple([]) lam = building_blocks.Lambda('x', tf.int32, empty_tuple) extracted_tf = transformations.consolidate_and_extract_local_processing(lam) self.assertIsInstance(extracted_tf, building_blocks.CompiledComputation)