コード例 #1
0
ファイル: pocketfft.py プロジェクト: frederikwilde/jax
def pocketfft_mhlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]):
    """PocketFFT kernel for CPU."""
    a_type = ir.RankedTensorType(a.type)
    n = len(a_type.shape)

    fft_lengths = list(fft_lengths)
    descriptor_bytes, out_dtype, out_shape = _pocketfft_descriptor(
        list(a_type.shape), dtype, fft_type, fft_lengths)

    if out_dtype == np.float32:
        out_type = ir.F32Type.get()
    elif out_dtype == np.float64:
        out_type = ir.F64Type.get()
    elif out_dtype == np.complex64:
        out_type = ir.ComplexType.get(ir.F32Type.get())
    elif out_dtype == np.complex128:
        out_type = ir.ComplexType.get(ir.F64Type.get())
    else:
        raise ValueError(f"Unknown output type {out_dtype}")

    if 0 in a_type.shape or 0 in out_shape:
        zero = mhlo.ConstOp(
            ir.RankedTensorType.get([], out_type),
            ir.DenseElementsAttr.get(np.array(0, dtype=out_dtype),
                                     type=out_type))
        if jax._src.lib.mlir_api_version < 9:
            return mhlo.BroadcastOp(
                ir.RankedTensorType.get(out_shape, out_type), zero,
                ir.DenseElementsAttr.get(np.asarray(out_shape,
                                                    np.int64))).result
        else:
            return mhlo.BroadcastOp(
                zero, ir.DenseElementsAttr.get(np.asarray(out_shape,
                                                          np.int64))).result

    u8_type = ir.IntegerType.get_unsigned(8)
    descriptor = mhlo.ConstOp(
        ir.RankedTensorType.get([len(descriptor_bytes)], u8_type),
        ir.DenseElementsAttr.get(np.frombuffer(descriptor_bytes,
                                               dtype=np.uint8),
                                 type=u8_type))
    layout = ir.DenseIntElementsAttr.get(np.arange(n - 1, -1, -1),
                                         type=ir.IndexType.get())
    return mhlo.CustomCallOp(
        [ir.RankedTensorType.get(out_shape, out_type)], [descriptor, a],
        call_target_name=ir.StringAttr.get("pocketfft"),
        has_side_effect=ir.BoolAttr.get(False),
        backend_config=ir.StringAttr.get(""),
        api_version=ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 2),
        called_computations=ir.ArrayAttr.get([]),
        operand_layouts=ir.ArrayAttr.get([
            ir.DenseIntElementsAttr.get(np.array([0], np.int64),
                                        type=ir.IndexType.get()),
            layout,
        ]),
        result_layouts=ir.ArrayAttr.get([layout])).result
コード例 #2
0
 def _mhlo_s32(x):
     typ = ir.RankedTensorType.get([], ir.IntegerType.get_signless(32))
     if jax._src.lib.mlir_api_version < 21:
         return mhlo.ConstOp(
             typ, ir.DenseElementsAttr.get(np.array(x,
                                                    dtype=np.int32))).result
     else:
         return mhlo.ConstantOp(
             typ, ir.DenseElementsAttr.get(np.array(x,
                                                    dtype=np.int32))).result
コード例 #3
0
 def _mhlo_u8(x):
     if jax._src.lib.mlir_api_version < 21:
         return mhlo.ConstOp(
             ir.DenseElementsAttr.get(
                 np.array(x, dtype=np.uint8),
                 type=ir.IntegerType.get_unsigned(8))).result
     else:
         return mhlo.ConstantOp(
             ir.DenseElementsAttr.get(
                 np.array(x, dtype=np.uint8),
                 type=ir.IntegerType.get_unsigned(8))).result
コード例 #4
0
 def _mhlo_u8(x):
     typ = ir.RankedTensorType.get([], ir.IntegerType.get_unsigned(8))
     if jax._src.lib.mlir_api_version < 21:
         return mhlo.ConstOp(
             typ,
             ir.DenseElementsAttr.get(np.array(x, dtype=np.uint8),
                                      type=typ.element_type)).result
     else:
         return mhlo.ConstantOp(
             typ,
             ir.DenseElementsAttr.get(np.array(x, dtype=np.uint8),
                                      type=typ.element_type)).result
コード例 #5
0
ファイル: pocketfft.py プロジェクト: romanngg/jax
def pocketfft_mhlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]):
    """PocketFFT kernel for CPU."""
    a_type = ir.RankedTensorType(a.type)
    n = len(a_type.shape)

    fft_lengths = list(fft_lengths)
    descriptor_bytes, out_dtype, out_shape = _pocketfft_descriptor(
        list(a_type.shape), dtype, fft_type, fft_lengths)

    if out_dtype == np.float32:
        out_type = ir.F32Type.get()
    elif out_dtype == np.float64:
        out_type = ir.F64Type.get()
    elif out_dtype == np.complex64:
        out_type = ir.ComplexType.get(ir.F32Type.get())
    elif out_dtype == np.complex128:
        out_type = ir.ComplexType.get(ir.F64Type.get())
    else:
        raise ValueError(f"Unknown output type {out_dtype}")

    if 0 in a_type.shape or 0 in out_shape:
        if xla_client._version >= 64:
            if jax._src.lib.mlir_api_version < 21:
                zero = mhlo.ConstOp(
                    ir.DenseElementsAttr.get(np.array(0, dtype=out_dtype),
                                             type=out_type))
            else:
                zero = mhlo.ConstantOp(
                    ir.DenseElementsAttr.get(np.array(0, dtype=out_dtype),
                                             type=out_type))
        else:
            if jax._src.lib.mlir_api_version < 21:
                zero = mhlo.ConstOp(
                    ir.RankedTensorType.get([], out_type),
                    ir.DenseElementsAttr.get(np.array(0, dtype=out_dtype),
                                             type=out_type))
            else:
                zero = mhlo.ConstantOp(
                    ir.RankedTensorType.get([], out_type),
                    ir.DenseElementsAttr.get(np.array(0, dtype=out_dtype),
                                             type=out_type))
        return mhlo.BroadcastOp(
            zero, ir.DenseElementsAttr.get(np.asarray(out_shape,
                                                      np.int64))).result

    u8_type = ir.IntegerType.get_unsigned(8)
    if xla_client._version >= 64:
        if jax._src.lib.mlir_api_version < 21:
            descriptor = mhlo.ConstOp(
                ir.DenseElementsAttr.get(np.frombuffer(descriptor_bytes,
                                                       dtype=np.uint8),
                                         type=u8_type))
        else:
            descriptor = mhlo.ConstantOp(
                ir.DenseElementsAttr.get(np.frombuffer(descriptor_bytes,
                                                       dtype=np.uint8),
                                         type=u8_type))
    else:
        if jax._src.lib.mlir_api_version < 21:
            descriptor = mhlo.ConstOp(
                ir.RankedTensorType.get([len(descriptor_bytes)], u8_type),
                ir.DenseElementsAttr.get(np.frombuffer(descriptor_bytes,
                                                       dtype=np.uint8),
                                         type=u8_type))
        else:
            descriptor = mhlo.ConstantOp(
                ir.RankedTensorType.get([len(descriptor_bytes)], u8_type),
                ir.DenseElementsAttr.get(np.frombuffer(descriptor_bytes,
                                                       dtype=np.uint8),
                                         type=u8_type))
    layout = tuple(range(n - 1, -1, -1))
    return custom_call("pocketfft",
                       [ir.RankedTensorType.get(out_shape, out_type)],
                       [descriptor, a],
                       operand_layouts=[[0], layout],
                       result_layouts=[layout])
コード例 #6
0
ファイル: lapack.py プロジェクト: frederikwilde/jax
def _mhlo_s32(x):
    typ = ir.RankedTensorType.get([], ir.IntegerType.get_signless(32))
    return mhlo.ConstOp(typ,
                        ir.DenseElementsAttr.get(np.array(
                            x, dtype=np.int32))).result
コード例 #7
0
ファイル: lapack.py プロジェクト: frederikwilde/jax
def _mhlo_u8(x):
    typ = ir.RankedTensorType.get([], ir.IntegerType.get_unsigned(8))
    return mhlo.ConstOp(
        typ,
        ir.DenseElementsAttr.get(np.array(x, dtype=np.uint8),
                                 type=typ.element_type)).result
コード例 #8
0
 def _mhlo_s32(x):
   return mhlo.ConstOp(
       ir.DenseElementsAttr.get(np.array(x, dtype=np.int32),
                                type=ir.IntegerType.get_signless(32))).result
コード例 #9
0
 def _mhlo_u8(x):
   return mhlo.ConstOp(
       ir.DenseElementsAttr.get(np.array(x, dtype=np.uint8),
                                type=ir.IntegerType.get_unsigned(8))).result