예제 #1
0
    def testGetOps(self):
        default_ops = 'NoOp:NoOp,_Recv:RecvOp,_Send:SendOp'
        graphs = [
            text_format.Parse(d, graph_pb2.GraphDef())
            for d in [GRAPH_DEF_TXT, GRAPH_DEF_TXT_2]
        ]

        ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels(
            'rawproto', self.WriteGraphFiles(graphs), default_ops)
        matmul_prefix = 'Batch'

        self.assertListEqual(
            [
                ('AccumulateNV2', None),  #
                ('BiasAdd', 'BiasOp<CPUDevice, float>'),  #
                ('Const', 'ConstantOp'),  #
                ('MatMul', matmul_prefix +
                 'MatMulOp<CPUDevice, double, double, double, true>'),  #
                ('MatMul', matmul_prefix +
                 'MatMulOp<CPUDevice, float, float, float, true>'),  #
                ('Maximum',
                 'BinaryOp<CPUDevice, functor::maximum<int64_t>>'),  #
                ('NoOp', 'NoOp'),  #
                ('Reshape', 'ReshapeOp'),  #
                ('_Recv', 'RecvOp'),  #
                ('_Send', 'SendOp'),  #
            ],
            ops_and_kernels)

        graphs[0].node[0].ClearField('device')
        graphs[0].node[2].ClearField('device')
        ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels(
            'rawproto', self.WriteGraphFiles(graphs), default_ops)
        self.assertListEqual(
            [
                ('AccumulateNV2', None),  #
                ('BiasAdd', 'BiasOp<CPUDevice, float>'),  #
                ('Const', 'ConstantOp'),  #
                ('MatMul', matmul_prefix +
                 'MatMulOp<CPUDevice, double, double, double, true>'),  #
                ('MatMul', matmul_prefix +
                 'MatMulOp<CPUDevice, float, float, float, true>'),  #
                ('Maximum',
                 'BinaryOp<CPUDevice, functor::maximum<int64_t>>'),  #
                ('NoOp', 'NoOp'),  #
                ('Reshape', 'ReshapeOp'),  #
                ('_Recv', 'RecvOp'),  #
                ('_Send', 'SendOp'),  #
            ],
            ops_and_kernels)
  def testAll(self):
    default_ops = 'all'
    graphs = [
        text_format.Parse(d, graph_pb2.GraphDef())
        for d in [GRAPH_DEF_TXT, GRAPH_DEF_TXT_2]
    ]
    ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels(
        'rawproto', self.WriteGraphFiles(graphs), default_ops)

    header = selective_registration_header_lib.get_header_from_ops_and_kernels(
        ops_and_kernels, include_all_ops_and_kernels=True)
    self.assertListEqual(
        [
            '// This file was autogenerated by %s' % self.script_name,
            '#ifndef OPS_TO_REGISTER',  #
            '#define OPS_TO_REGISTER',  #
            '#define SHOULD_REGISTER_OP(op) true',  #
            '#define SHOULD_REGISTER_OP_KERNEL(clz) true',  #
            '#define SHOULD_REGISTER_OP_GRADIENT true',  #
            '#endif'
        ],
        header.split('\n'))

    self.assertListEqual(
        header.split('\n'),
        selective_registration_header_lib.get_header(
            self.WriteGraphFiles(graphs), 'rawproto', default_ops).split('\n'))
  def testAll(self):
    default_ops = 'all'
    graphs = [
        text_format.Parse(d, graph_pb2.GraphDef())
        for d in [GRAPH_DEF_TXT, GRAPH_DEF_TXT_2]
    ]
    ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels(
        'rawproto', self.WriteGraphFiles(graphs), default_ops)

    header = selective_registration_header_lib.get_header_from_ops_and_kernels(
        ops_and_kernels, include_all_ops_and_kernels=True)
    self.assertListEqual(
        [
            '// This file was autogenerated by %s' % self.script_name,
            '#ifndef OPS_TO_REGISTER',  #
            '#define OPS_TO_REGISTER',  #
            '#define SHOULD_REGISTER_OP(op) true',  #
            '#define SHOULD_REGISTER_OP_KERNEL(clz) true',  #
            '#define SHOULD_REGISTER_OP_GRADIENT true',  #
            '#endif'
        ],
        header.split('\n'))

    self.assertListEqual(
        header.split('\n'),
        selective_registration_header_lib.get_header(
            self.WriteGraphFiles(graphs), 'rawproto', default_ops).split('\n'))
  def testGetOps(self):
    default_ops = 'NoOp:NoOp,_Recv:RecvOp,_Send:SendOp'
    graphs = [
        text_format.Parse(d, graph_pb2.GraphDef())
        for d in [GRAPH_DEF_TXT, GRAPH_DEF_TXT_2]
    ]

    ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels(
        'rawproto', self.WriteGraphFiles(graphs), default_ops)
    matmul_prefix = ''
    if test_util.IsMklEnabled():
      matmul_prefix = 'Mkl'

    self.assertListEqual(
        [
            ('AccumulateNV2', None),  #
            ('BiasAdd', 'BiasOp<CPUDevice, float>'),  #
            ('MatMul',
             matmul_prefix + 'MatMulOp<CPUDevice, double, false >'),  #
            ('MatMul', matmul_prefix + 'MatMulOp<CPUDevice, float, false >'),  #
            ('NoOp', 'NoOp'),  #
            ('Reshape', 'ReshapeOp'),  #
            ('_Recv', 'RecvOp'),  #
            ('_Send', 'SendOp'),  #
        ],
        ops_and_kernels)

    graphs[0].node[0].ClearField('device')
    graphs[0].node[2].ClearField('device')
    ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels(
        'rawproto', self.WriteGraphFiles(graphs), default_ops)
    self.assertListEqual(
        [
            ('AccumulateNV2', None),  #
            ('BiasAdd', 'BiasOp<CPUDevice, float>'),  #
            ('MatMul',
             matmul_prefix + 'MatMulOp<CPUDevice, double, false >'),  #
            ('MatMul', matmul_prefix + 'MatMulOp<CPUDevice, float, false >'),  #
            ('NoOp', 'NoOp'),  #
            ('Reshape', 'ReshapeOp'),  #
            ('_Recv', 'RecvOp'),  #
            ('_Send', 'SendOp'),  #
        ],
        ops_and_kernels)
  def testGetOps(self):
    default_ops = 'NoOp:NoOp,_Recv:RecvOp,_Send:SendOp'
    graphs = [
        text_format.Parse(d, graph_pb2.GraphDef())
        for d in [GRAPH_DEF_TXT, GRAPH_DEF_TXT_2]
    ]

    ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels(
        'rawproto', self.WriteGraphFiles(graphs), default_ops)
    matmul_prefix = ''
    if test_util.IsMklEnabled():
      matmul_prefix = 'Mkl'

    self.assertListEqual(
        [
            ('BiasAdd', 'BiasOp<CPUDevice, float>'),  #
            ('MatMul',
             matmul_prefix + 'MatMulOp<CPUDevice, double, false >'),  #
            ('MatMul', matmul_prefix + 'MatMulOp<CPUDevice, float, false >'),  #
            ('NoOp', 'NoOp'),  #
            ('Reshape', 'ReshapeOp'),  #
            ('_Recv', 'RecvOp'),  #
            ('_Send', 'SendOp'),  #
        ],
        ops_and_kernels)

    graphs[0].node[0].ClearField('device')
    graphs[0].node[2].ClearField('device')
    ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels(
        'rawproto', self.WriteGraphFiles(graphs), default_ops)
    self.assertListEqual(
        [
            ('BiasAdd', 'BiasOp<CPUDevice, float>'),  #
            ('MatMul',
             matmul_prefix + 'MatMulOp<CPUDevice, double, false >'),  #
            ('MatMul', matmul_prefix + 'MatMulOp<CPUDevice, float, false >'),  #
            ('NoOp', 'NoOp'),  #
            ('Reshape', 'ReshapeOp'),  #
            ('_Recv', 'RecvOp'),  #
            ('_Send', 'SendOp'),  #
        ],
        ops_and_kernels)
  def testGetOpsFromList(self):
    default_ops = ''
    # Test with 2 different ops.
    ops_list = """[["Add", "BinaryOp<CPUDevice, functor::add<float>>"],
        ["Softplus", "SoftplusOp<CPUDevice, float>"]]"""
    ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels(
        'ops_list', self.WriteTextFile(ops_list), default_ops)
    self.assertListEqual([
        ('Add', 'BinaryOp<CPUDevice, functor::add<float>>'),
        ('Softplus', 'SoftplusOp<CPUDevice, float>'),
    ], ops_and_kernels)

    # Test with a single op.
    ops_list = '[["Softplus", "SoftplusOp<CPUDevice, float>"]]'
    ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels(
        'ops_list', self.WriteTextFile(ops_list), default_ops)
    self.assertListEqual([
        ('Softplus', 'SoftplusOp<CPUDevice, float>'),
    ], ops_and_kernels)

    # Test with duplicated op.
    ops_list = """[["Add", "BinaryOp<CPUDevice, functor::add<float>>"],
        ["Add", "BinaryOp<CPUDevice, functor::add<float>>"]]"""
    ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels(
        'ops_list', self.WriteTextFile(ops_list), default_ops)
    self.assertListEqual([
        ('Add', 'BinaryOp<CPUDevice, functor::add<float>>'),
    ], ops_and_kernels)

    # Test op with no kernel.
    ops_list = '[["Softplus", ""]]'
    ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels(
        'ops_list', self.WriteTextFile(ops_list), default_ops)
    self.assertListEqual([
        ('Softplus', None),
    ], ops_and_kernels)

    # Test two ops_list files.
    ops_list = '[["Softplus", "SoftplusOp<CPUDevice, float>"]]'
    ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels(
        'ops_list',
        self.WriteTextFile(ops_list) + self.WriteTextFile(ops_list),
        default_ops)
    self.assertListEqual([
        ('Softplus', 'SoftplusOp<CPUDevice, float>'),
    ], ops_and_kernels)

    # Test empty file.
    ops_list = ''
    with self.assertRaises(Exception):
      ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels(
          'ops_list', self.WriteTextFile(ops_list), default_ops)