def federated_zip(self, value): """Implements `federated_zip` as defined in `api/intrinsics.py`.""" # TODO(b/113112108): Extend this to accept *args. # TODO(b/113112108): We use the iterate/unwrap approach below because # our type system is not powerful enough to express the concept of # "an operation that takes tuples of T of arbitrary length", and therefore # the intrinsic federated_zip must only take a fixed number of arguments, # here fixed at 2. There are other potential approaches to getting around # this problem (e.g. having the operator act on sequences and thereby # sidestepping the issue) which we may want to explore. value = value_impl.to_value(value, None, self._context_stack) py_typecheck.check_type(value, value_base.Value) py_typecheck.check_type(value.type_signature, computation_types.StructType) value = value_impl.ValueImpl.get_comp(value) comp = building_block_factory.create_federated_zip(value) comp = self._bind_comp_as_reference(comp) return value_impl.ValueImpl(comp, self._context_stack)
def extract_update(after_aggregate, canonical_form_types): """Converts `after_aggregate` to `update`. Args: after_aggregate: The second result of splitting `after_broadcast` on `intrinsic_defs.FEDERATED_AGGREGATE`. canonical_form_types: `dict` holding the `canonical_form.CanonicalForm` type signatures specified by the `tff.utils.IterativeProcess` we are compiling. Returns: `update` as specified by `canonical_form.CanonicalForm`, an instance of `building_blocks.CompiledComputation`. Raises: transformations.CanonicalFormCompilationError: If we extract an AST of the wrong type. """ # See `get_iterative_process_for_canonical_form()` above for the meaning of # variable names used in the code below. s5_elements_in_after_aggregate_result = [0, 1] s5_output_extracted = transformations.select_output_from_lambda( after_aggregate, s5_elements_in_after_aggregate_result) s5_output_zipped = building_blocks.Lambda( s5_output_extracted.parameter_name, s5_output_extracted.parameter_type, building_block_factory.create_federated_zip( s5_output_extracted.result)) s4_elements_in_after_aggregate_parameter = [[0, 0, 0], [1]] s4_to_s5_computation = ( transformations.zip_selection_as_argument_to_lower_level_lambda( s5_output_zipped, s4_elements_in_after_aggregate_parameter).result.function) update = transformations.consolidate_and_extract_local_processing( s4_to_s5_computation) if update.type_signature != canonical_form_types['update_type']: raise transformations.CanonicalFormCompilationError( 'Extracted a TF block of the wrong type. Expected a function with type ' '{}, but the type signature of the TF block was {}'.format( canonical_form_types['update_type'], update.type_signature)) return update
def _extract_work(before_aggregate, grappler_config): """Extracts `work` from `before_aggregate`. This function is intended to be used by `get_canonical_form_for_iterative_process` only. As a result, this function does not assert that `before_aggregate` has the expected structure, the caller is expected to perform these checks before calling this function. Args: before_aggregate: The first result of splitting `after_broadcast` on aggregate intrinsics. grappler_config: An instance of `tf.compat.v1.ConfigProto` to configure Grappler graph optimization. Returns: `work` as specified by `canonical_form.CanonicalForm`, an instance of `building_blocks.CompiledComputation`. Raises: transformations.CanonicalFormCompilationError: If we extract an AST of the wrong type. """ c3_elements_in_before_aggregate_parameter = [[0, 1], [1]] c3_to_before_aggregate_computation = ( transformations.zip_selection_as_argument_to_lower_level_lambda( before_aggregate, c3_elements_in_before_aggregate_parameter).result.function) c4_index_in_before_aggregate_result = [[0, 0], [1, 0]] c3_to_unzipped_c4_computation = transformations.select_output_from_lambda( c3_to_before_aggregate_computation, c4_index_in_before_aggregate_result) c3_to_c4_computation = building_blocks.Lambda( c3_to_unzipped_c4_computation.parameter_name, c3_to_unzipped_c4_computation.parameter_type, building_block_factory.create_federated_zip( c3_to_unzipped_c4_computation.result)) return transformations.consolidate_and_extract_local_processing( c3_to_c4_computation, grappler_config)
def _extract_update(after_aggregate): """Extracts `update` from `after_aggregate`. This function is intended to be used by `get_canonical_form_for_iterative_process` only. As a result, this function does not assert that `after_aggregate` has the expected structure, the caller is expected to perform these checks before calling this function. Args: after_aggregate: The second result of splitting `after_broadcast` on aggregate intrinsics. Returns: `update` as specified by `canonical_form.CanonicalForm`, an instance of `building_blocks.CompiledComputation`. Raises: transformations.CanonicalFormCompilationError: If we extract an AST of the wrong type. """ s7_elements_in_after_aggregate_result = [0, 1] s7_output_extracted = transformations.select_output_from_lambda( after_aggregate, s7_elements_in_after_aggregate_result) s7_output_zipped = building_blocks.Lambda( s7_output_extracted.parameter_name, s7_output_extracted.parameter_type, building_block_factory.create_federated_zip(s7_output_extracted.result)) s6_elements_in_after_aggregate_parameter = [[0, 0, 0], [1, 0], [1, 1]] s6_to_s7_computation = ( transformations.zip_selection_as_argument_to_lower_level_lambda( s7_output_zipped, s6_elements_in_after_aggregate_parameter).result.function) # TODO(b/148942011): The transformation # `zip_selection_as_argument_to_lower_level_lambda` does not support selecting # from nested structures, therefore we need to pack the type signature # `<s1, s3, s4>` as `<s1, <s3, s4>>`. name_generator = building_block_factory.unique_name_generator( s6_to_s7_computation) pack_ref_name = next(name_generator) pack_ref_type = computation_types.StructType([ s6_to_s7_computation.parameter_type.member[0], computation_types.StructType([ s6_to_s7_computation.parameter_type.member[1], s6_to_s7_computation.parameter_type.member[2], ]), ]) pack_ref = building_blocks.Reference(pack_ref_name, pack_ref_type) sel_s1 = building_blocks.Selection(pack_ref, index=0) sel = building_blocks.Selection(pack_ref, index=1) sel_s3 = building_blocks.Selection(sel, index=0) sel_s4 = building_blocks.Selection(sel, index=1) result = building_blocks.Struct([sel_s1, sel_s3, sel_s4]) pack_fn = building_blocks.Lambda(pack_ref.name, pack_ref.type_signature, result) ref_name = next(name_generator) ref_type = computation_types.FederatedType(pack_ref_type, placements.SERVER) ref = building_blocks.Reference(ref_name, ref_type) unpacked_args = building_block_factory.create_federated_map_or_apply( pack_fn, ref) call = building_blocks.Call(s6_to_s7_computation, unpacked_args) fn = building_blocks.Lambda(ref.name, ref.type_signature, call) return transformations.consolidate_and_extract_local_processing(fn)
def _extract_update(after_aggregate, grappler_config): """Extracts `update` from `after_aggregate`. This function is intended to be used by `get_map_reduce_form_for_iterative_process` only. As a result, this function does not assert that `after_aggregate` has the expected structure, the caller is expected to perform these checks before calling this function. Args: after_aggregate: The second result of splitting `after_broadcast` on aggregate intrinsics. grappler_config: An instance of `tf.compat.v1.ConfigProto` to configure Grappler graph optimization. Returns: `update` as specified by `forms.MapReduceForm`, an instance of `building_blocks.CompiledComputation`. Raises: compiler.MapReduceFormCompilationError: If we extract an AST of the wrong type. """ after_aggregate_zipped = building_blocks.Lambda( after_aggregate.parameter_name, after_aggregate.parameter_type, building_block_factory.create_federated_zip(after_aggregate.result)) # `create_federated_zip` doesn't have unique reference names, but we need # them for `as_function_of_some_federated_subparameters`. after_aggregate_zipped, _ = tree_transformations.uniquify_reference_names( after_aggregate_zipped) server_state_index = ('original_arg', 'original_arg', 0) aggregate_result_index = ('intrinsic_results', 'federated_aggregate_result') secure_sum_bitwidth_result_index = ('intrinsic_results', 'federated_secure_sum_bitwidth_result') secure_sum_result_index = ('intrinsic_results', 'federated_secure_sum_result') secure_modular_sum_result_index = ('intrinsic_results', 'federated_secure_modular_sum_result') update_with_flat_inputs = _as_function_of_some_federated_subparameters( after_aggregate_zipped, ( server_state_index, aggregate_result_index, secure_sum_bitwidth_result_index, secure_sum_result_index, secure_modular_sum_result_index, )) # TODO(b/148942011): The transformation # `zip_selection_as_argument_to_lower_level_lambda` does not support selecting # from nested structures, therefore we need to transform the input from # <server_state, <aggregation_results...>> into # <server_state, aggregation_results...> # unpack = <v, <...>> -> <v, ...> name_generator = building_block_factory.unique_name_generator( update_with_flat_inputs) unpack_param_name = next(name_generator) original_param_type = update_with_flat_inputs.parameter_type.member unpack_param_type = computation_types.StructType([ original_param_type[0], computation_types.StructType(original_param_type[1:]), ]) unpack_param_ref = building_blocks.Reference(unpack_param_name, unpack_param_type) select = lambda bb, i: building_blocks.Selection(bb, index=i) unpack = building_blocks.Lambda( unpack_param_name, unpack_param_type, building_blocks.Struct([select(unpack_param_ref, 0)] + [ select(select(unpack_param_ref, 1), i) for i in range(len(original_param_type) - 1) ])) # update = v -> update_with_flat_inputs(federated_map(unpack, v)) param_name = next(name_generator) param_type = computation_types.at_server(unpack_param_type) param_ref = building_blocks.Reference(param_name, param_type) update = building_blocks.Lambda( param_name, param_type, building_blocks.Call( update_with_flat_inputs, building_block_factory.create_federated_map_or_apply( unpack, param_ref))) return compiler.consolidate_and_extract_local_processing( update, grappler_config)
def _merge_args( abstract_parameter_type, args: List[building_blocks.ComputationBuildingBlock], name_generator, ) -> building_blocks.ComputationBuildingBlock: """Merges the arguments of multiple function invocations into one. Args: abstract_parameter_type: The abstract parameter type specification for the function being invoked. This is used to determine whether any functional parameters accept multiple arguments. args: A list where each element contains the arguments to a single call. name_generator: A generator used to create unique names. Returns: A building block to use as the new (merged) argument. """ if abstract_parameter_type.is_federated(): zip_args = building_block_factory.create_federated_zip( building_blocks.Struct(args)) # `create_federated_zip` introduces repeated names. zip_args, _ = tree_transformations.uniquify_reference_names( zip_args, name_generator) return zip_args if (abstract_parameter_type.is_tensor() or abstract_parameter_type.is_abstract()): return building_blocks.Struct([(None, arg) for arg in args]) if abstract_parameter_type.is_function(): # For functions, we must compose them differently depending on whether the # abstract function (from the intrinsic definition) takes more than one # parameter. # # If it does not, such as in the `fn` argument to `federated_map`, we can # simply select out the argument and call the result: # `(fn0(arg[0]), fn1(arg[1]), ..., fnN(arg[n]))` # # If it takes multiple arguments such as the `accumulate` argument to # `federated_aggregate`, we have to select out the individual arguments to # pass to each function: # # `( # fn0(arg[0][0], arg[1][0]), # fn1(arg[0][1], arg[1][1]), # ... # fnN(arg[0][n], arg[1][n]), # )` param_name = next(name_generator) if abstract_parameter_type.parameter.is_struct(): num_args = len(abstract_parameter_type.parameter) parameter_types = [[] for i in range(num_args)] for arg in args: for i in range(num_args): parameter_types[i].append(arg.type_signature.parameter[i]) param_type = computation_types.StructType(parameter_types) param_ref = building_blocks.Reference(param_name, param_type) calls = [] for (n, fn) in enumerate(args): args_to_fn = [] for i in range(num_args): args_to_fn.append( building_blocks.Selection(building_blocks.Selection( param_ref, index=i), index=n)) calls.append( building_blocks.Call( fn, building_blocks.Struct([(None, arg) for arg in args_to_fn]))) else: param_type = computation_types.StructType( [arg.type_signature.parameter for arg in args]) param_ref = building_blocks.Reference(param_name, param_type) calls = [ building_blocks.Call( fn, building_blocks.Selection(param_ref, index=n)) for (n, fn) in enumerate(args) ] return building_blocks.Lambda(parameter_name=param_name, parameter_type=param_type, result=building_blocks.Struct([ (None, call) for call in calls ])) if abstract_parameter_type.is_struct(): # Bind each argument to a name so that we can reference them multiple times. arg_locals = [] arg_refs = [] for arg in args: arg_name = next(name_generator) arg_locals.append((arg_name, arg)) arg_refs.append( building_blocks.Reference(arg_name, arg.type_signature)) merged_args = [] for i in range(len(abstract_parameter_type)): ith_args = [ building_blocks.Selection(ref, index=i) for ref in arg_refs ] merged_args.append( _merge_args(abstract_parameter_type[i], ith_args, name_generator)) return building_blocks.Block( arg_locals, building_blocks.Struct([(None, arg) for arg in merged_args])) raise TypeError(f'Cannot merge args of type: {abstract_parameter_type}')