Exemplo n.º 1
0
    def get_gemm_feature(target):
        k = tvm.reduce_axis((0, N), 'k')
        A = tvm.placeholder((N, N), name='A')
        B = tvm.placeholder((N, N), name='B')
        C = tvm.compute(A.shape,
                        lambda y, x: tvm.sum(A[y, k] * B[k, x], axis=k),
                        name='C')

        s = tvm.create_schedule(C.op)

        y, x = s[C].op.axis
        axes = list(s[C].tile(y, x, 8, 8)) + [k]
        perm = np.random.permutation(5)
        axes = [axes[x] for x in perm]
        s[C].reorder(*axes)

        if "gpu" in target.keys:
            pick = []
            # filter out reduction axis
            for i in range(len(perm)):
                if perm[i] != 4:
                    pick.append(axes[i])
            s[C].bind(pick[0], tvm.thread_axis("blockIdx.x"))
            s[C].bind(pick[1], tvm.thread_axis("vthread"))
            s[C].bind(pick[2], tvm.thread_axis("threadIdx.y"))

        with target:
            feas = feature.get_itervar_feature(s, [A, B, C])
            feas = feature.flatten_itervar_feature(feas)
        return feas
Exemplo n.º 2
0
    def get_gemm_feature(target):
        k = tvm.reduce_axis((0, N), 'k')
        A = tvm.placeholder((N, N), name='A')
        B = tvm.placeholder((N, N), name='B')
        C = tvm.compute(A.shape, lambda y, x: tvm.sum(A[y, k] * B[k, x], axis=k),
                        name='C')

        s = tvm.create_schedule(C.op)

        y, x = s[C].op.axis
        axes = list(s[C].tile(y, x, 8, 8)) + [k]
        perm = np.random.permutation(5)
        axes = [axes[x] for x in perm]
        s[C].reorder(*axes)

        if "gpu" in target.keys:
            pick = []
            # filter out reduction axis
            for i in range(len(perm)):
                if perm[i] != 4:
                    pick.append(axes[i])
            s[C].bind(pick[0], tvm.thread_axis("blockIdx.x"))
            s[C].bind(pick[1], tvm.thread_axis("vthread"))
            s[C].bind(pick[2], tvm.thread_axis("threadIdx.y"))

        with target:
            feas = feature.get_itervar_feature(s, [A, B, C])
            feas = feature.flatten_itervar_feature(feas)
        return feas