def testGetOps(self): default_ops = 'NoOp:NoOp,_Recv:RecvOp,_Send:SendOp' graphs = [text_format.Parse(d, graph_pb2.GraphDef()) for d in [GRAPH_DEF_TXT, GRAPH_DEF_TXT_2]] ops_and_kernels = print_selective_registration_header.get_ops_and_kernels( 'rawproto', self.WriteGraphFiles(graphs), default_ops) self.assertListEqual([('BiasAdd', 'BiasOp<CPUDevice, float>'), # ('MatMul', 'MatMulOp<CPUDevice, double, false >'), # ('MatMul', 'MatMulOp<CPUDevice, float, false >'), # ('NoOp', 'NoOp'), # ('Reshape', 'ReshapeOp'), # ('_Recv', 'RecvOp'), # ('_Send', 'SendOp'), # ], ops_and_kernels) graphs[0].node[0].ClearField('device') graphs[0].node[2].ClearField('device') ops_and_kernels = print_selective_registration_header.get_ops_and_kernels( 'rawproto', self.WriteGraphFiles(graphs), default_ops) self.assertListEqual([('BiasAdd', 'BiasOp<CPUDevice, float>'), # ('MatMul', 'MatMulOp<CPUDevice, double, false >'), # ('MatMul', 'MatMulOp<CPUDevice, float, false >'), # ('NoOp', 'NoOp'), # ('Reshape', 'ReshapeOp'), # ('_Recv', 'RecvOp'), # ('_Send', 'SendOp'), # ], ops_and_kernels)
def testGetOps(self): default_ops = 'NoOp:NoOp,_Recv:RecvOp,_Send:SendOp' graphs = [ text_format.Parse(d, graph_pb2.GraphDef()) for d in [GRAPH_DEF_TXT, GRAPH_DEF_TXT_2] ] ops_and_kernels = print_selective_registration_header.get_ops_and_kernels( 'rawproto', self.WriteGraphFiles(graphs), default_ops) self.assertListEqual( [ ('BiasAdd', 'BiasOp<CPUDevice, float>'), # ('MatMul', 'MatMulOp<CPUDevice, double, false >'), # ('MatMul', 'MatMulOp<CPUDevice, float, false >'), # ('NoOp', 'NoOp'), # ('Reshape', 'ReshapeOp'), # ('_Recv', 'RecvOp'), # ('_Send', 'SendOp'), # ], ops_and_kernels) graphs[0].node[0].ClearField('device') graphs[0].node[2].ClearField('device') ops_and_kernels = print_selective_registration_header.get_ops_and_kernels( 'rawproto', self.WriteGraphFiles(graphs), default_ops) self.assertListEqual( [ ('BiasAdd', 'BiasOp<CPUDevice, float>'), # ('MatMul', 'MatMulOp<CPUDevice, double, false >'), # ('MatMul', 'MatMulOp<CPUDevice, float, false >'), # ('NoOp', 'NoOp'), # ('Reshape', 'ReshapeOp'), # ('_Recv', 'RecvOp'), # ('_Send', 'SendOp'), # ], ops_and_kernels)
def testAll(self): default_ops = 'all' graphs = [ text_format.Parse(d, graph_pb2.GraphDef()) for d in [GRAPH_DEF_TXT, GRAPH_DEF_TXT_2] ] ops_and_kernels = print_selective_registration_header.get_ops_and_kernels( 'rawproto', self.WriteGraphFiles(graphs), default_ops) header = print_selective_registration_header.get_header( ops_and_kernels, default_ops) self.assertListEqual( [ '#ifndef OPS_TO_REGISTER', # '#define OPS_TO_REGISTER', # '#define SHOULD_REGISTER_OP(op) true', # '#define SHOULD_REGISTER_OP_KERNEL(clz) true', # '#define SHOULD_REGISTER_OP_GRADIENT true', # '#endif' ], header.split('\n'))
def testAll(self): default_ops = 'all' graphs = [ text_format.Parse(d, graph_pb2.GraphDef()) for d in [GRAPH_DEF_TXT, GRAPH_DEF_TXT_2] ] ops_and_kernels = print_selective_registration_header.get_ops_and_kernels( 'rawproto', self.WriteGraphFiles(graphs), default_ops) header = print_selective_registration_header.get_header(ops_and_kernels, default_ops) self.assertListEqual( [ '#ifndef OPS_TO_REGISTER', # '#define OPS_TO_REGISTER', # '#define SHOULD_REGISTER_OP(op) true', # '#define SHOULD_REGISTER_OP_KERNEL(clz) true', # '#define SHOULD_REGISTER_OP_GRADIENT true', # '#endif' ], header.split('\n'))