def test_reduces_lambda_returning_empty_tuple_to_tf(self): empty_tuple = building_blocks.Struct([]) lam = building_blocks.Lambda('x', tf.int32, empty_tuple) extracted_tf = compiler.consolidate_and_extract_local_processing( lam, DEFAULT_GRAPPLER_CONFIG) self.assertIsInstance(extracted_tf, building_blocks.CompiledComputation)
def _extract_prepare(before_broadcast, grappler_config): """extracts `prepare` from `before_broadcast`. This function is intended to be used by `get_map_reduce_form_for_iterative_process` only. As a result, this function does not assert that `before_broadcast` has the expected structure, the caller is expected to perform these checks before calling this function. Args: before_broadcast: The first result of splitting `next_bb` on `intrinsic_defs.FEDERATED_BROADCAST`. grappler_config: An instance of `tf.compat.v1.ConfigProto` to configure Grappler graph optimization. Returns: `prepare` as specified by `forms.MapReduceForm`, an instance of `building_blocks.CompiledComputation`. Raises: compiler.MapReduceFormCompilationError: If we extract an AST of the wrong type. """ server_state_index_in_before_broadcast = 0 prepare = _as_function_of_single_subparameter( before_broadcast, server_state_index_in_before_broadcast) return compiler.consolidate_and_extract_local_processing( prepare, grappler_config)
def _extract_compute_server_context(before_broadcast, grappler_config): """Extracts `compute_server_config` from `before_broadcast`.""" server_data_index_in_before_broadcast = 0 compute_server_context = _as_function_of_single_subparameter( before_broadcast, server_data_index_in_before_broadcast) return compiler.consolidate_and_extract_local_processing( compute_server_context, grappler_config)
def test_reduces_unplaced_lambda_leaving_type_signature_alone(self): lam = building_blocks.Lambda('x', tf.int32, building_blocks.Reference('x', tf.int32)) extracted_tf = compiler.consolidate_and_extract_local_processing( lam, DEFAULT_GRAPPLER_CONFIG) self.assertIsInstance(extracted_tf, building_blocks.CompiledComputation) self.assertEqual(extracted_tf.type_signature, lam.type_signature)
def _compile_selected_output_as_tensorflow_function( comp: building_blocks.Lambda, path: building_block_factory.Path, grappler_config) -> building_blocks.CompiledComputation: """Compiles the functional result of `comp` at `path` to TensorFlow.""" extracted = building_block_factory.select_output_from_lambda(comp, path).result return compiler.consolidate_and_extract_local_processing( extracted, grappler_config)
def _compile_selected_output_to_no_argument_tensorflow( comp: building_blocks.Lambda, path: building_block_factory.Path, grappler_config) -> building_blocks.CompiledComputation: """Compiles the independent value result of `comp` at `path` to TensorFlow.""" extracted = building_block_factory.select_output_from_lambda(comp, path).result return compiler.consolidate_and_extract_local_processing( building_blocks.Lambda(None, None, extracted), grappler_config)
def test_reduces_unplaced_lambda_to_equivalent_tf(self): lam = building_blocks.Lambda('x', tf.int32, building_blocks.Reference('x', tf.int32)) extracted_tf = compiler.consolidate_and_extract_local_processing( lam, DEFAULT_GRAPPLER_CONFIG) executable_tf = computation_impl.ConcreteComputation.from_building_block( extracted_tf) executable_lam = computation_impl.ConcreteComputation.from_building_block( lam) for k in range(10): self.assertEqual(executable_tf(k), executable_lam(k))
def test_reduces_federated_identity_to_member_identity(self): fed_int_type = computation_types.FederatedType(tf.int32, placements.CLIENTS) lam = building_blocks.Lambda( 'x', fed_int_type, building_blocks.Reference('x', fed_int_type)) extracted_tf = compiler.consolidate_and_extract_local_processing( lam, DEFAULT_GRAPPLER_CONFIG) self.assertIsInstance(extracted_tf, building_blocks.CompiledComputation) unplaced_function_type = computation_types.FunctionType( fed_int_type.member, fed_int_type.member) self.assertEqual(extracted_tf.type_signature, unplaced_function_type)
def test_already_reduced_case(self): init = form_utils.get_iterative_process_for_map_reduce_form( mapreduce_test_utils.get_temperature_sensor_example()).initialize comp = init.to_building_block() result = compiler.consolidate_and_extract_local_processing( comp, DEFAULT_GRAPPLER_CONFIG) self.assertIsInstance(result, building_blocks.CompiledComputation) self.assertIsInstance(result.proto, computation_pb2.Computation) self.assertEqual(result.proto.WhichOneof('computation'), 'tensorflow')
def test_reduces_federated_value_at_clients_to_equivalent_noarg_function( self): zero = building_block_factory.create_tensorflow_constant( computation_types.TensorType(tf.int32, shape=[]), 0) federated_value = building_block_factory.create_federated_value( zero, placements.CLIENTS) federated_value_func = building_blocks.Lambda(None, None, federated_value) extracted_tf = compiler.consolidate_and_extract_local_processing( federated_value_func, DEFAULT_GRAPPLER_CONFIG) executable_tf = computation_impl.ConcreteComputation.from_building_block( extracted_tf) self.assertEqual(executable_tf(), 0)
def _extract_client_processing(after_broadcast, grappler_config): """Extracts `client_processing` from `after_broadcast`.""" context_from_server_index_in_after_broadcast = (1, ) client_data_index_in_after_broadcast = (0, 1) # NOTE: the order of parameters here is different from `work`. # `work` is odd in that it takes its parameters as `(data, params)` rather # than `(params, data)` (the order of the iterative process / computation). # Here, we use the same `(params, data)` ordering as in the input computation. client_processing = _as_function_of_some_federated_subparameters( after_broadcast, [ context_from_server_index_in_after_broadcast, client_data_index_in_after_broadcast ]) return compiler.consolidate_and_extract_local_processing( client_processing, grappler_config)
def _extract_work(before_aggregate, grappler_config): """Extracts `work` from `before_aggregate`. This function is intended to be used by `get_map_reduce_form_for_iterative_process` only. As a result, this function does not assert that `before_aggregate` has the expected structure, the caller is expected to perform these checks before calling this function. Args: before_aggregate: The first result of splitting `after_broadcast` on aggregate intrinsics. grappler_config: An instance of `tf.compat.v1.ConfigProto` to configure Grappler graph optimization. Returns: `work` as specified by `forms.MapReduceForm`, an instance of `building_blocks.CompiledComputation`. Raises: compiler.MapReduceFormCompilationError: If we extract an AST of the wrong type. """ # Indices of `work` args in `before_aggregate` parameter client_data_index = ('original_arg', 1) broadcast_result_index = ('federated_broadcast_result', ) work_to_before_aggregate = _as_function_of_some_federated_subparameters( before_aggregate, [client_data_index, broadcast_result_index]) # Indices of `work` results in `before_aggregate` result aggregate_input_index = ('federated_aggregate_param', 0) secure_sum_bitwidth_input_index = ('federated_secure_sum_bitwidth_param', 0) secure_sum_input_index = ('federated_secure_sum_param', 0) secure_modular_sum_input_index = ('federated_secure_modular_sum_param', 0) work_unzipped = building_block_factory.select_output_from_lambda( work_to_before_aggregate, [ aggregate_input_index, secure_sum_bitwidth_input_index, secure_sum_input_index, secure_modular_sum_input_index, ]) work = building_blocks.Lambda( work_unzipped.parameter_name, work_unzipped.parameter_type, building_block_factory.create_federated_zip(work_unzipped.result)) return compiler.consolidate_and_extract_local_processing( work, grappler_config)
def test_reduces_federated_apply_to_equivalent_function(self): lam = building_blocks.Lambda('x', tf.int32, building_blocks.Reference('x', tf.int32)) arg_type = computation_types.FederatedType(tf.int32, placements.CLIENTS) arg = building_blocks.Reference('arg', arg_type) map_block = building_block_factory.create_federated_map_or_apply( lam, arg) mapping_fn = building_blocks.Lambda('arg', arg_type, map_block) extracted_tf = compiler.consolidate_and_extract_local_processing( mapping_fn, DEFAULT_GRAPPLER_CONFIG) self.assertIsInstance(extracted_tf, building_blocks.CompiledComputation) executable_tf = computation_impl.ConcreteComputation.from_building_block( extracted_tf) executable_lam = computation_impl.ConcreteComputation.from_building_block( lam) for k in range(10): self.assertEqual(executable_tf(k), executable_lam(k))
def test_raises_on_none(self): with self.assertRaises(TypeError): compiler.consolidate_and_extract_local_processing( None, DEFAULT_GRAPPLER_CONFIG)
def get_map_reduce_form_for_iterative_process( ip: iterative_process.IterativeProcess, grappler_config: tf.compat.v1.ConfigProto = _GRAPPLER_DEFAULT_CONFIG, *, tff_internal_preprocessing: Optional[BuildingBlockFn] = None, ) -> forms.MapReduceForm: """Constructs `tff.backends.mapreduce.MapReduceForm` given iterative process. Args: ip: An instance of `tff.templates.IterativeProcess` that is compatible with MapReduce form. Iterative processes are only compatible if `initialize_fn` returns a single federated value placed at `SERVER` and `next` takes exactly two arguments. The first must be the state value placed at `SERVER`. - `next` returns exactly two values. grappler_config: An optional instance of `tf.compat.v1.ConfigProto` to configure Grappler graph optimization of the TensorFlow graphs backing the resulting `tff.backends.mapreduce.MapReduceForm`. These options are combined with a set of defaults that aggressively configure Grappler. If the input `grappler_config` has `graph_options.rewrite_options.disable_meta_optimizer=True`, Grappler is bypassed. tff_internal_preprocessing: An optional function to transform the AST of the iterative process. Returns: An instance of `tff.backends.mapreduce.MapReduceForm` equivalent to the provided `tff.templates.IterativeProcess`. Raises: TypeError: If the arguments are of the wrong types. compiler.MapReduceFormCompilationError: If the compilation process fails. """ py_typecheck.check_type(ip, iterative_process.IterativeProcess) initialize_bb, next_bb = ( check_iterative_process_compatible_with_map_reduce_form( ip, tff_internal_preprocessing=tff_internal_preprocessing)) py_typecheck.check_type(grappler_config, tf.compat.v1.ConfigProto) grappler_config = _merge_grappler_config_with_default(grappler_config) next_bb, _ = tree_transformations.uniquify_reference_names(next_bb) before_broadcast, after_broadcast = _split_ast_on_broadcast(next_bb) before_aggregate, after_aggregate = _split_ast_on_aggregate( after_broadcast) initialize = compiler.consolidate_and_extract_local_processing( initialize_bb, grappler_config) prepare = _extract_prepare(before_broadcast, grappler_config) work = _extract_work(before_aggregate, grappler_config) zero, accumulate, merge, report = _extract_federated_aggregate_functions( before_aggregate, grappler_config) secure_sum_bitwidth = _compile_selected_output_to_no_argument_tensorflow( before_aggregate, ('federated_secure_sum_bitwidth_param', 1), grappler_config) secure_sum_max_input = _compile_selected_output_to_no_argument_tensorflow( before_aggregate, ('federated_secure_sum_param', 1), grappler_config) secure_sum_modulus = _compile_selected_output_to_no_argument_tensorflow( before_aggregate, ('federated_secure_modular_sum_param', 1), grappler_config) update = _extract_update(after_aggregate, grappler_config) next_parameter_names = structure.name_list_with_nones( ip.next.type_signature.parameter) server_state_label, client_data_label = next_parameter_names blocks = (initialize, prepare, work, zero, accumulate, merge, report, secure_sum_bitwidth, secure_sum_max_input, secure_sum_modulus, update) comps = (computation_impl.ConcreteComputation.from_building_block(bb) for bb in blocks) return forms.MapReduceForm(*comps, server_state_label=server_state_label, client_data_label=client_data_label)
def _extract_update(after_aggregate, grappler_config): """Extracts `update` from `after_aggregate`. This function is intended to be used by `get_map_reduce_form_for_iterative_process` only. As a result, this function does not assert that `after_aggregate` has the expected structure, the caller is expected to perform these checks before calling this function. Args: after_aggregate: The second result of splitting `after_broadcast` on aggregate intrinsics. grappler_config: An instance of `tf.compat.v1.ConfigProto` to configure Grappler graph optimization. Returns: `update` as specified by `forms.MapReduceForm`, an instance of `building_blocks.CompiledComputation`. Raises: compiler.MapReduceFormCompilationError: If we extract an AST of the wrong type. """ after_aggregate_zipped = building_blocks.Lambda( after_aggregate.parameter_name, after_aggregate.parameter_type, building_block_factory.create_federated_zip(after_aggregate.result)) # `create_federated_zip` doesn't have unique reference names, but we need # them for `as_function_of_some_federated_subparameters`. after_aggregate_zipped, _ = tree_transformations.uniquify_reference_names( after_aggregate_zipped) server_state_index = ('original_arg', 'original_arg', 0) aggregate_result_index = ('intrinsic_results', 'federated_aggregate_result') secure_sum_bitwidth_result_index = ('intrinsic_results', 'federated_secure_sum_bitwidth_result') secure_sum_result_index = ('intrinsic_results', 'federated_secure_sum_result') secure_modular_sum_result_index = ('intrinsic_results', 'federated_secure_modular_sum_result') update_with_flat_inputs = _as_function_of_some_federated_subparameters( after_aggregate_zipped, ( server_state_index, aggregate_result_index, secure_sum_bitwidth_result_index, secure_sum_result_index, secure_modular_sum_result_index, )) # TODO(b/148942011): The transformation # `zip_selection_as_argument_to_lower_level_lambda` does not support selecting # from nested structures, therefore we need to transform the input from # <server_state, <aggregation_results...>> into # <server_state, aggregation_results...> # unpack = <v, <...>> -> <v, ...> name_generator = building_block_factory.unique_name_generator( update_with_flat_inputs) unpack_param_name = next(name_generator) original_param_type = update_with_flat_inputs.parameter_type.member unpack_param_type = computation_types.StructType([ original_param_type[0], computation_types.StructType(original_param_type[1:]), ]) unpack_param_ref = building_blocks.Reference(unpack_param_name, unpack_param_type) select = lambda bb, i: building_blocks.Selection(bb, index=i) unpack = building_blocks.Lambda( unpack_param_name, unpack_param_type, building_blocks.Struct([select(unpack_param_ref, 0)] + [ select(select(unpack_param_ref, 1), i) for i in range(len(original_param_type) - 1) ])) # update = v -> update_with_flat_inputs(federated_map(unpack, v)) param_name = next(name_generator) param_type = computation_types.at_server(unpack_param_type) param_ref = building_blocks.Reference(param_name, param_type) update = building_blocks.Lambda( param_name, param_type, building_blocks.Call( update_with_flat_inputs, building_block_factory.create_federated_map_or_apply( unpack, param_ref))) return compiler.consolidate_and_extract_local_processing( update, grappler_config)