def strided_head(H, s_h, s_w): n, c1, h, w, c0 = H.shape out_shape = (n, c1, (h - 1) * s_h + 1, (w - 1) * s_w + 1, c0) H_strided = akg.tvm.compute(out_shape, lambda i0, i1, i2, i3, i4: akg.tvm.expr.Select(akg.tvm.any(truncmod(i2, s_h) != 0, truncmod(i3, s_w) != 0), akg.tvm.const(0.0, dtype="float16"), H[i0, i1, floordiv(i2, s_h), floordiv(i3, s_w), i4]), name=H.name + "_strided") return H_strided
def transpose_convert_head(Head): out_shape = ((floordiv(Head.shape[0].value, block_size)) * Head.shape[2].value * Head.shape[3].value, Head.shape[1].value, block_size, block_size) tmp_6D_shape = (floordiv(Head.shape[0].value, block_size), block_size, Head.shape[1].value, Head.shape[2].value, Head.shape[3].value, block_size) Head_6D = akg.topi.reshape(Head, tmp_6D_shape) # Transpose from (N//block_size_N, block_size_N, C//block_size_C, H, W, block_size_C) # to (N//block_size_N, H, W, C//block_size_C, block_size_C, block_size_N,) Head_6D_transpose = akg.topi.transpose(Head_6D, (0, 3, 4, 2, 5, 1)) Head_transpose_convert = akg.topi.reshape(Head_6D_transpose, out_shape) return Head_transpose_convert
def flip_weight(B, k_c, k_hw, const_shift): out_shape = (B.shape[1].value * k_hw, k_c // block_size, block_size, block_size) B_flip = akg.tvm.compute(out_shape, lambda i0, i1, i2, i3: B[i1 * k_hw + const_shift - truncmod(i0, k_hw), floordiv(i0, k_hw), i3, i2], name=B.name + "_flipped") return B_flip
def transpose_data(A): out_shape = (A.shape[1].value * block_size, A.shape[0].value // block_size, A.shape[2].value, A.shape[3].value, block_size) A_transpose = akg.tvm.compute(out_shape, lambda j0, j1, j2, j3, j4: A[j1 * block_size + j4, floordiv(j0, block_size), j2, j3, truncmod(j0, block_size)], name=A.name + "_transposed") return A_transpose