def get_broadcast_form_for_computation( comp: computation_base.Computation, grappler_config: tf.compat.v1.ConfigProto = _GRAPPLER_DEFAULT_CONFIG ) -> forms.BroadcastForm: """Constructs `tff.backends.mapreduce.BroadcastForm` given a computation. Args: comp: An instance of `tff.Computation` that is compatible with broadcast form. Computations are only compatible if they take in a single value placed at server, return a single value placed at clients, and do not contain any aggregations. grappler_config: An instance of `tf.compat.v1.ConfigProto` to configure Grappler graph optimization of the Tensorflow graphs backing the resulting `tff.backends.mapreduce.BroadcastForm`. These options are combined with a set of defaults that aggressively configure Grappler. If `grappler_config_proto` has `graph_options.rewrite_options.disable_meta_optimizer=True`, Grappler is bypassed. Returns: An instance of `tff.backends.mapreduce.BroadcastForm` equivalent to the provided `tff.Computation`. """ py_typecheck.check_type(comp, computation_base.Computation) _check_function_signature_compatible_with_broadcast_form( comp.type_signature) py_typecheck.check_type(grappler_config, tf.compat.v1.ConfigProto) grappler_config = _merge_grappler_config_with_default(grappler_config) bb = comp.to_building_block() bb, _ = intrinsic_reductions.replace_intrinsics_with_bodies(bb) bb = _replace_lambda_body_with_call_dominant_form(bb) tree_analysis.check_contains_only_reducible_intrinsics(bb) aggregations = tree_analysis.find_aggregations_in_tree(bb) if aggregations: raise ValueError( f'`get_broadcast_form_for_computation` called with computation ' f'containing {len(aggregations)} aggregations, but broadcast form ' 'does not allow aggregation. Full list of aggregations:\n{aggregations}' ) before_broadcast, after_broadcast = _split_ast_on_broadcast(bb) compute_server_context = _extract_compute_server_context( before_broadcast, grappler_config) client_processing = _extract_client_processing(after_broadcast, grappler_config) compute_server_context, client_processing = ( computation_wrapper_instances.building_block_to_computation(bb) for bb in (compute_server_context, client_processing)) comp_param_names = structure.name_list_with_nones( comp.type_signature.parameter) server_data_label, client_data_label = comp_param_names return forms.BroadcastForm(compute_server_context, client_processing, server_data_label=server_data_label, client_data_label=client_data_label)
def test_passes_with_federated_map(self): intrinsic = building_blocks.Intrinsic( intrinsic_defs.FEDERATED_MAP.uri, computation_types.FunctionType([ computation_types.FunctionType(tf.int32, tf.float32), computation_types.FederatedType(tf.int32, placements.CLIENTS) ], computation_types.FederatedType(tf.float32, placements.CLIENTS))) tree_analysis.check_contains_only_reducible_intrinsics(intrinsic)
def test_raises_with_federated_mean(self): intrinsic = building_blocks.Intrinsic( intrinsic_defs.FEDERATED_MEAN.uri, computation_types.FunctionType( computation_types.FederatedType(tf.int32, placements.CLIENTS), computation_types.FederatedType(tf.int32, placements.SERVER))) with self.assertRaisesRegex(ValueError, intrinsic.compact_representation()): tree_analysis.check_contains_only_reducible_intrinsics(intrinsic)
def check_iterative_process_compatible_with_map_reduce_form( ip: iterative_process.IterativeProcess): """Tests compatibility with `tff.backends.mapreduce.MapReduceForm`. Note: the conditions here are specified in the documentation for `get_map_reduce_form_for_iterative_process`. Changes to this function should be propagated to that documentation. Args: ip: An instance of `tff.templates.IterativeProcess` to check for compatibility with `tff.backends.mapreduce.MapReduceForm`. Returns: TFF-internal building-blocks representing the validated and simplified `initialize` and `next` computations. Raises: TypeError: If the arguments are of the wrong types. """ py_typecheck.check_type(ip, iterative_process.IterativeProcess) initialize_tree = ip.initialize.to_building_block() next_tree = ip.next.to_building_block() init_type = initialize_tree.type_signature _check_type_is_no_arg_fn(init_type, '`initialize`', TypeError) if (not init_type.result.is_federated() or init_type.result.placement != placements.SERVER): raise TypeError( 'Expected `initialize` to return a single federated value ' 'placed at server (type `T@SERVER`), found return type:\n' f'{init_type.result}') next_type = next_tree.type_signature _check_type_is_fn(next_type, '`next`', TypeError) if not next_type.parameter.is_struct() or len(next_type.parameter) != 2: raise TypeError( 'Expected `next` to take two arguments, found parameter ' f' type:\n{next_type.parameter}') if not next_type.result.is_struct() or len(next_type.result) != 2: raise TypeError('Expected `next` to return two values, found result ' f'type:\n{next_type.result}') initialize_tree, _ = intrinsic_reductions.replace_intrinsics_with_bodies( initialize_tree) next_tree, _ = intrinsic_reductions.replace_intrinsics_with_bodies( next_tree) next_tree = _replace_lambda_body_with_call_dominant_form(next_tree) tree_analysis.check_contains_only_reducible_intrinsics(initialize_tree) tree_analysis.check_contains_only_reducible_intrinsics(next_tree) tree_analysis.check_broadcast_not_dependent_on_aggregate(next_tree) return initialize_tree, next_tree
def test_generic_plus_reduces(self): uri = intrinsic_defs.GENERIC_PLUS.uri comp = building_blocks.Intrinsic( uri, computation_types.FunctionType([tf.float32, tf.float32], tf.float32)) count_before_reduction = _count_intrinsics(comp, uri) reduced, modified = intrinsic_reductions.replace_intrinsics_with_bodies( comp) count_after_reduction = _count_intrinsics(reduced, uri) self.assertTrue(modified) self.assert_types_identical(comp.type_signature, reduced.type_signature) self.assertGreater(count_before_reduction, 0) self.assertEqual(count_after_reduction, 0) tree_analysis.check_contains_only_reducible_intrinsics(reduced)
def test_generic_divide_reduces(self): uri = intrinsic_defs.GENERIC_DIVIDE.uri context_stack = context_stack_impl.context_stack comp = building_blocks.Intrinsic( uri, computation_types.FunctionType([tf.float32, tf.float32], tf.float32)) count_before_reduction = _count_intrinsics(comp, uri) reduced, modified = value_transformations.replace_intrinsics_with_bodies( comp, context_stack) count_after_reduction = _count_intrinsics(reduced, uri) self.assertGreater(count_before_reduction, 0) self.assertEqual(count_after_reduction, 0) tree_analysis.check_contains_only_reducible_intrinsics(reduced) self.assertTrue(modified)
def test_raises_on_none(self): with self.assertRaises(TypeError): tree_analysis.check_contains_only_reducible_intrinsics(None)
def get_canonical_form_for_iterative_process(ip): """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process. This function transforms computations from the input `ip` into an instance of `tff.backends.mapreduce.CanonicalForm`. Args: ip: An instance of `tff.templates.IterativeProcess`. Returns: An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this process. Raises: TypeError: If the arguments are of the wrong types. transformations.CanonicalFormCompilationError: If the compilation process fails. """ py_typecheck.check_type(ip, iterative_process.IterativeProcess) initialize_comp = building_blocks.ComputationBuildingBlock.from_proto( ip.initialize._computation_proto) # pylint: disable=protected-access next_comp = building_blocks.ComputationBuildingBlock.from_proto( ip.next._computation_proto) # pylint: disable=protected-access _check_iterative_process_compatible_with_canonical_form( initialize_comp, next_comp) initialize_comp = _replace_intrinsics_with_bodies(initialize_comp) next_comp = _replace_intrinsics_with_bodies(next_comp) tree_analysis.check_contains_only_reducible_intrinsics(initialize_comp) tree_analysis.check_contains_only_reducible_intrinsics(next_comp) tree_analysis.check_broadcast_not_dependent_on_aggregate(next_comp) if tree_analysis.contains_called_intrinsic( next_comp, intrinsic_defs.FEDERATED_BROADCAST.uri): before_broadcast, after_broadcast = ( transformations.force_align_and_split_by_intrinsics( next_comp, [intrinsic_defs.FEDERATED_BROADCAST.uri])) else: before_broadcast, after_broadcast = ( _create_before_and_after_broadcast_for_no_broadcast(next_comp)) contains_federated_aggregate = tree_analysis.contains_called_intrinsic( next_comp, intrinsic_defs.FEDERATED_AGGREGATE.uri) contains_federated_secure_sum = tree_analysis.contains_called_intrinsic( next_comp, intrinsic_defs.FEDERATED_SECURE_SUM.uri) if not (contains_federated_aggregate or contains_federated_secure_sum): raise ValueError( 'Expected an `tff.templates.IterativeProcess` containing at least one ' '`federated_aggregate` or `federated_secure_sum`, found none.') if contains_federated_aggregate and contains_federated_secure_sum: before_aggregate, after_aggregate = ( transformations.force_align_and_split_by_intrinsics( after_broadcast, [ intrinsic_defs.FEDERATED_AGGREGATE.uri, intrinsic_defs.FEDERATED_SECURE_SUM.uri, ])) elif contains_federated_secure_sum: assert not contains_federated_aggregate before_aggregate, after_aggregate = ( _create_before_and_after_aggregate_for_no_federated_aggregate( after_broadcast)) else: assert contains_federated_aggregate and not contains_federated_secure_sum before_aggregate, after_aggregate = ( _create_before_and_after_aggregate_for_no_federated_secure_sum( after_broadcast)) type_info = _get_type_info(initialize_comp, before_broadcast, after_broadcast, before_aggregate, after_aggregate) initialize = transformations.consolidate_and_extract_local_processing( initialize_comp) _check_type_equal(initialize.type_signature, type_info['initialize_type']) prepare = _extract_prepare(before_broadcast) _check_type_equal(prepare.type_signature, type_info['prepare_type']) work = _extract_work(before_aggregate) _check_type_equal(work.type_signature, type_info['work_type']) zero, accumulate, merge, report = _extract_federated_aggregate_functions( before_aggregate) _check_type_equal(zero.type_signature, type_info['zero_type']) _check_type_equal(accumulate.type_signature, type_info['accumulate_type']) _check_type_equal(merge.type_signature, type_info['merge_type']) _check_type_equal(report.type_signature, type_info['report_type']) bitwidth = _extract_federated_secure_sum_functions(before_aggregate) _check_type_equal(bitwidth.type_signature, type_info['bitwidth_type']) update = _extract_update(after_aggregate) _check_type_equal(update.type_signature, type_info['update_type']) return canonical_form.CanonicalForm( computation_wrapper_instances.building_block_to_computation(initialize), computation_wrapper_instances.building_block_to_computation(prepare), computation_wrapper_instances.building_block_to_computation(work), computation_wrapper_instances.building_block_to_computation(zero), computation_wrapper_instances.building_block_to_computation(accumulate), computation_wrapper_instances.building_block_to_computation(merge), computation_wrapper_instances.building_block_to_computation(report), computation_wrapper_instances.building_block_to_computation(bitwidth), computation_wrapper_instances.building_block_to_computation(update))
def get_canonical_form_for_iterative_process( ip: iterative_process.IterativeProcess, grappler_config: Optional[ tf.compat.v1.ConfigProto] = _GRAPPLER_DEFAULT_CONFIG): """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process. This function transforms computations from the input `ip` into an instance of `tff.backends.mapreduce.CanonicalForm`. Args: ip: An instance of `tff.templates.IterativeProcess` that is compatible with canonical form. Iterative processes are only compatible if: - `initialize_fn` returns a single federated value placed at `SERVER`. - `next` takes exactly two arguments. The first must be the state value placed at `SERVER`. - `next` returns exactly two values. grappler_config: An optional instance of `tf.compat.v1.ConfigProto` to configure Grappler graph optimization of the TensorFlow graphs backing the resulting `tff.backends.mapreduce.CanonicalForm`. These options are combined with a set of defaults that aggressively configure Grappler. If `None`, Grappler is bypassed. Returns: An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this process. Raises: TypeError: If the arguments are of the wrong types. transformations.CanonicalFormCompilationError: If the compilation process fails. """ py_typecheck.check_type(ip, iterative_process.IterativeProcess) if grappler_config is not None: py_typecheck.check_type(grappler_config, tf.compat.v1.ConfigProto) overridden_grappler_config = tf.compat.v1.ConfigProto() overridden_grappler_config.CopyFrom(_GRAPPLER_DEFAULT_CONFIG) overridden_grappler_config.MergeFrom(grappler_config) grappler_config = overridden_grappler_config initialize_comp = building_blocks.ComputationBuildingBlock.from_proto( ip.initialize._computation_proto) # pylint: disable=protected-access next_comp = building_blocks.ComputationBuildingBlock.from_proto( ip.next._computation_proto) # pylint: disable=protected-access _check_iterative_process_compatible_with_canonical_form( initialize_comp, next_comp) initialize_comp = _replace_intrinsics_with_bodies(initialize_comp) next_comp = _replace_intrinsics_with_bodies(next_comp) next_comp = _replace_lambda_body_with_call_dominant_form(next_comp) tree_analysis.check_contains_only_reducible_intrinsics(initialize_comp) tree_analysis.check_contains_only_reducible_intrinsics(next_comp) tree_analysis.check_broadcast_not_dependent_on_aggregate(next_comp) if tree_analysis.contains_called_intrinsic( next_comp, intrinsic_defs.FEDERATED_BROADCAST.uri): before_broadcast, after_broadcast = ( transformations.force_align_and_split_by_intrinsics( next_comp, [intrinsic_defs.FEDERATED_BROADCAST.uri])) else: before_broadcast, after_broadcast = ( _create_before_and_after_broadcast_for_no_broadcast(next_comp)) contains_federated_aggregate = tree_analysis.contains_called_intrinsic( next_comp, intrinsic_defs.FEDERATED_AGGREGATE.uri) contains_federated_secure_sum = tree_analysis.contains_called_intrinsic( next_comp, intrinsic_defs.FEDERATED_SECURE_SUM.uri) if not (contains_federated_aggregate or contains_federated_secure_sum): raise ValueError( 'Expected an `tff.templates.IterativeProcess` containing at least one ' '`federated_aggregate` or `federated_secure_sum`, found none.') if contains_federated_aggregate and contains_federated_secure_sum: before_aggregate, after_aggregate = ( transformations.force_align_and_split_by_intrinsics( after_broadcast, [ intrinsic_defs.FEDERATED_AGGREGATE.uri, intrinsic_defs.FEDERATED_SECURE_SUM.uri, ])) elif contains_federated_secure_sum: assert not contains_federated_aggregate before_aggregate, after_aggregate = ( _create_before_and_after_aggregate_for_no_federated_aggregate( after_broadcast)) else: assert contains_federated_aggregate and not contains_federated_secure_sum before_aggregate, after_aggregate = ( _create_before_and_after_aggregate_for_no_federated_secure_sum( after_broadcast)) type_info = _get_type_info(initialize_comp, before_broadcast, after_broadcast, before_aggregate, after_aggregate) initialize = transformations.consolidate_and_extract_local_processing( initialize_comp, grappler_config) _check_type_equal(initialize.type_signature, type_info['initialize_type']) prepare = _extract_prepare(before_broadcast, grappler_config) _check_type_equal(prepare.type_signature, type_info['prepare_type']) work = _extract_work(before_aggregate, grappler_config) _check_type_equal(work.type_signature, type_info['work_type']) zero, accumulate, merge, report = _extract_federated_aggregate_functions( before_aggregate, grappler_config) _check_type_equal(zero.type_signature, type_info['zero_type']) _check_type_equal(accumulate.type_signature, type_info['accumulate_type']) _check_type_equal(merge.type_signature, type_info['merge_type']) _check_type_equal(report.type_signature, type_info['report_type']) bitwidth = _extract_federated_secure_sum_functions(before_aggregate, grappler_config) _check_type_equal(bitwidth.type_signature, type_info['bitwidth_type']) update = _extract_update(after_aggregate, grappler_config) _check_type_equal(update.type_signature, type_info['update_type']) next_parameter_names = (name for ( name, _) in structure.iter_elements(ip.next.type_signature.parameter)) server_state_label, client_data_label = next_parameter_names return canonical_form.CanonicalForm( computation_wrapper_instances.building_block_to_computation( initialize), computation_wrapper_instances.building_block_to_computation(prepare), computation_wrapper_instances.building_block_to_computation(work), computation_wrapper_instances.building_block_to_computation(zero), computation_wrapper_instances.building_block_to_computation( accumulate), computation_wrapper_instances.building_block_to_computation(merge), computation_wrapper_instances.building_block_to_computation(report), computation_wrapper_instances.building_block_to_computation(bitwidth), computation_wrapper_instances.building_block_to_computation(update), server_state_label=server_state_label, client_data_label=client_data_label)