def test_conv2d(): n = 1 h = 14 w = 14 ci = 2 co = 4 kh = 3 kw = 3 stride = 2 A = te.placeholder((n, h, w, ci), name="x") B = te.placeholder((co, kh, kw, ci), name="w") C = mps.conv2d(A, B, "SAME", 2) s1 = te.create_schedule(C.op) def verify(A, B, C, target="llvm"): if not tvm.get_global_func("tvm.contrib.mps.conv2d", True): print("skip because extern function is not available") return dev = tvm.metal(0) f = tvm.build(s1, [A, B, C], "metal") a = tvm.nd.array( np.random.uniform(size=(n, h, w, ci)).astype(A.dtype), dev) b = tvm.nd.array( np.random.uniform(size=(co, kh, kw, ci)).astype(B.dtype), dev) c = tvm.nd.array( np.zeros((n, h // stride, w // stride, co), dtype=C.dtype), dev) f(a, b, c) # print(c.asnumpy()) # print(c.shape) verify(A, B, C, s1)
def test_conv2d(): if not tvm.module.enabled("metal"): print("skip because %s is not enabled..." % "metal") return n = 1 h = 14 w = 14 ci = 2 co = 4 kh = 3 kw = 3 stride = 2 A = tvm.placeholder((n, h, w, ci), name="x") B = tvm.placeholder((co, kh, kw, ci), name="w") C = mps.conv2d(A, B, 'SAME', 2) s1 = tvm.create_schedule(C.op) def verify(A, B, C, target="llvm"): if not tvm.get_global_func("tvm.contrib.mps.conv2d", True): print("skip because extern function is not available") return ctx = tvm.metal(0) f = tvm.build(s1, [A, B, C], "metal") a = tvm.nd.array(np.random.uniform(size=(n, h, w, ci)).astype(A.dtype), ctx) b = tvm.nd.array(np.random.uniform(size=(co, kh, kw, ci)).astype(B.dtype), ctx) c = tvm.nd.array(np.zeros((n, h // stride, w // stride, co), dtype=C.dtype), ctx) f(a, b, c) # print(c.asnumpy()) # print(c.shape) verify(A, B, C, s1)