def check_erf(dev, n, dtype): A = te.placeholder((n, ), name="A", dtype=dtype) C = te.compute(A.shape, lambda *i: te.erf(A(*i)), name="C") s = te.create_schedule(C.op) s[C].bind(s[C].op.axis[0], te.thread_axis("threadIdx.x")) fun = tvm.build(s, [A, C], target) source_str = fun.imported_modules[0].get_source() matches = re.findall("erf", source_str) error_matches = re.findall("erff", source_str) assert len(matches) == 1 and len(error_matches) == 0
def check_erf(dev, n, dtype): A = te.placeholder((n, ), name="A", dtype=dtype) C = te.compute(A.shape, lambda *i: te.erf(A(*i)), name="C") s = te.create_schedule(C.op) s[C].bind(s[C].op.axis[0], tx) fun = tvm.build(s, [A, C], target) a = tvm.nd.empty((n, ), A.dtype, dev) c = tvm.nd.empty((n, ), A.dtype, dev) # Only need to test compiling here fun(a, c)
def erf(x): """Take gauss error function of input x. Parameters ---------- x : tvm.te.Tensor Input argument. Returns ------- y : tvm.te.Tensor The result. """ return te.compute(x.shape, lambda *i: te.erf(x(*i)))