def test_with_simple_block(self): data = building_blocks.Data('a', tf.int32) simple_block = building_blocks.Block([('x', data)], building_blocks.Reference( 'x', tf.int32)) lambdas_and_blocks_removed, modified = compiler_transformations.remove_lambdas_and_blocks( simple_block) self.assertTrue(modified) self.assertNoLambdasOrBlocks(lambdas_and_blocks_removed) self.assertEqual(lambdas_and_blocks_removed.compact_representation(), 'a')
def test_with_simple_called_lambda(self): identity_lam = building_blocks.Lambda( 'x', tf.int32, building_blocks.Reference('x', tf.int32)) called_lambda = building_blocks.Call(identity_lam, building_blocks.Data('a', tf.int32)) lambdas_and_blocks_removed, modified = compiler_transformations.remove_lambdas_and_blocks( called_lambda) self.assertTrue(modified) self.assertNoLambdasOrBlocks(lambdas_and_blocks_removed) self.assertEqual(lambdas_and_blocks_removed.compact_representation(), 'a')
def test_with_structure_replacing_federated_zip(self): fed_tuple = building_blocks.Reference( 'tup', computation_types.FederatedType([tf.int32] * 3, placements.CLIENTS)) unzipped = building_block_factory.create_federated_unzip(fed_tuple) zipped = building_block_factory.create_federated_zip(unzipped) placement_unwrapped, _ = transformations.unwrap_placement(zipped) placement_gone = placement_unwrapped.argument lambdas_and_blocks_removed, modified = compiler_transformations.remove_lambdas_and_blocks( placement_gone) self.assertTrue(modified) self.assertNoLambdasOrBlocks(lambdas_and_blocks_removed)
def test_with_nested_called_lambdas(self): identity_lam = building_blocks.Lambda( 'x', tf.int32, building_blocks.Reference('x', tf.int32)) ref_to_fn = building_blocks.Reference('fn', identity_lam.type_signature) data = building_blocks.Data('a', tf.int32) called_inner_lambda = building_blocks.Call(ref_to_fn, data) higher_level_lambda = building_blocks.Lambda( 'fn', identity_lam.type_signature, called_inner_lambda) lambdas_and_blocks_removed, modified = compiler_transformations.remove_lambdas_and_blocks( higher_level_lambda) self.assertTrue(modified) self.assertNoLambdasOrBlocks(lambdas_and_blocks_removed)
def test_with_structure_replacing_federated_map(self): function_type = computation_types.FunctionType(tf.int32, tf.int32) tuple_ref = building_blocks.Reference('arg', [ function_type, tf.int32, ]) fn = building_blocks.Selection(tuple_ref, index=0) arg = building_blocks.Selection(tuple_ref, index=1) called_fn = building_blocks.Call(fn, arg) concrete_fn = building_blocks.Lambda( 'x', tf.int32, building_blocks.Reference('x', tf.int32)) concrete_arg = building_blocks.Data('a', tf.int32) arg_tuple = building_blocks.Tuple([concrete_fn, concrete_arg]) generated_structure = building_blocks.Block([('arg', arg_tuple)], called_fn) lambdas_and_blocks_removed, modified = compiler_transformations.remove_lambdas_and_blocks( generated_structure) self.assertTrue(modified) self.assertNoLambdasOrBlocks(lambdas_and_blocks_removed)
def test_with_higher_level_lambdas(self): self.skipTest('b/146904968') data = building_blocks.Data('a', tf.int32) dummy = building_blocks.Reference('z', tf.int32) lowest_lambda = building_blocks.Lambda( 'z', tf.int32, building_blocks.Tuple([dummy, building_blocks.Reference('x', tf.int32)])) middle_lambda = building_blocks.Lambda('x', tf.int32, lowest_lambda) lam_arg = building_blocks.Reference('x', middle_lambda.type_signature) rez = building_blocks.Call(lam_arg, data) left_lambda = building_blocks.Lambda('x', middle_lambda.type_signature, rez) higher_call = building_blocks.Call(left_lambda, middle_lambda) high_call = building_blocks.Call(higher_call, data) lambdas_and_blocks_removed, modified = compiler_transformations.remove_lambdas_and_blocks( high_call) self.assertTrue(modified) self.assertNoLambdasOrBlocks(lambdas_and_blocks_removed)
def test_with_multiple_reference_indirection(self): identity_lam = building_blocks.Lambda( 'x', tf.int32, building_blocks.Reference('x', tf.int32)) tuple_wrapping_ref = building_blocks.Tuple( [building_blocks.Reference('a', identity_lam.type_signature)]) selection_from_ref = building_blocks.Selection( building_blocks.Reference('b', tuple_wrapping_ref.type_signature), index=0) data = building_blocks.Data('a', tf.int32) called_lambda_with_indirection = building_blocks.Call( building_blocks.Reference('c', selection_from_ref.type_signature), data) blk = building_blocks.Block([ ('a', identity_lam), ('b', tuple_wrapping_ref), ('c', selection_from_ref), ], called_lambda_with_indirection) lambdas_and_blocks_removed, modified = compiler_transformations.remove_lambdas_and_blocks( blk) self.assertTrue(modified) self.assertNoLambdasOrBlocks(lambdas_and_blocks_removed)