def __init__(self, name): super(FlatPrimitive, self).__init__(name) self.multiple_results = True def _abstract(*flat_avals, **params): return pe.abstract_eval_fun(self.impl, *flat_avals, **params) self.def_abstract_eval(_abstract) def _jvp(primals, tangents, **params): return ad.jvp(lu.wrap_init(self.impl, params)).call_wrapped( primals, tangents) ad.primitive_jvps[self] = _jvp def _batch(args, dims, **params): batched, out_dims = batch_fun(lu.wrap_init(self.impl, params), dims) return batched.call_wrapped(*args), out_dims() batching.primitive_batchers[self] = _batch def _xla(c, *xla_args, **params): translation = xla.lower_fun(self.impl, multiple_results=True) return translation(c, *xla_args, **params) xla.translations[self] = _xla def _mlir(c, *mlir_args, **params): lowering = mlir.lower_fun(self.impl, multiple_results=True) return lowering(c, *mlir_args, **params) mlir.register_lowering(self, _mlir)
def hop_lowering(prim): def rule(ctx, *args, backend, name, call_jaxpr, **_params): return mlir._call_lowering( # pylint: disable=protected-access name, name, call_jaxpr, backend, ctx.module_context, ctx.avals_in, ctx.avals_out, *args) mlir.register_lowering(prim, rule) return rule
def test_should_not_pass_tokens_into_unordered_effect(self): def effect_lowering(ctx, *, effect): self.assertEmpty(ctx.tokens_in) return [] mlir.register_lowering(effect_p, effect_lowering) @jax.jit def f(x): effect_p.bind(effect='bar') return x + 1. f.lower(2.)
def test_lowering_ordered_effect_should_create_tokens(self): def effect_lowering(ctx, *, effect): ctx.set_tokens_out(ctx.tokens_in) return [] mlir.register_lowering(effect_p, effect_lowering) @jax.jit def f(x): effect_p.bind(effect='foo') return x + 1. mhlo = f.lower(2.).compiler_ir() main = mhlo.body.operations[0] first_op = main.body.blocks[0].operations[0] self.assertEqual(first_op.operation.name, "mhlo.create_token") @jax.jit def f(x): effect_p.bind(effect='foo') effect_p.bind(effect='foo2') return x + 1. mhlo = f.lower(2.).compiler_ir() main = mhlo.body.operations[0] first_op = main.body.blocks[0].operations[0] self.assertEqual(first_op.operation.name, "mhlo.create_token") second_op = main.body.blocks[0].operations[1] self.assertEqual(second_op.operation.name, "mhlo.create_token") @jax.jit def f(x): effect_p.bind(effect='foo') return x + 1. mhlo = f.lower(2.).compiler_ir() main = mhlo.body.operations[0] first_op = main.body.blocks[0].operations[0] self.assertEqual(first_op.operation.name, "mhlo.create_token") @jax.jit def f(x): effect_p.bind(effect='foo') effect_p.bind(effect='foo2') return x + 1. mhlo = f.lower(2.).compiler_ir() main = mhlo.body.operations[0] first_op = main.body.blocks[0].operations[0] self.assertEqual(first_op.operation.name, "mhlo.create_token") second_op = main.body.blocks[0].operations[1] self.assertEqual(second_op.operation.name, "mhlo.create_token")
def test_should_pass_tokens_into_ordered_effect(self): def _effect_lowering(ctx, *, effect): self.assertListEqual(list(ctx.tokens_in.effects()), ['foo']) ctx.set_tokens_out(ctx.tokens_in) return [] mlir.register_lowering(effect_p, _effect_lowering) @jax.jit def f(x): effect_p.bind(effect='foo') return x + 1. f.lower(2.)
def test_lowering_that_doesnt_set_tokens_should_cause_error(self): def bad_effect_lowering(ctx, *, effect): # Doesn't call `ctx.set_tokens_out`! return [] mlir.register_lowering(effect_p, bad_effect_lowering) @jax.jit def f(x): effect_p.bind(effect='foo') return x + 1. with self.assertRaisesRegex(ValueError, 'Lowering rule for `effect` needs to ' 'set `tokens_out`'): f.lower(2.)
def test_lowering_that_sets_wrong_tokens_should_cause_error(self): def bad_effect_lowering(ctx, *, effect): ctx.set_tokens_out(mlir.TokenSet(bar=ctx.tokens_in.get('foo'))) return [] mlir.register_lowering(effect_p, bad_effect_lowering) @jax.jit def f(x): effect_p.bind(effect='foo') return x + 1. with self.assertRaisesRegex(ValueError, 'Lowering rule for `effect` returns ' 'incorrect set of output token.'): f.lower(2.)
def setUp(self): super().setUp() self.old_x64 = config.jax_enable_x64 config.update('jax_enable_x64', False) self._old_lowering = mlir._lowerings[effect_p] def _effect_lowering(ctx, *, effect): if effect in core.ordered_effects: expected_effects = [effect] else: expected_effects = [] self.assertListEqual(list(ctx.tokens_in.effects()), expected_effects) ctx.set_tokens_out(ctx.tokens_in) return [] mlir.register_lowering(effect_p, _effect_lowering) dispatch.runtime_tokens.clear()
def test_nontrivial_lowering_with_unordered_effect_should_consume_token(self): mlir.register_lowering(effect_p, function_effect_lowering) @jax.jit def f(x): effect_p.bind(effect='bar') return x + 1. mhlo = f.lower(2.).compiler_ir() main = mhlo.body.operations[0] first_op = main.body.blocks[0].operations[0] self.assertEqual(first_op.operation.name, "func.call") self.assertEqual(str(first_op.attributes["callee"]), "@effect") self.assertLen(list(first_op.operands), 0) func = mhlo.body.operations[1] self.assertEqual(func.name.value, "effect") self.assertLen(list(func.type.inputs), 0) self.assertLen(list(func.type.results), 0)
def _custom_jvp_call_jaxpr_abstract_eval(*args, fun_jaxpr: core.ClosedJaxpr, **params): del args, params return fun_jaxpr.out_avals custom_jvp_call_jaxpr_p = core.AxisPrimitive('custom_jvp_call_jaxpr') custom_jvp_call_jaxpr_p.multiple_results = True custom_jvp_call_jaxpr_p.def_impl(_custom_jvp_call_jaxpr_impl) custom_jvp_call_jaxpr_p.def_abstract_eval(_custom_jvp_call_jaxpr_abstract_eval) CustomJVPCallPrimitive.initial_style = custom_jvp_call_jaxpr_p mlir.register_lowering( custom_jvp_call_jaxpr_p, mlir.lower_fun(_custom_jvp_call_jaxpr_impl, multiple_results=True)) def _custom_jvp_call_jaxpr_jvp(primals, tangents, *, fun_jaxpr: core.ClosedJaxpr, jvp_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]], num_consts: int): _, args = split_list(primals, [num_consts]) consts_dot, args_dot = split_list(tangents, [num_consts]) if any(type(t) is not Zero for t in consts_dot): raise ad.CustomJVPException() jvp_jaxpr, jvp_consts = jvp_jaxpr_thunk() # consts can be tracers! args_dot = map(ad.instantiate_zeros, args_dot)
rule=jvp_of_rule_rule, in_tree=jvp_in_tree, out_tree=jvp_out_tree) assert len(outs) % 2 == 0, len(outs) out_primals, out_tangents = util.split_list(outs, [len(outs) // 2]) return out_primals, out_tangents custom_vmap_p = core.Primitive('custom_vmap_call') custom_vmap_p.multiple_results = True custom_vmap_p.def_impl(custom_vmap_impl) custom_vmap_p.def_abstract_eval(custom_vmap_abstract_eval) batching.primitive_batchers[custom_vmap_p] = custom_vmap_batching ad.primitive_jvps[custom_vmap_p] = custom_vmap_jvp xla.register_initial_style_primitive(custom_vmap_p) mlir.register_lowering(custom_vmap_p, mlir.lower_fun(custom_vmap_impl, multiple_results=True)) # -- custom vmap applications def tree_split(mask, tree): lhs = tree_map(lambda l, x: x if l else None, mask, tree) rhs = tree_map(lambda l, x: None if l else x, mask, tree) return lhs, rhs def tree_merge(mask, lhs_tree, rhs_tree): return tree_map(lambda l, x_l, x_r: x_l if l else x_r, mask, lhs_tree, rhs_tree)
# Use JAX's convention for complex gradients # https://github.com/google/jax/issues/6223#issuecomment-807740707 return lax.conj(out) def _fft_transpose_rule(t, operand, fft_type, fft_lengths): if fft_type == xla_client.FftType.RFFT: result = _rfft_transpose(t, fft_lengths) elif fft_type == xla_client.FftType.IRFFT: result = _irfft_transpose(t, fft_lengths) else: result = fft(t, fft_type, fft_lengths) return result, def _fft_batching_rule(batched_args, batch_dims, fft_type, fft_lengths): x, = batched_args bd, = batch_dims x = batching.moveaxis(x, bd, 0) return fft(x, fft_type, fft_lengths), 0 fft_p = Primitive('fft') fft_p.def_impl(_fft_impl) fft_p.def_abstract_eval(fft_abstract_eval) mlir.register_lowering(fft_p, _fft_lowering) ad.deflinear2(fft_p, _fft_transpose_rule) batching.primitive_batchers[fft_p] = _fft_batching_rule if pocketfft: mlir.register_lowering(fft_p, _fft_lowering_cpu, platform='cpu')
# buffers from different XLA backends are passed through the host. backend = xb.get_device_backend(device) moved_buf = backend.buffer_from_pyval(x.device_buffer.to_py(), device) return device_array.make_device_array(x.aval, device, moved_buf) def _device_put_impl(x, device: Optional[Device] = None): if device_array.type_is_device_array(x): return _copy_device_array_to_device(x, device) try: a = xla.abstractify(x) except TypeError as err: raise TypeError( f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err return aval_to_result_handler(device, a)(*device_put(x, device)) device_put_p = core.Primitive('device_put') device_put_p.def_impl(_device_put_impl) device_put_p.def_abstract_eval(lambda x, device=None: x) xla.translations[device_put_p] = lambda c, x, device=None: x ad.deflinear2(device_put_p, lambda cotangent, _, **kwargs: [cotangent]) masking.defvectorized(device_put_p) batching.defvectorized(device_put_p) def _device_put_lowering(ctx, x, *, device): return [x] mlir.register_lowering(device_put_p, _device_put_lowering)
donated_invars=(False, ) * len(args)) xla.register_translation(nest_p, _nest_translation_rule) def _nest_lowering(ctx, *args, name, call_jaxpr, scope, **_): return mlir._xla_call_lower( # pylint: disable=protected-access ctx, *args, name=jax_util.wrap_name(name, f'nest[{scope}]'), call_jaxpr=call_jaxpr, donated_invars=(False, ) * len(args)) mlir.register_lowering(nest_p, _nest_lowering) def _nest_transpose_rule(*args, **kwargs): return ad.call_transpose(nest_p, *args, **kwargs) ad.primitive_transposes[nest_p] = _nest_transpose_rule def nest(f, *, scope: str): """Wraps a function to create a new scope for harvested values. Harvested values live in one dynamic name scope (for a particular tag), and in strict mode, values with the same name cannot be collected or injected more than once. nest(f, scope=<name>) will take all tagged values in `f` and
return _coo_todense(data_dot, row, col, spinfo=spinfo) def _coo_todense_transpose(ct, data, row, col, *, spinfo): # Note: we assume that transpose has the same sparsity pattern. # Can we check this? assert ad.is_undefined_primal(data) if ad.is_undefined_primal(row) or ad.is_undefined_primal(col): raise ValueError("Cannot transpose with respect to sparse indices") assert ct.shape == spinfo.shape assert row.aval.dtype == col.aval.dtype assert ct.dtype == data.aval.dtype return _coo_extract(row, col, ct), row, col ad.defjvp(coo_todense_p, _coo_todense_jvp, None, None) ad.primitive_transposes[coo_todense_p] = _coo_todense_transpose mlir.register_lowering(coo_todense_p, _coo_todense_lowering) if gpu_sparse: if gpu_sparse.cuda_is_supported: mlir.register_lowering( coo_todense_p, partial(_coo_todense_gpu_lowering, gpu_sparse.cuda_coo_todense), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( coo_todense_p, partial(_coo_todense_gpu_lowering, gpu_sparse.rocm_coo_todense), platform='rocm') if sparse_apis and sparse_apis.is_supported: mlir.register_lowering( coo_todense_p,
] # Broadcast out b if necessary new_b = [ batching.broadcast(x, axis_size, 0) if now_bat and not was_bat else batching.moveaxis(x, d, 0) if now_bat and d != 0 else x for x, d, was_bat, now_bat in zip(b, b_dims, orig_b_bat, b_bat) ] outs = linear_solve_p.bind(*(new_params + new_b), const_lengths=const_lengths, jaxprs=batched_jaxprs) out_dims = [ 0 if batched else batching.not_mapped for batched in solve_x_bat ] return outs, out_dims linear_solve_p = core.AxisPrimitive('custom_linear_solve') linear_solve_p.multiple_results = True linear_solve_p.def_impl(_custom_linear_solve_impl) linear_solve_p.def_abstract_eval(_linear_solve_abstract_eval) ad.primitive_jvps[linear_solve_p] = _custom_linear_solve_jvp xla.register_initial_style_primitive(linear_solve_p) mlir.register_lowering( linear_solve_p, mlir.lower_fun(_custom_linear_solve_impl, multiple_results=True)) ad.primitive_transposes[linear_solve_p] = _linear_solve_transpose_rule batching.axis_primitive_batchers[linear_solve_p] = _linear_solve_batching_rule pe.partial_eval_jaxpr_custom_rules[linear_solve_p] = \ partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'linear_solve')
def debug_callback_lowering(ctx, *args, effect, callback, **params): if effect in core.ordered_effects: token = ctx.tokens_in.get(effect)[0] result, keepalive, token = _ordered_effect_lowering(ctx, token, *args, effect=effect, callback=callback, **params) ctx.set_tokens_out(mlir.TokenSet({effect: (token,)})) else: def _callback(*flat_args): return tuple(debug_callback_p.impl( *flat_args, effect=effect, callback=callback, **params)) result, keepalive = mlir.emit_python_callback(ctx.module_context.platform, _callback, list(args), ctx.avals_in, ctx.avals_out, True) ctx.module_context.add_keepalive(keepalive) return result mlir.register_lowering(debug_callback_p, debug_callback_lowering, platform="cpu") def debug_callback(callback: Callable[..., Any], effect: DebugEffect, *args, **kwargs): """Calls a stageable Python callback. `debug_callback` enables you to pass in a Python function that can be called inside of a staged JAX program. A `debug_callback` follows existing JAX transformation *pure* operational semantics, which are therefore unaware of side-effects. This means the effect could be dropped, duplicated, or potentially reordered in the presence of higher-order primitives and transformations. We want this behavior because we'd like `debug_callback` to be "innocuous", i.e. we want these primitives to change the JAX computation as little as possible while revealing as much about them as possible, such as which parts
elif isinstance(obj, BCOO): _, indices = bufs return bcoo_extract(indices, ct), indices elif isinstance(obj, COO): _, row, col = bufs return _coo_extract(row, col, ct), row, col else: raise NotImplementedError(f"todense_transpose for {type(obj)}") def _todense_batching_rule(batched_args, batch_dims, *, tree): return jax.vmap(partial(_todense_impl, tree=tree), batch_dims)(*batched_args), 0 ad.primitive_jvps[todense_p] = _todense_jvp ad.primitive_transposes[todense_p] = _todense_transpose batching.primitive_batchers[todense_p] = _todense_batching_rule mlir.register_lowering(todense_p, mlir.lower_fun( _todense_impl, multiple_results=False)) def empty(shape, dtype=None, index_dtype='int32', sparse_format='bcoo', **kwds): """Create an empty sparse array. Args: shape: sequence of integers giving the array shape. dtype: (optional) dtype of the array. index_dtype: (optional) dtype of the index arrays. format: string specifying the matrix format (e.g. ['bcoo']). **kwds: additional keywords passed to the format-specific _empty constructor. Returns: mat: empty sparse matrix. """ formats = {'bcoo': BCOO, 'coo': COO, 'csr': CSR, 'csc': CSC}
@assert_p.def_effectful_abstract_eval def assert_abstract_eval(pred, code, payload, *, msgs): return [], {CheckEffect} def assert_lowering_rule(*a, **k): # TODO(lenamartens): actually throw an error through emit_python_callable # TODO(lenamartens) add in-depth error explanation to link to in module docs. raise ValueError( 'Cannot abstractly evaluate a checkify.check which was not' ' functionalized. This probably means you tried to stage' ' (jit/scan/pmap/...) a `check` without functionalizing it' ' through `checkify.checkify`.') mlir.register_lowering(assert_p, assert_lowering_rule) mlir.lowerable_effects.add(CheckEffect) cf.allowed_effects.add(CheckEffect) ## checkify rules def summary() -> str: return str(source_info_util.summarize(source_info_util.current())) def nan_error_check(prim, error, enabled_errors, *in_vals, **params): out = prim.bind(*in_vals, **params) if ErrorCategory.NAN not in enabled_errors: return out, error no_nans = jnp.logical_not(jnp.any(jnp.isnan(out)))
res_arg, lin_arg = tree_unflatten(call_in_tree, args) del lin_arg assert all(not ad.is_undefined_primal(x) for x in tree_leaves(res_arg)) cts = [ ad_util.zeros_like_aval(ct.aval) if type(ct) is ad_util.Zero else ct for ct in cts ] ct_out = tree_unflatten(out_tree, cts) ct_lin = transpose(res_arg, ct_out) check_transpose_rule_trees(transpose, lin_tree, tree_structure(ct_lin)) ct_lin_flat, _ = tree_flatten(tree_broadcast(lin_tree, ct_lin, is_leaf=lambda x: x is None), is_leaf=lambda x: x is None) return [None] * len(tree_leaves(res_arg)) + ct_lin_flat def custom_transpose_lowering(*args, call_jaxpr, **params): return core.jaxpr_as_fun(call_jaxpr)(*args) custom_transpose_p = CustomTransposePrimitive('custom_transpose_call') core.custom_typechecks[custom_transpose_p] = custom_transpose_typecheck ad.primitive_transposes[custom_transpose_p] = custom_transpose_transpose_rule mlir.register_lowering( custom_transpose_p, mlir.lower_fun(custom_transpose_lowering, multiple_results=True)) xla.register_initial_style_primitive(custom_transpose_p)
new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs) new_params = dict(eqn.params, jaxpr=new_jaxpr) if not any(used_inputs) and not any(used_outputs) and not new_jaxpr.effects: return used_inputs, None else: new_eqn = pe.new_jaxpr_eqn( [v for v, used in zip(eqn.invars, used_inputs) if used], [v for v, used in zip(eqn.outvars, used_outputs) if used], eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info) return used_inputs, new_eqn pe.dce_rules[remat_p] = remat_dce def checkpoint_name(x, name): return name_p.bind(x, name=name) name_p.def_impl(lambda x, *, name: x) name_p.def_abstract_eval(lambda x, *, name: x) def name_jvp(primals, tangents, *, name): (x,), (xdot,) = primals, tangents return name_p.bind(x, name=name), xdot # don't name the tangent value ad.primitive_jvps[name_p] = name_jvp mlir.register_lowering(name_p, lambda ctx, x, *, name: [x]) def name_batcher(args, dims, *, name): (x,), (d,) = args, dims return name_p.bind(x, name=name), d batching.primitive_batchers[name_p] = name_batcher
def cond(carry): i, _ = carry return i < nsteps def body(carry): i, state = carry i_ = nsteps - i - 1 if reverse else i next_state = core.eval_jaxpr(discharged_jaxpr, consts, i_, *state) return i + 1, next_state _, state = lax.while_loop(cond, body, (jnp.int32(0), list(args))) return state mlir.register_lowering(for_p, mlir.lower_fun(_for_impl, multiple_results=True)) for_p.def_impl(partial(xla.apply_primitive, for_p)) def _for_jvp(primals, tangents, *, jaxpr, nsteps, reverse, which_linear): nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents] # We need to find out which `Ref`s have nonzero tangents after running the # for loop. Ordinarily we do this with a fixed point on the body jaxpr but # a `for` body jaxpr is stateful and has no outputs. We therefore discharge # the state effect from the jaxpr and we will now have a "symmetric" jaxpr # where the inputs line up with the outputs. We use this discharged jaxpr # for the fixed point. discharged_jaxpr, body_consts = discharge_state(jaxpr, ()) for _ in range(len(nonzero_tangents)): _, out_nonzero_tangents = ad.jvp_jaxpr(core.ClosedJaxpr( discharged_jaxpr, body_consts), [False] + nonzero_tangents,
def _coo_todense_transpose(ct, data, row, col, *, spinfo): # Note: we assume that transpose has the same sparsity pattern. # Can we check this? assert ad.is_undefined_primal(data) if ad.is_undefined_primal(row) or ad.is_undefined_primal(col): raise ValueError("Cannot transpose with respect to sparse indices") assert ct.shape == spinfo.shape assert row.aval.dtype == col.aval.dtype assert ct.dtype == data.aval.dtype return _coo_extract(row, col, ct), row, col ad.defjvp(coo_todense_p, _coo_todense_jvp, None, None) ad.primitive_transposes[coo_todense_p] = _coo_todense_transpose xla.register_translation(coo_todense_p, _coo_todense_translation_rule) mlir.register_lowering(coo_todense_p, _coo_todense_lowering) if sparse_apis and sparse_apis.is_supported: xla.register_translation(coo_todense_p, _coo_todense_gpu_translation_rule, platform='gpu') if jax._src.lib.version > (0, 3, 5): mlir.register_lowering(coo_todense_p, _coo_todense_gpu_lowering, platform='gpu') #-------------------------------------------------------------------- # coo_fromdense coo_fromdense_p = core.Primitive('coo_fromdense') coo_fromdense_p.multiple_results = True
translation_rule = _remat_translation_using_opt_barrier elif is_gpu_platform: translation_rule = _remat_translation_using_while else: translation_rule = _remat_translation_using_cond else: translation_rule = lambda *args, jaxpr: core.eval_jaxpr( jaxpr, (), *args) return jax.named_call(translation_rule, name=wrap_name(name, "remat"))(*args, jaxpr=jaxpr) for remat_primitive in (pe.remat_call_p, ad_checkpoint.remat_p): # type: ignore mlir.register_lowering(remat_primitive, mlir.lower_fun(remat_impl, multiple_results=True)) mlir.register_lowering(remat_primitive, mlir.lower_fun(partial(remat_impl, is_gpu_platform=True), multiple_results=True), platform="gpu") def _optimization_barrier_abstract_eval(*args): return args def _optimization_barrier_lowering_rule(ctx, *args): barrier_types = _map(mlir.aval_to_ir_types, ctx.avals_in) flat_barrier_types = util.flatten(barrier_types)
lcf.allowed_effects.add('while1') lcf.allowed_effects.add('while2') # TODO(sharadmv): remove jaxlib guards for GPU tests when jaxlib minimum # version is >= 0.3.11 disabled_backends = ['tpu'] if jaxlib.version < (0, 3, 11): disabled_backends.append('gpu') def trivial_effect_lowering(ctx, *, effect): ctx.set_tokens_out(ctx.tokens_in) return [] mlir.register_lowering(effect_p, trivial_effect_lowering) def function_effect_lowering(ctx, *, effect): def _f(ctx): ctx.set_tokens_out(ctx.tokens_in) return [] func = mlir._emit_lowering_rule_as_fun(_f, ctx) output_types = map(mlir.aval_to_ir_types, ctx.avals_out) token_types = [mlir.token_type() for _ in ctx.tokens_in.items()] output_types = [*token_types, *output_types] flat_output_types = util.flatten(output_types) call = mlir.func_dialect.CallOp( flat_output_types, mlir.ir.FlatSymbolRefAttr.get(func.name.value),
rw = mhlo.ReduceWindowOp( map(mlir.aval_to_ir_type, ctx.avals_out), operands, init_values, mlir.dense_int_elements(window_dimensions), mlir.dense_int_elements(window_strides), mlir.dense_int_elements(base_dilation), mlir.dense_int_elements(window_dilation), ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64))) reducer = rw.regions[0].blocks.append(*(scalar_types + scalar_types)) with ir.InsertionPoint(reducer): out_nodes = mlir.jaxpr_subcomp(ctx.module_context, jaxpr, consts, *([a] for a in reducer.arguments)) mhlo.ReturnOp(util.flatten(out_nodes)) return rw.results mlir.register_lowering(reduce_window_p, _generic_reduce_window_lower) def _reduce_window_sum_shape_rule(operand, *, window_dimensions, window_strides, padding, base_dilation, window_dilation): if not dtypes.issubdtype(operand.dtype, np.number): msg = "operand to reduce_window_sum must have a number dtype, got {}" raise TypeError(msg.format(np.dtype(operand.dtype).name)) return _common_reduce_window_shape_rule(operand, window_dimensions, window_strides, padding, base_dilation, window_dilation) def _reduce_window_sum_translation_rule(ctx, avals_in, avals_out, operand, *, window_dimensions, window_strides,
def tearDown(self): super().tearDown() dispatch.runtime_tokens.clear() config.update('jax_enable_x64', self.old_x64) mlir.register_lowering(effect_p, self._old_lowering)
@sp_indices_p.def_abstract_eval def _sp_indices_abstract_eval(mat): return mat.indices_aval def _sp_indices_translation_rule(ctx, avals_in, avals_out, data, indices): return [indices] # Note: cannot use lower_fun to define attribute access primitives # because it leads to infinite recursion. xla.register_translation(sp_indices_p, _sp_indices_translation_rule) def _sp_indices_mhlo_lowering(ctx, data_and_indices): return [data_and_indices[1]] mlir.register_lowering(sp_indices_p, _sp_indices_mhlo_lowering) sp_data_p = core.Primitive('sp_data') @sp_data_p.def_impl def _sp_data_impl(mat): return mat.data @sp_data_p.def_abstract_eval def _sp_data_abstract_eval(mat): return mat.data_aval def _sp_data_translation_rule(ctx, avals_in, avals_out, data, indices): return [data] # Note: cannot use lower_fun to define attribute access primitives
def _sharded_call_impl(fun, *args, nparts, in_parts, out_parts_thunk, local_in_parts, local_out_parts_thunk, local_nparts, name): compiled_fun = _sharded_callable(fun, nparts, in_parts, out_parts_thunk, local_in_parts, local_out_parts_thunk, local_nparts, name, *map(xla.abstractify, args)) return compiled_fun(*args) sharded_call_p = core.CallPrimitive("sharded_call") sharded_call = sharded_call_p.bind sharded_call_p.def_impl(_sharded_call_impl) mlir.register_lowering(sharded_call_p, _sharded_jit_lowering) def sharded_jit( fun: Callable, in_parts, out_parts, num_partitions: Optional[int] = None, local_in_parts=None, local_out_parts=None, local_num_partitions=None, static_argnums: Union[int, Iterable[int]] = (), ): """Like ``jit``, but partitions ``fun`` across multiple devices. WARNING: this feature is still under active development! It may not work well,
output_spatial_dimensions=list(out_spec[2:])) num_spatial_dims = len(rhs_spec) - 2 window_reversal = mlir.dense_bool_elements([False] * num_spatial_dims) return [ mhlo.ConvOp(mlir.aval_to_ir_type(aval_out), lhs, rhs, mlir.dense_int_elements(window_strides), mlir.dense_int_elements(padding), mlir.dense_int_elements(lhs_dilation), mlir.dense_int_elements(rhs_dilation), window_reversal, dnums, mlir.i64_attr(feature_group_count), mlir.i64_attr(batch_group_count), lax.precision_attr(precision)).result ] mlir.register_lowering(conv_general_dilated_p, _conv_general_dilated_lower) # TODO(b/161124619, b/161126248): XLA does not support complex convolution on # GPU, and on CPU it uses a slow loop-based implementation; # on these backends, lower complex convolutions away. mlir.register_lowering(conv_general_dilated_p, partial(_conv_general_dilated_lower, expand_complex_convolutions=True), platform='cpu') mlir.register_lowering(conv_general_dilated_p, partial(_conv_general_dilated_lower, expand_complex_convolutions=True), platform='gpu') def _reshape_axis_into(src, dst, x): perm = [i for i in range(x.ndim) if i != src]