示例#1
0
def verify_strided_set(in_shape, v_shape, begin, end, strides=None):
    A = tvm.placeholder(shape=in_shape, name="A")
    V = tvm.placeholder(shape=v_shape, name="V")
    b = tvm.placeholder(shape=(len(begin),), name="b", dtype='int32')
    e = tvm.placeholder(shape=(len(end),), name="e", dtype='int32')
    if strides is not None:
        st = tvm.placeholder(shape=(len(strides),), name="st", dtype='int32')
        B = topi.strided_set(A, V, b, e, st) + 1
    else:
        B = topi.strided_set(A, V, b, e) + 1

    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)
        with tvm.target.create(device):
            s = topi.generic.schedule_injective(B)

        if strides is not None:
            foo = tvm.build(s, [A, V, b, e, st, B], device, name="stride_set")
            s_np = np.asarray(strides).astype('int32')
            s_nd = tvm.nd.array(s_np, ctx)
        else:
            foo = tvm.build(s, [A, V, b, e, B], device, name="stride_set")
        x_np = np.random.uniform(size=in_shape).astype(A.dtype)
        v_np = np.random.uniform(size=v_shape).astype(V.dtype)
        b_np = np.asarray(begin).astype('int32')
        e_np = np.asarray(end).astype('int32')
        out_npy = topi.testing.strided_set_python(
            x_np, v_np, begin, end, strides) + 1
        data_nd = tvm.nd.array(x_np, ctx)
        v_nd = tvm.nd.array(v_np, ctx)
        b_nd = tvm.nd.array(b_np, ctx)
        e_nd = tvm.nd.array(e_np, ctx)
        out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=A.dtype)
        if strides is not None:
            foo(data_nd, v_nd, b_nd, e_nd, s_nd, out_nd)
        else:
            foo(data_nd, v_nd, b_nd, e_nd, out_nd)
        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)

    for device in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]:
        check_device(device)
示例#2
0
def compute_strided_set(attrs, inputs, output_type):
    """Compute definition of strided_set"""
    return [
        topi.strided_set(inputs[0], inputs[1], inputs[2], inputs[3], inputs[4])
    ]