def test_storage_sync(): m = tvm.size_var('m') l = tvm.size_var('l') A = tvm.placeholder((m, l), name='A') A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1') A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2') s = tvm.create_schedule(A2.op) xo, xi = s[A2].split(A2.op.axis[0], factor=8) s[A2].bind(xo, tvm.thread_axis("blockIdx.x")) s[A1].compute_at(s[A2], xo) s[A1].set_scope("shared") bounds = tvm.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) stmt = tvm.schedule.ScheduleOps(s, bounds) Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') A2b = tvm.decl_buffer(A2.shape, A2.dtype, name='A2') stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64) f = tvm.ir_pass.MakeAPI(stmt, "test", [Ab, A2b], 0, True) flist = tvm.ir_pass.SplitHostDevice(f) f = flist[1] f = tvm.ir_pass.ThreadSync(f, "shared") body_list = tvm.make.stmt_list(f.body.body.body.body) assert (body_list[1].value.name == "tvm_storage_sync")
def test_dynamic_tensor(): dtype = 'float32' stype = 'csr' target = 'llvm' ctx = tvm.context(target, 0) nr, nc, n = tvm.size_var('nr'), tvm.size_var('nc'), tvm.size_var('n') A = tvmsp.placeholder(shape=(nr, nc), nonzeros=n, name='A', dtype=dtype) assert(A.stype == 'csr') C = tvm.compute(A.data.shape, lambda i: A.data[i] * 2., tag='cs_scatter') s = tvm.create_schedule(C.op) _nr, _nc = 3, 5 a = np.maximum(np.random.uniform(size=(_nr, _nc)).astype(dtype)-.6, 0.) a = tvmsp.array(a, ctx) assert a.data.dtype == a.dtype Ab = namedtuple('CSRBuffer', ['data', 'indices', 'indptr']) Ab.data = tvm.decl_buffer(a.data.shape, a.data.dtype, name='A_data') Ab.indices = tvm.decl_buffer(a.data.shape, a.data.dtype, name='A_indices') binds = {A.data: Ab.data, A.indices: Ab.indices} f = tvm.build(s, [nr, A.data, C], target, binds=binds) c = tvmsp.array(np.zeros((_nr, _nc), dtype), ctx) c.data = tvm.nd.empty(a.data.shape, dtype) c.indices = a.indices c.indptr = a.indptr f(a.data.shape[0], a.data, c.data) tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() * 2., rtol=1e-5)
def test_infer_type_leaky_relu(): n, c, h, w = tvm.size_var("n"), tvm.size_var("c"), tvm.size_var( "h"), tvm.size_var("w") x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) y = relay.nn.leaky_relu(x, alpha=0.1) "alpha=0.1" in y.astext() yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((n, c, h, w), "float32") shape = (1, 5, 10, 10) dtype = "float32" x = relay.var("x", relay.TensorType(shape, dtype)) z = relay.nn.leaky_relu(x, alpha=0.1) assert "alpha=0.1" in z.astext() zz = run_infer_type(z) assert zz.checked_type == relay.TensorType(shape, dtype) func = relay.Function([x], z) x_data = np.random.uniform(low=-1, high=1, size=shape).astype(dtype) ref_res = np.where(x_data > 0, x_data, x_data * 0.1) for target, ctx in ctx_list(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) intrp2 = relay.create_executor("debug", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(x_data) tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5) op_res2 = intrp2.evaluate(func)(x_data) tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5)
def test_with(): n = tvm.size_var('n') m = tvm.size_var('m') l = tvm.size_var('l') A = tvm.placeholder((n, l), name='A') B = tvm.placeholder((m, l), name='B') with tvm.tag_scope(tag="gemm"): k = tvm.reduce_axis((0, l), name='k') C = tvm.compute((n, m), lambda i, j: tvm.sum(A[i, k] * B[j, k], axis=k), attrs={ "hello": 1, "arr": [10, 12] }) assert C.op.tag == 'gemm' assert "hello" in C.op.attrs assert "xx" not in C.op.attrs assert C.op.attrs["hello"].value == 1 CC = tvm.load_json(tvm.save_json(C)) assert CC.op.attrs["hello"].value == 1 assert CC.op.attrs["arr"][0].value == 10 # str format happened to be json compatible assert json.loads(str(CC.op.attrs))["arr"][1] == 12
def test_tuple_with_different_deps(): m = tvm.size_var('m') n = tvm.size_var('n') A0 = tvm.placeholder((m, n), name='A1') A1 = tvm.placeholder((m, n), name='A2') B0, B1 = tvm.compute((m, n), lambda i, j: (A0[i, j] * 2, A1[i, j] * 3), name='B') C = tvm.compute((m, n), lambda i, j: B0[i, j] + 4, name='C') s = tvm.create_schedule(C.op) xo, xi = s[C].split(C.op.axis[0], factor=10) s[B0.op].compute_at(s[C], xo) sch = s.normalize() bounds = tvm.schedule.InferBound(sch) stmt = tvm.schedule.ScheduleOps(sch, bounds) def get_B1_realize(x): if isinstance(x, tvm.stmt.Realize) and \ x.func == B1.op and x.value_index == 1: ret.append(x) ret = [] tvm.ir_pass.PostOrderVisit(stmt, get_B1_realize) assert stmt.node == C.op and len(ret) == 1
def test_rocm_cross_thread_reduction(): # based on the reduction tutorial n = tvm.size_var("n") m = tvm.size_var("m") A = tvm.placeholder((n, m), name='A') k = tvm.reduce_axis((0, m), "k") B = tvm.compute((n, ), lambda i: tvm.sum(A[i, k], axis=k), name="B") s = tvm.create_schedule(B.op) ko, ki = s[B].split(B.op.reduce_axis[0], factor=16) BF = s.rfactor(B, ki) xo, xi = s[B].split(s[B].op.axis[0], factor=32) s[B].bind(xo, bx) s[B].bind(xi, ty) s[B].bind(s[B].op.reduce_axis[0], tx) s[BF].compute_at(s[B], s[B].op.reduce_axis[0]) s[B].set_store_predicate(tx.var.equal(0)) frocm = tvm.build(s, [A, B], "rocm") nn = 128 ctx = tvm.rocm(0) a = tvm.nd.array(np.random.uniform(size=(nn, nn)).astype(A.dtype), ctx) b = tvm.nd.array(np.zeros(nn, dtype=B.dtype), ctx) frocm(a, b) tvm.testing.assert_allclose(b.asnumpy(), np.sum(a.asnumpy(), axis=1), rtol=1e-4)
def test_scan_group(): m = tvm.size_var("m") n = tvm.size_var("n") x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x") s_state = tvm.placeholder((m, n)) s_init = tvm.compute((1, n), lambda _, i: x[0, i]) s_update1 = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + x[t, i]) s_update2 = tvm.compute((m, n), lambda t, i: s_update1[t, i] + 1) s_update3 = tvm.compute((m, n), lambda t, i: s_update2[t, i] + 1) res = tvm.scan(s_init, s_update3, s_state, inputs=x) s = tvm.create_schedule(res.op) assert s[s_update1].group is not None assert s[s_update2].group == s[s_update1].group # Assign within group, is valid s[s_update1].compute_at(s[s_update2], s_update2.op.axis[1]) # create a new group, for [s_update2 and s_update1] g2 = s.create_group(outputs=s_update2, inputs=[s_state, x]) assert g2.group is not None assert g2.group == s[s_update3].group assert s[s_update2].group == g2 assert s[s_update1].group == g2 g2.compute_at(s[s_update3], s_update3.op.axis[1]) assert g2.attach_stage == s[s_update3] try: # compute outside group error. s[s_update2].compute_at(s[s_init], s_init.op.axis[0]) assert False except tvm.TVMError: pass
def test_tensor_intrin_scalar_params(): n = tvm.size_var("n") x = tvm.placeholder((n,), name='x') v = tvm.size_var("v") w = tvm.size_var("w") z = tvm.compute((n,), lambda i: x[i]*v + w, name='z') def intrin_func(ins, outs, sp): assert(isinstance(ins[0], tvm.schedule.Buffer)) assert(ins[0].shape[0] == n) assert(sp[0] == v) assert(sp[1] == w) return tvm.call_packed("hw_func", ins[0].data, outs[0].data, sp[0], sp[1]) with tvm.build_config(offset_factor=1): intrin = tvm.decl_tensor_intrin(z.op, intrin_func, scalar_params=[v, w]) assert intrin.op == z.op assert intrin.reduce_init is None assert tuple(intrin.inputs) == tuple(z.op.input_tensors) assert(intrin.buffers[0].shape[0] == n) assert tuple(intrin.scalar_params) == tuple((v, w)) A = tvm.placeholder((10,10), name='A') # Pass scalar inputs to the TensorIntrin, interleaved with tensor inputs C = tvm.compute((10,10), lambda i, j: intrin(i*i, A[i, j], i+j), name="C") s = tvm.create_schedule(C.op) stmt = tvm.lower(s, [A, C], simple_mode=True) assert isinstance(stmt.body.body.body, tvm.stmt.Evaluate) assert len(stmt.body.body.body.value.args) == 5 assert str(stmt.body.body.body.value.args[3]) == "(i*i)" assert str(stmt.body.body.body.value.args[4]) == "(i + j)"
def test_schedule_create(): m = tvm.size_var('m') n = tvm.size_var('n') l = tvm.size_var('l') A = tvm.placeholder((m, l), name='A') B = tvm.placeholder((n, l), name='B') AA = tvm.compute((m, l), lambda i, j: A[i, j]) T = tvm.compute((m, n, l), lambda i, j, k: AA(i, k) * B(j, k)) s = tvm.create_schedule(T.op) s[AA].set_scope("shared") xo, xi = s[T].split(T.op.axis[0], factor=10) xi1, xi2 = s[T].split(xi, factor=2) s[AA].compute_at(s[T], xi1) xo, xi = s[AA].split(AA.op.axis[0], factor=10) s[T].reorder(xi2, xi1) assert T.op.axis[1] in s[T].leaf_iter_vars # save load json json_str = tvm.save_json(s) s_loaded = tvm.load_json(json_str) assert isinstance(s_loaded, tvm.schedule.Schedule) assert(str(s_loaded.outputs[0].body) == str(s.outputs[0].body)) # pickle unpickle dump = pkl.dumps(s) s_loaded = pkl.loads(dump) assert isinstance(s_loaded, tvm.schedule.Schedule) assert(str(s_loaded.outputs[0].body) == str(s.outputs[0].body))
def test_buffer_vload(): m = tvm.size_var('m') n = tvm.size_var('n') Ab = tvm.decl_buffer((m, n), tvm.float32, elem_offset=100) load = Ab.vload([2, 3]) offset = tvm.ir_pass.Simplify(load.index) assert tvm.ir_pass.Equal(offset, n * 2 + 103)
def test_tensor_comm_reducer(): m = tvm.size_var('m') n = tvm.size_var('n') A = tvm.placeholder((m, n), name='A') k = tvm.reduce_axis((0, n), "k") mysum = tvm.comm_reducer(lambda x, y: x+y, lambda t: tvm.const(0, dtype=t)) C = tvm.compute((m,), lambda i: mysum(A[i, k], axis=k))
def test_tensor_scan(): m = tvm.size_var("m") n = tvm.size_var("n") x = tvm.placeholder((m, n)) s = tvm.placeholder((m, n)) res = tvm.scan(tvm.compute((1, n), lambda _, i: x[0, i]), tvm.compute((m, n), lambda t, i: s[t - 1, i] + x[t, i]), s) assert tuple(res.shape) == (m, n)
def test_expand_dims_infer_type(): for dtype in ['float16', 'float32']: n, t, d = tvm.size_var("n"), tvm.size_var("t"), 100 x = relay.var("x", shape=(n, t, d), dtype=dtype) y = relay.expand_dims(x, axis=2) assert "axis=2" in y.astext() yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((n, t, 1, 100), dtype)
def test_bitserial_dense(): m, k = tvm.size_var("m"), tvm.size_var("k") x = relay.var("x", relay.TensorType((m, k), "int16")) w = relay.var("w", relay.TensorType((k, 32), "int16")) y = relay.nn.bitserial_dense(x, w, units=32) "units=8" in y.astext() yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((m, 32), "int16")
def test_tensor_reduce_multi_axis(): m = tvm.size_var('m') n = tvm.size_var('n') A = tvm.placeholder((m, n), name='A') k1 = tvm.reduce_axis((0, n), "k") k2 = tvm.reduce_axis((0, m), "k") C = tvm.compute((1,), lambda _: tvm.sum(A[k1, k2], axis=(k1, k2))) C = tvm.compute((1,), lambda _: tvm.sum(A[k1, k2], axis=[k1, k2]))
def test_verify_compute(): n = tvm.size_var("n") m = tvm.size_var("m") A = tvm.placeholder((n, m), name='A') k = tvm.reduce_axis((0, m), "k") k_ = tvm.reduce_axis((0, m - 1), "k_") f1 = lambda i: tvm.sum(A[i, k], axis=k) f2 = lambda i: A[i, 0] + 1 f3 = lambda i: tvm.sum(A[i, k], axis=k) + 1 f4 = lambda i: A[i, 0] * (tvm.sum(A[i, k], axis=k) + 1) f5 = lambda i: (tvm.sum(A[i, k], axis=k), A[i, 0] + 1) f6 = lambda i: (tvm.sum(A[i, k], axis=k), tvm.sum(A[i, k_], axis=k_)) # # Valid compute try: B = tvm.compute((n, ), f1, name="B") except tvm._ffi.base.TVMError as ex: assert False # # Valid compute try: B = tvm.compute((n, ), f2, name="B") except tvm._ffi.base.TVMError as ex: assert False # # Invalid compute with non top level reduction try: B = tvm.compute((n, ), f3, name="B") assert False except tvm._ffi.base.TVMError as ex: pass # # Invalid compute with non top level reduction try: B = tvm.compute((n, ), f4, name="B") assert False except tvm._ffi.base.TVMError as ex: pass # # Invalid compute with reduction and non-reduction batch ops try: B0, B1 = tvm.compute((n, ), f5, name="B") assert False except tvm._ffi.base.TVMError as ex: pass # # Invalid compute with unequal batch reduction ops try: B0, B1 = tvm.compute((n, ), f6, name="B") assert False except tvm._ffi.base.TVMError as ex: pass
def test_outer_product(): n = tvm.size_var('n') m = tvm.size_var('m') a = tvm.placeholder((n, ), name='a') b = tvm.placeholder((m, ), name='b') try: c = outer_product(n, m, a, b) ir = c.op.body except IOError as err: assert sys.version_info[0] == 2 and str( err) == 'could not get source code' return #Check for i in (0, n) assert isinstance(ir, tvm.stmt.For) assert ir.loop_var.name == 'i' assert ir.min.value == 0 assert ir.extent.name == 'n' ibody = ir.body assert isinstance(ibody, tvm.stmt.For) #Check for j in (0, m) assert ibody.loop_var.name == 'j' assert ibody.min.value == 0 assert ibody.extent.name == 'm' #Check loop body jblock = ibody.body assert isinstance(jblock, tvm.stmt.SeqStmt) jbody = jblock[0] assert isinstance(jbody, tvm.stmt.AssertStmt) assert isinstance(jbody.message, tvm.expr.StringImm) assert jbody.message.value == "index out of range!" jbody = jblock[1] assert isinstance(jbody, tvm.stmt.Provide) assert jbody.func.name == 'c' assert len(jbody.args) == 2 assert jbody.args[0].name == 'i' assert jbody.args[1].name == 'j' assert isinstance(jbody.value, tvm.expr.Mul) mul = jbody.value assert isinstance(mul.a, tvm.expr.Call) assert mul.a.name == 'a' assert mul.b.name == 'b' func, ins, outs = run_and_check(outer_product, [n, m, a, b], { n: 99, m: 101 }) temp = util.tempdir() path = temp.relpath('%s.py' % func.name) func.save(path) func_ = tvm.hybrid.HybridModule() func_.load(path) run_and_check(func_, ins, {n: 99, m: 101}, outs=outs) for key, _ in HYBRID_GLOBALS.items(): assert key not in globals().keys() assert key not in outer_product.__globals__.keys()
def test_tile(): m = tvm.size_var('m') n = tvm.size_var('n') A = tvm.placeholder((m, n), name='A') T = tvm.compute((m, n), lambda i, j: A[i, j]) s = tvm.create_schedule(T.op) xo, yo, xi, yi = s[T].tile(T.op.axis[0], T.op.axis[1], x_factor=10, y_factor=5) assert tuple(s[T].leaf_iter_vars) == (xo, yo, xi, yi)
def test_dropout(): for dtype in ['float16', 'float32']: n, t, d = tvm.size_var("n"), tvm.size_var("t"), tvm.size_var("d") input_ty = relay.TensorType((n, t, d), dtype) x = relay.var("x", input_ty) y = relay.nn.dropout(x, rate=0.75) assert "rate=" in y.astext() yy = run_infer_type(y) assert yy.checked_type == input_ty
def test_buffer(): m = tvm.size_var('m') n = tvm.size_var('n') l = tvm.size_var('l') Ab = tvm.decl_buffer((m, n), tvm.float32) Bb = tvm.decl_buffer((n, l), tvm.float32) assert isinstance(Ab, tvm.tir.Buffer) assert Ab.dtype == tvm.float32 assert tuple(Ab.shape) == (m, n)
def _compute_binary_scalar_logic(op, dtype, ndim): a = tvm.placeholder([tvm.size_var() for _ in range(ndim)], name='a', dtype=dtype) b = tvm.var('b', dtype='float64') c = tvm.compute([tvm.size_var() for _ in range(ndim)], lambda *idx: _bin_scalar_logic_op_map[op](a, b, *idx), name='c') s = tvm.create_schedule(c.op) return s, a, b, c
def test_buffer_access_ptr(): m = tvm.size_var('m') n = tvm.size_var('n') Ab = tvm.decl_buffer((m, n), tvm.float32, strides=[n + 1, 1]) aptr = Ab.access_ptr("rw") assert tvm.ir_pass.Equal(aptr.args[3], Ab.strides[0] * m) assert aptr.args[0].dtype == Ab.dtype assert aptr.args[4].value == Buffer.READ | Buffer.WRITE aptr = Ab.access_ptr("w") assert aptr.args[4].value == Buffer.WRITE
def test_buffer_access_ptr_extent(): m = tvm.size_var('m') n = tvm.size_var('n') Ab = tvm.decl_buffer((m, n), tvm.float32) aptr = Ab.access_ptr("rw") assert tvm.ir_pass.Equal(aptr.args[3], m * n) aptr = Ab.access_ptr("rw", offset=100) assert tvm.ir_pass.Equal(aptr.args[3], m * n - 100) Ab = tvm.decl_buffer((m, n), tvm.float32, strides=[n + 1, 1]) aptr = Ab.access_ptr("rw", offset=100) assert tvm.ir_pass.Equal(aptr.args[3], Ab.strides[0] * m - 100)
def test_fuse(): m = tvm.size_var('m') n = tvm.size_var('n') A = tvm.placeholder((m, n), name='A') T = tvm.compute((m, n), lambda i, j: A[i, j]) s = tvm.create_schedule(T.op) xo, yo, xi, yi = s[T].tile(T.op.axis[0], T.op.axis[1], x_factor=10, y_factor=5) fused = s[T].fuse(xo, yo) assert any(isinstance(x, tvm.schedule.Fuse) for x in s[T].relations) assert tuple(s[T].leaf_iter_vars) == (fused, xi, yi)
def test_infer_type_prelu(): n, c, h, w = tvm.size_var("n"), tvm.size_var("c"), tvm.size_var( "h"), tvm.size_var("w") verify_infer_type_prelu((n, c, h, w), (c, ), 1, (n, c, h, w)) verify_infer_type_prelu((n, h, w, c), (c, ), 3, (n, h, w, c)) verify_infer_type_prelu((n, c, h, w), None, 1, (n, c, h, w)) verify_infer_type_prelu((n, h, w, c), None, 3, (n, h, w, c)) verify_infer_type_prelu((1, 3, 2, 2), (3, ), 1, (1, 3, 2, 2)) verify_infer_type_prelu((1, 2, 2, 3), (3, ), 3, (1, 2, 2, 3)) verify_infer_type_prelu((1, 3, 2, 2), None, 1, (1, 3, 2, 2)) verify_infer_type_prelu((1, 2, 2, 3), None, 3, (1, 2, 2, 3))
def test_concatenate(): for dtype in ['float16', 'float32']: n, t, d = tvm.size_var("n"), tvm.size_var("t"), 100 x = relay.var("x", shape=(n, t, d)) y = relay.var("y", shape=(n, t, d)) z = relay.concatenate((x, y), axis=-1) assert "axis=" in z.astext() zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((n, t, 200)) x = relay.exp(x) z = relay.concatenate((x, y), axis=2) zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((n, t, 200)) z = relay.concatenate((x, y), axis=1) zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((n, t + t, 100)) # check shape mismatches (the following case is expected to raise tvm._ffi.base.TVMError. try: x = relay.var('p1', shape=(2, 5)) y = relay.var('p2', shape=(2, 3)) c = relay.concatenate([x, y], axis=0) func = relay.Function([x, y], c) zz = run_infer_type(func) except tvm._ffi.base.TVMError: pass else: assert False x = relay.var("x", shape=(10, 5), dtype=dtype) y = relay.var("y", shape=(10, 5), dtype=dtype) t = relay.var("z", shape=(), dtype=dtype) z = relay.concatenate((x, y), axis=1) z = relay.add(z, t) # Check result. func = relay.Function([x, y, t], z) x_data = np.random.rand(10, 5).astype(dtype) y_data = np.random.rand(10, 5).astype(dtype) t_data = np.random.uniform(size=()).astype(dtype) ref_res = np.concatenate((x_data, y_data), axis=1) + t_data for target, ctx in ctx_list(): if dtype == 'float16' and target == 'cuda' and not have_fp16( tvm.gpu(0).compute_version): continue intrp1 = relay.create_executor("graph", ctx=ctx, target=target) intrp2 = relay.create_executor("debug", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(x_data, y_data, t_data) tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=0.01) op_res2 = intrp2.evaluate(func)(x_data, y_data, t_data) tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=0.01)
def compute_add(dtype, ndim): A = tvm.placeholder([tvm.size_var() for _ in range(ndim)], name='A', dtype=dtype) B = tvm.placeholder([tvm.size_var() for _ in range(ndim)], name='B', dtype=dtype) C = tvm.compute([tvm.size_var() for _ in range(ndim)], lambda *index: A[index] + B[index], name='C') s = tvm.create_schedule(C.op) return s, A, B, C
def test_flatten_prefetch(): A = tvm.placeholder((25, 100, 4), name = 'A') _A= tvm.decl_buffer(A.shape, A.dtype, name = 'A'); i = tvm.size_var('i') j = tvm.size_var('j') region = [tvm.ir.Range.make_by_min_extent(i[0], i[1]) for i in [(i, 2), (j, 8), (0, 4)]] stmt = tvm.tir.Prefetch(A.op, 0, A.dtype, region) stmt = tvm.ir_pass.StorageFlatten(stmt, {A: _A}, 64) stmt = tvm.ir_pass.Simplify(stmt) assert stmt.extent.value == 2 assert isinstance(stmt.body, tvm.tir.For) assert stmt.body.extent.value == 2
def test_transpose_infer_type(): n, t, d = tvm.size_var("n"), tvm.size_var("t"), 100 x = relay.var("x", relay.TensorType((n, t, d), "float32")) y = relay.transpose(x, axes=(1, 0, 2)) assert "axes=" in y.astext() yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((t, n, 100), "float32") y = relay.transpose(x) assert "axes=" in y.astext() yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((100, t, n), "float32")
def test_batch_matmul(): b, m, n, k = tvm.size_var("b"), tvm.size_var("m"), tvm.size_var( "n"), tvm.size_var("k") x = relay.var("x", relay.TensorType((b, m, k), "float32")) y = relay.var("y", relay.TensorType((b, n, k), "float32")) z = relay.nn.batch_matmul(x, y) zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((b, m, n), "float32") verify_batch_matmul((1, 16, 32), (1, 16, 32), (1, 16, 16)) verify_batch_matmul((5, 16, 32), (5, 16, 32), (5, 16, 16)) verify_batch_matmul((5, 16, 32), (5, 20, 32), (5, 16, 20)) verify_batch_matmul((30, 16, 32), (30, 20, 32), (30, 16, 20))