Esempio n. 1
0
    input_bufs = in_handler(args)
    out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
    return out_handler(out_bufs)


def _sharded_call_impl(fun, *args, nparts, in_parts, out_parts_thunk,
                       local_in_parts, local_out_parts_thunk, local_nparts,
                       name):
    compiled_fun = _sharded_callable(fun, nparts, in_parts, out_parts_thunk,
                                     local_in_parts, local_out_parts_thunk,
                                     local_nparts, name,
                                     *map(xla.abstractify, args))
    return compiled_fun(*args)


sharded_call_p = core.CallPrimitive("sharded_call")
sharded_call = sharded_call_p.bind
sharded_call_p.def_impl(_sharded_call_impl)
mlir.register_lowering(sharded_call_p, _sharded_jit_lowering)


def sharded_jit(
        fun: Callable,
        in_parts,
        out_parts,
        num_partitions: Optional[int] = None,
        local_in_parts=None,
        local_out_parts=None,
        local_num_partitions=None,
        static_argnums: Union[int, Iterable[int]] = (),
):
Esempio n. 2
0
  Returns:
    The original `value` that was passed in.
  """
    if key is not None:
        value = prim.tie_in(key, value)
    flat_args, in_tree = tree_util.tree_flatten(value)
    out_flat = sow_p.bind(*flat_args,
                          name=name,
                          tag=tag,
                          mode=mode,
                          tree=in_tree)
    return tree_util.tree_unflatten(in_tree, out_flat)


nest_p = jax_core.CallPrimitive('nest')


def _nest_impl(f, *args, **_):
    with jax_core.new_sublevel():
        return f.call_wrapped(*args)


nest_p.def_impl(_nest_impl)


def _nest_translation_rule(*args, backend, name, call_jaxpr, scope, **_):
    return xla._xla_call_translation_rule(  # pylint: disable=protected-access
        *args,
        name=jax_util.wrap_name(name, f'nest[{scope}]'),
        backend=backend,
Esempio n. 3
0
                                        donated_invars, tuple_args)
    if any(donated_invars):
        # TODO(tomhennigan): At call time we should mark these buffers as deleted.
        unused_donations = [
            str(c.GetShape(a)) for a, d in zip(xla_args, donated_invars) if d
        ]
        msg = "See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation."
        if platform not in platforms_with_donation:
            msg = f"Donation is not implemented for {platform}.\n{msg}"
        warnings.warn(
            f"Some donated buffers were not usable: {', '.join(unused_donations)}.\n{msg}"
        )
    return c.build(output)


xla_call_p: core.CallPrimitive = core.CallPrimitive('xla_call')
xla_call = xla_call_p.bind


def _xla_call_partial_eval_update_params(params, kept_inputs, num_new_inputs):
    donated_invars = params['donated_invars']
    if not kept_inputs and donated_invars:
        # JaxprTrace.post_process_call creates a call with no input tracers
        donated_invars = (False, ) * num_new_inputs
    else:
        assert len(kept_inputs) == len(donated_invars)
        # JaxprTrace.process_call drops known input tracers
        donated_invars = [
            d for d, kept in zip(donated_invars, kept_inputs) if kept
        ]
        # Any new inputs are prepended to the left, so mark those as not donated.
Esempio n. 4
0
                           _partition_knowns)
from ..core import raise_to_shaped, get_aval, Literal, Jaxpr
from ..custom_derivatives import _initial_style_jaxpr, _resolve_kwargs
from ..api_util import flatten_fun_nokwargs
from ..tree_util import tree_flatten, tree_unflatten
from ..util import safe_map, safe_zip, unzip2, split_list, cache
from .. import source_info_util

map = safe_map
zip = safe_zip

################################################################################
# Reverse call primitive
################################################################################

invertible_call_p = core.CallPrimitive('invertible_call')
invertible_call = invertible_call_p.bind
invertible_call_p.def_impl(core.call_impl)

def _invertible_call_make_output_tracers(trace, in_tracers, out_tracers, params):
  uks = [not t.pval.is_known() for t in out_tracers]
  out_tracers_known, out_tracers_unknown = _partition_knowns(out_tracers, uks)

  # Add dummy arguments representing the outputs to the jaxpr. Those should
  # remain unused if the expression is evaluated, but they make it well-formed.
  out_known_avals = [raise_to_shaped(t.pval.get_aval()) for t in out_tracers_known]
  out_consts = [trace.instantiate_const(t) for t in out_tracers_known]
  new_jaxpr = _append_invars(params['call_jaxpr'], tuple(out_known_avals))
  new_in_tracers = (*in_tracers, *out_consts)

  # Append dummy outputs that correspond to known outputs left in the call_jaxpr
Esempio n. 5
0
import numpy as onp
from oryx.core import state
from oryx.core import trace_util
from oryx.core.interpreters import harvest
from oryx.core.interpreters import unzip

variable = state.variable
unzip_variable = functools.partial(unzip.unzip, tag=state.VARIABLE)


def call_impl(f, *args, **params):
    del params
    return f.call_wrapped(*args)


call_p = jax_core.CallPrimitive('call')
call_bind = call_p.bind
call_p.def_impl(call_impl)


def call(f):
    def wrapped(*args, **kwargs):
        fun = lu.wrap_init(f, kwargs)
        flat_args, in_tree = jax.tree_flatten(args)
        flat_fun, out_tree = jax.flatten_fun_nokwargs(fun, in_tree)
        ans = call_p.bind(flat_fun, *flat_args)
        return jax.tree_unflatten(out_tree(), ans)

    return wrapped

Esempio n. 6
0
from haiku._src import base
from haiku._src import stateful

import jax
from jax import api
from jax import core
from jax.interpreters import ad
from jax.interpreters import xla
import jax.linear_util as lu

xc = jax.lib.xla_client
xe = xc._xla  # pylint: disable=protected-access

# Registering named call as a primitive
named_call_p = core.CallPrimitive('named_call')
# named_call is implemented as a plain core.call and only diverges
# under compilation (see named_call_translation_rule)
named_call_p.def_impl(core.call_impl)


def _named_call_translation_rule(
    comp_builder: xe.XlaBuilder,
    axis_env: xla.AxisEnv,
    in_nodes: Sequence[xe.XlaOp],
    name_stack: str,
    backend: Optional[Any],
    name: str,
    call_jaxpr: core.Jaxpr,
) -> xe.XlaOp:
    """Compile and add a custom name to the XLA metadata."""