def dex_call_cpu_translation(b, *args, func_atom): xla_shapes = list(map(b.get_shape, args)) result_aval, shape_vars = dex_call_abstract_eval_with_shape( *(jax.core.ShapedArray(xshape.dimensions(), xshape.numpy_dtype()) for xshape in xla_shapes), func_atom=func_atom) result_xshape = xc.Shape.array_shape(result_aval.dtype, result_aval.shape) custom_call = custom_call_cache.get(func_atom, None) native = get_compiled(func_atom) if custom_call is None: assert len(args) == len(native.explicit_argument_signature) assert 1 == len(native.result_signature) custom_call_ctype = ctypes.CFUNCTYPE( None, ctypes.c_void_p, ctypes.POINTER(ctypes.c_void_p * len(args))) @custom_call_ctype def trampoline(result_ptr, arg_ptr_array): name_to_cval = { name: IdxRepTy(value) for name, value in shape_vars.items() } for binder, ptr in zip(native.explicit_argument_signature, arg_ptr_array.contents): if isinstance(binder.type, ScalarType): cval = ctypes.cast(ptr, ctypes.POINTER( binder.type.arg_ctype)).contents elif isinstance(binder.type, RectContArrayType): cval = ctypes.cast(ptr, binder.type.arg_ctype) else: raise AssertionError("Unexpected binder type") name_to_cval[binder.name] = cval result_binder = native.result_signature[0] name_to_cval[result_binder.name] = ctypes.cast( result_ptr, result_binder.type.ref_ctype) native.callable(*(name_to_cval[name] for name in native.ccall_signature)) trampoline_addr = ctypes.c_void_p.from_param(trampoline) custom_call_name = f"dex_custom_call{next(custom_call_id)}".encode( 'ascii') xc.register_custom_call_target( custom_call_name, make_custom_call_target(trampoline_addr)) custom_call_cache[func_atom] = (custom_call_name, trampoline) # TODO: Unregister custom calls at some point? else: custom_call_name, *_ = custom_call return xc.ops.CustomCall(b, custom_call_name, operands=args, shape=result_xshape)
import numpy as np import jax.numpy as jnp import functools import itertools import operator from jax import abstract_arrays from jax.core import Primitive from jax.lib import xla_client from jax.interpreters import xla import superbee_cuda try: for _name, _value in superbee_cuda.gpu_custom_call_targets.items(): xla_client.register_custom_call_target(_name, _value, platform="gpu") except ImportError: print("could not import cuda_superbee_kernels. Are .so files present?") pass _prod = lambda xs: functools.reduce(operator.mul, xs, 1) # TODO(phawkins): remove after we no longer need to support old jax releases. def _unpack_builder(c): # If `c` is a ComputationBuilder object, extracts the underlying XlaBuilder. return getattr(c, "_builder", c) def _superbee(builder, var, u_wgrid, v_wgrid, w_wgrid, maskW, dxt, dyt, dzw, cost, cosu, dt_tracer): """Superbee kernel for GPU."""
enable_logging = is_truthy(os.environ.get("MPI4JAX_DEBUG", "")) mpi_xla_bridge.set_logging(enable_logging) if HAS_GPU_EXT: gpu_copy_behavior = os.environ.get("MPI4JAX_USE_CUDA_MPI", "") if is_truthy(gpu_copy_behavior): has_cuda_mpi = True elif is_falsy(gpu_copy_behavior): has_cuda_mpi = False else: has_cuda_mpi = False warn_msg = ( "Not using CUDA-enabled MPI. " "If you are sure that your MPI library is built with CUDA support, " "set MPI4JAX_USE_CUDA_MPI=1. To silence this warning, " "set MPI4JAX_USE_CUDA_MPI=0.") warnings.warn(warn_msg) mpi_xla_bridge_gpu.set_copy_to_host(not has_cuda_mpi) # register custom call targets for name, fn in mpi_xla_bridge_cpu.cpu_custom_call_targets.items(): xla_client.register_custom_call_target(name, fn, platform="cpu") if HAS_GPU_EXT: for name, fn in mpi_xla_bridge_gpu.gpu_custom_call_targets.items(): xla_client.register_custom_call_target(name, fn, platform="gpu") del os, xla_client
def _xla_translation_cpu(numba_fn, abstract_eval_fn, xla_builder, *args): """Returns the XLA CustomCall for the given numba function. Args: numba_fn: A numba function. For its signature, see the module docstring. abstract_eval_fn: The abstract shape evaluation function. xla_builder: The XlaBuilder instance. *args: The positional arguments to be passed to `numba_fn`. Returns: The XLA CustomCall operation calling into the numba function. """ if config.FLAGS["NETKET_DEBUG"]: print("Encoding the CPU variant of numba4jax function") input_shapes = [xla_builder.get_shape(arg) for arg in args] # TODO(josipd): Check that the input layout is the numpy default. output_abstract_arrays = abstract_eval_fn( *[_xla_shape_to_abstract(shape) for shape in input_shapes] ) output_shapes = tuple(array.shape for array in output_abstract_arrays) output_shapes_flattened = tuple( dim for array in output_abstract_arrays for dim in array.shape ) output_ndims = tuple(array.ndim for array in output_abstract_arrays) output_ndims_offsets = tuple(np.cumsum(np.concatenate([[0], output_ndims]))) output_dtypes = tuple(array.dtype for array in output_abstract_arrays) layout_for_shape = lambda shape: range(len(shape) - 1, -1, -1) output_layouts = map(layout_for_shape, output_shapes) xla_output_shapes = [ xla_client.Shape.array_shape(*arg) for arg in zip(output_dtypes, output_shapes, output_layouts) ] xla_output_shape = xla_client.Shape.tuple_shape(xla_output_shapes) input_dtypes = tuple(shape.element_type() for shape in input_shapes) input_dimensions = tuple(shape.dimensions() for shape in input_shapes) output_i = tuple(i for i in range(len(output_shapes))) input_i = tuple(i for i in range(len(input_dimensions))) n_out = len(output_shapes) n_in = len(input_dimensions) xla_call_sig = nb_types.void( nb_types.CPointer(nb_types.voidptr), # output_ptrs nb_types.CPointer(nb_types.voidptr), # input_ptrs ) @numba.cfunc(xla_call_sig) def xla_custom_call_target(output_ptrs, input_ptrs): # manually unroll input and output args because numba is # relatively dummb and cannot always infer getitem on inhomogeneous tuples if n_out == 1: args_out = ( numba.carray(output_ptrs[0], output_shapes[0], dtype=output_dtypes[0]), ) elif n_out == 2: args_out = ( numba.carray(output_ptrs[0], output_shapes[0], dtype=output_dtypes[0]), numba.carray(output_ptrs[1], output_shapes[1], dtype=output_dtypes[1]), ) elif n_out == 3: args_out = ( numba.carray(output_ptrs[0], output_shapes[0], dtype=output_dtypes[0]), numba.carray(output_ptrs[1], output_shapes[1], dtype=output_dtypes[1]), numba.carray(output_ptrs[2], output_shapes[2], dtype=output_dtypes[2]), ) elif n_out == 4: args_out = ( numba.carray(output_ptrs[0], output_shapes[0], dtype=output_dtypes[0]), numba.carray(output_ptrs[1], output_shapes[1], dtype=output_dtypes[1]), numba.carray(output_ptrs[2], output_shapes[2], dtype=output_dtypes[2]), numba.carray(output_ptrs[3], output_shapes[3], dtype=output_dtypes[3]), ) if n_in == 1: args_in = ( numba.carray(input_ptrs[0], input_dimensions[0], dtype=input_dtypes[0]), ) elif n_in == 2: args_in = ( numba.carray(input_ptrs[0], input_dimensions[0], dtype=input_dtypes[0]), numba.carray(input_ptrs[1], input_dimensions[1], dtype=input_dtypes[1]), ) elif n_in == 3: args_in = ( numba.carray(input_ptrs[0], input_dimensions[0], dtype=input_dtypes[0]), numba.carray(input_ptrs[1], input_dimensions[1], dtype=input_dtypes[1]), numba.carray(input_ptrs[2], input_dimensions[2], dtype=input_dtypes[2]), ) elif n_in == 4: args_in = ( numba.carray(input_ptrs[0], input_dimensions[0], dtype=input_dtypes[0]), numba.carray(input_ptrs[1], input_dimensions[1], dtype=input_dtypes[1]), numba.carray(input_ptrs[2], input_dimensions[2], dtype=input_dtypes[2]), numba.carray(input_ptrs[3], input_dimensions[3], dtype=input_dtypes[3]), ) numba_fn(args_out + args_in) target_name = xla_custom_call_target.native_name.encode("ascii") capsule = _create_xla_target_capsule(xla_custom_call_target.address) xla_client.register_custom_call_target(target_name, capsule, "cpu") # xla_extension.register_custom_call_target(target_name, capsule, "Host") return xla_client.ops.CustomCallWithLayout( xla_builder, target_name, operands=args, shape_with_layout=xla_output_shape, operand_shapes_with_layout=input_shapes, )
def _SetupJax(self): """This function is used internally by TFC to setup JAX primatives and create desired behavior when taking derivatives of TFC constrained expressions.""" # Helper functions def _constant_bool(c, a): return xla_client.ops.Constant(c, bool(a)) def _constant_s32_scalar(c, a): return xla_client.ops.Constant(c, int(a)) def _unpack_builder(c): # If `c` is a ComputationBuilder object, extracts the underlying XlaBuilder. return getattr(c, "_builder", c) # Regiser XLA function obj = self.basisClass.xlaCapsule xlaName = "BasisFunc" + str(self.basisClass.identifier) xlaName = xlaName.encode("utf-8") xla_client.register_custom_call_target(xlaName, obj, platform="cpu") # Create primitives H_p = core.Primitive("H") def Hjax(x, d=0, full=False): return H_p.bind(x, d=d, full=full) # Implicit translation def H_impl(x, d=0, full=False): return self.basisClass.H(x, d, full) H_p.def_impl(H_impl) # Abstract evaluation def H_abstract_eval(x, d=0, full=False): if full: dim1 = self.basisClass.m else: dim1 = self.basisClass.m - self.basisClass.numC if len(x.shape) == 0: dims = (dim1, ) else: dims = (x.shape[0], dim1) return abstract_arrays.ShapedArray(dims, x.dtype) H_p.def_abstract_eval(H_abstract_eval) # XLA compilation def H_xla(c, x, d=0, full=False): c = _unpack_builder(c) x_shape = c.get_shape(x) dims = x_shape.dimensions() dtype = x_shape.element_type() dim0 = dims[0] if full: dim1 = self.basisClass.m else: dim1 = self.basisClass.m - self.basisClass.numC return xla_client.ops.CustomCall( c, xlaName, ( _constant_s32_scalar(c, self.basisClass.identifier), x, _constant_s32_scalar(c, d), _constant_bool(c, full), _constant_s32_scalar(c, dim0), _constant_s32_scalar(c, dim1), ), xla_client.Shape.array_shape(dtype, (dim0, dim1)), ) xla.backend_specific_translations["cpu"][H_p] = H_xla # Define batching translation def H_batch(vec, batch, d=0, full=False): return Hjax(*vec, d=d, full=full), batch[0] batching.primitive_batchers[H_p] = H_batch # Define jacobain vector product def H_jvp(arg_vals, arg_tans, d=0, full=False): x = arg_vals[0] dx = arg_tans[0] if not (dx is ad.Zero): if type(dx) is batching.BatchTracer: flag = onp.any(dx.val != 0) else: flag = onp.any(dx != 0) if flag: if len(dx.shape) == 1: out_tans = Hjax(x, d=d + 1, full=full) * onp.expand_dims(dx, 1) else: out_tans = Hjax(x, d=d + 1, full=full) * dx else: dim0 = x.shape[0] if full: dim1 = self.basisClass.m else: dim1 = self.basisClass.m - self.basisClass.numC out_tans = np.zeros((dim0, dim1)) return (Hjax(x, d=d, full=full), out_tans) ad.primitive_jvps[H_p] = H_jvp # Provide pointer for TFC class self._Hjax = Hjax
def SetupJAX(self): """This function is used internally by TFC to setup autograd primatives and create desired behavior when taking derivatives of TFC constrained expressions.""" # Helper functions def _constant_bool(c, a): return xla_client.ops.Constant(c, bool(a)) def _constant_s32_scalar(c, a): return xla_client.ops.Constant(c, int(a)) def _constant_array(c, a): return xla_client.ops.Constant(c, a) def _unpack_builder(c): # If `c` is a ComputationBuilder object, extracts the underlying XlaBuilder. return getattr(c, "_builder", c) d0 = onp.zeros(self.dim, dtype=np.int32) # Regiser XLA function obj = self.basisClass.xlaCapsule xlaName = "BasisFunc" + str(self.basisClass.identifier) xlaName = xlaName.encode("utf-8") xla_client.register_custom_call_target(xlaName, obj, platform="cpu") # Create Primitives H_p = core.Primitive("H") def Hjax(*x, d=d0, full=False): return H_p.bind(*x, d=d, full=full) # Implicit translations def H_impl(*x, d=d0, full=False): return self.basisClass.H(np.array(x), d, full) H_p.def_impl(H_impl) # Define abstract evaluation def H_abstract_eval(*x, d=d0, full=False): if full: dim1 = self.basisClass.numBasisFuncFull else: dim1 = self.basisClass.numBasisFunc if len(x[0].shape) == 0: dims = (dim1, ) else: dims = (x[0].shape[0], dim1) return abstract_arrays.ShapedArray(dims, x[0].dtype) H_p.def_abstract_eval(H_abstract_eval) # XLA compilation def H_xla(c, *x, d=d0, full=False): c = _unpack_builder(c) x_shape = c.get_shape(x[0]) dims = x_shape.dimensions() dtype = x_shape.element_type() dim0 = dims[0] if full: dim1 = self.basisClass.numBasisFuncFull else: dim1 = self.basisClass.numBasisFunc return xla_client.ops.CustomCall( c, xlaName, ( _constant_s32_scalar(c, self.basisClass.identifier), xla_client.ops.ConcatInDim(c, x, 0), _constant_array(c, d), _constant_s32_scalar(c, self.dim), _constant_bool(c, full), _constant_s32_scalar(c, dim0), _constant_s32_scalar(c, dim1), ), xla_client.Shape.array_shape(dtype, (dim0, dim1)), ) xla.backend_specific_translations["cpu"][H_p] = H_xla # Batching translation def H_batch(vec, batch, d=d0, full=False): return Hjax(*vec, d=d, full=full), batch[0] batching.primitive_batchers[H_p] = H_batch # Jacobian vector translation def H_jvp(arg_vals, arg_tans, d=d0, full=False): n = len(arg_vals) flat = len(arg_vals[0].shape) == 1 dim0 = arg_vals[0].shape[0] if full: dim1 = self.basisClass.numBasisFuncFull else: dim1 = self.basisClass.numBasisFunc out_tans = np.zeros((dim0, dim1)) for k in range(n): if not (type(arg_tans[k]) is ad.Zero): if type(arg_tans[k]) is batching.BatchTracer: flag = onp.any(arg_tans[k].val != 0) else: flag = onp.any(arg_tans[k] != 0) if flag: dark = copy(d) dark[k] += 1 if flat: out_tans += Hjax(*arg_vals, d=dark, full=full) * np.expand_dims( arg_tans[k], 1) else: out_tans += Hjax(*arg_vals, d=dark, full=full) * arg_tans[k] return (Hjax(*arg_vals, d=d, full=full), out_tans) ad.primitive_jvps[H_p] = H_jvp self._Hjax = Hjax