def test_returns_false_for_blocks_with_different_variable_values(self): data = building_blocks.Data('data', tf.int32) data_1 = building_blocks.Data('data', tf.float32) comp_1 = building_blocks.Block([('a', data_1)], data) data_2 = building_blocks.Data('data', tf.bool) comp_2 = building_blocks.Block([('a', data_2)], data) self.assertFalse(tree_analysis.trees_equal(comp_1, comp_2))
def test_returns_true_for_blocks_resulting_reference_to_same_local(self): data = building_blocks.Data('data', tf.int32) ref_to_a = building_blocks.Reference('a', data.type_signature) ref_to_b = building_blocks.Reference('b', data.type_signature) comp_1 = building_blocks.Block([('a', data)], ref_to_a) comp_2 = building_blocks.Block([('b', data)], ref_to_b) self.assertTrue(tree_analysis.trees_equal(comp_1, comp_2))
def test_returns_false_for_blocks_referring_to_different_local(self): data = building_blocks.Data('data', tf.int32) ref_to_a = building_blocks.Reference('a', data.type_signature) ref_to_b = building_blocks.Reference('b', data.type_signature) comp_1 = building_blocks.Block([('a', data), ('b', ref_to_a)], ref_to_a) comp_2 = building_blocks.Block([('b', data), ('a', ref_to_b)], ref_to_a) self.assertFalse(tree_analysis.trees_equal(comp_1, comp_2)) self.assertFalse(tree_analysis.trees_equal(comp_2, comp_1))
def test_removes_nested_blocks_with_unused_reference(self): input_data = building_blocks.Data('b', tf.int32) blk = building_blocks.Block( [('x', building_blocks.Data('a', tf.int32))], input_data) higher_level_blk = building_blocks.Block([('y', input_data)], blk) data, modified = transformation_utils.transform_postorder( higher_level_blk, self._unused_block_remover.transform) self.assertTrue(modified) self.assertEqual(data.compact_representation(), input_data.compact_representation())
def test_returns_tf_computation_with_functional_type_block_to_lambda_with_block( self): concrete_int_type = computation_types.TensorType(tf.int32) param = building_blocks.Reference('x', tf.float32) block_to_param = building_blocks.Block([('x', param)], param) lam = building_blocks.Lambda(param.name, param.type_signature, block_to_param) unused_int = building_block_factory.create_tensorflow_constant( concrete_int_type, 1) blk_to_lam = building_blocks.Block([('y', unused_int)], lam) self.assert_compiles_to_tensorflow(blk_to_lam)
def test_nested_blocks(self): x_ref = building_blocks.Reference('a', tf.int32) data = building_blocks.Data('data', tf.int32) block1 = building_blocks.Block([('a', data), ('a', x_ref)], x_ref) block2 = building_blocks.Block([('a', data), ('a', x_ref)], block1) transformed_comp, modified = tree_transformations.uniquify_reference_names( block2) self.assertEqual(block2.compact_representation(), '(let a=data,a=a in (let a=data,a=a in a))') self.assertEqual( transformed_comp.compact_representation(), '(let a=data,_var1=a in (let _var2=data,_var3=_var2 in _var3))') tree_analysis.check_has_unique_names(transformed_comp) self.assertTrue(modified)
def test_returns_single_called_graph_after_resolving_multiple_variables( self): ref_to_int = building_blocks.Reference('var', tf.int32) first_tf_id_type = computation_types.TensorType(tf.int32) first_tf_id = building_block_factory.create_compiled_identity( first_tf_id_type) called_tf_id = building_blocks.Call(first_tf_id, ref_to_int) ref_to_call = building_blocks.Reference('call', called_tf_id.type_signature) second_tf_id_type = computation_types.TensorType(tf.int32) second_tf_id = building_block_factory.create_compiled_identity( second_tf_id_type) second_called = building_blocks.Call(second_tf_id, ref_to_call) ref_to_second_call = building_blocks.Reference( 'second_call', called_tf_id.type_signature) block_locals = [('call', called_tf_id), ('second_call', second_called)] block = building_blocks.Block(block_locals, ref_to_second_call) tf_representing_block, _ = transformations.create_tensorflow_representing_block( block) self.assertEqual(tf_representing_block.type_signature, block.type_signature) self.assertIsInstance(tf_representing_block, building_blocks.Call) self.assertIsInstance(tf_representing_block.function, building_blocks.CompiledComputation) self.assertIsInstance(tf_representing_block.argument, building_blocks.Reference) self.assertEqual(tf_representing_block.argument.name, 'var')
def test_returned_tensorflow_executes_correctly_with_no_unbound_refs(self): concrete_int_type = computation_types.TensorType(tf.int32) concrete_int = building_block_factory.create_tensorflow_constant( concrete_int_type, 1) first_tf_id_type = computation_types.TensorType(tf.int32) first_tf_id = building_block_factory.create_compiled_identity( first_tf_id_type) called_tf_id = building_blocks.Call(first_tf_id, concrete_int) ref_to_call = building_blocks.Reference('call', called_tf_id.type_signature) second_tf_id_type = computation_types.TensorType(tf.int32) second_tf_id = building_block_factory.create_compiled_identity( second_tf_id_type) second_called = building_blocks.Call(second_tf_id, ref_to_call) ref_to_second_call = building_blocks.Reference( 'second_call', called_tf_id.type_signature) block_locals = [('call', called_tf_id), ('second_call', second_called)] block = building_blocks.Block( block_locals, building_blocks.Tuple([ref_to_second_call, ref_to_second_call])) tf_representing_block, _ = transformations.create_tensorflow_representing_block( block) result = test_utils.run_tensorflow( tf_representing_block.function.proto) self.assertAllEqual(result, [1, 1])
def _create_complex_computation(): tensor_type = computation_types.TensorType(tf.int32) compiled = building_block_factory.create_compiled_identity( tensor_type, 'a') federated_type = computation_types.FederatedType(tf.int32, placements.SERVER) arg_ref = building_blocks.Reference('arg', federated_type) bindings = [] results = [] def _bind(name, value): bindings.append((name, value)) return building_blocks.Reference(name, value.type_signature) for i in range(2): called_federated_broadcast = building_block_factory.create_federated_broadcast( arg_ref) called_federated_map = building_block_factory.create_federated_map( compiled, _bind(f'broadcast_{i}', called_federated_broadcast)) called_federated_mean = building_block_factory.create_federated_mean( _bind(f'map_{i}', called_federated_map), None) results.append(_bind(f'mean_{i}', called_federated_mean)) result = building_blocks.Struct(results) block = building_blocks.Block(bindings, result) return building_blocks.Lambda('arg', tf.int32, block)
def test_propogates_dependence_up_through_block_result(self): dummy_intrinsic = building_blocks.Intrinsic('dummy_intrinsic', tf.int32) integer_reference = building_blocks.Reference('int', tf.int32) block = building_blocks.Block([('x', integer_reference)], dummy_intrinsic) dependent_nodes = tree_analysis.extract_nodes_consuming( block, dummy_intrinsic_predicate) self.assertIn(block, dependent_nodes)
def test_raises_lambda_rebinding_of_block_variable(self): x_ref = building_blocks.Reference('x', tf.int32) lambda_1 = building_blocks.Lambda('x', tf.int32, x_ref) x_data = building_blocks.Data('x', tf.int32) single_block = building_blocks.Block([('x', x_data)], lambda_1) with self.assertRaises(tree_analysis.NonuniqueNameError): tree_analysis.check_has_unique_names(single_block)
def test_with_block(self): ex = reference_resolving_executor.ReferenceResolvingExecutor( eager_tf_executor.EagerTFExecutor()) loop = asyncio.get_event_loop() f_type = computation_types.FunctionType(tf.int32, tf.int32) a = building_blocks.Reference( 'a', computation_types.StructType([('f', f_type), ('x', tf.int32)])) ret = building_blocks.Block([('f', building_blocks.Selection(a, name='f')), ('x', building_blocks.Selection(a, name='x'))], building_blocks.Call( building_blocks.Reference('f', f_type), building_blocks.Call( building_blocks.Reference('f', f_type), building_blocks.Reference( 'x', tf.int32)))) comp = building_blocks.Lambda(a.name, a.type_signature, ret) @computations.tf_computation(tf.int32) def add_one(x): return x + 1 v1 = loop.run_until_complete( ex.create_value(comp.proto, comp.type_signature)) v2 = loop.run_until_complete(ex.create_value(add_one)) v3 = loop.run_until_complete(ex.create_value(10, tf.int32)) v4 = loop.run_until_complete( ex.create_struct( anonymous_tuple.AnonymousTuple([('f', v2), ('x', v3)]))) v5 = loop.run_until_complete(ex.create_call(v1, v4)) result = loop.run_until_complete(v5.compute()) self.assertEqual(result.numpy(), 12)
def test_returns_correct_structure_with_tuple_in_result(self): ref_to_int = building_blocks.Reference('var', tf.int32) first_tf_id = building_block_factory.create_compiled_identity(tf.int32) called_tf_id = building_blocks.Call(first_tf_id, ref_to_int) ref_to_call = building_blocks.Reference('call', called_tf_id.type_signature) second_tf_id = building_block_factory.create_compiled_identity( tf.int32) second_called = building_blocks.Call(second_tf_id, ref_to_call) ref_to_second_call = building_blocks.Reference( 'second_call', called_tf_id.type_signature) block_locals = [('call', called_tf_id), ('second_call', second_called)] block = building_blocks.Block( block_locals, building_blocks.Tuple([ref_to_second_call, ref_to_second_call])) tf_representing_block, _ = compiler_transformations.create_tensorflow_representing_block( block) self.assertEqual(tf_representing_block.type_signature, block.type_signature) self.assertIsInstance(tf_representing_block, building_blocks.Call) self.assertIsInstance(tf_representing_block.function, building_blocks.CompiledComputation) self.assertIsInstance(tf_representing_block.argument, building_blocks.Reference) self.assertEqual(tf_representing_block.argument.name, 'var')
def _insert_comp_in_top_level_lambda(comp, name, comp_to_insert): """Inserts a computation into `comp` with the given `name`. Args: comp: The `building_blocks.Lambda` to transform. The names of lambda parameters and block variables in `comp` must be unique. name: The name to use. comp_to_insert: The `building_blocks.ComputationBuildingBlock` to insert. Returns: A new computation with the transformation applied or the original `comp`. """ py_typecheck.check_type(comp, building_blocks.Lambda) py_typecheck.check_type(name, str) py_typecheck.check_type(comp_to_insert, building_blocks.ComputationBuildingBlock) tree_analysis.check_has_unique_names(comp) result = comp.result if result.is_block(): variables = result.locals result = result.result else: variables = [] variables.insert(0, (name, comp_to_insert)) block = building_blocks.Block(variables, result) return building_blocks.Lambda(comp.parameter_name, comp.parameter_type, block)
def _transform(comp, context_tree): """Renames References in `comp` to unique names.""" if comp.is_reference(): payload = context_tree.get_payload_with_name(comp.name) if payload is None: return comp, False new_name = payload.new_name if new_name is comp.name: return comp, False return building_blocks.Reference(new_name, comp.type_signature, comp.context), True elif comp.is_block(): new_locals = [] modified = False for name, val in comp.locals: context_tree.walk_down_one_variable_binding() new_name = context_tree.get_payload_with_name(name).new_name modified = modified or (new_name is not name) new_locals.append((new_name, val)) return building_blocks.Block(new_locals, comp.result), modified elif comp.is_lambda(): if comp.parameter_type is None: return comp, False context_tree.walk_down_one_variable_binding() new_name = context_tree.get_payload_with_name( comp.parameter_name).new_name if new_name is comp.parameter_name: return comp, False return building_blocks.Lambda(new_name, comp.parameter_type, comp.result), True return comp, False
def test_raises_with_naked_graph_as_block_local(self): graph = building_block_factory.create_compiled_identity(tf.int32) block_locals = [('graph', graph)] ref_to_graph = building_blocks.Reference('graph', graph.type_signature) block = building_blocks.Block(block_locals, ref_to_graph) with self.assertRaises(ValueError): compiler_transformations.create_tensorflow_representing_block(block)
def test_returns_correct_structure_with_no_unbound_references(self): concrete_int_type = computation_types.TensorType(tf.int32) concrete_int = building_block_factory.create_tensorflow_constant( concrete_int_type, 1) first_tf_id_type = computation_types.TensorType(tf.int32) first_tf_id = building_block_factory.create_compiled_identity( first_tf_id_type) called_tf_id = building_blocks.Call(first_tf_id, concrete_int) ref_to_call = building_blocks.Reference('call', called_tf_id.type_signature) second_tf_id_type = computation_types.TensorType(tf.int32) second_tf_id = building_block_factory.create_compiled_identity( second_tf_id_type) second_called = building_blocks.Call(second_tf_id, ref_to_call) ref_to_second_call = building_blocks.Reference( 'second_call', called_tf_id.type_signature) block_locals = [('call', called_tf_id), ('second_call', second_called)] block = building_blocks.Block( block_locals, building_blocks.Tuple([ref_to_second_call, ref_to_second_call])) tf_representing_block, _ = transformations.create_tensorflow_representing_block( block) self.assertEqual(tf_representing_block.type_signature, block.type_signature) self.assertIsInstance(tf_representing_block, building_blocks.Call) self.assertIsInstance(tf_representing_block.function, building_blocks.CompiledComputation) self.assertIsNone(tf_representing_block.argument)
def test_with_block(self): ex = reference_resolving_executor.ReferenceResolvingExecutor( eager_tf_executor.EagerTFExecutor()) f_type = computation_types.FunctionType(tf.int32, tf.int32) a = building_blocks.Reference( 'a', computation_types.StructType([('f', f_type), ('x', tf.int32)])) ret = building_blocks.Block( [('f', building_blocks.Selection(a, name='f')), ('x', building_blocks.Selection(a, name='x'))], building_blocks.Call( building_blocks.Reference('f', f_type), building_blocks.Call(building_blocks.Reference('f', f_type), building_blocks.Reference('x', tf.int32)))) comp = building_blocks.Lambda(a.name, a.type_signature, ret) @tensorflow_computation.tf_computation(tf.int32) def add_one(x): return x + 1 v1 = asyncio.run(ex.create_value(comp.proto, comp.type_signature)) v2 = asyncio.run(ex.create_value(add_one)) v3 = asyncio.run(ex.create_value(10, tf.int32)) v4 = asyncio.run( ex.create_struct(structure.Struct([('f', v2), ('x', v3)]))) v5 = asyncio.run(ex.create_call(v1, v4)) result = asyncio.run(v5.compute()) self.assertEqual(result.numpy(), 12)
def test_returns_comp_with_block_untransformed(self): data = building_blocks.Data('a', tf.int32) block = building_blocks.Block([('x', data), ('y', data)], data) untransformed, modified_indicator = compiler_transformations.remove_duplicate_called_graphs( block) self.assertEqual(untransformed, block) self.assertFalse(modified_indicator)
def test_basic_functionality_of_block_class(self): x = building_blocks.Block( [('x', building_blocks.Reference('arg', (tf.int32, tf.int32))), ('y', building_blocks.Selection( building_blocks.Reference('x', (tf.int32, tf.int32)), index=0))], building_blocks.Reference('y', tf.int32)) self.assertEqual(str(x.type_signature), 'int32') self.assertEqual([(k, v.compact_representation()) for k, v in x.locals], [('x', 'arg'), ('y', 'x[0]')]) self.assertEqual(x.result.compact_representation(), 'y') self.assertEqual( repr(x), 'Block([(\'x\', Reference(\'arg\', ' 'StructType([TensorType(tf.int32), TensorType(tf.int32)]) as tuple)), ' '(\'y\', Selection(Reference(\'x\', ' 'StructType([TensorType(tf.int32), TensorType(tf.int32)]) as tuple), ' 'index=0))], ' 'Reference(\'y\', TensorType(tf.int32)))') self.assertEqual(x.compact_representation(), '(let x=arg,y=x[0] in y)') x_proto = x.proto self.assertEqual( type_serialization.deserialize_type(x_proto.type), x.type_signature) self.assertEqual(x_proto.WhichOneof('computation'), 'block') self.assertEqual(str(x_proto.block.result), str(x.result.proto)) for idx, loc_proto in enumerate(x_proto.block.local): loc_name, loc_value = x.locals[idx] self.assertEqual(loc_proto.name, loc_name) self.assertEqual(str(loc_proto.value), str(loc_value.proto)) self._serialize_deserialize_roundtrip_test(x)
def test_executes_correctly_with_tuple_in_result(self): ref_to_int = building_blocks.Reference('var', tf.int32) first_tf_id_type = computation_types.TensorType(tf.int32) first_tf_id = building_block_factory.create_compiled_identity( first_tf_id_type) called_tf_id = building_blocks.Call(first_tf_id, ref_to_int) ref_to_call = building_blocks.Reference('call', called_tf_id.type_signature) second_tf_id_type = computation_types.TensorType(tf.int32) second_tf_id = building_block_factory.create_compiled_identity( second_tf_id_type) second_called = building_blocks.Call(second_tf_id, ref_to_call) ref_to_second_call = building_blocks.Reference( 'second_call', called_tf_id.type_signature) block_locals = [('call', called_tf_id), ('second_call', second_called)] block = building_blocks.Block( block_locals, building_blocks.Tuple([ref_to_second_call, ref_to_second_call])) tf_representing_block, _ = transformations.create_tensorflow_representing_block( block) result_ones = test_utils.run_tensorflow( tf_representing_block.function.proto, 1) self.assertAllEqual(result_ones, [1, 1]) result_zeros = test_utils.run_tensorflow( tf_representing_block.function.proto, 0) self.assertAllEqual(result_zeros, [0, 0])
def test_unwraps_block_with_empty_locals(self): input_data = building_blocks.Data('b', tf.int32) blk = building_blocks.Block([], input_data) data, modified = transformation_utils.transform_postorder( blk, self._unused_block_remover.transform) self.assertTrue(modified) self.assertEqual(data.compact_representation(), input_data.compact_representation())
def test_returns_tf_computation_block_with_compiled_comp(self): concrete_int_type = computation_types.TensorType(tf.int32) tf_identity = building_block_factory.create_compiled_identity( concrete_int_type) unused_int = building_block_factory.create_tensorflow_constant( concrete_int_type, 1) block_to_id = building_blocks.Block([('x', unused_int)], tf_identity) self.assert_compiles_to_tensorflow(block_to_id)
def test_returns_tf_computation_with_functional_type_lambda_with_block( self): param = building_blocks.Reference('x', [('a', tf.int32), ('b', tf.float32)]) block_to_param = building_blocks.Block([('x', param)], param) lam = building_blocks.Lambda(param.name, param.type_signature, block_to_param) self.assert_compiles_to_tensorflow(lam)
def test_passes_unbound_type_signature_obscured_under_block(self): fed_ref = building_blocks.Reference( 'x', computation_types.FederatedType(tf.int32, placements.SERVER)) block = building_blocks.Block( [('y', fed_ref), ('x', building_blocks.Data('whimsy', tf.int32)), ('z', building_blocks.Reference('x', tf.int32))], building_blocks.Reference('y', fed_ref.type_signature)) tree_transformations.strip_placement(block)
def _split_by_intrinsics_in_top_level_lambda(comp): """Splits by the intrinsics in the frist block local in the result of `comp`. This function splits `comp` into two computations `before` and `after` the called intrinsic or tuple of called intrinsics found as the first local in the `building_blocks.Block` returned by the top level lambda; and returns a Python tuple representing the pair of `before` and `after` computations. Args: comp: The `building_blocks.Lambda` to split. Returns: A pair of `building_blocks.ComputationBuildingBlock`s. Raises: ValueError: If the first local in the `building_blocks.Block` referenced by the top level lambda is not a called intrincs or a `building_blocks.Struct` of called intrinsics. """ py_typecheck.check_type(comp, building_blocks.Lambda) py_typecheck.check_type(comp.result, building_blocks.Block) tree_analysis.check_has_unique_names(comp) name_generator = building_block_factory.unique_name_generator(comp) name, first_local = comp.result.locals[0] if building_block_analysis.is_called_intrinsic(first_local): result = first_local.argument elif first_local.is_struct(): elements = [] for element in first_local: if not building_block_analysis.is_called_intrinsic(element): raise ValueError( 'Expected all the elements of the `building_blocks.Struct` to be ' 'called intrinsics, but found: \n{}'.format(element)) elements.append(element.argument) result = building_blocks.Struct(elements) else: raise ValueError( 'Expected either a called intrinsic or a `building_blocks.Struct` of ' 'called intrinsics, but found: \n{}'.format(first_local)) before = building_blocks.Lambda(comp.parameter_name, comp.parameter_type, result) ref_name = next(name_generator) ref_type = computation_types.StructType( (comp.parameter_type, first_local.type_signature)) ref = building_blocks.Reference(ref_name, ref_type) sel_after_arg_1 = building_blocks.Selection(ref, index=0) sel_after_arg_2 = building_blocks.Selection(ref, index=1) variables = comp.result.locals variables[0] = (name, sel_after_arg_2) variables.insert(0, (comp.parameter_name, sel_after_arg_1)) block = building_blocks.Block(variables, comp.result.result) after = building_blocks.Lambda(ref.name, ref.type_signature, block) return before, after
def test_leaves_single_used_reference(self): blk = building_blocks.Block( [('x', building_blocks.Data('a', tf.int32))], building_blocks.Reference('x', tf.int32)) transformed_blk, modified = transformation_utils.transform_postorder( blk, self._unused_block_remover.transform) self.assertFalse(modified) self.assertEqual(transformed_blk.compact_representation(), blk.compact_representation())
def test_propogates_dependence_up_through_block_locals(self): type_signature = computation_types.TensorType(tf.int32) dummy_intrinsic = building_blocks.Intrinsic('dummy_intrinsic', type_signature) integer_reference = building_blocks.Reference('int', tf.int32) block = building_blocks.Block([('x', dummy_intrinsic)], integer_reference) dependent_nodes = tree_analysis.extract_nodes_consuming( block, dummy_intrinsic_predicate) self.assertIn(block, dependent_nodes)
def test_raises_multiple_placements(self): server_placed_data = building_blocks.Reference( 'x', computation_types.at_server(tf.int32)) clients_placed_data = building_blocks.Reference( 'y', computation_types.at_clients(tf.int32)) block_holding_both = building_blocks.Block([('x', server_placed_data)], clients_placed_data) with self.assertRaisesRegex(ValueError, 'multiple different placements'): tree_transformations.strip_placement(block_holding_both)
def test_with_simple_block(self): data = building_blocks.Data('a', tf.int32) simple_block = building_blocks.Block([('x', data)], building_blocks.Reference( 'x', tf.int32)) lambdas_and_blocks_removed, modified = compiler_transformations.remove_lambdas_and_blocks( simple_block) self.assertTrue(modified) self.assertNoLambdasOrBlocks(lambdas_and_blocks_removed) self.assertEqual(lambdas_and_blocks_removed.compact_representation(), 'a')