Exemplo n.º 1
0
def _select_and_scatter_lower(ctx, operand, source, init_value, *,
                              select_jaxpr, select_consts, scatter_jaxpr,
                              scatter_consts, window_dimensions,
                              window_strides, padding):
    operand_aval, source_aval, init_value_aval = ctx.avals_in
    aval_out, = ctx.avals_out
    scalar_aval = operand_aval.update(shape=())
    scalar_type = mlir.aval_to_ir_type(scalar_aval)
    op = mhlo.SelectAndScatterOp(
        mlir.aval_to_ir_type(aval_out), operand, source, init_value,
        mlir.dense_int_elements(window_dimensions),
        mlir.dense_int_elements(window_strides),
        ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64)))
    select = op.select.blocks.append(scalar_type, scalar_type)
    with ir.InsertionPoint(select):
        if select_jaxpr.effects:
            raise NotImplementedError('Cannot lower effectful `select`.')
        out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, select_jaxpr,
                                          mlir.TokenSet(), select_consts,
                                          *([a] for a in select.arguments))
        mhlo.ReturnOp(util.flatten(out_nodes))
    scatter = op.scatter.blocks.append(scalar_type, scalar_type)
    with ir.InsertionPoint(scatter):
        if scatter_jaxpr.effects:
            raise NotImplementedError('Cannot lower effectful `scatter`.')
        out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, scatter_jaxpr,
                                          mlir.TokenSet(), scatter_consts,
                                          *([a] for a in scatter.arguments))
        mhlo.ReturnOp(util.flatten(out_nodes))
    return op.results
Exemplo n.º 2
0
def _generic_reduce_window_lower(ctx, *args, jaxpr, consts, window_dimensions,
                                 window_strides, padding, base_dilation,
                                 window_dilation):
    operands, init_values = util.split_list(args, [len(args) // 2])
    _, init_value_avals = util.split_list(ctx.avals_in, [len(operands)])
    scalar_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals]
    rw = mhlo.ReduceWindowOp(
        map(mlir.aval_to_ir_type, ctx.avals_out),
        operands,
        init_values,
        mlir.dense_int_elements(window_dimensions),
        window_strides=mlir.dense_int_elements(window_strides),
        base_dilations=mlir.dense_int_elements(base_dilation),
        window_dilations=mlir.dense_int_elements(window_dilation),
        padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64),
                                            shape=(len(padding), 2)))
    reducer = rw.regions[0].blocks.append(*(scalar_types + scalar_types))
    with ir.InsertionPoint(reducer):
        if jaxpr.effects:
            raise NotImplementedError(
                'Cannot lower effectful `reduce_window`.')
        out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, jaxpr,
                                          mlir.TokenSet(), consts,
                                          *([a] for a in reducer.arguments))
        mhlo.ReturnOp(util.flatten(out_nodes))
    return rw.results
Exemplo n.º 3
0
def debug_callback_lowering(ctx, *args, effect, callback, **params):
  if effect in core.ordered_effects:
    token = ctx.tokens_in.get(effect)[0]
    result, keepalive, token = _ordered_effect_lowering(ctx, token,
        *args, effect=effect, callback=callback, **params)
    ctx.set_tokens_out(mlir.TokenSet({effect: (token,)}))
  else:
    def _callback(*flat_args):
      return tuple(debug_callback_p.impl(
        *flat_args, effect=effect, callback=callback, **params))
    result, keepalive = mlir.emit_python_callback(ctx.module_context.platform,
      _callback, list(args), ctx.avals_in, ctx.avals_out,  True)
  ctx.module_context.add_keepalive(keepalive)
  return result
Exemplo n.º 4
0
def function_effect_lowering(ctx, *, effect):
  def _f(ctx):
    ctx.set_tokens_out(ctx.tokens_in)
    return []
  func = mlir._emit_lowering_rule_as_fun(_f, ctx)

  output_types = map(mlir.aval_to_ir_types, ctx.avals_out)
  token_types = [mlir.token_type() for _ in ctx.tokens_in.items()]
  output_types = [*token_types, *output_types]
  flat_output_types = util.flatten(output_types)
  call = mlir.func_dialect.CallOp(flat_output_types,
                                  mlir.ir.FlatSymbolRefAttr.get(func.name.value),
                                  mlir.flatten_lowering_ir_args(ctx.tokens_in.tokens()))
  tokens, out = util.split_list(call.results, [len(ctx.tokens_in)])
  ctx.set_tokens_out(mlir.TokenSet(zip(ctx.tokens_in.effects(), tokens)))
  return out
def callback_effect_lowering(ctx: mlir.LoweringRuleContext, *args, callback,
                             out_avals, effect):
    del out_avals

    def _token_callback(token, *args):
        out = callback(*args)
        flat_out = jax.tree_util.tree_leaves(out)
        return (token, *flat_out)

    token_in = ctx.tokens_in.get(effect)[0]
    (token_out, *out_op), keep_alive = mlir.emit_python_callback(
        ctx.module_context.platform, _token_callback, [token_in, *args],
        [core.abstract_token, *ctx.avals_in],
        [core.abstract_token, *ctx.avals_out], True)
    ctx.module_context.add_keepalive(keep_alive)
    ctx.set_tokens_out(
        ctx.tokens_in.update_tokens(mlir.TokenSet({effect: token_out})))
    return out_op
Exemplo n.º 6
0
def _cond_lowering(ctx, index, *args, branches, linear):
    del linear  # Unused.
    joined_effects = core.join_effects(*(branch.effects
                                         for branch in branches))
    ordered_effects = [
        eff for eff in joined_effects if eff in core.ordered_effects
    ]
    num_tokens = len(ordered_effects)
    tokens_in = ctx.tokens_in.subset(ordered_effects)
    output_token_types = [mlir.token_type() for _ in ordered_effects]
    output_types = [
        *output_token_types, *_map(mlir.aval_to_ir_types, ctx.avals_out)
    ]
    flat_output_types = util.flatten(output_types)

    # mhlo.CaseOp takes a single argument 'index' and the corresponding blocks
    # have no arguments; the computation within the block uses implicit
    # captures.
    case_op = mhlo.CaseOp(flat_output_types,
                          index=index,
                          num_branches=len(branches))
    name_stack = extend_name_stack(ctx.module_context.name_stack, 'cond')
    for i, jaxpr in enumerate(branches):
        branch = case_op.regions[i].blocks.append()
        with ir.InsertionPoint(branch):
            sub_ctx = ctx.module_context.replace(
                name_stack=xla.extend_name_stack(name_stack,
                                                 f'branch_{i}_fun'))
            out_vals, tokens_out = mlir.jaxpr_subcomp(
                sub_ctx, jaxpr.jaxpr, tokens_in,
                _map(mlir.ir_constants, jaxpr.consts),
                *_map(mlir.wrap_singleton_ir_values, args))
            out_tokens = [tokens_out.get(eff) for eff in ordered_effects]
            out_vals = [*out_tokens, *out_vals]
            mhlo.ReturnOp(util.flatten(out_vals))

    tokens_and_outputs = util.unflatten(case_op.results,
                                        _map(len, output_types))
    tokens, outputs = util.split_list(tokens_and_outputs, [num_tokens])
    ctx.set_tokens_out(mlir.TokenSet(zip(ordered_effects, tokens)))
    return outputs
Exemplo n.º 7
0
 def bad_effect_lowering(ctx, *, effect):
     ctx.set_tokens_out(mlir.TokenSet(bar=ctx.tokens_in.get('foo')))
     return []