def testGetOps(self): default_ops = 'NoOp:NoOp,_Recv:RecvOp,_Send:SendOp' graphs = [ text_format.Parse(d, graph_pb2.GraphDef()) for d in [GRAPH_DEF_TXT, GRAPH_DEF_TXT_2] ] ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels( 'rawproto', self.WriteGraphFiles(graphs), default_ops) matmul_prefix = 'Batch' self.assertListEqual( [ ('AccumulateNV2', None), # ('BiasAdd', 'BiasOp<CPUDevice, float>'), # ('Const', 'ConstantOp'), # ('MatMul', matmul_prefix + 'MatMulOp<CPUDevice, double, double, double, true>'), # ('MatMul', matmul_prefix + 'MatMulOp<CPUDevice, float, float, float, true>'), # ('Maximum', 'BinaryOp<CPUDevice, functor::maximum<int64_t>>'), # ('NoOp', 'NoOp'), # ('Reshape', 'ReshapeOp'), # ('_Recv', 'RecvOp'), # ('_Send', 'SendOp'), # ], ops_and_kernels) graphs[0].node[0].ClearField('device') graphs[0].node[2].ClearField('device') ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels( 'rawproto', self.WriteGraphFiles(graphs), default_ops) self.assertListEqual( [ ('AccumulateNV2', None), # ('BiasAdd', 'BiasOp<CPUDevice, float>'), # ('Const', 'ConstantOp'), # ('MatMul', matmul_prefix + 'MatMulOp<CPUDevice, double, double, double, true>'), # ('MatMul', matmul_prefix + 'MatMulOp<CPUDevice, float, float, float, true>'), # ('Maximum', 'BinaryOp<CPUDevice, functor::maximum<int64_t>>'), # ('NoOp', 'NoOp'), # ('Reshape', 'ReshapeOp'), # ('_Recv', 'RecvOp'), # ('_Send', 'SendOp'), # ], ops_and_kernels)
def testAll(self): default_ops = 'all' graphs = [ text_format.Parse(d, graph_pb2.GraphDef()) for d in [GRAPH_DEF_TXT, GRAPH_DEF_TXT_2] ] ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels( 'rawproto', self.WriteGraphFiles(graphs), default_ops) header = selective_registration_header_lib.get_header_from_ops_and_kernels( ops_and_kernels, include_all_ops_and_kernels=True) self.assertListEqual( [ '// This file was autogenerated by %s' % self.script_name, '#ifndef OPS_TO_REGISTER', # '#define OPS_TO_REGISTER', # '#define SHOULD_REGISTER_OP(op) true', # '#define SHOULD_REGISTER_OP_KERNEL(clz) true', # '#define SHOULD_REGISTER_OP_GRADIENT true', # '#endif' ], header.split('\n')) self.assertListEqual( header.split('\n'), selective_registration_header_lib.get_header( self.WriteGraphFiles(graphs), 'rawproto', default_ops).split('\n'))
def testAll(self): default_ops = 'all' graphs = [ text_format.Parse(d, graph_pb2.GraphDef()) for d in [GRAPH_DEF_TXT, GRAPH_DEF_TXT_2] ] ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels( 'rawproto', self.WriteGraphFiles(graphs), default_ops) header = selective_registration_header_lib.get_header_from_ops_and_kernels( ops_and_kernels, include_all_ops_and_kernels=True) self.assertListEqual( [ '// This file was autogenerated by %s' % self.script_name, '#ifndef OPS_TO_REGISTER', # '#define OPS_TO_REGISTER', # '#define SHOULD_REGISTER_OP(op) true', # '#define SHOULD_REGISTER_OP_KERNEL(clz) true', # '#define SHOULD_REGISTER_OP_GRADIENT true', # '#endif' ], header.split('\n')) self.assertListEqual( header.split('\n'), selective_registration_header_lib.get_header( self.WriteGraphFiles(graphs), 'rawproto', default_ops).split('\n'))
def testGetOps(self): default_ops = 'NoOp:NoOp,_Recv:RecvOp,_Send:SendOp' graphs = [ text_format.Parse(d, graph_pb2.GraphDef()) for d in [GRAPH_DEF_TXT, GRAPH_DEF_TXT_2] ] ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels( 'rawproto', self.WriteGraphFiles(graphs), default_ops) matmul_prefix = '' if test_util.IsMklEnabled(): matmul_prefix = 'Mkl' self.assertListEqual( [ ('AccumulateNV2', None), # ('BiasAdd', 'BiasOp<CPUDevice, float>'), # ('MatMul', matmul_prefix + 'MatMulOp<CPUDevice, double, false >'), # ('MatMul', matmul_prefix + 'MatMulOp<CPUDevice, float, false >'), # ('NoOp', 'NoOp'), # ('Reshape', 'ReshapeOp'), # ('_Recv', 'RecvOp'), # ('_Send', 'SendOp'), # ], ops_and_kernels) graphs[0].node[0].ClearField('device') graphs[0].node[2].ClearField('device') ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels( 'rawproto', self.WriteGraphFiles(graphs), default_ops) self.assertListEqual( [ ('AccumulateNV2', None), # ('BiasAdd', 'BiasOp<CPUDevice, float>'), # ('MatMul', matmul_prefix + 'MatMulOp<CPUDevice, double, false >'), # ('MatMul', matmul_prefix + 'MatMulOp<CPUDevice, float, false >'), # ('NoOp', 'NoOp'), # ('Reshape', 'ReshapeOp'), # ('_Recv', 'RecvOp'), # ('_Send', 'SendOp'), # ], ops_and_kernels)
def testGetOps(self): default_ops = 'NoOp:NoOp,_Recv:RecvOp,_Send:SendOp' graphs = [ text_format.Parse(d, graph_pb2.GraphDef()) for d in [GRAPH_DEF_TXT, GRAPH_DEF_TXT_2] ] ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels( 'rawproto', self.WriteGraphFiles(graphs), default_ops) matmul_prefix = '' if test_util.IsMklEnabled(): matmul_prefix = 'Mkl' self.assertListEqual( [ ('BiasAdd', 'BiasOp<CPUDevice, float>'), # ('MatMul', matmul_prefix + 'MatMulOp<CPUDevice, double, false >'), # ('MatMul', matmul_prefix + 'MatMulOp<CPUDevice, float, false >'), # ('NoOp', 'NoOp'), # ('Reshape', 'ReshapeOp'), # ('_Recv', 'RecvOp'), # ('_Send', 'SendOp'), # ], ops_and_kernels) graphs[0].node[0].ClearField('device') graphs[0].node[2].ClearField('device') ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels( 'rawproto', self.WriteGraphFiles(graphs), default_ops) self.assertListEqual( [ ('BiasAdd', 'BiasOp<CPUDevice, float>'), # ('MatMul', matmul_prefix + 'MatMulOp<CPUDevice, double, false >'), # ('MatMul', matmul_prefix + 'MatMulOp<CPUDevice, float, false >'), # ('NoOp', 'NoOp'), # ('Reshape', 'ReshapeOp'), # ('_Recv', 'RecvOp'), # ('_Send', 'SendOp'), # ], ops_and_kernels)
def testGetOpsFromList(self): default_ops = '' # Test with 2 different ops. ops_list = """[["Add", "BinaryOp<CPUDevice, functor::add<float>>"], ["Softplus", "SoftplusOp<CPUDevice, float>"]]""" ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels( 'ops_list', self.WriteTextFile(ops_list), default_ops) self.assertListEqual([ ('Add', 'BinaryOp<CPUDevice, functor::add<float>>'), ('Softplus', 'SoftplusOp<CPUDevice, float>'), ], ops_and_kernels) # Test with a single op. ops_list = '[["Softplus", "SoftplusOp<CPUDevice, float>"]]' ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels( 'ops_list', self.WriteTextFile(ops_list), default_ops) self.assertListEqual([ ('Softplus', 'SoftplusOp<CPUDevice, float>'), ], ops_and_kernels) # Test with duplicated op. ops_list = """[["Add", "BinaryOp<CPUDevice, functor::add<float>>"], ["Add", "BinaryOp<CPUDevice, functor::add<float>>"]]""" ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels( 'ops_list', self.WriteTextFile(ops_list), default_ops) self.assertListEqual([ ('Add', 'BinaryOp<CPUDevice, functor::add<float>>'), ], ops_and_kernels) # Test op with no kernel. ops_list = '[["Softplus", ""]]' ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels( 'ops_list', self.WriteTextFile(ops_list), default_ops) self.assertListEqual([ ('Softplus', None), ], ops_and_kernels) # Test two ops_list files. ops_list = '[["Softplus", "SoftplusOp<CPUDevice, float>"]]' ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels( 'ops_list', self.WriteTextFile(ops_list) + self.WriteTextFile(ops_list), default_ops) self.assertListEqual([ ('Softplus', 'SoftplusOp<CPUDevice, float>'), ], ops_and_kernels) # Test empty file. ops_list = '' with self.assertRaises(Exception): ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels( 'ops_list', self.WriteTextFile(ops_list), default_ops)