Exemplo n.º 1
0
def _select_and_scatter_lower(ctx, operand, source, init_value, *,
                              select_jaxpr, select_consts, scatter_jaxpr,
                              scatter_consts, window_dimensions,
                              window_strides, padding):
    operand_aval, source_aval, init_value_aval = ctx.avals_in
    aval_out, = ctx.avals_out
    scalar_aval = operand_aval.update(shape=())
    scalar_type = mlir.aval_to_ir_type(scalar_aval)
    op = mhlo.SelectAndScatterOp(
        mlir.aval_to_ir_type(aval_out), operand, source, init_value,
        mlir.dense_int_elements(window_dimensions),
        mlir.dense_int_elements(window_strides),
        ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64)))
    select = op.select.blocks.append(scalar_type, scalar_type)
    with ir.InsertionPoint(select):
        out_nodes = mlir.jaxpr_subcomp(ctx.module_context, select_jaxpr,
                                       select_consts,
                                       *([a] for a in select.arguments))
        mhlo.ReturnOp(util.flatten(out_nodes))
    scatter = op.scatter.blocks.append(scalar_type, scalar_type)
    with ir.InsertionPoint(scatter):
        out_nodes = mlir.jaxpr_subcomp(ctx.module_context, scatter_jaxpr,
                                       scatter_consts,
                                       *([a] for a in scatter.arguments))
        mhlo.ReturnOp(util.flatten(out_nodes))
    return op.results
Exemplo n.º 2
0
def _generic_reduce_window_lower(ctx, *args, jaxpr, consts, window_dimensions,
                                 window_strides, padding, base_dilation,
                                 window_dilation):
    operands, init_values = util.split_list(args, [len(args) // 2])
    _, init_value_avals = util.split_list(ctx.avals_in, [len(operands)])
    scalar_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals]
    rw = mhlo.ReduceWindowOp(
        map(mlir.aval_to_ir_type, ctx.avals_out),
        operands,
        init_values,
        mlir.dense_int_elements(window_dimensions),
        window_strides=mlir.dense_int_elements(window_strides),
        base_dilations=mlir.dense_int_elements(base_dilation),
        window_dilations=mlir.dense_int_elements(window_dilation),
        padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64),
                                            shape=(len(padding), 2)))
    reducer = rw.regions[0].blocks.append(*(scalar_types + scalar_types))
    with ir.InsertionPoint(reducer):
        if jaxpr.effects:
            raise NotImplementedError(
                'Cannot lower effectful `reduce_window`.')
        out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, jaxpr,
                                          mlir.TokenSet(), consts,
                                          *([a] for a in reducer.arguments))
        mhlo.ReturnOp(util.flatten(out_nodes))
    return rw.results
Exemplo n.º 3
0
def _generic_reduce_window_lower(ctx, avals_in, avals_out, *args, jaxpr,
                                 consts, window_dimensions, window_strides,
                                 padding, base_dilation, window_dilation):
    operands, init_values = util.split_list(args, [len(args) // 2])
    _, init_value_avals = util.split_list(avals_in, [len(operands)])
    scalar_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals]
    rw = mhlo.ReduceWindowOp(
        map(mlir.aval_to_ir_type, avals_out), operands, init_values,
        mlir.dense_int_elements(window_dimensions),
        mlir.dense_int_elements(window_strides),
        mlir.dense_int_elements(base_dilation),
        mlir.dense_int_elements(window_dilation),
        ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64)))
    reducer = rw.regions[0].blocks.append(*(scalar_types + scalar_types))
    with ir.InsertionPoint(reducer):
        out_nodes = mlir.jaxpr_subcomp(ctx, jaxpr, consts,
                                       *([a] for a in reducer.arguments))
        mhlo.ReturnOp(util.flatten(out_nodes))
    return rw.results
Exemplo n.º 4
0
def _cond_lowering(ctx, index, *args, branches, linear):
    del linear  # Unused.
    joined_effects = core.join_effects(*(branch.effects
                                         for branch in branches))
    ordered_effects = [
        eff for eff in joined_effects if eff in core.ordered_effects
    ]
    num_tokens = len(ordered_effects)
    tokens_in = ctx.tokens_in.subset(ordered_effects)
    output_token_types = [mlir.token_type() for _ in ordered_effects]
    output_types = [
        *output_token_types, *_map(mlir.aval_to_ir_types, ctx.avals_out)
    ]
    flat_output_types = util.flatten(output_types)

    # mhlo.CaseOp takes a single argument 'index' and the corresponding blocks
    # have no arguments; the computation within the block uses implicit
    # captures.
    case_op = mhlo.CaseOp(flat_output_types,
                          index=index,
                          num_branches=len(branches))
    name_stack = extend_name_stack(ctx.module_context.name_stack, 'cond')
    for i, jaxpr in enumerate(branches):
        branch = case_op.regions[i].blocks.append()
        with ir.InsertionPoint(branch):
            sub_ctx = ctx.module_context.replace(
                name_stack=xla.extend_name_stack(name_stack,
                                                 f'branch_{i}_fun'))
            out_vals, tokens_out = mlir.jaxpr_subcomp(
                sub_ctx, jaxpr.jaxpr, tokens_in,
                _map(mlir.ir_constants, jaxpr.consts),
                *_map(mlir.wrap_singleton_ir_values, args))
            out_tokens = [tokens_out.get(eff) for eff in ordered_effects]
            out_vals = [*out_tokens, *out_vals]
            mhlo.ReturnOp(util.flatten(out_vals))

    tokens_and_outputs = util.unflatten(case_op.results,
                                        _map(len, output_types))
    tokens, outputs = util.split_list(tokens_and_outputs, [num_tokens])
    ctx.set_tokens_out(mlir.TokenSet(zip(ordered_effects, tokens)))
    return outputs