def test_returns_correct_example_of_broadcast_dependent_on_aggregate(self): aggregate = test_utils.create_dummy_called_federated_aggregate() broadcasted_aggregate = building_block_factory.create_federated_broadcast( aggregate) with self.assertRaisesRegex(ValueError, 'acc_param'): tree_analysis.check_broadcast_not_dependent_on_aggregate( broadcasted_aggregate)
def test_finds_broadcast_dependent_on_aggregate(self): aggregate = test_utils.create_dummy_called_federated_aggregate() broadcasted_aggregate = building_block_factory.create_federated_broadcast( aggregate) with self.assertRaises(ValueError): tree_analysis.check_broadcast_not_dependent_on_aggregate( broadcasted_aggregate)
def test_returns_correct_example_of_broadcast_dependent_on_aggregate(self): aggregate = computation_test_utils.create_dummy_called_federated_aggregate( 'accumulate_parameter', 'merge_parameter', 'report_parameter') broadcasted_aggregate = building_block_factory.create_federated_broadcast( aggregate) with self.assertRaisesRegex(ValueError, 'accumulate_parameter'): tree_analysis.check_broadcast_not_dependent_on_aggregate( broadcasted_aggregate)
def check_iterative_process_compatible_with_map_reduce_form( ip: iterative_process.IterativeProcess): """Tests compatibility with `tff.backends.mapreduce.MapReduceForm`. Note: the conditions here are specified in the documentation for `get_map_reduce_form_for_iterative_process`. Changes to this function should be propagated to that documentation. Args: ip: An instance of `tff.templates.IterativeProcess` to check for compatibility with `tff.backends.mapreduce.MapReduceForm`. Returns: TFF-internal building-blocks representing the validated and simplified `initialize` and `next` computations. Raises: TypeError: If the arguments are of the wrong types. """ py_typecheck.check_type(ip, iterative_process.IterativeProcess) initialize_tree = ip.initialize.to_building_block() next_tree = ip.next.to_building_block() init_type = initialize_tree.type_signature _check_type_is_no_arg_fn(init_type, '`initialize`', TypeError) if (not init_type.result.is_federated() or init_type.result.placement != placements.SERVER): raise TypeError( 'Expected `initialize` to return a single federated value ' 'placed at server (type `T@SERVER`), found return type:\n' f'{init_type.result}') next_type = next_tree.type_signature _check_type_is_fn(next_type, '`next`', TypeError) if not next_type.parameter.is_struct() or len(next_type.parameter) != 2: raise TypeError( 'Expected `next` to take two arguments, found parameter ' f' type:\n{next_type.parameter}') if not next_type.result.is_struct() or len(next_type.result) != 2: raise TypeError('Expected `next` to return two values, found result ' f'type:\n{next_type.result}') initialize_tree, _ = intrinsic_reductions.replace_intrinsics_with_bodies( initialize_tree) next_tree, _ = intrinsic_reductions.replace_intrinsics_with_bodies( next_tree) next_tree = _replace_lambda_body_with_call_dominant_form(next_tree) tree_analysis.check_contains_only_reducible_intrinsics(initialize_tree) tree_analysis.check_contains_only_reducible_intrinsics(next_tree) tree_analysis.check_broadcast_not_dependent_on_aggregate(next_tree) return initialize_tree, next_tree
def test_does_not_find_aggregate_dependent_on_broadcast(self): broadcast = test_utils.create_dummy_called_federated_broadcast() value_type = broadcast.type_signature zero = building_blocks.Data('zero', value_type.member) accumulate_result = building_blocks.Data('accumulate_result', value_type.member) accumulate = building_blocks.Lambda('accumulate_parameter', [value_type.member, value_type.member], accumulate_result) merge_result = building_blocks.Data('merge_result', value_type.member) merge = building_blocks.Lambda('merge_parameter', [value_type.member, value_type.member], merge_result) report_result = building_blocks.Data('report_result', value_type.member) report = building_blocks.Lambda('report_parameter', value_type.member, report_result) aggregate_dependent_on_broadcast = building_block_factory.create_federated_aggregate( broadcast, zero, accumulate, merge, report) tree_analysis.check_broadcast_not_dependent_on_aggregate( aggregate_dependent_on_broadcast)
def test_raises_on_none_comp(self): with self.assertRaises(TypeError): tree_analysis.check_broadcast_not_dependent_on_aggregate(None)
def get_canonical_form_for_iterative_process(ip): """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process. This function transforms computations from the input `ip` into an instance of `tff.backends.mapreduce.CanonicalForm`. Args: ip: An instance of `tff.templates.IterativeProcess`. Returns: An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this process. Raises: TypeError: If the arguments are of the wrong types. transformations.CanonicalFormCompilationError: If the compilation process fails. """ py_typecheck.check_type(ip, iterative_process.IterativeProcess) initialize_comp = building_blocks.ComputationBuildingBlock.from_proto( ip.initialize._computation_proto) # pylint: disable=protected-access next_comp = building_blocks.ComputationBuildingBlock.from_proto( ip.next._computation_proto) # pylint: disable=protected-access _check_iterative_process_compatible_with_canonical_form( initialize_comp, next_comp) initialize_comp = _replace_intrinsics_with_bodies(initialize_comp) next_comp = _replace_intrinsics_with_bodies(next_comp) tree_analysis.check_contains_only_reducible_intrinsics(initialize_comp) tree_analysis.check_contains_only_reducible_intrinsics(next_comp) tree_analysis.check_broadcast_not_dependent_on_aggregate(next_comp) if tree_analysis.contains_called_intrinsic( next_comp, intrinsic_defs.FEDERATED_BROADCAST.uri): before_broadcast, after_broadcast = ( transformations.force_align_and_split_by_intrinsics( next_comp, [intrinsic_defs.FEDERATED_BROADCAST.uri])) else: before_broadcast, after_broadcast = ( _create_before_and_after_broadcast_for_no_broadcast(next_comp)) contains_federated_aggregate = tree_analysis.contains_called_intrinsic( next_comp, intrinsic_defs.FEDERATED_AGGREGATE.uri) contains_federated_secure_sum = tree_analysis.contains_called_intrinsic( next_comp, intrinsic_defs.FEDERATED_SECURE_SUM.uri) if not (contains_federated_aggregate or contains_federated_secure_sum): raise ValueError( 'Expected an `tff.templates.IterativeProcess` containing at least one ' '`federated_aggregate` or `federated_secure_sum`, found none.') if contains_federated_aggregate and contains_federated_secure_sum: before_aggregate, after_aggregate = ( transformations.force_align_and_split_by_intrinsics( after_broadcast, [ intrinsic_defs.FEDERATED_AGGREGATE.uri, intrinsic_defs.FEDERATED_SECURE_SUM.uri, ])) elif contains_federated_secure_sum: assert not contains_federated_aggregate before_aggregate, after_aggregate = ( _create_before_and_after_aggregate_for_no_federated_aggregate( after_broadcast)) else: assert contains_federated_aggregate and not contains_federated_secure_sum before_aggregate, after_aggregate = ( _create_before_and_after_aggregate_for_no_federated_secure_sum( after_broadcast)) type_info = _get_type_info(initialize_comp, before_broadcast, after_broadcast, before_aggregate, after_aggregate) initialize = transformations.consolidate_and_extract_local_processing( initialize_comp) _check_type_equal(initialize.type_signature, type_info['initialize_type']) prepare = _extract_prepare(before_broadcast) _check_type_equal(prepare.type_signature, type_info['prepare_type']) work = _extract_work(before_aggregate) _check_type_equal(work.type_signature, type_info['work_type']) zero, accumulate, merge, report = _extract_federated_aggregate_functions( before_aggregate) _check_type_equal(zero.type_signature, type_info['zero_type']) _check_type_equal(accumulate.type_signature, type_info['accumulate_type']) _check_type_equal(merge.type_signature, type_info['merge_type']) _check_type_equal(report.type_signature, type_info['report_type']) bitwidth = _extract_federated_secure_sum_functions(before_aggregate) _check_type_equal(bitwidth.type_signature, type_info['bitwidth_type']) update = _extract_update(after_aggregate) _check_type_equal(update.type_signature, type_info['update_type']) return canonical_form.CanonicalForm( computation_wrapper_instances.building_block_to_computation(initialize), computation_wrapper_instances.building_block_to_computation(prepare), computation_wrapper_instances.building_block_to_computation(work), computation_wrapper_instances.building_block_to_computation(zero), computation_wrapper_instances.building_block_to_computation(accumulate), computation_wrapper_instances.building_block_to_computation(merge), computation_wrapper_instances.building_block_to_computation(report), computation_wrapper_instances.building_block_to_computation(bitwidth), computation_wrapper_instances.building_block_to_computation(update))
def get_canonical_form_for_iterative_process(iterative_process): """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process. This function transforms computations from the input `iterative_process` into an instance of `tff.backends.mapreduce.CanonicalForm`. Args: iterative_process: An instance of `tff.utils.IterativeProcess`. Returns: An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this process. Raises: TypeError: If the arguments are of the wrong types. transformations.CanonicalFormCompilationError: If the compilation process fails. """ py_typecheck.check_type(iterative_process, computation_utils.IterativeProcess) initialize_comp = building_blocks.ComputationBuildingBlock.from_proto( iterative_process.initialize._computation_proto) # pylint: disable=protected-access next_comp = building_blocks.ComputationBuildingBlock.from_proto( iterative_process.next._computation_proto) # pylint: disable=protected-access _check_iterative_process_compatible_with_canonical_form( initialize_comp, next_comp) if len(next_comp.type_signature.result) == 2: next_comp = _create_next_with_fake_client_output(next_comp) initialize_comp = _replace_intrinsics_with_bodies(initialize_comp) next_comp = _replace_intrinsics_with_bodies(next_comp) tree_analysis.check_intrinsics_whitelisted_for_reduction(initialize_comp) tree_analysis.check_intrinsics_whitelisted_for_reduction(next_comp) tree_analysis.check_broadcast_not_dependent_on_aggregate(next_comp) if tree_analysis.contains_called_intrinsic( next_comp, intrinsic_defs.FEDERATED_BROADCAST.uri): before_broadcast, after_broadcast = ( transformations.force_align_and_split_by_intrinsics( next_comp, [intrinsic_defs.FEDERATED_BROADCAST.uri])) else: before_broadcast, after_broadcast = ( _create_before_and_after_broadcast_for_no_broadcast(next_comp)) before_aggregate, after_aggregate = ( transformations.force_align_and_split_by_intrinsics( after_broadcast, [intrinsic_defs.FEDERATED_AGGREGATE.uri])) type_info = _get_type_info(initialize_comp, before_broadcast, after_broadcast, before_aggregate, after_aggregate) initialize = transformations.consolidate_and_extract_local_processing( initialize_comp) _check_type_equal(initialize.type_signature, type_info['initialize_type']) prepare = _extract_prepare(before_broadcast) _check_type_equal(prepare.type_signature, type_info['prepare_type']) work = _extract_work(before_aggregate, after_aggregate) _check_type_equal(work.type_signature, type_info['work_type']) zero, accumulate, merge, report = _extract_aggregate_functions( before_aggregate) _check_type_equal(zero.type_signature, type_info['zero_type']) _check_type_equal(accumulate.type_signature, type_info['accumulate_type']) _check_type_equal(merge.type_signature, type_info['merge_type']) _check_type_equal(report.type_signature, type_info['report_type']) update = _extract_update(after_aggregate) _check_type_equal(update.type_signature, type_info['update_type']) return canonical_form.CanonicalForm( computation_wrapper_instances.building_block_to_computation( initialize), computation_wrapper_instances.building_block_to_computation(prepare), computation_wrapper_instances.building_block_to_computation(work), computation_wrapper_instances.building_block_to_computation(zero), computation_wrapper_instances.building_block_to_computation( accumulate), computation_wrapper_instances.building_block_to_computation(merge), computation_wrapper_instances.building_block_to_computation(report), computation_wrapper_instances.building_block_to_computation(update))
def get_canonical_form_for_iterative_process(iterative_process): """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process. This function transforms computations from the input `iterative_process` into an instance of `tff.backends.mapreduce.CanonicalForm`. Args: iterative_process: An instance of `tff.utils.IterativeProcess`. Returns: An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this process. Raises: TypeError: If the arguments are of the wrong types. transformations.CanonicalFormCompilationError: If the compilation process fails. """ py_typecheck.check_type(iterative_process, computation_utils.IterativeProcess) initialize_comp = building_blocks.ComputationBuildingBlock.from_proto( iterative_process.initialize._computation_proto) # pylint: disable=protected-access next_comp = building_blocks.ComputationBuildingBlock.from_proto( iterative_process.next._computation_proto) # pylint: disable=protected-access if not (isinstance(next_comp.type_signature.parameter, computation_types.NamedTupleType) and isinstance(next_comp.type_signature.result, computation_types.NamedTupleType)): raise TypeError( 'Any IterativeProcess compatible with CanonicalForm must ' 'have a `next` function which takes and returns instances ' 'of `tff.NamedTupleType`; your next function takes ' 'parameters of type {} and returns results of type {}'.format( next_comp.type_signature.parameter, next_comp.type_signature.result)) if len(next_comp.type_signature.result) == 2: next_comp = _create_next_with_fake_client_output(next_comp) initialize_comp = replace_intrinsics_with_bodies(initialize_comp) next_comp = replace_intrinsics_with_bodies(next_comp) tree_analysis.check_intrinsics_whitelisted_for_reduction(initialize_comp) tree_analysis.check_intrinsics_whitelisted_for_reduction(next_comp) tree_analysis.check_broadcast_not_dependent_on_aggregate(next_comp) if tree_analysis.contains_called_intrinsic( next_comp, intrinsic_defs.FEDERATED_BROADCAST.uri): before_broadcast, after_broadcast = ( transformations.force_align_and_split_by_intrinsics( next_comp, [intrinsic_defs.FEDERATED_BROADCAST.uri])) else: before_broadcast, after_broadcast = ( _create_before_and_after_broadcast_for_no_broadcast(next_comp)) before_aggregate, after_aggregate = ( transformations.force_align_and_split_by_intrinsics( after_broadcast, [intrinsic_defs.FEDERATED_AGGREGATE.uri])) init_info_packed = pack_initialize_comp_type_signature( initialize_comp.type_signature) next_info_packed = pack_next_comp_type_signature(next_comp.type_signature, init_info_packed) before_broadcast_info_packed = ( check_and_pack_before_broadcast_type_signature( before_broadcast.type_signature, next_info_packed)) before_aggregate_info_packed = ( check_and_pack_before_aggregate_type_signature( before_aggregate.type_signature, before_broadcast_info_packed)) canonical_form_types = check_and_pack_after_aggregate_type_signature( after_aggregate.type_signature, before_aggregate_info_packed) initialize = transformations.consolidate_and_extract_local_processing( initialize_comp) if initialize.type_signature.result != canonical_form_types[ 'initialize_type'].member: raise transformations.CanonicalFormCompilationError( 'Compilation of initialize has failed. Expected to extract a ' '`building_blocks.CompiledComputation` of type {}, instead we extracted ' 'a {} of type {}.'.format(next_comp.type_signature.parameter[0], type(initialize), initialize.type_signature.result)) prepare = extract_prepare(before_broadcast, canonical_form_types) work = extract_work(before_aggregate, after_aggregate, canonical_form_types) zero_noarg_function, accumulate, merge, report = extract_aggregate_functions( before_aggregate, canonical_form_types) update = extract_update(after_aggregate, canonical_form_types) cf = canonical_form.CanonicalForm( computation_wrapper_instances.building_block_to_computation( initialize), computation_wrapper_instances.building_block_to_computation(prepare), computation_wrapper_instances.building_block_to_computation(work), computation_wrapper_instances.building_block_to_computation( zero_noarg_function), computation_wrapper_instances.building_block_to_computation( accumulate), computation_wrapper_instances.building_block_to_computation(merge), computation_wrapper_instances.building_block_to_computation(report), computation_wrapper_instances.building_block_to_computation(update)) return cf
def get_canonical_form_for_iterative_process( ip: iterative_process.IterativeProcess, grappler_config: Optional[ tf.compat.v1.ConfigProto] = _GRAPPLER_DEFAULT_CONFIG): """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process. This function transforms computations from the input `ip` into an instance of `tff.backends.mapreduce.CanonicalForm`. Args: ip: An instance of `tff.templates.IterativeProcess` that is compatible with canonical form. Iterative processes are only compatible if: - `initialize_fn` returns a single federated value placed at `SERVER`. - `next` takes exactly two arguments. The first must be the state value placed at `SERVER`. - `next` returns exactly two values. grappler_config: An optional instance of `tf.compat.v1.ConfigProto` to configure Grappler graph optimization of the TensorFlow graphs backing the resulting `tff.backends.mapreduce.CanonicalForm`. These options are combined with a set of defaults that aggressively configure Grappler. If `None`, Grappler is bypassed. Returns: An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this process. Raises: TypeError: If the arguments are of the wrong types. transformations.CanonicalFormCompilationError: If the compilation process fails. """ py_typecheck.check_type(ip, iterative_process.IterativeProcess) if grappler_config is not None: py_typecheck.check_type(grappler_config, tf.compat.v1.ConfigProto) overridden_grappler_config = tf.compat.v1.ConfigProto() overridden_grappler_config.CopyFrom(_GRAPPLER_DEFAULT_CONFIG) overridden_grappler_config.MergeFrom(grappler_config) grappler_config = overridden_grappler_config initialize_comp = building_blocks.ComputationBuildingBlock.from_proto( ip.initialize._computation_proto) # pylint: disable=protected-access next_comp = building_blocks.ComputationBuildingBlock.from_proto( ip.next._computation_proto) # pylint: disable=protected-access _check_iterative_process_compatible_with_canonical_form( initialize_comp, next_comp) initialize_comp = _replace_intrinsics_with_bodies(initialize_comp) next_comp = _replace_intrinsics_with_bodies(next_comp) next_comp = _replace_lambda_body_with_call_dominant_form(next_comp) tree_analysis.check_contains_only_reducible_intrinsics(initialize_comp) tree_analysis.check_contains_only_reducible_intrinsics(next_comp) tree_analysis.check_broadcast_not_dependent_on_aggregate(next_comp) if tree_analysis.contains_called_intrinsic( next_comp, intrinsic_defs.FEDERATED_BROADCAST.uri): before_broadcast, after_broadcast = ( transformations.force_align_and_split_by_intrinsics( next_comp, [intrinsic_defs.FEDERATED_BROADCAST.uri])) else: before_broadcast, after_broadcast = ( _create_before_and_after_broadcast_for_no_broadcast(next_comp)) contains_federated_aggregate = tree_analysis.contains_called_intrinsic( next_comp, intrinsic_defs.FEDERATED_AGGREGATE.uri) contains_federated_secure_sum = tree_analysis.contains_called_intrinsic( next_comp, intrinsic_defs.FEDERATED_SECURE_SUM.uri) if not (contains_federated_aggregate or contains_federated_secure_sum): raise ValueError( 'Expected an `tff.templates.IterativeProcess` containing at least one ' '`federated_aggregate` or `federated_secure_sum`, found none.') if contains_federated_aggregate and contains_federated_secure_sum: before_aggregate, after_aggregate = ( transformations.force_align_and_split_by_intrinsics( after_broadcast, [ intrinsic_defs.FEDERATED_AGGREGATE.uri, intrinsic_defs.FEDERATED_SECURE_SUM.uri, ])) elif contains_federated_secure_sum: assert not contains_federated_aggregate before_aggregate, after_aggregate = ( _create_before_and_after_aggregate_for_no_federated_aggregate( after_broadcast)) else: assert contains_federated_aggregate and not contains_federated_secure_sum before_aggregate, after_aggregate = ( _create_before_and_after_aggregate_for_no_federated_secure_sum( after_broadcast)) type_info = _get_type_info(initialize_comp, before_broadcast, after_broadcast, before_aggregate, after_aggregate) initialize = transformations.consolidate_and_extract_local_processing( initialize_comp, grappler_config) _check_type_equal(initialize.type_signature, type_info['initialize_type']) prepare = _extract_prepare(before_broadcast, grappler_config) _check_type_equal(prepare.type_signature, type_info['prepare_type']) work = _extract_work(before_aggregate, grappler_config) _check_type_equal(work.type_signature, type_info['work_type']) zero, accumulate, merge, report = _extract_federated_aggregate_functions( before_aggregate, grappler_config) _check_type_equal(zero.type_signature, type_info['zero_type']) _check_type_equal(accumulate.type_signature, type_info['accumulate_type']) _check_type_equal(merge.type_signature, type_info['merge_type']) _check_type_equal(report.type_signature, type_info['report_type']) bitwidth = _extract_federated_secure_sum_functions(before_aggregate, grappler_config) _check_type_equal(bitwidth.type_signature, type_info['bitwidth_type']) update = _extract_update(after_aggregate, grappler_config) _check_type_equal(update.type_signature, type_info['update_type']) next_parameter_names = (name for ( name, _) in structure.iter_elements(ip.next.type_signature.parameter)) server_state_label, client_data_label = next_parameter_names return canonical_form.CanonicalForm( computation_wrapper_instances.building_block_to_computation( initialize), computation_wrapper_instances.building_block_to_computation(prepare), computation_wrapper_instances.building_block_to_computation(work), computation_wrapper_instances.building_block_to_computation(zero), computation_wrapper_instances.building_block_to_computation( accumulate), computation_wrapper_instances.building_block_to_computation(merge), computation_wrapper_instances.building_block_to_computation(report), computation_wrapper_instances.building_block_to_computation(bitwidth), computation_wrapper_instances.building_block_to_computation(update), server_state_label=server_state_label, client_data_label=client_data_label)