def test_replaces_lambda_to_called_composition_of_tf_blocks_with_tf_of_same_type_named_param( self): selection_type = computation_types.StructType([('a', tf.int32), ('b', tf.float32)]) selection_tf_block = _create_compiled_computation(lambda x: x[0], selection_type) add_one_int_type = computation_types.TensorType(tf.int32) add_one_int_tf_block = _create_compiled_computation(lambda x: x + 1, add_one_int_type) int_ref = building_blocks.Reference('x', [('a', tf.int32), ('b', tf.float32)]) called_selection = building_blocks.Call(selection_tf_block, int_ref) one_added = building_blocks.Call(add_one_int_tf_block, called_selection) lambda_wrapper = building_blocks.Lambda('x', [('a', tf.int32), ('b', tf.float32)], one_added) parsed, modified = parse_tff_to_tf(lambda_wrapper) exec_lambda = computation_wrapper_instances.building_block_to_computation( lambda_wrapper) exec_tf = computation_wrapper_instances.building_block_to_computation( parsed) self.assertIsInstance(parsed, building_blocks.CompiledComputation) self.assertTrue(modified) # TODO(b/157172423): change to assertEqual when Py container is preserved. parsed.type_signature.check_equivalent_to(lambda_wrapper.type_signature) self.assertEqual( exec_lambda({ 'a': 15, 'b': 16. }), exec_tf({ 'a': 15, 'b': 16. }))
def test_replaces_lambda_to_called_graph_on_tuple_of_selections_from_arg_with_tf_of_same_type( self): identity_tf_block = building_block_factory.create_compiled_identity( [tf.int32, tf.bool]) tuple_ref = building_blocks.Reference('x', [tf.int32, tf.float32, tf.bool]) selected_int = building_blocks.Selection(tuple_ref, index=0) selected_bool = building_blocks.Selection(tuple_ref, index=2) created_tuple = building_blocks.Tuple([selected_int, selected_bool]) called_tf_block = building_blocks.Call(identity_tf_block, created_tuple) lambda_wrapper = building_blocks.Lambda( 'x', [tf.int32, tf.float32, tf.bool], called_tf_block) parsed, modified = parse_tff_to_tf(lambda_wrapper) exec_lambda = computation_wrapper_instances.building_block_to_computation( lambda_wrapper) exec_tf = computation_wrapper_instances.building_block_to_computation( parsed) self.assertIsInstance(parsed, building_blocks.CompiledComputation) self.assertTrue(modified) self.assertEqual(parsed.type_signature, lambda_wrapper.type_signature) exec_lambda = computation_wrapper_instances.building_block_to_computation( lambda_wrapper) exec_tf = computation_wrapper_instances.building_block_to_computation( parsed) self.assertEqual(exec_lambda([7, 8., True]), exec_tf([7, 8., True]))
def test_replaces_lambda_to_called_graph_on_selection_from_arg_with_tf_of_same_type_with_names( self): identity_tf_block = building_block_factory.create_compiled_identity( tf.int32) tuple_ref = building_blocks.Reference('x', [('a', tf.int32), ('b', tf.float32)]) selected_int = building_blocks.Selection(tuple_ref, index=0) called_tf_block = building_blocks.Call(identity_tf_block, selected_int) lambda_wrapper = building_blocks.Lambda('x', [('a', tf.int32), ('b', tf.float32)], called_tf_block) parsed, modified = parse_tff_to_tf(lambda_wrapper) exec_lambda = computation_wrapper_instances.building_block_to_computation( lambda_wrapper) exec_tf = computation_wrapper_instances.building_block_to_computation( parsed) self.assertIsInstance(parsed, building_blocks.CompiledComputation) self.assertTrue(modified) self.assertEqual(parsed.type_signature, lambda_wrapper.type_signature) self.assertEqual(exec_lambda({ 'a': 5, 'b': 6. }), exec_tf({ 'a': 5, 'b': 6. }))
def test_replaces_lambda_to_called_composition_of_tf_blocks_with_tf_of_same_type_named_param( self): selection_tf_block = _create_compiled_computation( lambda x: x[0], [('a', tf.int32), ('b', tf.float32)]) add_one_int_tf_block = _create_compiled_computation( lambda x: x + 1, tf.int32) int_ref = building_blocks.Reference('x', [('a', tf.int32), ('b', tf.float32)]) called_selection = building_blocks.Call(selection_tf_block, int_ref) one_added = building_blocks.Call(add_one_int_tf_block, called_selection) lambda_wrapper = building_blocks.Lambda('x', [('a', tf.int32), ('b', tf.float32)], one_added) parsed, modified = parse_tff_to_tf(lambda_wrapper) exec_lambda = computation_wrapper_instances.building_block_to_computation( lambda_wrapper) exec_tf = computation_wrapper_instances.building_block_to_computation( parsed) self.assertIsInstance(parsed, building_blocks.CompiledComputation) self.assertTrue(modified) self.assertEqual(parsed.type_signature, lambda_wrapper.type_signature) self.assertEqual(exec_lambda({ 'a': 15, 'b': 16. }), exec_tf({ 'a': 15, 'b': 16. }))
def test_replaces_lambda_to_named_tuple_of_called_graphs_with_tf_of_same_type( self): int_identity_tf_block = building_block_factory.create_compiled_identity( tf.int32) float_identity_tf_block = building_block_factory.create_compiled_identity( tf.float32) tuple_ref = building_blocks.Reference('x', [tf.int32, tf.float32]) selected_int = building_blocks.Selection(tuple_ref, index=0) selected_float = building_blocks.Selection(tuple_ref, index=1) called_int_tf_block = building_blocks.Call(int_identity_tf_block, selected_int) called_float_tf_block = building_blocks.Call(float_identity_tf_block, selected_float) tuple_of_called_graphs = building_blocks.Tuple([ ('a', called_int_tf_block), ('b', called_float_tf_block) ]) lambda_wrapper = building_blocks.Lambda('x', [tf.int32, tf.float32], tuple_of_called_graphs) parsed, modified = parse_tff_to_tf(lambda_wrapper) exec_lambda = computation_wrapper_instances.building_block_to_computation( lambda_wrapper) exec_tf = computation_wrapper_instances.building_block_to_computation( parsed) self.assertIsInstance(parsed, building_blocks.CompiledComputation) self.assertTrue(modified) self.assertEqual(parsed.type_signature, lambda_wrapper.type_signature) self.assertEqual(exec_lambda([13, 14.]), exec_tf([13, 14.]))
def test_replaces_lambda_to_named_tuple_of_called_graphs_with_tf_of_same_type( self): int_tensor_type = computation_types.TensorType(tf.int32) int_identity_tf_block = building_block_factory.create_compiled_identity( int_tensor_type) float_tensor_type = computation_types.TensorType(tf.float32) float_identity_tf_block = building_block_factory.create_compiled_identity( float_tensor_type) tuple_ref = building_blocks.Reference('x', [tf.int32, tf.float32]) selected_int = building_blocks.Selection(tuple_ref, index=0) selected_float = building_blocks.Selection(tuple_ref, index=1) called_int_tf_block = building_blocks.Call(int_identity_tf_block, selected_int) called_float_tf_block = building_blocks.Call(float_identity_tf_block, selected_float) tuple_of_called_graphs = building_blocks.Struct([ ('a', called_int_tf_block), ('b', called_float_tf_block) ]) lambda_wrapper = building_blocks.Lambda('x', [tf.int32, tf.float32], tuple_of_called_graphs) parsed, modified = parse_tff_to_tf(lambda_wrapper) exec_lambda = computation_wrapper_instances.building_block_to_computation( lambda_wrapper) exec_tf = computation_wrapper_instances.building_block_to_computation( parsed) self.assertIsInstance(parsed, building_blocks.CompiledComputation) self.assertTrue(modified) # TODO(b/157172423): change to assertEqual when Py container is preserved. parsed.type_signature.check_equivalent_to( lambda_wrapper.type_signature) self.assertEqual(exec_lambda([13, 14.]), exec_tf([13, 14.]))
def test_replaces_lambda_to_selection_from_called_graph_with_tf_of_same_type( self): identity_tf_block_type = computation_types.StructType( [tf.int32, tf.float32]) identity_tf_block = building_block_factory.create_compiled_identity( identity_tf_block_type) tuple_ref = building_blocks.Reference('x', [tf.int32, tf.float32]) called_tf_block = building_blocks.Call(identity_tf_block, tuple_ref) selection_from_call = building_blocks.Selection(called_tf_block, index=1) lambda_wrapper = building_blocks.Lambda('x', [tf.int32, tf.float32], selection_from_call) parsed, modified = parse_tff_to_tf(lambda_wrapper) exec_lambda = computation_wrapper_instances.building_block_to_computation( lambda_wrapper) exec_tf = computation_wrapper_instances.building_block_to_computation( parsed) self.assertIsInstance(parsed, building_blocks.CompiledComputation) self.assertTrue(modified) # TODO(b/157172423): change to assertEqual when Py container is preserved. parsed.type_signature.check_equivalent_to( lambda_wrapper.type_signature) self.assertEqual(exec_lambda([0, 1.]), exec_tf([0, 1.]))
def test_replaces_lambda_to_called_graph_on_tuple_of_selections_from_arg_with_tf_of_same_type( self): identity_tf_block_type = computation_types.StructType( [tf.int32, tf.bool]) identity_tf_block = building_block_factory.create_compiled_identity( identity_tf_block_type) tuple_ref = building_blocks.Reference('x', [tf.int32, tf.float32, tf.bool]) selected_int = building_blocks.Selection(tuple_ref, index=0) selected_bool = building_blocks.Selection(tuple_ref, index=2) created_tuple = building_blocks.Struct([selected_int, selected_bool]) called_tf_block = building_blocks.Call(identity_tf_block, created_tuple) lambda_wrapper = building_blocks.Lambda( 'x', [tf.int32, tf.float32, tf.bool], called_tf_block) parsed, modified = parse_tff_to_tf(lambda_wrapper) exec_lambda = computation_wrapper_instances.building_block_to_computation( lambda_wrapper) exec_tf = computation_wrapper_instances.building_block_to_computation( parsed) self.assertIsInstance(parsed, building_blocks.CompiledComputation) self.assertTrue(modified) # FIXME(b/157172423) change to assertEqual when Py container is preserved. parsed.type_signature.check_equivalent_to( lambda_wrapper.type_signature) exec_lambda = computation_wrapper_instances.building_block_to_computation( lambda_wrapper) exec_tf = computation_wrapper_instances.building_block_to_computation( parsed) self.assertEqual(exec_lambda([7, 8., True]), exec_tf([7, 8., True]))
def test_reduces_unplaced_lambda_to_equivalent_tf(self): lam = building_blocks.Lambda('x', tf.int32, building_blocks.Reference('x', tf.int32)) extracted_tf = transformations.consolidate_and_extract_local_processing(lam) executable_tf = computation_wrapper_instances.building_block_to_computation( extracted_tf) executable_lam = computation_wrapper_instances.building_block_to_computation( lam) for k in range(10): self.assertEqual(executable_tf(k), executable_lam(k))
def test_reduces_federated_apply_to_equivalent_function(self): lam = building_blocks.Lambda('x', tf.int32, building_blocks.Reference('x', tf.int32)) arg = building_blocks.Reference( 'arg', computation_types.FederatedType(tf.int32, placements.CLIENTS)) mapped_fn = building_block_factory.create_federated_map_or_apply(lam, arg) extracted_tf = mapreduce_transformations.consolidate_and_extract_local_processing( mapped_fn) self.assertIsInstance(extracted_tf, building_blocks.CompiledComputation) executable_tf = computation_wrapper_instances.building_block_to_computation( extracted_tf) executable_lam = computation_wrapper_instances.building_block_to_computation( lam) for k in range(10): self.assertEqual(executable_tf(k), executable_lam(k))
def transform_to_native_form( comp: computation_base.Computation) -> computation_base.Computation: """Compiles a computation for execution in the TFF native runtime. This function transforms the proto underlying `comp` by first transforming it to call-dominant form (see `tff.framework.transform_to_call_dominant` for definition), then computing information on the dependency structure of the bindings and remapping them into tuples, such that every computation is evaluated as early as possible, and parallelized with any other computation with which it shares dependency structure. Args: comp: Instance of `computation_base.Computation` to compile. Returns: A new `computation_base.Computation` representing the compiled version of `comp`. """ proto = computation_impl.ComputationImpl.get_proto(comp) computation_building_block = building_blocks.ComputationBuildingBlock.from_proto( proto) try: logging.debug('Compiling TFF computation.') call_dominant_form, _ = transformations.transform_to_call_dominant( computation_building_block) logging.debug('Computation compiled to:') logging.debug(call_dominant_form.formatted_representation()) return computation_wrapper_instances.building_block_to_computation( call_dominant_form) except ValueError as e: logging.debug('Compilation for native runtime failed with error %s', e) logging.debug('computation: %s', computation_building_block.compact_representation()) return comp
def compiler(comp): # Compile secure_sum and secure_sum_bitwidth intrinsics to insecure # TensorFlow computations for testing purposes. replaced_intrinsic_bodies, _ = intrinsic_reductions.replace_secure_intrinsics_with_insecure_bodies( comp.to_building_block()) return computation_wrapper_instances.building_block_to_computation( replaced_intrinsic_bodies)
def transform_mathematical_functions_to_tensorflow( comp: computation_base.Computation, ) -> computation_base.Computation: """Compiles all mathematical functions in `comp` to TensorFlow blocks. Notice that this does not necessarily represent a strict performance improvement. In particular, this compilation will not attempt to deduplicate across the boundaries of communication operators, and therefore it may be the case that compiling eagerly to TensorFlow hides the opportunity for a dynamic cache to be used. Args: comp: Instance of `computation_base.Computation` to compile. Returns: A new `computation_base.Computation` representing the compiled version of `comp`. """ proto = computation_impl.ComputationImpl.get_proto(comp) computation_building_block = building_blocks.ComputationBuildingBlock.from_proto( proto) try: logging.debug('Compiling local computations to TensorFlow.') tf_compiled, _ = transformations.compile_local_computation_to_tensorflow( computation_building_block) logging.debug('Local computations compiled to TF:') logging.debug(tf_compiled.formatted_representation()) return computation_wrapper_instances.building_block_to_computation( tf_compiled) except ValueError as e: logging.debug( 'Compilation of local computation to TensorFlow failed with error %s', e) logging.debug('computation: %s', computation_building_block.compact_representation()) return comp
def test_identity_lambda_executes_as_identity(self): lam = building_blocks.Lambda('x', tf.int32, building_blocks.Reference('x', tf.int32)) computation_impl_lambda = computation_wrapper_instances.building_block_to_computation( lam) for k in range(10): self.assertEqual(computation_impl_lambda(k), k)
def transform_to_native_form( comp: computation_base.Computation) -> computation_base.Computation: """Compiles a computation for execution in the TFF native runtime. This function transforms the proto underlying `comp` by transforming it to call-dominant form (see `tff.framework.transform_to_call_dominant` for definition). Args: comp: Instance of `computation_base.Computation` to compile. Returns: A new `computation_base.Computation` representing the compiled version of `comp`. """ proto = computation_impl.ComputationImpl.get_proto(comp) computation_building_block = building_blocks.ComputationBuildingBlock.from_proto( proto) try: logging.debug('Compiling TFF computation.') call_dominant_form, _ = transformations.transform_to_call_dominant( computation_building_block) logging.debug('Computation compiled to:') logging.debug(call_dominant_form.formatted_representation()) return computation_wrapper_instances.building_block_to_computation( call_dominant_form) except ValueError as e: logging.debug('Compilation for native runtime failed with error %s', e) logging.debug('computation: %s', computation_building_block.compact_representation()) return comp
def test_reduces_federated_value_at_server_to_equivalent_noarg_function(self): federated_value = intrinsics.federated_value(0, placements.SERVER)._comp extracted_tf = mapreduce_transformations.consolidate_and_extract_local_processing( federated_value) executable_tf = computation_wrapper_instances.building_block_to_computation( extracted_tf) self.assertEqual(executable_tf(), 0)
def test_converts_building_block_to_computation(self): lam = building_blocks.Lambda('x', tf.int32, building_blocks.Reference('x', tf.int32)) computation_impl_lambda = computation_wrapper_instances.building_block_to_computation( lam) self.assertIsInstance(computation_impl_lambda, computation_impl.ComputationImpl)
def get_broadcast_form_for_computation( comp: computation_base.Computation, grappler_config: tf.compat.v1.ConfigProto = _GRAPPLER_DEFAULT_CONFIG ) -> forms.BroadcastForm: """Constructs `tff.backends.mapreduce.BroadcastForm` given a computation. Args: comp: An instance of `tff.Computation` that is compatible with broadcast form. Computations are only compatible if they take in a single value placed at server, return a single value placed at clients, and do not contain any aggregations. grappler_config: An instance of `tf.compat.v1.ConfigProto` to configure Grappler graph optimization of the Tensorflow graphs backing the resulting `tff.backends.mapreduce.BroadcastForm`. These options are combined with a set of defaults that aggressively configure Grappler. If `grappler_config_proto` has `graph_options.rewrite_options.disable_meta_optimizer=True`, Grappler is bypassed. Returns: An instance of `tff.backends.mapreduce.BroadcastForm` equivalent to the provided `tff.Computation`. """ py_typecheck.check_type(comp, computation_base.Computation) _check_function_signature_compatible_with_broadcast_form( comp.type_signature) py_typecheck.check_type(grappler_config, tf.compat.v1.ConfigProto) grappler_config = _merge_grappler_config_with_default(grappler_config) bb = comp.to_building_block() bb, _ = intrinsic_reductions.replace_intrinsics_with_bodies(bb) bb = _replace_lambda_body_with_call_dominant_form(bb) tree_analysis.check_contains_only_reducible_intrinsics(bb) aggregations = tree_analysis.find_aggregations_in_tree(bb) if aggregations: raise ValueError( f'`get_broadcast_form_for_computation` called with computation ' f'containing {len(aggregations)} aggregations, but broadcast form ' 'does not allow aggregation. Full list of aggregations:\n{aggregations}' ) before_broadcast, after_broadcast = _split_ast_on_broadcast(bb) compute_server_context = _extract_compute_server_context( before_broadcast, grappler_config) client_processing = _extract_client_processing(after_broadcast, grappler_config) compute_server_context, client_processing = ( computation_wrapper_instances.building_block_to_computation(bb) for bb in (compute_server_context, client_processing)) comp_param_names = structure.name_list_with_nones( comp.type_signature.parameter) server_data_label, client_data_label = comp_param_names return forms.BroadcastForm(compute_server_context, client_processing, server_data_label=server_data_label, client_data_label=client_data_label)
def get_map_reduce_form_for_iterative_process( ip: iterative_process.IterativeProcess, grappler_config: tf.compat.v1.ConfigProto = _GRAPPLER_DEFAULT_CONFIG ) -> forms.MapReduceForm: """Constructs `tff.backends.mapreduce.MapReduceForm` given iterative process. Args: ip: An instance of `tff.templates.IterativeProcess` that is compatible with MapReduce form. Iterative processes are only compatible if `initialize_fn` returns a single federated value placed at `SERVER` and `next` takes exactly two arguments. The first must be the state value placed at `SERVER`. - `next` returns exactly two values. grappler_config: An optional instance of `tf.compat.v1.ConfigProto` to configure Grappler graph optimization of the TensorFlow graphs backing the resulting `tff.backends.mapreduce.MapReduceForm`. These options are combined with a set of defaults that aggressively configure Grappler. If the input `grappler_config` has `graph_options.rewrite_options.disable_meta_optimizer=True`, Grappler is bypassed. Returns: An instance of `tff.backends.mapreduce.MapReduceForm` equivalent to the provided `tff.templates.IterativeProcess`. Raises: TypeError: If the arguments are of the wrong types. transformations.MapReduceFormCompilationError: If the compilation process fails. """ py_typecheck.check_type(ip, iterative_process.IterativeProcess) initialize_bb, next_bb = ( check_iterative_process_compatible_with_map_reduce_form(ip)) py_typecheck.check_type(grappler_config, tf.compat.v1.ConfigProto) grappler_config = _merge_grappler_config_with_default(grappler_config) next_bb, _ = tree_transformations.uniquify_reference_names(next_bb) before_broadcast, after_broadcast = _split_ast_on_broadcast(next_bb) before_aggregate, after_aggregate = _split_ast_on_aggregate( after_broadcast) initialize = transformations.consolidate_and_extract_local_processing( initialize_bb, grappler_config) prepare = _extract_prepare(before_broadcast, grappler_config) work = _extract_work(before_aggregate, grappler_config) zero, accumulate, merge, report = _extract_federated_aggregate_functions( before_aggregate, grappler_config) bitwidth = _extract_federated_secure_sum_bitwidth_functions( before_aggregate, grappler_config) update = _extract_update(after_aggregate, grappler_config) next_parameter_names = structure.name_list_with_nones( ip.next.type_signature.parameter) server_state_label, client_data_label = next_parameter_names comps = (computation_wrapper_instances.building_block_to_computation(bb) for bb in (initialize, prepare, work, zero, accumulate, merge, report, bitwidth, update)) return forms.MapReduceForm(*comps, server_state_label=server_state_label, client_data_label=client_data_label)
def test_replaces_lambda_to_called_graph_with_tf_of_same_type(self): identity_tf_block = building_block_factory.create_compiled_identity( tf.int32) int_ref = building_blocks.Reference('x', tf.int32) called_tf_block = building_blocks.Call(identity_tf_block, int_ref) lambda_wrapper = building_blocks.Lambda('x', tf.int32, called_tf_block) parsed, modified = parse_tff_to_tf(lambda_wrapper) exec_lambda = computation_wrapper_instances.building_block_to_computation( lambda_wrapper) exec_tf = computation_wrapper_instances.building_block_to_computation( parsed) self.assertIsInstance(parsed, building_blocks.CompiledComputation) self.assertTrue(modified) self.assertEqual(parsed.type_signature, lambda_wrapper.type_signature) self.assertEqual(exec_lambda(2), exec_tf(2))
def test_replaces_lambda_to_called_tf_block_with_replicated_lambda_arg_with_tf_block_of_same_type( self): sum_and_add_one = _create_compiled_computation( lambda x: x[0] + x[1] + 1, [tf.int32, tf.int32]) int_ref = building_blocks.Reference('x', tf.int32) tuple_of_ints = building_blocks.Tuple((int_ref, int_ref)) summed = building_blocks.Call(sum_and_add_one, tuple_of_ints) lambda_wrapper = building_blocks.Lambda('x', tf.int32, summed) parsed, modified = parse_tff_to_tf(lambda_wrapper) exec_lambda = computation_wrapper_instances.building_block_to_computation( lambda_wrapper) exec_tf = computation_wrapper_instances.building_block_to_computation( parsed) self.assertIsInstance(parsed, building_blocks.CompiledComputation) self.assertTrue(modified) self.assertEqual(parsed.type_signature, lambda_wrapper.type_signature) self.assertEqual(exec_lambda(17), exec_tf(17))
def test_replaces_lambda_to_called_graph_with_tf_of_same_type(self): identity_tf_block_type = computation_types.TensorType(tf.int32) identity_tf_block = building_block_factory.create_compiled_identity( identity_tf_block_type) int_ref = building_blocks.Reference('x', tf.int32) called_tf_block = building_blocks.Call(identity_tf_block, int_ref) lambda_wrapper = building_blocks.Lambda('x', tf.int32, called_tf_block) parsed, modified = parse_tff_to_tf(lambda_wrapper) exec_lambda = computation_wrapper_instances.building_block_to_computation( lambda_wrapper) exec_tf = computation_wrapper_instances.building_block_to_computation( parsed) self.assertIsInstance(parsed, building_blocks.CompiledComputation) self.assertTrue(modified) # TODO(b/157172423): change to assertEqual when Py container is preserved. parsed.type_signature.check_equivalent_to(lambda_wrapper.type_signature) self.assertEqual(exec_lambda(2), exec_tf(2))
def test_reduces_federated_value_at_server_to_equivalent_noarg_function(self): zero = building_block_factory.create_tensorflow_constant( computation_types.TensorType(tf.int32, shape=[]), 0) federated_value = building_block_factory.create_federated_value( zero, placements.SERVER) extracted_tf = transformations.consolidate_and_extract_local_processing( federated_value, DEFAULT_GRAPPLER_CONFIG) executable_tf = computation_wrapper_instances.building_block_to_computation( extracted_tf) self.assertEqual(executable_tf(), 0)
def test_next_computation_returning_tensor_fails_well(self): cf = test_utils.get_temperature_sensor_example() it = canonical_form_utils.get_iterative_process_for_canonical_form(cf) init_result = it.initialize.type_signature.result lam = building_blocks.Lambda('x', init_result, building_blocks.Reference('x', init_result)) bad_it = iterative_process.IterativeProcess( it.initialize, computation_wrapper_instances.building_block_to_computation(lam)) with self.assertRaises(TypeError): canonical_form_utils.get_canonical_form_for_iterative_process(bad_it)
def test_replaces_lambda_to_selection_from_called_graph_with_tf_of_same_type( self): identity_tf_block = building_block_factory.create_compiled_identity( [tf.int32, tf.float32]) tuple_ref = building_blocks.Reference('x', [tf.int32, tf.float32]) called_tf_block = building_blocks.Call(identity_tf_block, tuple_ref) selection_from_call = building_blocks.Selection(called_tf_block, index=1) lambda_wrapper = building_blocks.Lambda('x', [tf.int32, tf.float32], selection_from_call) parsed, modified = parse_tff_to_tf(lambda_wrapper) exec_lambda = computation_wrapper_instances.building_block_to_computation( lambda_wrapper) exec_tf = computation_wrapper_instances.building_block_to_computation( parsed) self.assertIsInstance(parsed, building_blocks.CompiledComputation) self.assertTrue(modified) self.assertEqual(parsed.type_signature, lambda_wrapper.type_signature) self.assertEqual(exec_lambda([0, 1.]), exec_tf([0, 1.]))
def test_replaces_lambda_to_called_tf_block_with_replicated_lambda_arg_with_tf_block_of_same_type( self): sum_and_add_one_type = computation_types.StructType([tf.int32, tf.int32]) sum_and_add_one = _create_compiled_computation(lambda x: x[0] + x[1] + 1, sum_and_add_one_type) int_ref = building_blocks.Reference('x', tf.int32) tuple_of_ints = building_blocks.Struct((int_ref, int_ref)) summed = building_blocks.Call(sum_and_add_one, tuple_of_ints) lambda_wrapper = building_blocks.Lambda('x', tf.int32, summed) parsed, modified = parse_tff_to_tf(lambda_wrapper) exec_lambda = computation_wrapper_instances.building_block_to_computation( lambda_wrapper) exec_tf = computation_wrapper_instances.building_block_to_computation( parsed) self.assertIsInstance(parsed, building_blocks.CompiledComputation) self.assertTrue(modified) # TODO(b/157172423): change to assertEqual when Py container is preserved. parsed.type_signature.check_equivalent_to(lambda_wrapper.type_signature) self.assertEqual(exec_lambda(17), exec_tf(17))
def deserialize_computation( computation_proto: pb.Computation) -> computation_base.Computation: """Deserializes 'tff.Computation' as a pb.Computation. Args: computation_proto: An instance of `pb.Computation`. Returns: The corresponding instance of `tff.Computation`. Raises: TypeError: If the argument is of the wrong type. """ py_typecheck.check_type(computation_proto, pb.Computation) return computation_wrapper_instances.building_block_to_computation( building_blocks.ComputationBuildingBlock.from_proto(computation_proto))
def _do_not_use_transform_to_native_form(comp): """Use `tff.backends.native.transform_to_native_form`.""" proto = computation_impl.ComputationImpl.get_proto(comp) computation_building_block = building_blocks.ComputationBuildingBlock.from_proto( proto) try: logging.debug('Compiling TFF computation.') call_dominant_form, _ = transformations.transform_to_call_dominant( computation_building_block) logging.debug('Computation compiled to:') logging.debug(call_dominant_form.formatted_representation()) return computation_wrapper_instances.building_block_to_computation( call_dominant_form) except ValueError as e: logging.debug('Compilation for native runtime failed with error %s', e) logging.debug('computation: %s', computation_building_block.compact_representation()) return comp
def transform_to_native_form( comp: computation_base.Computation, transform_math_to_tf: bool = False) -> computation_base.Computation: """Compiles a computation for execution in the TFF native runtime. This function transforms the proto underlying `comp` by transforming it to call-dominant form (see `tff.framework.transform_to_call_dominant` for definition). Args: comp: Instance of `computation_base.Computation` to compile. transform_math_to_tf: Whether to additional transform math to TensorFlow graphs. Necessary if running on a execution state without ReferenceResolvingExecutors underneath FederatingExecutors. Returns: A new `computation_base.Computation` representing the compiled version of `comp`. """ proto = computation_impl.ComputationImpl.get_proto(comp) computation_building_block = building_blocks.ComputationBuildingBlock.from_proto( proto) try: logging.debug('Compiling TFF computation to CDF.') call_dominant_form, _ = transformations.transform_to_call_dominant( computation_building_block) logging.debug('Computation compiled to:') logging.debug(call_dominant_form.formatted_representation()) if transform_math_to_tf: logging.debug('Compiling local computations to TensorFlow.') call_dominant_form, _ = transformations.compile_local_computation_to_tensorflow( call_dominant_form) logging.debug('Computation compiled to:') logging.debug(call_dominant_form.formatted_representation()) call_dominant_form, _ = tree_transformations.transform_tf_call_ops_to_disable_grappler( call_dominant_form) return computation_wrapper_instances.building_block_to_computation( call_dominant_form) except ValueError as e: logging.debug('Compilation for native runtime failed with error %s', e) logging.debug('computation: %s', computation_building_block.compact_representation()) return comp
def test_broadcast_dependent_on_aggregate_fails_well(self): mrf = mapreduce_test_utils.get_temperature_sensor_example() it = form_utils.get_iterative_process_for_map_reduce_form(mrf) next_comp = it.next.to_building_block() top_level_param = building_blocks.Reference(next_comp.parameter_name, next_comp.parameter_type) first_result = building_blocks.Call(next_comp, top_level_param) middle_param = building_blocks.Struct([ building_blocks.Selection(first_result, index=0), building_blocks.Selection(top_level_param, index=1) ]) second_result = building_blocks.Call(next_comp, middle_param) not_reducible = building_blocks.Lambda(next_comp.parameter_name, next_comp.parameter_type, second_result) not_reducible_it = iterative_process.IterativeProcess( it.initialize, computation_wrapper_instances.building_block_to_computation( not_reducible)) with self.assertRaisesRegex(ValueError, 'broadcast dependent on aggregate'): form_utils.get_map_reduce_form_for_iterative_process(not_reducible_it)