def test_simple_block_inlining(self): test_arg = computation_building_blocks.Data('test_data', tf.int32) result = computation_building_blocks.Reference('test_x', test_arg.type_signature) simple_block = computation_building_blocks.Block( [('test_x', test_arg)], result) self.assertEqual(str(simple_block), '(let test_x=test_data in test_x)') inlined = transformations.inline_blocks_with_n_referenced_locals( simple_block) self.assertEqual(str(inlined), '(let in test_data)')
def test_multiple_inline_for_nested_block(self): used1 = computation_building_blocks.Reference('used1', tf.int32) used2 = computation_building_blocks.Reference('used2', tf.int32) ref = computation_building_blocks.Reference('x', used1.type_signature) lower_block = computation_building_blocks.Block([('x', used1)], ref) higher_block = computation_building_blocks.Block([('used1', used2)], lower_block) inlined = transformations.inline_blocks_with_n_referenced_locals( higher_block) self.assertEqual(str(higher_block), '(let used1=used2 in (let x=used1 in x))') self.assertEqual(str(inlined), '(let in (let in used2))') user_inlined_lower_block = computation_building_blocks.Block( [('x', used1)], used1) user_inlined_higher_block = computation_building_blocks.Block( [('used1', used2)], user_inlined_lower_block) self.assertEqual(str(user_inlined_higher_block), '(let used1=used2 in (let x=used1 in used1))') inlined_noop = transformations.inline_blocks_with_n_referenced_locals( user_inlined_higher_block) self.assertEqual(str(inlined_noop), '(let used1=used2 in (let in used1))')
def test_inline_conflicting_locals(self): arg_comp = computation_building_blocks.Reference( 'arg', [tf.int32, tf.int32]) selected = computation_building_blocks.Selection(arg_comp, index=0) internal_arg = computation_building_blocks.Reference('arg', tf.int32) block = computation_building_blocks.Block([('arg', selected)], internal_arg) lam = computation_building_blocks.Lambda('arg', arg_comp.type_signature, block) self.assertEqual(str(lam), '(arg -> (let arg=arg[0] in arg))') inlined = transformations.inline_blocks_with_n_referenced_locals(lam) self.assertEqual(str(inlined), '(arg -> (let in arg[0]))')
def test_conflicting_name_resolved_inlining(self): red_herring_arg = computation_building_blocks.Reference( 'redherring', tf.int32) used_arg = computation_building_blocks.Reference('used', tf.int32) ref = computation_building_blocks.Reference('x', used_arg.type_signature) lower_block = computation_building_blocks.Block([('x', used_arg)], ref) higher_block = computation_building_blocks.Block( [('x', red_herring_arg)], lower_block) self.assertEqual(str(higher_block), '(let x=redherring in (let x=used in x))') inlined = transformations.inline_blocks_with_n_referenced_locals( higher_block) self.assertEqual(str(inlined), '(let in (let in used))')
def test_no_inlining_if_referenced_twice(self): test_arg = computation_building_blocks.Data('test_data', tf.int32) ref1 = computation_building_blocks.Reference('test_x', test_arg.type_signature) ref2 = computation_building_blocks.Reference('test_x', test_arg.type_signature) result = computation_building_blocks.Tuple([ref1, ref2]) simple_block = computation_building_blocks.Block( [('test_x', test_arg)], result) self.assertEqual(str(simple_block), '(let test_x=test_data in <test_x,test_x>)') inlined = transformations.inline_blocks_with_n_referenced_locals( simple_block) self.assertEqual(str(inlined), str(simple_block))
def test_conflicting_nested_name_inlining(self): innermost = computation_building_blocks.Reference('x', tf.int32) intermediate_arg = computation_building_blocks.Reference('y', tf.int32) item2 = computation_building_blocks.Block([('x', intermediate_arg)], innermost) item1 = computation_building_blocks.Reference('x', tf.int32) mediate_tuple = computation_building_blocks.Tuple([item1, item2]) used = computation_building_blocks.Reference('used', tf.int32) used1 = computation_building_blocks.Reference('used1', tf.int32) outer_block = computation_building_blocks.Block([('x', used), ('y', used1)], mediate_tuple) self.assertEqual(str(outer_block), '(let x=used,y=used1 in <x,(let x=y in x)>)') inlined = transformations.inline_blocks_with_n_referenced_locals( outer_block) self.assertEqual(str(inlined), '(let in <used,(let in used1)>)')
def test_inline_conflicting_lambdas(self): comp = computation_building_blocks.Tuple( [computation_building_blocks.Data('test', tf.int32)]) input1 = computation_building_blocks.Reference('input2', comp.type_signature) first_level_call = computation_building_blocks.Call( computation_building_blocks.Lambda('input2', input1.type_signature, input1), comp) input2 = computation_building_blocks.Reference( 'input2', first_level_call.type_signature) second_level_call = computation_building_blocks.Call( computation_building_blocks.Lambda('input2', input2.type_signature, input2), first_level_call) self.assertEqual(str(second_level_call), '(input2 -> input2)((input2 -> input2)(<test>))') lambda_reduced_comp = transformations.replace_called_lambda_with_block( second_level_call) self.assertEqual( str(lambda_reduced_comp), '(let input2=(let input2=<test> in input2) in input2)') inlined = transformations.inline_blocks_with_n_referenced_locals( lambda_reduced_comp) self.assertEqual(str(inlined), '(let in (let in <test>))')