Beispiel #1
0
def verify_strided_slice(in_shape, begin, end, strides=None):
    A = tvm.placeholder(shape=in_shape, name="A")
    strides = [1,1,1] if strides is None else strides
    B = topi.strided_slice(A, begin, end, strides) + 1

    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)
        with tvm.target.create(device):
            s = topi.generic.schedule_injective(B)

        foo = tvm.build(s, [A, B], device, name="stride_slice")
        x_np = np.random.uniform(size=in_shape).astype(A.dtype)
        out_npy = topi.testing.strided_slice_python(
            x_np, begin, end, strides) + 1
        data_nd = tvm.nd.array(x_np, ctx)
        out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=A.dtype)
        foo(data_nd, out_nd)
        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)

    for device in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]:
        check_device(device)
Beispiel #2
0
def verify_strided_slice(in_shape, begin, end, strides=None):
    A = tvm.placeholder(shape=in_shape, name="A")
    strides = [1, 1, 1] if strides is None else strides
    B = topi.strided_slice(A, begin, end, strides) + 1

    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)
        with tvm.target.create(device):
            s = topi.generic.schedule_injective(B)

        foo = tvm.build(s, [A, B], device, name="stride_slice")
        x_np = np.random.uniform(size=in_shape).astype(A.dtype)
        out_npy = topi.testing.strided_slice_python(x_np, begin, end,
                                                    strides) + 1
        data_nd = tvm.nd.array(x_np, ctx)
        out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=A.dtype)
        foo(data_nd, out_nd)
        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)

    for device in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]:
        check_device(device)
Beispiel #3
0
def verify_strided_slice(in_shape, begin, end, stride=None):
    stride = stride if stride else [1, 1, 1]
    A = tvm.placeholder(shape=in_shape, name="A")
    B = topi.strided_slice(A, begin, end, stride) + 1

    def test_forward(x, begin, end, stride):
        return x[begin[0]:end[0]:stride[0], begin[1]:end[1]:stride[1],
                 begin[2]:end[2]:stride[2]] + 1

    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)
        with tvm.target.create(device):
            s = topi.generic.schedule_injective(B)

        foo = tvm.build(s, [A, B], device, name="stride_slice")
        x_np = np.random.uniform(size=in_shape).astype(A.dtype)
        out_npy = test_forward(x_np, begin, end, stride)
        data_nd = tvm.nd.array(x_np, ctx)
        out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=A.dtype)
        foo(data_nd, out_nd)
        np.testing.assert_allclose(out_nd.asnumpy(), out_npy)

    for device in ["llvm", "opencl", "sdaccel"]:
        check_device(device)