コード例 #1
0
def matmul_mul(x,
               y,
               c,
               b,
               out_dtype,
               left_format="zZ",
               right_format="nZ",
               out_format="zN",
               transpose_x=False,
               transpose_y=False,
               attrs=None,
               target="cce"):
    matmul_res, attrs = matmul(x,
                               y,
                               b,
                               out_dtype,
                               left_format,
                               right_format,
                               out_format,
                               transpose_x,
                               transpose_y,
                               attrs=attrs)
    attr = {}
    print(matmul_res.shape)
    res = mul(matmul_res, c, target=target)
    return res, attrs
コード例 #2
0
def matmul_addn(x, y, adds, b, out_dtype, left_format="zZ", right_format="nZ", out_format="zN", transpose_x=False, transpose_y=False,
                attrs=None, target='cce'):
    matmul_res, attrs = matmul(x, y, b, out_dtype, left_format, right_format, out_format, transpose_x, transpose_y, attrs=None)
    attr = {}
    addn_res = Addn(adds, target=target)
    res = Add(matmul_res, addn_res, target=target)
    return res, attrs
コード例 #3
0
def batch_matmul(x1,
                 x2,
                 transpose_a=False,
                 transpose_b=False,
                 target=utils.CCE):
    """use cube version matmul"""
    return math.matmul(x=x1,
                       y=x2,
                       b=None,
                       out_dtype=x1.dtype,
                       left_format="zN",
                       right_format="zN",
                       out_format="zN",
                       transpose_x=transpose_a,
                       transpose_y=transpose_b,
                       target=target)
コード例 #4
0
def mat_mul(x1,
            x2,
            out_dtype,
            transpose_a=False,
            transpose_b=False,
            target=utils.CCE):
    """MatMul"""
    return math.matmul(x=x1,
                       y=x2,
                       b=None,
                       out_dtype=out_dtype,
                       left_format="zN",
                       right_format="zN",
                       out_format="zN",
                       transpose_x=transpose_a,
                       transpose_y=transpose_b,
                       target=target)
コード例 #5
0
ファイル: matmul4d_ad.py プロジェクト: mindspore-ai/akg
def matmul4d_ad(head, x, y, b, out_dtype, adj_x=False, adj_y=False):
    """compute 4d format mat shape from shape inputs."""
    shape_xx = get_shape(x)

    if adj_x:  # no need to change in this case
        shape_xx_forward = shape_xx

    else:
        batch_num, m_o, k_o, m_i, k_i = shape_xx
        shape_xx_forward = (batch_num, k_o, m_o, k_i, m_i)

    ########################################
    #  compute the forward kernel          #
    ########################################

    x_temp = akg.tvm.placeholder(shape_xx_forward,
                                 name="input_1",
                                 dtype=x.dtype)

    # we transfer all cases to that of adj_x=False
    out = matmul(x_temp, y, b, out_dtype, "zN", "nZ", "zN", False, adj_y)[0]

    ########################################
    #  compute the backward kernel         #
    ########################################

    _jacs = list(akg.differentiate(out, [x_temp], head))

    if adj_x:
        grad = akg.tvm.compute(
            shape_xx, lambda n, ko, mo, ki, mi: _jacs[0][n, ko, mo, mi, ki])
    else:
        grad = akg.tvm.compute(
            shape_xx, lambda n, mo, ko, mi, ki: _jacs[0][n, ko, mo, mi, ki])

    sjacs = akg.tvm.create_schedule([grad.op])

    attrs = dict()

    attrs["pragma_data_transpose"] = "Y"
    attrs["pragma_data_transpose_block"] = "Y"
    if not adj_y:
        attrs["pragma_weight_transpose"] = "Y"

    return grad, attrs
コード例 #6
0
def transdata_matmul(x,
                     y,
                     b,
                     out_dtype,
                     left_format="zZ",
                     right_format="nZ",
                     out_format="zN",
                     transpose_x=False,
                     transpose_y=False,
                     attrs=None,
                     target="cce"):
    x_fractal_shape = get_matmul_fractal_shape(x, 'zN')
    y_fractal_shape = get_matmul_fractal_shape(y, 'zN')

    func = akg.tvm.get_global_func("TransData")
    x = func(
        [x], {
            "src_format": "DefaultFormat",
            "dst_format": "FRACTAL_NZ",
            "output_shape": x_fractal_shape
        })
    y = func(
        [y], {
            "src_format": "DefaultFormat",
            "dst_format": "FRACTAL_NZ",
            "output_shape": y_fractal_shape
        })

    res, attrs = matmul(x,
                        y,
                        b,
                        out_dtype,
                        left_format,
                        right_format,
                        out_format,
                        transpose_x,
                        transpose_y,
                        attrs=attrs)
    return res, attrs
コード例 #7
0
def matmul_addn_transdata(x,
                          y,
                          adds,
                          b,
                          out_dtype,
                          left_format="zZ",
                          right_format="nZ",
                          out_format="zN",
                          transpose_x=False,
                          transpose_y=False,
                          attrs=None,
                          target='cce'):
    matmul_res, attrs_mat = matmul(x,
                                   y,
                                   b,
                                   out_dtype,
                                   left_format,
                                   right_format,
                                   out_format,
                                   transpose_x,
                                   transpose_y,
                                   attrs=attrs)
    addn_res = Addn(adds, target=target)
    res = Add(matmul_res, addn_res, target=target)
    if out_format == 'zN':
        n1, m1, m0, n0 = matmul_res.shape[-4:]
        new_shape = matmul_res.shape[:-4] + [m1 * m0, n1 * n0]
    elif out_format == 'zZ':
        m1, n1, m0, n0 = matmul_res.shape[-4:]
        new_shape = matmul_res.shape[:-4] + [m1 * m0, n1 * n0]

    func = akg.tvm.get_global_func("TransData")
    res = func(
        [res], {
            "src_format": "FRACTAL_NZ",
            "dst_format": "DefaultFormat",
            "output_shape": new_shape
        })
    return res, attrs_mat
コード例 #8
0
def matmul_transdata(x,
                     y,
                     b,
                     out_dtype,
                     left_format="zZ",
                     right_format="nZ",
                     out_format="zN",
                     transpose_x=False,
                     transpose_y=False,
                     attrs=None,
                     target="cce"):
    matmul_res, attrs = matmul(x,
                               y,
                               b,
                               out_dtype,
                               left_format,
                               right_format,
                               out_format,
                               transpose_x,
                               transpose_y,
                               attrs=None)
    if out_format == 'zN':
        n1, m1, m0, n0 = matmul_res.shape[-4:]
        new_shape = matmul_res.shape[:-4] + [m1 * m0, n1 * n0]
        tranpose_axis = [1, 2, 0, 3]
    elif out_format == 'zZ':
        m1, n1, m0, n0 = matmul_res.shape[-4:]
        new_shape = matmul_res.shape[:-4] + [m1 * m0, n1 * n0]
        tranpose_axis = [0, 2, 1, 3]

    func = akg.tvm.get_global_func("TransData")
    res = func(
        [matmul_res], {
            "src_format": "FRACTAL_NZ",
            "dst_format": "DefaultFormat",
            "output_shape": new_shape
        })
    return res, attrs
コード例 #9
0
def matmul_gelugrad(x,
                    y,
                    dy,
                    b,
                    out_dtype,
                    left_format="zZ",
                    right_format="nZ",
                    out_format="zN",
                    transpose_x=False,
                    transpose_y=False,
                    attrs=None):
    matmul_res, attrs = matmul(x,
                               y,
                               b,
                               out_dtype,
                               left_format,
                               right_format,
                               out_format,
                               transpose_x,
                               transpose_y,
                               attrs=None)
    res = gelu_grad.gelu_grad(matmul_res, dy)
    return res, attrs
コード例 #10
0
def matmul_tanh(x,
                y,
                b,
                out_dtype,
                left_format="zZ",
                right_format="nZ",
                out_format="zN",
                transpose_x=False,
                transpose_y=False,
                attrs=None,
                target="cce"):
    matmul_res, attrs = matmul(x,
                               y,
                               b,
                               out_dtype,
                               left_format,
                               right_format,
                               out_format,
                               transpose_x,
                               transpose_y,
                               attrs=None)

    res = Tanh(matmul_res)
    return res, attrs
コード例 #11
0
def matmul_tensoradd(x, y, c, b, out_dtype, left_format="zZ", right_format="nZ", out_format="zN", transpose_x=False, transpose_y=False,
                    attrs=None, target='cce'):
    matmul_res, attrs = matmul(x, y, b, out_dtype, left_format, right_format, out_format, transpose_x, transpose_y, attrs=None)
    res = Add(matmul_res, c, target=target)
    return res, attrs