def get_gemm_feature(target): k = tvm.reduce_axis((0, N), 'k') A = tvm.placeholder((N, N), name='A') B = tvm.placeholder((N, N), name='B') C = tvm.compute(A.shape, lambda y, x: tvm.sum(A[y, k] * B[k, x], axis=k), name='C') s = tvm.create_schedule(C.op) y, x = s[C].op.axis axes = list(s[C].tile(y, x, 8, 8)) + [k] perm = np.random.permutation(5) axes = [axes[x] for x in perm] s[C].reorder(*axes) if "gpu" in target.keys: pick = [] # filter out reduction axis for i in range(len(perm)): if perm[i] != 4: pick.append(axes[i]) s[C].bind(pick[0], tvm.thread_axis("blockIdx.x")) s[C].bind(pick[1], tvm.thread_axis("vthread")) s[C].bind(pick[2], tvm.thread_axis("threadIdx.y")) with target: feas = feature.get_itervar_feature(s, [A, B, C]) feas = feature.flatten_itervar_feature(feas) return feas