def test_reduces_federated_identity_to_member_identity(self): fed_int_type = computation_types.FederatedType(tf.int32, placements.CLIENTS) lam = building_blocks.Lambda( 'x', fed_int_type, building_blocks.Reference('x', fed_int_type)) extracted_tf = mapreduce_transformations.consolidate_and_extract_local_processing( lam) self.assertIsInstance(extracted_tf, building_blocks.CompiledComputation) unplaced_function_type = computation_types.FunctionType( fed_int_type.member, fed_int_type.member) self.assertEqual(extracted_tf.type_signature, unplaced_function_type)
def test_already_reduced_case(self): init = canonical_form_utils.get_iterative_process_for_canonical_form( test_utils.get_temperature_sensor_example()).initialize comp = test_utils.computation_to_building_block(init) result = mapreduce_transformations.consolidate_and_extract_local_processing( comp) self.assertIsInstance(result, building_blocks.CompiledComputation) self.assertIsInstance(result.proto, computation_pb2.Computation) self.assertEqual(result.proto.WhichOneof('computation'), 'tensorflow')
def test_already_reduced_case(self): init = form_utils.get_iterative_process_for_map_reduce_form( mapreduce_test_utils.get_temperature_sensor_example()).initialize comp = init.to_building_block() result = transformations.consolidate_and_extract_local_processing( comp, DEFAULT_GRAPPLER_CONFIG) self.assertIsInstance(result, building_blocks.CompiledComputation) self.assertIsInstance(result.proto, computation_pb2.Computation) self.assertEqual(result.proto.WhichOneof('computation'), 'tensorflow')
def _extract_work(before_aggregate, after_aggregate): """Extracts `work` from `before_aggregate` and `after_aggregate`. This function is intended to be used by `get_canonical_form_for_iterative_process` only. As a result, this function does not assert that `before_aggregate` or `after_aggregate` has the expected structure, the caller is expected to perform these checks before calling this function. Args: before_aggregate: The first result of splitting `after_broadcast` on aggregate intrinsics. after_aggregate: The second result of splitting `after_broadcast` on aggregate intrinsics. Returns: `work` as specified by `canonical_form.CanonicalForm`, an instance of `building_blocks.CompiledComputation`. Raises: transformations.CanonicalFormCompilationError: If we extract an AST of the wrong type. """ c3_elements_in_before_aggregate_parameter = [[0, 1], [1]] c3_to_before_aggregate_computation = ( transformations.zip_selection_as_argument_to_lower_level_lambda( before_aggregate, c3_elements_in_before_aggregate_parameter).result.function) c6_index_in_before_aggregate_result = [[0, 0], [1, 0]] c3_to_c6_computation = transformations.select_output_from_lambda( c3_to_before_aggregate_computation, c6_index_in_before_aggregate_result) c8_index_in_after_aggregate_result = 2 after_aggregate_to_c8_computation = transformations.select_output_from_lambda( after_aggregate, c8_index_in_after_aggregate_result) c3_elements_in_after_aggregate_parameter = [[0, 0, 1], [0, 1]] c3_to_c8_computation = ( transformations.zip_selection_as_argument_to_lower_level_lambda( after_aggregate_to_c8_computation, c3_elements_in_after_aggregate_parameter).result.function) c3_to_unzipped_c4_computation = transformations.concatenate_function_outputs( c3_to_c6_computation, c3_to_c8_computation) c3_to_c4_computation = building_blocks.Lambda( c3_to_unzipped_c4_computation.parameter_name, c3_to_unzipped_c4_computation.parameter_type, building_block_factory.create_federated_zip( c3_to_unzipped_c4_computation.result)) return transformations.consolidate_and_extract_local_processing( c3_to_c4_computation)
def _extract_aggregate_functions(before_aggregate): """Converts `before_aggregate` to aggregation functions. Args: before_aggregate: The first result of splitting `after_broadcast` on `intrinsic_defs.FEDERATED_AGGREGATE`. Returns: `zero`, `accumulate`, `merge` and `report` as specified by `canonical_form.CanonicalForm`. All are instances of `building_blocks.CompiledComputation`. Raises: transformations.CanonicalFormCompilationError: If we extract an ASTs of the wrong type. """ # See `get_iterative_process_for_canonical_form()` above for the meaning of # variable names used in the code below. zero_index_in_before_aggregate_result = 1 zero_tff = transformations.select_output_from_lambda( before_aggregate, zero_index_in_before_aggregate_result).result accumulate_index_in_before_aggregate_result = 2 accumulate_tff = transformations.select_output_from_lambda( before_aggregate, accumulate_index_in_before_aggregate_result).result merge_index_in_before_aggregate_result = 3 merge_tff = transformations.select_output_from_lambda( before_aggregate, merge_index_in_before_aggregate_result).result report_index_in_before_aggregate_result = 4 report_tff = transformations.select_output_from_lambda( before_aggregate, report_index_in_before_aggregate_result).result zero = transformations.consolidate_and_extract_local_processing(zero_tff) accumulate = transformations.consolidate_and_extract_local_processing( accumulate_tff) merge = transformations.consolidate_and_extract_local_processing(merge_tff) report = transformations.consolidate_and_extract_local_processing(report_tff) return zero, accumulate, merge, report
def test_reduces_federated_apply_to_equivalent_function(self): lam = building_blocks.Lambda('x', tf.int32, building_blocks.Reference('x', tf.int32)) arg = building_blocks.Reference( 'arg', computation_types.FederatedType(tf.int32, placements.CLIENTS)) mapped_fn = building_block_factory.create_federated_map_or_apply(lam, arg) extracted_tf = mapreduce_transformations.consolidate_and_extract_local_processing( mapped_fn) self.assertIsInstance(extracted_tf, building_blocks.CompiledComputation) executable_tf = computation_wrapper_instances.building_block_to_computation( extracted_tf) executable_lam = computation_wrapper_instances.building_block_to_computation( lam) for k in range(10): self.assertEqual(executable_tf(k), executable_lam(k))
def _extract_client_processing(after_broadcast, grappler_config): """Extracts `client_processing` from `after_broadcast`.""" context_from_server_index_in_after_broadcast = (1, ) client_data_index_in_after_broadcast = (0, 1) # NOTE: the order of parameters here is different from `work`. # `work` is odd in that it takes its parameters as `(data, params)` rather # than `(params, data)` (the order of the iterative process / computation). # Here, we use the same `(params, data)` ordering as in the input computation. client_processing = _as_function_of_some_federated_subparameters( after_broadcast, [ context_from_server_index_in_after_broadcast, client_data_index_in_after_broadcast ]) return transformations.consolidate_and_extract_local_processing( client_processing, grappler_config)
def _extract_work(before_aggregate, after_aggregate): """Converts `before_aggregate` and `after_aggregate` to `work`. Args: before_aggregate: The first result of splitting `after_broadcast` on `intrinsic_defs.FEDERATED_AGGREGATE`. after_aggregate: The second result of splitting `after_broadcast` on `intrinsic_defs.FEDERATED_AGGREGATE`. Returns: `work` as specified by `canonical_form.CanonicalForm`, an instance of `building_blocks.CompiledComputation`. Raises: transformations.CanonicalFormCompilationError: If we extract an AST of the wrong type. """ # See `get_iterative_process_for_canonical_form()` above for the meaning of # variable names used in the code below. c3_elements_in_before_aggregate_parameter = [[0, 1], [1]] c3_to_before_aggregate_computation = ( transformations.zip_selection_as_argument_to_lower_level_lambda( before_aggregate, c3_elements_in_before_aggregate_parameter).result.function) c5_index_in_before_aggregate_result = 0 c3_to_c5_computation = transformations.select_output_from_lambda( c3_to_before_aggregate_computation, c5_index_in_before_aggregate_result) c6_index_in_after_aggregate_result = 2 after_aggregate_to_c6_computation = transformations.select_output_from_lambda( after_aggregate, c6_index_in_after_aggregate_result) c3_elements_in_after_aggregate_parameter = [[0, 0, 1], [0, 1]] c3_to_c6_computation = ( transformations.zip_selection_as_argument_to_lower_level_lambda( after_aggregate_to_c6_computation, c3_elements_in_after_aggregate_parameter).result.function) c3_to_unzipped_c4_computation = transformations.concatenate_function_outputs( c3_to_c5_computation, c3_to_c6_computation) c3_to_c4_computation = building_blocks.Lambda( c3_to_unzipped_c4_computation.parameter_name, c3_to_unzipped_c4_computation.parameter_type, building_block_factory.create_federated_zip( c3_to_unzipped_c4_computation.result)) work = transformations.consolidate_and_extract_local_processing( c3_to_c4_computation) return work
def extract_update(after_aggregate, canonical_form_types): """Converts `after_aggregate` to `update`. Args: after_aggregate: The second result of splitting `after_broadcast` on `intrinsic_defs.FEDERATED_AGGREGATE`. canonical_form_types: `dict` holding the `canonical_form.CanonicalForm` type signatures specified by the `tff.utils.IterativeProcess` we are compiling. Returns: `update` as specified by `canonical_form.CanonicalForm`, an instance of `building_blocks.CompiledComputation`. Raises: transformations.CanonicalFormCompilationError: If we fail to extract a `building_blocks.CompiledComputation`, or we extract one of the wrong type. """ # See `get_iterative_process_for_canonical_form()` above for the meaning of # variable names used in the code below. s5_elements_in_after_aggregate_result = [0, 1] s5_output_extracted = transformations.select_output_from_lambda( after_aggregate, s5_elements_in_after_aggregate_result) s5_output_zipped = building_blocks.Lambda( s5_output_extracted.parameter_name, s5_output_extracted.parameter_type, building_block_factory.create_federated_zip( s5_output_extracted.result)) s4_elements_in_after_aggregate_parameter = [[0, 0, 0], [1]] s4_to_s5_computation = ( transformations.zip_selection_as_argument_to_lower_level_lambda( s5_output_zipped, s4_elements_in_after_aggregate_parameter).result.function) update = transformations.consolidate_and_extract_local_processing( s4_to_s5_computation) if not isinstance(update, building_blocks.CompiledComputation): raise transformations.CanonicalFormCompilationError( 'Failed to extract a `building_blocks.CompiledComputation` from ' 'update, instead received a {} (of type {}).'.format( type(update), update.type_signature)) if update.type_signature != canonical_form_types['update_type']: raise transformations.CanonicalFormCompilationError( 'Extracted a TF block of the wrong type. Expected a function with type ' '{}, but the type signature of the TF block was {}'.format( canonical_form_types['update_type'], update.type_signature)) return update
def _extract_prepare(before_broadcast): """Converts `before_broadcast` into `prepare`. Args: before_broadcast: The first result of splitting `next_comp` on `intrinsic_defs.FEDERATED_BROADCAST`. Returns: `prepare` as specified by `canonical_form.CanonicalForm`, an instance of `building_blocks.CompiledComputation`. Raises: transformations.CanonicalFormCompilationError: If we extract an AST of the wrong type. """ s1_index_in_before_broadcast = 0 s1_to_s2_computation = ( transformations. bind_single_selection_as_argument_to_lower_level_lambda( before_broadcast, s1_index_in_before_broadcast)).result.function prepare = transformations.consolidate_and_extract_local_processing( s1_to_s2_computation) return prepare
def extract_prepare(before_broadcast, canonical_form_types): """Converts `before_broadcast` into `prepare`. Args: before_broadcast: The first result of splitting `next_comp` on `tff_framework.FEDERATED_BROADCAST`. canonical_form_types: `dict` holding the `canonical_form.CanonicalForm` type signatures specified by the `tff.utils.IterativeProcess` we are compiling. Returns: `prepare` as specified by `canonical_form.CanonicalForm`, an instance of `tff_framework.CompiledComputation`. Raises: transformations.CanonicalFormCompilationError: If we fail to extract a `tff_framework.CompiledComputation`, or we extract one of the wrong type. """ # See `get_iterative_process_for_canonical_form()` above for the meaning of # variable names used in the code below. s1_index_in_before_broadcast = 0 s1_to_s2_computation = ( transformations. bind_single_selection_as_argument_to_lower_level_lambda( before_broadcast, s1_index_in_before_broadcast)).result.function prepare = transformations.consolidate_and_extract_local_processing( s1_to_s2_computation) if not isinstance(prepare, tff_framework.CompiledComputation): raise transformations.CanonicalFormCompilationError( 'Failed to extract a `tff_framework.CompiledComputation` from ' 'prepare, instead received a {} (of type {}).'.format( type(prepare), prepare.type_signature)) if prepare.type_signature != canonical_form_types['prepare_type']: raise transformations.CanonicalFormCompilationError( 'Extracted a TF block of the wrong type. Expected a function with type ' '{}, but the type signature of the TF block was {}'.format( canonical_form_types['prepare_type'], prepare.type_signature)) return prepare
def _extract_work(before_aggregate, grappler_config): """Extracts `work` from `before_aggregate`. This function is intended to be used by `get_map_reduce_form_for_iterative_process` only. As a result, this function does not assert that `before_aggregate` has the expected structure, the caller is expected to perform these checks before calling this function. Args: before_aggregate: The first result of splitting `after_broadcast` on aggregate intrinsics. grappler_config: An instance of `tf.compat.v1.ConfigProto` to configure Grappler graph optimization. Returns: `work` as specified by `forms.MapReduceForm`, an instance of `building_blocks.CompiledComputation`. Raises: transformations.MapReduceFormCompilationError: If we extract an AST of the wrong type. """ c3_elements_in_before_aggregate_parameter = [(0, 1), (1, )] c3_to_before_aggregate_computation = _as_function_of_some_federated_subparameters( before_aggregate, c3_elements_in_before_aggregate_parameter) c4_index_in_before_aggregate_result = [[0, 0], [1, 0]] c3_to_unzipped_c4_computation = transformations.select_output_from_lambda( c3_to_before_aggregate_computation, c4_index_in_before_aggregate_result) c3_to_c4_computation = building_blocks.Lambda( c3_to_unzipped_c4_computation.parameter_name, c3_to_unzipped_c4_computation.parameter_type, building_block_factory.create_federated_zip( c3_to_unzipped_c4_computation.result)) return transformations.consolidate_and_extract_local_processing( c3_to_c4_computation, grappler_config)
def get_canonical_form_for_iterative_process(iterative_process): """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process. This function transforms computations from the input `iterative_process` into an instance of `tff.backends.mapreduce.CanonicalForm`. Args: iterative_process: An instance of `tff.utils.IterativeProcess`. Returns: An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this process. Raises: TypeError: If the arguments are of the wrong types. transformations.CanonicalFormCompilationError: If the compilation process fails. """ py_typecheck.check_type(iterative_process, computation_utils.IterativeProcess) initialize_comp = tff_framework.ComputationBuildingBlock.from_proto( iterative_process.initialize._computation_proto) # pylint: disable=protected-access next_comp = tff_framework.ComputationBuildingBlock.from_proto( iterative_process.next._computation_proto) # pylint: disable=protected-access if not (isinstance(next_comp.type_signature.parameter, tff.NamedTupleType) and isinstance(next_comp.type_signature.result, tff.NamedTupleType)): raise TypeError( 'Any IterativeProcess compatible with CanonicalForm must ' 'have a `next` function which takes and returns instances ' 'of `tff.NamedTupleType`; your next function takes ' 'parameters of type {} and returns results of type {}'.format( next_comp.type_signature.parameter, next_comp.type_signature.result)) if len(next_comp.type_signature.result) == 2: next_result = next_comp.result dummy_clients_metrics_appended = tff_framework.Tuple([ next_result[0], next_result[1], tff.federated_value([], tff.CLIENTS)._comp # pylint: disable=protected-access ]) next_comp = tff_framework.Lambda(next_comp.parameter_name, next_comp.parameter_type, dummy_clients_metrics_appended) initialize_comp = tff_framework.replace_intrinsics_with_bodies( initialize_comp) next_comp = tff_framework.replace_intrinsics_with_bodies(next_comp) tff_framework.check_intrinsics_whitelisted_for_reduction(initialize_comp) tff_framework.check_intrinsics_whitelisted_for_reduction(next_comp) tff_framework.check_broadcast_not_dependent_on_aggregate(next_comp) before_broadcast, after_broadcast = ( transformations.force_align_and_split_by_intrinsic( next_comp, tff_framework.FEDERATED_BROADCAST.uri)) before_aggregate, after_aggregate = ( transformations.force_align_and_split_by_intrinsic( after_broadcast, tff_framework.FEDERATED_AGGREGATE.uri)) init_info_packed = pack_initialize_comp_type_signature( initialize_comp.type_signature) next_info_packed = pack_next_comp_type_signature(next_comp.type_signature, init_info_packed) before_broadcast_info_packed = ( check_and_pack_before_broadcast_type_signature( before_broadcast.type_signature, next_info_packed)) before_aggregate_info_packed = ( check_and_pack_before_aggregate_type_signature( before_aggregate.type_signature, before_broadcast_info_packed)) canonical_form_types = check_and_pack_after_aggregate_type_signature( after_aggregate.type_signature, before_aggregate_info_packed) initialize = transformations.consolidate_and_extract_local_processing( initialize_comp) if not (isinstance(initialize, tff_framework.CompiledComputation) and initialize.type_signature.result == canonical_form_types['initialize_type'].member): raise transformations.CanonicalFormCompilationError( 'Compilation of initialize has failed. Expected to extract a ' '`tff_framework.CompiledComputation` of type {}, instead we extracted ' 'a {} of type {}.'.format(next_comp.type_signature.parameter[0], type(initialize), initialize.type_signature.result)) prepare = extract_prepare(before_broadcast, canonical_form_types) work = extract_work(before_aggregate, after_aggregate, canonical_form_types) zero_noarg_function, accumulate, merge, report = extract_aggregate_functions( before_aggregate, canonical_form_types) update = extract_update(after_aggregate, canonical_form_types) cf = canonical_form.CanonicalForm( tff_framework.building_block_to_computation(initialize), tff_framework.building_block_to_computation(prepare), tff_framework.building_block_to_computation(work), tff_framework.building_block_to_computation(zero_noarg_function), tff_framework.building_block_to_computation(accumulate), tff_framework.building_block_to_computation(merge), tff_framework.building_block_to_computation(report), tff_framework.building_block_to_computation(update)) return cf
def test_raises_reference_to_functional_type(self): function_type = computation_types.FunctionType(tf.int32, tf.int32) ref = building_blocks.Reference('x', function_type) with self.assertRaisesRegex(ValueError, 'of functional type passed'): mapreduce_transformations.consolidate_and_extract_local_processing( ref)
def test_raises_on_none(self): with self.assertRaises(TypeError): mapreduce_transformations.consolidate_and_extract_local_processing( None)
def get_canonical_form_for_iterative_process(ip): """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process. This function transforms computations from the input `ip` into an instance of `tff.backends.mapreduce.CanonicalForm`. Args: ip: An instance of `tff.templates.IterativeProcess`. Returns: An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this process. Raises: TypeError: If the arguments are of the wrong types. transformations.CanonicalFormCompilationError: If the compilation process fails. """ py_typecheck.check_type(ip, iterative_process.IterativeProcess) initialize_comp = building_blocks.ComputationBuildingBlock.from_proto( ip.initialize._computation_proto) # pylint: disable=protected-access next_comp = building_blocks.ComputationBuildingBlock.from_proto( ip.next._computation_proto) # pylint: disable=protected-access _check_iterative_process_compatible_with_canonical_form( initialize_comp, next_comp) initialize_comp = _replace_intrinsics_with_bodies(initialize_comp) next_comp = _replace_intrinsics_with_bodies(next_comp) tree_analysis.check_contains_only_reducible_intrinsics(initialize_comp) tree_analysis.check_contains_only_reducible_intrinsics(next_comp) tree_analysis.check_broadcast_not_dependent_on_aggregate(next_comp) if tree_analysis.contains_called_intrinsic( next_comp, intrinsic_defs.FEDERATED_BROADCAST.uri): before_broadcast, after_broadcast = ( transformations.force_align_and_split_by_intrinsics( next_comp, [intrinsic_defs.FEDERATED_BROADCAST.uri])) else: before_broadcast, after_broadcast = ( _create_before_and_after_broadcast_for_no_broadcast(next_comp)) contains_federated_aggregate = tree_analysis.contains_called_intrinsic( next_comp, intrinsic_defs.FEDERATED_AGGREGATE.uri) contains_federated_secure_sum = tree_analysis.contains_called_intrinsic( next_comp, intrinsic_defs.FEDERATED_SECURE_SUM.uri) if not (contains_federated_aggregate or contains_federated_secure_sum): raise ValueError( 'Expected an `tff.templates.IterativeProcess` containing at least one ' '`federated_aggregate` or `federated_secure_sum`, found none.') if contains_federated_aggregate and contains_federated_secure_sum: before_aggregate, after_aggregate = ( transformations.force_align_and_split_by_intrinsics( after_broadcast, [ intrinsic_defs.FEDERATED_AGGREGATE.uri, intrinsic_defs.FEDERATED_SECURE_SUM.uri, ])) elif contains_federated_secure_sum: assert not contains_federated_aggregate before_aggregate, after_aggregate = ( _create_before_and_after_aggregate_for_no_federated_aggregate( after_broadcast)) else: assert contains_federated_aggregate and not contains_federated_secure_sum before_aggregate, after_aggregate = ( _create_before_and_after_aggregate_for_no_federated_secure_sum( after_broadcast)) type_info = _get_type_info(initialize_comp, before_broadcast, after_broadcast, before_aggregate, after_aggregate) initialize = transformations.consolidate_and_extract_local_processing( initialize_comp) _check_type_equal(initialize.type_signature, type_info['initialize_type']) prepare = _extract_prepare(before_broadcast) _check_type_equal(prepare.type_signature, type_info['prepare_type']) work = _extract_work(before_aggregate) _check_type_equal(work.type_signature, type_info['work_type']) zero, accumulate, merge, report = _extract_federated_aggregate_functions( before_aggregate) _check_type_equal(zero.type_signature, type_info['zero_type']) _check_type_equal(accumulate.type_signature, type_info['accumulate_type']) _check_type_equal(merge.type_signature, type_info['merge_type']) _check_type_equal(report.type_signature, type_info['report_type']) bitwidth = _extract_federated_secure_sum_functions(before_aggregate) _check_type_equal(bitwidth.type_signature, type_info['bitwidth_type']) update = _extract_update(after_aggregate) _check_type_equal(update.type_signature, type_info['update_type']) return canonical_form.CanonicalForm( computation_wrapper_instances.building_block_to_computation(initialize), computation_wrapper_instances.building_block_to_computation(prepare), computation_wrapper_instances.building_block_to_computation(work), computation_wrapper_instances.building_block_to_computation(zero), computation_wrapper_instances.building_block_to_computation(accumulate), computation_wrapper_instances.building_block_to_computation(merge), computation_wrapper_instances.building_block_to_computation(report), computation_wrapper_instances.building_block_to_computation(bitwidth), computation_wrapper_instances.building_block_to_computation(update))
def _extract_update(after_aggregate): """Extracts `update` from `after_aggregate`. This function is intended to be used by `get_canonical_form_for_iterative_process` only. As a result, this function does not assert that `after_aggregate` has the expected structure, the caller is expected to perform these checks before calling this function. Args: after_aggregate: The second result of splitting `after_broadcast` on aggregate intrinsics. Returns: `update` as specified by `canonical_form.CanonicalForm`, an instance of `building_blocks.CompiledComputation`. Raises: transformations.CanonicalFormCompilationError: If we extract an AST of the wrong type. """ s7_elements_in_after_aggregate_result = [0, 1] s7_output_extracted = transformations.select_output_from_lambda( after_aggregate, s7_elements_in_after_aggregate_result) s7_output_zipped = building_blocks.Lambda( s7_output_extracted.parameter_name, s7_output_extracted.parameter_type, building_block_factory.create_federated_zip(s7_output_extracted.result)) s6_elements_in_after_aggregate_parameter = [[0, 0, 0], [1, 0], [1, 1]] s6_to_s7_computation = ( transformations.zip_selection_as_argument_to_lower_level_lambda( s7_output_zipped, s6_elements_in_after_aggregate_parameter).result.function) # TODO(b/148942011): The transformation # `zip_selection_as_argument_to_lower_level_lambda` does not support selecting # from nested structures, therefore we need to pack the type signature # `<s1, s3, s4>` as `<s1, <s3, s4>>`. name_generator = building_block_factory.unique_name_generator( s6_to_s7_computation) pack_ref_name = next(name_generator) pack_ref_type = computation_types.StructType([ s6_to_s7_computation.parameter_type.member[0], computation_types.StructType([ s6_to_s7_computation.parameter_type.member[1], s6_to_s7_computation.parameter_type.member[2], ]), ]) pack_ref = building_blocks.Reference(pack_ref_name, pack_ref_type) sel_s1 = building_blocks.Selection(pack_ref, index=0) sel = building_blocks.Selection(pack_ref, index=1) sel_s3 = building_blocks.Selection(sel, index=0) sel_s4 = building_blocks.Selection(sel, index=1) result = building_blocks.Struct([sel_s1, sel_s3, sel_s4]) pack_fn = building_blocks.Lambda(pack_ref.name, pack_ref.type_signature, result) ref_name = next(name_generator) ref_type = computation_types.FederatedType(pack_ref_type, placements.SERVER) ref = building_blocks.Reference(ref_name, ref_type) unpacked_args = building_block_factory.create_federated_map_or_apply( pack_fn, ref) call = building_blocks.Call(s6_to_s7_computation, unpacked_args) fn = building_blocks.Lambda(ref.name, ref.type_signature, call) return transformations.consolidate_and_extract_local_processing(fn)
def test_reduces_lambda_returning_empty_tuple_to_tf(self): empty_tuple = building_blocks.Struct([]) lam = building_blocks.Lambda('x', tf.int32, empty_tuple) extracted_tf = transformations.consolidate_and_extract_local_processing( lam, DEFAULT_GRAPPLER_CONFIG) self.assertIsInstance(extracted_tf, building_blocks.CompiledComputation)
def test_raises_on_none(self): with self.assertRaises(TypeError): transformations.consolidate_and_extract_local_processing( None, DEFAULT_GRAPPLER_CONFIG)
def get_canonical_form_for_iterative_process( ip: iterative_process.IterativeProcess, grappler_config: Optional[ tf.compat.v1.ConfigProto] = _GRAPPLER_DEFAULT_CONFIG): """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process. This function transforms computations from the input `ip` into an instance of `tff.backends.mapreduce.CanonicalForm`. Args: ip: An instance of `tff.templates.IterativeProcess` that is compatible with canonical form. Iterative processes are only compatible if: - `initialize_fn` returns a single federated value placed at `SERVER`. - `next` takes exactly two arguments. The first must be the state value placed at `SERVER`. - `next` returns exactly two values. grappler_config: An optional instance of `tf.compat.v1.ConfigProto` to configure Grappler graph optimization of the TensorFlow graphs backing the resulting `tff.backends.mapreduce.CanonicalForm`. These options are combined with a set of defaults that aggressively configure Grappler. If `None`, Grappler is bypassed. Returns: An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this process. Raises: TypeError: If the arguments are of the wrong types. transformations.CanonicalFormCompilationError: If the compilation process fails. """ py_typecheck.check_type(ip, iterative_process.IterativeProcess) if grappler_config is not None: py_typecheck.check_type(grappler_config, tf.compat.v1.ConfigProto) overridden_grappler_config = tf.compat.v1.ConfigProto() overridden_grappler_config.CopyFrom(_GRAPPLER_DEFAULT_CONFIG) overridden_grappler_config.MergeFrom(grappler_config) grappler_config = overridden_grappler_config initialize_comp = building_blocks.ComputationBuildingBlock.from_proto( ip.initialize._computation_proto) # pylint: disable=protected-access next_comp = building_blocks.ComputationBuildingBlock.from_proto( ip.next._computation_proto) # pylint: disable=protected-access _check_iterative_process_compatible_with_canonical_form( initialize_comp, next_comp) initialize_comp = _replace_intrinsics_with_bodies(initialize_comp) next_comp = _replace_intrinsics_with_bodies(next_comp) next_comp = _replace_lambda_body_with_call_dominant_form(next_comp) tree_analysis.check_contains_only_reducible_intrinsics(initialize_comp) tree_analysis.check_contains_only_reducible_intrinsics(next_comp) tree_analysis.check_broadcast_not_dependent_on_aggregate(next_comp) if tree_analysis.contains_called_intrinsic( next_comp, intrinsic_defs.FEDERATED_BROADCAST.uri): before_broadcast, after_broadcast = ( transformations.force_align_and_split_by_intrinsics( next_comp, [intrinsic_defs.FEDERATED_BROADCAST.uri])) else: before_broadcast, after_broadcast = ( _create_before_and_after_broadcast_for_no_broadcast(next_comp)) contains_federated_aggregate = tree_analysis.contains_called_intrinsic( next_comp, intrinsic_defs.FEDERATED_AGGREGATE.uri) contains_federated_secure_sum = tree_analysis.contains_called_intrinsic( next_comp, intrinsic_defs.FEDERATED_SECURE_SUM.uri) if not (contains_federated_aggregate or contains_federated_secure_sum): raise ValueError( 'Expected an `tff.templates.IterativeProcess` containing at least one ' '`federated_aggregate` or `federated_secure_sum`, found none.') if contains_federated_aggregate and contains_federated_secure_sum: before_aggregate, after_aggregate = ( transformations.force_align_and_split_by_intrinsics( after_broadcast, [ intrinsic_defs.FEDERATED_AGGREGATE.uri, intrinsic_defs.FEDERATED_SECURE_SUM.uri, ])) elif contains_federated_secure_sum: assert not contains_federated_aggregate before_aggregate, after_aggregate = ( _create_before_and_after_aggregate_for_no_federated_aggregate( after_broadcast)) else: assert contains_federated_aggregate and not contains_federated_secure_sum before_aggregate, after_aggregate = ( _create_before_and_after_aggregate_for_no_federated_secure_sum( after_broadcast)) type_info = _get_type_info(initialize_comp, before_broadcast, after_broadcast, before_aggregate, after_aggregate) initialize = transformations.consolidate_and_extract_local_processing( initialize_comp, grappler_config) _check_type_equal(initialize.type_signature, type_info['initialize_type']) prepare = _extract_prepare(before_broadcast, grappler_config) _check_type_equal(prepare.type_signature, type_info['prepare_type']) work = _extract_work(before_aggregate, grappler_config) _check_type_equal(work.type_signature, type_info['work_type']) zero, accumulate, merge, report = _extract_federated_aggregate_functions( before_aggregate, grappler_config) _check_type_equal(zero.type_signature, type_info['zero_type']) _check_type_equal(accumulate.type_signature, type_info['accumulate_type']) _check_type_equal(merge.type_signature, type_info['merge_type']) _check_type_equal(report.type_signature, type_info['report_type']) bitwidth = _extract_federated_secure_sum_functions(before_aggregate, grappler_config) _check_type_equal(bitwidth.type_signature, type_info['bitwidth_type']) update = _extract_update(after_aggregate, grappler_config) _check_type_equal(update.type_signature, type_info['update_type']) next_parameter_names = (name for ( name, _) in structure.iter_elements(ip.next.type_signature.parameter)) server_state_label, client_data_label = next_parameter_names return canonical_form.CanonicalForm( computation_wrapper_instances.building_block_to_computation( initialize), computation_wrapper_instances.building_block_to_computation(prepare), computation_wrapper_instances.building_block_to_computation(work), computation_wrapper_instances.building_block_to_computation(zero), computation_wrapper_instances.building_block_to_computation( accumulate), computation_wrapper_instances.building_block_to_computation(merge), computation_wrapper_instances.building_block_to_computation(report), computation_wrapper_instances.building_block_to_computation(bitwidth), computation_wrapper_instances.building_block_to_computation(update), server_state_label=server_state_label, client_data_label=client_data_label)
def get_canonical_form_for_iterative_process(iterative_process): """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process. This function transforms computations from the input `iterative_process` into an instance of `tff.backends.mapreduce.CanonicalForm`. Args: iterative_process: An instance of `tff.utils.IterativeProcess`. Returns: An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this process. Raises: TypeError: If the arguments are of the wrong types. transformations.CanonicalFormCompilationError: If the compilation process fails. """ py_typecheck.check_type(iterative_process, computation_utils.IterativeProcess) initialize_comp = building_blocks.ComputationBuildingBlock.from_proto( iterative_process.initialize._computation_proto) # pylint: disable=protected-access next_comp = building_blocks.ComputationBuildingBlock.from_proto( iterative_process.next._computation_proto) # pylint: disable=protected-access _check_iterative_process_compatible_with_canonical_form( initialize_comp, next_comp) if len(next_comp.type_signature.result) == 2: next_comp = _create_next_with_fake_client_output(next_comp) initialize_comp = _replace_intrinsics_with_bodies(initialize_comp) next_comp = _replace_intrinsics_with_bodies(next_comp) tree_analysis.check_intrinsics_whitelisted_for_reduction(initialize_comp) tree_analysis.check_intrinsics_whitelisted_for_reduction(next_comp) tree_analysis.check_broadcast_not_dependent_on_aggregate(next_comp) if tree_analysis.contains_called_intrinsic( next_comp, intrinsic_defs.FEDERATED_BROADCAST.uri): before_broadcast, after_broadcast = ( transformations.force_align_and_split_by_intrinsics( next_comp, [intrinsic_defs.FEDERATED_BROADCAST.uri])) else: before_broadcast, after_broadcast = ( _create_before_and_after_broadcast_for_no_broadcast(next_comp)) before_aggregate, after_aggregate = ( transformations.force_align_and_split_by_intrinsics( after_broadcast, [intrinsic_defs.FEDERATED_AGGREGATE.uri])) type_info = _get_type_info(initialize_comp, before_broadcast, after_broadcast, before_aggregate, after_aggregate) initialize = transformations.consolidate_and_extract_local_processing( initialize_comp) _check_type_equal(initialize.type_signature, type_info['initialize_type']) prepare = _extract_prepare(before_broadcast) _check_type_equal(prepare.type_signature, type_info['prepare_type']) work = _extract_work(before_aggregate, after_aggregate) _check_type_equal(work.type_signature, type_info['work_type']) zero, accumulate, merge, report = _extract_aggregate_functions( before_aggregate) _check_type_equal(zero.type_signature, type_info['zero_type']) _check_type_equal(accumulate.type_signature, type_info['accumulate_type']) _check_type_equal(merge.type_signature, type_info['merge_type']) _check_type_equal(report.type_signature, type_info['report_type']) update = _extract_update(after_aggregate) _check_type_equal(update.type_signature, type_info['update_type']) return canonical_form.CanonicalForm( computation_wrapper_instances.building_block_to_computation( initialize), computation_wrapper_instances.building_block_to_computation(prepare), computation_wrapper_instances.building_block_to_computation(work), computation_wrapper_instances.building_block_to_computation(zero), computation_wrapper_instances.building_block_to_computation( accumulate), computation_wrapper_instances.building_block_to_computation(merge), computation_wrapper_instances.building_block_to_computation(report), computation_wrapper_instances.building_block_to_computation(update))
def test_reduces_unplaced_lambda_leaving_type_signature_alone(self): lam = building_blocks.Lambda('x', tf.int32, building_blocks.Reference('x', tf.int32)) extracted_tf = transformations.consolidate_and_extract_local_processing(lam) self.assertIsInstance(extracted_tf, building_blocks.CompiledComputation) self.assertEqual(extracted_tf.type_signature, lam.type_signature)
def test_reduces_lambda_returning_empty_tuple_to_tf(self): self.skipTest('Depends on a lower level fix, currently in review.') empty_tuple = building_blocks.Tuple([]) lam = building_blocks.Lambda('x', tf.int32, empty_tuple) extracted_tf = transformations.consolidate_and_extract_local_processing(lam) self.assertIsInstance(extracted_tf, building_blocks.CompiledComputation)