# flatbuffers needs importlib.util but fails to import it itself. import importlib.util # noqa: F401 from typing import List import jaxlib.mlir.ir as ir import jaxlib.mlir.dialects.mhlo as mhlo from . import _pocketfft from . import pocketfft_flatbuffers_py_generated as pd import numpy as np import flatbuffers from jaxlib import xla_client for _name, _value in _pocketfft.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="cpu") FftType = xla_client.FftType flatbuffers_version_2 = hasattr(flatbuffers, "__version__") def _pocketfft_descriptor(shape: List[int], dtype, fft_type: FftType, fft_lengths: List[int]) -> bytes: n = len(shape) assert len(fft_lengths) >= 1 assert len(fft_lengths) <= n, (fft_lengths, n) builder = flatbuffers.Builder(128) forward = fft_type in (FftType.FFT, FftType.RFFT)
from functools import partial import operator import jaxlib.mlir.ir as ir import jaxlib.mlir.dialects.mhlo as mhlo import numpy as np from jaxlib import xla_client from .mhlo_helpers import custom_call try: from .cuda import _cublas for _name, _value in _cublas.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="CUDA") except ImportError: _cublas = None try: from .cuda import _cusolver for _name, _value in _cusolver.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="CUDA") except ImportError: _cusolver = None try: from .rocm import _hipblas for _name, _value in _hipblas.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="ROCM") except ImportError: