def test_with_simple_block(self):
   data = building_blocks.Data('a', tf.int32)
   simple_block = building_blocks.Block([('x', data)],
                                        building_blocks.Reference(
                                            'x', tf.int32))
   lambdas_and_blocks_removed, modified = compiler_transformations.remove_lambdas_and_blocks(
       simple_block)
   self.assertTrue(modified)
   self.assertNoLambdasOrBlocks(lambdas_and_blocks_removed)
   self.assertEqual(lambdas_and_blocks_removed.compact_representation(), 'a')
 def test_with_simple_called_lambda(self):
   identity_lam = building_blocks.Lambda(
       'x', tf.int32, building_blocks.Reference('x', tf.int32))
   called_lambda = building_blocks.Call(identity_lam,
                                        building_blocks.Data('a', tf.int32))
   lambdas_and_blocks_removed, modified = compiler_transformations.remove_lambdas_and_blocks(
       called_lambda)
   self.assertTrue(modified)
   self.assertNoLambdasOrBlocks(lambdas_and_blocks_removed)
   self.assertEqual(lambdas_and_blocks_removed.compact_representation(), 'a')
 def test_with_structure_replacing_federated_zip(self):
   fed_tuple = building_blocks.Reference(
       'tup',
       computation_types.FederatedType([tf.int32] * 3, placements.CLIENTS))
   unzipped = building_block_factory.create_federated_unzip(fed_tuple)
   zipped = building_block_factory.create_federated_zip(unzipped)
   placement_unwrapped, _ = transformations.unwrap_placement(zipped)
   placement_gone = placement_unwrapped.argument
   lambdas_and_blocks_removed, modified = compiler_transformations.remove_lambdas_and_blocks(
       placement_gone)
   self.assertTrue(modified)
   self.assertNoLambdasOrBlocks(lambdas_and_blocks_removed)
Esempio n. 4
0
 def test_with_nested_called_lambdas(self):
     identity_lam = building_blocks.Lambda(
         'x', tf.int32, building_blocks.Reference('x', tf.int32))
     ref_to_fn = building_blocks.Reference('fn',
                                           identity_lam.type_signature)
     data = building_blocks.Data('a', tf.int32)
     called_inner_lambda = building_blocks.Call(ref_to_fn, data)
     higher_level_lambda = building_blocks.Lambda(
         'fn', identity_lam.type_signature, called_inner_lambda)
     lambdas_and_blocks_removed, modified = compiler_transformations.remove_lambdas_and_blocks(
         higher_level_lambda)
     self.assertTrue(modified)
     self.assertNoLambdasOrBlocks(lambdas_and_blocks_removed)
 def test_with_structure_replacing_federated_map(self):
   function_type = computation_types.FunctionType(tf.int32, tf.int32)
   tuple_ref = building_blocks.Reference('arg', [
       function_type,
       tf.int32,
   ])
   fn = building_blocks.Selection(tuple_ref, index=0)
   arg = building_blocks.Selection(tuple_ref, index=1)
   called_fn = building_blocks.Call(fn, arg)
   concrete_fn = building_blocks.Lambda(
       'x', tf.int32, building_blocks.Reference('x', tf.int32))
   concrete_arg = building_blocks.Data('a', tf.int32)
   arg_tuple = building_blocks.Tuple([concrete_fn, concrete_arg])
   generated_structure = building_blocks.Block([('arg', arg_tuple)], called_fn)
   lambdas_and_blocks_removed, modified = compiler_transformations.remove_lambdas_and_blocks(
       generated_structure)
   self.assertTrue(modified)
   self.assertNoLambdasOrBlocks(lambdas_and_blocks_removed)
 def test_with_higher_level_lambdas(self):
   self.skipTest('b/146904968')
   data = building_blocks.Data('a', tf.int32)
   dummy = building_blocks.Reference('z', tf.int32)
   lowest_lambda = building_blocks.Lambda(
       'z', tf.int32,
       building_blocks.Tuple([dummy,
                              building_blocks.Reference('x', tf.int32)]))
   middle_lambda = building_blocks.Lambda('x', tf.int32, lowest_lambda)
   lam_arg = building_blocks.Reference('x', middle_lambda.type_signature)
   rez = building_blocks.Call(lam_arg, data)
   left_lambda = building_blocks.Lambda('x', middle_lambda.type_signature, rez)
   higher_call = building_blocks.Call(left_lambda, middle_lambda)
   high_call = building_blocks.Call(higher_call, data)
   lambdas_and_blocks_removed, modified = compiler_transformations.remove_lambdas_and_blocks(
       high_call)
   self.assertTrue(modified)
   self.assertNoLambdasOrBlocks(lambdas_and_blocks_removed)
 def test_with_multiple_reference_indirection(self):
   identity_lam = building_blocks.Lambda(
       'x', tf.int32, building_blocks.Reference('x', tf.int32))
   tuple_wrapping_ref = building_blocks.Tuple(
       [building_blocks.Reference('a', identity_lam.type_signature)])
   selection_from_ref = building_blocks.Selection(
       building_blocks.Reference('b', tuple_wrapping_ref.type_signature),
       index=0)
   data = building_blocks.Data('a', tf.int32)
   called_lambda_with_indirection = building_blocks.Call(
       building_blocks.Reference('c', selection_from_ref.type_signature), data)
   blk = building_blocks.Block([
       ('a', identity_lam),
       ('b', tuple_wrapping_ref),
       ('c', selection_from_ref),
   ], called_lambda_with_indirection)
   lambdas_and_blocks_removed, modified = compiler_transformations.remove_lambdas_and_blocks(
       blk)
   self.assertTrue(modified)
   self.assertNoLambdasOrBlocks(lambdas_and_blocks_removed)