예제 #1
0
def verify_where(in_shape):
    Cond = te.placeholder(shape=in_shape, name="cond")
    dtype = Cond.dtype
    A = te.placeholder(shape=in_shape, name="A")
    B = te.placeholder(shape=in_shape, name="B")
    C = topi.where(Cond, A, B)

    def check_device(device, ctx):
        print("Running on target: %s" % device)
        with tvm.target.Target(device):
            s = tvm.topi.testing.get_broadcast_schedule(device)(C)
        f = tvm.build(s, [Cond, A, B, C], device, name="where")
        cond_npy = np.random.uniform(low=-1, high=1,
                                     size=in_shape).astype(dtype)
        x_npy = np.random.uniform(size=in_shape).astype(dtype)
        y_npy = np.random.uniform(size=in_shape).astype(dtype)
        out_npy = np.where(cond_npy, x_npy, y_npy)
        cond_nd = tvm.nd.array(cond_npy, ctx)
        x_nd = tvm.nd.array(x_npy, ctx)
        y_nd = tvm.nd.array(y_npy, ctx)
        out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(C.dtype), ctx)
        f(cond_nd, x_nd, y_nd, out_nd)
        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)

    for device, ctx in tvm.testing.enabled_targets():
        check_device(device, ctx)
예제 #2
0
 def check_device(device, ctx):
     with tvm.target.Target(device):
         print("Running on target: %s" % device)
         conv2d_compute, conv2d_schedule = tvm.topi.testing.get_conv2d_nchw_implement(device)
         data = te.placeholder((2, 1, 2, 4), "int8", "data")
         w = te.placeholder((3, 1, 2, 2), "int8", "w")
         conv1 = conv2d_compute(data, w, 1, 0, 1, "int32")
         zeros = topi.full((2, 3, 1, 3), "int32", tvm.tir.const(0, dtype="int32"))
         gt = topi.greater_equal(conv1, zeros)
         one = topi.full((2, 3, 1, 3), "int32", tvm.tir.const(1, dtype="int32"))
         two = topi.full((2, 3, 1, 3), "int32", tvm.tir.const(2, dtype="int32"))
         where = topi.where(gt, one, two)
         add = topi.add(conv1, where)
         outs = [add]
         s = conv2d_schedule(outs)
         tvm.build(s, [data, w, add], target=backend)
예제 #3
0
 def check_device(device):
     with tvm.target.create(device):
         ctx = tvm.context(device, 0)
         if not ctx.exist:
             print("Skip because %s is not enabled" % device)
             return
         print("Running on target: %s" % device)
         conv2d_compute, conv2d_schedule = tvm.topi.testing.get_conv2d_nchw_implement(device)
         data = te.placeholder((2, 1, 2, 4), 'int8', 'data')
         w = te.placeholder((3, 1, 2, 2), 'int8', 'w')
         conv1 = conv2d_compute(data, w, 1, 0, 1, 'int32')
         zeros = topi.full((2, 3, 1, 3), 'int32', tvm.tir.const(0, dtype='int32'))
         gt = topi.greater_equal(conv1, zeros)
         one = topi.full((2, 3, 1, 3), 'int32', tvm.tir.const(1, dtype='int32'))
         two = topi.full((2, 3, 1, 3), 'int32', tvm.tir.const(2, dtype='int32'))
         where = topi.where(gt, one, two)
         add = topi.add(conv1, where)
         outs = [add]
         s = conv2d_schedule(outs)
         tvm.build(s, [data, w, add], target=backend)