def _select_and_scatter_lower(ctx, operand, source, init_value, *, select_jaxpr, select_consts, scatter_jaxpr, scatter_consts, window_dimensions, window_strides, padding): operand_aval, source_aval, init_value_aval = ctx.avals_in aval_out, = ctx.avals_out scalar_aval = operand_aval.update(shape=()) scalar_type = mlir.aval_to_ir_type(scalar_aval) op = mhlo.SelectAndScatterOp( mlir.aval_to_ir_type(aval_out), operand, source, init_value, mlir.dense_int_elements(window_dimensions), mlir.dense_int_elements(window_strides), ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64))) select = op.select.blocks.append(scalar_type, scalar_type) with ir.InsertionPoint(select): out_nodes = mlir.jaxpr_subcomp(ctx.module_context, select_jaxpr, select_consts, *([a] for a in select.arguments)) mhlo.ReturnOp(util.flatten(out_nodes)) scatter = op.scatter.blocks.append(scalar_type, scalar_type) with ir.InsertionPoint(scatter): out_nodes = mlir.jaxpr_subcomp(ctx.module_context, scatter_jaxpr, scatter_consts, *([a] for a in scatter.arguments)) mhlo.ReturnOp(util.flatten(out_nodes)) return op.results
def _remat_using_while(ctx, avals_in, avals_out, *args, name, call_jaxpr): input_types = map(aval_to_ir_types, avals_in) output_types = map(aval_to_ir_types, avals_out) flat_output_types = util.flatten(output_types) int32_scalar_type = aval_to_ir_type( core.ShapedArray((), np.dtype(np.int32))) loop_carry_types = [(int32_scalar_type, )] + input_types + output_types flat_loop_carry_types = util.flatten(loop_carry_types) counter_init = ir_constants(np.array(0, np.int32)) flat_args = flatten_lowering_ir_args((counter_init, ) + args + tuple( _dummy_like_aval(aval) for aval in avals_out)) loop_carry_tuple_type = ir.TupleType.get_tuple(flat_loop_carry_types) init_carry = mhlo.TupleOp(loop_carry_tuple_type, flat_args) one = ir_constant(np.array(1, np.int32)) while_op = mhlo.WhileOp([loop_carry_tuple_type], [init_carry.result]) # Loop condition cond_block = while_op.regions[0].blocks.append(loop_carry_tuple_type) with ir.InsertionPoint(cond_block): bool_scalar_type = aval_to_ir_type( core.ShapedArray((), np.dtype(np.bool_))) two = ir_constant(np.array(2, np.int32)) shape = ir_constant(np.array((), np.int64), canonicalize_types=False) rng = mhlo.RngUniformOp(one, two, shape).result i = mhlo.GetTupleElementOp(int32_scalar_type, cond_block.arguments[0], i32_attr(0)) cmp = mhlo.CompareOp(bool_scalar_type, i, rng, ir.StringAttr.get("LT"), ir.StringAttr.get("SIGNED")).result mhlo.ReturnOp([cmp]) body_block = while_op.regions[1].blocks.append(loop_carry_tuple_type) with ir.InsertionPoint(body_block): flat_body_args = [ mhlo.GetTupleElementOp(input_type, body_block.arguments[0], i32_attr(i)).result for i, input_type in enumerate(flat_loop_carry_types) ] body_args = util.unflatten(flat_body_args, map(len, loop_carry_types)) ((i, ), ), y, _ = util.split_list(body_args, [1, len(avals_in)]) body_ctx = ctx.replace(name_stack=xla.extend_name_stack( ctx.name_stack, xla.wrap_name(name, 'remat'))) z = jaxpr_subcomp(body_ctx, call_jaxpr, (), *y) i_next = mhlo.AddOp(i, one).result new_carry = mhlo.TupleOp(loop_carry_tuple_type, [i_next, *util.flatten(y), *util.flatten(z)]) mhlo.ReturnOp([new_carry.result]) outputs = [ mhlo.GetTupleElementOp(output_type, while_op.result, i32_attr(1 + len(avals_in) + i)).result for i, output_type in enumerate(flat_output_types) ] return util.unflatten(outputs, map(len, output_types))
def _generic_reduce_window_lower(ctx, *args, jaxpr, consts, window_dimensions, window_strides, padding, base_dilation, window_dilation): operands, init_values = util.split_list(args, [len(args) // 2]) _, init_value_avals = util.split_list(ctx.avals_in, [len(operands)]) scalar_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals] rw = mhlo.ReduceWindowOp( map(mlir.aval_to_ir_type, ctx.avals_out), operands, init_values, mlir.dense_int_elements(window_dimensions), window_strides=mlir.dense_int_elements(window_strides), base_dilations=mlir.dense_int_elements(base_dilation), window_dilations=mlir.dense_int_elements(window_dilation), padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64), shape=(len(padding), 2))) reducer = rw.regions[0].blocks.append(*(scalar_types + scalar_types)) with ir.InsertionPoint(reducer): if jaxpr.effects: raise NotImplementedError( 'Cannot lower effectful `reduce_window`.') out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, jaxpr, mlir.TokenSet(), consts, *([a] for a in reducer.arguments)) mhlo.ReturnOp(util.flatten(out_nodes)) return rw.results
def __init__( self, platform: str, axis_env: xla.AxisEnv, name_stack: str, context: Optional[ir.Context] = None, module: Optional[ir.Module] = None, ip: Optional[ir.InsertionPoint] = None, symbol_table: Optional[ir.SymbolTable] = None, cached_primitive_lowerings: Optional[Dict[Any, builtin.FuncOp]] = None): assert platform is not None self.context = context or ir.Context() self.module = module or ir.Module.create( loc=ir.Location.unknown(self.context)) self.ip = ip or ir.InsertionPoint(self.module.operation.opview.body) self.symbol_table = symbol_table or ir.SymbolTable( self.module.operation) self.platform = platform self.axis_env = axis_env self.name_stack = name_stack self.cached_primitive_lowerings = ({} if cached_primitive_lowerings is None else cached_primitive_lowerings) mhlo.register_mhlo_dialect(self.context) chlo.register_chlo_dialect(self.context)
def _reduce_window_lower(reduce_op, init_value, ctx, operand, *, window_dimensions, window_strides, padding, base_dilation, window_dilation): aval_out, = ctx.avals_out operand_aval, = ctx.avals_in scalar_aval = operand_aval.update(shape=()) scalar_type = mlir.aval_to_ir_type(scalar_aval) rw = mhlo.ReduceWindowOp( mlir.aval_to_ir_types(aval_out), [operand], [mlir.full_like_aval(init_value(scalar_aval.dtype), scalar_aval)], mlir.dense_int_elements(window_dimensions), mlir.dense_int_elements(window_strides), mlir.dense_int_elements(base_dilation), mlir.dense_int_elements(window_dilation), ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64))) reducer = rw.regions[0].blocks.append(scalar_type, scalar_type) with ir.InsertionPoint(reducer): mhlo.ReturnOp(reduce_op(*reducer.arguments)) return rw.results
def _generic_reduce_window_lower(ctx, avals_in, avals_out, *args, jaxpr, consts, window_dimensions, window_strides, padding, base_dilation, window_dilation): operands, init_values = util.split_list(args, [len(args) // 2]) _, init_value_avals = util.split_list(avals_in, [len(operands)]) scalar_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals] rw = mhlo.ReduceWindowOp( map(mlir.aval_to_ir_type, avals_out), operands, init_values, mlir.dense_int_elements(window_dimensions), mlir.dense_int_elements(window_strides), mlir.dense_int_elements(base_dilation), mlir.dense_int_elements(window_dilation), ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64))) reducer = rw.regions[0].blocks.append(*(scalar_types + scalar_types)) with ir.InsertionPoint(reducer): out_nodes = mlir.jaxpr_subcomp(ctx, jaxpr, consts, *([a] for a in reducer.arguments)) mhlo.ReturnOp(util.flatten(out_nodes)) return rw.results
def _emit_lowering_rule_as_fun(lowering_rule, ctx: LoweringRuleContext) -> builtin.FuncOp: """Emits the contents of a lowering rule as a private function.""" input_types = map(aval_to_ir_types, ctx.avals_in) output_types = map(aval_to_ir_types, ctx.avals_out) flat_input_types = util.flatten(input_types) flat_output_types = util.flatten(output_types) ftype = ir.FunctionType.get(flat_input_types, flat_output_types) assert ctx.primitive is not None func_op = builtin.FuncOp(ctx.primitive.name, ftype, ip=ctx.module_context.ip) func_op.attributes["sym_visibility"] = ir.StringAttr.get("private") ctx.module_context.symbol_table.insert(func_op) entry_block = func_op.add_entry_block() with ir.InsertionPoint(entry_block): unflattened_args = util.unflatten(entry_block.arguments, map(len, input_types)) outs = lowering_rule(ctx, *_unwrap_singleton_ir_values(unflattened_args)) std.ReturnOp(util.flatten(map(wrap_singleton_ir_values, outs))) return func_op
def __init__(self, platform: str, axis_env: xla.AxisEnv, name_stack: str, context: Optional[ir.Context] = None, module: Optional[ir.Module] = None, ip: Optional[ir.InsertionPoint] = None, symbol_table: Optional[ir.SymbolTable] = None): assert platform is not None self.context = context or ir.Context() self.module = module or ir.Module.create( loc=ir.Location.unknown(self.context)) self.ip = ip or ir.InsertionPoint(self.module.operation.opview.body) self.symbol_table = symbol_table or ir.SymbolTable( self.module.operation) self.platform = platform self.axis_env = axis_env self.name_stack = name_stack mhlo.register_mhlo_dialect(self.context) chlo.register_chlo_dialect(self.context)
def _cond_lowering(ctx, index, *args, branches, linear): del linear # Unused. joined_effects = core.join_effects(*(branch.effects for branch in branches)) ordered_effects = [ eff for eff in joined_effects if eff in core.ordered_effects ] num_tokens = len(ordered_effects) tokens_in = ctx.tokens_in.subset(ordered_effects) output_token_types = [mlir.token_type() for _ in ordered_effects] output_types = [ *output_token_types, *_map(mlir.aval_to_ir_types, ctx.avals_out) ] flat_output_types = util.flatten(output_types) # mhlo.CaseOp takes a single argument 'index' and the corresponding blocks # have no arguments; the computation within the block uses implicit # captures. case_op = mhlo.CaseOp(flat_output_types, index=index, num_branches=len(branches)) name_stack = extend_name_stack(ctx.module_context.name_stack, 'cond') for i, jaxpr in enumerate(branches): branch = case_op.regions[i].blocks.append() with ir.InsertionPoint(branch): sub_ctx = ctx.module_context.replace( name_stack=xla.extend_name_stack(name_stack, f'branch_{i}_fun')) out_vals, tokens_out = mlir.jaxpr_subcomp( sub_ctx, jaxpr.jaxpr, tokens_in, _map(mlir.ir_constants, jaxpr.consts), *_map(mlir.wrap_singleton_ir_values, args)) out_tokens = [tokens_out.get(eff) for eff in ordered_effects] out_vals = [*out_tokens, *out_vals] mhlo.ReturnOp(util.flatten(out_vals)) tokens_and_outputs = util.unflatten(case_op.results, _map(len, output_types)) tokens, outputs = util.split_list(tokens_and_outputs, [num_tokens]) ctx.set_tokens_out(mlir.TokenSet(zip(ordered_effects, tokens))) return outputs
def __init__( self, platform: str, axis_context: AxisContext, name_stack: NameStack, context: Optional[ir.Context] = None, module: Optional[ir.Module] = None, ip: Optional[ir.InsertionPoint] = None, symbol_table: Optional[ir.SymbolTable] = None, cached_primitive_lowerings: Optional[Dict[Any, FuncOpType]] = None): assert platform is not None self.context = context or make_ir_context() self.module = module or ir.Module.create( loc=ir.Location.unknown(self.context)) self.ip = ip or ir.InsertionPoint(self.module.body) self.symbol_table = symbol_table or ir.SymbolTable( self.module.operation) self.platform = platform self.axis_context = axis_context self.name_stack = name_stack self.cached_primitive_lowerings = ({} if cached_primitive_lowerings is None else cached_primitive_lowerings)
def lower_jaxpr_to_fun( ctx: ModuleContext, name: str, jaxpr: core.ClosedJaxpr, *, public: bool = False, replace_units_with_dummy: bool = False, replace_tokens_with_dummy: bool = False, replicated_args: Optional[Sequence[bool]] = None, arg_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None, result_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None, use_sharding_annotations: bool = True, input_output_aliases: Optional[Sequence[Optional[int]]] = None ) -> FuncOpType: """Lowers jaxpr and its callees to an IR function. Assumes that an MLIR context, location, and insertion point are set. Args: ctx: the lowering context. name: the function name. The name will be uniquified by the symbol table, so it is ok to use the same name multiple times. jaxpr: the jaxpr to lower. public: if true, the function's visibility is set to "public". replace_units_with_dummy: if true, unit arguments/return values are replaced with bool arrays of size [0]. replace_tokens_with_dummy: if true, token arguments/return values are replaced with bool arrays of size [0]. replicated_args: if present, annotates arguments as replicated. arg_shardings: sharding annotations for each argument (optional). result_shardings: sharding annotations for each argument (optional). use_sharding_annotations: if True, use mhlo.sharding annotations on parameters and return values to express sharding. If False, use mhlo.custom_call operators with sharding annotations. TODO(b/228598865): remove this option when mhlo.sharding annotations are propagated on non-entry functions during MHLO->HLO conversion. input_output_aliases: optional sequence that maps argument numbers to the corresponding output that should alias them. Returns the name of the function. """ def aval_to_types(aval): if replace_units_with_dummy and aval is core.abstract_unit: aval = core.ShapedArray((), np.dtype(np.bool_)) elif replace_tokens_with_dummy and aval is core.abstract_token: aval = core.ShapedArray((), np.dtype(np.bool_)) return aval_to_ir_types(aval) input_types = map(aval_to_types, jaxpr.in_avals) output_types = map(aval_to_types, jaxpr.out_avals) flat_input_types = util.flatten(input_types) flat_output_types = util.flatten(output_types) ftype = ir.FunctionType.get(flat_input_types, flat_output_types) func_op = FuncOp(name, ftype, ip=ctx.ip) func_op.attributes["sym_visibility"] = ir.StringAttr.get( "public" if public else "private") ctx.symbol_table.insert(func_op) ir_arg_shardings = None if arg_shardings is not None: ir_arg_shardings = util.flatten( [[sharding] * len(types) for sharding, types in zip(arg_shardings, input_types)]) ir_result_shardings = None if result_shardings is not None: ir_result_shardings = util.flatten( [[sharding] * len(types) for sharding, types in zip(result_shardings, output_types)]) if (replicated_args is not None or ir_arg_shardings is not None or input_output_aliases is not None): arg_attrs: List[Dict[str, ir.Attribute]] = [ {} for _ in range(len(flat_input_types)) ] if replicated_args is not None: replicated_ir_args = [ [replicated] * len(types) for replicated, types in zip(replicated_args, input_types) ] for attrs, replicated in zip(arg_attrs, util.flatten(replicated_ir_args)): if replicated: attrs[ "mhlo.is_same_data_across_replicas"] = ir.UnitAttr.get( ) if use_sharding_annotations and ir_arg_shardings is not None: for attrs, sharding in zip(arg_attrs, ir_arg_shardings): if sharding is not None: attrs["mhlo.sharding"] = ir.StringAttr.get( sharding.SerializeToString()) if input_output_aliases is not None: output_ids = util.unflatten(list(range(len(flat_output_types))), map(len, output_types)) aliases: List[Optional[int]] = [] for types, alias in zip(input_types, input_output_aliases): if alias is None: aliases.extend([None] * len(types)) else: aliases.extend(output_ids[alias]) for attrs, alias in zip(arg_attrs, aliases): if alias is not None: attrs["tf.aliasing_output"] = i32_attr(alias) func_op.arg_attrs = ir.ArrayAttr.get( [ir.DictAttr.get(attrs) for attrs in arg_attrs]) if use_sharding_annotations and ir_result_shardings is not None: func_op.result_attrs = ir.ArrayAttr.get([ ir.DictAttr.get({} if sharding is None else { "mhlo.sharding": ir.StringAttr.get(sharding.SerializeToString()) }) for sharding in ir_result_shardings ]) entry_block = func_op.add_entry_block() with ir.InsertionPoint(entry_block): flat_args = entry_block.arguments if not use_sharding_annotations and ir_arg_shardings is not None: flat_args = map(wrap_with_sharding_op, flat_args, ir_arg_shardings) unflattened_args = util.unflatten(flat_args, map(len, input_types)) args: List[List[ir.Value]] = [] for aval, arg in zip(jaxpr.in_avals, unflattened_args): if replace_units_with_dummy and aval is core.abstract_unit: args.append([]) elif replace_tokens_with_dummy and aval is core.abstract_token: args.append(mhlo.CreateTokenOp(mhlo.TokenType.get()).results) else: args.append(arg) callee_name_stack = xla.extend_name_stack(ctx.name_stack, xla.wrap_name(name, 'jit')) out_vals = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack), jaxpr.jaxpr, map(ir_constants, jaxpr.consts), *args) outs = [] for aval, out in zip(jaxpr.out_avals, out_vals): if replace_units_with_dummy and aval is core.abstract_unit: outs.append(ir_constants(np.zeros((), np.bool_))) elif replace_tokens_with_dummy and aval is core.abstract_token: outs.append(ir_constants(np.zeros((), np.bool_))) else: outs.append(out) flat_outputs = util.flatten(outs) if not use_sharding_annotations and ir_result_shardings is not None: flat_outputs = map(wrap_with_sharding_op, flat_outputs, ir_result_shardings) func_dialect.ReturnOp(flat_outputs) return func_op
def lower_jaxpr_to_fun(ctx: LoweringContext, name: str, jaxpr: core.ClosedJaxpr, *, public: bool = False, replace_units_with_dummy: bool = False, replace_tokens_with_dummy: bool = False) -> str: """Lowers jaxpr and its callees to an IR function. Assumes that an MLIR context, location, and insertion point are set. Args: ctx: the lowering context. name: the function name. The name will be uniquified by the symbol table, so it is ok to use the same name multiple times. jaxpr: the jaxpr to lower. public: if true, the function's visibility is set to "public". replace_units_with_dummy: if true, unit arguments/return values are replaced with bool arrays of size [0]. replace_tokens_with_dummy: if true, token arguments/return values are replaced with bool arrays of size [0]. Returns the name of the function. """ def aval_to_types(aval): if replace_units_with_dummy and aval is core.abstract_unit: aval = core.ShapedArray((), np.dtype(np.bool_)) elif replace_tokens_with_dummy and aval is core.abstract_token: aval = core.ShapedArray((), np.dtype(np.bool_)) return aval_to_ir_types(aval) input_types = map(aval_to_types, jaxpr.in_avals) output_types = map(aval_to_types, jaxpr.out_avals) flat_input_types = util.flatten(input_types) flat_output_types = util.flatten(output_types) ftype = ir.FunctionType.get(flat_input_types, flat_output_types) func_op = builtin.FuncOp(name, ftype, ip=ctx.ip) func_op.attributes["sym_visibility"] = ir.StringAttr.get( "public" if public else "private") symbol_name = ir.StringAttr(ctx.symbol_table.insert(func_op)).value entry_block = func_op.add_entry_block() with ir.InsertionPoint(entry_block): unflattened_args = util.unflatten(entry_block.arguments, map(len, input_types)) args: List[List[ir.Value]] = [] for aval, arg in zip(jaxpr.in_avals, unflattened_args): if replace_units_with_dummy and aval is core.abstract_unit: args.append([]) elif replace_tokens_with_dummy and aval is core.abstract_token: args.append(mhlo.CreateTokenOp(mhlo.TokenType.get()).results) else: args.append(arg) callee_name_stack = xla.extend_name_stack(ctx.name_stack, xla.wrap_name(name, 'jit')) out_vals = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack), jaxpr.jaxpr, map(ir_constants, jaxpr.consts), *args) outs = [] for aval, out in zip(jaxpr.out_avals, out_vals): if replace_units_with_dummy and aval is core.abstract_unit: outs.append(ir_constants(np.zeros((), np.bool_))) elif replace_tokens_with_dummy and aval is core.abstract_token: outs.append(ir_constants(np.zeros((), np.bool_))) else: outs.append(out) std.ReturnOp(util.flatten(outs)) return symbol_name
def _select_and_gather_add_lowering(ctx, tangents, operand, *, select_prim, window_dimensions, window_strides, padding, base_dilation, window_dilation, max_bits=64): _, operand_aval, = ctx.avals_in out_aval, = ctx.avals_out dtype = operand_aval.dtype etype = mlir.dtype_to_ir_type(dtype) nbits = dtypes.finfo(dtype).bits assert nbits <= max_bits double_word_reduction = nbits * 2 <= max_bits const = lambda dtype, x: mlir.ir_constant(np.array(x, dtype=dtype), canonicalize_types=False) if jax._src.lib.mlir_api_version >= 9: def _broadcast(x, dims): return mhlo.BroadcastOp(x, mlir.dense_int_elements(dims)) else: def _broadcast(x, dims): etype = ir.RankedTensorType(x.type).element_type return mhlo.BroadcastOp(ir.RankedTensorType(dims, etype), x, mlir.dense_int_elements(dims)) if double_word_reduction: # TODO(b/73062247): XLA doesn't yet implement ReduceWindow on tuples, so # we implement a pair-wise ReduceWindow by packing two k-bit values into # 2k-bit unsigned integer using bit tricks. word_dtype = lax._UINT_DTYPES[nbits] double_word_dtype = lax._UINT_DTYPES[nbits * 2] word_type = mlir.dtype_to_ir_type(word_dtype) double_word_type = mlir.dtype_to_ir_type(double_word_dtype) # Packs two values into a tuple. def pack(a, b): a_dims = ir.RankedTensorType(a.type).shape b_dims = ir.RankedTensorType(b.type).shape a = mhlo.BitcastConvertOp( ir.RankedTensorType.get(a_dims, word_type), a) b = mhlo.BitcastConvertOp( ir.RankedTensorType.get(b_dims, word_type), b) a = mhlo.ConvertOp( ir.RankedTensorType.get(a_dims, double_word_type), a) b = mhlo.ConvertOp( ir.RankedTensorType.get(b_dims, double_word_type), b) a = mhlo.ShiftLeftOp( a, _broadcast(const(double_word_dtype, nbits), a_dims)) return mhlo.OrOp(a, b) # Unpacks the first element of a tuple. def fst(t): dims = ir.RankedTensorType(t.type).shape st = mhlo.ShiftRightLogicalOp(t, const(double_word_dtype, nbits)) return mhlo.BitcastConvertOp( ir.RankedTensorType.get(dims, etype), mhlo.ConvertOp(ir.RankedTensorType.get(dims, word_type), st)).result # Unpacks the second element of a tuple. def snd(t): dims = ir.RankedTensorType(t.type).shape return mhlo.BitcastConvertOp( ir.RankedTensorType.get(dims, etype), mhlo.ConvertOp(ir.RankedTensorType.get(dims, word_type), t)).result else: # The double-word trick above only works if we have a sufficiently large # type. As an alternative, we can pack two half words into a single word, # at the cost of precision. # TODO(b/73062247): add support for tuple reductions and remove this case. warnings.warn( "Using reduced precision for gradient of reduce-window " "min/max operator to work around missing XLA support for " "pair-reductions. This is likely from a second or " "higher derivative of a max-pooling operation.") r_nbits = nbits // 2 # Drop/round the bottom mantissa bits. nexp = dtypes.finfo(dtype).nexp nmant = r_nbits - nexp - 1 double_word_dtype = word_dtype = lax._UINT_DTYPES[nbits] double_word_type = word_type = mlir.dtype_to_ir_type(word_dtype) # Packs two values into a tuple. def pack(a, b): a_dims = ir.RankedTensorType(a.type).shape b_dims = ir.RankedTensorType(b.type).shape if jax._src.lib.mlir_api_version >= 21: a = mhlo.ReducePrecisionOp(a, exponent_bits=mlir.i32_attr(nexp), mantissa_bits=mlir.i32_attr(nmant)) b = mhlo.ReducePrecisionOp(b, exponent_bits=mlir.i32_attr(nexp), mantissa_bits=mlir.i32_attr(nmant)) else: a = mhlo.ReducePrecisionOp(a.type, a, exponent_bits=mlir.i32_attr(nexp), mantissa_bits=mlir.i32_attr(nmant)) b = mhlo.ReducePrecisionOp(b.type, b, exponent_bits=mlir.i32_attr(nexp), mantissa_bits=mlir.i32_attr(nmant)) a = mhlo.BitcastConvertOp( ir.RankedTensorType.get(a_dims, word_type), a) b = mhlo.BitcastConvertOp( ir.RankedTensorType.get(b_dims, word_type), b) b = mhlo.ShiftRightLogicalOp( b, _broadcast(const(word_dtype, r_nbits), b_dims)) return mhlo.OrOp(a, b) # Unpacks the first element of a tuple. def fst(t): st = mhlo.AndOp(t, const(word_dtype, ((1 << r_nbits) - 1) << r_nbits)) return mhlo.BitcastConvertOp(ir.RankedTensorType.get([], etype), st).result # Unpacks the second element of a tuple. def snd(t): dims = ir.RankedTensorType(t.type).shape return mhlo.BitcastConvertOp( ir.RankedTensorType.get(dims, etype), mhlo.ShiftLeftOp(t, _broadcast(const(word_dtype, r_nbits), dims))).result assert select_prim is lax.ge_p or select_prim is lax.le_p, select_prim init = -np.inf if select_prim is lax.ge_p else np.inf rw = mhlo.ReduceWindowOp( [ir.RankedTensorType.get(out_aval.shape, double_word_type)], pack(operand, tangents), pack(const(dtype, init), const(dtype, 0)), mlir.dense_int_elements(window_dimensions), window_strides=mlir.dense_int_elements(window_strides), base_dilations=mlir.dense_int_elements(base_dilation), window_dilations=mlir.dense_int_elements(window_dilation), padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64), shape=(len(padding), 2))) scalar_type = ir.RankedTensorType.get([], double_word_type) reducer = rw.regions[0].blocks.append(scalar_type, scalar_type) with ir.InsertionPoint(reducer): x, y = reducer.arguments assert select_prim is lax.ge_p or select_prim is lax.le_p which = "GE" if select_prim is lax.ge_p else "LE" out = mhlo.SelectOp(mlir.compare_mhlo(fst(x), fst(y), which), x, y) mhlo.ReturnOp(out) return [snd(rw.result)]