def pack(a, b): a_dims = ir.RankedTensorType(a.type).shape b_dims = ir.RankedTensorType(b.type).shape if jax._src.lib.mlir_api_version >= 21: a = mhlo.ReducePrecisionOp(a, exponent_bits=mlir.i32_attr(nexp), mantissa_bits=mlir.i32_attr(nmant)) b = mhlo.ReducePrecisionOp(b, exponent_bits=mlir.i32_attr(nexp), mantissa_bits=mlir.i32_attr(nmant)) else: a = mhlo.ReducePrecisionOp(a.type, a, exponent_bits=mlir.i32_attr(nexp), mantissa_bits=mlir.i32_attr(nmant)) b = mhlo.ReducePrecisionOp(b.type, b, exponent_bits=mlir.i32_attr(nexp), mantissa_bits=mlir.i32_attr(nmant)) a = mhlo.BitcastConvertOp( ir.RankedTensorType.get(a_dims, word_type), a) b = mhlo.BitcastConvertOp( ir.RankedTensorType.get(b_dims, word_type), b) b = mhlo.ShiftRightLogicalOp( b, _broadcast(const(word_dtype, r_nbits), b_dims)) return mhlo.OrOp(a, b)
def pack(a, b): a_dims = ir.RankedTensorType(a.type).shape b_dims = ir.RankedTensorType(b.type).shape a = mhlo.BitcastConvertOp(ir.RankedTensorType.get(a_dims, word_type), a) b = mhlo.BitcastConvertOp(ir.RankedTensorType.get(b_dims, word_type), b) a = mhlo.ConvertOp(ir.RankedTensorType.get(a_dims, double_word_type), a) b = mhlo.ConvertOp(ir.RankedTensorType.get(b_dims, double_word_type), b) a = mhlo.ShiftLeftOp(a, _broadcast(const(double_word_dtype, nbits), a_dims)) return mhlo.OrOp(a, b)
def fst(t): dims = ir.RankedTensorType(t.type).shape st = mhlo.ShiftRightLogicalOp(t, const(double_word_dtype, nbits)) return mhlo.BitcastConvertOp( ir.RankedTensorType.get(dims, etype), mhlo.ConvertOp(ir.RankedTensorType.get(dims, word_type), st)).result
def compare_mhlo(x, y, direction, type): """Creates mhlo.CompareOp.""" if jax._src.lib.mlir_api_version >= 5: return mhlo.CompareOp(x, y, mhlo.ComparisonDirectionAttr.get(direction), mhlo.ComparisonTypeAttr.get(type)) dims = ir.RankedTensorType(x.type).shape bool_shape = ir.RankedTensorType.get(dims, ir.IntegerType.get_signless(1)) if jax._src.lib.mlir_api_version >= 3: return mhlo.CompareOp(bool_shape, x, y, mhlo.ComparisonDirectionAttr.get(direction), mhlo.ComparisonTypeAttr.get(type)) return mhlo.CompareOp(bool_shape, x, y, ir.StringAttr.get(direction), ir.StringAttr.get(type))
def _minmax_mhlo(op, cmp, x, y): """Min/max that compares complex values lexicographically as pairs.""" tensor_type = ir.RankedTensorType(x.type) if ir.ComplexType.isinstance(tensor_type.element_type): rx = mhlo.RealOp(x).result ry = mhlo.RealOp(y).result real_eq = compare_mhlo(rx, ry, "EQ", "FLOAT") real_cmp = compare_mhlo(rx, ry, cmp, "FLOAT") imag_cmp = compare_mhlo( mhlo.ImagOp(x).result, mhlo.ImagOp(y).result, cmp, "FLOAT") which = mhlo.SelectOp(real_eq, imag_cmp, real_cmp).result return mhlo.SelectOp(which, x, y) else: return op(x, y)
def _minmax_mhlo(op, cmp, x, y): """Min/max that compares complex values lexicographically as pairs.""" tensor_type = ir.RankedTensorType(x.type) if ir.ComplexType.isinstance(tensor_type.element_type): rx = mhlo.RealOp(x).result ry = mhlo.RealOp(y).result dims = [tensor_type.get_dim_size(i) for i in range(tensor_type.rank)] bool_shape = ir.RankedTensorType.get(dims, ir.IntegerType.get_signless(1)) real_eq = mhlo.CompareOp(bool_shape, rx, ry, ir.StringAttr.get("EQ"), ir.StringAttr.get("FLOAT")) real_cmp = mhlo.CompareOp(bool_shape, rx, ry, ir.StringAttr.get(cmp), ir.StringAttr.get("FLOAT")) imag_cmp = mhlo.CompareOp(bool_shape, mhlo.ImagOp(x).result, mhlo.ImagOp(y).result, ir.StringAttr.get(cmp), ir.StringAttr.get("FLOAT")) which = mhlo.SelectOp(real_eq, imag_cmp, real_cmp).result return mhlo.SelectOp(which, x, y) else: return op(x, y)
def snd(t): dims = ir.RankedTensorType(t.type).shape return mhlo.BitcastConvertOp( ir.RankedTensorType.get(dims, etype), mhlo.ShiftLeftOp(t, _broadcast(const(word_dtype, r_nbits), dims))).result
def snd(t): dims = ir.RankedTensorType(t.type).shape return mhlo.BitcastConvertOp( ir.RankedTensorType.get(dims, etype), mhlo.ConvertOp(ir.RankedTensorType.get(dims, word_type), t)).result
def _broadcast(x, dims): etype = ir.RankedTensorType(x.type).element_type return mhlo.BroadcastOp(ir.RankedTensorType(dims, etype), x, mlir.dense_int_elements(dims))