def test_passes_unbound_type_signature_obscured_under_block(self): fed_ref = building_blocks.Reference( 'x', computation_types.FederatedType(tf.int32, placements.SERVER)) block = building_blocks.Block( [('y', fed_ref), ('x', building_blocks.Data('whimsy', tf.int32)), ('z', building_blocks.Reference('x', tf.int32))], building_blocks.Reference('y', fed_ref.type_signature)) tree_transformations.strip_placement(block)
def test_passes_noarg_lambda(self): lam = building_blocks.Lambda(None, None, building_blocks.Data('a', tf.int32)) fed_int_type = computation_types.FederatedType(tf.int32, placements.SERVER) fed_eval = building_blocks.Intrinsic( intrinsic_defs.FEDERATED_EVAL_AT_SERVER.uri, computation_types.FunctionType(lam.type_signature, fed_int_type)) called_eval = building_blocks.Call(fed_eval, lam) tree_transformations.strip_placement(called_eval)
def test_raises_multiple_placements(self): server_placed_data = building_blocks.Reference( 'x', computation_types.at_server(tf.int32)) clients_placed_data = building_blocks.Reference( 'y', computation_types.at_clients(tf.int32)) block_holding_both = building_blocks.Block([('x', server_placed_data)], clients_placed_data) with self.assertRaisesRegex(ValueError, 'multiple different placements'): tree_transformations.strip_placement(block_holding_both)
def test_raises_disallowed_intrinsic(self): fed_ref = building_blocks.Reference( 'x', computation_types.FederatedType(tf.int32, placements.SERVER)) broadcaster = building_blocks.Intrinsic( intrinsic_defs.FEDERATED_BROADCAST.uri, computation_types.FunctionType( fed_ref.type_signature, computation_types.FederatedType(fed_ref.type_signature.member, placements.CLIENTS, all_equal=True))) called_broadcast = building_blocks.Call(broadcaster, fed_ref) with self.assertRaises(ValueError): tree_transformations.strip_placement(called_broadcast)
def test_unwrap_removes_federated_zips_at_clients(self): list_type = computation_types.to_type([tf.int32, tf.float32] * 2) clients_list_type = computation_types.at_server(list_type) fed_tuple = building_blocks.Reference('tup', clients_list_type) unzipped = building_block_factory.create_federated_unzip(fed_tuple) before = building_block_factory.create_federated_zip(unzipped) after, modified = tree_transformations.strip_placement(before) self.assertTrue(modified) self.assert_has_no_intrinsics_nor_federated_types(after) type_test_utils.assert_types_identical(before.type_signature, clients_list_type) type_test_utils.assert_types_identical(after.type_signature, list_type)
def test_removes_federated_types_under_function(self): int_type = tf.int32 server_int_type = computation_types.at_server(int_type) int_ref = building_blocks.Reference('x', int_type) int_id = building_blocks.Lambda('x', int_type, int_ref) fed_ref = building_blocks.Reference('x', server_int_type) applied_id = building_block_factory.create_federated_map_or_apply( int_id, fed_ref) before = building_block_factory.create_federated_map_or_apply( int_id, applied_id) after, modified = tree_transformations.strip_placement(before) self.assertTrue(modified) self.assert_has_no_intrinsics_nor_federated_types(after)
def test_strip_placement_with_called_lambda(self): int_type = computation_types.TensorType(tf.int32) server_int_type = computation_types.at_server(int_type) federated_ref = building_blocks.Reference('outer', server_int_type) inner_federated_ref = building_blocks.Reference( 'inner', server_int_type) identity_lambda = building_blocks.Lambda('inner', server_int_type, inner_federated_ref) before = building_blocks.Call(identity_lambda, federated_ref) after, modified = tree_transformations.strip_placement(before) self.assertTrue(modified) self.assert_has_no_intrinsics_nor_federated_types(after) type_test_utils.assert_types_identical(before.type_signature, server_int_type) type_test_utils.assert_types_identical(after.type_signature, int_type)
def test_strip_placement_nested_federated_type(self): int_type = computation_types.TensorType(tf.int32) server_int_type = computation_types.at_server(int_type) tupled_int_type = computation_types.to_type((int_type, int_type)) tupled_server_int_type = computation_types.to_type( (server_int_type, server_int_type)) fed_ref = building_blocks.Reference('x', server_int_type) before = building_blocks.Struct([fed_ref, fed_ref], container_type=tuple) after, modified = tree_transformations.strip_placement(before) self.assertTrue(modified) self.assert_has_no_intrinsics_nor_federated_types(after) type_test_utils.assert_types_identical(before.type_signature, tupled_server_int_type) type_test_utils.assert_types_identical(after.type_signature, tupled_int_type)
def test_strip_placement_federated_value_at_clients(self): int_data = building_blocks.Data('x', tf.int32) float_data = building_blocks.Data('x', tf.float32) fed_int = building_block_factory.create_federated_value( int_data, placements.CLIENTS) fed_float = building_block_factory.create_federated_value( float_data, placements.CLIENTS) tup = building_blocks.Struct([fed_int, fed_float], container_type=tuple) before = building_block_factory.create_federated_zip(tup) after, modified = tree_transformations.strip_placement(before) self.assertTrue(modified) self.assert_has_no_intrinsics_nor_federated_types(after) tuple_type = computation_types.StructWithPythonType( [(None, tf.int32), (None, tf.float32)], tuple) type_test_utils.assert_types_identical( before.type_signature, computation_types.at_clients(tuple_type)) type_test_utils.assert_types_identical(after.type_signature, tuple_type)
def test_strip_placement_removes_federated_maps(self): int_type = computation_types.TensorType(tf.int32) clients_int_type = computation_types.at_clients(int_type) int_ref = building_blocks.Reference('x', int_type) int_id = building_blocks.Lambda('x', int_type, int_ref) fed_ref = building_blocks.Reference('x', clients_int_type) applied_id = building_block_factory.create_federated_map_or_apply( int_id, fed_ref) before = building_block_factory.create_federated_map_or_apply( int_id, applied_id) after, modified = tree_transformations.strip_placement(before) self.assertTrue(modified) self.assert_has_no_intrinsics_nor_federated_types(after) type_test_utils.assert_types_identical(before.type_signature, clients_int_type) type_test_utils.assert_types_identical(after.type_signature, int_type) self.assertEqual( before.compact_representation(), 'federated_map(<(x -> x),federated_map(<(x -> x),x>)>)') self.assertEqual(after.compact_representation(), '(x -> x)((x -> x)(x))')
def consolidate_and_extract_local_processing(comp, grappler_config_proto): """Consolidates all the local processing in `comp`. The input computation `comp` must have the following properties: 1. The output of `comp` may be of a federated type or unplaced. We refer to the placement `p` of that type as the placement of `comp`. There is no placement anywhere in the body of `comp` different than `p`. If `comp` is of a functional type, and has a parameter, the type of that parameter is a federated type placed at `p` as well, or unplaced if the result of the function is unplaced. 2. The only intrinsics that may appear in the body of `comp` are those that manipulate data locally within the same placement. The exact set of these intrinsics will be gradually updated. At the moment, we support only the following: * Either `federated_apply` or `federated_map`, depending on whether `comp` is `SERVER`- or `CLIENTS`-placed. `federated_map_all_equal` is also allowed in the `CLIENTS`-placed case. * Either `federated_value_at_server` or `federated_value_at_clients`, likewise placement-dependent. * Either `federated_zip_at_server` or `federated_zip_at_clients`, again placement-dependent. Anything else, including `sequence_*` operators, should have been reduced already prior to calling this function. 3. There are no lambdas in the body of `comp` except for `comp` itself being possibly a (top-level) lambda. All other lambdas must have been reduced. This requirement may eventually be relaxed by embedding lambda reducer into this helper method. 4. If `comp` is of a functional type, it is either an instance of `building_blocks.CompiledComputation`, in which case there is nothing for us to do here, or a `building_blocks.Lambda`. 5. There is at most one unbound reference under `comp`, and this is only allowed in the case that `comp` is not of a functional type. Aside from the intrinsics specified above, and the possibility of allowing lambdas, blocks, and references given the constraints above, the remaining constructs in `comp` include a combination of tuples, selections, calls, and sections of TensorFlow (as `CompiledComputation`s). This helper function does contain the logic to consolidate these constructs. The output of this transformation is always a single section of TensorFlow, which we henceforth refer to as `result`, the exact form of which depends on the placement of `comp` and the presence or absence of an argument. a. If there is no argument in `comp`, and `comp` is `SERVER`-placed, then the `result` is such that `comp` can be equivalently represented as: ``` federated_value_at_server(result()) ``` b. If there is no argument in `comp`, and `comp` is `CLIENTS`-placed, then the `result` is such that `comp` can be equivalently represented as: ``` federated_value_at_clients(result()) ``` c. If there is an argument in `comp`, and `comp` is `SERVER`-placed, then the `result` is such that `comp` can be equivalently represented as: ``` (arg -> federated_apply(<result, arg>)) ``` d. If there is an argument in `comp`, and `comp` is `CLIENTS`-placed, then the `result` is such that `comp` can be equivalently represented as: ``` (arg -> federated_map(<result, arg>)) ``` If the type of `comp` is `T@p` (thus `comp` is non-functional), the type of `result` is `T`, where `p` is the specific (concrete) placement of `comp`. If the type of `comp` is `(T@p -> U@p)`, then the type of `result` must be `(T -> U)`, where `p` is again a specific placement. Args: comp: An instance of `building_blocks.ComputationBuildingBlock` that serves as the input to this transformation, as described above. grappler_config_proto: An instance of `tf.compat.v1.ConfigProto` to configure Grappler graph optimization of the generated TensorFlow graph. If `grappler_config_proto` has `graph_options.rewrite_options.disable_meta_optimizer=True`, Grappler is bypassed. Returns: An instance of `building_blocks.CompiledComputation` that holds the TensorFlow section produced by this extraction step, as described above. """ py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) comp.type_signature.check_function() # Drop any unused subcomputations which may reference placements different # from the result. simplified = transformations.to_call_dominant(comp) unplaced, _ = tree_transformations.strip_placement(simplified) extracted = parse_tff_to_tf(unplaced, grappler_config_proto) check_extraction_result(unplaced, extracted) return extracted
def _compile_to_tf(fn): simplified = transformations.to_call_dominant(fn) unplaced, _ = tree_transformations.strip_placement(simplified) return compiler.compile_local_subcomputations_to_tensorflow(unplaced)
def test_computation_non_federated_type(self): before = building_blocks.Data('x', tf.int32) after, modified = tree_transformations.strip_placement(before) self.assertEqual(before, after) self.assertFalse(modified)
def test_raises_on_none(self): with self.assertRaises(TypeError): tree_transformations.strip_placement(None)