Example #1
0
# flatbuffers needs importlib.util but fails to import it itself.
import importlib.util  # noqa: F401
from typing import List

import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.mhlo as mhlo

from . import _pocketfft
from . import pocketfft_flatbuffers_py_generated as pd
import numpy as np

import flatbuffers
from jaxlib import xla_client

for _name, _value in _pocketfft.registrations().items():
    xla_client.register_custom_call_target(_name, _value, platform="cpu")

FftType = xla_client.FftType

flatbuffers_version_2 = hasattr(flatbuffers, "__version__")


def _pocketfft_descriptor(shape: List[int], dtype, fft_type: FftType,
                          fft_lengths: List[int]) -> bytes:
    n = len(shape)
    assert len(fft_lengths) >= 1
    assert len(fft_lengths) <= n, (fft_lengths, n)

    builder = flatbuffers.Builder(128)

    forward = fft_type in (FftType.FFT, FftType.RFFT)
Example #2
0
from functools import partial
import operator

import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.mhlo as mhlo

import numpy as np

from jaxlib import xla_client

from .mhlo_helpers import custom_call

try:
    from .cuda import _cublas
    for _name, _value in _cublas.registrations().items():
        xla_client.register_custom_call_target(_name, _value, platform="CUDA")
except ImportError:
    _cublas = None

try:
    from .cuda import _cusolver
    for _name, _value in _cusolver.registrations().items():
        xla_client.register_custom_call_target(_name, _value, platform="CUDA")
except ImportError:
    _cusolver = None

try:
    from .rocm import _hipblas
    for _name, _value in _hipblas.registrations().items():
        xla_client.register_custom_call_target(_name, _value, platform="ROCM")
except ImportError: