Ejemplo n.º 1
0
def potrf_mhlo(dtype, a, lower=False):
    a_type = ir.RankedTensorType(a.type)
    dims = a_type.shape
    m, n = dims[-2:]
    if m != n:
        raise ValueError(f"potrf expects a square matrix, got {a_type}")
    if dtype == np.float32:
        fn = b"lapack_spotrf"
    elif dtype == np.float64:
        fn = b"lapack_dpotrf"
    elif dtype == np.complex64:
        fn = b"lapack_cpotrf"
    elif dtype == np.complex128:
        fn = b"lapack_zpotrf"
    else:
        raise NotImplementedError(f"Unsupported dtype {dtype}")
    batch_dims = tuple(dims[:-2])
    num_bd = len(batch_dims)
    b = 1
    for d in batch_dims:
        b *= d

    scalar_layout = []
    layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
    info_layout = tuple(range(num_bd - 1, -1, -1))
    out = custom_call(fn, [
        a.type,
        ir.RankedTensorType.get(batch_dims, ir.IntegerType.get_signless(32))
    ], [_mhlo_s32(int(lower)),
        _mhlo_s32(b), _mhlo_s32(n), a],
                      operand_layouts=[scalar_layout] * 3 + [layout],
                      result_layouts=[layout, info_layout])
    return out[:2]
Ejemplo n.º 2
0
def _geqrf_mhlo(platform, gpu_solver, dtype, a):
    """QR decomposition."""
    a_type = ir.RankedTensorType(a.type)
    dims = a_type.shape
    assert len(dims) >= 2
    m, n = dims[-2:]
    batch_dims = tuple(dims[:-2])
    num_bd = len(batch_dims)
    batch = _prod(batch_dims)

    lwork, opaque = gpu_solver.build_geqrf_descriptor(np.dtype(dtype), batch,
                                                      m, n)

    layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
    i32_type = ir.IntegerType.get_signless(32)
    out = custom_call(f"{platform}solver_geqrf", [
        a.type,
        ir.RankedTensorType.get(batch_dims +
                                (min(m, n), ), a_type.element_type),
        ir.RankedTensorType.get(batch_dims, i32_type),
        ir.RankedTensorType.get([lwork], a_type.element_type),
    ], [a],
                      backend_config=opaque,
                      operand_layouts=[layout],
                      result_layouts=[
                          layout,
                          tuple(range(num_bd, -1, -1)),
                          tuple(range(num_bd - 1, -1, -1)),
                          [0],
                      ])
    return out[:3]
Ejemplo n.º 3
0
def _potrf_mhlo(platform, gpu_solver, dtype, a, lower):
    """Cholesky decomposition."""
    a_type = ir.RankedTensorType(a.type)
    dims = a_type.shape
    m, n = dims[-2:]
    assert m == n
    batch_dims = tuple(dims[:-2])
    num_bd = len(batch_dims)
    batch = _prod(batch_dims)

    lwork, opaque = gpu_solver.build_potrf_descriptor(np.dtype(dtype), lower,
                                                      batch, n)

    layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
    info_layout = tuple(range(num_bd - 1, -1, -1))
    i32_type = ir.IntegerType.get_signless(32)
    info_type = ir.RankedTensorType.get(batch_dims, i32_type)
    work_layout = [0]
    out = custom_call(f"{platform}solver_potrf", [
        a.type,
        info_type,
        ir.RankedTensorType.get([lwork], ir.IntegerType.get_signless(8)),
    ], [a],
                      backend_config=opaque,
                      operand_layouts=[layout],
                      result_layouts=[layout, info_layout, work_layout])
    return out[:2]
Ejemplo n.º 4
0
def csr_matmat_mhlo(data,
                    indices,
                    indptr,
                    B,
                    *,
                    shape,
                    transpose=False,
                    compute_dtype=None,
                    compute_type=None,
                    index_dtype,
                    data_dtype,
                    B_dtype):
    """CSR from dense matrix."""
    data_type, index_type, nnz = _validate_csr_mhlo(data, indices, indptr,
                                                    shape)
    rows, cols = shape
    B_shape = ir.RankedTensorType(B.type).shape
    _, Ccols = B_shape

    if compute_dtype is None:
        compute_dtype = data_dtype
        compute_type = data_type

    buffer_size, opaque = _hipsparse.build_csr_matmat_descriptor(
        data_dtype, B_dtype, compute_dtype, index_dtype, rows, cols, Ccols,
        nnz, transpose)
    out_size = cols if transpose else rows

    i32_type = ir.IntegerType.get_signless(32)
    out = mhlo.CustomCallOp(
        [
            ir.TupleType.get_tuple([
                ir.RankedTensorType.get([out_size, Ccols], compute_type),
                ir.RankedTensorType.get([buffer_size],
                                        ir.IntegerType.get_signless(8)),
            ])
        ], [data, indices, indptr, B],
        call_target_name=ir.StringAttr.get("hipsparse_csr_matmat"),
        has_side_effect=ir.BoolAttr.get(False),
        backend_config=ir.StringAttr.get(opaque),
        api_version=ir.IntegerAttr.get(i32_type, 2),
        called_computations=ir.ArrayAttr.get([]),
        operand_layouts=ir.ArrayAttr.get([
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([1, 0]),
                                        type=ir.IndexType.get()),
        ]),
        result_layouts=ir.ArrayAttr.get([
            ir.DenseIntElementsAttr.get(np.array([1, 0]),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
        ]))
    return mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 0)).result
Ejemplo n.º 5
0
def syevd_mhlo(dtype, a, lower=False):
    """Symmetric (Hermitian) eigendecomposition."""
    a_type = ir.RankedTensorType(a.type)
    dims = a_type.shape
    assert len(dims) >= 2
    m, n = dims[-2:]
    assert m == n
    batch_dims = tuple(dims[:-2])
    num_bd = len(batch_dims)
    batch = _prod(batch_dims)
    layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))

    if n <= 32:
        kernel = b"cusolver_syevj"
        lwork, opaque = _cusolver.build_syevj_descriptor(
            np.dtype(dtype), lower, batch, n)
    else:
        kernel = b"cusolver_syevd"
        lwork, opaque = _cusolver.build_syevd_descriptor(
            np.dtype(dtype), lower, batch, n)

    if ir.ComplexType.isinstance(a_type.element_type):
        eigvals_type = ir.ComplexType(a_type.element_type).element_type
    else:
        eigvals_type = a_type.element_type

    layout = ir.DenseIntElementsAttr.get(
        np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
        type=ir.IndexType.get())
    i32_type = ir.IntegerType.get_signless(32)
    out = mhlo.CustomCallOp(
        [
            ir.TupleType.get_tuple([
                a.type,
                ir.RankedTensorType.get(batch_dims + (n, ), eigvals_type),
                ir.RankedTensorType.get(batch_dims, i32_type),
                ir.RankedTensorType.get([lwork], a_type.element_type),
            ])
        ], [a],
        call_target_name=ir.StringAttr.get(kernel),
        has_side_effect=ir.BoolAttr.get(False),
        backend_config=ir.StringAttr.get(opaque),
        api_version=ir.IntegerAttr.get(i32_type, 2),
        called_computations=ir.ArrayAttr.get([]),
        operand_layouts=ir.ArrayAttr.get([layout]),
        result_layouts=ir.ArrayAttr.get([
            layout,
            ir.DenseIntElementsAttr.get(np.array(range(num_bd, -1, -1)),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array(range(num_bd - 1, -1, -1)),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
        ]))
    return [
        mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
        for i in range(3)
    ]
Ejemplo n.º 6
0
def trsm_mhlo(dtype, a, b, left_side=False, lower=False, trans_a=False,
              conj_a=False, diag=False):
  """Batched triangular solve.

  XLA implements unbatched triangular solve directly, so we need only implement
  the batched case."""
  b_type = ir.RankedTensorType(b.type)
  dims = b_type.shape
  assert len(dims) >= 2
  m, n = dims[-2:]
  batch_dims = tuple(dims[:-2])
  num_bd = len(batch_dims)
  batch = _prod(batch_dims)
  k = m if left_side else n

  a_type = ir.RankedTensorType(a.type)
  if (batch_dims + (k, k) != tuple(a_type.shape) or
      a_type.element_type != b_type.element_type):
    raise ValueError("Argument mismatch for trsm, got {} and {}".format(
      a_type, b_type))

  if conj_a and not trans_a:
    raise NotImplementedError("Conjugation without transposition not supported")

  lwork, opaque = _hipblas.build_trsm_batched_descriptor(
    np.dtype(dtype), batch, m, n, left_side, lower, trans_a, conj_a, diag)
  layout = ir.DenseIntElementsAttr.get(
      np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
      type=ir.IndexType.get())
  work_type = ir.RankedTensorType.get([lwork], ir.IntegerType.get_signless(8))
  work_layout = ir.DenseIntElementsAttr.get(np.array([0]),
                                            type=ir.IndexType.get())
  i32_type = ir.IntegerType.get_signless(32)
  tup = mhlo.CustomCallOp(
      [ir.TupleType.get_tuple([b_type, work_type, work_type])],
      [a, b],
      call_target_name = ir.StringAttr.get("hipblas_trsm_batched"),
      has_side_effect=ir.BoolAttr.get(False),
      backend_config=ir.StringAttr.get(opaque),
      api_version=ir.IntegerAttr.get(i32_type, 2),
      called_computations=ir.ArrayAttr.get([]),
      operand_layouts=ir.ArrayAttr.get([layout] * 2),
      result_layouts=ir.ArrayAttr.get([layout, work_layout, work_layout])).result
  return mhlo.GetTupleElementOp(
      tup, ir.IntegerAttr.get(i32_type, 0)).result
Ejemplo n.º 7
0
def trsm_mhlo(dtype, alpha, a, b, left_side=False, lower=False, trans_a=False,
               conj_a=False, diag=False):
  a_type = ir.RankedTensorType(a.type)
  b_type = ir.RankedTensorType(b.type)

  dims = b_type.shape
  m, n = dims[-2:]
  k = m if left_side else n

  batch_dims = tuple(dims[:-2])
  num_bd = len(batch_dims)
  num_b = 1
  for d in batch_dims:
    num_b *= d

  if (batch_dims + (k, k) != tuple(a_type.shape) or
      a_type.element_type != b_type.element_type):
    raise ValueError("Argument mismatch for trsm, got {} and {}".format(
      a_type, b_type))

  if dtype == np.float32:
    fn = "blas_strsm"
  elif dtype == np.float64:
    fn = "blas_dtrsm"
  elif dtype == np.complex64:
    fn = "blas_ctrsm"
  elif dtype == np.complex128:
    fn = "blas_ztrsm"
  else:
    raise NotImplementedError(f"Unsupported dtype {dtype}")

  if conj_a and not trans_a:
    raise NotImplementedError("Conjugation without transposition not supported")
  scalar_layout = []
  layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
  return custom_call(
      fn,
      [b.type],
      [_mhlo_s32(int(left_side)), _mhlo_s32(int(lower)),
       _mhlo_s32((2 if conj_a else 1) if trans_a else 0), _mhlo_s32(int(diag)),
       _mhlo_s32(m), _mhlo_s32(n), _mhlo_s32(num_b),
       alpha, a, b],
      operand_layouts=[scalar_layout] * 8 + [layout] * 2,
      result_layouts=[layout])
Ejemplo n.º 8
0
def getrf_mhlo(dtype, a):
    """LU decomposition."""
    a_type = ir.RankedTensorType(a.type)
    dims = a_type.shape
    assert len(dims) >= 2
    m, n = dims[-2:]
    batch_dims = tuple(dims[:-2])
    num_bd = len(batch_dims)
    batch = _prod(batch_dims)

    if batch > 1 and m == n and m // batch <= 128:
        lwork, opaque = _hipblas.build_getrf_batched_descriptor(
            np.dtype(dtype), batch, m)
        workspace = ir.RankedTensorType.get([lwork],
                                            ir.IntegerType.get_signless(8))
        kernel = "hipblas_getrf_batched"
    else:
        lwork, opaque = _hipsolver.build_getrf_descriptor(
            np.dtype(dtype), batch, m, n)
        workspace = ir.RankedTensorType.get([lwork], a_type.element_type)
        kernel = "hipsolver_getrf"

    layout = ir.DenseIntElementsAttr.get(
        np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
        type=ir.IndexType.get())
    i32_type = ir.IntegerType.get_signless(32)
    out = mhlo.CustomCallOp(
        [
            ir.TupleType.get_tuple([
                a.type,
                ir.RankedTensorType.get(batch_dims + (min(m, n), ), i32_type),
                ir.RankedTensorType.get(batch_dims, i32_type),
                workspace,
            ])
        ], [a],
        call_target_name=ir.StringAttr.get(kernel),
        has_side_effect=ir.BoolAttr.get(False),
        backend_config=ir.StringAttr.get(opaque),
        api_version=ir.IntegerAttr.get(i32_type, 2),
        called_computations=ir.ArrayAttr.get([]),
        operand_layouts=ir.ArrayAttr.get([layout]),
        result_layouts=ir.ArrayAttr.get([
            layout,
            ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.arange(num_bd - 1, -1, -1),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
        ]))
    return [
        mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
        for i in range(3)
    ]
Ejemplo n.º 9
0
def getrf_mhlo(dtype, a):
    dims = ir.RankedTensorType(a.type).shape
    assert len(dims) >= 2
    m, n = dims[-2:]
    batch_dims = tuple(dims[:-2])
    num_bd = len(batch_dims)
    b = 1
    for d in batch_dims:
        b *= d

    if dtype == np.float32:
        fn = b"lapack_sgetrf"
    elif dtype == np.float64:
        fn = b"lapack_dgetrf"
    elif dtype == np.complex64:
        fn = b"lapack_cgetrf"
    elif dtype == np.complex128:
        fn = b"lapack_zgetrf"
    else:
        raise NotImplementedError("Unsupported dtype {}".format(dtype))

    scalar_layout = ir.DenseIntElementsAttr.get(np.zeros((0, ), np.int64),
                                                type=ir.IndexType.get())
    layout = ir.DenseIntElementsAttr.get(
        np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
        type=ir.IndexType.get())
    i32_type = ir.IntegerType.get_signless(32)
    out = mhlo.CustomCallOp(
        [
            ir.TupleType.get_tuple([
                a.type,
                ir.RankedTensorType.get(batch_dims + (min(m, n), ), i32_type),
                ir.RankedTensorType.get(batch_dims, i32_type),
            ])
        ], [_mhlo_s32(int(b)),
            _mhlo_s32(m), _mhlo_s32(n), a],
        call_target_name=ir.StringAttr.get(fn),
        has_side_effect=ir.BoolAttr.get(False),
        backend_config=ir.StringAttr.get(""),
        api_version=ir.IntegerAttr.get(i32_type, 2),
        called_computations=ir.ArrayAttr.get([]),
        operand_layouts=ir.ArrayAttr.get([scalar_layout] * 3 + [layout]),
        result_layouts=ir.ArrayAttr.get([
            layout,
            ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.arange(num_bd - 1, -1, -1),
                                        type=ir.IndexType.get()),
        ]))
    return [
        mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
        for i in range(3)
    ]
Ejemplo n.º 10
0
def _threefry2x32_lowering(prng, platform, keys, data):
    """ThreeFry2x32 kernel for GPU."""
    assert len(keys) == 2, keys
    assert len(data) == 2, data
    assert (ir.RankedTensorType(
        keys[0].type).element_type == ir.IntegerType.get_unsigned(32)
            ), keys[0].type
    typ = keys[0].type
    dims = ir.RankedTensorType(typ).shape

    for x in itertools.chain(keys, data):
        assert x.type == typ, (x.type, typ)
    ndims = len(dims)

    opaque = prng.threefry2x32_descriptor(_prod(dims))
    layout = tuple(range(ndims - 1, -1, -1))
    return custom_call(f"{platform}_threefry2x32", [typ, typ],
                       [keys[0], keys[1], data[0], data[1]],
                       backend_config=opaque,
                       operand_layouts=[layout] * 4,
                       result_layouts=[layout] * 2)
Ejemplo n.º 11
0
def potrf_mhlo(dtype, a, lower=False):
    a_type = ir.RankedTensorType(a.type)
    dims = a_type.shape
    m, n = dims[-2:]
    if m != n:
        raise ValueError(
            "potrf expects a square matrix, got {}".format(a_type))
    if dtype == np.float32:
        fn = b"lapack_spotrf"
    elif dtype == np.float64:
        fn = b"lapack_dpotrf"
    elif dtype == np.complex64:
        fn = b"lapack_cpotrf"
    elif dtype == np.complex128:
        fn = b"lapack_zpotrf"
    else:
        raise NotImplementedError("Unsupported dtype {}".format(dtype))
    batch_dims = tuple(dims[:-2])
    num_bd = len(batch_dims)
    b = 1
    for d in batch_dims:
        b *= d

    scalar_layout = ir.DenseIntElementsAttr.get(np.zeros((0, ), np.int64),
                                                type=ir.IndexType.get())
    layout = ir.DenseIntElementsAttr.get(
        np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
        type=ir.IndexType.get())
    i32_type = ir.IntegerType.get_signless(32)
    info_layout = ir.DenseIntElementsAttr.get(np.array(
        range(num_bd - 1, -1, -1)),
                                              type=ir.IndexType.get())
    tup = mhlo.CustomCallOp(
        [
            ir.TupleType.get_tuple([
                a.type,
                ir.RankedTensorType.get(batch_dims, i32_type),
            ])
        ], [_mhlo_s32(int(lower)),
            _mhlo_s32(b), _mhlo_s32(n), a],
        call_target_name=ir.StringAttr.get(fn),
        has_side_effect=ir.BoolAttr.get(False),
        backend_config=ir.StringAttr.get(""),
        api_version=ir.IntegerAttr.get(i32_type, 2),
        called_computations=ir.ArrayAttr.get([]),
        operand_layouts=ir.ArrayAttr.get([scalar_layout] * 3 + [layout]),
        result_layouts=ir.ArrayAttr.get([layout, info_layout])).result
    return [
        mhlo.GetTupleElementOp(tup, ir.IntegerAttr.get(i32_type, 0)).result,
        mhlo.GetTupleElementOp(tup, ir.IntegerAttr.get(i32_type, 1)).result,
    ]
Ejemplo n.º 12
0
def geqrf_mhlo(dtype, a):
    a_type = ir.RankedTensorType(a.type)
    dims = a_type.shape
    assert len(dims) >= 2
    m, n = dims[-2:]
    batch_dims = tuple(dims[:-2])
    num_bd = len(batch_dims)
    b = 1
    for d in batch_dims:
        b *= d

    if dtype == np.float32:
        fn = b"lapack_sgeqrf"
        lwork = _lapack.lapack_sgeqrf_workspace(m, n)
    elif dtype == np.float64:
        fn = b"lapack_dgeqrf"
        lwork = _lapack.lapack_dgeqrf_workspace(m, n)
    elif dtype == np.complex64:
        fn = b"lapack_cgeqrf"
        lwork = _lapack.lapack_cgeqrf_workspace(m, n)
    elif dtype == np.complex128:
        fn = b"lapack_zgeqrf"
        lwork = _lapack.lapack_zgeqrf_workspace(m, n)
    else:
        raise NotImplementedError(f"Unsupported dtype {dtype}")

    scalar_layout = []
    layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
    i32_type = ir.IntegerType.get_signless(32)
    out = custom_call(fn, [
        a.type,
        ir.RankedTensorType.get(batch_dims +
                                (min(m, n), ), a_type.element_type),
        ir.RankedTensorType.get(batch_dims, i32_type),
        ir.RankedTensorType.get([lwork], a_type.element_type),
    ], [_mhlo_s32(int(b)),
        _mhlo_s32(m),
        _mhlo_s32(n),
        _mhlo_s32(lwork), a],
                      operand_layouts=[scalar_layout] * 4 + [layout],
                      result_layouts=[
                          layout,
                          tuple(range(num_bd, -1, -1)),
                          tuple(range(num_bd - 1, -1, -1)),
                          [0],
                      ])
    return out[:3]
Ejemplo n.º 13
0
def _syevd_mhlo(platform,
                gpu_solver,
                have_jacobi_solver,
                dtype,
                a,
                lower=False):
    """Symmetric (Hermitian) eigendecomposition."""
    a_type = ir.RankedTensorType(a.type)
    dims = a_type.shape
    assert len(dims) >= 2
    m, n = dims[-2:]
    assert m == n
    batch_dims = tuple(dims[:-2])
    num_bd = len(batch_dims)
    batch = _prod(batch_dims)
    layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))

    if have_jacobi_solver and n <= 32:
        kernel = f"{platform}solver_syevj"
        lwork, opaque = gpu_solver.build_syevj_descriptor(
            np.dtype(dtype), lower, batch, n)
    else:
        kernel = f"{platform}solver_syevd"
        lwork, opaque = gpu_solver.build_syevd_descriptor(
            np.dtype(dtype), lower, batch, n)

    if ir.ComplexType.isinstance(a_type.element_type):
        eigvals_type = ir.ComplexType(a_type.element_type).element_type
    else:
        eigvals_type = a_type.element_type

    i32_type = ir.IntegerType.get_signless(32)
    out = custom_call(kernel, [
        a.type,
        ir.RankedTensorType.get(batch_dims + (n, ), eigvals_type),
        ir.RankedTensorType.get(batch_dims, i32_type),
        ir.RankedTensorType.get([lwork], a_type.element_type),
    ], [a],
                      backend_config=opaque,
                      operand_layouts=[layout],
                      result_layouts=[
                          layout,
                          tuple(range(num_bd, -1, -1)),
                          tuple(range(num_bd - 1, -1, -1)),
                          [0],
                      ])
    return out[:3]
Ejemplo n.º 14
0
def potrf_mhlo(dtype, a, lower):
    """Cholesky decomposition."""
    a_type = ir.RankedTensorType(a.type)
    dims = a_type.shape
    m, n = dims[-2:]
    assert m == n
    batch_dims = tuple(dims[:-2])
    num_bd = len(batch_dims)
    batch = _prod(batch_dims)

    lwork, opaque = _hipsolver.build_potrf_descriptor(np.dtype(dtype), lower,
                                                      batch, n)
    kernel = b"hipsolver_potrf"

    layout = ir.DenseIntElementsAttr.get(
        np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
        type=ir.IndexType.get())
    info_layout = ir.DenseIntElementsAttr.get(np.array(
        range(num_bd - 1, -1, -1)),
                                              type=ir.IndexType.get())
    i32_type = ir.IntegerType.get_signless(32)
    info_type = ir.RankedTensorType.get(batch_dims, i32_type)
    work_layout = ir.DenseIntElementsAttr.get(np.array([0]),
                                              type=ir.IndexType.get())
    tup = mhlo.CustomCallOp([
        ir.TupleType.get_tuple([
            a.type,
            info_type,
            ir.RankedTensorType.get([lwork], ir.IntegerType.get_signless(8)),
        ])
    ], [a],
                            call_target_name=ir.StringAttr.get(kernel),
                            has_side_effect=ir.BoolAttr.get(False),
                            backend_config=ir.StringAttr.get(opaque),
                            api_version=ir.IntegerAttr.get(
                                ir.IntegerType.get_signless(32), 2),
                            called_computations=ir.ArrayAttr.get([]),
                            operand_layouts=ir.ArrayAttr.get([layout]),
                            result_layouts=ir.ArrayAttr.get(
                                [layout, info_layout, work_layout])).result
    return [
        mhlo.GetTupleElementOp(tup, ir.IntegerAttr.get(i32_type, 0)).result,
        mhlo.GetTupleElementOp(tup, ir.IntegerAttr.get(i32_type, 1)).result,
    ]
Ejemplo n.º 15
0
def _csr_fromdense_mhlo(platform, gpu_sparse, mat, *, nnz, index_dtype,
                        data_dtype, index_type):
    """CSR from dense matrix."""
    mat_type = ir.RankedTensorType(mat.type)
    rows, cols = mat_type.shape

    buffer_size, opaque = gpu_sparse.build_csr_fromdense_descriptor(
        data_dtype, index_dtype, rows, cols, nnz)

    out = custom_call(f"{platform}sparse_csr_fromdense", [
        ir.RankedTensorType.get([nnz], mat_type.element_type),
        ir.RankedTensorType.get([nnz], index_type),
        ir.RankedTensorType.get([rows + 1], index_type),
        ir.RankedTensorType.get([buffer_size], ir.IntegerType.get_signless(8)),
    ], [mat],
                      backend_config=opaque,
                      operand_layouts=[[1, 0]],
                      result_layouts=[[0]] * 4)
    return out[:3]
Ejemplo n.º 16
0
def _getrf_mhlo(platform, gpu_blas, gpu_solver, dtype, a):
    """LU decomposition."""
    a_type = ir.RankedTensorType(a.type)
    dims = a_type.shape
    assert len(dims) >= 2
    m, n = dims[-2:]
    batch_dims = tuple(dims[:-2])
    num_bd = len(batch_dims)
    batch = _prod(batch_dims)

    if batch > 1 and m == n and m // batch <= 128:
        lwork, opaque = gpu_blas.build_getrf_batched_descriptor(
            np.dtype(dtype), batch, m)
        workspace = ir.RankedTensorType.get([lwork],
                                            ir.IntegerType.get_signless(8))
        kernel = f"{platform}blas_getrf_batched"
    else:
        lwork, opaque = gpu_solver.build_getrf_descriptor(
            np.dtype(dtype), batch, m, n)
        workspace = ir.RankedTensorType.get([lwork], a_type.element_type)
        kernel = f"{platform}solver_getrf"

    layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
    i32_type = ir.IntegerType.get_signless(32)
    out = custom_call(kernel, [
        a.type,
        ir.RankedTensorType.get(batch_dims + (min(m, n), ), i32_type),
        ir.RankedTensorType.get(batch_dims, i32_type),
        workspace,
    ], [a],
                      backend_config=opaque,
                      operand_layouts=[layout],
                      result_layouts=[
                          layout,
                          tuple(range(num_bd, -1, -1)),
                          tuple(range(num_bd - 1, -1, -1)),
                          [0],
                      ])
    return out[:3]
Ejemplo n.º 17
0
def _csr_matmat_mhlo(platform,
                     gpu_sparse,
                     data,
                     indices,
                     indptr,
                     B,
                     *,
                     shape,
                     transpose=False,
                     compute_dtype=None,
                     compute_type=None,
                     index_dtype,
                     data_dtype,
                     B_dtype):
    """CSR from dense matrix."""
    data_type, index_type, nnz = _validate_csr_mhlo(data, indices, indptr,
                                                    shape)
    rows, cols = shape
    B_shape = ir.RankedTensorType(B.type).shape
    _, Ccols = B_shape

    if compute_dtype is None:
        compute_dtype = data_dtype
        compute_type = data_type

    buffer_size, opaque = gpu_sparse.build_csr_matmat_descriptor(
        data_dtype, B_dtype, compute_dtype, index_dtype, rows, cols, Ccols,
        nnz, transpose)
    out_size = cols if transpose else rows

    out = custom_call(f"{platform}sparse_csr_matmat", [
        ir.RankedTensorType.get([out_size, Ccols], compute_type),
        ir.RankedTensorType.get([buffer_size], ir.IntegerType.get_signless(8)),
    ], [data, indices, indptr, B],
                      backend_config=opaque,
                      operand_layouts=[[0], [0], [0], [1, 0]],
                      result_layouts=[[1, 0], [0]])
    return out[0]
Ejemplo n.º 18
0
def coo_fromdense_mhlo(mat, *, nnz, data_dtype, index_dtype, index_type):
    """COO from dense matrix."""
    mat_type = ir.RankedTensorType(mat.type)
    rows, cols = mat_type.shape

    buffer_size, opaque = _hipsparse.build_coo_fromdense_descriptor(
        data_dtype, index_dtype, rows, cols, nnz)

    i32_type = ir.IntegerType.get_signless(32)
    out = mhlo.CustomCallOp(
        [
            ir.TupleType.get_tuple([
                ir.RankedTensorType.get([nnz], mat_type.element_type),
                ir.RankedTensorType.get([nnz], index_type),
                ir.RankedTensorType.get([nnz], index_type),
                ir.RankedTensorType.get([buffer_size],
                                        ir.IntegerType.get_signless(8)),
            ])
        ], [mat],
        call_target_name=ir.StringAttr.get("hipsparse_coo_fromdense"),
        has_side_effect=ir.BoolAttr.get(False),
        backend_config=ir.StringAttr.get(opaque),
        api_version=ir.IntegerAttr.get(i32_type, 2),
        called_computations=ir.ArrayAttr.get([]),
        operand_layouts=ir.ArrayAttr.get([
            ir.DenseIntElementsAttr.get(np.array([1, 0]),
                                        type=ir.IndexType.get()),
        ]),
        result_layouts=ir.ArrayAttr.get([
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
        ] * 4))
    return [
        mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
        for i in range(3)
    ]
Ejemplo n.º 19
0
def _lu_pivots_to_permutation_mhlo(platform, gpu_linalg, pivots, *,
                                   permutation_size):
    """Kernel for the transformation of pivots to permutations on GPU."""
    typ = ir.RankedTensorType(pivots.type)
    dims = typ.shape
    i32_type = ir.IntegerType.get_signless(32)

    assert typ.element_type == i32_type, typ

    batch_size = _prod(dims[:-1])
    pivot_size = dims[-1]

    opaque = gpu_linalg.lu_pivots_to_permutation_descriptor(
        batch_size, pivot_size, permutation_size)
    pivots_layout = tuple(range(len(dims) - 1, -1, -1))
    permutations_layout = pivots_layout
    permutations_dims = list(dims)
    permutations_dims[-1] = permutation_size
    permutations_type = ir.RankedTensorType.get(permutations_dims, i32_type)
    return custom_call(f"{platform}_lu_pivots_to_permutation",
                       [permutations_type], [pivots],
                       backend_config=opaque,
                       operand_layouts=[pivots_layout],
                       result_layouts=[permutations_layout])
Ejemplo n.º 20
0
def syevd_mhlo(dtype, a, lower=False):
    a_type = ir.RankedTensorType(a.type)
    dims = a_type.shape
    assert len(dims) >= 2
    m, n = dims[-2:]
    assert m == n
    batch_dims = tuple(dims[:-2])
    num_bd = len(batch_dims)
    b = 1
    for d in batch_dims:
        b *= d
    layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))

    i32_type = ir.IntegerType.get_signless(32)
    if dtype == np.float32:
        fn = b"lapack_ssyevd"
        eigvals_type = ir.F32Type.get()
        workspace = [
            ir.RankedTensorType.get([_lapack.syevd_work_size(n)],
                                    a_type.element_type),
            ir.RankedTensorType.get([_lapack.syevd_iwork_size(n)], i32_type),
        ]
        workspace_layouts = [
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
        ]
    elif dtype == np.float64:
        fn = b"lapack_dsyevd"
        eigvals_type = ir.F64Type.get()
        workspace = [
            ir.RankedTensorType.get([_lapack.syevd_work_size(n)],
                                    a_type.element_type),
            ir.RankedTensorType.get([_lapack.syevd_iwork_size(n)], i32_type),
        ]
        workspace_layouts = [
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
        ]
    elif dtype == np.complex64:
        fn = b"lapack_cheevd"
        eigvals_type = ir.F32Type.get()
        workspace = [
            ir.RankedTensorType.get([_lapack.heevd_work_size(n)],
                                    a_type.element_type),
            ir.RankedTensorType.get([_lapack.heevd_rwork_size(n)],
                                    eigvals_type),
            ir.RankedTensorType.get([_lapack.syevd_iwork_size(n)], i32_type),
        ]
        workspace_layouts = [
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
        ]
    elif dtype == np.complex128:
        fn = b"lapack_zheevd"
        eigvals_type = ir.F64Type.get()
        workspace = [
            ir.RankedTensorType.get([_lapack.heevd_work_size(n)],
                                    a_type.element_type),
            ir.RankedTensorType.get([_lapack.heevd_rwork_size(n)],
                                    eigvals_type),
            ir.RankedTensorType.get([_lapack.syevd_iwork_size(n)], i32_type),
        ]
        workspace_layouts = [
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
        ]
    else:
        raise NotImplementedError("Unsupported dtype {}".format(dtype))

    scalar_layout = ir.DenseIntElementsAttr.get(np.zeros((0, ), np.int64),
                                                type=ir.IndexType.get())
    layout = ir.DenseIntElementsAttr.get(
        np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
        type=ir.IndexType.get())
    out = mhlo.CustomCallOp(
        [
            ir.TupleType.get_tuple([
                a.type,
                ir.RankedTensorType.get(batch_dims + (n, ), eigvals_type),
                ir.RankedTensorType.get(batch_dims, i32_type),
            ] + workspace)
        ], [_mhlo_s32(1 if lower else 0),
            _mhlo_s32(b),
            _mhlo_s32(n), a],
        call_target_name=ir.StringAttr.get(fn),
        has_side_effect=ir.BoolAttr.get(False),
        backend_config=ir.StringAttr.get(""),
        api_version=ir.IntegerAttr.get(i32_type, 2),
        called_computations=ir.ArrayAttr.get([]),
        operand_layouts=ir.ArrayAttr.get([scalar_layout] * 3 + [layout]),
        result_layouts=ir.ArrayAttr.get([
            layout,
            ir.DenseIntElementsAttr.get(np.array(range(num_bd, -1, -1)),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array(range(num_bd - 1, -1, -1)),
                                        type=ir.IndexType.get()),
        ] + workspace_layouts))
    return [
        mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
        for i in range(3)
    ]
Ejemplo n.º 21
0
def gesvd_mhlo(dtype, a, full_matrices=True, compute_uv=True):
    """Singular value decomposition."""
    a_type = ir.RankedTensorType(a.type)
    dims = a_type.shape
    assert len(dims) >= 2
    m, n = dims[-2:]
    batch_dims = tuple(dims[:-2])
    num_bd = len(batch_dims)
    b = _prod(batch_dims)
    if ir.ComplexType.isinstance(a_type.element_type):
        singular_vals_type = ir.ComplexType(a_type.element_type).element_type
    else:
        singular_vals_type = a_type.element_type

    scalar_layout = ir.DenseIntElementsAttr.get(np.array(
        tuple(range(num_bd - 1, -1, -1))),
                                                type=ir.IndexType.get())
    vector_layout = ir.DenseIntElementsAttr.get(
        np.array((num_bd, ) + tuple(range(num_bd - 1, -1, -1))),
        type=ir.IndexType.get())
    matrix_layout = ir.DenseIntElementsAttr.get(
        np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
        type=ir.IndexType.get())
    i32_type = ir.IntegerType.get_signless(32)

    if m < n:
        lwork, opaque = _hipsolver.build_gesvd_descriptor(
            np.dtype(dtype), b, n, m, compute_uv, full_matrices)
        k = n if full_matrices else m
        out = mhlo.CustomCallOp(
            [
                ir.TupleType.get_tuple([
                    a.type,
                    ir.RankedTensorType.get(batch_dims +
                                            (min(m, n), ), singular_vals_type),
                    ir.RankedTensorType.get(batch_dims +
                                            (k, n), a_type.element_type),
                    ir.RankedTensorType.get(batch_dims +
                                            (m, m), a_type.element_type),
                    ir.RankedTensorType.get(batch_dims, i32_type),
                    ir.RankedTensorType.get([lwork], a_type.element_type),
                ])
            ], [a],
            call_target_name=ir.StringAttr.get("hipsolver_gesvd"),
            has_side_effect=ir.BoolAttr.get(False),
            backend_config=ir.StringAttr.get(opaque),
            api_version=ir.IntegerAttr.get(i32_type, 2),
            called_computations=ir.ArrayAttr.get([]),
            operand_layouts=ir.ArrayAttr.get([matrix_layout]),
            result_layouts=ir.ArrayAttr.get([
                matrix_layout,
                vector_layout,
                matrix_layout,
                matrix_layout,
                scalar_layout,
                ir.DenseIntElementsAttr.get(np.array([0]),
                                            type=ir.IndexType.get()),
            ]))
        s = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 1)).result
        vt = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type,
                                                            2)).result
        u = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 3)).result
        info = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type,
                                                              4)).result
    else:
        lwork, opaque = _hipsolver.build_gesvd_descriptor(
            np.dtype(dtype), b, m, n, compute_uv, full_matrices)
        k = m if full_matrices else n
        out = mhlo.CustomCallOp(
            [
                ir.TupleType.get_tuple([
                    a.type,
                    ir.RankedTensorType.get(batch_dims +
                                            (min(m, n), ), singular_vals_type),
                    ir.RankedTensorType.get(batch_dims +
                                            (m, k), a_type.element_type),
                    ir.RankedTensorType.get(batch_dims +
                                            (n, n), a_type.element_type),
                    ir.RankedTensorType.get(batch_dims, i32_type),
                    ir.RankedTensorType.get([lwork], a_type.element_type),
                ])
            ], [a],
            call_target_name=ir.StringAttr.get("hipsolver_gesvd"),
            has_side_effect=ir.BoolAttr.get(False),
            backend_config=ir.StringAttr.get(opaque),
            api_version=ir.IntegerAttr.get(i32_type, 2),
            called_computations=ir.ArrayAttr.get([]),
            operand_layouts=ir.ArrayAttr.get([matrix_layout]),
            result_layouts=ir.ArrayAttr.get([
                matrix_layout,
                vector_layout,
                matrix_layout,
                matrix_layout,
                scalar_layout,
                ir.DenseIntElementsAttr.get(np.array([0]),
                                            type=ir.IndexType.get()),
            ]))
        s = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 1)).result
        u = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 2)).result
        vt = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type,
                                                            3)).result
        info = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type,
                                                              4)).result
    return s, u, vt, info
Ejemplo n.º 22
0
def pocketfft_mhlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]):
    """PocketFFT kernel for CPU."""
    a_type = ir.RankedTensorType(a.type)
    n = len(a_type.shape)

    fft_lengths = list(fft_lengths)
    descriptor_bytes, out_dtype, out_shape = _pocketfft_descriptor(
        list(a_type.shape), dtype, fft_type, fft_lengths)

    if out_dtype == np.float32:
        out_type = ir.F32Type.get()
    elif out_dtype == np.float64:
        out_type = ir.F64Type.get()
    elif out_dtype == np.complex64:
        out_type = ir.ComplexType.get(ir.F32Type.get())
    elif out_dtype == np.complex128:
        out_type = ir.ComplexType.get(ir.F64Type.get())
    else:
        raise ValueError(f"Unknown output type {out_dtype}")

    if 0 in a_type.shape or 0 in out_shape:
        if xla_client._version >= 64:
            if jax._src.lib.mlir_api_version < 21:
                zero = mhlo.ConstOp(
                    ir.DenseElementsAttr.get(np.array(0, dtype=out_dtype),
                                             type=out_type))
            else:
                zero = mhlo.ConstantOp(
                    ir.DenseElementsAttr.get(np.array(0, dtype=out_dtype),
                                             type=out_type))
        else:
            if jax._src.lib.mlir_api_version < 21:
                zero = mhlo.ConstOp(
                    ir.RankedTensorType.get([], out_type),
                    ir.DenseElementsAttr.get(np.array(0, dtype=out_dtype),
                                             type=out_type))
            else:
                zero = mhlo.ConstantOp(
                    ir.RankedTensorType.get([], out_type),
                    ir.DenseElementsAttr.get(np.array(0, dtype=out_dtype),
                                             type=out_type))
        return mhlo.BroadcastOp(
            zero, ir.DenseElementsAttr.get(np.asarray(out_shape,
                                                      np.int64))).result

    u8_type = ir.IntegerType.get_unsigned(8)
    if xla_client._version >= 64:
        if jax._src.lib.mlir_api_version < 21:
            descriptor = mhlo.ConstOp(
                ir.DenseElementsAttr.get(np.frombuffer(descriptor_bytes,
                                                       dtype=np.uint8),
                                         type=u8_type))
        else:
            descriptor = mhlo.ConstantOp(
                ir.DenseElementsAttr.get(np.frombuffer(descriptor_bytes,
                                                       dtype=np.uint8),
                                         type=u8_type))
    else:
        if jax._src.lib.mlir_api_version < 21:
            descriptor = mhlo.ConstOp(
                ir.RankedTensorType.get([len(descriptor_bytes)], u8_type),
                ir.DenseElementsAttr.get(np.frombuffer(descriptor_bytes,
                                                       dtype=np.uint8),
                                         type=u8_type))
        else:
            descriptor = mhlo.ConstantOp(
                ir.RankedTensorType.get([len(descriptor_bytes)], u8_type),
                ir.DenseElementsAttr.get(np.frombuffer(descriptor_bytes,
                                                       dtype=np.uint8),
                                         type=u8_type))
    layout = tuple(range(n - 1, -1, -1))
    return custom_call("pocketfft",
                       [ir.RankedTensorType.get(out_shape, out_type)],
                       [descriptor, a],
                       operand_layouts=[[0], layout],
                       result_layouts=[layout])
Ejemplo n.º 23
0
def gesvd_mhlo(dtype, a, full_matrices=True, compute_uv=True):
    """Singular value decomposition."""
    a_type = ir.RankedTensorType(a.type)
    dims = a_type.shape
    assert len(dims) >= 2
    m, n = dims[-2:]
    batch_dims = tuple(dims[:-2])
    num_bd = len(batch_dims)
    b = _prod(batch_dims)
    if ir.ComplexType.isinstance(a_type.element_type):
        singular_vals_type = ir.ComplexType(a_type.element_type).element_type
    else:
        singular_vals_type = a_type.element_type

    scalar_layout = ir.DenseIntElementsAttr.get(np.array(
        tuple(range(num_bd - 1, -1, -1))),
                                                type=ir.IndexType.get())
    vector_layout = ir.DenseIntElementsAttr.get(
        np.array((num_bd, ) + tuple(range(num_bd - 1, -1, -1))),
        type=ir.IndexType.get())
    i32_type = ir.IntegerType.get_signless(32)

    if m < 32 and n < 32:
        # The batched kernel doesn't support "econ" mode.
        econ = not full_matrices and b == 1
        lwork, opaque = _cusolver.build_gesvdj_descriptor(
            np.dtype(dtype), b, m, n, compute_uv, 1 if econ else 0)
        k = min(m, n)
        matrix_layout = ir.DenseIntElementsAttr.get(
            np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
            type=ir.IndexType.get())
        out = mhlo.CustomCallOp(
            [
                ir.TupleType.get_tuple([
                    a.type,
                    ir.RankedTensorType.get(batch_dims +
                                            (min(m, n), ), singular_vals_type),
                    ir.RankedTensorType.get(batch_dims + (m, k if econ else m),
                                            a_type.element_type),
                    ir.RankedTensorType.get(batch_dims + (n, k if econ else n),
                                            a_type.element_type),
                    ir.RankedTensorType.get(batch_dims, i32_type),
                    ir.RankedTensorType.get([lwork], a_type.element_type),
                ])
            ], [a],
            call_target_name=ir.StringAttr.get("cusolver_gesvdj"),
            has_side_effect=ir.BoolAttr.get(False),
            backend_config=ir.StringAttr.get(opaque),
            api_version=ir.IntegerAttr.get(i32_type, 2),
            called_computations=ir.ArrayAttr.get([]),
            operand_layouts=ir.ArrayAttr.get([matrix_layout]),
            result_layouts=ir.ArrayAttr.get([
                matrix_layout,
                vector_layout,
                matrix_layout,
                matrix_layout,
                scalar_layout,
                ir.DenseIntElementsAttr.get(np.array([0]),
                                            type=ir.IndexType.get()),
            ]))
        s = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 1)).result
        u = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 2)).result
        v = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 3))
        info = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type,
                                                              4)).result
        vt = mhlo.TransposeOp(
            v,
            ir.DenseIntElementsAttr.get(
                np.array(tuple(range(num_bd)) + (num_bd + 1, num_bd)))).result
        if np.issubdtype(dtype, np.complexfloating):
            vt = mhlo.ComplexOp(mhlo.RealOp(vt),
                                mhlo.NegOp(mhlo.ImagOp(vt))).result
        if not full_matrices and not econ:
            u = mhlo.SliceOp(
                u, ir.DenseIntElementsAttr.get(np.zeros([len(dims)],
                                                        np.int64)),
                ir.DenseIntElementsAttr.get(
                    np.array(batch_dims + (m, min(m, n)))),
                ir.DenseIntElementsAttr.get(np.ones([len(dims)],
                                                    np.int64))).result
            vt = mhlo.SliceOp(
                vt, ir.DenseIntElementsAttr.get(np.zeros([len(dims)],
                                                         np.int64)),
                ir.DenseIntElementsAttr.get(
                    np.array(batch_dims + (min(m, n), n))),
                ir.DenseIntElementsAttr.get(np.ones([len(dims)],
                                                    np.int64))).result
    elif m < n:
        lwork, opaque = _cusolver.build_gesvd_descriptor(
            np.dtype(dtype), b, n, m, compute_uv, full_matrices)
        k = n if full_matrices else m
        matrix_layout = ir.DenseIntElementsAttr.get(
            np.array((num_bd + 1, num_bd) + tuple(range(num_bd - 1, -1, -1))),
            type=ir.IndexType.get())
        out = mhlo.CustomCallOp(
            [
                ir.TupleType.get_tuple([
                    a.type,
                    ir.RankedTensorType.get(batch_dims +
                                            (min(m, n), ), singular_vals_type),
                    ir.RankedTensorType.get(batch_dims +
                                            (k, n), a_type.element_type),
                    ir.RankedTensorType.get(batch_dims +
                                            (m, m), a_type.element_type),
                    ir.RankedTensorType.get(batch_dims, i32_type),
                    ir.RankedTensorType.get([lwork], a_type.element_type),
                ])
            ], [a],
            call_target_name=ir.StringAttr.get("cusolver_gesvd"),
            has_side_effect=ir.BoolAttr.get(False),
            backend_config=ir.StringAttr.get(opaque),
            api_version=ir.IntegerAttr.get(i32_type, 2),
            called_computations=ir.ArrayAttr.get([]),
            operand_layouts=ir.ArrayAttr.get([matrix_layout]),
            result_layouts=ir.ArrayAttr.get([
                matrix_layout,
                vector_layout,
                matrix_layout,
                matrix_layout,
                scalar_layout,
                ir.DenseIntElementsAttr.get(np.array([0]),
                                            type=ir.IndexType.get()),
            ]))
        s = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 1)).result
        vt = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type,
                                                            2)).result
        u = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 3)).result
        info = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type,
                                                              4)).result
    else:
        lwork, opaque = _cusolver.build_gesvd_descriptor(
            np.dtype(dtype), b, m, n, compute_uv, full_matrices)
        k = m if full_matrices else n
        matrix_layout = ir.DenseIntElementsAttr.get(
            np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
            type=ir.IndexType.get())
        out = mhlo.CustomCallOp(
            [
                ir.TupleType.get_tuple([
                    a.type,
                    ir.RankedTensorType.get(batch_dims +
                                            (min(m, n), ), singular_vals_type),
                    ir.RankedTensorType.get(batch_dims +
                                            (m, k), a_type.element_type),
                    ir.RankedTensorType.get(batch_dims +
                                            (n, n), a_type.element_type),
                    ir.RankedTensorType.get(batch_dims, i32_type),
                    ir.RankedTensorType.get([lwork], a_type.element_type),
                ])
            ], [a],
            call_target_name=ir.StringAttr.get("cusolver_gesvd"),
            has_side_effect=ir.BoolAttr.get(False),
            backend_config=ir.StringAttr.get(opaque),
            api_version=ir.IntegerAttr.get(i32_type, 2),
            called_computations=ir.ArrayAttr.get([]),
            operand_layouts=ir.ArrayAttr.get([matrix_layout]),
            result_layouts=ir.ArrayAttr.get([
                matrix_layout,
                vector_layout,
                matrix_layout,
                matrix_layout,
                scalar_layout,
                ir.DenseIntElementsAttr.get(np.array([0]),
                                            type=ir.IndexType.get()),
            ]))
        s = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 1)).result
        u = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 2)).result
        vt = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type,
                                                            3)).result
        info = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type,
                                                              4)).result
    return s, u, vt, info
Ejemplo n.º 24
0
def gesdd_mhlo(dtype, a, full_matrices=True, compute_uv=True):
    a_type = ir.RankedTensorType(a.type)
    dims = a_type.shape
    assert len(dims) >= 2
    m, n = dims[-2:]
    batch_dims = tuple(dims[:-2])
    num_bd = len(batch_dims)
    b = 1
    for d in batch_dims:
        b *= d

    i32_type = ir.IntegerType.get_signless(32)
    if dtype == np.float32:
        fn = b"lapack_sgesdd"
        singular_vals_type = ir.F32Type.get()
        lwork = _lapack.sgesdd_work_size(m, n, compute_uv, full_matrices)
        workspace = [
            ir.RankedTensorType.get([_lapack.gesdd_iwork_size(m, n)],
                                    i32_type),
            ir.RankedTensorType.get([lwork], a_type.element_type),
        ]
        workspace_layouts = [
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
        ]
    elif dtype == np.float64:
        fn = b"lapack_dgesdd"
        singular_vals_type = ir.F64Type.get()
        lwork = _lapack.dgesdd_work_size(m, n, compute_uv, full_matrices)
        workspace = [
            ir.RankedTensorType.get([_lapack.gesdd_iwork_size(m, n)],
                                    i32_type),
            ir.RankedTensorType.get([lwork], a_type.element_type),
        ]
        workspace_layouts = [
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
        ]
    elif dtype == np.complex64:
        fn = b"lapack_cgesdd"
        singular_vals_type = ir.F32Type.get()
        lwork = _lapack.cgesdd_work_size(m, n, compute_uv, full_matrices)
        workspace = [
            ir.RankedTensorType.get([_lapack.gesdd_iwork_size(m, n)],
                                    i32_type),
            ir.RankedTensorType.get(
                [_lapack.cgesdd_rwork_size(m, n, int(compute_uv))],
                ir.F32Type.get()),
            ir.RankedTensorType.get([lwork], a_type.element_type),
        ]
        workspace_layouts = [
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
        ]
    elif dtype == np.complex128:
        fn = b"lapack_zgesdd"
        singular_vals_type = ir.F64Type.get()
        lwork = _lapack.zgesdd_work_size(m, n, compute_uv, full_matrices)
        workspace = [
            ir.RankedTensorType.get([_lapack.gesdd_iwork_size(m, n)],
                                    i32_type),
            ir.RankedTensorType.get(
                [_lapack.cgesdd_rwork_size(m, n, int(compute_uv))],
                ir.F64Type.get()),
            ir.RankedTensorType.get([lwork], a_type.element_type),
        ]
        workspace_layouts = [
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
        ]
    else:
        raise NotImplementedError("Unsupported dtype {}".format(dtype))

    scalar_layout = ir.DenseIntElementsAttr.get(np.zeros((0, ), np.int64),
                                                type=ir.IndexType.get())
    layout = ir.DenseIntElementsAttr.get(
        np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
        type=ir.IndexType.get())
    out = mhlo.CustomCallOp([
        ir.TupleType.get_tuple([
            a.type,
            ir.RankedTensorType.get(batch_dims +
                                    (min(m, n), ), singular_vals_type),
            ir.RankedTensorType.get(
                batch_dims +
                (m, m if full_matrices else min(m, n)), a_type.element_type),
            ir.RankedTensorType.get(
                batch_dims +
                (n if full_matrices else min(m, n), n), a_type.element_type),
            ir.RankedTensorType.get(batch_dims, i32_type),
        ] + workspace)
    ], [
        _mhlo_s32(int(full_matrices)),
        _mhlo_s32(int(compute_uv)),
        _mhlo_s32(b),
        _mhlo_s32(m),
        _mhlo_s32(n),
        _mhlo_s32(lwork), a
    ],
                            call_target_name=ir.StringAttr.get(fn),
                            has_side_effect=ir.BoolAttr.get(False),
                            backend_config=ir.StringAttr.get(""),
                            api_version=ir.IntegerAttr.get(i32_type, 2),
                            called_computations=ir.ArrayAttr.get([]),
                            operand_layouts=ir.ArrayAttr.get([scalar_layout] *
                                                             6 + [layout]),
                            result_layouts=ir.ArrayAttr.get([
                                layout,
                                ir.DenseIntElementsAttr.get(
                                    np.array((num_bd, ) +
                                             tuple(range(num_bd - 1, -1, -1))),
                                    type=ir.IndexType.get()),
                                layout,
                                layout,
                                ir.DenseIntElementsAttr.get(
                                    np.array(range(num_bd - 1, -1, -1)),
                                    type=ir.IndexType.get()),
                            ] + workspace_layouts))
    return [
        mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
        for i in range(1, 5)
    ]
Ejemplo n.º 25
0
def gees_mhlo(a, jobvs=True, sort=False, select=None):
    a_type = ir.RankedTensorType(a.type)
    etype = a_type.element_type
    dims = a_type.shape
    assert len(dims) >= 2
    m, n = dims[-2:]
    assert m == n
    batch_dims = tuple(dims[:-2])
    num_bd = len(batch_dims)
    b = 1
    for d in batch_dims:
        b *= d
    layout = ir.DenseIntElementsAttr.get(
        np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
        type=ir.IndexType.get())

    if sort:
        raise NotImplementedError(
            "The sort feature of LAPACK's gees routine is not implemented.")

    jobvs = ord('V' if jobvs else 'N')
    sort = ord('S' if sort else 'N')

    if not ir.ComplexType.isinstance(etype):
        fn = "lapack_sgees" if etype == ir.F32Type.get() else "lapack_dgees"
        schurvecs_type = etype
        workspaces = [ir.RankedTensorType.get(dims, schurvecs_type)]
        workspace_layouts = [layout]
        eigvals = [ir.RankedTensorType.get(batch_dims + (n, ), etype)] * 2
        eigvals_layouts = [
            ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1),
                                        type=ir.IndexType.get())
        ] * 2
    else:
        fn = ("lapack_cgees" if etype == ir.ComplexType.get(ir.F32Type.get())
              else "lapack_zgees")
        schurvecs_type = etype
        workspaces = [
            ir.RankedTensorType.get(dims, schurvecs_type),
            ir.RankedTensorType.get([n],
                                    ir.ComplexType(etype).element_type),
        ]
        workspace_layouts = [
            layout,
            ir.DenseIntElementsAttr.get(np.array([0]),
                                        type=ir.IndexType.get()),
        ]
        eigvals = [ir.RankedTensorType.get(batch_dims + (n, ), etype)]
        eigvals_layouts = [
            ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1),
                                        type=ir.IndexType.get())
        ]

    i32_type = ir.IntegerType.get_signless(32)

    scalar_layout = ir.DenseIntElementsAttr.get(np.zeros((0, ), np.int64),
                                                type=ir.IndexType.get())
    out = mhlo.CustomCallOp(
        [
            ir.TupleType.get_tuple(workspaces + eigvals + [
                ir.RankedTensorType.get(dims, schurvecs_type),
                ir.RankedTensorType.get(batch_dims, i32_type),
                ir.RankedTensorType.get(batch_dims, i32_type),
            ])
        ],
        [
            _mhlo_s32(b),
            _mhlo_s32(n),
            _mhlo_u8(np.uint8(jobvs)),
            _mhlo_u8(np.uint8(sort)),
            # TODO: figure out how to put the callable select function here
            a
        ],
        call_target_name=ir.StringAttr.get(fn),
        has_side_effect=ir.BoolAttr.get(False),
        backend_config=ir.StringAttr.get(""),
        api_version=ir.IntegerAttr.get(i32_type, 2),
        called_computations=ir.ArrayAttr.get([]),
        operand_layouts=ir.ArrayAttr.get([scalar_layout] * 4 + [layout]),
        result_layouts=ir.ArrayAttr.get(workspace_layouts + eigvals_layouts + [
            layout,
            ir.DenseIntElementsAttr.get(np.arange(num_bd - 1, -1, -1),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.arange(num_bd - 1, -1, -1),
                                        type=ir.IndexType.get()),
        ]))
    i32_attr = lambda i: ir.IntegerAttr.get(i32_type, i)
    if sort == ord('S'):
        return (mhlo.GetTupleElementOp(out, i32_attr(0)).result,
                mhlo.GetTupleElementOp(out, i32_attr(3)).result,
                mhlo.GetTupleElementOp(out, i32_attr(4)).result,
                mhlo.GetTupleElementOp(out, i32_attr(5)).result)
    else:
        return (mhlo.GetTupleElementOp(out, i32_attr(0)).result,
                mhlo.GetTupleElementOp(out, i32_attr(3)).result,
                mhlo.GetTupleElementOp(out, i32_attr(5)).result)
Ejemplo n.º 26
0
def geev_mhlo(dtype, a, jobvl=True, jobvr=True):
    dims = ir.RankedTensorType(a.type).shape
    assert len(dims) >= 2
    m, n = dims[-2:]
    assert m == n
    batch_dims = tuple(dims[:-2])
    num_bd = len(batch_dims)
    b = 1
    for d in batch_dims:
        b *= d
    layout = ir.DenseIntElementsAttr.get(
        np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
        type=ir.IndexType.get())

    jobvl_c = ord('V' if jobvl else 'N')
    jobvr_c = ord('V' if jobvr else 'N')

    if dtype == np.float32:
        fn = b"lapack_sgeev"
        real = True
        eigvecs_type = ir.ComplexType.get(ir.F32Type.get())
        workspaces = [
            ir.RankedTensorType.get([n, n], ir.F32Type.get()),
            ir.RankedTensorType.get([n, n], ir.F32Type.get()),
            ir.RankedTensorType.get([n, n], ir.F32Type.get())
        ]
        workspace_layouts = [
            ir.DenseIntElementsAttr.get(np.array([0, 1]),
                                        type=ir.IndexType.get())
        ] * 3
        eigvals = [
            ir.RankedTensorType.get(batch_dims + (n, ), ir.F32Type.get()),
            ir.RankedTensorType.get(batch_dims + (n, ), ir.F32Type.get())
        ]
        eigvals_layouts = [
            ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1),
                                        type=ir.IndexType.get())
        ] * 2
    elif dtype == np.float64:
        fn = b"lapack_dgeev"
        real = True
        eigvecs_type = ir.ComplexType.get(ir.F64Type.get())
        workspaces = [
            ir.RankedTensorType.get([n, n], ir.F64Type.get()),
            ir.RankedTensorType.get([n, n], ir.F64Type.get()),
            ir.RankedTensorType.get([n, n], ir.F64Type.get())
        ]
        workspace_layouts = [
            ir.DenseIntElementsAttr.get(np.array([0, 1]),
                                        type=ir.IndexType.get())
        ] * 3
        eigvals = [
            ir.RankedTensorType.get(batch_dims + (n, ), ir.F64Type.get()),
            ir.RankedTensorType.get(batch_dims + (n, ), ir.F64Type.get())
        ]
        eigvals_layouts = [
            ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1),
                                        type=ir.IndexType.get())
        ] * 2
    elif dtype == np.complex64:
        fn = b"lapack_cgeev"
        real = False
        eigvecs_type = ir.ComplexType.get(ir.F32Type.get())
        workspaces = [
            ir.RankedTensorType.get([n, n],
                                    ir.ComplexType.get(ir.F32Type.get())),
            ir.RankedTensorType.get([2 * n], ir.F32Type.get())
        ]
        workspace_layouts = [
            ir.DenseIntElementsAttr.get(np.array([0, 1]),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get())
        ]
        eigvals = [
            ir.RankedTensorType.get(batch_dims + (n, ),
                                    ir.ComplexType.get(ir.F32Type.get()))
        ]
        eigvals_layouts = [
            ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1),
                                        type=ir.IndexType.get())
        ]
    elif dtype == np.complex128:
        fn = b"lapack_zgeev"
        real = False
        eigvecs_type = ir.ComplexType.get(ir.F64Type.get())
        workspaces = [
            ir.RankedTensorType.get([n, n],
                                    ir.ComplexType.get(ir.F64Type.get())),
            ir.RankedTensorType.get([2 * n], ir.F64Type.get())
        ]
        workspace_layouts = [
            ir.DenseIntElementsAttr.get(np.array([0, 1]),
                                        type=ir.IndexType.get()),
            ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get())
        ]
        eigvals = [
            ir.RankedTensorType.get(batch_dims + (n, ),
                                    ir.ComplexType.get(ir.F64Type.get()))
        ]
        eigvals_layouts = [
            ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1),
                                        type=ir.IndexType.get())
        ]
    else:
        raise NotImplementedError("Unsupported dtype {}".format(dtype))

    i32_type = ir.IntegerType.get_signless(32)
    scalar_layout = ir.DenseIntElementsAttr.get(np.zeros((0, ), np.int64),
                                                type=ir.IndexType.get())
    info_layout = ir.DenseIntElementsAttr.get(np.arange(num_bd - 1, -1, -1),
                                              type=ir.IndexType.get())
    out = mhlo.CustomCallOp(
        [
            ir.TupleType.get_tuple(workspaces + eigvals + [
                ir.RankedTensorType.get(dims, eigvecs_type),
                ir.RankedTensorType.get(dims, eigvecs_type),
                ir.RankedTensorType.get(batch_dims, i32_type),
            ])
        ],
        [_mhlo_s32(b),
         _mhlo_s32(n),
         _mhlo_u8(jobvl_c),
         _mhlo_u8(jobvr_c), a],
        call_target_name=ir.StringAttr.get(fn),
        has_side_effect=ir.BoolAttr.get(False),
        backend_config=ir.StringAttr.get(""),
        api_version=ir.IntegerAttr.get(i32_type, 2),
        called_computations=ir.ArrayAttr.get([]),
        operand_layouts=ir.ArrayAttr.get([scalar_layout] * 4 + [layout]),
        result_layouts=ir.ArrayAttr.get(workspace_layouts + eigvals_layouts +
                                        [layout] * 2 + [info_layout])).result
    i32_attr = lambda i: ir.IntegerAttr.get(i32_type, i)
    if real:
        return (mhlo.ComplexOp(mhlo.GetTupleElementOp(out, i32_attr(3)),
                               mhlo.GetTupleElementOp(out,
                                                      i32_attr(4))).result,
                mhlo.GetTupleElementOp(out, i32_attr(5)).result,
                mhlo.GetTupleElementOp(out, i32_attr(6)).result,
                mhlo.GetTupleElementOp(out, i32_attr(7)).result)
    else:
        return (mhlo.GetTupleElementOp(out, i32_attr(2)).result,
                mhlo.GetTupleElementOp(out, i32_attr(3)).result,
                mhlo.GetTupleElementOp(out, i32_attr(4)).result,
                mhlo.GetTupleElementOp(out, i32_attr(5)).result)
Ejemplo n.º 27
0
def trsm_mhlo(dtype,
              alpha,
              a,
              b,
              left_side=False,
              lower=False,
              trans_a=False,
              conj_a=False,
              diag=False):
    a_type = ir.RankedTensorType(a.type)
    b_type = ir.RankedTensorType(b.type)

    dims = b_type.shape
    m, n = dims[-2:]
    k = m if left_side else n

    batch_dims = tuple(dims[:-2])
    num_bd = len(batch_dims)
    num_b = 1
    for d in batch_dims:
        num_b *= d

    if (batch_dims + (k, k) != tuple(a_type.shape)
            or a_type.element_type != b_type.element_type):
        raise ValueError("Argument mismatch for trsm, got {} and {}".format(
            a_type, b_type))

    if dtype == np.float32:
        fn = "blas_strsm"
    elif dtype == np.float64:
        fn = "blas_dtrsm"
    elif dtype == np.complex64:
        fn = "blas_ctrsm"
    elif dtype == np.complex128:
        fn = "blas_ztrsm"
    else:
        raise NotImplementedError("Unsupported dtype {}".format(dtype))

    if conj_a and not trans_a:
        raise NotImplementedError(
            "Conjugation without transposition not supported")
    scalar_layout = ir.DenseIntElementsAttr.get(np.zeros((0, ), np.int64),
                                                type=ir.IndexType.get())
    layout = ir.DenseIntElementsAttr.get(
        np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
        type=ir.IndexType.get())
    return mhlo.CustomCallOp(
        [b.type], [
            _mhlo_s32(int(left_side)),
            _mhlo_s32(int(lower)),
            _mhlo_s32((2 if conj_a else 1) if trans_a else 0),
            _mhlo_s32(int(diag)),
            _mhlo_s32(m),
            _mhlo_s32(n),
            _mhlo_s32(num_b), alpha, a, b
        ],
        call_target_name=ir.StringAttr.get(fn),
        has_side_effect=ir.BoolAttr.get(False),
        backend_config=ir.StringAttr.get(""),
        api_version=ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 2),
        called_computations=ir.ArrayAttr.get([]),
        operand_layouts=ir.ArrayAttr.get([scalar_layout] * 8 + [layout] * 2),
        result_layouts=ir.ArrayAttr.get([layout])).result