Exemplo n.º 1
0
def _select_and_scatter_lower(ctx, operand, source, init_value, *,
                              select_jaxpr, select_consts, scatter_jaxpr,
                              scatter_consts, window_dimensions,
                              window_strides, padding):
    operand_aval, source_aval, init_value_aval = ctx.avals_in
    aval_out, = ctx.avals_out
    scalar_aval = operand_aval.update(shape=())
    scalar_type = mlir.aval_to_ir_type(scalar_aval)
    op = mhlo.SelectAndScatterOp(
        mlir.aval_to_ir_type(aval_out), operand, source, init_value,
        mlir.dense_int_elements(window_dimensions),
        mlir.dense_int_elements(window_strides),
        ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64)))
    select = op.select.blocks.append(scalar_type, scalar_type)
    with ir.InsertionPoint(select):
        out_nodes = mlir.jaxpr_subcomp(ctx.module_context, select_jaxpr,
                                       select_consts,
                                       *([a] for a in select.arguments))
        mhlo.ReturnOp(util.flatten(out_nodes))
    scatter = op.scatter.blocks.append(scalar_type, scalar_type)
    with ir.InsertionPoint(scatter):
        out_nodes = mlir.jaxpr_subcomp(ctx.module_context, scatter_jaxpr,
                                       scatter_consts,
                                       *([a] for a in scatter.arguments))
        mhlo.ReturnOp(util.flatten(out_nodes))
    return op.results
Exemplo n.º 2
0
Arquivo: mlir.py Projeto: rsepassi/jax
def _remat_using_while(ctx, avals_in, avals_out, *args, name, call_jaxpr):
    input_types = map(aval_to_ir_types, avals_in)
    output_types = map(aval_to_ir_types, avals_out)
    flat_output_types = util.flatten(output_types)
    int32_scalar_type = aval_to_ir_type(
        core.ShapedArray((), np.dtype(np.int32)))
    loop_carry_types = [(int32_scalar_type, )] + input_types + output_types
    flat_loop_carry_types = util.flatten(loop_carry_types)
    counter_init = ir_constants(np.array(0, np.int32))
    flat_args = flatten_lowering_ir_args((counter_init, ) + args + tuple(
        _dummy_like_aval(aval) for aval in avals_out))
    loop_carry_tuple_type = ir.TupleType.get_tuple(flat_loop_carry_types)
    init_carry = mhlo.TupleOp(loop_carry_tuple_type, flat_args)

    one = ir_constant(np.array(1, np.int32))
    while_op = mhlo.WhileOp([loop_carry_tuple_type], [init_carry.result])

    # Loop condition
    cond_block = while_op.regions[0].blocks.append(loop_carry_tuple_type)
    with ir.InsertionPoint(cond_block):
        bool_scalar_type = aval_to_ir_type(
            core.ShapedArray((), np.dtype(np.bool_)))
        two = ir_constant(np.array(2, np.int32))
        shape = ir_constant(np.array((), np.int64), canonicalize_types=False)
        rng = mhlo.RngUniformOp(one, two, shape).result
        i = mhlo.GetTupleElementOp(int32_scalar_type, cond_block.arguments[0],
                                   i32_attr(0))
        cmp = mhlo.CompareOp(bool_scalar_type, i, rng, ir.StringAttr.get("LT"),
                             ir.StringAttr.get("SIGNED")).result
        mhlo.ReturnOp([cmp])

    body_block = while_op.regions[1].blocks.append(loop_carry_tuple_type)
    with ir.InsertionPoint(body_block):
        flat_body_args = [
            mhlo.GetTupleElementOp(input_type, body_block.arguments[0],
                                   i32_attr(i)).result
            for i, input_type in enumerate(flat_loop_carry_types)
        ]
        body_args = util.unflatten(flat_body_args, map(len, loop_carry_types))
        ((i, ), ), y, _ = util.split_list(body_args, [1, len(avals_in)])
        body_ctx = ctx.replace(name_stack=xla.extend_name_stack(
            ctx.name_stack, xla.wrap_name(name, 'remat')))
        z = jaxpr_subcomp(body_ctx, call_jaxpr, (), *y)
        i_next = mhlo.AddOp(i, one).result
        new_carry = mhlo.TupleOp(loop_carry_tuple_type,
                                 [i_next, *util.flatten(y), *util.flatten(z)])
        mhlo.ReturnOp([new_carry.result])

    outputs = [
        mhlo.GetTupleElementOp(output_type, while_op.result,
                               i32_attr(1 + len(avals_in) + i)).result
        for i, output_type in enumerate(flat_output_types)
    ]
    return util.unflatten(outputs, map(len, output_types))
Exemplo n.º 3
0
def _generic_reduce_window_lower(ctx, *args, jaxpr, consts, window_dimensions,
                                 window_strides, padding, base_dilation,
                                 window_dilation):
    operands, init_values = util.split_list(args, [len(args) // 2])
    _, init_value_avals = util.split_list(ctx.avals_in, [len(operands)])
    scalar_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals]
    rw = mhlo.ReduceWindowOp(
        map(mlir.aval_to_ir_type, ctx.avals_out),
        operands,
        init_values,
        mlir.dense_int_elements(window_dimensions),
        window_strides=mlir.dense_int_elements(window_strides),
        base_dilations=mlir.dense_int_elements(base_dilation),
        window_dilations=mlir.dense_int_elements(window_dilation),
        padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64),
                                            shape=(len(padding), 2)))
    reducer = rw.regions[0].blocks.append(*(scalar_types + scalar_types))
    with ir.InsertionPoint(reducer):
        if jaxpr.effects:
            raise NotImplementedError(
                'Cannot lower effectful `reduce_window`.')
        out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, jaxpr,
                                          mlir.TokenSet(), consts,
                                          *([a] for a in reducer.arguments))
        mhlo.ReturnOp(util.flatten(out_nodes))
    return rw.results
Exemplo n.º 4
0
 def __init__(
     self,
     platform: str,
     axis_env: xla.AxisEnv,
     name_stack: str,
     context: Optional[ir.Context] = None,
     module: Optional[ir.Module] = None,
     ip: Optional[ir.InsertionPoint] = None,
     symbol_table: Optional[ir.SymbolTable] = None,
     cached_primitive_lowerings: Optional[Dict[Any,
                                               builtin.FuncOp]] = None):
     assert platform is not None
     self.context = context or ir.Context()
     self.module = module or ir.Module.create(
         loc=ir.Location.unknown(self.context))
     self.ip = ip or ir.InsertionPoint(self.module.operation.opview.body)
     self.symbol_table = symbol_table or ir.SymbolTable(
         self.module.operation)
     self.platform = platform
     self.axis_env = axis_env
     self.name_stack = name_stack
     self.cached_primitive_lowerings = ({} if
                                        cached_primitive_lowerings is None
                                        else cached_primitive_lowerings)
     mhlo.register_mhlo_dialect(self.context)
     chlo.register_chlo_dialect(self.context)
Exemplo n.º 5
0
def _reduce_window_lower(reduce_op, init_value, ctx, operand, *,
                         window_dimensions, window_strides, padding,
                         base_dilation, window_dilation):
    aval_out, = ctx.avals_out
    operand_aval, = ctx.avals_in
    scalar_aval = operand_aval.update(shape=())
    scalar_type = mlir.aval_to_ir_type(scalar_aval)
    rw = mhlo.ReduceWindowOp(
        mlir.aval_to_ir_types(aval_out), [operand],
        [mlir.full_like_aval(init_value(scalar_aval.dtype), scalar_aval)],
        mlir.dense_int_elements(window_dimensions),
        mlir.dense_int_elements(window_strides),
        mlir.dense_int_elements(base_dilation),
        mlir.dense_int_elements(window_dilation),
        ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64)))
    reducer = rw.regions[0].blocks.append(scalar_type, scalar_type)
    with ir.InsertionPoint(reducer):
        mhlo.ReturnOp(reduce_op(*reducer.arguments))
    return rw.results
Exemplo n.º 6
0
def _generic_reduce_window_lower(ctx, avals_in, avals_out, *args, jaxpr,
                                 consts, window_dimensions, window_strides,
                                 padding, base_dilation, window_dilation):
    operands, init_values = util.split_list(args, [len(args) // 2])
    _, init_value_avals = util.split_list(avals_in, [len(operands)])
    scalar_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals]
    rw = mhlo.ReduceWindowOp(
        map(mlir.aval_to_ir_type, avals_out), operands, init_values,
        mlir.dense_int_elements(window_dimensions),
        mlir.dense_int_elements(window_strides),
        mlir.dense_int_elements(base_dilation),
        mlir.dense_int_elements(window_dilation),
        ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64)))
    reducer = rw.regions[0].blocks.append(*(scalar_types + scalar_types))
    with ir.InsertionPoint(reducer):
        out_nodes = mlir.jaxpr_subcomp(ctx, jaxpr, consts,
                                       *([a] for a in reducer.arguments))
        mhlo.ReturnOp(util.flatten(out_nodes))
    return rw.results
Exemplo n.º 7
0
Arquivo: mlir.py Projeto: GJBoth/jax
def _emit_lowering_rule_as_fun(lowering_rule,
                               ctx: LoweringRuleContext) -> builtin.FuncOp:
  """Emits the contents of a lowering rule as a private function."""
  input_types = map(aval_to_ir_types, ctx.avals_in)
  output_types = map(aval_to_ir_types, ctx.avals_out)
  flat_input_types = util.flatten(input_types)
  flat_output_types = util.flatten(output_types)
  ftype = ir.FunctionType.get(flat_input_types, flat_output_types)
  assert ctx.primitive is not None
  func_op = builtin.FuncOp(ctx.primitive.name, ftype, ip=ctx.module_context.ip)
  func_op.attributes["sym_visibility"] = ir.StringAttr.get("private")
  ctx.module_context.symbol_table.insert(func_op)
  entry_block = func_op.add_entry_block()
  with ir.InsertionPoint(entry_block):
    unflattened_args = util.unflatten(entry_block.arguments,
                                      map(len, input_types))
    outs = lowering_rule(ctx, *_unwrap_singleton_ir_values(unflattened_args))
    std.ReturnOp(util.flatten(map(wrap_singleton_ir_values, outs)))
  return func_op
Exemplo n.º 8
0
Arquivo: mlir.py Projeto: rsepassi/jax
 def __init__(self,
              platform: str,
              axis_env: xla.AxisEnv,
              name_stack: str,
              context: Optional[ir.Context] = None,
              module: Optional[ir.Module] = None,
              ip: Optional[ir.InsertionPoint] = None,
              symbol_table: Optional[ir.SymbolTable] = None):
     assert platform is not None
     self.context = context or ir.Context()
     self.module = module or ir.Module.create(
         loc=ir.Location.unknown(self.context))
     self.ip = ip or ir.InsertionPoint(self.module.operation.opview.body)
     self.symbol_table = symbol_table or ir.SymbolTable(
         self.module.operation)
     self.platform = platform
     self.axis_env = axis_env
     self.name_stack = name_stack
     mhlo.register_mhlo_dialect(self.context)
     chlo.register_chlo_dialect(self.context)
Exemplo n.º 9
0
def _cond_lowering(ctx, index, *args, branches, linear):
    del linear  # Unused.
    joined_effects = core.join_effects(*(branch.effects
                                         for branch in branches))
    ordered_effects = [
        eff for eff in joined_effects if eff in core.ordered_effects
    ]
    num_tokens = len(ordered_effects)
    tokens_in = ctx.tokens_in.subset(ordered_effects)
    output_token_types = [mlir.token_type() for _ in ordered_effects]
    output_types = [
        *output_token_types, *_map(mlir.aval_to_ir_types, ctx.avals_out)
    ]
    flat_output_types = util.flatten(output_types)

    # mhlo.CaseOp takes a single argument 'index' and the corresponding blocks
    # have no arguments; the computation within the block uses implicit
    # captures.
    case_op = mhlo.CaseOp(flat_output_types,
                          index=index,
                          num_branches=len(branches))
    name_stack = extend_name_stack(ctx.module_context.name_stack, 'cond')
    for i, jaxpr in enumerate(branches):
        branch = case_op.regions[i].blocks.append()
        with ir.InsertionPoint(branch):
            sub_ctx = ctx.module_context.replace(
                name_stack=xla.extend_name_stack(name_stack,
                                                 f'branch_{i}_fun'))
            out_vals, tokens_out = mlir.jaxpr_subcomp(
                sub_ctx, jaxpr.jaxpr, tokens_in,
                _map(mlir.ir_constants, jaxpr.consts),
                *_map(mlir.wrap_singleton_ir_values, args))
            out_tokens = [tokens_out.get(eff) for eff in ordered_effects]
            out_vals = [*out_tokens, *out_vals]
            mhlo.ReturnOp(util.flatten(out_vals))

    tokens_and_outputs = util.unflatten(case_op.results,
                                        _map(len, output_types))
    tokens, outputs = util.split_list(tokens_and_outputs, [num_tokens])
    ctx.set_tokens_out(mlir.TokenSet(zip(ordered_effects, tokens)))
    return outputs
Exemplo n.º 10
0
 def __init__(
         self,
         platform: str,
         axis_context: AxisContext,
         name_stack: NameStack,
         context: Optional[ir.Context] = None,
         module: Optional[ir.Module] = None,
         ip: Optional[ir.InsertionPoint] = None,
         symbol_table: Optional[ir.SymbolTable] = None,
         cached_primitive_lowerings: Optional[Dict[Any,
                                                   FuncOpType]] = None):
     assert platform is not None
     self.context = context or make_ir_context()
     self.module = module or ir.Module.create(
         loc=ir.Location.unknown(self.context))
     self.ip = ip or ir.InsertionPoint(self.module.body)
     self.symbol_table = symbol_table or ir.SymbolTable(
         self.module.operation)
     self.platform = platform
     self.axis_context = axis_context
     self.name_stack = name_stack
     self.cached_primitive_lowerings = ({} if
                                        cached_primitive_lowerings is None
                                        else cached_primitive_lowerings)
Exemplo n.º 11
0
def lower_jaxpr_to_fun(
    ctx: ModuleContext,
    name: str,
    jaxpr: core.ClosedJaxpr,
    *,
    public: bool = False,
    replace_units_with_dummy: bool = False,
    replace_tokens_with_dummy: bool = False,
    replicated_args: Optional[Sequence[bool]] = None,
    arg_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
    result_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
    use_sharding_annotations: bool = True,
    input_output_aliases: Optional[Sequence[Optional[int]]] = None
) -> FuncOpType:
    """Lowers jaxpr and its callees to an IR function.

  Assumes that an MLIR context, location, and insertion point are set.

  Args:
    ctx: the lowering context.
    name: the function name. The name will be uniquified by the symbol table,
      so it is ok to use the same name multiple times.
    jaxpr: the jaxpr to lower.
    public: if true, the function's visibility is set to "public".
    replace_units_with_dummy: if true, unit arguments/return values are
      replaced with bool arrays of size [0].
    replace_tokens_with_dummy: if true, token arguments/return values are
      replaced with bool arrays of size [0].
    replicated_args: if present, annotates arguments as replicated.
    arg_shardings: sharding annotations for each argument (optional).
    result_shardings: sharding annotations for each argument (optional).
    use_sharding_annotations: if True, use mhlo.sharding annotations on
      parameters and return values to express sharding. If False, use
      mhlo.custom_call operators with sharding annotations.
      TODO(b/228598865): remove this option when mhlo.sharding annotations are
      propagated on non-entry functions during MHLO->HLO conversion.
    input_output_aliases: optional sequence that maps argument numbers to the
      corresponding output that should alias them.
  Returns the name of the function.
  """
    def aval_to_types(aval):
        if replace_units_with_dummy and aval is core.abstract_unit:
            aval = core.ShapedArray((), np.dtype(np.bool_))
        elif replace_tokens_with_dummy and aval is core.abstract_token:
            aval = core.ShapedArray((), np.dtype(np.bool_))
        return aval_to_ir_types(aval)

    input_types = map(aval_to_types, jaxpr.in_avals)
    output_types = map(aval_to_types, jaxpr.out_avals)
    flat_input_types = util.flatten(input_types)
    flat_output_types = util.flatten(output_types)
    ftype = ir.FunctionType.get(flat_input_types, flat_output_types)
    func_op = FuncOp(name, ftype, ip=ctx.ip)
    func_op.attributes["sym_visibility"] = ir.StringAttr.get(
        "public" if public else "private")
    ctx.symbol_table.insert(func_op)
    ir_arg_shardings = None
    if arg_shardings is not None:
        ir_arg_shardings = util.flatten(
            [[sharding] * len(types)
             for sharding, types in zip(arg_shardings, input_types)])
    ir_result_shardings = None
    if result_shardings is not None:
        ir_result_shardings = util.flatten(
            [[sharding] * len(types)
             for sharding, types in zip(result_shardings, output_types)])

    if (replicated_args is not None or ir_arg_shardings is not None
            or input_output_aliases is not None):
        arg_attrs: List[Dict[str, ir.Attribute]] = [
            {} for _ in range(len(flat_input_types))
        ]

        if replicated_args is not None:
            replicated_ir_args = [
                [replicated] * len(types)
                for replicated, types in zip(replicated_args, input_types)
            ]
            for attrs, replicated in zip(arg_attrs,
                                         util.flatten(replicated_ir_args)):
                if replicated:
                    attrs[
                        "mhlo.is_same_data_across_replicas"] = ir.UnitAttr.get(
                        )

        if use_sharding_annotations and ir_arg_shardings is not None:
            for attrs, sharding in zip(arg_attrs, ir_arg_shardings):
                if sharding is not None:
                    attrs["mhlo.sharding"] = ir.StringAttr.get(
                        sharding.SerializeToString())

        if input_output_aliases is not None:
            output_ids = util.unflatten(list(range(len(flat_output_types))),
                                        map(len, output_types))
            aliases: List[Optional[int]] = []
            for types, alias in zip(input_types, input_output_aliases):
                if alias is None:
                    aliases.extend([None] * len(types))
                else:
                    aliases.extend(output_ids[alias])

            for attrs, alias in zip(arg_attrs, aliases):
                if alias is not None:
                    attrs["tf.aliasing_output"] = i32_attr(alias)

        func_op.arg_attrs = ir.ArrayAttr.get(
            [ir.DictAttr.get(attrs) for attrs in arg_attrs])

    if use_sharding_annotations and ir_result_shardings is not None:
        func_op.result_attrs = ir.ArrayAttr.get([
            ir.DictAttr.get({} if sharding is None else {
                "mhlo.sharding":
                ir.StringAttr.get(sharding.SerializeToString())
            }) for sharding in ir_result_shardings
        ])

    entry_block = func_op.add_entry_block()
    with ir.InsertionPoint(entry_block):
        flat_args = entry_block.arguments
        if not use_sharding_annotations and ir_arg_shardings is not None:
            flat_args = map(wrap_with_sharding_op, flat_args, ir_arg_shardings)

        unflattened_args = util.unflatten(flat_args, map(len, input_types))
        args: List[List[ir.Value]] = []
        for aval, arg in zip(jaxpr.in_avals, unflattened_args):
            if replace_units_with_dummy and aval is core.abstract_unit:
                args.append([])
            elif replace_tokens_with_dummy and aval is core.abstract_token:
                args.append(mhlo.CreateTokenOp(mhlo.TokenType.get()).results)
            else:
                args.append(arg)
        callee_name_stack = xla.extend_name_stack(ctx.name_stack,
                                                  xla.wrap_name(name, 'jit'))
        out_vals = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
                                 jaxpr.jaxpr, map(ir_constants,
                                                  jaxpr.consts), *args)
        outs = []
        for aval, out in zip(jaxpr.out_avals, out_vals):
            if replace_units_with_dummy and aval is core.abstract_unit:
                outs.append(ir_constants(np.zeros((), np.bool_)))
            elif replace_tokens_with_dummy and aval is core.abstract_token:
                outs.append(ir_constants(np.zeros((), np.bool_)))
            else:
                outs.append(out)
        flat_outputs = util.flatten(outs)
        if not use_sharding_annotations and ir_result_shardings is not None:
            flat_outputs = map(wrap_with_sharding_op, flat_outputs,
                               ir_result_shardings)

        func_dialect.ReturnOp(flat_outputs)

    return func_op
Exemplo n.º 12
0
Arquivo: mlir.py Projeto: rsepassi/jax
def lower_jaxpr_to_fun(ctx: LoweringContext,
                       name: str,
                       jaxpr: core.ClosedJaxpr,
                       *,
                       public: bool = False,
                       replace_units_with_dummy: bool = False,
                       replace_tokens_with_dummy: bool = False) -> str:
    """Lowers jaxpr and its callees to an IR function.

  Assumes that an MLIR context, location, and insertion point are set.

  Args:
    ctx: the lowering context.
    name: the function name. The name will be uniquified by the symbol table,
      so it is ok to use the same name multiple times.
    jaxpr: the jaxpr to lower.
    public: if true, the function's visibility is set to "public".
    replace_units_with_dummy: if true, unit arguments/return values are
      replaced with bool arrays of size [0].
    replace_tokens_with_dummy: if true, token arguments/return values are
      replaced with bool arrays of size [0].
  Returns the name of the function.
  """
    def aval_to_types(aval):
        if replace_units_with_dummy and aval is core.abstract_unit:
            aval = core.ShapedArray((), np.dtype(np.bool_))
        elif replace_tokens_with_dummy and aval is core.abstract_token:
            aval = core.ShapedArray((), np.dtype(np.bool_))
        return aval_to_ir_types(aval)

    input_types = map(aval_to_types, jaxpr.in_avals)
    output_types = map(aval_to_types, jaxpr.out_avals)
    flat_input_types = util.flatten(input_types)
    flat_output_types = util.flatten(output_types)
    ftype = ir.FunctionType.get(flat_input_types, flat_output_types)
    func_op = builtin.FuncOp(name, ftype, ip=ctx.ip)
    func_op.attributes["sym_visibility"] = ir.StringAttr.get(
        "public" if public else "private")
    symbol_name = ir.StringAttr(ctx.symbol_table.insert(func_op)).value
    entry_block = func_op.add_entry_block()
    with ir.InsertionPoint(entry_block):
        unflattened_args = util.unflatten(entry_block.arguments,
                                          map(len, input_types))
        args: List[List[ir.Value]] = []
        for aval, arg in zip(jaxpr.in_avals, unflattened_args):
            if replace_units_with_dummy and aval is core.abstract_unit:
                args.append([])
            elif replace_tokens_with_dummy and aval is core.abstract_token:
                args.append(mhlo.CreateTokenOp(mhlo.TokenType.get()).results)
            else:
                args.append(arg)
        callee_name_stack = xla.extend_name_stack(ctx.name_stack,
                                                  xla.wrap_name(name, 'jit'))
        out_vals = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
                                 jaxpr.jaxpr, map(ir_constants,
                                                  jaxpr.consts), *args)
        outs = []
        for aval, out in zip(jaxpr.out_avals, out_vals):
            if replace_units_with_dummy and aval is core.abstract_unit:
                outs.append(ir_constants(np.zeros((), np.bool_)))
            elif replace_tokens_with_dummy and aval is core.abstract_token:
                outs.append(ir_constants(np.zeros((), np.bool_)))
            else:
                outs.append(out)
        std.ReturnOp(util.flatten(outs))

    return symbol_name
Exemplo n.º 13
0
def _select_and_gather_add_lowering(ctx,
                                    tangents,
                                    operand,
                                    *,
                                    select_prim,
                                    window_dimensions,
                                    window_strides,
                                    padding,
                                    base_dilation,
                                    window_dilation,
                                    max_bits=64):
    _, operand_aval, = ctx.avals_in
    out_aval, = ctx.avals_out
    dtype = operand_aval.dtype
    etype = mlir.dtype_to_ir_type(dtype)
    nbits = dtypes.finfo(dtype).bits

    assert nbits <= max_bits
    double_word_reduction = nbits * 2 <= max_bits

    const = lambda dtype, x: mlir.ir_constant(np.array(x, dtype=dtype),
                                              canonicalize_types=False)

    if jax._src.lib.mlir_api_version >= 9:

        def _broadcast(x, dims):
            return mhlo.BroadcastOp(x, mlir.dense_int_elements(dims))
    else:

        def _broadcast(x, dims):
            etype = ir.RankedTensorType(x.type).element_type
            return mhlo.BroadcastOp(ir.RankedTensorType(dims, etype), x,
                                    mlir.dense_int_elements(dims))

    if double_word_reduction:
        # TODO(b/73062247): XLA doesn't yet implement ReduceWindow on tuples, so
        # we implement a pair-wise ReduceWindow by packing two k-bit values into
        # 2k-bit unsigned integer using bit tricks.
        word_dtype = lax._UINT_DTYPES[nbits]
        double_word_dtype = lax._UINT_DTYPES[nbits * 2]
        word_type = mlir.dtype_to_ir_type(word_dtype)
        double_word_type = mlir.dtype_to_ir_type(double_word_dtype)

        # Packs two values into a tuple.
        def pack(a, b):
            a_dims = ir.RankedTensorType(a.type).shape
            b_dims = ir.RankedTensorType(b.type).shape
            a = mhlo.BitcastConvertOp(
                ir.RankedTensorType.get(a_dims, word_type), a)
            b = mhlo.BitcastConvertOp(
                ir.RankedTensorType.get(b_dims, word_type), b)
            a = mhlo.ConvertOp(
                ir.RankedTensorType.get(a_dims, double_word_type), a)
            b = mhlo.ConvertOp(
                ir.RankedTensorType.get(b_dims, double_word_type), b)
            a = mhlo.ShiftLeftOp(
                a, _broadcast(const(double_word_dtype, nbits), a_dims))
            return mhlo.OrOp(a, b)

        # Unpacks the first element of a tuple.
        def fst(t):
            dims = ir.RankedTensorType(t.type).shape
            st = mhlo.ShiftRightLogicalOp(t, const(double_word_dtype, nbits))
            return mhlo.BitcastConvertOp(
                ir.RankedTensorType.get(dims, etype),
                mhlo.ConvertOp(ir.RankedTensorType.get(dims, word_type),
                               st)).result

        # Unpacks the second element of a tuple.
        def snd(t):
            dims = ir.RankedTensorType(t.type).shape
            return mhlo.BitcastConvertOp(
                ir.RankedTensorType.get(dims, etype),
                mhlo.ConvertOp(ir.RankedTensorType.get(dims, word_type),
                               t)).result

    else:
        # The double-word trick above only works if we have a sufficiently large
        # type. As an alternative, we can pack two half words into a single word,
        # at the cost of precision.
        # TODO(b/73062247): add support for tuple reductions and remove this case.
        warnings.warn(
            "Using reduced precision for gradient of reduce-window "
            "min/max operator to work around missing XLA support for "
            "pair-reductions. This is likely from a second or "
            "higher derivative of a max-pooling operation.")
        r_nbits = nbits // 2
        # Drop/round the bottom mantissa bits.
        nexp = dtypes.finfo(dtype).nexp
        nmant = r_nbits - nexp - 1

        double_word_dtype = word_dtype = lax._UINT_DTYPES[nbits]
        double_word_type = word_type = mlir.dtype_to_ir_type(word_dtype)

        # Packs two values into a tuple.
        def pack(a, b):
            a_dims = ir.RankedTensorType(a.type).shape
            b_dims = ir.RankedTensorType(b.type).shape
            if jax._src.lib.mlir_api_version >= 21:
                a = mhlo.ReducePrecisionOp(a,
                                           exponent_bits=mlir.i32_attr(nexp),
                                           mantissa_bits=mlir.i32_attr(nmant))
                b = mhlo.ReducePrecisionOp(b,
                                           exponent_bits=mlir.i32_attr(nexp),
                                           mantissa_bits=mlir.i32_attr(nmant))
            else:
                a = mhlo.ReducePrecisionOp(a.type,
                                           a,
                                           exponent_bits=mlir.i32_attr(nexp),
                                           mantissa_bits=mlir.i32_attr(nmant))
                b = mhlo.ReducePrecisionOp(b.type,
                                           b,
                                           exponent_bits=mlir.i32_attr(nexp),
                                           mantissa_bits=mlir.i32_attr(nmant))
            a = mhlo.BitcastConvertOp(
                ir.RankedTensorType.get(a_dims, word_type), a)
            b = mhlo.BitcastConvertOp(
                ir.RankedTensorType.get(b_dims, word_type), b)
            b = mhlo.ShiftRightLogicalOp(
                b, _broadcast(const(word_dtype, r_nbits), b_dims))
            return mhlo.OrOp(a, b)

        # Unpacks the first element of a tuple.
        def fst(t):
            st = mhlo.AndOp(t,
                            const(word_dtype, ((1 << r_nbits) - 1) << r_nbits))
            return mhlo.BitcastConvertOp(ir.RankedTensorType.get([], etype),
                                         st).result

        # Unpacks the second element of a tuple.
        def snd(t):
            dims = ir.RankedTensorType(t.type).shape
            return mhlo.BitcastConvertOp(
                ir.RankedTensorType.get(dims, etype),
                mhlo.ShiftLeftOp(t, _broadcast(const(word_dtype, r_nbits),
                                               dims))).result

    assert select_prim is lax.ge_p or select_prim is lax.le_p, select_prim
    init = -np.inf if select_prim is lax.ge_p else np.inf
    rw = mhlo.ReduceWindowOp(
        [ir.RankedTensorType.get(out_aval.shape, double_word_type)],
        pack(operand, tangents),
        pack(const(dtype, init), const(dtype, 0)),
        mlir.dense_int_elements(window_dimensions),
        window_strides=mlir.dense_int_elements(window_strides),
        base_dilations=mlir.dense_int_elements(base_dilation),
        window_dilations=mlir.dense_int_elements(window_dilation),
        padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64),
                                            shape=(len(padding), 2)))
    scalar_type = ir.RankedTensorType.get([], double_word_type)
    reducer = rw.regions[0].blocks.append(scalar_type, scalar_type)
    with ir.InsertionPoint(reducer):
        x, y = reducer.arguments
        assert select_prim is lax.ge_p or select_prim is lax.le_p
        which = "GE" if select_prim is lax.ge_p else "LE"
        out = mhlo.SelectOp(mlir.compare_mhlo(fst(x), fst(y), which), x, y)
        mhlo.ReturnOp(out)
    return [snd(rw.result)]