def _can_extract_intrinsics_to_top_level_lambda(comp, uri): """Tests if the intrinsic for the given `uri` can be extracted. This currently maps identically to: the called intrinsics we intend to hoist don't close over any intermediate variables. That is, any variables other than potentiall the top-level parameter the computation itself declares. Args: comp: The `building_blocks.Lambda` to test. The names of lambda parameters and block variables in `comp` must be unique. uri: A Python `list` of URI of intrinsics. Returns: `True` if the intrinsic can be extracted, otherwise `False`. """ py_typecheck.check_type(comp, building_blocks.Lambda) py_typecheck.check_type(uri, list) for x in uri: py_typecheck.check_type(x, str) tree_analysis.check_has_unique_names(comp) intrinsics = _get_called_intrinsics(comp, uri) return all( tree_analysis.contains_no_unbound_references(x, comp.parameter_name) for x in intrinsics)
def _as_function_of_some_federated_subparameters( bb: building_blocks.Lambda, paths, ) -> building_blocks.Lambda: """Turns `x -> ...only uses parts of x...` into `parts_of_x -> ...`.""" tree_analysis.check_has_unique_names(bb) bb = _prepare_for_rebinding(bb) name_generator = building_block_factory.unique_name_generator(bb) type_list = [] int_paths = [] for path in paths: selected_type = bb.parameter_type int_path = [] for index in path: if not selected_type.is_struct(): raise _ParameterSelectionError(path, bb) if isinstance(index, int): if index > len(selected_type): raise _ParameterSelectionError(path, bb) int_path.append(index) else: py_typecheck.check_type(index, str) if not structure.has_field(selected_type, index): raise _ParameterSelectionError(path, bb) int_path.append( structure.name_to_index_map(selected_type)[index]) selected_type = selected_type[index] if not selected_type.is_federated(): raise _NonFederatedSelectionError( 'Attempted to rebind references to parameter selection path ' f'{path} from type {bb.parameter_type}, but the value at that path ' f'was of non-federated type {selected_type}. Selections must all ' f'be of federated type. Original AST:\n{bb}') int_paths.append(tuple(int_path)) type_list.append(selected_type) placement = type_list[0].placement if not all(x.placement is placement for x in type_list): raise _MismatchedSelectionPlacementError( 'In order to zip the argument to the lower-level lambda together, all ' 'selected arguments should be at the same placement. Your selections ' f'have resulted in the list of types:\n{type_list}') zip_type = computation_types.FederatedType([x.member for x in type_list], placement=placement) ref_to_zip = building_blocks.Reference(next(name_generator), zip_type) path_to_replacement = {} for i, path in enumerate(int_paths): path_to_replacement[path] = _construct_selection_from_federated_tuple( ref_to_zip, i, name_generator) new_lambda_body = _replace_selections(bb.result, bb.parameter_name, path_to_replacement) lambda_with_zipped_param = building_blocks.Lambda( ref_to_zip.name, ref_to_zip.type_signature, new_lambda_body) tree_analysis.check_contains_no_new_unbound_references( bb, lambda_with_zipped_param) return lambda_with_zipped_param
def _insert_comp_in_top_level_lambda(comp, name, comp_to_insert): """Inserts a computation into `comp` with the given `name`. Args: comp: The `building_blocks.Lambda` to transform. The names of lambda parameters and block variables in `comp` must be unique. name: The name to use. comp_to_insert: The `building_blocks.ComputationBuildingBlock` to insert. Returns: A new computation with the transformation applied or the original `comp`. """ py_typecheck.check_type(comp, building_blocks.Lambda) py_typecheck.check_type(name, str) py_typecheck.check_type(comp_to_insert, building_blocks.ComputationBuildingBlock) tree_analysis.check_has_unique_names(comp) result = comp.result if result.is_block(): variables = result.locals result = result.result else: variables = [] variables.insert(0, (name, comp_to_insert)) block = building_blocks.Block(variables, result) return building_blocks.Lambda(comp.parameter_name, comp.parameter_type, block)
def test_raises_lambda_rebinding_of_block_variable(self): x_ref = building_blocks.Reference('x', tf.int32) lambda_1 = building_blocks.Lambda('x', tf.int32, x_ref) x_data = building_blocks.Data('x', tf.int32) single_block = building_blocks.Block([('x', x_data)], lambda_1) with self.assertRaises(tree_analysis.NonuniqueNameError): tree_analysis.check_has_unique_names(single_block)
def _split_by_intrinsics_in_top_level_lambda(comp): """Splits by the intrinsics in the frist block local in the result of `comp`. This function splits `comp` into two computations `before` and `after` the called intrinsic or tuple of called intrinsics found as the first local in the `building_blocks.Block` returned by the top level lambda; and returns a Python tuple representing the pair of `before` and `after` computations. Args: comp: The `building_blocks.Lambda` to split. Returns: A pair of `building_blocks.ComputationBuildingBlock`s. Raises: ValueError: If the first local in the `building_blocks.Block` referenced by the top level lambda is not a called intrincs or a `building_blocks.Struct` of called intrinsics. """ py_typecheck.check_type(comp, building_blocks.Lambda) py_typecheck.check_type(comp.result, building_blocks.Block) tree_analysis.check_has_unique_names(comp) name_generator = building_block_factory.unique_name_generator(comp) name, first_local = comp.result.locals[0] if building_block_analysis.is_called_intrinsic(first_local): result = first_local.argument elif first_local.is_struct(): elements = [] for element in first_local: if not building_block_analysis.is_called_intrinsic(element): raise ValueError( 'Expected all the elements of the `building_blocks.Struct` to be ' 'called intrinsics, but found: \n{}'.format(element)) elements.append(element.argument) result = building_blocks.Struct(elements) else: raise ValueError( 'Expected either a called intrinsic or a `building_blocks.Struct` of ' 'called intrinsics, but found: \n{}'.format(first_local)) before = building_blocks.Lambda(comp.parameter_name, comp.parameter_type, result) ref_name = next(name_generator) ref_type = computation_types.StructType( (comp.parameter_type, first_local.type_signature)) ref = building_blocks.Reference(ref_name, ref_type) sel_after_arg_1 = building_blocks.Selection(ref, index=0) sel_after_arg_2 = building_blocks.Selection(ref, index=1) variables = comp.result.locals variables[0] = (name, sel_after_arg_2) variables.insert(0, (comp.parameter_name, sel_after_arg_1)) block = building_blocks.Block(variables, comp.result.result) after = building_blocks.Lambda(ref.name, ref.type_signature, block) return before, after
def concatenate_function_outputs(first_function, second_function): """Constructs a new function concatenating the outputs of its arguments. Assumes that `first_function` and `second_function` already have unique names, and have declared parameters of the same type. The constructed function will bind its parameter to each of the parameters of `first_function` and `second_function`, and return the result of executing these functions in parallel and concatenating the outputs in a tuple. Args: first_function: Instance of `building_blocks.Lambda` whose result we wish to concatenate with the result of `second_function`. second_function: Instance of `building_blocks.Lambda` whose result we wish to concatenate with the result of `first_function`. Returns: A new instance of `building_blocks.Lambda` with unique names representing the computation described above. Raises: TypeError: If the arguments are not instances of `building_blocks.Lambda`, or declare parameters of different types. """ py_typecheck.check_type(first_function, building_blocks.Lambda) py_typecheck.check_type(second_function, building_blocks.Lambda) tree_analysis.check_has_unique_names(first_function) tree_analysis.check_has_unique_names(second_function) if first_function.parameter_type != second_function.parameter_type: raise TypeError( 'Must pass two functions which declare the same parameter ' 'type to `concatenate_function_outputs`; you have passed ' 'one function which declared a parameter of type {}, and ' 'another which declares a parameter of type {}'.format( first_function.type_signature, second_function.type_signature)) def _rename_first_function_arg(comp): if comp.is_reference() and comp.name == first_function.parameter_name: if comp.type_signature != second_function.parameter_type: raise AssertionError('{}, {}'.format( comp.type_signature, second_function.parameter_type)) return building_blocks.Reference(second_function.parameter_name, comp.type_signature), True return comp, False first_function, _ = transformation_utils.transform_postorder( first_function, _rename_first_function_arg) concatenated_function = building_blocks.Lambda( second_function.parameter_name, second_function.parameter_type, building_blocks.Struct([first_function.result, second_function.result])) renamed, _ = tree_transformations.uniquify_reference_names( concatenated_function) return renamed
def _as_function_of_single_subparameter(bb: building_blocks.Lambda, index: int) -> building_blocks.Lambda: """Turns `x -> ...only uses x_i...` into `x_i -> ...only uses x_i`.""" tree_analysis.check_has_unique_names(bb) bb = _prepare_for_rebinding(bb) new_name = next(building_block_factory.unique_name_generator(bb)) new_ref = building_blocks.Reference(new_name, bb.type_signature.parameter[index]) new_lambda_body = _replace_selections(bb.result, bb.parameter_name, {(index, ): new_ref}) new_lambda = building_blocks.Lambda(new_ref.name, new_ref.type_signature, new_lambda_body) tree_analysis.check_contains_no_new_unbound_references(bb, new_lambda) return new_lambda
def test_single_level_block(self): ref = building_blocks.Reference('a', tf.int32) data = building_blocks.Data('data', tf.int32) block = building_blocks.Block((('a', data), ('a', ref), ('a', ref)), ref) transformed_comp, modified = tree_transformations.uniquify_reference_names( block) self.assertEqual(block.compact_representation(), '(let a=data,a=a,a=a in a)') self.assertEqual(transformed_comp.compact_representation(), '(let a=data,_var1=a,_var2=_var1 in _var2)') tree_analysis.check_has_unique_names(transformed_comp) self.assertTrue(modified)
def test_nested_blocks(self): x_ref = building_blocks.Reference('a', tf.int32) data = building_blocks.Data('data', tf.int32) block1 = building_blocks.Block([('a', data), ('a', x_ref)], x_ref) block2 = building_blocks.Block([('a', data), ('a', x_ref)], block1) transformed_comp, modified = tree_transformations.uniquify_reference_names( block2) self.assertEqual(block2.compact_representation(), '(let a=data,a=a in (let a=data,a=a in a))') self.assertEqual( transformed_comp.compact_representation(), '(let a=data,_var1=a in (let _var2=data,_var3=_var2 in _var3))') tree_analysis.check_has_unique_names(transformed_comp) self.assertTrue(modified)
def test_nested_lambdas(self): data = building_blocks.Data('data', tf.int32) input1 = building_blocks.Reference('a', data.type_signature) first_level_call = building_blocks.Call( building_blocks.Lambda('a', input1.type_signature, input1), data) input2 = building_blocks.Reference('b', first_level_call.type_signature) second_level_call = building_blocks.Call( building_blocks.Lambda('b', input2.type_signature, input2), first_level_call) transformed_comp, modified = tree_transformations.uniquify_reference_names( second_level_call) self.assertEqual(transformed_comp.compact_representation(), '(b -> b)((a -> a)(data))') tree_analysis.check_has_unique_names(transformed_comp) self.assertFalse(modified)
def test_parameters_are_mapped_together(self): x_reference = building_blocks.Reference('x', tf.int32) x_lambda = building_blocks.Lambda('x', tf.int32, x_reference) y_reference = building_blocks.Reference('y', tf.int32) y_lambda = building_blocks.Lambda('y', tf.int32, y_reference) concatenated = transformations.concatenate_function_outputs( x_lambda, y_lambda) parameter_name = concatenated.parameter_name def _raise_on_other_name_reference(comp): if isinstance(comp, building_blocks.Reference) and comp.name != parameter_name: raise ValueError return comp, True tree_analysis.check_has_unique_names(concatenated) transformation_utils.transform_postorder(concatenated, _raise_on_other_name_reference)
def test_binding_multiple_args_results_in_unique_names(self): fed_at_clients = computation_types.FederatedType( tf.int32, placements.CLIENTS) fed_at_server = computation_types.FederatedType( tf.int32, placements.SERVER) tuple_of_federated_types = computation_types.StructType( [[fed_at_clients], fed_at_server, [fed_at_clients]]) first_selection = building_blocks.Selection(building_blocks.Selection( building_blocks.Reference('x', tuple_of_federated_types), index=0), index=0) second_selection = building_blocks.Selection(building_blocks.Selection( building_blocks.Reference('x', tuple_of_federated_types), index=2), index=0) lam = building_blocks.Lambda( 'x', tuple_of_federated_types, building_blocks.Struct([first_selection, second_selection])) new_lam = form_utils._as_function_of_some_federated_subparameters( lam, [(0, 0), (2, 0)]) tree_analysis.check_has_unique_names(new_lam)
def test_blocks_nested_inside_of_locals(self): data = building_blocks.Data('data', tf.int32) lower_block = building_blocks.Block([('a', data)], data) middle_block = building_blocks.Block([('a', lower_block)], data) higher_block = building_blocks.Block([('a', middle_block)], data) y_ref = building_blocks.Reference('a', tf.int32) lower_block_with_y_ref = building_blocks.Block([('a', y_ref)], data) middle_block_with_y_ref = building_blocks.Block( [('a', lower_block_with_y_ref)], data) higher_block_with_y_ref = building_blocks.Block( [('a', middle_block_with_y_ref)], data) multiple_bindings_highest_block = building_blocks.Block( [('a', higher_block), ('a', higher_block_with_y_ref)], higher_block_with_y_ref) transformed_comp = self.assert_transforms( multiple_bindings_highest_block, 'uniquify_names_blocks_nested_inside_of_locals.expected') tree_analysis.check_has_unique_names(transformed_comp)
def test_binding_multiple_args_results_in_unique_names(self): fed_at_clients = computation_types.FederatedType( tf.int32, placements.CLIENTS) fed_at_server = computation_types.FederatedType( tf.int32, placements.SERVER) tuple_of_federated_types = computation_types.NamedTupleType( [[fed_at_clients], fed_at_server, [fed_at_clients]]) first_selection = building_blocks.Selection(building_blocks.Selection( building_blocks.Reference('x', tuple_of_federated_types), index=0), index=0) second_selection = building_blocks.Selection(building_blocks.Selection( building_blocks.Reference('x', tuple_of_federated_types), index=2), index=0) lam = building_blocks.Lambda( 'x', tuple_of_federated_types, building_blocks.Tuple([first_selection, second_selection])) deep_zeroth_index_extracted = mapreduce_transformations.zip_selection_as_argument_to_lower_level_lambda( lam, [[0, 0], [2, 0]]) tree_analysis.check_has_unique_names(deep_zeroth_index_extracted)
def test_block_lambda_block_lambda(self): x_ref = building_blocks.Reference('a', tf.int32) inner_lambda = building_blocks.Lambda('a', tf.int32, x_ref) called_lambda = building_blocks.Call(inner_lambda, x_ref) lower_block = building_blocks.Block([('a', x_ref), ('a', x_ref)], called_lambda) second_lambda = building_blocks.Lambda('a', tf.int32, lower_block) second_call = building_blocks.Call(second_lambda, x_ref) data = building_blocks.Data('data', tf.int32) last_block = building_blocks.Block([('a', data), ('a', x_ref)], second_call) transformed_comp, modified = tree_transformations.uniquify_reference_names( last_block) self.assertEqual( last_block.compact_representation(), '(let a=data,a=a in (a -> (let a=a,a=a in (a -> a)(a)))(a))') self.assertEqual( transformed_comp.compact_representation(), '(let a=data,_var1=a in (_var2 -> (let _var3=_var2,_var4=_var3 in (_var5 -> _var5)(_var4)))(_var1))' ) tree_analysis.check_has_unique_names(transformed_comp) self.assertTrue(modified)
def test_ok_on_nested_lambdas_with_different_variable_name(self): ref_to_x = building_blocks.Reference('x', tf.int32) lambda_1 = building_blocks.Lambda('x', tf.int32, ref_to_x) lambda_2 = building_blocks.Lambda('y', tf.int32, lambda_1) tree_analysis.check_has_unique_names(lambda_2)
def test_raises_on_nested_lambdas_with_same_variable_name(self): ref_to_x = building_blocks.Reference('x', tf.int32) lambda_1 = building_blocks.Lambda('x', tf.int32, ref_to_x) lambda_2 = building_blocks.Lambda('x', tf.int32, lambda_1) with self.assertRaises(tree_analysis.NonuniqueNameError): tree_analysis.check_has_unique_names(lambda_2)
def test_ok_on_multiple_no_arg_lambdas(self): data = building_blocks.Data('x', tf.int32) lambda_1 = building_blocks.Lambda(None, None, data) lambda_2 = building_blocks.Lambda(None, None, data) tup = building_blocks.Struct([lambda_1, lambda_2]) tree_analysis.check_has_unique_names(tup)
def test_ok_on_single_lambda(self): ref_to_x = building_blocks.Reference('x', tf.int32) lambda_1 = building_blocks.Lambda('x', tf.int32, ref_to_x) tree_analysis.check_has_unique_names(lambda_1)
def test_ok_on_sequential_binding_of_different_variable_in_block(self): x_data = building_blocks.Data('x', tf.int32) block = building_blocks.Block([('x', x_data), ('y', x_data)], x_data) tree_analysis.check_has_unique_names(block)
def _group_by_intrinsics_in_top_level_lambda(comp): """Groups the intrinsics in the frist block local in the result of `comp`. This transformation creates an AST by replacing the tuple of called intrinsics found as the first local in the `building_blocks.Block` returned by the top level lambda with two new computations. The first computation is a tuple of tuples of called intrinsics, representing the original tuple of called intrinscis grouped by URI. The second computation is a tuple of selection from the first computations, representing original tuple of called intrinsics. It is necessary to group intrinsics before it is possible to merge them. Args: comp: The `building_blocks.Lambda` to transform. Returns: A `building_blocks.Lamda` that returns a `building_blocks.Block`, the first local variables of the retunred `building_blocks.Block` will be a tuple of tuples of called intrinsics representing the original tuple of called intrinscis grouped by URI. Raises: ValueError: If the first local in the `building_blocks.Block` referenced by the top level lambda is not a `building_blocks.Struct` of called intrinsics. """ py_typecheck.check_type(comp, building_blocks.Lambda) py_typecheck.check_type(comp.result, building_blocks.Block) tree_analysis.check_has_unique_names(comp) name_generator = building_block_factory.unique_name_generator(comp) name, first_local = comp.result.locals[0] py_typecheck.check_type(first_local, building_blocks.Struct) for element in first_local: if not building_block_analysis.is_called_intrinsic(element): raise ValueError( 'Expected all the elements of the `building_blocks.Struct` to be ' 'called intrinsics, but found: \n{}'.format(element)) # Create collections of data describing how to pack and unpack the intrinsics # into groups by their URI. # # packed_keys is a list of unique URI ordered by occurrence in the original # tuple of called intrinsics. # packed_groups is a `collections.OrderedDict` where each key is a URI to # group by and each value is a list of intrinsics with that URI. # packed_indexes is a list of tuples where each tuple contains two indexes: # the first index in the tuple is the index of the group that the intrinsic # was packed into; the second index in the tuple is the index of the # intrinsic in that group that the intrinsic was packed into; the index of # the tuple in packed_indexes corresponds to the index of the intrinsic in # the list of intrinsics that are beging grouped. Therefore, packed_indexes # represents an implicit mapping of packed indexes, keyed by unpacked index. packed_keys = [] for called_intrinsic in first_local: uri = called_intrinsic.function.uri if uri not in packed_keys: packed_keys.append(uri) # If there are no duplicates, return early. if len(packed_keys) == len(first_local): return comp, False packed_groups = collections.OrderedDict([(x, []) for x in packed_keys]) packed_indexes = [] for called_intrinsic in first_local: packed_group = packed_groups[called_intrinsic.function.uri] packed_group.append(called_intrinsic) packed_indexes.append(( packed_keys.index(called_intrinsic.function.uri), len(packed_group) - 1, )) packed_elements = [] for called_intrinsics in packed_groups.values(): if len(called_intrinsics) > 1: element = building_blocks.Struct(called_intrinsics) else: element = called_intrinsics[0] packed_elements.append(element) packed_comp = building_blocks.Struct(packed_elements) packed_ref_name = next(name_generator) packed_ref_type = computation_types.to_type(packed_comp.type_signature) packed_ref = building_blocks.Reference(packed_ref_name, packed_ref_type) unpacked_elements = [] for indexes in packed_indexes: group_index = indexes[0] sel = building_blocks.Selection(packed_ref, index=group_index) uri = packed_keys[group_index] called_intrinsics = packed_groups[uri] if len(called_intrinsics) > 1: intrinsic_index = indexes[1] sel = building_blocks.Selection(sel, index=intrinsic_index) unpacked_elements.append(sel) unpacked_comp = building_blocks.Struct(unpacked_elements) variables = comp.result.locals variables[0] = (name, unpacked_comp) variables.insert(0, (packed_ref_name, packed_comp)) block = building_blocks.Block(variables, comp.result.result) fn = building_blocks.Lambda(comp.parameter_name, comp.parameter_type, block) return fn, True
def _extract_intrinsics_to_top_level_lambda(comp, uri): r"""Extracts intrinsics in `comp` for the given `uri`. This transformation creates an AST such that all the called intrinsics for the given `uri` in body of the `building_blocks.Block` returned by the top level lambda have been extracted to the top level lambda and replaced by selections from a reference to the constructed variable. Lambda | Block / \ [x=Struct, ...] Comp | [Call, Call Call] / \ / \ / \ Intrinsic Comp Intrinsic Comp Intrinsic Comp The order of the extracted called intrinsics matches the order of `uri`. Note: if this function is passed an AST which contains nested called intrinsics, it will fail, as it will mutate the subcomputation containing the lower-level called intrinsics on the way back up the tree. Args: comp: The `building_blocks.Lambda` to transform. The names of lambda parameters and block variables in `comp` must be unique. uri: A URI of an intrinsic. Returns: A new computation with the transformation applied or the original `comp`. Raises: ValueError: If all the intrinsics for the given `uri` in `comp` are not exclusively bound by `comp`. """ py_typecheck.check_type(comp, building_blocks.Lambda) py_typecheck.check_type(uri, list) for x in uri: py_typecheck.check_type(x, str) tree_analysis.check_has_unique_names(comp) name_generator = building_block_factory.unique_name_generator(comp) intrinsics = _get_called_intrinsics(comp, uri) for intrinsic in intrinsics: if not tree_analysis.contains_no_unbound_references( intrinsic, comp.parameter_name): raise ValueError( 'Expected a computation which binds all the references in all the ' 'intrinsic with the uri: {}.'.format(uri)) if len(intrinsics) > 1: order = {} for index, element in enumerate(uri): if element not in order: order[element] = index intrinsics = sorted(intrinsics, key=lambda x: order[x.function.uri]) extracted_comp = building_blocks.Struct(intrinsics) else: extracted_comp = intrinsics[0] ref_name = next(name_generator) ref_type = computation_types.to_type(extracted_comp.type_signature) ref = building_blocks.Reference(ref_name, ref_type) def _should_transform(comp): return building_block_analysis.is_called_intrinsic(comp, uri) def _transform(comp): if not _should_transform(comp): return comp, False if len(intrinsics) > 1: index = intrinsics.index(comp) comp = building_blocks.Selection(ref, index=index) return comp, True else: return ref, True comp, _ = transformation_utils.transform_postorder(comp, _transform) comp = _insert_comp_in_top_level_lambda(comp, name=ref.name, comp_to_insert=extracted_comp) return comp, True
def test_ok_on_single_block(self): x_data = building_blocks.Data('x', tf.int32) single_block = building_blocks.Block([('x', x_data)], x_data) tree_analysis.check_has_unique_names(single_block)
def test_raises_on_sequential_binding_of_same_variable_in_block(self): x_data = building_blocks.Data('x', tf.int32) block = building_blocks.Block([('x', x_data), ('x', x_data)], x_data) with self.assertRaises(tree_analysis.NonuniqueNameError): tree_analysis.check_has_unique_names(block)
def remove_duplicate_called_graphs(comp): """Deduplicates called graphs for a subset of TFF AST constructs. Args: comp: Instance of `building_blocks.ComputationBuildingBlock` whose called graphs we wish to deduplicate, according to `tree_analysis.trees_equal`. For `comp` to be eligible here, it must be either a lambda itself whose body contains no lambdas or blocks, or another computation containing no lambdas or blocks. This restriction is necessary because `remove_duplicate_called_graphs` makes no effort to ensure that it is not pulling references out of their defining scope, except for the case where `comp` is a lambda itself. This function exits early and logs a warning if this assumption is violated. Additionally, `comp` must contain only computations which can be represented in TensorFlow, IE, satisfy the type restriction in `type_analysis.is_tensorflow_compatible_type`. Returns: Either a called instance of `building_blocks.CompiledComputation` or a `building_blocks.CompiledComputation` itself, depending on whether `comp` is of non-functional or functional type respectively. Additionally, returns a boolean to match the `transformation_utils.TransformSpec` pattern. """ py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) tree_analysis.check_has_unique_names(comp) name_generator = building_block_factory.unique_name_generator(comp) if comp.is_lambda(): comp_to_check = comp.result else: comp_to_check = comp if tree_analysis.contains_types(comp_to_check, ( building_blocks.Block, building_blocks.Lambda, )): logging.warning( 'The preprocessors have failed to remove called lambdas ' 'and blocks; falling back to less efficient, but ' 'guaranteed, TensorFlow generation with computation %s.', comp) return comp, False leaf_called_graphs = [] def _pack_called_graphs_into_block(inner_comp): """Packs deduplicated bindings to called graphs in `leaf_called_graphs`.""" if inner_comp.is_call() and inner_comp.function.is_compiled_computation(): for (name, x) in leaf_called_graphs: if tree_analysis.trees_equal(x, inner_comp): return building_blocks.Reference(name, inner_comp.type_signature), True new_name = next(name_generator) leaf_called_graphs.append((new_name, inner_comp)) return building_blocks.Reference(new_name, inner_comp.type_signature), True return inner_comp, False if comp.is_lambda(): transformed_result, _ = transformation_utils.transform_postorder( comp.result, _pack_called_graphs_into_block) packed_into_block = building_blocks.Block(leaf_called_graphs, transformed_result) parsed, _ = create_tensorflow_representing_block(packed_into_block) tff_func = building_blocks.Lambda(comp.parameter_name, comp.parameter_type, parsed) tf_parser_callable = tree_to_cc_transformations.TFParser() comp, _ = tree_transformations.insert_called_tf_identity_at_leaves(tff_func) tf_generated, _ = transformation_utils.transform_postorder( comp, tf_parser_callable) else: transformed_result, _ = transformation_utils.transform_postorder( comp, _pack_called_graphs_into_block) packed_into_block = building_blocks.Block(leaf_called_graphs, transformed_result) tf_generated, _ = create_tensorflow_representing_block(packed_into_block) return tf_generated, True
def test_ok_block_binding_of_new_variable(self): x_data = building_blocks.Data('x', tf.int32) single_block = building_blocks.Block([('x', x_data)], x_data) lambda_1 = building_blocks.Lambda('y', tf.int32, single_block) tree_analysis.check_has_unique_names(lambda_1)
def test_raises_on_none(self): with self.assertRaises(TypeError): tree_analysis.check_has_unique_names(None)
def test_ok_lambda_binding_of_new_variable(self): y_ref = building_blocks.Reference('y', tf.int32) lambda_1 = building_blocks.Lambda('y', tf.int32, y_ref) x_data = building_blocks.Data('x', tf.int32) single_block = building_blocks.Block([('x', x_data)], lambda_1) tree_analysis.check_has_unique_names(single_block)
def force_align_and_split_by_intrinsics( comp: building_blocks.Lambda, intrinsic_defaults: List[building_blocks.Call], ) -> Tuple[building_blocks.Lambda, building_blocks.Lambda]: """Divides `comp` into before-and-after of calls to one ore more intrinsics. The input computation `comp` must have the following properties: 1. The computation `comp` is completely self-contained, i.e., there are no references to arguments introduced in a scope external to `comp`. 2. `comp`'s return value must not contain uncalled lambdas. 3. None of the calls to intrinsics in `intrinsic_defaults` may be within a lambda passed to another external function (intrinsic or compiled computation). 4. No argument passed to an intrinsic in `intrinsic_defaults` may be dependent on the result of a call to an intrinsic in `intrinsic_uris_and_defaults`. 5. All intrinsics in `intrinsic_defaults` must have "merge-able" arguments. Structs will be merged element-wise, federated values will be zipped, and functions will be composed: `f = lambda f1_arg, f2_arg: (f1(f1_arg), f2(f2_arg))` 6. All intrinsics in `intrinsic_defaults` must return a single federated value whose member is the merged result of any merged calls, i.e.: `f(merged_arg).member = (f1(f1_arg).member, f2(f2_arg).member)` Under these conditions, (and assuming `comp` is a computation with non-`None` argument), this function will return two `building_blocks.Lambda`s `before` and `after` such that `comp` is semantically equivalent to the following expression*: ``` (arg -> (let x=before(arg), y=intrinsic1(x[0]), z=intrinsic2(x[1]), ... in after(<arg, <y,z,...>>))) ``` If `comp` is a no-arg computation, the returned computations will be equivalent (in the same sense as above) to: ``` ( -> (let x=before(), y=intrinsic1(x[0]), z=intrinsic2(x[1]), ... in after(<y,z,...>))) ``` *Note that these expressions may not be entirely equivalent under nondeterminism since there is no way in this case to handle computations in which `before` creates a random variable that is then used in `after`, since the only way for state to pass from `before` to `after` is for it to travel through one of the intrinsics. In this expression, there is only a single call to `intrinsic` that results from consolidating all occurrences of this intrinsic in the original `comp`. All logic in `comp` that produced inputs to any these intrinsic calls is now consolidated and jointly encapsulated in `before`, which produces a combined argument to all the original calls. All the remaining logic in `comp`, including that which consumed the outputs of the intrinsic calls, must have been encapsulated into `after`. If the original computation `comp` had type `(T -> U)`, then `before` and `after` would be `(T -> X)` and `(<T,Y> -> U)`, respectively, where `X` is the type of the argument to the single combined intrinsic call above. Note that `after` takes the output of the call to the intrinsic as well as the original argument to `comp`, as it may be dependent on both. Args: comp: The instance of `building_blocks.Lambda` that serves as the input to this transformation, as described above. intrinsic_defaults: A list of intrinsics with which to split the computation, provided as a list of `Call`s to insert if no intrinsic with a matching URI is found. Intrinsics in this list will be merged, and `comp` will be split across them. Returns: A pair of the form `(before, after)`, where each of `before` and `after` is a `building_blocks.ComputationBuildingBlock` instance that represents a part of the result as specified above. """ py_typecheck.check_type(comp, building_blocks.Lambda) py_typecheck.check_type(intrinsic_defaults, list) comp_repr = comp.compact_representation() # Flatten `comp` to call-dominant form so that we're working with just a # linear list of intrinsic calls with no indirection via tupling, selection, # blocks, called lambdas, or references. comp = to_call_dominant(comp) # CDF can potentially return blocks if there are variables not dependent on # the top-level parameter. We normalize these away. if not comp.is_lambda(): comp.check_block() comp.result.check_lambda() if comp.result.result.is_block(): additional_locals = comp.result.result.locals result = comp.result.result.result else: additional_locals = [] result = comp.result.result # Note: without uniqueness, a local in `comp.locals` could potentially # shadow `comp.result.parameter_name`. However, `to_call_dominant` # above ensure that names are unique, as it ends in a call to # `uniquify_reference_names`. comp = building_blocks.Lambda( comp.result.parameter_name, comp.result.parameter_type, building_blocks.Block(comp.locals + additional_locals, result)) comp.check_lambda() # Simple computations with no intrinsic calls won't have a block. # Normalize these as well. if not comp.result.is_block(): comp = building_blocks.Lambda(comp.parameter_name, comp.parameter_type, building_blocks.Block([], comp.result)) comp.result.check_block() name_generator = building_block_factory.unique_name_generator(comp) intrinsic_uris = set(call.function.uri for call in intrinsic_defaults) deps = _compute_intrinsic_dependencies(intrinsic_uris, comp.parameter_name, comp.result.locals, comp_repr) merged_intrinsics = _compute_merged_intrinsics(intrinsic_defaults, deps.uri_to_locals, name_generator) # Note: the outputs are labeled as `{uri}_param for convenience, e.g. # `federated_secure_sum_param: ...`. before = building_blocks.Lambda( comp.parameter_name, comp.parameter_type, building_blocks.Block( deps.locals_not_dependent_on_intrinsics, building_blocks.Struct([(f'{merged.uri}_param', merged.args) for merged in merged_intrinsics]))) after_param_name = next(name_generator) if comp.parameter_type is not None: # TODO(b/147499373): If None-arguments were uniformly represented as empty # tuples, we would be able to avoid this (and related) ugly casing. after_param_type = computation_types.StructType([ ('original_arg', comp.parameter_type), ('intrinsic_results', computation_types.StructType([(f'{merged.uri}_result', merged.return_type) for merged in merged_intrinsics])), ]) else: after_param_type = computation_types.StructType([ ('intrinsic_results', computation_types.StructType([(f'{merged.uri}_result', merged.return_type) for merged in merged_intrinsics])), ]) after_param_ref = building_blocks.Reference(after_param_name, after_param_type) if comp.parameter_type is not None: original_arg_bindings = [ (comp.parameter_name, building_blocks.Selection(after_param_ref, name='original_arg')) ] else: original_arg_bindings = [] unzip_bindings = [] for merged in merged_intrinsics: if merged.unpack_to_locals: intrinsic_result = building_blocks.Selection( building_blocks.Selection(after_param_ref, name='intrinsic_results'), name=f'{merged.uri}_result') select_param_type = intrinsic_result.type_signature.member for i, binding_name in enumerate(merged.unpack_to_locals): select_param_name = next(name_generator) select_param_ref = building_blocks.Reference( select_param_name, select_param_type) selected = building_block_factory.create_federated_map_or_apply( building_blocks.Lambda( select_param_name, select_param_type, building_blocks.Selection(select_param_ref, index=i)), intrinsic_result) unzip_bindings.append((binding_name, selected)) after = building_blocks.Lambda( after_param_name, after_param_type, building_blocks.Block( original_arg_bindings + # Note that we must duplicate `locals_not_dependent_on_intrinsics` # across both the `before` and `after` computations since both can # rely on them, and there's no way to plumb results from `before` # through to `after` except via one of the intrinsics being split # upon. In MapReduceForm, this limitation is caused by the fact that # `prepare` has no output which serves as an input to `report`. deps.locals_not_dependent_on_intrinsics + unzip_bindings + deps.locals_dependent_on_intrinsics, comp.result.result)) try: tree_analysis.check_has_unique_names(before) tree_analysis.check_has_unique_names(after) except tree_analysis.NonuniqueNameError as e: raise ValueError( f'nonunique names in result of splitting\n{comp}') from e return before, after