Пример #1
0
def verify_comparator(shape, dtype, out_type='int8'):
    A = tvm.placeholder(shape, dtype, name="A")
    B = tvm.placeholder(shape, dtype, name="B")
    C = topi.less(A, B)
    s_less = tvm.create_schedule([C.op])

    D = tvm.placeholder(shape, dtype, name="D")
    E = tvm.placeholder(shape, dtype, name="E")
    F = topi.greater(D, E, out_type)
    s_greater = tvm.create_schedule([F.op])

    @memoize("topi.tests.test_topi_indicator")
    def get_ref_data():
        return [
            np.random.uniform(0, 10, size=shape).astype(dtype),
            np.random.uniform(0, 10, size=shape).astype(dtype)
        ]

    [np_l, np_r] = get_ref_data()

    def check_device(device):
        if not tvm.module.enabled(device):
            print("Skip because %s is not enabled" % device)
            return

        ctx = tvm.context(device, 0)
        out = tvm.nd.array(np.zeros(shape, dtype=out_type), ctx)
        tvm_l = tvm.nd.array(np_l, ctx)
        tvm_r = tvm.nd.array(np_r, ctx)

        f = tvm.build(s_less, [A, B, C], device, name="less")
        f(tvm_l, tvm_r, out)
        np.testing.assert_allclose(out.asnumpy(),
                                   np.less(np_l, np_r).astype(out_type),
                                   rtol=1e-5)

        f = tvm.build(s_greater, [D, E, F], device, name="greater")
        f(tvm_l, tvm_r, out)
        np.testing.assert_allclose(out.asnumpy(),
                                   np.greater(np_l, np_r).astype(out_type),
                                   rtol=1e-5)

    for device in ["llvm"]:
        check_device(device)
Пример #2
0
def greater_compute(attrs, inputs, output_type, target):
    assert len(inputs) == 2
    return [topi.greater(inputs[0], inputs[1])]
Пример #3
0
def compute_greater(_, inputs, out_info):
    """Compute definition of greater"""
    return topi.greater(inputs[0], inputs[1]).astype('float32')
Пример #4
0
 def greater(x, y):
     return topi.greater(x, y).astype("int8")
Пример #5
0
 def greater(x, y):
     return topi.greater(x, y).astype("int8")