예제 #1
0
파일: hipsparse.py 프로젝트: John1Tang/jax
def coo_todense_mhlo(data, row, col, *, shape, data_dtype, index_dtype):
    """COO to dense matrix."""
    data_type, _, nnz = _validate_coo_mhlo(data, row, col, shape)
    rows, cols = shape

    buffer_size, opaque = _hipsparse.build_coo_todense_descriptor(
        data_dtype, index_dtype, rows, cols, nnz)

    i32_type = ir.IntegerType.get_signless(32)
    out = mhlo.CustomCallOp(
        [
            ir.TupleType.get_tuple([
                ir.RankedTensorType.get(shape, data_type),
                ir.RankedTensorType.get([buffer_size],
                                        ir.IntegerType.get_signless(8)),
            ])
        ], [data, row, col],
        call_target_name=ir.StringAttr.get("hipsparse_coo_todense"),
        has_side_effect=ir.BoolAttr.get(False),
        backend_config=ir.StringAttr.get(opaque),
        api_version=ir.IntegerAttr.get(i32_type, 2),
        called_computations=ir.ArrayAttr.get([]),
        operand_layouts=ir.ArrayAttr.get([
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
        ] * 3),
        result_layouts=ir.ArrayAttr.get([
            ir.DenseIntElementsAttr.get(np.array([1, 0]),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
        ]))
    return mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 0)).result
예제 #2
0
파일: cuda_prng.py 프로젝트: John1Tang/jax
def threefry2x32_lowering(keys, data):
  """ThreeFry2x32 kernel for GPU."""
  assert len(keys) == 2, keys
  assert len(data) == 2, data
  assert (ir.RankedTensorType(keys[0].type).element_type ==
          ir.IntegerType.get_unsigned(32)), keys[0].type
  typ = keys[0].type
  dims = ir.RankedTensorType(typ).shape

  for x in itertools.chain(keys, data):
    assert x.type == typ, (x.type, typ)
  ndims = len(dims)

  opaque = _cuda_prng.cuda_threefry2x32_descriptor(_prod(dims))
  layout = ir.DenseIntElementsAttr.get(np.arange(ndims - 1, -1, -1),
                                       type=ir.IndexType.get())
  i32_type = ir.IntegerType.get_signless(32)
  tup = mhlo.CustomCallOp(
      [ir.TupleType.get_tuple([typ, typ])],
      [keys[0], keys[1], data[0], data[1]],
      call_target_name = ir.StringAttr.get("cuda_threefry2x32"),
      has_side_effect=ir.BoolAttr.get(False),
      backend_config=ir.StringAttr.get(opaque),
      api_version=ir.IntegerAttr.get(i32_type, 2),
      called_computations=ir.ArrayAttr.get([]),
      operand_layouts=ir.ArrayAttr.get([layout] * 4),
      result_layouts=ir.ArrayAttr.get([layout] * 2)).result
  return [
    mhlo.GetTupleElementOp(tup, ir.IntegerAttr.get(i32_type, i)).result
    for i in range(2)
  ]
예제 #3
0
def lu_pivots_to_permutation_mhlo(pivots, *, permutation_size):
    """Kernel for the transformation of pivots to permutations on GPU."""
    typ = ir.RankedTensorType(pivots.type)
    dims = typ.shape
    i32_type = ir.IntegerType.get_signless(32)

    assert typ.element_type == i32_type, typ

    batch_size = _prod(dims[:-1])
    pivot_size = dims[-1]

    opaque = _hip_linalg.hip_lu_pivots_to_permutation_descriptor(
        batch_size, pivot_size, permutation_size)
    pivots_layout = ir.DenseIntElementsAttr.get(np.arange(
        len(dims) - 1, -1, -1),
                                                type=ir.IndexType.get())
    permutations_layout = pivots_layout
    permutations_dims = list(dims)
    permutations_dims[-1] = permutation_size
    permutations_type = ir.RankedTensorType.get(permutations_dims, i32_type)
    return mhlo.CustomCallOp(
        [permutations_type], [pivots],
        call_target_name=ir.StringAttr.get("hip_lu_pivots_to_permutation"),
        has_side_effect=ir.BoolAttr.get(False),
        backend_config=ir.StringAttr.get(opaque),
        api_version=ir.IntegerAttr.get(i32_type, 2),
        called_computations=ir.ArrayAttr.get([]),
        operand_layouts=ir.ArrayAttr.get([pivots_layout]),
        result_layouts=ir.ArrayAttr.get([permutations_layout])).result
예제 #4
0
파일: hipsparse.py 프로젝트: John1Tang/jax
def csr_matmat_mhlo(data,
                    indices,
                    indptr,
                    B,
                    *,
                    shape,
                    transpose=False,
                    compute_dtype=None,
                    compute_type=None,
                    index_dtype,
                    data_dtype,
                    B_dtype):
    """CSR from dense matrix."""
    data_type, index_type, nnz = _validate_csr_mhlo(data, indices, indptr,
                                                    shape)
    rows, cols = shape
    B_shape = ir.RankedTensorType(B.type).shape
    _, Ccols = B_shape

    if compute_dtype is None:
        compute_dtype = data_dtype
        compute_type = data_type

    buffer_size, opaque = _hipsparse.build_csr_matmat_descriptor(
        data_dtype, B_dtype, compute_dtype, index_dtype, rows, cols, Ccols,
        nnz, transpose)
    out_size = cols if transpose else rows

    i32_type = ir.IntegerType.get_signless(32)
    out = mhlo.CustomCallOp(
        [
            ir.TupleType.get_tuple([
                ir.RankedTensorType.get([out_size, Ccols], compute_type),
                ir.RankedTensorType.get([buffer_size],
                                        ir.IntegerType.get_signless(8)),
            ])
        ], [data, indices, indptr, B],
        call_target_name=ir.StringAttr.get("hipsparse_csr_matmat"),
        has_side_effect=ir.BoolAttr.get(False),
        backend_config=ir.StringAttr.get(opaque),
        api_version=ir.IntegerAttr.get(i32_type, 2),
        called_computations=ir.ArrayAttr.get([]),
        operand_layouts=ir.ArrayAttr.get([
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([1, 0]),
                                        type=ir.IndexType.get()),
        ]),
        result_layouts=ir.ArrayAttr.get([
            ir.DenseIntElementsAttr.get(np.array([1, 0]),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
        ]))
    return mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 0)).result
예제 #5
0
파일: cusolver.py 프로젝트: John1Tang/jax
def syevd_mhlo(dtype, a, lower=False):
    """Symmetric (Hermitian) eigendecomposition."""
    a_type = ir.RankedTensorType(a.type)
    dims = a_type.shape
    assert len(dims) >= 2
    m, n = dims[-2:]
    assert m == n
    batch_dims = tuple(dims[:-2])
    num_bd = len(batch_dims)
    batch = _prod(batch_dims)
    layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))

    if n <= 32:
        kernel = b"cusolver_syevj"
        lwork, opaque = _cusolver.build_syevj_descriptor(
            np.dtype(dtype), lower, batch, n)
    else:
        kernel = b"cusolver_syevd"
        lwork, opaque = _cusolver.build_syevd_descriptor(
            np.dtype(dtype), lower, batch, n)

    if ir.ComplexType.isinstance(a_type.element_type):
        eigvals_type = ir.ComplexType(a_type.element_type).element_type
    else:
        eigvals_type = a_type.element_type

    layout = ir.DenseIntElementsAttr.get(
        np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
        type=ir.IndexType.get())
    i32_type = ir.IntegerType.get_signless(32)
    out = mhlo.CustomCallOp(
        [
            ir.TupleType.get_tuple([
                a.type,
                ir.RankedTensorType.get(batch_dims + (n, ), eigvals_type),
                ir.RankedTensorType.get(batch_dims, i32_type),
                ir.RankedTensorType.get([lwork], a_type.element_type),
            ])
        ], [a],
        call_target_name=ir.StringAttr.get(kernel),
        has_side_effect=ir.BoolAttr.get(False),
        backend_config=ir.StringAttr.get(opaque),
        api_version=ir.IntegerAttr.get(i32_type, 2),
        called_computations=ir.ArrayAttr.get([]),
        operand_layouts=ir.ArrayAttr.get([layout]),
        result_layouts=ir.ArrayAttr.get([
            layout,
            ir.DenseIntElementsAttr.get(np.array(range(num_bd, -1, -1)),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array(range(num_bd - 1, -1, -1)),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
        ]))
    return [
        mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
        for i in range(3)
    ]
예제 #6
0
def pocketfft_mhlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]):
    """PocketFFT kernel for CPU."""
    a_type = ir.RankedTensorType(a.type)
    n = len(a_type.shape)

    fft_lengths = list(fft_lengths)
    descriptor_bytes, out_dtype, out_shape = _pocketfft_descriptor(
        list(a_type.shape), dtype, fft_type, fft_lengths)

    if out_dtype == np.float32:
        out_type = ir.F32Type.get()
    elif out_dtype == np.float64:
        out_type = ir.F64Type.get()
    elif out_dtype == np.complex64:
        out_type = ir.ComplexType.get(ir.F32Type.get())
    elif out_dtype == np.complex128:
        out_type = ir.ComplexType.get(ir.F64Type.get())
    else:
        raise ValueError(f"Unknown output type {out_dtype}")

    if 0 in a_type.shape or 0 in out_shape:
        zero = mhlo.ConstOp(
            ir.RankedTensorType.get([], out_type),
            ir.DenseElementsAttr.get(np.array(0, dtype=out_dtype),
                                     type=out_type))
        if jax._src.lib.mlir_api_version < 9:
            return mhlo.BroadcastOp(
                ir.RankedTensorType.get(out_shape, out_type), zero,
                ir.DenseElementsAttr.get(np.asarray(out_shape,
                                                    np.int64))).result
        else:
            return mhlo.BroadcastOp(
                zero, ir.DenseElementsAttr.get(np.asarray(out_shape,
                                                          np.int64))).result

    u8_type = ir.IntegerType.get_unsigned(8)
    descriptor = mhlo.ConstOp(
        ir.RankedTensorType.get([len(descriptor_bytes)], u8_type),
        ir.DenseElementsAttr.get(np.frombuffer(descriptor_bytes,
                                               dtype=np.uint8),
                                 type=u8_type))
    layout = ir.DenseIntElementsAttr.get(np.arange(n - 1, -1, -1),
                                         type=ir.IndexType.get())
    return mhlo.CustomCallOp(
        [ir.RankedTensorType.get(out_shape, out_type)], [descriptor, a],
        call_target_name=ir.StringAttr.get("pocketfft"),
        has_side_effect=ir.BoolAttr.get(False),
        backend_config=ir.StringAttr.get(""),
        api_version=ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 2),
        called_computations=ir.ArrayAttr.get([]),
        operand_layouts=ir.ArrayAttr.get([
            ir.DenseIntElementsAttr.get(np.array([0], np.int64),
                                        type=ir.IndexType.get()),
            layout,
        ]),
        result_layouts=ir.ArrayAttr.get([layout])).result
예제 #7
0
def getrf_mhlo(dtype, a):
    """LU decomposition."""
    a_type = ir.RankedTensorType(a.type)
    dims = a_type.shape
    assert len(dims) >= 2
    m, n = dims[-2:]
    batch_dims = tuple(dims[:-2])
    num_bd = len(batch_dims)
    batch = _prod(batch_dims)

    if batch > 1 and m == n and m // batch <= 128:
        lwork, opaque = _hipblas.build_getrf_batched_descriptor(
            np.dtype(dtype), batch, m)
        workspace = ir.RankedTensorType.get([lwork],
                                            ir.IntegerType.get_signless(8))
        kernel = "hipblas_getrf_batched"
    else:
        lwork, opaque = _hipsolver.build_getrf_descriptor(
            np.dtype(dtype), batch, m, n)
        workspace = ir.RankedTensorType.get([lwork], a_type.element_type)
        kernel = "hipsolver_getrf"

    layout = ir.DenseIntElementsAttr.get(
        np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
        type=ir.IndexType.get())
    i32_type = ir.IntegerType.get_signless(32)
    out = mhlo.CustomCallOp(
        [
            ir.TupleType.get_tuple([
                a.type,
                ir.RankedTensorType.get(batch_dims + (min(m, n), ), i32_type),
                ir.RankedTensorType.get(batch_dims, i32_type),
                workspace,
            ])
        ], [a],
        call_target_name=ir.StringAttr.get(kernel),
        has_side_effect=ir.BoolAttr.get(False),
        backend_config=ir.StringAttr.get(opaque),
        api_version=ir.IntegerAttr.get(i32_type, 2),
        called_computations=ir.ArrayAttr.get([]),
        operand_layouts=ir.ArrayAttr.get([layout]),
        result_layouts=ir.ArrayAttr.get([
            layout,
            ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.arange(num_bd - 1, -1, -1),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
        ]))
    return [
        mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
        for i in range(3)
    ]
예제 #8
0
파일: lapack.py 프로젝트: frederikwilde/jax
def getrf_mhlo(dtype, a):
    dims = ir.RankedTensorType(a.type).shape
    assert len(dims) >= 2
    m, n = dims[-2:]
    batch_dims = tuple(dims[:-2])
    num_bd = len(batch_dims)
    b = 1
    for d in batch_dims:
        b *= d

    if dtype == np.float32:
        fn = b"lapack_sgetrf"
    elif dtype == np.float64:
        fn = b"lapack_dgetrf"
    elif dtype == np.complex64:
        fn = b"lapack_cgetrf"
    elif dtype == np.complex128:
        fn = b"lapack_zgetrf"
    else:
        raise NotImplementedError("Unsupported dtype {}".format(dtype))

    scalar_layout = ir.DenseIntElementsAttr.get(np.zeros((0, ), np.int64),
                                                type=ir.IndexType.get())
    layout = ir.DenseIntElementsAttr.get(
        np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
        type=ir.IndexType.get())
    i32_type = ir.IntegerType.get_signless(32)
    out = mhlo.CustomCallOp(
        [
            ir.TupleType.get_tuple([
                a.type,
                ir.RankedTensorType.get(batch_dims + (min(m, n), ), i32_type),
                ir.RankedTensorType.get(batch_dims, i32_type),
            ])
        ], [_mhlo_s32(int(b)),
            _mhlo_s32(m), _mhlo_s32(n), a],
        call_target_name=ir.StringAttr.get(fn),
        has_side_effect=ir.BoolAttr.get(False),
        backend_config=ir.StringAttr.get(""),
        api_version=ir.IntegerAttr.get(i32_type, 2),
        called_computations=ir.ArrayAttr.get([]),
        operand_layouts=ir.ArrayAttr.get([scalar_layout] * 3 + [layout]),
        result_layouts=ir.ArrayAttr.get([
            layout,
            ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.arange(num_bd - 1, -1, -1),
                                        type=ir.IndexType.get()),
        ]))
    return [
        mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
        for i in range(3)
    ]
예제 #9
0
def trsm_mhlo(dtype,
              a,
              b,
              left_side=False,
              lower=False,
              trans_a=False,
              conj_a=False,
              diag=False):
    """Batched triangular solve.

  XLA implements unbatched triangular solve directly, so we need only implement
  the batched case."""
    b_type = ir.RankedTensorType(b.type)
    dims = b_type.shape
    assert len(dims) >= 2
    m, n = dims[-2:]
    batch_dims = tuple(dims[:-2])
    num_bd = len(batch_dims)
    batch = _prod(batch_dims)
    k = m if left_side else n

    a_type = ir.RankedTensorType(a.type)
    if (batch_dims + (k, k) != tuple(a_type.shape)
            or a_type.element_type != b_type.element_type):
        raise ValueError("Argument mismatch for trsm, got {} and {}".format(
            a_type, b_type))

    if conj_a and not trans_a:
        raise NotImplementedError(
            "Conjugation without transposition not supported")

    lwork, opaque = _hipblas.build_trsm_batched_descriptor(
        np.dtype(dtype), batch, m, n, left_side, lower, trans_a, conj_a, diag)
    layout = ir.DenseIntElementsAttr.get(
        np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
        type=ir.IndexType.get())
    work_type = ir.RankedTensorType.get([lwork],
                                        ir.IntegerType.get_signless(8))
    work_layout = ir.DenseIntElementsAttr.get(np.array([0]),
                                              type=ir.IndexType.get())
    i32_type = ir.IntegerType.get_signless(32)
    tup = mhlo.CustomCallOp(
        [ir.TupleType.get_tuple([b_type, work_type, work_type])], [a, b],
        call_target_name=ir.StringAttr.get("hipblas_trsm_batched"),
        has_side_effect=ir.BoolAttr.get(False),
        backend_config=ir.StringAttr.get(opaque),
        api_version=ir.IntegerAttr.get(i32_type, 2),
        called_computations=ir.ArrayAttr.get([]),
        operand_layouts=ir.ArrayAttr.get([layout] * 2),
        result_layouts=ir.ArrayAttr.get([layout, work_layout,
                                         work_layout])).result
    return mhlo.GetTupleElementOp(tup, ir.IntegerAttr.get(i32_type, 0)).result
예제 #10
0
파일: lapack.py 프로젝트: frederikwilde/jax
def potrf_mhlo(dtype, a, lower=False):
    a_type = ir.RankedTensorType(a.type)
    dims = a_type.shape
    m, n = dims[-2:]
    if m != n:
        raise ValueError(
            "potrf expects a square matrix, got {}".format(a_type))
    if dtype == np.float32:
        fn = b"lapack_spotrf"
    elif dtype == np.float64:
        fn = b"lapack_dpotrf"
    elif dtype == np.complex64:
        fn = b"lapack_cpotrf"
    elif dtype == np.complex128:
        fn = b"lapack_zpotrf"
    else:
        raise NotImplementedError("Unsupported dtype {}".format(dtype))
    batch_dims = tuple(dims[:-2])
    num_bd = len(batch_dims)
    b = 1
    for d in batch_dims:
        b *= d

    scalar_layout = ir.DenseIntElementsAttr.get(np.zeros((0, ), np.int64),
                                                type=ir.IndexType.get())
    layout = ir.DenseIntElementsAttr.get(
        np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
        type=ir.IndexType.get())
    i32_type = ir.IntegerType.get_signless(32)
    info_layout = ir.DenseIntElementsAttr.get(np.array(
        range(num_bd - 1, -1, -1)),
                                              type=ir.IndexType.get())
    tup = mhlo.CustomCallOp(
        [
            ir.TupleType.get_tuple([
                a.type,
                ir.RankedTensorType.get(batch_dims, i32_type),
            ])
        ], [_mhlo_s32(int(lower)),
            _mhlo_s32(b), _mhlo_s32(n), a],
        call_target_name=ir.StringAttr.get(fn),
        has_side_effect=ir.BoolAttr.get(False),
        backend_config=ir.StringAttr.get(""),
        api_version=ir.IntegerAttr.get(i32_type, 2),
        called_computations=ir.ArrayAttr.get([]),
        operand_layouts=ir.ArrayAttr.get([scalar_layout] * 3 + [layout]),
        result_layouts=ir.ArrayAttr.get([layout, info_layout])).result
    return [
        mhlo.GetTupleElementOp(tup, ir.IntegerAttr.get(i32_type, 0)).result,
        mhlo.GetTupleElementOp(tup, ir.IntegerAttr.get(i32_type, 1)).result,
    ]
예제 #11
0
def orgqr_mhlo(dtype, a, tau):
    """Product of elementary Householder reflections."""
    a_type = ir.RankedTensorType(a.type)
    dims = a_type.shape
    assert len(dims) >= 2
    m, n = dims[-2:]
    batch_dims = tuple(dims[:-2])
    num_bd = len(batch_dims)
    batch = _prod(batch_dims)

    tau_dims = ir.RankedTensorType(tau.type).shape
    assert tau_dims[:-1] == dims[:-2]
    k = tau_dims[-1]

    lwork, opaque = _hipsolver.build_orgqr_descriptor(np.dtype(dtype), batch,
                                                      m, n, k)

    layout = ir.DenseIntElementsAttr.get(
        np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
        type=ir.IndexType.get())
    i32_type = ir.IntegerType.get_signless(32)
    out = mhlo.CustomCallOp(
        [
            ir.TupleType.get_tuple([
                a.type,
                ir.RankedTensorType.get(batch_dims, i32_type),
                ir.RankedTensorType.get([lwork], a_type.element_type),
            ])
        ], [a, tau],
        call_target_name=ir.StringAttr.get(b"hipsolver_orgqr"),
        has_side_effect=ir.BoolAttr.get(False),
        backend_config=ir.StringAttr.get(opaque),
        api_version=ir.IntegerAttr.get(i32_type, 2),
        called_computations=ir.ArrayAttr.get([]),
        operand_layouts=ir.ArrayAttr.get([
            layout,
            ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1),
                                        type=ir.IndexType.get()),
        ]),
        result_layouts=ir.ArrayAttr.get([
            layout,
            ir.DenseIntElementsAttr.get(np.arange(num_bd - 1, -1, -1),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
        ]))
    return [
        mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
        for i in range(2)
    ]
예제 #12
0
def coo_matvec_mhlo(data,
                    row,
                    col,
                    x,
                    *,
                    shape,
                    transpose=False,
                    compute_dtype=None,
                    compute_type=None,
                    index_dtype,
                    data_dtype,
                    x_dtype):
    """COO matrix/vector multiply."""
    data_type, index_type, nnz = _validate_coo_mhlo(data, row, col, shape)
    rows, cols = shape

    if compute_dtype is None:
        compute_dtype = data_dtype
        compute_type = data_type

    buffer_size, opaque = _cusparse.build_coo_matvec_descriptor(
        data_dtype, x_dtype, compute_dtype, index_dtype, rows, cols, nnz,
        transpose)
    out_size = cols if transpose else rows

    i32_type = ir.IntegerType.get_signless(32)
    out = mhlo.CustomCallOp(
        [
            ir.TupleType.get_tuple([
                ir.RankedTensorType.get([out_size], compute_type),
                ir.RankedTensorType.get([buffer_size],
                                        ir.IntegerType.get_signless(8)),
            ])
        ], [data, row, col, x],
        call_target_name=ir.StringAttr.get("cusparse_coo_matvec"),
        has_side_effect=ir.BoolAttr.get(False),
        backend_config=ir.StringAttr.get(opaque),
        api_version=ir.IntegerAttr.get(i32_type, 2),
        called_computations=ir.ArrayAttr.get([]),
        operand_layouts=ir.ArrayAttr.get([
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
        ] * 4),
        result_layouts=ir.ArrayAttr.get([
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
        ] * 2))
    return mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 0)).result
예제 #13
0
def custom_call(
    call_target_name: str,
    out_types: Sequence[ir.Type],
    operands: Sequence[ir.Value],
    operand_layouts: Sequence[Sequence[int]],
    result_layouts: Sequence[Sequence[int]],
    backend_config: Optional[str] = None,
    has_side_effect: bool = False,
    api_version: int = 2,
) -> Union[ir.Value, Sequence[ir.Value]]:
    """Less-verbose helper for building an MHLO custom call op.

  Once https://github.com/llvm/llvm-project/issues/54932 is fixed, this helper
  may be able to go away.
  """
    i32_type = ir.IntegerType.get_signless(32)
    out = mhlo.CustomCallOp(
        (out_types
         if len(out_types) == 1 else [ir.TupleType.get_tuple(out_types)]),
        operands,
        call_target_name=ir.StringAttr.get(call_target_name),
        has_side_effect=ir.BoolAttr.get(has_side_effect),
        backend_config=ir.StringAttr.get(
            "" if backend_config is None else backend_config),
        api_version=ir.IntegerAttr.get(i32_type, api_version),
        called_computations=ir.ArrayAttr.get([]),
        operand_layouts=ir.ArrayAttr.get([
            ir.DenseIntElementsAttr.get(np.atleast_1d(
                np.asarray(l, dtype=np.int64)),
                                        type=ir.IndexType.get())
            for l in operand_layouts
        ]),
        result_layouts=ir.ArrayAttr.get([
            ir.DenseIntElementsAttr.get(np.atleast_1d(
                np.asarray(l, dtype=np.int64)),
                                        type=ir.IndexType.get())
            for l in result_layouts
        ]))
    if len(out_types) == 1:
        return out.result
    else:
        return [
            mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
            for i in range(len(out_types))
        ]
예제 #14
0
def potrf_mhlo(dtype, a, lower):
    """Cholesky decomposition."""
    a_type = ir.RankedTensorType(a.type)
    dims = a_type.shape
    m, n = dims[-2:]
    assert m == n
    batch_dims = tuple(dims[:-2])
    num_bd = len(batch_dims)
    batch = _prod(batch_dims)

    lwork, opaque = _hipsolver.build_potrf_descriptor(np.dtype(dtype), lower,
                                                      batch, n)
    kernel = b"hipsolver_potrf"

    layout = ir.DenseIntElementsAttr.get(
        np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
        type=ir.IndexType.get())
    info_layout = ir.DenseIntElementsAttr.get(np.array(
        range(num_bd - 1, -1, -1)),
                                              type=ir.IndexType.get())
    i32_type = ir.IntegerType.get_signless(32)
    info_type = ir.RankedTensorType.get(batch_dims, i32_type)
    work_layout = ir.DenseIntElementsAttr.get(np.array([0]),
                                              type=ir.IndexType.get())
    tup = mhlo.CustomCallOp([
        ir.TupleType.get_tuple([
            a.type,
            info_type,
            ir.RankedTensorType.get([lwork], ir.IntegerType.get_signless(8)),
        ])
    ], [a],
                            call_target_name=ir.StringAttr.get(kernel),
                            has_side_effect=ir.BoolAttr.get(False),
                            backend_config=ir.StringAttr.get(opaque),
                            api_version=ir.IntegerAttr.get(
                                ir.IntegerType.get_signless(32), 2),
                            called_computations=ir.ArrayAttr.get([]),
                            operand_layouts=ir.ArrayAttr.get([layout]),
                            result_layouts=ir.ArrayAttr.get(
                                [layout, info_layout, work_layout])).result
    return [
        mhlo.GetTupleElementOp(tup, ir.IntegerAttr.get(i32_type, 0)).result,
        mhlo.GetTupleElementOp(tup, ir.IntegerAttr.get(i32_type, 1)).result,
    ]
예제 #15
0
파일: hipsparse.py 프로젝트: John1Tang/jax
def gtsv2_mhlo(dl, d, du, B, *, m, n, ldb, t):
    """Calls `hipsparse<t>gtsv2(dl, d, du, B, m, n, ldb)`."""
    f32 = (t == np.float32)
    if f32:
        buffer_size = _hipsparse.gtsv2_f32_buffer_size(m, n, ldb)
    else:
        buffer_size = _hipsparse.gtsv2_f64_buffer_size(m, n, ldb)
    i32_type = ir.IntegerType.get_signless(32)
    out = mhlo.CustomCallOp(
        [
            ir.TupleType.get_tuple([
                ir.RankedTensorType.get(
                    [ldb, n],
                    ir.F32Type.get() if f32 else ir.F64Type.get()),
                ir.RankedTensorType.get([buffer_size],
                                        ir.IntegerType.get_signless(8)),
            ])
        ], [dl, d, du, B],
        call_target_name=ir.StringAttr.get("hipsparse_gtsv2_" +
                                           ("f32" if f32 else "f64")),
        has_side_effect=ir.BoolAttr.get(False),
        backend_config=ir.StringAttr.get(
            _hipsparse.build_gtsv2_descriptor(m, n, ldb)),
        api_version=ir.IntegerAttr.get(i32_type, 2),
        called_computations=ir.ArrayAttr.get([]),
        operand_layouts=ir.ArrayAttr.get([
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
        ] * 3 + [
            ir.DenseIntElementsAttr.get(np.array([1, 0]),
                                        type=ir.IndexType.get())
        ]),
        result_layouts=ir.ArrayAttr.get([
            ir.DenseIntElementsAttr.get(np.array([1, 0]),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
        ]))
    return mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 0)).result
예제 #16
0
파일: hipsparse.py 프로젝트: John1Tang/jax
def coo_fromdense_mhlo(mat, *, nnz, data_dtype, index_dtype, index_type):
    """COO from dense matrix."""
    mat_type = ir.RankedTensorType(mat.type)
    rows, cols = mat_type.shape

    buffer_size, opaque = _hipsparse.build_coo_fromdense_descriptor(
        data_dtype, index_dtype, rows, cols, nnz)

    i32_type = ir.IntegerType.get_signless(32)
    out = mhlo.CustomCallOp(
        [
            ir.TupleType.get_tuple([
                ir.RankedTensorType.get([nnz], mat_type.element_type),
                ir.RankedTensorType.get([nnz], index_type),
                ir.RankedTensorType.get([nnz], index_type),
                ir.RankedTensorType.get([buffer_size],
                                        ir.IntegerType.get_signless(8)),
            ])
        ], [mat],
        call_target_name=ir.StringAttr.get("hipsparse_coo_fromdense"),
        has_side_effect=ir.BoolAttr.get(False),
        backend_config=ir.StringAttr.get(opaque),
        api_version=ir.IntegerAttr.get(i32_type, 2),
        called_computations=ir.ArrayAttr.get([]),
        operand_layouts=ir.ArrayAttr.get([
            ir.DenseIntElementsAttr.get(np.array([1, 0]),
                                        type=ir.IndexType.get()),
        ]),
        result_layouts=ir.ArrayAttr.get([
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
        ] * 4))
    return [
        mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
        for i in range(3)
    ]
예제 #17
0
파일: lapack.py 프로젝트: frederikwilde/jax
def syevd_mhlo(dtype, a, lower=False):
    a_type = ir.RankedTensorType(a.type)
    dims = a_type.shape
    assert len(dims) >= 2
    m, n = dims[-2:]
    assert m == n
    batch_dims = tuple(dims[:-2])
    num_bd = len(batch_dims)
    b = 1
    for d in batch_dims:
        b *= d
    layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))

    i32_type = ir.IntegerType.get_signless(32)
    if dtype == np.float32:
        fn = b"lapack_ssyevd"
        eigvals_type = ir.F32Type.get()
        workspace = [
            ir.RankedTensorType.get([_lapack.syevd_work_size(n)],
                                    a_type.element_type),
            ir.RankedTensorType.get([_lapack.syevd_iwork_size(n)], i32_type),
        ]
        workspace_layouts = [
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
        ]
    elif dtype == np.float64:
        fn = b"lapack_dsyevd"
        eigvals_type = ir.F64Type.get()
        workspace = [
            ir.RankedTensorType.get([_lapack.syevd_work_size(n)],
                                    a_type.element_type),
            ir.RankedTensorType.get([_lapack.syevd_iwork_size(n)], i32_type),
        ]
        workspace_layouts = [
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
        ]
    elif dtype == np.complex64:
        fn = b"lapack_cheevd"
        eigvals_type = ir.F32Type.get()
        workspace = [
            ir.RankedTensorType.get([_lapack.heevd_work_size(n)],
                                    a_type.element_type),
            ir.RankedTensorType.get([_lapack.heevd_rwork_size(n)],
                                    eigvals_type),
            ir.RankedTensorType.get([_lapack.syevd_iwork_size(n)], i32_type),
        ]
        workspace_layouts = [
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
        ]
    elif dtype == np.complex128:
        fn = b"lapack_zheevd"
        eigvals_type = ir.F64Type.get()
        workspace = [
            ir.RankedTensorType.get([_lapack.heevd_work_size(n)],
                                    a_type.element_type),
            ir.RankedTensorType.get([_lapack.heevd_rwork_size(n)],
                                    eigvals_type),
            ir.RankedTensorType.get([_lapack.syevd_iwork_size(n)], i32_type),
        ]
        workspace_layouts = [
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
        ]
    else:
        raise NotImplementedError("Unsupported dtype {}".format(dtype))

    scalar_layout = ir.DenseIntElementsAttr.get(np.zeros((0, ), np.int64),
                                                type=ir.IndexType.get())
    layout = ir.DenseIntElementsAttr.get(
        np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
        type=ir.IndexType.get())
    out = mhlo.CustomCallOp(
        [
            ir.TupleType.get_tuple([
                a.type,
                ir.RankedTensorType.get(batch_dims + (n, ), eigvals_type),
                ir.RankedTensorType.get(batch_dims, i32_type),
            ] + workspace)
        ], [_mhlo_s32(1 if lower else 0),
            _mhlo_s32(b),
            _mhlo_s32(n), a],
        call_target_name=ir.StringAttr.get(fn),
        has_side_effect=ir.BoolAttr.get(False),
        backend_config=ir.StringAttr.get(""),
        api_version=ir.IntegerAttr.get(i32_type, 2),
        called_computations=ir.ArrayAttr.get([]),
        operand_layouts=ir.ArrayAttr.get([scalar_layout] * 3 + [layout]),
        result_layouts=ir.ArrayAttr.get([
            layout,
            ir.DenseIntElementsAttr.get(np.array(range(num_bd, -1, -1)),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array(range(num_bd - 1, -1, -1)),
                                        type=ir.IndexType.get()),
        ] + workspace_layouts))
    return [
        mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
        for i in range(3)
    ]
예제 #18
0
파일: lapack.py 프로젝트: frederikwilde/jax
def gees_mhlo(a, jobvs=True, sort=False, select=None):
    a_type = ir.RankedTensorType(a.type)
    etype = a_type.element_type
    dims = a_type.shape
    assert len(dims) >= 2
    m, n = dims[-2:]
    assert m == n
    batch_dims = tuple(dims[:-2])
    num_bd = len(batch_dims)
    b = 1
    for d in batch_dims:
        b *= d
    layout = ir.DenseIntElementsAttr.get(
        np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
        type=ir.IndexType.get())

    if sort:
        raise NotImplementedError(
            "The sort feature of LAPACK's gees routine is not implemented.")

    jobvs = ord('V' if jobvs else 'N')
    sort = ord('S' if sort else 'N')

    if not ir.ComplexType.isinstance(etype):
        fn = "lapack_sgees" if etype == ir.F32Type.get() else "lapack_dgees"
        schurvecs_type = etype
        workspaces = [ir.RankedTensorType.get(dims, schurvecs_type)]
        workspace_layouts = [layout]
        eigvals = [ir.RankedTensorType.get(batch_dims + (n, ), etype)] * 2
        eigvals_layouts = [
            ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1),
                                        type=ir.IndexType.get())
        ] * 2
    else:
        fn = ("lapack_cgees" if etype == ir.ComplexType.get(ir.F32Type.get())
              else "lapack_zgees")
        schurvecs_type = etype
        workspaces = [
            ir.RankedTensorType.get(dims, schurvecs_type),
            ir.RankedTensorType.get([n],
                                    ir.ComplexType(etype).element_type),
        ]
        workspace_layouts = [
            layout,
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
        ]
        eigvals = [ir.RankedTensorType.get(batch_dims + (n, ), etype)]
        eigvals_layouts = [
            ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1),
                                        type=ir.IndexType.get())
        ]

    i32_type = ir.IntegerType.get_signless(32)

    scalar_layout = ir.DenseIntElementsAttr.get(np.zeros((0, ), np.int64),
                                                type=ir.IndexType.get())
    out = mhlo.CustomCallOp(
        [
            ir.TupleType.get_tuple(workspaces + eigvals + [
                ir.RankedTensorType.get(dims, schurvecs_type),
                ir.RankedTensorType.get(batch_dims, i32_type),
                ir.RankedTensorType.get(batch_dims, i32_type),
            ])
        ],
        [
            _mhlo_s32(b),
            _mhlo_s32(n),
            _mhlo_u8(np.uint8(jobvs)),
            _mhlo_u8(np.uint8(sort)),
            # TODO: figure out how to put the callable select function here
            a
        ],
        call_target_name=ir.StringAttr.get(fn),
        has_side_effect=ir.BoolAttr.get(False),
        backend_config=ir.StringAttr.get(""),
        api_version=ir.IntegerAttr.get(i32_type, 2),
        called_computations=ir.ArrayAttr.get([]),
        operand_layouts=ir.ArrayAttr.get([scalar_layout] * 4 + [layout]),
        result_layouts=ir.ArrayAttr.get(workspace_layouts + eigvals_layouts + [
            layout,
            ir.DenseIntElementsAttr.get(np.arange(num_bd - 1, -1, -1),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.arange(num_bd - 1, -1, -1),
                                        type=ir.IndexType.get()),
        ]))
    i32_attr = lambda i: ir.IntegerAttr.get(i32_type, i)
    if sort == ord('S'):
        return (mhlo.GetTupleElementOp(out, i32_attr(0)).result,
                mhlo.GetTupleElementOp(out, i32_attr(3)).result,
                mhlo.GetTupleElementOp(out, i32_attr(4)).result,
                mhlo.GetTupleElementOp(out, i32_attr(5)).result)
    else:
        return (mhlo.GetTupleElementOp(out, i32_attr(0)).result,
                mhlo.GetTupleElementOp(out, i32_attr(3)).result,
                mhlo.GetTupleElementOp(out, i32_attr(5)).result)
예제 #19
0
파일: lapack.py 프로젝트: frederikwilde/jax
def trsm_mhlo(dtype,
              alpha,
              a,
              b,
              left_side=False,
              lower=False,
              trans_a=False,
              conj_a=False,
              diag=False):
    a_type = ir.RankedTensorType(a.type)
    b_type = ir.RankedTensorType(b.type)

    dims = b_type.shape
    m, n = dims[-2:]
    k = m if left_side else n

    batch_dims = tuple(dims[:-2])
    num_bd = len(batch_dims)
    num_b = 1
    for d in batch_dims:
        num_b *= d

    if (batch_dims + (k, k) != tuple(a_type.shape)
            or a_type.element_type != b_type.element_type):
        raise ValueError("Argument mismatch for trsm, got {} and {}".format(
            a_type, b_type))

    if dtype == np.float32:
        fn = "blas_strsm"
    elif dtype == np.float64:
        fn = "blas_dtrsm"
    elif dtype == np.complex64:
        fn = "blas_ctrsm"
    elif dtype == np.complex128:
        fn = "blas_ztrsm"
    else:
        raise NotImplementedError("Unsupported dtype {}".format(dtype))

    if conj_a and not trans_a:
        raise NotImplementedError(
            "Conjugation without transposition not supported")
    scalar_layout = ir.DenseIntElementsAttr.get(np.zeros((0, ), np.int64),
                                                type=ir.IndexType.get())
    layout = ir.DenseIntElementsAttr.get(
        np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
        type=ir.IndexType.get())
    return mhlo.CustomCallOp(
        [b.type], [
            _mhlo_s32(int(left_side)),
            _mhlo_s32(int(lower)),
            _mhlo_s32((2 if conj_a else 1) if trans_a else 0),
            _mhlo_s32(int(diag)),
            _mhlo_s32(m),
            _mhlo_s32(n),
            _mhlo_s32(num_b), alpha, a, b
        ],
        call_target_name=ir.StringAttr.get(fn),
        has_side_effect=ir.BoolAttr.get(False),
        backend_config=ir.StringAttr.get(""),
        api_version=ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 2),
        called_computations=ir.ArrayAttr.get([]),
        operand_layouts=ir.ArrayAttr.get([scalar_layout] * 8 + [layout] * 2),
        result_layouts=ir.ArrayAttr.get([layout])).result
예제 #20
0
파일: lapack.py 프로젝트: frederikwilde/jax
def gesdd_mhlo(dtype, a, full_matrices=True, compute_uv=True):
    a_type = ir.RankedTensorType(a.type)
    dims = a_type.shape
    assert len(dims) >= 2
    m, n = dims[-2:]
    batch_dims = tuple(dims[:-2])
    num_bd = len(batch_dims)
    b = 1
    for d in batch_dims:
        b *= d

    i32_type = ir.IntegerType.get_signless(32)
    if dtype == np.float32:
        fn = b"lapack_sgesdd"
        singular_vals_type = ir.F32Type.get()
        lwork = _lapack.sgesdd_work_size(m, n, compute_uv, full_matrices)
        workspace = [
            ir.RankedTensorType.get([_lapack.gesdd_iwork_size(m, n)],
                                    i32_type),
            ir.RankedTensorType.get([lwork], a_type.element_type),
        ]
        workspace_layouts = [
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
        ]
    elif dtype == np.float64:
        fn = b"lapack_dgesdd"
        singular_vals_type = ir.F64Type.get()
        lwork = _lapack.dgesdd_work_size(m, n, compute_uv, full_matrices)
        workspace = [
            ir.RankedTensorType.get([_lapack.gesdd_iwork_size(m, n)],
                                    i32_type),
            ir.RankedTensorType.get([lwork], a_type.element_type),
        ]
        workspace_layouts = [
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
        ]
    elif dtype == np.complex64:
        fn = b"lapack_cgesdd"
        singular_vals_type = ir.F32Type.get()
        lwork = _lapack.cgesdd_work_size(m, n, compute_uv, full_matrices)
        workspace = [
            ir.RankedTensorType.get([_lapack.gesdd_iwork_size(m, n)],
                                    i32_type),
            ir.RankedTensorType.get(
                [_lapack.cgesdd_rwork_size(m, n, int(compute_uv))],
                ir.F32Type.get()),
            ir.RankedTensorType.get([lwork], a_type.element_type),
        ]
        workspace_layouts = [
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
        ]
    elif dtype == np.complex128:
        fn = b"lapack_zgesdd"
        singular_vals_type = ir.F64Type.get()
        lwork = _lapack.zgesdd_work_size(m, n, compute_uv, full_matrices)
        workspace = [
            ir.RankedTensorType.get([_lapack.gesdd_iwork_size(m, n)],
                                    i32_type),
            ir.RankedTensorType.get(
                [_lapack.cgesdd_rwork_size(m, n, int(compute_uv))],
                ir.F64Type.get()),
            ir.RankedTensorType.get([lwork], a_type.element_type),
        ]
        workspace_layouts = [
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
        ]
    else:
        raise NotImplementedError("Unsupported dtype {}".format(dtype))

    scalar_layout = ir.DenseIntElementsAttr.get(np.zeros((0, ), np.int64),
                                                type=ir.IndexType.get())
    layout = ir.DenseIntElementsAttr.get(
        np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
        type=ir.IndexType.get())
    out = mhlo.CustomCallOp([
        ir.TupleType.get_tuple([
            a.type,
            ir.RankedTensorType.get(batch_dims +
                                    (min(m, n), ), singular_vals_type),
            ir.RankedTensorType.get(
                batch_dims +
                (m, m if full_matrices else min(m, n)), a_type.element_type),
            ir.RankedTensorType.get(
                batch_dims +
                (n if full_matrices else min(m, n), n), a_type.element_type),
            ir.RankedTensorType.get(batch_dims, i32_type),
        ] + workspace)
    ], [
        _mhlo_s32(int(full_matrices)),
        _mhlo_s32(int(compute_uv)),
        _mhlo_s32(b),
        _mhlo_s32(m),
        _mhlo_s32(n),
        _mhlo_s32(lwork), a
    ],
                            call_target_name=ir.StringAttr.get(fn),
                            has_side_effect=ir.BoolAttr.get(False),
                            backend_config=ir.StringAttr.get(""),
                            api_version=ir.IntegerAttr.get(i32_type, 2),
                            called_computations=ir.ArrayAttr.get([]),
                            operand_layouts=ir.ArrayAttr.get([scalar_layout] *
                                                             6 + [layout]),
                            result_layouts=ir.ArrayAttr.get([
                                layout,
                                ir.DenseIntElementsAttr.get(
                                    np.array((num_bd, ) +
                                             tuple(range(num_bd - 1, -1, -1))),
                                    type=ir.IndexType.get()),
                                layout,
                                layout,
                                ir.DenseIntElementsAttr.get(
                                    np.array(range(num_bd - 1, -1, -1)),
                                    type=ir.IndexType.get()),
                            ] + workspace_layouts))
    return [
        mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
        for i in range(1, 5)
    ]
예제 #21
0
파일: cusolver.py 프로젝트: John1Tang/jax
def gesvd_mhlo(dtype, a, full_matrices=True, compute_uv=True):
    """Singular value decomposition."""
    a_type = ir.RankedTensorType(a.type)
    dims = a_type.shape
    assert len(dims) >= 2
    m, n = dims[-2:]
    batch_dims = tuple(dims[:-2])
    num_bd = len(batch_dims)
    b = _prod(batch_dims)
    if ir.ComplexType.isinstance(a_type.element_type):
        singular_vals_type = ir.ComplexType(a_type.element_type).element_type
    else:
        singular_vals_type = a_type.element_type

    scalar_layout = ir.DenseIntElementsAttr.get(np.array(
        tuple(range(num_bd - 1, -1, -1))),
                                                type=ir.IndexType.get())
    vector_layout = ir.DenseIntElementsAttr.get(
        np.array((num_bd, ) + tuple(range(num_bd - 1, -1, -1))),
        type=ir.IndexType.get())
    i32_type = ir.IntegerType.get_signless(32)

    if m < 32 and n < 32:
        # The batched kernel doesn't support "econ" mode.
        econ = not full_matrices and b == 1
        lwork, opaque = _cusolver.build_gesvdj_descriptor(
            np.dtype(dtype), b, m, n, compute_uv, 1 if econ else 0)
        k = min(m, n)
        matrix_layout = ir.DenseIntElementsAttr.get(
            np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
            type=ir.IndexType.get())
        out = mhlo.CustomCallOp(
            [
                ir.TupleType.get_tuple([
                    a.type,
                    ir.RankedTensorType.get(batch_dims +
                                            (min(m, n), ), singular_vals_type),
                    ir.RankedTensorType.get(batch_dims + (m, k if econ else m),
                                            a_type.element_type),
                    ir.RankedTensorType.get(batch_dims + (n, k if econ else n),
                                            a_type.element_type),
                    ir.RankedTensorType.get(batch_dims, i32_type),
                    ir.RankedTensorType.get([lwork], a_type.element_type),
                ])
            ], [a],
            call_target_name=ir.StringAttr.get("cusolver_gesvdj"),
            has_side_effect=ir.BoolAttr.get(False),
            backend_config=ir.StringAttr.get(opaque),
            api_version=ir.IntegerAttr.get(i32_type, 2),
            called_computations=ir.ArrayAttr.get([]),
            operand_layouts=ir.ArrayAttr.get([matrix_layout]),
            result_layouts=ir.ArrayAttr.get([
                matrix_layout,
                vector_layout,
                matrix_layout,
                matrix_layout,
                scalar_layout,
                ir.DenseIntElementsAttr.get(np.array([0]),
                                            type=ir.IndexType.get()),
            ]))
        s = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 1)).result
        u = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 2)).result
        v = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 3))
        info = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type,
                                                              4)).result
        vt = mhlo.TransposeOp(
            v,
            ir.DenseIntElementsAttr.get(
                np.array(tuple(range(num_bd)) + (num_bd + 1, num_bd)))).result
        if np.issubdtype(dtype, np.complexfloating):
            vt = mhlo.ComplexOp(mhlo.RealOp(vt),
                                mhlo.NegOp(mhlo.ImagOp(vt))).result
        if not full_matrices and not econ:
            u = mhlo.SliceOp(
                u, ir.DenseIntElementsAttr.get(np.zeros([len(dims)],
                                                        np.int64)),
                ir.DenseIntElementsAttr.get(
                    np.array(batch_dims + (m, min(m, n)))),
                ir.DenseIntElementsAttr.get(np.ones([len(dims)],
                                                    np.int64))).result
            vt = mhlo.SliceOp(
                vt, ir.DenseIntElementsAttr.get(np.zeros([len(dims)],
                                                         np.int64)),
                ir.DenseIntElementsAttr.get(
                    np.array(batch_dims + (min(m, n), n))),
                ir.DenseIntElementsAttr.get(np.ones([len(dims)],
                                                    np.int64))).result
    elif m < n:
        lwork, opaque = _cusolver.build_gesvd_descriptor(
            np.dtype(dtype), b, n, m, compute_uv, full_matrices)
        k = n if full_matrices else m
        matrix_layout = ir.DenseIntElementsAttr.get(
            np.array((num_bd + 1, num_bd) + tuple(range(num_bd - 1, -1, -1))),
            type=ir.IndexType.get())
        out = mhlo.CustomCallOp(
            [
                ir.TupleType.get_tuple([
                    a.type,
                    ir.RankedTensorType.get(batch_dims +
                                            (min(m, n), ), singular_vals_type),
                    ir.RankedTensorType.get(batch_dims +
                                            (k, n), a_type.element_type),
                    ir.RankedTensorType.get(batch_dims +
                                            (m, m), a_type.element_type),
                    ir.RankedTensorType.get(batch_dims, i32_type),
                    ir.RankedTensorType.get([lwork], a_type.element_type),
                ])
            ], [a],
            call_target_name=ir.StringAttr.get("cusolver_gesvd"),
            has_side_effect=ir.BoolAttr.get(False),
            backend_config=ir.StringAttr.get(opaque),
            api_version=ir.IntegerAttr.get(i32_type, 2),
            called_computations=ir.ArrayAttr.get([]),
            operand_layouts=ir.ArrayAttr.get([matrix_layout]),
            result_layouts=ir.ArrayAttr.get([
                matrix_layout,
                vector_layout,
                matrix_layout,
                matrix_layout,
                scalar_layout,
                ir.DenseIntElementsAttr.get(np.array([0]),
                                            type=ir.IndexType.get()),
            ]))
        s = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 1)).result
        vt = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type,
                                                            2)).result
        u = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 3)).result
        info = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type,
                                                              4)).result
    else:
        lwork, opaque = _cusolver.build_gesvd_descriptor(
            np.dtype(dtype), b, m, n, compute_uv, full_matrices)
        k = m if full_matrices else n
        matrix_layout = ir.DenseIntElementsAttr.get(
            np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
            type=ir.IndexType.get())
        out = mhlo.CustomCallOp(
            [
                ir.TupleType.get_tuple([
                    a.type,
                    ir.RankedTensorType.get(batch_dims +
                                            (min(m, n), ), singular_vals_type),
                    ir.RankedTensorType.get(batch_dims +
                                            (m, k), a_type.element_type),
                    ir.RankedTensorType.get(batch_dims +
                                            (n, n), a_type.element_type),
                    ir.RankedTensorType.get(batch_dims, i32_type),
                    ir.RankedTensorType.get([lwork], a_type.element_type),
                ])
            ], [a],
            call_target_name=ir.StringAttr.get("cusolver_gesvd"),
            has_side_effect=ir.BoolAttr.get(False),
            backend_config=ir.StringAttr.get(opaque),
            api_version=ir.IntegerAttr.get(i32_type, 2),
            called_computations=ir.ArrayAttr.get([]),
            operand_layouts=ir.ArrayAttr.get([matrix_layout]),
            result_layouts=ir.ArrayAttr.get([
                matrix_layout,
                vector_layout,
                matrix_layout,
                matrix_layout,
                scalar_layout,
                ir.DenseIntElementsAttr.get(np.array([0]),
                                            type=ir.IndexType.get()),
            ]))
        s = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 1)).result
        u = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 2)).result
        vt = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type,
                                                            3)).result
        info = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type,
                                                              4)).result
    return s, u, vt, info
예제 #22
0
def gesvd_mhlo(dtype, a, full_matrices=True, compute_uv=True):
    """Singular value decomposition."""
    a_type = ir.RankedTensorType(a.type)
    dims = a_type.shape
    assert len(dims) >= 2
    m, n = dims[-2:]
    batch_dims = tuple(dims[:-2])
    num_bd = len(batch_dims)
    b = _prod(batch_dims)
    if ir.ComplexType.isinstance(a_type.element_type):
        singular_vals_type = ir.ComplexType(a_type.element_type).element_type
    else:
        singular_vals_type = a_type.element_type

    scalar_layout = ir.DenseIntElementsAttr.get(np.array(
        tuple(range(num_bd - 1, -1, -1))),
                                                type=ir.IndexType.get())
    vector_layout = ir.DenseIntElementsAttr.get(
        np.array((num_bd, ) + tuple(range(num_bd - 1, -1, -1))),
        type=ir.IndexType.get())
    matrix_layout = ir.DenseIntElementsAttr.get(
        np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
        type=ir.IndexType.get())
    i32_type = ir.IntegerType.get_signless(32)

    if m < n:
        lwork, opaque = _hipsolver.build_gesvd_descriptor(
            np.dtype(dtype), b, n, m, compute_uv, full_matrices)
        k = n if full_matrices else m
        out = mhlo.CustomCallOp(
            [
                ir.TupleType.get_tuple([
                    a.type,
                    ir.RankedTensorType.get(batch_dims +
                                            (min(m, n), ), singular_vals_type),
                    ir.RankedTensorType.get(batch_dims +
                                            (k, n), a_type.element_type),
                    ir.RankedTensorType.get(batch_dims +
                                            (m, m), a_type.element_type),
                    ir.RankedTensorType.get(batch_dims, i32_type),
                    ir.RankedTensorType.get([lwork], a_type.element_type),
                ])
            ], [a],
            call_target_name=ir.StringAttr.get("hipsolver_gesvd"),
            has_side_effect=ir.BoolAttr.get(False),
            backend_config=ir.StringAttr.get(opaque),
            api_version=ir.IntegerAttr.get(i32_type, 2),
            called_computations=ir.ArrayAttr.get([]),
            operand_layouts=ir.ArrayAttr.get([matrix_layout]),
            result_layouts=ir.ArrayAttr.get([
                matrix_layout,
                vector_layout,
                matrix_layout,
                matrix_layout,
                scalar_layout,
                ir.DenseIntElementsAttr.get(np.array([0]),
                                            type=ir.IndexType.get()),
            ]))
        s = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 1)).result
        vt = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type,
                                                            2)).result
        u = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 3)).result
        info = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type,
                                                              4)).result
    else:
        lwork, opaque = _hipsolver.build_gesvd_descriptor(
            np.dtype(dtype), b, m, n, compute_uv, full_matrices)
        k = m if full_matrices else n
        out = mhlo.CustomCallOp(
            [
                ir.TupleType.get_tuple([
                    a.type,
                    ir.RankedTensorType.get(batch_dims +
                                            (min(m, n), ), singular_vals_type),
                    ir.RankedTensorType.get(batch_dims +
                                            (m, k), a_type.element_type),
                    ir.RankedTensorType.get(batch_dims +
                                            (n, n), a_type.element_type),
                    ir.RankedTensorType.get(batch_dims, i32_type),
                    ir.RankedTensorType.get([lwork], a_type.element_type),
                ])
            ], [a],
            call_target_name=ir.StringAttr.get("hipsolver_gesvd"),
            has_side_effect=ir.BoolAttr.get(False),
            backend_config=ir.StringAttr.get(opaque),
            api_version=ir.IntegerAttr.get(i32_type, 2),
            called_computations=ir.ArrayAttr.get([]),
            operand_layouts=ir.ArrayAttr.get([matrix_layout]),
            result_layouts=ir.ArrayAttr.get([
                matrix_layout,
                vector_layout,
                matrix_layout,
                matrix_layout,
                scalar_layout,
                ir.DenseIntElementsAttr.get(np.array([0]),
                                            type=ir.IndexType.get()),
            ]))
        s = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 1)).result
        u = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 2)).result
        vt = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type,
                                                            3)).result
        info = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type,
                                                              4)).result
    return s, u, vt, info
예제 #23
0
파일: lapack.py 프로젝트: frederikwilde/jax
def geev_mhlo(dtype, a, jobvl=True, jobvr=True):
    dims = ir.RankedTensorType(a.type).shape
    assert len(dims) >= 2
    m, n = dims[-2:]
    assert m == n
    batch_dims = tuple(dims[:-2])
    num_bd = len(batch_dims)
    b = 1
    for d in batch_dims:
        b *= d
    layout = ir.DenseIntElementsAttr.get(
        np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
        type=ir.IndexType.get())

    jobvl_c = ord('V' if jobvl else 'N')
    jobvr_c = ord('V' if jobvr else 'N')

    if dtype == np.float32:
        fn = b"lapack_sgeev"
        real = True
        eigvecs_type = ir.ComplexType.get(ir.F32Type.get())
        workspaces = [
            ir.RankedTensorType.get([n, n], ir.F32Type.get()),
            ir.RankedTensorType.get([n, n], ir.F32Type.get()),
            ir.RankedTensorType.get([n, n], ir.F32Type.get())
        ]
        workspace_layouts = [
            ir.DenseIntElementsAttr.get(np.array([0, 1]),
                                        type=ir.IndexType.get())
        ] * 3
        eigvals = [
            ir.RankedTensorType.get(batch_dims + (n, ), ir.F32Type.get()),
            ir.RankedTensorType.get(batch_dims + (n, ), ir.F32Type.get())
        ]
        eigvals_layouts = [
            ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1),
                                        type=ir.IndexType.get())
        ] * 2
    elif dtype == np.float64:
        fn = b"lapack_dgeev"
        real = True
        eigvecs_type = ir.ComplexType.get(ir.F64Type.get())
        workspaces = [
            ir.RankedTensorType.get([n, n], ir.F64Type.get()),
            ir.RankedTensorType.get([n, n], ir.F64Type.get()),
            ir.RankedTensorType.get([n, n], ir.F64Type.get())
        ]
        workspace_layouts = [
            ir.DenseIntElementsAttr.get(np.array([0, 1]),
                                        type=ir.IndexType.get())
        ] * 3
        eigvals = [
            ir.RankedTensorType.get(batch_dims + (n, ), ir.F64Type.get()),
            ir.RankedTensorType.get(batch_dims + (n, ), ir.F64Type.get())
        ]
        eigvals_layouts = [
            ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1),
                                        type=ir.IndexType.get())
        ] * 2
    elif dtype == np.complex64:
        fn = b"lapack_cgeev"
        real = False
        eigvecs_type = ir.ComplexType.get(ir.F32Type.get())
        workspaces = [
            ir.RankedTensorType.get([n, n],
                                    ir.ComplexType.get(ir.F32Type.get())),
            ir.RankedTensorType.get([2 * n], ir.F32Type.get())
        ]
        workspace_layouts = [
            ir.DenseIntElementsAttr.get(np.array([0, 1]),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get())
        ]
        eigvals = [
            ir.RankedTensorType.get(batch_dims + (n, ),
                                    ir.ComplexType.get(ir.F32Type.get()))
        ]
        eigvals_layouts = [
            ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1),
                                        type=ir.IndexType.get())
        ]
    elif dtype == np.complex128:
        fn = b"lapack_zgeev"
        real = False
        eigvecs_type = ir.ComplexType.get(ir.F64Type.get())
        workspaces = [
            ir.RankedTensorType.get([n, n],
                                    ir.ComplexType.get(ir.F64Type.get())),
            ir.RankedTensorType.get([2 * n], ir.F64Type.get())
        ]
        workspace_layouts = [
            ir.DenseIntElementsAttr.get(np.array([0, 1]),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get())
        ]
        eigvals = [
            ir.RankedTensorType.get(batch_dims + (n, ),
                                    ir.ComplexType.get(ir.F64Type.get()))
        ]
        eigvals_layouts = [
            ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1),
                                        type=ir.IndexType.get())
        ]
    else:
        raise NotImplementedError("Unsupported dtype {}".format(dtype))

    i32_type = ir.IntegerType.get_signless(32)
    scalar_layout = ir.DenseIntElementsAttr.get(np.zeros((0, ), np.int64),
                                                type=ir.IndexType.get())
    info_layout = ir.DenseIntElementsAttr.get(np.arange(num_bd - 1, -1, -1),
                                              type=ir.IndexType.get())
    out = mhlo.CustomCallOp(
        [
            ir.TupleType.get_tuple(workspaces + eigvals + [
                ir.RankedTensorType.get(dims, eigvecs_type),
                ir.RankedTensorType.get(dims, eigvecs_type),
                ir.RankedTensorType.get(batch_dims, i32_type),
            ])
        ],
        [_mhlo_s32(b),
         _mhlo_s32(n),
         _mhlo_u8(jobvl_c),
         _mhlo_u8(jobvr_c), a],
        call_target_name=ir.StringAttr.get(fn),
        has_side_effect=ir.BoolAttr.get(False),
        backend_config=ir.StringAttr.get(""),
        api_version=ir.IntegerAttr.get(i32_type, 2),
        called_computations=ir.ArrayAttr.get([]),
        operand_layouts=ir.ArrayAttr.get([scalar_layout] * 4 + [layout]),
        result_layouts=ir.ArrayAttr.get(workspace_layouts + eigvals_layouts +
                                        [layout] * 2 + [info_layout])).result
    i32_attr = lambda i: ir.IntegerAttr.get(i32_type, i)
    if real:
        return (mhlo.ComplexOp(mhlo.GetTupleElementOp(out, i32_attr(3)),
                               mhlo.GetTupleElementOp(out,
                                                      i32_attr(4))).result,
                mhlo.GetTupleElementOp(out, i32_attr(5)).result,
                mhlo.GetTupleElementOp(out, i32_attr(6)).result,
                mhlo.GetTupleElementOp(out, i32_attr(7)).result)
    else:
        return (mhlo.GetTupleElementOp(out, i32_attr(2)).result,
                mhlo.GetTupleElementOp(out, i32_attr(3)).result,
                mhlo.GetTupleElementOp(out, i32_attr(4)).result,
                mhlo.GetTupleElementOp(out, i32_attr(5)).result)