Ejemplo n.º 1
0
 def test_simple_block_inlining(self):
     test_arg = computation_building_blocks.Data('test_data', tf.int32)
     result = computation_building_blocks.Reference('test_x',
                                                    test_arg.type_signature)
     simple_block = computation_building_blocks.Block(
         [('test_x', test_arg)], result)
     self.assertEqual(str(simple_block), '(let test_x=test_data in test_x)')
     inlined = transformations.inline_blocks_with_n_referenced_locals(
         simple_block)
     self.assertEqual(str(inlined), '(let  in test_data)')
Ejemplo n.º 2
0
 def test_multiple_inline_for_nested_block(self):
     used1 = computation_building_blocks.Reference('used1', tf.int32)
     used2 = computation_building_blocks.Reference('used2', tf.int32)
     ref = computation_building_blocks.Reference('x', used1.type_signature)
     lower_block = computation_building_blocks.Block([('x', used1)], ref)
     higher_block = computation_building_blocks.Block([('used1', used2)],
                                                      lower_block)
     inlined = transformations.inline_blocks_with_n_referenced_locals(
         higher_block)
     self.assertEqual(str(higher_block),
                      '(let used1=used2 in (let x=used1 in x))')
     self.assertEqual(str(inlined), '(let  in (let  in used2))')
     user_inlined_lower_block = computation_building_blocks.Block(
         [('x', used1)], used1)
     user_inlined_higher_block = computation_building_blocks.Block(
         [('used1', used2)], user_inlined_lower_block)
     self.assertEqual(str(user_inlined_higher_block),
                      '(let used1=used2 in (let x=used1 in used1))')
     inlined_noop = transformations.inline_blocks_with_n_referenced_locals(
         user_inlined_higher_block)
     self.assertEqual(str(inlined_noop),
                      '(let used1=used2 in (let  in used1))')
Ejemplo n.º 3
0
 def test_inline_conflicting_locals(self):
     arg_comp = computation_building_blocks.Reference(
         'arg', [tf.int32, tf.int32])
     selected = computation_building_blocks.Selection(arg_comp, index=0)
     internal_arg = computation_building_blocks.Reference('arg', tf.int32)
     block = computation_building_blocks.Block([('arg', selected)],
                                               internal_arg)
     lam = computation_building_blocks.Lambda('arg',
                                              arg_comp.type_signature,
                                              block)
     self.assertEqual(str(lam), '(arg -> (let arg=arg[0] in arg))')
     inlined = transformations.inline_blocks_with_n_referenced_locals(lam)
     self.assertEqual(str(inlined), '(arg -> (let  in arg[0]))')
Ejemplo n.º 4
0
 def test_conflicting_name_resolved_inlining(self):
     red_herring_arg = computation_building_blocks.Reference(
         'redherring', tf.int32)
     used_arg = computation_building_blocks.Reference('used', tf.int32)
     ref = computation_building_blocks.Reference('x',
                                                 used_arg.type_signature)
     lower_block = computation_building_blocks.Block([('x', used_arg)], ref)
     higher_block = computation_building_blocks.Block(
         [('x', red_herring_arg)], lower_block)
     self.assertEqual(str(higher_block),
                      '(let x=redherring in (let x=used in x))')
     inlined = transformations.inline_blocks_with_n_referenced_locals(
         higher_block)
     self.assertEqual(str(inlined), '(let  in (let  in used))')
Ejemplo n.º 5
0
 def test_no_inlining_if_referenced_twice(self):
     test_arg = computation_building_blocks.Data('test_data', tf.int32)
     ref1 = computation_building_blocks.Reference('test_x',
                                                  test_arg.type_signature)
     ref2 = computation_building_blocks.Reference('test_x',
                                                  test_arg.type_signature)
     result = computation_building_blocks.Tuple([ref1, ref2])
     simple_block = computation_building_blocks.Block(
         [('test_x', test_arg)], result)
     self.assertEqual(str(simple_block),
                      '(let test_x=test_data in <test_x,test_x>)')
     inlined = transformations.inline_blocks_with_n_referenced_locals(
         simple_block)
     self.assertEqual(str(inlined), str(simple_block))
Ejemplo n.º 6
0
 def test_conflicting_nested_name_inlining(self):
     innermost = computation_building_blocks.Reference('x', tf.int32)
     intermediate_arg = computation_building_blocks.Reference('y', tf.int32)
     item2 = computation_building_blocks.Block([('x', intermediate_arg)],
                                               innermost)
     item1 = computation_building_blocks.Reference('x', tf.int32)
     mediate_tuple = computation_building_blocks.Tuple([item1, item2])
     used = computation_building_blocks.Reference('used', tf.int32)
     used1 = computation_building_blocks.Reference('used1', tf.int32)
     outer_block = computation_building_blocks.Block([('x', used),
                                                      ('y', used1)],
                                                     mediate_tuple)
     self.assertEqual(str(outer_block),
                      '(let x=used,y=used1 in <x,(let x=y in x)>)')
     inlined = transformations.inline_blocks_with_n_referenced_locals(
         outer_block)
     self.assertEqual(str(inlined), '(let  in <used,(let  in used1)>)')
Ejemplo n.º 7
0
 def test_inline_conflicting_lambdas(self):
     comp = computation_building_blocks.Tuple(
         [computation_building_blocks.Data('test', tf.int32)])
     input1 = computation_building_blocks.Reference('input2',
                                                    comp.type_signature)
     first_level_call = computation_building_blocks.Call(
         computation_building_blocks.Lambda('input2', input1.type_signature,
                                            input1), comp)
     input2 = computation_building_blocks.Reference(
         'input2', first_level_call.type_signature)
     second_level_call = computation_building_blocks.Call(
         computation_building_blocks.Lambda('input2', input2.type_signature,
                                            input2), first_level_call)
     self.assertEqual(str(second_level_call),
                      '(input2 -> input2)((input2 -> input2)(<test>))')
     lambda_reduced_comp = transformations.replace_called_lambda_with_block(
         second_level_call)
     self.assertEqual(
         str(lambda_reduced_comp),
         '(let input2=(let input2=<test> in input2) in input2)')
     inlined = transformations.inline_blocks_with_n_referenced_locals(
         lambda_reduced_comp)
     self.assertEqual(str(inlined), '(let  in (let  in <test>))')