def testGetOps(self):
    default_ops = 'NoOp:NoOp,_Recv:RecvOp,_Send:SendOp'
    graphs = [text_format.Parse(d, graph_pb2.GraphDef())
              for d in [GRAPH_DEF_TXT, GRAPH_DEF_TXT_2]]

    ops_and_kernels = print_selective_registration_header.get_ops_and_kernels(
        'rawproto', self.WriteGraphFiles(graphs), default_ops)
    self.assertListEqual([('BiasAdd', 'BiasOp<CPUDevice, float>'),  #
                          ('MatMul', 'MatMulOp<CPUDevice, double, false >'),  #
                          ('MatMul', 'MatMulOp<CPUDevice, float, false >'),  #
                          ('NoOp', 'NoOp'),  #
                          ('Reshape', 'ReshapeOp'),  #
                          ('_Recv', 'RecvOp'),  #
                          ('_Send', 'SendOp'),  #
                         ],
                         ops_and_kernels)

    graphs[0].node[0].ClearField('device')
    graphs[0].node[2].ClearField('device')
    ops_and_kernels = print_selective_registration_header.get_ops_and_kernels(
        'rawproto', self.WriteGraphFiles(graphs), default_ops)
    self.assertListEqual([('BiasAdd', 'BiasOp<CPUDevice, float>'),  #
                          ('MatMul', 'MatMulOp<CPUDevice, double, false >'),  #
                          ('MatMul', 'MatMulOp<CPUDevice, float, false >'),  #
                          ('NoOp', 'NoOp'),  #
                          ('Reshape', 'ReshapeOp'),  #
                          ('_Recv', 'RecvOp'),  #
                          ('_Send', 'SendOp'),  #
                         ],
                         ops_and_kernels)
    def testGetOps(self):
        default_ops = 'NoOp:NoOp,_Recv:RecvOp,_Send:SendOp'
        graphs = [
            text_format.Parse(d, graph_pb2.GraphDef())
            for d in [GRAPH_DEF_TXT, GRAPH_DEF_TXT_2]
        ]

        ops_and_kernels = print_selective_registration_header.get_ops_and_kernels(
            'rawproto', self.WriteGraphFiles(graphs), default_ops)
        self.assertListEqual(
            [
                ('BiasAdd', 'BiasOp<CPUDevice, float>'),  #
                ('MatMul', 'MatMulOp<CPUDevice, double, false >'),  #
                ('MatMul', 'MatMulOp<CPUDevice, float, false >'),  #
                ('NoOp', 'NoOp'),  #
                ('Reshape', 'ReshapeOp'),  #
                ('_Recv', 'RecvOp'),  #
                ('_Send', 'SendOp'),  #
            ],
            ops_and_kernels)

        graphs[0].node[0].ClearField('device')
        graphs[0].node[2].ClearField('device')
        ops_and_kernels = print_selective_registration_header.get_ops_and_kernels(
            'rawproto', self.WriteGraphFiles(graphs), default_ops)
        self.assertListEqual(
            [
                ('BiasAdd', 'BiasOp<CPUDevice, float>'),  #
                ('MatMul', 'MatMulOp<CPUDevice, double, false >'),  #
                ('MatMul', 'MatMulOp<CPUDevice, float, false >'),  #
                ('NoOp', 'NoOp'),  #
                ('Reshape', 'ReshapeOp'),  #
                ('_Recv', 'RecvOp'),  #
                ('_Send', 'SendOp'),  #
            ],
            ops_and_kernels)
    def testAll(self):
        default_ops = 'all'
        graphs = [
            text_format.Parse(d, graph_pb2.GraphDef())
            for d in [GRAPH_DEF_TXT, GRAPH_DEF_TXT_2]
        ]
        ops_and_kernels = print_selective_registration_header.get_ops_and_kernels(
            'rawproto', self.WriteGraphFiles(graphs), default_ops)

        header = print_selective_registration_header.get_header(
            ops_and_kernels, default_ops)
        self.assertListEqual(
            [
                '#ifndef OPS_TO_REGISTER',  #
                '#define OPS_TO_REGISTER',  #
                '#define SHOULD_REGISTER_OP(op) true',  #
                '#define SHOULD_REGISTER_OP_KERNEL(clz) true',  #
                '#define SHOULD_REGISTER_OP_GRADIENT true',  #
                '#endif'
            ],
            header.split('\n'))
  def testAll(self):
    default_ops = 'all'
    graphs = [
        text_format.Parse(d, graph_pb2.GraphDef())
        for d in [GRAPH_DEF_TXT, GRAPH_DEF_TXT_2]
    ]
    ops_and_kernels = print_selective_registration_header.get_ops_and_kernels(
        'rawproto', self.WriteGraphFiles(graphs), default_ops)

    header = print_selective_registration_header.get_header(ops_and_kernels,
                                                            default_ops)
    self.assertListEqual(
        [
            '#ifndef OPS_TO_REGISTER',  #
            '#define OPS_TO_REGISTER',  #
            '#define SHOULD_REGISTER_OP(op) true',  #
            '#define SHOULD_REGISTER_OP_KERNEL(clz) true',  #
            '#define SHOULD_REGISTER_OP_GRADIENT true',  #
            '#endif'
        ],
        header.split('\n'))