Ejemplo n.º 1
0
  def test_unsupported_op(self):
    p = core.Primitive('unsupported_op')
    p.def_abstract_eval(lambda x: x)
    p.def_impl(lambda x: x)

    def thunk():
      mask(p.bind, ['n'], 'n')([np.arange(3)], {'n': 2})

    message = "Masking rule for unsupported_op not implemented yet."
    self.assertRaisesWithLiteralMatch(NotImplementedError, message, thunk)
Ejemplo n.º 2
0
def standard_primitive(shape_rule, dtype_rule, name, translation_rule=None,
                       weak_type_rule=None, named_shape_rule=None):
  weak_type_rule = weak_type_rule or _standard_weak_type_rule
  named_shape_rule = named_shape_rule or standard_named_shape_rule
  prim = core.Primitive(name)
  prim.def_impl(partial(xla.apply_primitive, prim))
  prim.def_abstract_eval(
      partial(standard_abstract_eval, prim, shape_rule, dtype_rule,
              weak_type_rule, named_shape_rule))
  xla.register_translation(
      prim, translation_rule or partial(_standard_translate, name))
  return prim
Ejemplo n.º 3
0
    def test_shapecheck_unsupported_op(self):
        p = jc.Primitive('unsupported_op')
        p.def_impl(lambda x: x)

        def thunk():
            @shapecheck(['n'], 'n')
            def identity(x):
                return p.bind(x)

        self.assertRaisesRegex(
            NotImplementedError,
            "Shape rule for unsupported_op not implemented yet.", thunk)
Ejemplo n.º 4
0
def setup_spec(spec, grad=True):
    xla_client.register_cpu_custom_call_target(
        spec["xla_name"],
        getattr(xla_ops, spec["name"])())

    prim = core.Primitive("celerite2_" + spec["name"])
    prim.multiple_results = True
    spec["base_primitive"] = prim

    prim.def_impl(partial(xla.apply_primitive, prim))
    prim.def_abstract_eval(partial(_abstract_eval, spec))
    xla.backend_specific_translations["cpu"][prim] = partial(
        _translation_rule, spec)

    if not grad:
        return prim

    xla_client.register_cpu_custom_call_target(
        spec["xla_name"] + b"_rev",
        getattr(xla_ops, spec["name"] + "_rev")())

    jvp = core.Primitive("celerite2_" + spec["name"] + "_jvp")
    jvp.multiple_results = True
    rev = core.Primitive("celerite2_" + spec["name"] + "_rev")
    rev.multiple_results = True
    spec["jvp_primitive"] = jvp
    spec["rev_primitive"] = rev

    ad.primitive_jvps[prim] = partial(_jvp, spec)
    jvp.def_abstract_eval(partial(_jvp_abstract_eval, spec))
    ad.primitive_transposes[jvp] = partial(_jvp_transpose, spec)

    rev.def_impl(partial(xla.apply_primitive, rev))
    rev.def_abstract_eval(partial(_rev_abstract_eval, spec))
    xla.backend_specific_translations["cpu"][rev] = partial(
        _rev_translation_rule, spec)

    return prim
Ejemplo n.º 5
0
def _build_op(name, spec):
    xla_client.register_cpu_custom_call_target(
        name,
        getattr(xla_ops, spec["name"])())

    prim = core.Primitive(f"celerite2_{spec['name']}")
    prim.multiple_results = True
    prim.def_impl(partial(xla.apply_primitive, prim))
    prim.def_abstract_eval(partial(_abstract_eval, spec))
    xla.backend_specific_translations["cpu"][prim] = partial(
        _translation_rule, name, spec)

    if not spec["has_rev"]:
        return prim

    xla_client.register_cpu_custom_call_target(
        name + b"_rev",
        getattr(xla_ops, f"{spec['name']}_rev")())

    jvp_prim = core.Primitive(f"celerite2_{spec['name']}_jvp")
    jvp_prim.multiple_results = True
    rev_prim = core.Primitive(f"celerite2_{spec['name']}_rev")
    rev_prim.multiple_results = True

    # Setup a dummy JVP rule
    ad.primitive_jvps[prim] = partial(_jvp, prim, jvp_prim, spec)
    jvp_prim.def_abstract_eval(partial(_jvp_abstract_eval, spec))
    ad.primitive_transposes[jvp_prim] = partial(_jvp_transpose, rev_prim, spec)

    # Handle reverse pass using custom op
    rev_prim.def_impl(partial(xla.apply_primitive, rev_prim))
    rev_prim.def_abstract_eval(partial(_rev_abstract_eval, spec))
    xla.backend_specific_translations["cpu"][rev_prim] = partial(
        _rev_translation_rule, name + b"_rev", spec)

    return prim
Ejemplo n.º 6
0
from jax import core
from jax.interpreters import xla
from jax.lib import cusparse
from jax.lib import xla_bridge
from jax.lib import xla_client
import jax.numpy as jnp
import numpy as np

xb = xla_bridge
xops = xla_client.ops

#--------------------------------------------------------------------
# csr_todense

csr_todense_p = core.Primitive('csr_todense')


def csr_todense(data, indices, indptr, *, shape):
    """Convert CSR-format sparse matrix to a dense matrix.

  Args:
    data : array of shape ``(nnz,)``.
    indices : array of shape ``(nnz,)``
    indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``
    shape : length-2 tuple representing the matrix shape

  Returns:
    mat : array with specified shape and dtype matching ``data``
  """
    return csr_todense_p.bind(data, indices, indptr, shape=shape)
Ejemplo n.º 7
0
                           local_nparts=local_nparts,
                           name=flat_fun.__name__)
        return tree_unflatten(out_tree(), out)

    return wrapped


def _sharding_constraint_impl(x, partitions):
    # TODO(skye): can we also prevent this from being called in other
    # non-sharded_jit contexts? (e.g. pmap, control flow)
    raise NotImplementedError(
        "with_sharding_constraint() should only be called inside sharded_jit()"
    )


sharding_constraint_p = core.Primitive("sharding_constraint")
sharding_constraint_p.def_impl(_sharding_constraint_impl)
sharding_constraint_p.def_abstract_eval(lambda x, partitions: x)
ad.deflinear2(
    sharding_constraint_p, lambda ct, _, partitions:
    (with_sharding_constraint(ct, partitions), ))


def _sharding_constraint_lowering(ctx, x_node, partitions):
    return [
        mlir.wrap_with_sharding_op(x_node, xla.sharding_to_proto(partitions))
    ]


mlir.register_lowering(sharding_constraint_p, _sharding_constraint_lowering)
Ejemplo n.º 8
0
from oryx.core.interpreters import log_prob as lp
from oryx.core.ppl import transformations

seed = random.PRNGKey
conditional = transformations.conditional
graph_replace = transformations.graph_replace
joint_log_prob = transformations.joint_log_prob
joint_sample = transformations.joint_sample
log_prob = transformations.log_prob
intervene = transformations.intervene
random_variable = transformations.random_variable

# Define a random normal primitive so we can register it with the `log_prob`
# transformation.
random_normal_p = jax_core.Primitive('random_normal')


def random_normal(key):
    return random_normal_p.bind(key)


def random_normal_impl(rng):
    return random.normal(rng)


def random_normal_abstract(_):
    return abstract_arrays.ShapedArray((), np.float32)


def random_normal_log_prob(_, x):
Ejemplo n.º 9
0
    for operand in contract_fake_ops:
        idx = tuple(i for i, fake_op in enumerate(fake_ops)
                    if operand is fake_op)
        assert len(idx) == 1
        contract_operands.append(operands[idx[0]])
    return contract_operands, contractions


lax_numpy._polymorphic_einsum_contract_path_handlers[
    _DimPolynomial] = _einsum_contract_path

# A JAX primitive with no array arguments but with a dimension parameter
# that is a DimPoly. The value of the primitive is the value of the dimension.
# This primitive is used only in the context of jax2tf, so it does not need
# XLA translation rules.
dim_as_value_p = core.Primitive("dim_as_value")


def _dim_as_value_abstract(dim: DimSize) -> core.AbstractValue:
    return core.ShapedArray((), np.int32)


dim_as_value_p.def_abstract_eval(_dim_as_value_abstract)


def _dim_as_value(dim: DimSize):
    return dim_as_value_p.bind(dim=dim)


class PolyShape(tuple):
    """Tuple of polymorphic dimension specifications.
Ejemplo n.º 10
0
from jax import lax
from jax import linear_util as lu
from jax.config import config
from jax.experimental import maps
from jax.experimental import pjit
from jax.interpreters import mlir
from jax._src import lib as jaxlib
from jax._src import dispatch
from jax._src import test_util as jtu
from jax._src import util
from jax._src.lax import control_flow as lcf
import numpy as np

config.parse_flags_with_absl()

effect_p = core.Primitive('effect')
effect_p.multiple_results = True


@effect_p.def_effectful_abstract_eval
def _(*, effect):
    return [], {effect}


mlir.lowerable_effects.add('foo')
mlir.lowerable_effects.add('foo2')
mlir.lowerable_effects.add('bar')
mlir.lowerable_effects.add('while')
mlir.lowerable_effects.add('while1')
mlir.lowerable_effects.add('while2')
core.ordered_effects.add('foo')
Ejemplo n.º 11
0
def _custom_ivjp(fun, ivjp, args):
  in_avals = [raise_to_shaped(get_aval(x)) for x in args]
  fun_jaxpr = custom_derivatives._initial_style_jaxpr(fun, in_avals)
  try:
    ivjp_jaxpr = custom_derivatives._initial_style_jaxpr(
        ivjp, in_avals + fun_jaxpr.out_avals * 2)
  except RecursionError:
    raise ValueError("Calls to {} from its custom ivjp aren't supported yet".format(fun.__name__))
  return custom_ivjp_p.bind(*args, fun_jaxpr=fun_jaxpr,
                                   ivjp_jaxpr=ivjp_jaxpr)

def _custom_ivjp_impl(*args, fun_jaxpr, **_):
  return core.jaxpr_as_fun(fun_jaxpr)(*args)

custom_ivjp_p = core.Primitive('custom_ivjp')
custom_ivjp_p.multiple_results = True
custom_ivjp_p.def_impl(_custom_ivjp_impl)
custom_ivjp_p.def_abstract_eval(lambda *_, fun_jaxpr, **__: fun_jaxpr.out_avals)

def _custom_ivjp_jvp(primals, tangents, *, fun_jaxpr, ivjp_jaxpr):
  primals_out = custom_ivjp_p.bind(*primals, fun_jaxpr=fun_jaxpr,
                                             ivjp_jaxpr=ivjp_jaxpr)
  fun = core.jaxpr_as_fun(fun_jaxpr)
  # FIXME: This might compute the primals multiple times, but we only need to do
  #        this trick while linearizing. It should be possible to do it through
  #        a custom partial eval rule.
  _, tangents_out = ad.jvp(lu.wrap_init(fun)).call_wrapped(primals, tangents)
  return primals_out, tangents_out
ad.primitive_jvps[custom_ivjp_p] = _custom_ivjp_jvp
Ejemplo n.º 12
0
                           _partition_knowns)
from ..core import raise_to_shaped, get_aval, Literal, Jaxpr
from ..custom_derivatives import _initial_style_jaxpr, _resolve_kwargs
from ..api_util import flatten_fun_nokwargs
from ..tree_util import tree_flatten, tree_unflatten
from ..util import safe_map, safe_zip, unzip2, split_list, cache
from .. import source_info_util

map = safe_map
zip = safe_zip

################################################################################
# Reverse call primitive
################################################################################

invertible_call_p = core.Primitive('invertible_call')
invertible_call_p.call_primitive = True
invertible_call = partial(core.call_bind, invertible_call_p)
invertible_call_p.def_custom_bind(invertible_call)
invertible_call_p.def_impl(core.call_impl)
invertible_call_p.multiple_results = True

def _invertible_call_make_output_tracers(trace, in_tracers, out_tracers, params):
  uks = [not t.pval.is_known() for t in out_tracers]
  out_tracers_known, out_tracers_unknown = _partition_knowns(out_tracers, uks)

  # Add dummy arguments representing the outputs to the jaxpr. Those should
  # remain unused if the expression is evaluated, but they make it well-formed.
  out_known_avals = [raise_to_shaped(t.pval.get_aval()) for t in out_tracers_known]
  out_consts = [trace.instantiate_const(t) for t in out_tracers_known]
  new_jaxpr = _append_invars(params['call_jaxpr'], tuple(out_known_avals))
Ejemplo n.º 13
0
        c.GetShape(k2).dimensions(),
        c.GetShape(x1).dimensions(),
        c.GetShape(x2).dimensions())
    rank = len(shape)

    def _broadcast(x):
        ndims = c.GetShape(x).rank()
        return xla_client.ops.BroadcastInDim(x, shape,
                                             tuple(range(rank - ndims, rank)))

    return cuda_prng.threefry2x32(xla_bridge.computation_builder_shim(c),
                                  (_broadcast(k1), _broadcast(k2)),
                                  (_broadcast(x1), _broadcast(x2)))


threefry2x32_p = core.Primitive("threefry2x32")
threefry2x32_p.multiple_results = True
threefry2x32_p.def_impl(partial(xla.apply_primitive, threefry2x32_p))
threefry2x32_p.def_abstract_eval(_threefry2x32_abstract_eval)
batching.defbroadcasting(threefry2x32_p)
xla.translations[threefry2x32_p] = xla.lower_fun(
    partial(_threefry2x32_lowering, use_rolled_loops=False))
xla.backend_specific_translations['cpu'][threefry2x32_p] = xla.lower_fun(
    partial(_threefry2x32_lowering, use_rolled_loops=True))
if cuda_prng:
    xla.backend_specific_translations['gpu'][threefry2x32_p] = \
        _threefry2x32_gpu_translation_rule


@jit
def threefry_2x32(keypair, count):
Ejemplo n.º 14
0
@custom_transforms
def cumsum(x):
    return np.cumsum(x, axis=-1)


defjvp(cumsum, lambda g, ans, x: np.cumsum(g, axis=-1))


# XXX work around the issue: batching rule for 'reduce_window' not implemented
# when using @custom_transforms decorator
def _cumprod_impl(x):
    return np.cumprod(x, axis=-1)


cumprod_p = core.Primitive('cumprod')
cumprod_p.def_impl(_cumprod_impl)
cumprod_p.def_abstract_eval(
    partial(partial_eval.abstract_eval_fun, _cumprod_impl))
xla.translations[cumprod_p] = partial(xla.lower_fun, _cumprod_impl)
# XXX this implementation does not address the case x=0, hence the result in that case will be nan
# Ref: https://stackoverflow.com/questions/40916955/how-to-compute-gradient-of-cumprod-safely
ad.defjvp2(cumprod_p, lambda g, ans, x: np.cumsum(g / x, axis=-1) * ans)
batching.defvectorized(cumprod_p)


def cumprod(x):
    return cumprod_p.bind(x)


def promote_shapes(*args, shape=()):
Ejemplo n.º 15
0
Archivo: ad.py Proyecto: jbampton/jax
                         new_invars, new_outvars, jaxpr.jaxpr.eqns)
  return core.ClosedJaxpr(new_jaxpr, jaxpr.consts)

def _perm(primal_counts, tangent_counts, lst):
  n = sum(primal_counts)
  primals, tangents = lst[:n], lst[n:]
  primal_groups = split_list(primals, primal_counts[:-1])
  tangent_groups = split_list(tangents, tangent_counts[:-1])
  return _interleave(primal_groups, tangent_groups)

def _interleave(xs, ys):
  assert len(xs) == len(ys)
  return [e for pair in zip(xs, ys) for l in pair for e in l]


custom_lin_p: core.Primitive = core.Primitive('custom_lin')
custom_lin_p.def_abstract_eval(lambda *_, out_avals, **__: out_avals)
custom_lin_p.multiple_results = True

def _raise_custom_vjp_error_on_jvp(*_, **__):
  raise TypeError("can't apply forward-mode autodiff (jvp) to a custom_vjp "
                  "function.")
custom_lin_p.def_impl(_raise_custom_vjp_error_on_jvp)

def _custom_lin_transpose(cts_out, *invals, num_res, bwd, out_avals):
  res, _ = split_list(invals, [num_res])
  cts_out = map(instantiate_zeros_aval, out_avals, cts_out)
  cts_in = bwd.call_wrapped(*res, *cts_out)
  return [None] * num_res + list(cts_in)
primitive_transposes[custom_lin_p] = _custom_lin_transpose
Ejemplo n.º 16
0
    cts = [
        zeros_like_aval(a) if type(ct) is Zero else ct
        for ct, a in zip(cts, cts_avals)
    ]

    cts_out = linear_call_p.bind(*t_consts,
                                 *f_consts,
                                 *operands_res,
                                 *cts,
                                 callee=transpose,
                                 transpose=callee,
                                 num_callee_consts=len(t_consts),
                                 num_transpose_consts=len(f_consts),
                                 num_res=len(operands_res))

    return [None
            ] * (num_callee_consts + num_transpose_consts + num_res) + cts_out


def _linear_call_abstract_eval(*args, **kwargs):
    return map(core.raise_to_shaped, kwargs['callee'].out_avals)


linear_call_p = core.Primitive('linear_call')
linear_call_p.multiple_results = True
linear_call_p.def_impl(_linear_call_impl)
linear_call_p.def_abstract_eval(_linear_call_abstract_eval)
ad.primitive_transposes[linear_call_p] = _linear_call_transpose_rule
xla.initial_style_translations[linear_call_p] = xla.lower_fun_initial_style(
    _linear_call_impl)
Ejemplo n.º 17
0
  return tuple(x)


def _threefry2x32_gpu_translation_rule(c, k1, k2, x1, x2):
  shape = lax.broadcast_shapes(
      c.GetShape(k1).dimensions(), c.GetShape(k2).dimensions(),
      c.GetShape(x1).dimensions(), c.GetShape(x2).dimensions())
  rank = len(shape)
  def _broadcast(x):
    ndims = c.GetShape(x).rank()
    return c.BroadcastInDim(x, shape, tuple(range(rank - ndims, rank)))
  return cuda_prng.threefry2x32(
      c, (_broadcast(k1), _broadcast(k2)), (_broadcast(x1), _broadcast(x2)))

threefry2x32_p = core.Primitive("threefry2x32")
threefry2x32_p.multiple_results = True
threefry2x32_p.def_impl(partial(xla.apply_primitive, threefry2x32_p))
threefry2x32_p.def_abstract_eval(_threefry2x32_abstract_eval)
batching.defbroadcasting(threefry2x32_p)
xla.translations[threefry2x32_p] = xla.lower_fun(
    partial(_threefry2x32_lowering, use_rolled_loops=False), instantiate=True)
xla.backend_specific_translations['cpu'][threefry2x32_p] = xla.lower_fun(
    partial(_threefry2x32_lowering, use_rolled_loops=True), instantiate=True)
if cuda_prng:
  xla.backend_specific_translations['gpu'][threefry2x32_p] = \
      _threefry2x32_gpu_translation_rule

@jit
def threefry_2x32(keypair, count):
  """Apply the Threefry 2x32 hash.
Ejemplo n.º 18
0
    # buffers from different XLA backends are passed through the host.
    backend = xb.get_device_backend(device)
    moved_buf = backend.buffer_from_pyval(x.device_buffer.to_py(), device)
  return device_array.make_device_array(x.aval, device, moved_buf)


def _device_put_impl(x, device: Optional[Device] = None):
  if device_array.type_is_device_array(x):
    return _copy_device_array_to_device(x, device)

  try:
    a = xla.abstractify(x)
  except TypeError as err:
    raise TypeError(
        f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err
  return aval_to_result_handler(device, a)(*device_put(x, device))

device_put_p = core.Primitive('device_put')
device_put_p.def_impl(_device_put_impl)
device_put_p.def_abstract_eval(lambda x, device=None: x)
xla.translations[device_put_p] = lambda c, x, device=None: x
ad.deflinear2(device_put_p, lambda cotangent, _, **kwargs: [cotangent])
masking.defvectorized(device_put_p)
batching.defvectorized(device_put_p)

def _device_put_lowering(ctx, x, *, device):
  return [x]


mlir.register_lowering(device_put_p, _device_put_lowering)
Ejemplo n.º 19
0
from oryx.core import trace_util

__all__ = [
    'HarvestTrace',
    'HarvestTracer',
    'call_and_reap',
    'harvest',
    'nest',
    'plant',
    'reap',
    'sow',
]

Value = Any

sow_p = jax_core.Primitive('sow')
sow_p.multiple_results = True


@sow_p.def_impl
def _sow_impl(*args, **_):
    return args


@sow_p.def_abstract_eval
def _sow_abstract_eval(*avals, **_):
    return avals


@functools.partial(ad.deflinear, sow_p)
def _sow_transpose(cts_in, *_, **__):
Ejemplo n.º 20
0
custom_jvp_call_p = CustomJVPCallPrimitive('custom_jvp_call')


def _custom_jvp_call_jaxpr_impl(*args, fun_jaxpr: core.ClosedJaxpr, **params):
    del params  # other params ignored because we're just executing the primal fun
    return core.jaxpr_as_fun(fun_jaxpr)(*args)


def _custom_jvp_call_jaxpr_abstract_eval(*args, fun_jaxpr: core.ClosedJaxpr,
                                         **params):
    del args, params
    return fun_jaxpr.out_avals


custom_jvp_call_jaxpr_p = core.Primitive('custom_jvp_call_jaxpr')
custom_jvp_call_jaxpr_p.multiple_results = True
custom_jvp_call_jaxpr_p.def_impl(_custom_jvp_call_jaxpr_impl)
custom_jvp_call_jaxpr_p.def_abstract_eval(_custom_jvp_call_jaxpr_abstract_eval)
CustomJVPCallPrimitive.initial_style = custom_jvp_call_jaxpr_p


def _custom_jvp_call_jaxpr_jvp(primals, tangents, *,
                               fun_jaxpr: core.ClosedJaxpr,
                               jvp_jaxpr_thunk: Callable[[],
                                                         Tuple[core.Jaxpr,
                                                               Sequence[Any]]],
                               num_consts: int):
    _, args = split_list(primals, [num_consts])
    consts_dot, args_dot = split_list(tangents, [num_consts])
    if any(type(t) is not Zero for t in consts_dot):
Ejemplo n.º 21
0
def standard_pmap_primitive(name):
    prim = core.Primitive(name)
    prim.def_impl(partial(pxla.apply_parallel_primitive, prim))
    prim.def_abstract_eval(lambda x, *args, **params: x)
    return prim
Ejemplo n.º 22
0
# https://github.com/google/jax/issues/1142
# courtesy of mattjj
def mybar_impl(w):
    A, _ = pymbar.BAR(w[0], w[1])
    return A


def mybar_vjp(g, w):
    return g * tmbar.dG_dw(w)


def mybar(x):
    return mybar_p.bind(x)


mybar_p = core.Primitive('mybar')
mybar_p.def_impl(mybar_impl)
ad.defvjp(mybar_p, mybar_vjp)


def BAR_leg(insertion_du_dls, deletion_du_dls, lambda_schedule):
    insertion_W = math_utils.trapz(insertion_du_dls, lambda_schedule)
    deletion_W = math_utils.trapz(deletion_du_dls, lambda_schedule)

    return mybar(jnp.stack([insertion_W, deletion_W]))


def BAR_loss(
        complex_insertion_du_dls,  # [C, N]
        complex_deletion_du_dls,  # [C, N]
        solvent_insertion_du_dls,  # [C, N]
Ejemplo n.º 23
0
  else:
    val_out, arg_out = approx_min_k(operand, k, reduction_dimension,
                                    recall_target,
                                    reduction_input_size_override,
                                    aggregate_to_topk)
  if type(tangent) is ad_util.Zero:
    tangent_out = ad_util.Zero.from_value(val_out)
  else:
    arg_shape = arg_out.shape
    rank = len(arg_shape)
    if reduction_dimension < 0:
      reduction_dimension += rank
    iotas = [
        lax.broadcasted_iota(arg_out.dtype, arg_shape, i) for i in range(rank)
    ]
    idx = tuple(
        arg_out if i == reduction_dimension else iotas[i] for i in range(rank))
    tangent_out = tangent[idx]
  return (val_out, arg_out), (tangent_out, ad_util.Zero.from_value(arg_out))


approx_top_k_p = core.Primitive('approx_top_k')
approx_top_k_p.multiple_results = True
approx_top_k_p.def_impl(partial(xla.apply_primitive, approx_top_k_p))
approx_top_k_p.def_abstract_eval(_approx_top_k_abstract_eval)
xla.register_translation(approx_top_k_p, _approx_top_k_fallback_translation)
xla.register_translation(approx_top_k_p, _approx_top_k_tpu_translation,
                         platform='tpu')
batching.primitive_batchers[approx_top_k_p] = _approx_top_k_batch_rule
ad.primitive_jvps[approx_top_k_p] = _approx_top_k_jvp
Ejemplo n.º 24
0
                   axis_name=axis_name, axis_env=axis_env,
                   axis_index_groups=axis_index_groups, platform=platform)
    dtype = c.get_shape(val).numpy_dtype()
    if dtypes.issubdtype(dtype, np.complexfloating):
      return xops.Complex(psum(xops.Real(val)), psum(xops.Imag(val)))
    else:
      return psum(val)
  return xops.Tuple(c, list(map(_translate, args)))

def _psum_transpose_rule(cts, axis_name, axis_index_groups):
  nonzero_out_cts, treedef = tree_util.tree_flatten(cts)
  nonzero_in_cts = psum_p.bind(*nonzero_out_cts, axis_name=axis_name,
                               axis_index_groups=axis_index_groups)
  return tree_util.tree_unflatten(treedef, nonzero_in_cts)

psum_p = core.Primitive('psum')
psum_p.multiple_results = True
psum_p.def_abstract_eval(lambda *args, **params: map(raise_to_shaped, args))
pxla.soft_pmap_rules[psum_p] = \
    partial(_allreduce_soft_pmap_rule, psum_p, lax._reduce_sum)
xla.parallel_translations[psum_p] = _psum_translation_rule
ad.deflinear(psum_p, _psum_transpose_rule)
pxla.multi_host_supported_collectives.add(psum_p)
batching.split_axis_rules[psum_p] = partial(_split_axis_comm_assoc, psum_p)
batching.primitive_batchers[psum_p] = partial(_collective_batcher, psum_p)
batching.collective_rules[psum_p] = \
  partial(_batched_reduction_collective,
          psum_p,
          lambda v, d: v.sum(d),
          lambda v, axis_size: axis_size * v)
Ejemplo n.º 25
0
      fun = lu.wrap_init(f, kwargs)
      flat_args, in_tree = tree_util.tree_flatten(args)
      flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
      out_tree_dest = None
      out = prim.bind(flat_fun, *flat_args, num_args=len(flat_args),
                      name=f.__name__,
                      in_tree=in_tree,
                      out_tree=lambda: out_tree_dest,
                      **params)
      out_tree_dest = out_tree()
      return tree_util.tree_unflatten(out_tree_dest, out)
    return wrapped
  return bind


tie_all_p = jax_core.Primitive('tie_all')
tie_all_p.multiple_results = True
tie_all_p.def_impl(lambda *args: args)
tie_all_p.def_abstract_eval(lambda *args: safe_map(  # pylint: disable=g-long-lambda
    abstract_arrays.raise_to_shaped, args))
xla.translations[tie_all_p] = lambda c, *args: xc.ops.Tuple(c, args)


def _tie_all_batch_rule(batched_args, batch_dims):
  return batched_args, batch_dims


def _tie_all_transpose(cts_in, *args, **params):
  del args, params
  return cts_in
ad.deflinear(tie_all_p, _tie_all_transpose)
Ejemplo n.º 26
0
    for data dependency, for implementing the "result" feature, and for
    the current token.
  * tapped_args_treedef_: the treedef of the tapped positional arguments.
  * tap_func_: the actual (Python) function to invoke with the tapped positional
    arguments (unflatted according to tapped_args_treedef_) and
    the parameters that were passed to the id_tap function.
  * transforms: a tuple of the transformations that have been applied. Each
    element of the tuple is itself a tuple with the first element the name
    of the transform. The remaining elements depend on the transform. For
    example, for `batch`, the parameters are the dimensions that have been
    batched, and for `mask` the logical shapes. These are unpacked by
    _ConsumerCallable before passing to the user function.
  * the remaining parameters are from the user's invocation of the id_tap
    API function and are passed to the tap function.
"""
id_tap_p = core.Primitive("id_tap")
id_tap_p.multiple_results = True
xla.outfeed_primitives.add(id_tap_p)


def _add_transform(params: Dict, name: str, *transform_params) -> Dict:
  """Adds the `transform` to the params["transforms"].

  Uses a tuple representation internally, will be unpacked before the
  callback by _ConsumerCallable.
  """
  new_transform = (name, *transform_params)
  return dict(
      params, transforms=(params.get("transforms", ()) + (new_transform,)))

Ejemplo n.º 27
0
    for operand in contract_fake_ops:
        idx = tuple(i for i, fake_op in enumerate(fake_ops)
                    if operand is fake_op)
        assert len(idx) == 1
        contract_operands.append(operands[idx[0]])
    return contract_operands, contractions


lax_numpy._polymorphic_einsum_contract_path_handlers[
    _DimPolynomial] = _einsum_contract_path

# A JAX primitive with no array arguments but with a dimension parameter
# that is a DimPoly. The value of the primitive is the value of the dimension.
# This primitive is used only in the context of jax2tf, so it does not need
# XLA translation rules.
dim_as_value_p = core.Primitive("dim_as_value")


def _dim_as_value_abstract(dim: DimSize) -> core.AbstractValue:
    return core.ShapedArray((), np.int32)


dim_as_value_p.def_abstract_eval(_dim_as_value_abstract)


def _dim_as_value(dim: DimSize):
    return dim_as_value_p.bind(dim=dim)


class PolyShape(tuple):
    """Tuple of polymorphic dimension specifications.
Ejemplo n.º 28
0
    tangents = map(ad.instantiate_zeros, tangents)
    jvp_call, _ = ad.jvp_jaxpr(call, [True] * len(primals), True)
    jvp_in_tree = treedef_tuple((in_tree, in_tree))
    jvp_out_tree = treedef_tuple((out_tree, out_tree))
    outs = custom_vmap_p.bind(*primals,
                              *tangents,
                              call=jvp_call,
                              rule=jvp_of_rule_rule,
                              in_tree=jvp_in_tree,
                              out_tree=jvp_out_tree)
    assert len(outs) % 2 == 0, len(outs)
    out_primals, out_tangents = util.split_list(outs, [len(outs) // 2])
    return out_primals, out_tangents


custom_vmap_p = core.Primitive('custom_vmap_call')
custom_vmap_p.multiple_results = True
custom_vmap_p.def_impl(custom_vmap_impl)
custom_vmap_p.def_abstract_eval(custom_vmap_abstract_eval)
batching.primitive_batchers[custom_vmap_p] = custom_vmap_batching
ad.primitive_jvps[custom_vmap_p] = custom_vmap_jvp
xla.register_initial_style_primitive(custom_vmap_p)
mlir.register_lowering(custom_vmap_p,
                       mlir.lower_fun(custom_vmap_impl, multiple_results=True))

# -- custom vmap applications


def tree_split(mask, tree):
    lhs = tree_map(lambda l, x: x if l else None, mask, tree)
    rhs = tree_map(lambda l, x: None if l else x, mask, tree)
Ejemplo n.º 29
0
def sparse_array_constant_handler(c, val, canonicalize_dtypes):
    return (xb.constant(val.data, canonicalize_dtypes),
            xb.constant(val.indices, canonicalize_dtypes))


core.pytype_aval_mappings[SparseArray] = lambda x: x.aval
core.raise_to_shaped_mappings[AbstractSparseArray] = lambda aval, _: aval
xla.pytype_aval_mappings[SparseArray] = lambda x: x.aval
xla.canonicalize_dtype_handlers[SparseArray] = lambda x: x
xla.device_put_handlers[SparseArray] = sparse_array_device_put_handler
xla.xla_result_handlers[AbstractSparseArray] = sparse_array_result_handler
xla.xla_shape_handlers[AbstractSparseArray] = sparse_array_shape_handler
xb.register_constant_handler(SparseArray, sparse_array_constant_handler)

sp_indices_p = core.Primitive('sp_indices')


@sp_indices_p.def_impl
def _sp_indices_impl(mat):
    return mat.indices


@sp_indices_p.def_abstract_eval
def _sp_indices_abstract_eval(mat):
    return mat.indices_aval


def _sp_indices_translation_rule(c, data, indices):
    return indices
Ejemplo n.º 30
0
  return False

def checkpoint_dots(prim, *_, **__) -> bool:
  # Matrix multiplies are expensive, so let's save them (and nothing else).
  return prim in {jax._src.lax.lax.dot_general_p,
                  jax._src.lax.convolution.conv_general_dilated_p}

def dot_with_no_batch_dims(prim, *_, **params) -> bool:
  # This is a useful heuristic for transformers.
  if prim is jax._src.lax.lax.dot_general_p:
    (_, _), (lhs_b, rhs_b) = params['dimension_numbers']
    if not lhs_b and not rhs_b:
      return True
  return False

name_p = core.Primitive('name')

def save_any_names_but_these(*names_not_to_save):
  # Save named values, excluding the names given.
  names_not_to_save = frozenset(names_not_to_save)
  def policy(prim, *_, **params):
    if prim is name_p:
      return params['name'] not in names_not_to_save
    return False  # only allow saving named values
  return policy

def save_only_these_names(*names_which_can_be_saved):
  # Save named values, only among the names given.
  names_which_can_be_saved = set(names_which_can_be_saved)
  def policy(prim, *_, **params):
    if prim is name_p: