コード例 #1
0
ファイル: primitive.py プロジェクト: tensorflow/probability
    def __init__(self, name):
        super(FlatPrimitive, self).__init__(name)
        self.multiple_results = True

        def _abstract(*flat_avals, **params):
            return pe.abstract_eval_fun(self.impl, *flat_avals, **params)

        self.def_abstract_eval(_abstract)

        def _jvp(primals, tangents, **params):
            return ad.jvp(lu.wrap_init(self.impl, params)).call_wrapped(
                primals, tangents)

        ad.primitive_jvps[self] = _jvp

        def _batch(args, dims, **params):
            batched, out_dims = batch_fun(lu.wrap_init(self.impl, params),
                                          dims)
            return batched.call_wrapped(*args), out_dims()

        batching.primitive_batchers[self] = _batch

        def _xla(c, *xla_args, **params):
            translation = xla.lower_fun(self.impl, multiple_results=True)
            return translation(c, *xla_args, **params)

        xla.translations[self] = _xla

        def _mlir(c, *mlir_args, **params):
            lowering = mlir.lower_fun(self.impl, multiple_results=True)
            return lowering(c, *mlir_args, **params)

        mlir.register_lowering(self, _mlir)
コード例 #2
0
ファイル: primitive.py プロジェクト: tensorflow/probability
def hop_lowering(prim):
    def rule(ctx, *args, backend, name, call_jaxpr, **_params):
        return mlir._call_lowering(  # pylint: disable=protected-access
            name, name, call_jaxpr, backend, ctx.module_context, ctx.avals_in,
            ctx.avals_out, *args)

    mlir.register_lowering(prim, rule)
    return rule
コード例 #3
0
  def test_should_not_pass_tokens_into_unordered_effect(self):

    def effect_lowering(ctx, *, effect):
      self.assertEmpty(ctx.tokens_in)
      return []
    mlir.register_lowering(effect_p, effect_lowering)

    @jax.jit
    def f(x):
      effect_p.bind(effect='bar')
      return x + 1.
    f.lower(2.)
コード例 #4
0
    def test_lowering_ordered_effect_should_create_tokens(self):
        def effect_lowering(ctx, *, effect):
            ctx.set_tokens_out(ctx.tokens_in)
            return []

        mlir.register_lowering(effect_p, effect_lowering)

        @jax.jit
        def f(x):
            effect_p.bind(effect='foo')
            return x + 1.

        mhlo = f.lower(2.).compiler_ir()
        main = mhlo.body.operations[0]
        first_op = main.body.blocks[0].operations[0]
        self.assertEqual(first_op.operation.name, "mhlo.create_token")

        @jax.jit
        def f(x):
            effect_p.bind(effect='foo')
            effect_p.bind(effect='foo2')
            return x + 1.

        mhlo = f.lower(2.).compiler_ir()
        main = mhlo.body.operations[0]
        first_op = main.body.blocks[0].operations[0]
        self.assertEqual(first_op.operation.name, "mhlo.create_token")
        second_op = main.body.blocks[0].operations[1]
        self.assertEqual(second_op.operation.name, "mhlo.create_token")

        @jax.jit
        def f(x):
            effect_p.bind(effect='foo')
            return x + 1.

        mhlo = f.lower(2.).compiler_ir()
        main = mhlo.body.operations[0]
        first_op = main.body.blocks[0].operations[0]
        self.assertEqual(first_op.operation.name, "mhlo.create_token")

        @jax.jit
        def f(x):
            effect_p.bind(effect='foo')
            effect_p.bind(effect='foo2')
            return x + 1.

        mhlo = f.lower(2.).compiler_ir()
        main = mhlo.body.operations[0]
        first_op = main.body.blocks[0].operations[0]
        self.assertEqual(first_op.operation.name, "mhlo.create_token")
        second_op = main.body.blocks[0].operations[1]
        self.assertEqual(second_op.operation.name, "mhlo.create_token")
コード例 #5
0
  def test_should_pass_tokens_into_ordered_effect(self):

    def _effect_lowering(ctx, *, effect):
      self.assertListEqual(list(ctx.tokens_in.effects()), ['foo'])
      ctx.set_tokens_out(ctx.tokens_in)
      return []
    mlir.register_lowering(effect_p, _effect_lowering)

    @jax.jit
    def f(x):
      effect_p.bind(effect='foo')
      return x + 1.
    f.lower(2.)
コード例 #6
0
  def test_lowering_that_doesnt_set_tokens_should_cause_error(self):

    def bad_effect_lowering(ctx, *, effect):
      # Doesn't call `ctx.set_tokens_out`!
      return []
    mlir.register_lowering(effect_p, bad_effect_lowering)

    @jax.jit
    def f(x):
      effect_p.bind(effect='foo')
      return x + 1.
    with self.assertRaisesRegex(ValueError, 'Lowering rule for `effect` needs to '
        'set `tokens_out`'):
      f.lower(2.)
コード例 #7
0
  def test_lowering_that_sets_wrong_tokens_should_cause_error(self):

    def bad_effect_lowering(ctx, *, effect):
      ctx.set_tokens_out(mlir.TokenSet(bar=ctx.tokens_in.get('foo')))
      return []
    mlir.register_lowering(effect_p, bad_effect_lowering)

    @jax.jit
    def f(x):
      effect_p.bind(effect='foo')
      return x + 1.
    with self.assertRaisesRegex(ValueError, 'Lowering rule for `effect` returns '
        'incorrect set of output token.'):
      f.lower(2.)
コード例 #8
0
    def setUp(self):
        super().setUp()
        self.old_x64 = config.jax_enable_x64
        config.update('jax_enable_x64', False)
        self._old_lowering = mlir._lowerings[effect_p]

        def _effect_lowering(ctx, *, effect):
            if effect in core.ordered_effects:
                expected_effects = [effect]
            else:
                expected_effects = []
            self.assertListEqual(list(ctx.tokens_in.effects()),
                                 expected_effects)
            ctx.set_tokens_out(ctx.tokens_in)
            return []

        mlir.register_lowering(effect_p, _effect_lowering)
        dispatch.runtime_tokens.clear()
コード例 #9
0
  def test_nontrivial_lowering_with_unordered_effect_should_consume_token(self):

    mlir.register_lowering(effect_p, function_effect_lowering)

    @jax.jit
    def f(x):
      effect_p.bind(effect='bar')
      return x + 1.

    mhlo = f.lower(2.).compiler_ir()
    main = mhlo.body.operations[0]
    first_op = main.body.blocks[0].operations[0]
    self.assertEqual(first_op.operation.name, "func.call")
    self.assertEqual(str(first_op.attributes["callee"]), "@effect")
    self.assertLen(list(first_op.operands), 0)
    func = mhlo.body.operations[1]
    self.assertEqual(func.name.value, "effect")
    self.assertLen(list(func.type.inputs), 0)
    self.assertLen(list(func.type.results), 0)
コード例 #10
0

def _custom_jvp_call_jaxpr_abstract_eval(*args, fun_jaxpr: core.ClosedJaxpr,
                                         **params):
    del args, params
    return fun_jaxpr.out_avals


custom_jvp_call_jaxpr_p = core.AxisPrimitive('custom_jvp_call_jaxpr')
custom_jvp_call_jaxpr_p.multiple_results = True
custom_jvp_call_jaxpr_p.def_impl(_custom_jvp_call_jaxpr_impl)
custom_jvp_call_jaxpr_p.def_abstract_eval(_custom_jvp_call_jaxpr_abstract_eval)
CustomJVPCallPrimitive.initial_style = custom_jvp_call_jaxpr_p

mlir.register_lowering(
    custom_jvp_call_jaxpr_p,
    mlir.lower_fun(_custom_jvp_call_jaxpr_impl, multiple_results=True))


def _custom_jvp_call_jaxpr_jvp(primals, tangents, *,
                               fun_jaxpr: core.ClosedJaxpr,
                               jvp_jaxpr_thunk: Callable[[],
                                                         Tuple[core.Jaxpr,
                                                               Sequence[Any]]],
                               num_consts: int):
    _, args = split_list(primals, [num_consts])
    consts_dot, args_dot = split_list(tangents, [num_consts])
    if any(type(t) is not Zero for t in consts_dot):
        raise ad.CustomJVPException()
    jvp_jaxpr, jvp_consts = jvp_jaxpr_thunk()  # consts can be tracers!
    args_dot = map(ad.instantiate_zeros, args_dot)
コード例 #11
0
                              rule=jvp_of_rule_rule,
                              in_tree=jvp_in_tree,
                              out_tree=jvp_out_tree)
    assert len(outs) % 2 == 0, len(outs)
    out_primals, out_tangents = util.split_list(outs, [len(outs) // 2])
    return out_primals, out_tangents


custom_vmap_p = core.Primitive('custom_vmap_call')
custom_vmap_p.multiple_results = True
custom_vmap_p.def_impl(custom_vmap_impl)
custom_vmap_p.def_abstract_eval(custom_vmap_abstract_eval)
batching.primitive_batchers[custom_vmap_p] = custom_vmap_batching
ad.primitive_jvps[custom_vmap_p] = custom_vmap_jvp
xla.register_initial_style_primitive(custom_vmap_p)
mlir.register_lowering(custom_vmap_p,
                       mlir.lower_fun(custom_vmap_impl, multiple_results=True))

# -- custom vmap applications


def tree_split(mask, tree):
    lhs = tree_map(lambda l, x: x if l else None, mask, tree)
    rhs = tree_map(lambda l, x: None if l else x, mask, tree)
    return lhs, rhs


def tree_merge(mask, lhs_tree, rhs_tree):
    return tree_map(lambda l, x_l, x_r: x_l
                    if l else x_r, mask, lhs_tree, rhs_tree)

コード例 #12
0
ファイル: fft.py プロジェクト: frederikwilde/jax
    # Use JAX's convention for complex gradients
    # https://github.com/google/jax/issues/6223#issuecomment-807740707
    return lax.conj(out)


def _fft_transpose_rule(t, operand, fft_type, fft_lengths):
    if fft_type == xla_client.FftType.RFFT:
        result = _rfft_transpose(t, fft_lengths)
    elif fft_type == xla_client.FftType.IRFFT:
        result = _irfft_transpose(t, fft_lengths)
    else:
        result = fft(t, fft_type, fft_lengths)
    return result,


def _fft_batching_rule(batched_args, batch_dims, fft_type, fft_lengths):
    x, = batched_args
    bd, = batch_dims
    x = batching.moveaxis(x, bd, 0)
    return fft(x, fft_type, fft_lengths), 0


fft_p = Primitive('fft')
fft_p.def_impl(_fft_impl)
fft_p.def_abstract_eval(fft_abstract_eval)
mlir.register_lowering(fft_p, _fft_lowering)
ad.deflinear2(fft_p, _fft_transpose_rule)
batching.primitive_batchers[fft_p] = _fft_batching_rule
if pocketfft:
    mlir.register_lowering(fft_p, _fft_lowering_cpu, platform='cpu')
コード例 #13
0
ファイル: dispatch.py プロジェクト: jbampton/jax
    # buffers from different XLA backends are passed through the host.
    backend = xb.get_device_backend(device)
    moved_buf = backend.buffer_from_pyval(x.device_buffer.to_py(), device)
  return device_array.make_device_array(x.aval, device, moved_buf)


def _device_put_impl(x, device: Optional[Device] = None):
  if device_array.type_is_device_array(x):
    return _copy_device_array_to_device(x, device)

  try:
    a = xla.abstractify(x)
  except TypeError as err:
    raise TypeError(
        f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err
  return aval_to_result_handler(device, a)(*device_put(x, device))

device_put_p = core.Primitive('device_put')
device_put_p.def_impl(_device_put_impl)
device_put_p.def_abstract_eval(lambda x, device=None: x)
xla.translations[device_put_p] = lambda c, x, device=None: x
ad.deflinear2(device_put_p, lambda cotangent, _, **kwargs: [cotangent])
masking.defvectorized(device_put_p)
batching.defvectorized(device_put_p)

def _device_put_lowering(ctx, x, *, device):
  return [x]


mlir.register_lowering(device_put_p, _device_put_lowering)
コード例 #14
0
ファイル: harvest.py プロジェクト: mederrata/probability
        donated_invars=(False, ) * len(args))


xla.register_translation(nest_p, _nest_translation_rule)


def _nest_lowering(ctx, *args, name, call_jaxpr, scope, **_):
    return mlir._xla_call_lower(  # pylint: disable=protected-access
        ctx,
        *args,
        name=jax_util.wrap_name(name, f'nest[{scope}]'),
        call_jaxpr=call_jaxpr,
        donated_invars=(False, ) * len(args))


mlir.register_lowering(nest_p, _nest_lowering)


def _nest_transpose_rule(*args, **kwargs):
    return ad.call_transpose(nest_p, *args, **kwargs)


ad.primitive_transposes[nest_p] = _nest_transpose_rule


def nest(f, *, scope: str):
    """Wraps a function to create a new scope for harvested values.

  Harvested values live in one dynamic name scope (for a particular tag),
  and in strict mode, values with the same name cannot be collected or injected
  more than once. nest(f, scope=<name>) will take all tagged values in `f` and
コード例 #15
0
ファイル: coo.py プロジェクト: cloudhan/jax
  return _coo_todense(data_dot, row, col, spinfo=spinfo)

def _coo_todense_transpose(ct, data, row, col, *, spinfo):
  # Note: we assume that transpose has the same sparsity pattern.
  # Can we check this?
  assert ad.is_undefined_primal(data)
  if ad.is_undefined_primal(row) or ad.is_undefined_primal(col):
    raise ValueError("Cannot transpose with respect to sparse indices")
  assert ct.shape == spinfo.shape
  assert row.aval.dtype == col.aval.dtype
  assert ct.dtype == data.aval.dtype
  return _coo_extract(row, col, ct), row, col

ad.defjvp(coo_todense_p, _coo_todense_jvp, None, None)
ad.primitive_transposes[coo_todense_p] = _coo_todense_transpose
mlir.register_lowering(coo_todense_p, _coo_todense_lowering)
if gpu_sparse:
  if gpu_sparse.cuda_is_supported:
    mlir.register_lowering(
        coo_todense_p,
        partial(_coo_todense_gpu_lowering, gpu_sparse.cuda_coo_todense),
        platform='cuda')
  if gpu_sparse.rocm_is_supported:
    mlir.register_lowering(
        coo_todense_p,
        partial(_coo_todense_gpu_lowering, gpu_sparse.rocm_coo_todense),
        platform='rocm')

if sparse_apis and sparse_apis.is_supported:
  mlir.register_lowering(
      coo_todense_p,
コード例 #16
0
ファイル: solves.py プロジェクト: xueeinstein/jax
    ]
    # Broadcast out b if necessary
    new_b = [
        batching.broadcast(x, axis_size, 0) if now_bat and not was_bat else
        batching.moveaxis(x, d, 0) if now_bat and d != 0 else x
        for x, d, was_bat, now_bat in zip(b, b_dims, orig_b_bat, b_bat)
    ]

    outs = linear_solve_p.bind(*(new_params + new_b),
                               const_lengths=const_lengths,
                               jaxprs=batched_jaxprs)
    out_dims = [
        0 if batched else batching.not_mapped for batched in solve_x_bat
    ]
    return outs, out_dims


linear_solve_p = core.AxisPrimitive('custom_linear_solve')
linear_solve_p.multiple_results = True
linear_solve_p.def_impl(_custom_linear_solve_impl)
linear_solve_p.def_abstract_eval(_linear_solve_abstract_eval)
ad.primitive_jvps[linear_solve_p] = _custom_linear_solve_jvp
xla.register_initial_style_primitive(linear_solve_p)
mlir.register_lowering(
    linear_solve_p,
    mlir.lower_fun(_custom_linear_solve_impl, multiple_results=True))
ad.primitive_transposes[linear_solve_p] = _linear_solve_transpose_rule
batching.axis_primitive_batchers[linear_solve_p] = _linear_solve_batching_rule
pe.partial_eval_jaxpr_custom_rules[linear_solve_p] = \
    partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'linear_solve')
コード例 #17
0
def debug_callback_lowering(ctx, *args, effect, callback, **params):
  if effect in core.ordered_effects:
    token = ctx.tokens_in.get(effect)[0]
    result, keepalive, token = _ordered_effect_lowering(ctx, token,
        *args, effect=effect, callback=callback, **params)
    ctx.set_tokens_out(mlir.TokenSet({effect: (token,)}))
  else:
    def _callback(*flat_args):
      return tuple(debug_callback_p.impl(
        *flat_args, effect=effect, callback=callback, **params))
    result, keepalive = mlir.emit_python_callback(ctx.module_context.platform,
      _callback, list(args), ctx.avals_in, ctx.avals_out,  True)
  ctx.module_context.add_keepalive(keepalive)
  return result
mlir.register_lowering(debug_callback_p, debug_callback_lowering,
                       platform="cpu")

def debug_callback(callback: Callable[..., Any], effect: DebugEffect, *args,
                   **kwargs):
  """Calls a stageable Python callback.

  `debug_callback` enables you to pass in a Python function that can be called
  inside of a staged JAX program. A `debug_callback` follows existing JAX
  transformation *pure* operational semantics, which are therefore unaware of
  side-effects. This means the effect could be dropped, duplicated, or
  potentially reordered in the presence of higher-order primitives and
  transformations.

  We want this behavior because we'd like `debug_callback` to be "innocuous",
  i.e. we want these primitives to change the JAX computation as little as
  possible while revealing as much about them as possible, such as which parts
コード例 #18
0
ファイル: api.py プロジェクト: frederikwilde/jax
  elif isinstance(obj, BCOO):
    _, indices = bufs
    return bcoo_extract(indices, ct), indices
  elif isinstance(obj, COO):
    _, row, col = bufs
    return _coo_extract(row, col, ct), row, col
  else:
    raise NotImplementedError(f"todense_transpose for {type(obj)}")

def _todense_batching_rule(batched_args, batch_dims, *, tree):
  return jax.vmap(partial(_todense_impl, tree=tree), batch_dims)(*batched_args), 0

ad.primitive_jvps[todense_p] = _todense_jvp
ad.primitive_transposes[todense_p] = _todense_transpose
batching.primitive_batchers[todense_p] = _todense_batching_rule
mlir.register_lowering(todense_p, mlir.lower_fun(
    _todense_impl, multiple_results=False))


def empty(shape, dtype=None, index_dtype='int32', sparse_format='bcoo', **kwds):
  """Create an empty sparse array.

  Args:
    shape: sequence of integers giving the array shape.
    dtype: (optional) dtype of the array.
    index_dtype: (optional) dtype of the index arrays.
    format: string specifying the matrix format (e.g. ['bcoo']).
    **kwds: additional keywords passed to the format-specific _empty constructor.
  Returns:
    mat: empty sparse matrix.
  """
  formats = {'bcoo': BCOO, 'coo': COO, 'csr': CSR, 'csc': CSC}
コード例 #19
0
@assert_p.def_effectful_abstract_eval
def assert_abstract_eval(pred, code, payload, *, msgs):
    return [], {CheckEffect}


def assert_lowering_rule(*a, **k):
    # TODO(lenamartens): actually throw an error through emit_python_callable
    # TODO(lenamartens) add in-depth error explanation to link to in module docs.
    raise ValueError(
        'Cannot abstractly evaluate a checkify.check which was not'
        ' functionalized. This probably means you tried to stage'
        ' (jit/scan/pmap/...) a `check` without functionalizing it'
        ' through `checkify.checkify`.')


mlir.register_lowering(assert_p, assert_lowering_rule)
mlir.lowerable_effects.add(CheckEffect)
cf.allowed_effects.add(CheckEffect)

## checkify rules


def summary() -> str:
    return str(source_info_util.summarize(source_info_util.current()))


def nan_error_check(prim, error, enabled_errors, *in_vals, **params):
    out = prim.bind(*in_vals, **params)
    if ErrorCategory.NAN not in enabled_errors:
        return out, error
    no_nans = jnp.logical_not(jnp.any(jnp.isnan(out)))
コード例 #20
0
    res_arg, lin_arg = tree_unflatten(call_in_tree, args)
    del lin_arg
    assert all(not ad.is_undefined_primal(x) for x in tree_leaves(res_arg))

    cts = [
        ad_util.zeros_like_aval(ct.aval) if type(ct) is ad_util.Zero else ct
        for ct in cts
    ]
    ct_out = tree_unflatten(out_tree, cts)
    ct_lin = transpose(res_arg, ct_out)
    check_transpose_rule_trees(transpose, lin_tree, tree_structure(ct_lin))
    ct_lin_flat, _ = tree_flatten(tree_broadcast(lin_tree,
                                                 ct_lin,
                                                 is_leaf=lambda x: x is None),
                                  is_leaf=lambda x: x is None)
    return [None] * len(tree_leaves(res_arg)) + ct_lin_flat


def custom_transpose_lowering(*args, call_jaxpr, **params):
    return core.jaxpr_as_fun(call_jaxpr)(*args)


custom_transpose_p = CustomTransposePrimitive('custom_transpose_call')
core.custom_typechecks[custom_transpose_p] = custom_transpose_typecheck
ad.primitive_transposes[custom_transpose_p] = custom_transpose_transpose_rule
mlir.register_lowering(
    custom_transpose_p,
    mlir.lower_fun(custom_transpose_lowering, multiple_results=True))
xla.register_initial_style_primitive(custom_transpose_p)
コード例 #21
0
ファイル: ad_checkpoint.py プロジェクト: romanngg/jax
  new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs)
  new_params = dict(eqn.params, jaxpr=new_jaxpr)
  if not any(used_inputs) and not any(used_outputs) and not new_jaxpr.effects:
    return used_inputs, None
  else:
    new_eqn = pe.new_jaxpr_eqn(
        [v for v, used in zip(eqn.invars, used_inputs) if used],
        [v for v, used in zip(eqn.outvars, used_outputs) if used],
        eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info)
    return used_inputs, new_eqn
pe.dce_rules[remat_p] = remat_dce


def checkpoint_name(x, name):
  return name_p.bind(x, name=name)

name_p.def_impl(lambda x, *, name: x)
name_p.def_abstract_eval(lambda x, *, name: x)

def name_jvp(primals, tangents, *, name):
  (x,), (xdot,) = primals, tangents
  return name_p.bind(x, name=name), xdot  # don't name the tangent value
ad.primitive_jvps[name_p] = name_jvp

mlir.register_lowering(name_p, lambda ctx, x, *, name: [x])

def name_batcher(args, dims, *, name):
  (x,), (d,) = args, dims
  return name_p.bind(x, name=name), d
batching.primitive_batchers[name_p] = name_batcher
コード例 #22
0
    def cond(carry):
        i, _ = carry
        return i < nsteps

    def body(carry):
        i, state = carry
        i_ = nsteps - i - 1 if reverse else i
        next_state = core.eval_jaxpr(discharged_jaxpr, consts, i_, *state)
        return i + 1, next_state

    _, state = lax.while_loop(cond, body, (jnp.int32(0), list(args)))
    return state


mlir.register_lowering(for_p, mlir.lower_fun(_for_impl, multiple_results=True))
for_p.def_impl(partial(xla.apply_primitive, for_p))


def _for_jvp(primals, tangents, *, jaxpr, nsteps, reverse, which_linear):
    nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents]
    # We need to find out which `Ref`s have nonzero tangents after running the
    # for loop. Ordinarily we do this with a fixed point on the body jaxpr but
    # a `for` body jaxpr is stateful and has no outputs. We therefore discharge
    # the state effect from the jaxpr and we will now have a "symmetric" jaxpr
    # where the inputs line up with the outputs. We use this discharged jaxpr
    # for the fixed point.
    discharged_jaxpr, body_consts = discharge_state(jaxpr, ())
    for _ in range(len(nonzero_tangents)):
        _, out_nonzero_tangents = ad.jvp_jaxpr(core.ClosedJaxpr(
            discharged_jaxpr, body_consts), [False] + nonzero_tangents,
コード例 #23
0
ファイル: coo.py プロジェクト: John1Tang/jax
def _coo_todense_transpose(ct, data, row, col, *, spinfo):
    # Note: we assume that transpose has the same sparsity pattern.
    # Can we check this?
    assert ad.is_undefined_primal(data)
    if ad.is_undefined_primal(row) or ad.is_undefined_primal(col):
        raise ValueError("Cannot transpose with respect to sparse indices")
    assert ct.shape == spinfo.shape
    assert row.aval.dtype == col.aval.dtype
    assert ct.dtype == data.aval.dtype
    return _coo_extract(row, col, ct), row, col


ad.defjvp(coo_todense_p, _coo_todense_jvp, None, None)
ad.primitive_transposes[coo_todense_p] = _coo_todense_transpose
xla.register_translation(coo_todense_p, _coo_todense_translation_rule)
mlir.register_lowering(coo_todense_p, _coo_todense_lowering)
if sparse_apis and sparse_apis.is_supported:
    xla.register_translation(coo_todense_p,
                             _coo_todense_gpu_translation_rule,
                             platform='gpu')
    if jax._src.lib.version > (0, 3, 5):
        mlir.register_lowering(coo_todense_p,
                               _coo_todense_gpu_lowering,
                               platform='gpu')

#--------------------------------------------------------------------
# coo_fromdense

coo_fromdense_p = core.Primitive('coo_fromdense')
coo_fromdense_p.multiple_results = True
コード例 #24
0
ファイル: remat_impl.py プロジェクト: xueeinstein/jax
            translation_rule = _remat_translation_using_opt_barrier
        elif is_gpu_platform:
            translation_rule = _remat_translation_using_while
        else:
            translation_rule = _remat_translation_using_cond
    else:
        translation_rule = lambda *args, jaxpr: core.eval_jaxpr(
            jaxpr, (), *args)

    return jax.named_call(translation_rule,
                          name=wrap_name(name, "remat"))(*args, jaxpr=jaxpr)


for remat_primitive in (pe.remat_call_p,
                        ad_checkpoint.remat_p):  # type: ignore
    mlir.register_lowering(remat_primitive,
                           mlir.lower_fun(remat_impl, multiple_results=True))
    mlir.register_lowering(remat_primitive,
                           mlir.lower_fun(partial(remat_impl,
                                                  is_gpu_platform=True),
                                          multiple_results=True),
                           platform="gpu")


def _optimization_barrier_abstract_eval(*args):
    return args


def _optimization_barrier_lowering_rule(ctx, *args):
    barrier_types = _map(mlir.aval_to_ir_types, ctx.avals_in)
    flat_barrier_types = util.flatten(barrier_types)
コード例 #25
0
lcf.allowed_effects.add('while1')
lcf.allowed_effects.add('while2')

# TODO(sharadmv): remove jaxlib guards for GPU tests when jaxlib minimum
#                 version is >= 0.3.11
disabled_backends = ['tpu']
if jaxlib.version < (0, 3, 11):
    disabled_backends.append('gpu')


def trivial_effect_lowering(ctx, *, effect):
    ctx.set_tokens_out(ctx.tokens_in)
    return []


mlir.register_lowering(effect_p, trivial_effect_lowering)


def function_effect_lowering(ctx, *, effect):
    def _f(ctx):
        ctx.set_tokens_out(ctx.tokens_in)
        return []

    func = mlir._emit_lowering_rule_as_fun(_f, ctx)

    output_types = map(mlir.aval_to_ir_types, ctx.avals_out)
    token_types = [mlir.token_type() for _ in ctx.tokens_in.items()]
    output_types = [*token_types, *output_types]
    flat_output_types = util.flatten(output_types)
    call = mlir.func_dialect.CallOp(
        flat_output_types, mlir.ir.FlatSymbolRefAttr.get(func.name.value),
コード例 #26
0
ファイル: windowed_reductions.py プロジェクト: wayfeng/jax
    rw = mhlo.ReduceWindowOp(
        map(mlir.aval_to_ir_type, ctx.avals_out), operands, init_values,
        mlir.dense_int_elements(window_dimensions),
        mlir.dense_int_elements(window_strides),
        mlir.dense_int_elements(base_dilation),
        mlir.dense_int_elements(window_dilation),
        ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64)))
    reducer = rw.regions[0].blocks.append(*(scalar_types + scalar_types))
    with ir.InsertionPoint(reducer):
        out_nodes = mlir.jaxpr_subcomp(ctx.module_context, jaxpr, consts,
                                       *([a] for a in reducer.arguments))
        mhlo.ReturnOp(util.flatten(out_nodes))
    return rw.results


mlir.register_lowering(reduce_window_p, _generic_reduce_window_lower)


def _reduce_window_sum_shape_rule(operand, *, window_dimensions,
                                  window_strides, padding, base_dilation,
                                  window_dilation):
    if not dtypes.issubdtype(operand.dtype, np.number):
        msg = "operand to reduce_window_sum must have a number dtype, got {}"
        raise TypeError(msg.format(np.dtype(operand.dtype).name))
    return _common_reduce_window_shape_rule(operand, window_dimensions,
                                            window_strides, padding,
                                            base_dilation, window_dilation)


def _reduce_window_sum_translation_rule(ctx, avals_in, avals_out, operand, *,
                                        window_dimensions, window_strides,
コード例 #27
0
 def tearDown(self):
     super().tearDown()
     dispatch.runtime_tokens.clear()
     config.update('jax_enable_x64', self.old_x64)
     mlir.register_lowering(effect_p, self._old_lowering)
コード例 #28
0
@sp_indices_p.def_abstract_eval
def _sp_indices_abstract_eval(mat):
  return mat.indices_aval

def _sp_indices_translation_rule(ctx, avals_in, avals_out, data, indices):
  return [indices]

# Note: cannot use lower_fun to define attribute access primitives
# because it leads to infinite recursion.
xla.register_translation(sp_indices_p, _sp_indices_translation_rule)

def _sp_indices_mhlo_lowering(ctx, data_and_indices):
  return [data_and_indices[1]]

mlir.register_lowering(sp_indices_p, _sp_indices_mhlo_lowering)

sp_data_p = core.Primitive('sp_data')

@sp_data_p.def_impl
def _sp_data_impl(mat):
  return mat.data

@sp_data_p.def_abstract_eval
def _sp_data_abstract_eval(mat):
  return mat.data_aval

def _sp_data_translation_rule(ctx, avals_in, avals_out, data, indices):
  return [data]

# Note: cannot use lower_fun to define attribute access primitives
コード例 #29
0
ファイル: sharded_jit.py プロジェクト: xueeinstein/jax

def _sharded_call_impl(fun, *args, nparts, in_parts, out_parts_thunk,
                       local_in_parts, local_out_parts_thunk, local_nparts,
                       name):
    compiled_fun = _sharded_callable(fun, nparts, in_parts, out_parts_thunk,
                                     local_in_parts, local_out_parts_thunk,
                                     local_nparts, name,
                                     *map(xla.abstractify, args))
    return compiled_fun(*args)


sharded_call_p = core.CallPrimitive("sharded_call")
sharded_call = sharded_call_p.bind
sharded_call_p.def_impl(_sharded_call_impl)
mlir.register_lowering(sharded_call_p, _sharded_jit_lowering)


def sharded_jit(
        fun: Callable,
        in_parts,
        out_parts,
        num_partitions: Optional[int] = None,
        local_in_parts=None,
        local_out_parts=None,
        local_num_partitions=None,
        static_argnums: Union[int, Iterable[int]] = (),
):
    """Like ``jit``, but partitions ``fun`` across multiple devices.

  WARNING: this feature is still under active development! It may not work well,
コード例 #30
0
ファイル: convolution.py プロジェクト: frederikwilde/jax
        output_spatial_dimensions=list(out_spec[2:]))
    num_spatial_dims = len(rhs_spec) - 2
    window_reversal = mlir.dense_bool_elements([False] * num_spatial_dims)
    return [
        mhlo.ConvOp(mlir.aval_to_ir_type(aval_out), lhs, rhs,
                    mlir.dense_int_elements(window_strides),
                    mlir.dense_int_elements(padding),
                    mlir.dense_int_elements(lhs_dilation),
                    mlir.dense_int_elements(rhs_dilation), window_reversal,
                    dnums, mlir.i64_attr(feature_group_count),
                    mlir.i64_attr(batch_group_count),
                    lax.precision_attr(precision)).result
    ]


mlir.register_lowering(conv_general_dilated_p, _conv_general_dilated_lower)
# TODO(b/161124619, b/161126248): XLA does not support complex convolution on
# GPU, and on CPU it uses a slow loop-based implementation;
# on these backends, lower complex convolutions away.
mlir.register_lowering(conv_general_dilated_p,
                       partial(_conv_general_dilated_lower,
                               expand_complex_convolutions=True),
                       platform='cpu')
mlir.register_lowering(conv_general_dilated_p,
                       partial(_conv_general_dilated_lower,
                               expand_complex_convolutions=True),
                       platform='gpu')


def _reshape_axis_into(src, dst, x):
    perm = [i for i in range(x.ndim) if i != src]