def test_returns_comp_with_block_untransformed(self): data = building_blocks.Data('a', tf.int32) block = building_blocks.Block([('x', data), ('y', data)], data) untransformed, modified_indicator = compiler_transformations.remove_duplicate_called_graphs( block) self.assertEqual(untransformed, block) self.assertFalse(modified_indicator)
def test_returns_higher_level_lambda_untransformed(self): lower_level_lambda = building_blocks.Lambda( 'x', tf.int32, building_blocks.Reference('x', tf.int32)) higher_level_lambda = building_blocks.Lambda('y', tf.int32, lower_level_lambda) untransformed, modified_indicator = compiler_transformations.remove_duplicate_called_graphs( higher_level_lambda) self.assertEqual(untransformed, higher_level_lambda) self.assertFalse(modified_indicator)
def test_returns_tf_computation_with_functional_type(self): param = building_blocks.Reference('x', [('a', tf.int32), ('b', tf.float32)]) sel = building_blocks.Selection(source=param, index=0) tup = building_blocks.Tuple([sel, sel, sel]) lam = building_blocks.Lambda(param.name, param.type_signature, tup) transformed, modified_indicator = compiler_transformations.remove_duplicate_called_graphs( lam) self.assertTrue(modified_indicator) self.assertIsInstance(transformed, building_blocks.CompiledComputation) self.assertEqual(transformed.type_signature, lam.type_signature)
def _count_ops_parameterized_by_layers(k): inlined_tuple_with_k_layers = _construct_inlined_tuple(k) tf_representing_block_with_k_layers, _ = compiler_transformations.remove_duplicate_called_graphs( inlined_tuple_with_k_layers) block_ops_with_k_layers = tree_analysis.count_tensorflow_ops_under( tf_representing_block_with_k_layers) parser_callable = transformations.TFParser() naively_generated_tf_with_k_layers, _ = transformation_utils.transform_postorder( inlined_tuple_with_k_layers, parser_callable) naive_ops_with_k_layers = tree_analysis.count_tensorflow_ops_under( naively_generated_tf_with_k_layers) return block_ops_with_k_layers, naive_ops_with_k_layers
def test_returns_called_tf_computation_with_non_functional_type(self): constant_tuple = building_block_factory.create_tensorflow_constant( [tf.int32, tf.float32], 1) sel = building_blocks.Selection(source=constant_tuple, index=0) tup = building_blocks.Tuple([sel, sel, sel]) transformed, modified_indicator = compiler_transformations.remove_duplicate_called_graphs( tup) self.assertTrue(modified_indicator) self.assertEqual(transformed.type_signature, tup.type_signature) self.assertIsInstance(transformed, building_blocks.Call) self.assertIsInstance(transformed.function, building_blocks.CompiledComputation) self.assertIsNone(transformed.argument)
def test_raises_non_unique_names(self): data = building_blocks.Data('a', tf.int32) block = building_blocks.Block([('x', data), ('x', data)], data) with self.assertRaises(ValueError): compiler_transformations.remove_duplicate_called_graphs(block)
def test_raises_bad_type(self): with self.assertRaises(TypeError): compiler_transformations.remove_duplicate_called_graphs(1)