コード例 #1
0
ファイル: cusolver.py プロジェクト: John1Tang/jax
def syevd_mhlo(dtype, a, lower=False):
    """Symmetric (Hermitian) eigendecomposition."""
    a_type = ir.RankedTensorType(a.type)
    dims = a_type.shape
    assert len(dims) >= 2
    m, n = dims[-2:]
    assert m == n
    batch_dims = tuple(dims[:-2])
    num_bd = len(batch_dims)
    batch = _prod(batch_dims)
    layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))

    if n <= 32:
        kernel = b"cusolver_syevj"
        lwork, opaque = _cusolver.build_syevj_descriptor(
            np.dtype(dtype), lower, batch, n)
    else:
        kernel = b"cusolver_syevd"
        lwork, opaque = _cusolver.build_syevd_descriptor(
            np.dtype(dtype), lower, batch, n)

    if ir.ComplexType.isinstance(a_type.element_type):
        eigvals_type = ir.ComplexType(a_type.element_type).element_type
    else:
        eigvals_type = a_type.element_type

    layout = ir.DenseIntElementsAttr.get(
        np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
        type=ir.IndexType.get())
    i32_type = ir.IntegerType.get_signless(32)
    out = mhlo.CustomCallOp(
        [
            ir.TupleType.get_tuple([
                a.type,
                ir.RankedTensorType.get(batch_dims + (n, ), eigvals_type),
                ir.RankedTensorType.get(batch_dims, i32_type),
                ir.RankedTensorType.get([lwork], a_type.element_type),
            ])
        ], [a],
        call_target_name=ir.StringAttr.get(kernel),
        has_side_effect=ir.BoolAttr.get(False),
        backend_config=ir.StringAttr.get(opaque),
        api_version=ir.IntegerAttr.get(i32_type, 2),
        called_computations=ir.ArrayAttr.get([]),
        operand_layouts=ir.ArrayAttr.get([layout]),
        result_layouts=ir.ArrayAttr.get([
            layout,
            ir.DenseIntElementsAttr.get(np.array(range(num_bd, -1, -1)),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array(range(num_bd - 1, -1, -1)),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
        ]))
    return [
        mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
        for i in range(3)
    ]
コード例 #2
0
ファイル: gpu_solver.py プロジェクト: xueeinstein/jax
def _syevd_mhlo(platform,
                gpu_solver,
                have_jacobi_solver,
                dtype,
                a,
                lower=False):
    """Symmetric (Hermitian) eigendecomposition."""
    a_type = ir.RankedTensorType(a.type)
    dims = a_type.shape
    assert len(dims) >= 2
    m, n = dims[-2:]
    assert m == n
    batch_dims = tuple(dims[:-2])
    num_bd = len(batch_dims)
    batch = _prod(batch_dims)
    layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))

    if have_jacobi_solver and n <= 32:
        kernel = f"{platform}solver_syevj"
        lwork, opaque = gpu_solver.build_syevj_descriptor(
            np.dtype(dtype), lower, batch, n)
    else:
        kernel = f"{platform}solver_syevd"
        lwork, opaque = gpu_solver.build_syevd_descriptor(
            np.dtype(dtype), lower, batch, n)

    if ir.ComplexType.isinstance(a_type.element_type):
        eigvals_type = ir.ComplexType(a_type.element_type).element_type
    else:
        eigvals_type = a_type.element_type

    i32_type = ir.IntegerType.get_signless(32)
    out = custom_call(kernel, [
        a.type,
        ir.RankedTensorType.get(batch_dims + (n, ), eigvals_type),
        ir.RankedTensorType.get(batch_dims, i32_type),
        ir.RankedTensorType.get([lwork], a_type.element_type),
    ], [a],
                      backend_config=opaque,
                      operand_layouts=[layout],
                      result_layouts=[
                          layout,
                          tuple(range(num_bd, -1, -1)),
                          tuple(range(num_bd - 1, -1, -1)),
                          [0],
                      ])
    return out[:3]
コード例 #3
0
ファイル: gpu_solver.py プロジェクト: xueeinstein/jax
def _gesvd_mhlo(platform,
                gpu_solver,
                have_jacobi_solver,
                dtype,
                a,
                full_matrices=True,
                compute_uv=True):
    """Singular value decomposition."""
    a_type = ir.RankedTensorType(a.type)
    dims = a_type.shape
    assert len(dims) >= 2
    m, n = dims[-2:]
    batch_dims = tuple(dims[:-2])
    num_bd = len(batch_dims)
    b = _prod(batch_dims)
    if ir.ComplexType.isinstance(a_type.element_type):
        singular_vals_type = ir.ComplexType(a_type.element_type).element_type
    else:
        singular_vals_type = a_type.element_type

    scalar_layout = tuple(range(num_bd - 1, -1, -1))
    vector_layout = (num_bd, ) + tuple(range(num_bd - 1, -1, -1))
    i32_type = ir.IntegerType.get_signless(32)

    if have_jacobi_solver and m < 32 and n < 32:
        # The batched kernel doesn't support "econ" mode.
        econ = not full_matrices and b == 1
        lwork, opaque = gpu_solver.build_gesvdj_descriptor(
            np.dtype(dtype), b, m, n, compute_uv, 1 if econ else 0)
        k = min(m, n)
        matrix_layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
        _, s, u, v, info, _ = custom_call(f"{platform}solver_gesvdj", [
            a.type,
            ir.RankedTensorType.get(batch_dims +
                                    (min(m, n), ), singular_vals_type),
            ir.RankedTensorType.get(batch_dims + (m, k if econ else m),
                                    a_type.element_type),
            ir.RankedTensorType.get(batch_dims + (n, k if econ else n),
                                    a_type.element_type),
            ir.RankedTensorType.get(batch_dims, i32_type),
            ir.RankedTensorType.get([lwork], a_type.element_type),
        ], [a],
                                          backend_config=opaque,
                                          operand_layouts=[matrix_layout],
                                          result_layouts=[
                                              matrix_layout,
                                              vector_layout,
                                              matrix_layout,
                                              matrix_layout,
                                              scalar_layout,
                                              [0],
                                          ])
        vt = mhlo.TransposeOp(
            v,
            ir.DenseIntElementsAttr.get(
                np.array(tuple(range(num_bd)) + (num_bd + 1, num_bd)))).result
        if np.issubdtype(dtype, np.complexfloating):
            vt = mhlo.ComplexOp(mhlo.RealOp(vt),
                                mhlo.NegOp(mhlo.ImagOp(vt))).result
        if not full_matrices and not econ:
            u = mhlo.SliceOp(
                u, ir.DenseIntElementsAttr.get(np.zeros([len(dims)],
                                                        np.int64)),
                ir.DenseIntElementsAttr.get(
                    np.array(batch_dims + (m, min(m, n)))),
                ir.DenseIntElementsAttr.get(np.ones([len(dims)],
                                                    np.int64))).result
            vt = mhlo.SliceOp(
                vt, ir.DenseIntElementsAttr.get(np.zeros([len(dims)],
                                                         np.int64)),
                ir.DenseIntElementsAttr.get(
                    np.array(batch_dims + (min(m, n), n))),
                ir.DenseIntElementsAttr.get(np.ones([len(dims)],
                                                    np.int64))).result
    elif m < n:
        lwork, opaque = gpu_solver.build_gesvd_descriptor(
            np.dtype(dtype), b, n, m, compute_uv, full_matrices)
        k = n if full_matrices else m
        matrix_layout = (num_bd + 1, num_bd) + tuple(range(num_bd - 1, -1, -1))
        _, s, vt, u, info, _ = custom_call(f"{platform}solver_gesvd", [
            a.type,
            ir.RankedTensorType.get(batch_dims +
                                    (min(m, n), ), singular_vals_type),
            ir.RankedTensorType.get(batch_dims + (k, n), a_type.element_type),
            ir.RankedTensorType.get(batch_dims + (m, m), a_type.element_type),
            ir.RankedTensorType.get(batch_dims, i32_type),
            ir.RankedTensorType.get([lwork], a_type.element_type),
        ], [a],
                                           backend_config=opaque,
                                           operand_layouts=[matrix_layout],
                                           result_layouts=[
                                               matrix_layout,
                                               vector_layout,
                                               matrix_layout,
                                               matrix_layout,
                                               scalar_layout,
                                               [0],
                                           ])
    else:
        lwork, opaque = gpu_solver.build_gesvd_descriptor(
            np.dtype(dtype), b, m, n, compute_uv, full_matrices)
        k = m if full_matrices else n
        matrix_layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
        _, s, u, vt, info, _ = custom_call(f"{platform}solver_gesvd", [
            a.type,
            ir.RankedTensorType.get(batch_dims +
                                    (min(m, n), ), singular_vals_type),
            ir.RankedTensorType.get(batch_dims + (m, k), a_type.element_type),
            ir.RankedTensorType.get(batch_dims + (n, n), a_type.element_type),
            ir.RankedTensorType.get(batch_dims, i32_type),
            ir.RankedTensorType.get([lwork], a_type.element_type),
        ], [a],
                                           backend_config=opaque,
                                           operand_layouts=[matrix_layout],
                                           result_layouts=[
                                               matrix_layout,
                                               vector_layout,
                                               matrix_layout,
                                               matrix_layout,
                                               scalar_layout,
                                               [0],
                                           ])
    return s, u, vt, info
コード例 #4
0
def gesvd_mhlo(dtype, a, full_matrices=True, compute_uv=True):
    """Singular value decomposition."""
    a_type = ir.RankedTensorType(a.type)
    dims = a_type.shape
    assert len(dims) >= 2
    m, n = dims[-2:]
    batch_dims = tuple(dims[:-2])
    num_bd = len(batch_dims)
    b = _prod(batch_dims)
    if ir.ComplexType.isinstance(a_type.element_type):
        singular_vals_type = ir.ComplexType(a_type.element_type).element_type
    else:
        singular_vals_type = a_type.element_type

    scalar_layout = ir.DenseIntElementsAttr.get(np.array(
        tuple(range(num_bd - 1, -1, -1))),
                                                type=ir.IndexType.get())
    vector_layout = ir.DenseIntElementsAttr.get(
        np.array((num_bd, ) + tuple(range(num_bd - 1, -1, -1))),
        type=ir.IndexType.get())
    matrix_layout = ir.DenseIntElementsAttr.get(
        np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
        type=ir.IndexType.get())
    i32_type = ir.IntegerType.get_signless(32)

    if m < n:
        lwork, opaque = _hipsolver.build_gesvd_descriptor(
            np.dtype(dtype), b, n, m, compute_uv, full_matrices)
        k = n if full_matrices else m
        out = mhlo.CustomCallOp(
            [
                ir.TupleType.get_tuple([
                    a.type,
                    ir.RankedTensorType.get(batch_dims +
                                            (min(m, n), ), singular_vals_type),
                    ir.RankedTensorType.get(batch_dims +
                                            (k, n), a_type.element_type),
                    ir.RankedTensorType.get(batch_dims +
                                            (m, m), a_type.element_type),
                    ir.RankedTensorType.get(batch_dims, i32_type),
                    ir.RankedTensorType.get([lwork], a_type.element_type),
                ])
            ], [a],
            call_target_name=ir.StringAttr.get("hipsolver_gesvd"),
            has_side_effect=ir.BoolAttr.get(False),
            backend_config=ir.StringAttr.get(opaque),
            api_version=ir.IntegerAttr.get(i32_type, 2),
            called_computations=ir.ArrayAttr.get([]),
            operand_layouts=ir.ArrayAttr.get([matrix_layout]),
            result_layouts=ir.ArrayAttr.get([
                matrix_layout,
                vector_layout,
                matrix_layout,
                matrix_layout,
                scalar_layout,
                ir.DenseIntElementsAttr.get(np.array([0]),
                                            type=ir.IndexType.get()),
            ]))
        s = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 1)).result
        vt = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type,
                                                            2)).result
        u = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 3)).result
        info = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type,
                                                              4)).result
    else:
        lwork, opaque = _hipsolver.build_gesvd_descriptor(
            np.dtype(dtype), b, m, n, compute_uv, full_matrices)
        k = m if full_matrices else n
        out = mhlo.CustomCallOp(
            [
                ir.TupleType.get_tuple([
                    a.type,
                    ir.RankedTensorType.get(batch_dims +
                                            (min(m, n), ), singular_vals_type),
                    ir.RankedTensorType.get(batch_dims +
                                            (m, k), a_type.element_type),
                    ir.RankedTensorType.get(batch_dims +
                                            (n, n), a_type.element_type),
                    ir.RankedTensorType.get(batch_dims, i32_type),
                    ir.RankedTensorType.get([lwork], a_type.element_type),
                ])
            ], [a],
            call_target_name=ir.StringAttr.get("hipsolver_gesvd"),
            has_side_effect=ir.BoolAttr.get(False),
            backend_config=ir.StringAttr.get(opaque),
            api_version=ir.IntegerAttr.get(i32_type, 2),
            called_computations=ir.ArrayAttr.get([]),
            operand_layouts=ir.ArrayAttr.get([matrix_layout]),
            result_layouts=ir.ArrayAttr.get([
                matrix_layout,
                vector_layout,
                matrix_layout,
                matrix_layout,
                scalar_layout,
                ir.DenseIntElementsAttr.get(np.array([0]),
                                            type=ir.IndexType.get()),
            ]))
        s = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 1)).result
        u = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 2)).result
        vt = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type,
                                                            3)).result
        info = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type,
                                                              4)).result
    return s, u, vt, info
コード例 #5
0
ファイル: cusolver.py プロジェクト: John1Tang/jax
def gesvd_mhlo(dtype, a, full_matrices=True, compute_uv=True):
    """Singular value decomposition."""
    a_type = ir.RankedTensorType(a.type)
    dims = a_type.shape
    assert len(dims) >= 2
    m, n = dims[-2:]
    batch_dims = tuple(dims[:-2])
    num_bd = len(batch_dims)
    b = _prod(batch_dims)
    if ir.ComplexType.isinstance(a_type.element_type):
        singular_vals_type = ir.ComplexType(a_type.element_type).element_type
    else:
        singular_vals_type = a_type.element_type

    scalar_layout = ir.DenseIntElementsAttr.get(np.array(
        tuple(range(num_bd - 1, -1, -1))),
                                                type=ir.IndexType.get())
    vector_layout = ir.DenseIntElementsAttr.get(
        np.array((num_bd, ) + tuple(range(num_bd - 1, -1, -1))),
        type=ir.IndexType.get())
    i32_type = ir.IntegerType.get_signless(32)

    if m < 32 and n < 32:
        # The batched kernel doesn't support "econ" mode.
        econ = not full_matrices and b == 1
        lwork, opaque = _cusolver.build_gesvdj_descriptor(
            np.dtype(dtype), b, m, n, compute_uv, 1 if econ else 0)
        k = min(m, n)
        matrix_layout = ir.DenseIntElementsAttr.get(
            np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
            type=ir.IndexType.get())
        out = mhlo.CustomCallOp(
            [
                ir.TupleType.get_tuple([
                    a.type,
                    ir.RankedTensorType.get(batch_dims +
                                            (min(m, n), ), singular_vals_type),
                    ir.RankedTensorType.get(batch_dims + (m, k if econ else m),
                                            a_type.element_type),
                    ir.RankedTensorType.get(batch_dims + (n, k if econ else n),
                                            a_type.element_type),
                    ir.RankedTensorType.get(batch_dims, i32_type),
                    ir.RankedTensorType.get([lwork], a_type.element_type),
                ])
            ], [a],
            call_target_name=ir.StringAttr.get("cusolver_gesvdj"),
            has_side_effect=ir.BoolAttr.get(False),
            backend_config=ir.StringAttr.get(opaque),
            api_version=ir.IntegerAttr.get(i32_type, 2),
            called_computations=ir.ArrayAttr.get([]),
            operand_layouts=ir.ArrayAttr.get([matrix_layout]),
            result_layouts=ir.ArrayAttr.get([
                matrix_layout,
                vector_layout,
                matrix_layout,
                matrix_layout,
                scalar_layout,
                ir.DenseIntElementsAttr.get(np.array([0]),
                                            type=ir.IndexType.get()),
            ]))
        s = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 1)).result
        u = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 2)).result
        v = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 3))
        info = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type,
                                                              4)).result
        vt = mhlo.TransposeOp(
            v,
            ir.DenseIntElementsAttr.get(
                np.array(tuple(range(num_bd)) + (num_bd + 1, num_bd)))).result
        if np.issubdtype(dtype, np.complexfloating):
            vt = mhlo.ComplexOp(mhlo.RealOp(vt),
                                mhlo.NegOp(mhlo.ImagOp(vt))).result
        if not full_matrices and not econ:
            u = mhlo.SliceOp(
                u, ir.DenseIntElementsAttr.get(np.zeros([len(dims)],
                                                        np.int64)),
                ir.DenseIntElementsAttr.get(
                    np.array(batch_dims + (m, min(m, n)))),
                ir.DenseIntElementsAttr.get(np.ones([len(dims)],
                                                    np.int64))).result
            vt = mhlo.SliceOp(
                vt, ir.DenseIntElementsAttr.get(np.zeros([len(dims)],
                                                         np.int64)),
                ir.DenseIntElementsAttr.get(
                    np.array(batch_dims + (min(m, n), n))),
                ir.DenseIntElementsAttr.get(np.ones([len(dims)],
                                                    np.int64))).result
    elif m < n:
        lwork, opaque = _cusolver.build_gesvd_descriptor(
            np.dtype(dtype), b, n, m, compute_uv, full_matrices)
        k = n if full_matrices else m
        matrix_layout = ir.DenseIntElementsAttr.get(
            np.array((num_bd + 1, num_bd) + tuple(range(num_bd - 1, -1, -1))),
            type=ir.IndexType.get())
        out = mhlo.CustomCallOp(
            [
                ir.TupleType.get_tuple([
                    a.type,
                    ir.RankedTensorType.get(batch_dims +
                                            (min(m, n), ), singular_vals_type),
                    ir.RankedTensorType.get(batch_dims +
                                            (k, n), a_type.element_type),
                    ir.RankedTensorType.get(batch_dims +
                                            (m, m), a_type.element_type),
                    ir.RankedTensorType.get(batch_dims, i32_type),
                    ir.RankedTensorType.get([lwork], a_type.element_type),
                ])
            ], [a],
            call_target_name=ir.StringAttr.get("cusolver_gesvd"),
            has_side_effect=ir.BoolAttr.get(False),
            backend_config=ir.StringAttr.get(opaque),
            api_version=ir.IntegerAttr.get(i32_type, 2),
            called_computations=ir.ArrayAttr.get([]),
            operand_layouts=ir.ArrayAttr.get([matrix_layout]),
            result_layouts=ir.ArrayAttr.get([
                matrix_layout,
                vector_layout,
                matrix_layout,
                matrix_layout,
                scalar_layout,
                ir.DenseIntElementsAttr.get(np.array([0]),
                                            type=ir.IndexType.get()),
            ]))
        s = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 1)).result
        vt = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type,
                                                            2)).result
        u = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 3)).result
        info = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type,
                                                              4)).result
    else:
        lwork, opaque = _cusolver.build_gesvd_descriptor(
            np.dtype(dtype), b, m, n, compute_uv, full_matrices)
        k = m if full_matrices else n
        matrix_layout = ir.DenseIntElementsAttr.get(
            np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
            type=ir.IndexType.get())
        out = mhlo.CustomCallOp(
            [
                ir.TupleType.get_tuple([
                    a.type,
                    ir.RankedTensorType.get(batch_dims +
                                            (min(m, n), ), singular_vals_type),
                    ir.RankedTensorType.get(batch_dims +
                                            (m, k), a_type.element_type),
                    ir.RankedTensorType.get(batch_dims +
                                            (n, n), a_type.element_type),
                    ir.RankedTensorType.get(batch_dims, i32_type),
                    ir.RankedTensorType.get([lwork], a_type.element_type),
                ])
            ], [a],
            call_target_name=ir.StringAttr.get("cusolver_gesvd"),
            has_side_effect=ir.BoolAttr.get(False),
            backend_config=ir.StringAttr.get(opaque),
            api_version=ir.IntegerAttr.get(i32_type, 2),
            called_computations=ir.ArrayAttr.get([]),
            operand_layouts=ir.ArrayAttr.get([matrix_layout]),
            result_layouts=ir.ArrayAttr.get([
                matrix_layout,
                vector_layout,
                matrix_layout,
                matrix_layout,
                scalar_layout,
                ir.DenseIntElementsAttr.get(np.array([0]),
                                            type=ir.IndexType.get()),
            ]))
        s = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 1)).result
        u = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 2)).result
        vt = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type,
                                                            3)).result
        info = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type,
                                                              4)).result
    return s, u, vt, info
コード例 #6
0
ファイル: lapack.py プロジェクト: frederikwilde/jax
def gees_mhlo(a, jobvs=True, sort=False, select=None):
    a_type = ir.RankedTensorType(a.type)
    etype = a_type.element_type
    dims = a_type.shape
    assert len(dims) >= 2
    m, n = dims[-2:]
    assert m == n
    batch_dims = tuple(dims[:-2])
    num_bd = len(batch_dims)
    b = 1
    for d in batch_dims:
        b *= d
    layout = ir.DenseIntElementsAttr.get(
        np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
        type=ir.IndexType.get())

    if sort:
        raise NotImplementedError(
            "The sort feature of LAPACK's gees routine is not implemented.")

    jobvs = ord('V' if jobvs else 'N')
    sort = ord('S' if sort else 'N')

    if not ir.ComplexType.isinstance(etype):
        fn = "lapack_sgees" if etype == ir.F32Type.get() else "lapack_dgees"
        schurvecs_type = etype
        workspaces = [ir.RankedTensorType.get(dims, schurvecs_type)]
        workspace_layouts = [layout]
        eigvals = [ir.RankedTensorType.get(batch_dims + (n, ), etype)] * 2
        eigvals_layouts = [
            ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1),
                                        type=ir.IndexType.get())
        ] * 2
    else:
        fn = ("lapack_cgees" if etype == ir.ComplexType.get(ir.F32Type.get())
              else "lapack_zgees")
        schurvecs_type = etype
        workspaces = [
            ir.RankedTensorType.get(dims, schurvecs_type),
            ir.RankedTensorType.get([n],
                                    ir.ComplexType(etype).element_type),
        ]
        workspace_layouts = [
            layout,
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
        ]
        eigvals = [ir.RankedTensorType.get(batch_dims + (n, ), etype)]
        eigvals_layouts = [
            ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1),
                                        type=ir.IndexType.get())
        ]

    i32_type = ir.IntegerType.get_signless(32)

    scalar_layout = ir.DenseIntElementsAttr.get(np.zeros((0, ), np.int64),
                                                type=ir.IndexType.get())
    out = mhlo.CustomCallOp(
        [
            ir.TupleType.get_tuple(workspaces + eigvals + [
                ir.RankedTensorType.get(dims, schurvecs_type),
                ir.RankedTensorType.get(batch_dims, i32_type),
                ir.RankedTensorType.get(batch_dims, i32_type),
            ])
        ],
        [
            _mhlo_s32(b),
            _mhlo_s32(n),
            _mhlo_u8(np.uint8(jobvs)),
            _mhlo_u8(np.uint8(sort)),
            # TODO: figure out how to put the callable select function here
            a
        ],
        call_target_name=ir.StringAttr.get(fn),
        has_side_effect=ir.BoolAttr.get(False),
        backend_config=ir.StringAttr.get(""),
        api_version=ir.IntegerAttr.get(i32_type, 2),
        called_computations=ir.ArrayAttr.get([]),
        operand_layouts=ir.ArrayAttr.get([scalar_layout] * 4 + [layout]),
        result_layouts=ir.ArrayAttr.get(workspace_layouts + eigvals_layouts + [
            layout,
            ir.DenseIntElementsAttr.get(np.arange(num_bd - 1, -1, -1),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.arange(num_bd - 1, -1, -1),
                                        type=ir.IndexType.get()),
        ]))
    i32_attr = lambda i: ir.IntegerAttr.get(i32_type, i)
    if sort == ord('S'):
        return (mhlo.GetTupleElementOp(out, i32_attr(0)).result,
                mhlo.GetTupleElementOp(out, i32_attr(3)).result,
                mhlo.GetTupleElementOp(out, i32_attr(4)).result,
                mhlo.GetTupleElementOp(out, i32_attr(5)).result)
    else:
        return (mhlo.GetTupleElementOp(out, i32_attr(0)).result,
                mhlo.GetTupleElementOp(out, i32_attr(3)).result,
                mhlo.GetTupleElementOp(out, i32_attr(5)).result)
コード例 #7
0
def gees_mhlo(dtype, a, jobvs=True, sort=False, select=None):
  a_type = ir.RankedTensorType(a.type)
  etype = a_type.element_type
  dims = a_type.shape
  assert len(dims) >= 2
  m, n = dims[-2:]
  assert m == n
  batch_dims = tuple(dims[:-2])
  num_bd = len(batch_dims)
  b = 1
  for d in batch_dims:
    b *= d
  layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))

  if sort:
    raise NotImplementedError(
        "The sort feature of LAPACK's gees routine is not implemented.")

  jobvs = ord('V' if jobvs else 'N')
  sort = ord('S' if sort else 'N')

  if dtype == np.float32:
    fn = "lapack_sgees"
  elif dtype == np.float64:
    fn = "lapack_dgees"
  elif dtype == np.complex64:
    fn = "lapack_cgees"
  elif dtype == np.complex128:
    fn = "lapack_zgees"
  else:
    raise NotImplementedError(f"Unsupported dtype {dtype}")

  if not np.issubdtype(dtype, np.complexfloating):
    workspaces = [ir.RankedTensorType.get(dims, etype)]
    workspace_layouts = [layout]
    eigvals = [ir.RankedTensorType.get(batch_dims + (n,), etype)] * 2
    eigvals_layouts = [tuple(range(num_bd, -1, -1))] * 2
  else:
    workspaces = [
        ir.RankedTensorType.get(dims, etype),
        ir.RankedTensorType.get([n], ir.ComplexType(etype).element_type),
    ]
    workspace_layouts = [layout, [0]]
    eigvals = [ir.RankedTensorType.get(batch_dims + (n,), etype)]
    eigvals_layouts = [tuple(range(num_bd, -1, -1))]

  i32_type = ir.IntegerType.get_signless(32)

  scalar_layout = []
  out = custom_call(
      fn,
      workspaces + eigvals + [
        ir.RankedTensorType.get(dims, etype),
        ir.RankedTensorType.get(batch_dims, i32_type),
        ir.RankedTensorType.get(batch_dims, i32_type),
      ],
      [
        _mhlo_s32(b),
        _mhlo_s32(n),
        _mhlo_u8(np.uint8(jobvs)),
        _mhlo_u8(np.uint8(sort)),
        # TODO: figure out how to put the callable select function here
        a
      ],
      operand_layouts=[scalar_layout] * 4 + [layout],
      result_layouts=workspace_layouts + eigvals_layouts + [
        layout,
        tuple(range(num_bd - 1, -1, -1)),
        tuple(range(num_bd - 1, -1, -1)),
      ]
  )
  if sort == ord('S'):
    return (out[0], out[3], out[4], out[5])
  else:
    return (out[0], out[3], out[5])