예제 #1
0
def _ordered_effect_lowering(ctx, token, *args, **params):
  avals_in = [core.abstract_token, *ctx.avals_in]
  avals_out = [core.abstract_token, *ctx.avals_out]
  args = (token, *args)
  def _callback(token, *flat_args):
    out = debug_callback_p.impl(*flat_args, **params)
    return (token, *out)
  (token, *result), keepalive = mlir.emit_python_callback(
      ctx.module_context.platform, _callback, list(args), avals_in, avals_out,
      True)
  return result, keepalive, token
예제 #2
0
def callback_effect_lowering(ctx: mlir.LoweringRuleContext, *args, callback,
                             out_avals, effect):
    del out_avals
    if effect in core.ordered_effects:

        def _token_callback(token, *args):
            out = callback(*args)
            flat_out = jax.tree_util.tree_leaves(out)
            return (token, *flat_out)

        token_in = ctx.tokens_in.get(effect)[0]
        (token_out, *out_op), keep_alive = mlir.emit_python_callback(
            ctx.module_context.platform, _token_callback, [token_in, *args],
            [core.abstract_token, *ctx.avals_in],
            [core.abstract_token, *ctx.avals_out], True)
        ctx.set_tokens_out(
            ctx.tokens_in.update_tokens(mlir.TokenSet({effect: token_out})))
    else:
        out_op, keep_alive = mlir.emit_python_callback(
            ctx.module_context.platform, callback, list(args),
            list(ctx.avals_in), list(ctx.avals_out), True)
    ctx.module_context.add_keepalive(keep_alive)
    return out_op
예제 #3
0
def debug_callback_lowering(ctx, *args, effect, callback, **params):
  if effect in core.ordered_effects:
    token = ctx.tokens_in.get(effect)[0]
    result, keepalive, token = _ordered_effect_lowering(ctx, token,
        *args, effect=effect, callback=callback, **params)
    ctx.set_tokens_out(mlir.TokenSet({effect: (token,)}))
  else:
    def _callback(*flat_args):
      return tuple(debug_callback_p.impl(
        *flat_args, effect=effect, callback=callback, **params))
    result, keepalive = mlir.emit_python_callback(ctx.module_context.platform,
      _callback, list(args), ctx.avals_in, ctx.avals_out,  True)
  ctx.module_context.add_keepalive(keepalive)
  return result