def verify_expand_like(in_shape, out_shape, axis): A = tvm.placeholder(shape=in_shape, name="A") B = tvm.placeholder(shape=out_shape, name="B") C = topi.expand_like(A, B, axis) s = tvm.create_schedule([C.op]) def check_device(device): if not tvm.module.enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) ctx = tvm.context(device, 0) f = tvm.build(s, [A, B, C], device, name="expand_like") input = np.random.uniform(size=in_shape).astype(A.dtype) tvm_input = tvm.nd.array(input, ctx) odim = len(out_shape) real_axis = [x if x >= 0 else x + odim for x in axis] real_axis = sorted(real_axis) for x in real_axis: input = np.expand_dims(input, x).astype(A.dtype) for x in real_axis: input = np.concatenate([input]*out_shape[x], axis=x).astype(A.dtype) assert input.shape == out_shape tvm_shape_like = tvm.nd.array(np.zeros(out_shape).astype(B.dtype), ctx) out = tvm.nd.array(np.zeros(out_shape).astype(A.dtype), ctx) f(tvm_input, tvm_shape_like, out) tvm.testing.assert_allclose(out.asnumpy(), input) for device in ["llvm"]: check_device(device)
def verify_expand_like(in_shape, out_shape, axis): A = tvm.placeholder(shape=in_shape, name="A") B = tvm.placeholder(shape=out_shape, name="B") C = topi.expand_like(A, B, axis) s = tvm.create_schedule([C.op]) def check_device(device): if not tvm.module.enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) ctx = tvm.context(device, 0) f = tvm.build(s, [A, B, C], device, name="expand_like") input = np.random.uniform(size=in_shape).astype(A.dtype) tvm_input = tvm.nd.array(input, ctx) odim = len(out_shape) real_axis = [x if x >= 0 else x + odim for x in axis] real_axis = sorted(real_axis) for x in real_axis: input = np.expand_dims(input, x).astype(A.dtype) for x in real_axis: input = np.concatenate([input] * out_shape[x], axis=x).astype(A.dtype) assert input.shape == out_shape tvm_shape_like = tvm.nd.array(np.zeros(out_shape).astype(B.dtype), ctx) out = tvm.nd.array(np.zeros(out_shape).astype(A.dtype), ctx) f(tvm_input, tvm_shape_like, out) np.testing.assert_allclose(out.asnumpy(), input) for device in ["llvm"]: check_device(device)