Exemple #1
0
def verify_where(in_shape):
    Cond = tvm.placeholder(shape=in_shape, name="cond")
    dtype = Cond.dtype
    A = tvm.placeholder(shape=in_shape, name="A")
    B = tvm.placeholder(shape=in_shape, name="B")
    C = topi.where(Cond, A, B)
    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)
        with tvm.target.create(device):
            s = topi.generic.schedule_broadcast(C)
        f = tvm.build(s, [Cond, A, B, C], device, name="where")
        cond_npy = np.random.uniform(low=-1, high=1, size=in_shape).astype(dtype)
        x_npy = np.random.uniform(size=in_shape).astype(dtype)
        y_npy = np.random.uniform(size=in_shape).astype(dtype)
        out_npy = np.where(cond_npy, x_npy, y_npy)
        cond_nd = tvm.nd.array(cond_npy, ctx)
        x_nd = tvm.nd.array(x_npy, ctx)
        y_nd = tvm.nd.array(y_npy, ctx)
        out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(C.dtype), ctx)
        f(cond_nd, x_nd, y_nd, out_nd)
        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)

    for device in get_all_backend():
        check_device(device)
Exemple #2
0
 def check_device(device):
     with tvm.target.create(device):
         ctx = tvm.context(device, 0)
         if not ctx.exist:
             print("Skip because %s is not enabled" % device)
             return
         print("Running on target: %s" % device)
         data = tvm.placeholder((2, 1, 2, 4), 'int8', 'data')
         w = tvm.placeholder((3, 1, 2, 2), 'int8', 'w')
         conv1 = topi.nn.conv2d(data, w, 1, 0, 1, out_dtype='int32')
         zeros = topi.full((2, 3, 1, 3), 'int32', tvm.const(0, dtype='int32'))
         gt = topi.greater_equal(conv1, zeros)
         one = topi.full((2, 3, 1, 3), 'int32', tvm.const(1, dtype='int32'))
         two = topi.full((2, 3, 1, 3), 'int32', tvm.const(2, dtype='int32'))
         where = topi.where(gt, one, two)
         add = topi.add(conv1, where)
         outs = [add]
         s = topi.generic.schedule_conv2d_nchw(outs)
         tvm.build(s, [data, w, add], target=backend)