def compare_mhlo(x, y, direction, type): """Creates mhlo.CompareOp.""" if jax._src.lib.mlir_api_version >= 5: return mhlo.CompareOp(x, y, mhlo.ComparisonDirectionAttr.get(direction), mhlo.ComparisonTypeAttr.get(type)) dims = ir.RankedTensorType(x.type).shape bool_shape = ir.RankedTensorType.get(dims, ir.IntegerType.get_signless(1)) if jax._src.lib.mlir_api_version >= 3: return mhlo.CompareOp(bool_shape, x, y, mhlo.ComparisonDirectionAttr.get(direction), mhlo.ComparisonTypeAttr.get(type)) return mhlo.CompareOp(bool_shape, x, y, ir.StringAttr.get(direction), ir.StringAttr.get(type))
def _remat_using_while(ctx, avals_in, avals_out, *args, name, call_jaxpr): input_types = map(aval_to_ir_types, avals_in) output_types = map(aval_to_ir_types, avals_out) flat_output_types = util.flatten(output_types) int32_scalar_type = aval_to_ir_type( core.ShapedArray((), np.dtype(np.int32))) loop_carry_types = [(int32_scalar_type, )] + input_types + output_types flat_loop_carry_types = util.flatten(loop_carry_types) counter_init = ir_constants(np.array(0, np.int32)) flat_args = flatten_lowering_ir_args((counter_init, ) + args + tuple( _dummy_like_aval(aval) for aval in avals_out)) loop_carry_tuple_type = ir.TupleType.get_tuple(flat_loop_carry_types) init_carry = mhlo.TupleOp(loop_carry_tuple_type, flat_args) one = ir_constant(np.array(1, np.int32)) while_op = mhlo.WhileOp([loop_carry_tuple_type], [init_carry.result]) # Loop condition cond_block = while_op.regions[0].blocks.append(loop_carry_tuple_type) with ir.InsertionPoint(cond_block): bool_scalar_type = aval_to_ir_type( core.ShapedArray((), np.dtype(np.bool_))) two = ir_constant(np.array(2, np.int32)) shape = ir_constant(np.array((), np.int64), canonicalize_types=False) rng = mhlo.RngUniformOp(one, two, shape).result i = mhlo.GetTupleElementOp(int32_scalar_type, cond_block.arguments[0], i32_attr(0)) cmp = mhlo.CompareOp(bool_scalar_type, i, rng, ir.StringAttr.get("LT"), ir.StringAttr.get("SIGNED")).result mhlo.ReturnOp([cmp]) body_block = while_op.regions[1].blocks.append(loop_carry_tuple_type) with ir.InsertionPoint(body_block): flat_body_args = [ mhlo.GetTupleElementOp(input_type, body_block.arguments[0], i32_attr(i)).result for i, input_type in enumerate(flat_loop_carry_types) ] body_args = util.unflatten(flat_body_args, map(len, loop_carry_types)) ((i, ), ), y, _ = util.split_list(body_args, [1, len(avals_in)]) body_ctx = ctx.replace(name_stack=xla.extend_name_stack( ctx.name_stack, xla.wrap_name(name, 'remat'))) z = jaxpr_subcomp(body_ctx, call_jaxpr, (), *y) i_next = mhlo.AddOp(i, one).result new_carry = mhlo.TupleOp(loop_carry_tuple_type, [i_next, *util.flatten(y), *util.flatten(z)]) mhlo.ReturnOp([new_carry.result]) outputs = [ mhlo.GetTupleElementOp(output_type, while_op.result, i32_attr(1 + len(avals_in) + i)).result for i, output_type in enumerate(flat_output_types) ] return util.unflatten(outputs, map(len, output_types))
def _minmax_mhlo(op, cmp, x, y): """Min/max that compares complex values lexicographically as pairs.""" tensor_type = ir.RankedTensorType(x.type) if ir.ComplexType.isinstance(tensor_type.element_type): rx = mhlo.RealOp(x).result ry = mhlo.RealOp(y).result dims = [tensor_type.get_dim_size(i) for i in range(tensor_type.rank)] bool_shape = ir.RankedTensorType.get(dims, ir.IntegerType.get_signless(1)) real_eq = mhlo.CompareOp(bool_shape, rx, ry, ir.StringAttr.get("EQ"), ir.StringAttr.get("FLOAT")) real_cmp = mhlo.CompareOp(bool_shape, rx, ry, ir.StringAttr.get(cmp), ir.StringAttr.get("FLOAT")) imag_cmp = mhlo.CompareOp(bool_shape, mhlo.ImagOp(x).result, mhlo.ImagOp(y).result, ir.StringAttr.get(cmp), ir.StringAttr.get("FLOAT")) which = mhlo.SelectOp(real_eq, imag_cmp, real_cmp).result return mhlo.SelectOp(which, x, y) else: return op(x, y)
def convert_mhlo(x, aval_in, aval_out): """Variant of convert that has XLA HLO semantics. In particular, treat casts to boolean as x != 0, rather than truncating integer values (b/209440332).""" if aval_out.dtype == np.dtype(np.bool_): if dtypes.issubdtype(aval_in.dtype, np.inexact): compare_type = "FLOAT" elif dtypes.issubdtype(aval_in.dtype, np.signedinteger): compare_type = "SIGNED" else: compare_type = "UNSIGNED" return mhlo.CompareOp( aval_to_ir_type(aval_out), x, full_like_aval(0, aval_in), ir.StringAttr.get("NE"), ir.StringAttr.get(compare_type)).result return mhlo.ConvertOp(aval_to_ir_type(aval_out), x).result