def _ordered_effect_lowering(ctx, token, *args, **params): avals_in = [core.abstract_token, *ctx.avals_in] avals_out = [core.abstract_token, *ctx.avals_out] args = (token, *args) def _callback(token, *flat_args): out = debug_callback_p.impl(*flat_args, **params) return (token, *out) (token, *result), keepalive = mlir.emit_python_callback( ctx.module_context.platform, _callback, list(args), avals_in, avals_out, True) return result, keepalive, token
def callback_effect_lowering(ctx: mlir.LoweringRuleContext, *args, callback, out_avals, effect): del out_avals if effect in core.ordered_effects: def _token_callback(token, *args): out = callback(*args) flat_out = jax.tree_util.tree_leaves(out) return (token, *flat_out) token_in = ctx.tokens_in.get(effect)[0] (token_out, *out_op), keep_alive = mlir.emit_python_callback( ctx.module_context.platform, _token_callback, [token_in, *args], [core.abstract_token, *ctx.avals_in], [core.abstract_token, *ctx.avals_out], True) ctx.set_tokens_out( ctx.tokens_in.update_tokens(mlir.TokenSet({effect: token_out}))) else: out_op, keep_alive = mlir.emit_python_callback( ctx.module_context.platform, callback, list(args), list(ctx.avals_in), list(ctx.avals_out), True) ctx.module_context.add_keepalive(keep_alive) return out_op
def debug_callback_lowering(ctx, *args, effect, callback, **params): if effect in core.ordered_effects: token = ctx.tokens_in.get(effect)[0] result, keepalive, token = _ordered_effect_lowering(ctx, token, *args, effect=effect, callback=callback, **params) ctx.set_tokens_out(mlir.TokenSet({effect: (token,)})) else: def _callback(*flat_args): return tuple(debug_callback_p.impl( *flat_args, effect=effect, callback=callback, **params)) result, keepalive = mlir.emit_python_callback(ctx.module_context.platform, _callback, list(args), ctx.avals_in, ctx.avals_out, True) ctx.module_context.add_keepalive(keepalive) return result