示例#1
0
def verify_strided_set(in_shape, v_shape, begin, end, strides=None):
    A = te.placeholder(shape=in_shape, name="A")
    V = te.placeholder(shape=v_shape, name="V")
    b = te.placeholder(shape=(len(begin), ), name="b", dtype="int32")
    e = te.placeholder(shape=(len(end), ), name="e", dtype="int32")
    if strides is not None:
        st = te.placeholder(shape=(len(strides), ), name="st", dtype="int32")
        B = topi.strided_set(A, V, b, e, st) + 1
    else:
        B = topi.strided_set(A, V, b, e) + 1

    def check_device(device):
        ctx = tvm.context(device, 0)
        if not tvm.testing.device_enabled(device):
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)
        with tvm.target.Target(device):
            s = tvm.topi.testing.get_injective_schedule(device)(B)

        if strides is not None:
            foo = tvm.build(s, [A, V, b, e, st, B], device, name="stride_set")
            s_np = np.asarray(strides).astype("int32")
            s_nd = tvm.nd.array(s_np, ctx)
        else:
            foo = tvm.build(s, [A, V, b, e, B], device, name="stride_set")
        x_np = np.random.uniform(size=in_shape).astype(A.dtype)
        v_np = np.random.uniform(size=v_shape).astype(V.dtype)
        b_np = np.asarray(begin).astype("int32")
        e_np = np.asarray(end).astype("int32")
        out_npy = tvm.topi.testing.strided_set_python(x_np, v_np, begin, end,
                                                      strides) + 1
        data_nd = tvm.nd.array(x_np, ctx)
        v_nd = tvm.nd.array(v_np, ctx)
        b_nd = tvm.nd.array(b_np, ctx)
        e_nd = tvm.nd.array(e_np, ctx)
        out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=A.dtype)
        if strides is not None:
            foo(data_nd, v_nd, b_nd, e_nd, s_nd, out_nd)
        else:
            foo(data_nd, v_nd, b_nd, e_nd, out_nd)
        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)

    for device in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]:
        check_device(device)
示例#2
0
def compute_strided_set(attrs, inputs, output_type):
    """Compute definition of strided_set"""
    return [
        topi.strided_set(inputs[0], inputs[1], inputs[2], inputs[3], inputs[4])
    ]