コード例 #1
0
def test_squeeze():
    verify_squeeze((1, 2, 3, 4), 0)
    verify_squeeze((1, 2, 1, 4), None)
    verify_squeeze((1, 1, 1, 4), (1, 2))
    verify_squeeze((1, 1, 1, 1), None)

    # a special case to trigger inline let expression
    A = te.placeholder((2, ), "float32", "A")
    E = topi.squeeze(A)
    C = te.compute((1, ), lambda i: E[(2 * A[0] - 1).astype("int32")])
    for device in ["cuda", "opencl"]:
        ctx = tvm.context(device, 0)
        if tvm.testing.device_enabled(device):
            with tvm.target.Target(device):
                s = tvm.topi.testing.get_injective_schedule(device)(C)
                func = tvm.build(s, [A, C])
            a = tvm.nd.array(np.array((1, 2)).astype("float32"), ctx=ctx)
            c = tvm.nd.empty((1, ), dtype="float32", ctx=ctx)
            func(a, c)
            assert c.asnumpy()[0] == 2
コード例 #2
0
def test_squeeze():
    verify_squeeze((1, 2, 3, 4), 0)
    verify_squeeze((1, 2, 1, 4), None)
    verify_squeeze((1, 1, 1, 4), (1, 2))
    verify_squeeze((1, 1, 1, 1), None)

    # a special case to trigger inline let expression
    A = te.placeholder((2,), 'float32', 'A')
    E = topi.squeeze(A)
    C = te.compute((1,), lambda i: E[(2 * A[0] - 1).astype('int32')])
    for device in ['cuda', 'opencl']:
        ctx = tvm.context(device, 0)
        if ctx.exist:
            with tvm.target.create(device):
                s = tvm.topi.testing.get_injective_schedule(device)(C)
                func = tvm.build(s, [A, C])
            a = tvm.nd.array(np.array((1, 2)).astype('float32'), ctx=ctx)
            c = tvm.nd.empty((1,), dtype='float32', ctx=ctx)
            func(a, c)
            assert c.asnumpy()[0] == 2
コード例 #3
0
def verify_squeeze(src_shape, axis):
    A = te.placeholder(shape=src_shape, name="A")
    B = topi.squeeze(A, axis=axis)

    def check_device(device, ctx):
        print("Running on target: %s" % device)
        with tvm.target.Target(device):
            s = tvm.topi.testing.get_injective_schedule(device)(B)

        foo = tvm.build(s, [A, B], device, name="squeeze")
        data_npy = np.random.normal(size=src_shape).astype(A.dtype)
        out_npy = np.squeeze(data_npy, axis=axis)
        data_nd = tvm.nd.array(data_npy, ctx)
        out_nd_shape = out_npy.shape
        out_nd = tvm.nd.empty(out_nd_shape, ctx=ctx, dtype=B.dtype)
        foo(data_nd, out_nd)
        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)

    for device, ctx in tvm.testing.enabled_targets():
        check_device(device, ctx)
コード例 #4
0
def verify_squeeze(src_shape, axis):
    A = te.placeholder(shape=src_shape, name="A")
    B = topi.squeeze(A, axis=axis)
    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)
        with tvm.target.create(device):
            s = tvm.topi.testing.get_injective_schedule(device)(B)

        foo = tvm.build(s, [A, B], device, name="squeeze")
        data_npy = np.random.normal(size=src_shape).astype(A.dtype)
        out_npy = np.squeeze(data_npy, axis=axis)
        data_nd = tvm.nd.array(data_npy, ctx)
        out_nd_shape = out_npy.shape
        out_nd = tvm.nd.empty(out_nd_shape, ctx=ctx, dtype=B.dtype)
        foo(data_nd, out_nd)
        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)

    for device in get_all_backend():
        check_device(device)