Ejemplo n.º 1
0
def dex_call_cpu_translation(b, *args, func_atom):
    xla_shapes = list(map(b.get_shape, args))
    result_aval, shape_vars = dex_call_abstract_eval_with_shape(
        *(jax.core.ShapedArray(xshape.dimensions(), xshape.numpy_dtype())
          for xshape in xla_shapes),
        func_atom=func_atom)
    result_xshape = xc.Shape.array_shape(result_aval.dtype, result_aval.shape)

    custom_call = custom_call_cache.get(func_atom, None)
    native = get_compiled(func_atom)
    if custom_call is None:
        assert len(args) == len(native.explicit_argument_signature)
        assert 1 == len(native.result_signature)
        custom_call_ctype = ctypes.CFUNCTYPE(
            None, ctypes.c_void_p, ctypes.POINTER(ctypes.c_void_p * len(args)))

        @custom_call_ctype
        def trampoline(result_ptr, arg_ptr_array):
            name_to_cval = {
                name: IdxRepTy(value)
                for name, value in shape_vars.items()
            }
            for binder, ptr in zip(native.explicit_argument_signature,
                                   arg_ptr_array.contents):
                if isinstance(binder.type, ScalarType):
                    cval = ctypes.cast(ptr,
                                       ctypes.POINTER(
                                           binder.type.arg_ctype)).contents
                elif isinstance(binder.type, RectContArrayType):
                    cval = ctypes.cast(ptr, binder.type.arg_ctype)
                else:
                    raise AssertionError("Unexpected binder type")
                name_to_cval[binder.name] = cval
            result_binder = native.result_signature[0]
            name_to_cval[result_binder.name] = ctypes.cast(
                result_ptr, result_binder.type.ref_ctype)
            native.callable(*(name_to_cval[name]
                              for name in native.ccall_signature))

        trampoline_addr = ctypes.c_void_p.from_param(trampoline)
        custom_call_name = f"dex_custom_call{next(custom_call_id)}".encode(
            'ascii')
        xc.register_custom_call_target(
            custom_call_name, make_custom_call_target(trampoline_addr))
        custom_call_cache[func_atom] = (custom_call_name, trampoline)
        # TODO: Unregister custom calls at some point?
    else:
        custom_call_name, *_ = custom_call
    return xc.ops.CustomCall(b,
                             custom_call_name,
                             operands=args,
                             shape=result_xshape)
Ejemplo n.º 2
0
import numpy as np
import jax.numpy as jnp
import functools
import itertools
import operator

from jax import abstract_arrays
from jax.core import Primitive
from jax.lib import xla_client
from jax.interpreters import xla
import superbee_cuda

try:
    for _name, _value in superbee_cuda.gpu_custom_call_targets.items():
        xla_client.register_custom_call_target(_name, _value, platform="gpu")
except ImportError:
    print("could not import cuda_superbee_kernels. Are .so files present?")
    pass

_prod = lambda xs: functools.reduce(operator.mul, xs, 1)


# TODO(phawkins): remove after we no longer need to support old jax releases.
def _unpack_builder(c):
    # If `c` is a ComputationBuilder object, extracts the underlying XlaBuilder.
    return getattr(c, "_builder", c)


def _superbee(builder, var, u_wgrid, v_wgrid, w_wgrid, maskW, dxt, dyt, dzw,
              cost, cosu, dt_tracer):
    """Superbee kernel for GPU."""
Ejemplo n.º 3
0
enable_logging = is_truthy(os.environ.get("MPI4JAX_DEBUG", ""))
mpi_xla_bridge.set_logging(enable_logging)

if HAS_GPU_EXT:
    gpu_copy_behavior = os.environ.get("MPI4JAX_USE_CUDA_MPI", "")

    if is_truthy(gpu_copy_behavior):
        has_cuda_mpi = True
    elif is_falsy(gpu_copy_behavior):
        has_cuda_mpi = False
    else:
        has_cuda_mpi = False
        warn_msg = (
            "Not using CUDA-enabled MPI. "
            "If you are sure that your MPI library is built with CUDA support, "
            "set MPI4JAX_USE_CUDA_MPI=1. To silence this warning, "
            "set MPI4JAX_USE_CUDA_MPI=0.")
        warnings.warn(warn_msg)

    mpi_xla_bridge_gpu.set_copy_to_host(not has_cuda_mpi)

# register custom call targets
for name, fn in mpi_xla_bridge_cpu.cpu_custom_call_targets.items():
    xla_client.register_custom_call_target(name, fn, platform="cpu")

if HAS_GPU_EXT:
    for name, fn in mpi_xla_bridge_gpu.gpu_custom_call_targets.items():
        xla_client.register_custom_call_target(name, fn, platform="gpu")

del os, xla_client
Ejemplo n.º 4
0
def _xla_translation_cpu(numba_fn, abstract_eval_fn, xla_builder, *args):
    """Returns the XLA CustomCall for the given numba function.

    Args:
      numba_fn: A numba function. For its signature, see the module docstring.
      abstract_eval_fn: The abstract shape evaluation function.
      xla_builder: The XlaBuilder instance.
      *args: The positional arguments to be passed to `numba_fn`.
    Returns:
      The XLA CustomCall operation calling into the numba function.
    """

    if config.FLAGS["NETKET_DEBUG"]:
        print("Encoding the CPU variant of numba4jax function")

    input_shapes = [xla_builder.get_shape(arg) for arg in args]
    # TODO(josipd): Check that the input layout is the numpy default.
    output_abstract_arrays = abstract_eval_fn(
        *[_xla_shape_to_abstract(shape) for shape in input_shapes]
    )
    output_shapes = tuple(array.shape for array in output_abstract_arrays)
    output_shapes_flattened = tuple(
        dim for array in output_abstract_arrays for dim in array.shape
    )
    output_ndims = tuple(array.ndim for array in output_abstract_arrays)
    output_ndims_offsets = tuple(np.cumsum(np.concatenate([[0], output_ndims])))
    output_dtypes = tuple(array.dtype for array in output_abstract_arrays)
    layout_for_shape = lambda shape: range(len(shape) - 1, -1, -1)
    output_layouts = map(layout_for_shape, output_shapes)
    xla_output_shapes = [
        xla_client.Shape.array_shape(*arg)
        for arg in zip(output_dtypes, output_shapes, output_layouts)
    ]
    xla_output_shape = xla_client.Shape.tuple_shape(xla_output_shapes)

    input_dtypes = tuple(shape.element_type() for shape in input_shapes)
    input_dimensions = tuple(shape.dimensions() for shape in input_shapes)

    output_i = tuple(i for i in range(len(output_shapes)))
    input_i = tuple(i for i in range(len(input_dimensions)))

    n_out = len(output_shapes)
    n_in = len(input_dimensions)

    xla_call_sig = nb_types.void(
        nb_types.CPointer(nb_types.voidptr),  # output_ptrs
        nb_types.CPointer(nb_types.voidptr),  # input_ptrs
    )

    @numba.cfunc(xla_call_sig)
    def xla_custom_call_target(output_ptrs, input_ptrs):
        # manually unroll input and output args because numba is
        # relatively dummb and cannot always infer getitem on inhomogeneous tuples
        if n_out == 1:
            args_out = (
                numba.carray(output_ptrs[0], output_shapes[0], dtype=output_dtypes[0]),
            )
        elif n_out == 2:
            args_out = (
                numba.carray(output_ptrs[0], output_shapes[0], dtype=output_dtypes[0]),
                numba.carray(output_ptrs[1], output_shapes[1], dtype=output_dtypes[1]),
            )
        elif n_out == 3:
            args_out = (
                numba.carray(output_ptrs[0], output_shapes[0], dtype=output_dtypes[0]),
                numba.carray(output_ptrs[1], output_shapes[1], dtype=output_dtypes[1]),
                numba.carray(output_ptrs[2], output_shapes[2], dtype=output_dtypes[2]),
            )
        elif n_out == 4:
            args_out = (
                numba.carray(output_ptrs[0], output_shapes[0], dtype=output_dtypes[0]),
                numba.carray(output_ptrs[1], output_shapes[1], dtype=output_dtypes[1]),
                numba.carray(output_ptrs[2], output_shapes[2], dtype=output_dtypes[2]),
                numba.carray(output_ptrs[3], output_shapes[3], dtype=output_dtypes[3]),
            )

        if n_in == 1:
            args_in = (
                numba.carray(input_ptrs[0], input_dimensions[0], dtype=input_dtypes[0]),
            )
        elif n_in == 2:
            args_in = (
                numba.carray(input_ptrs[0], input_dimensions[0], dtype=input_dtypes[0]),
                numba.carray(input_ptrs[1], input_dimensions[1], dtype=input_dtypes[1]),
            )
        elif n_in == 3:
            args_in = (
                numba.carray(input_ptrs[0], input_dimensions[0], dtype=input_dtypes[0]),
                numba.carray(input_ptrs[1], input_dimensions[1], dtype=input_dtypes[1]),
                numba.carray(input_ptrs[2], input_dimensions[2], dtype=input_dtypes[2]),
            )
        elif n_in == 4:
            args_in = (
                numba.carray(input_ptrs[0], input_dimensions[0], dtype=input_dtypes[0]),
                numba.carray(input_ptrs[1], input_dimensions[1], dtype=input_dtypes[1]),
                numba.carray(input_ptrs[2], input_dimensions[2], dtype=input_dtypes[2]),
                numba.carray(input_ptrs[3], input_dimensions[3], dtype=input_dtypes[3]),
            )

        numba_fn(args_out + args_in)

    target_name = xla_custom_call_target.native_name.encode("ascii")
    capsule = _create_xla_target_capsule(xla_custom_call_target.address)
    xla_client.register_custom_call_target(target_name, capsule, "cpu")
    # xla_extension.register_custom_call_target(target_name, capsule, "Host")
    return xla_client.ops.CustomCallWithLayout(
        xla_builder,
        target_name,
        operands=args,
        shape_with_layout=xla_output_shape,
        operand_shapes_with_layout=input_shapes,
    )
Ejemplo n.º 5
0
    def _SetupJax(self):
        """This function is used internally by TFC to setup JAX primatives and create desired behavior when taking derivatives of TFC constrained expressions."""

        # Helper functions
        def _constant_bool(c, a):
            return xla_client.ops.Constant(c, bool(a))

        def _constant_s32_scalar(c, a):
            return xla_client.ops.Constant(c, int(a))

        def _unpack_builder(c):
            # If `c` is a ComputationBuilder object, extracts the underlying XlaBuilder.
            return getattr(c, "_builder", c)

        # Regiser XLA function
        obj = self.basisClass.xlaCapsule
        xlaName = "BasisFunc" + str(self.basisClass.identifier)
        xlaName = xlaName.encode("utf-8")
        xla_client.register_custom_call_target(xlaName, obj, platform="cpu")

        # Create primitives
        H_p = core.Primitive("H")

        def Hjax(x, d=0, full=False):
            return H_p.bind(x, d=d, full=full)

        # Implicit translation
        def H_impl(x, d=0, full=False):
            return self.basisClass.H(x, d, full)

        H_p.def_impl(H_impl)

        # Abstract evaluation
        def H_abstract_eval(x, d=0, full=False):
            if full:
                dim1 = self.basisClass.m
            else:
                dim1 = self.basisClass.m - self.basisClass.numC
            if len(x.shape) == 0:
                dims = (dim1, )
            else:
                dims = (x.shape[0], dim1)
            return abstract_arrays.ShapedArray(dims, x.dtype)

        H_p.def_abstract_eval(H_abstract_eval)

        # XLA compilation
        def H_xla(c, x, d=0, full=False):
            c = _unpack_builder(c)
            x_shape = c.get_shape(x)
            dims = x_shape.dimensions()
            dtype = x_shape.element_type()
            dim0 = dims[0]
            if full:
                dim1 = self.basisClass.m
            else:
                dim1 = self.basisClass.m - self.basisClass.numC
            return xla_client.ops.CustomCall(
                c,
                xlaName,
                (
                    _constant_s32_scalar(c, self.basisClass.identifier),
                    x,
                    _constant_s32_scalar(c, d),
                    _constant_bool(c, full),
                    _constant_s32_scalar(c, dim0),
                    _constant_s32_scalar(c, dim1),
                ),
                xla_client.Shape.array_shape(dtype, (dim0, dim1)),
            )

        xla.backend_specific_translations["cpu"][H_p] = H_xla

        # Define batching translation
        def H_batch(vec, batch, d=0, full=False):
            return Hjax(*vec, d=d, full=full), batch[0]

        batching.primitive_batchers[H_p] = H_batch

        # Define jacobain vector product
        def H_jvp(arg_vals, arg_tans, d=0, full=False):
            x = arg_vals[0]
            dx = arg_tans[0]
            if not (dx is ad.Zero):
                if type(dx) is batching.BatchTracer:
                    flag = onp.any(dx.val != 0)
                else:
                    flag = onp.any(dx != 0)
                if flag:
                    if len(dx.shape) == 1:
                        out_tans = Hjax(x, d=d + 1,
                                        full=full) * onp.expand_dims(dx, 1)
                    else:
                        out_tans = Hjax(x, d=d + 1, full=full) * dx
            else:
                dim0 = x.shape[0]
                if full:
                    dim1 = self.basisClass.m
                else:
                    dim1 = self.basisClass.m - self.basisClass.numC
                out_tans = np.zeros((dim0, dim1))
            return (Hjax(x, d=d, full=full), out_tans)

        ad.primitive_jvps[H_p] = H_jvp

        # Provide pointer for TFC class
        self._Hjax = Hjax
Ejemplo n.º 6
0
Archivo: mtfc.py Proyecto: leakec/tfc
    def SetupJAX(self):
        """This function is used internally by TFC to setup autograd primatives and create desired behavior when taking derivatives of TFC constrained expressions."""

        # Helper functions
        def _constant_bool(c, a):
            return xla_client.ops.Constant(c, bool(a))

        def _constant_s32_scalar(c, a):
            return xla_client.ops.Constant(c, int(a))

        def _constant_array(c, a):
            return xla_client.ops.Constant(c, a)

        def _unpack_builder(c):
            # If `c` is a ComputationBuilder object, extracts the underlying XlaBuilder.
            return getattr(c, "_builder", c)

        d0 = onp.zeros(self.dim, dtype=np.int32)

        # Regiser XLA function
        obj = self.basisClass.xlaCapsule
        xlaName = "BasisFunc" + str(self.basisClass.identifier)
        xlaName = xlaName.encode("utf-8")
        xla_client.register_custom_call_target(xlaName, obj, platform="cpu")

        # Create Primitives
        H_p = core.Primitive("H")

        def Hjax(*x, d=d0, full=False):
            return H_p.bind(*x, d=d, full=full)

        # Implicit translations
        def H_impl(*x, d=d0, full=False):
            return self.basisClass.H(np.array(x), d, full)

        H_p.def_impl(H_impl)

        # Define abstract evaluation
        def H_abstract_eval(*x, d=d0, full=False):
            if full:
                dim1 = self.basisClass.numBasisFuncFull
            else:
                dim1 = self.basisClass.numBasisFunc
            if len(x[0].shape) == 0:
                dims = (dim1, )
            else:
                dims = (x[0].shape[0], dim1)
            return abstract_arrays.ShapedArray(dims, x[0].dtype)

        H_p.def_abstract_eval(H_abstract_eval)

        # XLA compilation
        def H_xla(c, *x, d=d0, full=False):
            c = _unpack_builder(c)
            x_shape = c.get_shape(x[0])
            dims = x_shape.dimensions()
            dtype = x_shape.element_type()
            dim0 = dims[0]
            if full:
                dim1 = self.basisClass.numBasisFuncFull
            else:
                dim1 = self.basisClass.numBasisFunc
            return xla_client.ops.CustomCall(
                c,
                xlaName,
                (
                    _constant_s32_scalar(c, self.basisClass.identifier),
                    xla_client.ops.ConcatInDim(c, x, 0),
                    _constant_array(c, d),
                    _constant_s32_scalar(c, self.dim),
                    _constant_bool(c, full),
                    _constant_s32_scalar(c, dim0),
                    _constant_s32_scalar(c, dim1),
                ),
                xla_client.Shape.array_shape(dtype, (dim0, dim1)),
            )

        xla.backend_specific_translations["cpu"][H_p] = H_xla

        # Batching translation
        def H_batch(vec, batch, d=d0, full=False):
            return Hjax(*vec, d=d, full=full), batch[0]

        batching.primitive_batchers[H_p] = H_batch

        # Jacobian vector translation
        def H_jvp(arg_vals, arg_tans, d=d0, full=False):
            n = len(arg_vals)
            flat = len(arg_vals[0].shape) == 1
            dim0 = arg_vals[0].shape[0]
            if full:
                dim1 = self.basisClass.numBasisFuncFull
            else:
                dim1 = self.basisClass.numBasisFunc
            out_tans = np.zeros((dim0, dim1))
            for k in range(n):
                if not (type(arg_tans[k]) is ad.Zero):
                    if type(arg_tans[k]) is batching.BatchTracer:
                        flag = onp.any(arg_tans[k].val != 0)
                    else:
                        flag = onp.any(arg_tans[k] != 0)
                    if flag:
                        dark = copy(d)
                        dark[k] += 1
                        if flat:
                            out_tans += Hjax(*arg_vals, d=dark,
                                             full=full) * np.expand_dims(
                                                 arg_tans[k], 1)
                        else:
                            out_tans += Hjax(*arg_vals, d=dark,
                                             full=full) * arg_tans[k]
            return (Hjax(*arg_vals, d=d, full=full), out_tans)

        ad.primitive_jvps[H_p] = H_jvp

        self._Hjax = Hjax