def test_passes_with_federated_map(self): intrinsic = building_blocks.Intrinsic( intrinsic_defs.FEDERATED_MAP.uri, computation_types.FunctionType([ computation_types.FunctionType(tf.int32, tf.float32), computation_types.FederatedType(tf.int32, placements.CLIENTS) ], computation_types.FederatedType(tf.float32, placements.CLIENTS))) tree_analysis.check_intrinsics_whitelisted_for_reduction(intrinsic)
def test_raises_with_federated_mean(self): intrinsic = building_blocks.Intrinsic( intrinsic_defs.FEDERATED_MEAN.uri, computation_types.FunctionType( computation_types.FederatedType(tf.int32, placements.CLIENTS), computation_types.FederatedType(tf.int32, placements.SERVER))) with self.assertRaisesRegex(ValueError, intrinsic.compact_representation()): tree_analysis.check_intrinsics_whitelisted_for_reduction(intrinsic)
def test_generic_divide_reduces(self): uri = intrinsic_defs.GENERIC_DIVIDE.uri context_stack = context_stack_impl.context_stack comp = building_blocks.Intrinsic( uri, computation_types.FunctionType([tf.float32, tf.float32], tf.float32)) count_before_reduction = _count_intrinsics(comp, uri) reduced, modified = value_transformations.replace_all_intrinsics_with_bodies( comp, context_stack) count_after_reduction = _count_intrinsics(reduced, uri) self.assertGreater(count_before_reduction, 0) self.assertEqual(count_after_reduction, 0) tree_analysis.check_intrinsics_whitelisted_for_reduction(reduced) self.assertTrue(modified)
def test_raises_on_none(self): with self.assertRaises(TypeError): tree_analysis.check_intrinsics_whitelisted_for_reduction(None)
def get_canonical_form_for_iterative_process(iterative_process): """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process. This function transforms computations from the input `iterative_process` into an instance of `tff.backends.mapreduce.CanonicalForm`. Args: iterative_process: An instance of `tff.utils.IterativeProcess`. Returns: An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this process. Raises: TypeError: If the arguments are of the wrong types. transformations.CanonicalFormCompilationError: If the compilation process fails. """ py_typecheck.check_type(iterative_process, computation_utils.IterativeProcess) initialize_comp = building_blocks.ComputationBuildingBlock.from_proto( iterative_process.initialize._computation_proto) # pylint: disable=protected-access next_comp = building_blocks.ComputationBuildingBlock.from_proto( iterative_process.next._computation_proto) # pylint: disable=protected-access _check_iterative_process_compatible_with_canonical_form( initialize_comp, next_comp) if len(next_comp.type_signature.result) == 2: next_comp = _create_next_with_fake_client_output(next_comp) initialize_comp = _replace_intrinsics_with_bodies(initialize_comp) next_comp = _replace_intrinsics_with_bodies(next_comp) tree_analysis.check_intrinsics_whitelisted_for_reduction(initialize_comp) tree_analysis.check_intrinsics_whitelisted_for_reduction(next_comp) tree_analysis.check_broadcast_not_dependent_on_aggregate(next_comp) if tree_analysis.contains_called_intrinsic( next_comp, intrinsic_defs.FEDERATED_BROADCAST.uri): before_broadcast, after_broadcast = ( transformations.force_align_and_split_by_intrinsics( next_comp, [intrinsic_defs.FEDERATED_BROADCAST.uri])) else: before_broadcast, after_broadcast = ( _create_before_and_after_broadcast_for_no_broadcast(next_comp)) contains_federated_aggregate = tree_analysis.contains_called_intrinsic( next_comp, intrinsic_defs.FEDERATED_AGGREGATE.uri) contains_federated_secure_sum = tree_analysis.contains_called_intrinsic( next_comp, intrinsic_defs.FEDERATED_SECURE_SUM.uri) if contains_federated_aggregate and contains_federated_secure_sum: before_aggregate, after_aggregate = ( transformations.force_align_and_split_by_intrinsics( after_broadcast, [ intrinsic_defs.FEDERATED_AGGREGATE.uri, intrinsic_defs.FEDERATED_SECURE_SUM.uri, ])) elif not contains_federated_aggregate: before_aggregate, after_aggregate = ( _create_before_and_after_aggregate_for_no_federated_aggregate( after_broadcast)) elif not contains_federated_secure_sum: before_aggregate, after_aggregate = ( _create_before_and_after_aggregate_for_no_federated_secure_sum( after_broadcast)) else: raise ValueError( 'Expected an `tff.utils.IterativeProcess` containing at least one ' '`federated_aggregate` or `federated_secure_sum`, found none.') type_info = _get_type_info(initialize_comp, before_broadcast, after_broadcast, before_aggregate, after_aggregate) initialize = transformations.consolidate_and_extract_local_processing( initialize_comp) _check_type_equal(initialize.type_signature, type_info['initialize_type']) prepare = _extract_prepare(before_broadcast) _check_type_equal(prepare.type_signature, type_info['prepare_type']) work = _extract_work(before_aggregate, after_aggregate) _check_type_equal(work.type_signature, type_info['work_type']) zero, accumulate, merge, report = _extract_federated_aggregate_functions( before_aggregate) _check_type_equal(zero.type_signature, type_info['zero_type']) _check_type_equal(accumulate.type_signature, type_info['accumulate_type']) _check_type_equal(merge.type_signature, type_info['merge_type']) _check_type_equal(report.type_signature, type_info['report_type']) bitwidth = _extract_federated_secure_sum_functions(before_aggregate) _check_type_equal(bitwidth.type_signature, type_info['bitwidth_type']) update = _extract_update(after_aggregate) _check_type_equal(update.type_signature, type_info['update_type']) return canonical_form.CanonicalForm( computation_wrapper_instances.building_block_to_computation( initialize), computation_wrapper_instances.building_block_to_computation(prepare), computation_wrapper_instances.building_block_to_computation(work), computation_wrapper_instances.building_block_to_computation(zero), computation_wrapper_instances.building_block_to_computation( accumulate), computation_wrapper_instances.building_block_to_computation(merge), computation_wrapper_instances.building_block_to_computation(report), computation_wrapper_instances.building_block_to_computation(bitwidth), computation_wrapper_instances.building_block_to_computation(update))
def get_canonical_form_for_iterative_process(iterative_process): """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process. This function transforms computations from the input `iterative_process` into an instance of `tff.backends.mapreduce.CanonicalForm`. Args: iterative_process: An instance of `tff.utils.IterativeProcess`. Returns: An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this process. Raises: TypeError: If the arguments are of the wrong types. transformations.CanonicalFormCompilationError: If the compilation process fails. """ py_typecheck.check_type(iterative_process, computation_utils.IterativeProcess) initialize_comp = building_blocks.ComputationBuildingBlock.from_proto( iterative_process.initialize._computation_proto) # pylint: disable=protected-access next_comp = building_blocks.ComputationBuildingBlock.from_proto( iterative_process.next._computation_proto) # pylint: disable=protected-access if not (isinstance(next_comp.type_signature.parameter, computation_types.NamedTupleType) and isinstance(next_comp.type_signature.result, computation_types.NamedTupleType)): raise TypeError( 'Any IterativeProcess compatible with CanonicalForm must ' 'have a `next` function which takes and returns instances ' 'of `tff.NamedTupleType`; your next function takes ' 'parameters of type {} and returns results of type {}'.format( next_comp.type_signature.parameter, next_comp.type_signature.result)) if len(next_comp.type_signature.result) == 2: next_comp = _create_next_with_fake_client_output(next_comp) initialize_comp = replace_intrinsics_with_bodies(initialize_comp) next_comp = replace_intrinsics_with_bodies(next_comp) tree_analysis.check_intrinsics_whitelisted_for_reduction(initialize_comp) tree_analysis.check_intrinsics_whitelisted_for_reduction(next_comp) tree_analysis.check_broadcast_not_dependent_on_aggregate(next_comp) if tree_analysis.contains_called_intrinsic( next_comp, intrinsic_defs.FEDERATED_BROADCAST.uri): before_broadcast, after_broadcast = ( transformations.force_align_and_split_by_intrinsics( next_comp, [intrinsic_defs.FEDERATED_BROADCAST.uri])) else: before_broadcast, after_broadcast = ( _create_before_and_after_broadcast_for_no_broadcast(next_comp)) before_aggregate, after_aggregate = ( transformations.force_align_and_split_by_intrinsics( after_broadcast, [intrinsic_defs.FEDERATED_AGGREGATE.uri])) init_info_packed = pack_initialize_comp_type_signature( initialize_comp.type_signature) next_info_packed = pack_next_comp_type_signature(next_comp.type_signature, init_info_packed) before_broadcast_info_packed = ( check_and_pack_before_broadcast_type_signature( before_broadcast.type_signature, next_info_packed)) before_aggregate_info_packed = ( check_and_pack_before_aggregate_type_signature( before_aggregate.type_signature, before_broadcast_info_packed)) canonical_form_types = check_and_pack_after_aggregate_type_signature( after_aggregate.type_signature, before_aggregate_info_packed) initialize = transformations.consolidate_and_extract_local_processing( initialize_comp) if initialize.type_signature.result != canonical_form_types[ 'initialize_type'].member: raise transformations.CanonicalFormCompilationError( 'Compilation of initialize has failed. Expected to extract a ' '`building_blocks.CompiledComputation` of type {}, instead we extracted ' 'a {} of type {}.'.format(next_comp.type_signature.parameter[0], type(initialize), initialize.type_signature.result)) prepare = extract_prepare(before_broadcast, canonical_form_types) work = extract_work(before_aggregate, after_aggregate, canonical_form_types) zero_noarg_function, accumulate, merge, report = extract_aggregate_functions( before_aggregate, canonical_form_types) update = extract_update(after_aggregate, canonical_form_types) cf = canonical_form.CanonicalForm( computation_wrapper_instances.building_block_to_computation( initialize), computation_wrapper_instances.building_block_to_computation(prepare), computation_wrapper_instances.building_block_to_computation(work), computation_wrapper_instances.building_block_to_computation( zero_noarg_function), computation_wrapper_instances.building_block_to_computation( accumulate), computation_wrapper_instances.building_block_to_computation(merge), computation_wrapper_instances.building_block_to_computation(report), computation_wrapper_instances.building_block_to_computation(update)) return cf