def pack(a, b): a_dims = ir.RankedTensorType(a.type).shape b_dims = ir.RankedTensorType(b.type).shape if jax._src.lib.mlir_api_version >= 21: a = mhlo.ReducePrecisionOp(a, exponent_bits=mlir.i32_attr(nexp), mantissa_bits=mlir.i32_attr(nmant)) b = mhlo.ReducePrecisionOp(b, exponent_bits=mlir.i32_attr(nexp), mantissa_bits=mlir.i32_attr(nmant)) else: a = mhlo.ReducePrecisionOp(a.type, a, exponent_bits=mlir.i32_attr(nexp), mantissa_bits=mlir.i32_attr(nmant)) b = mhlo.ReducePrecisionOp(b.type, b, exponent_bits=mlir.i32_attr(nexp), mantissa_bits=mlir.i32_attr(nmant)) a = mhlo.BitcastConvertOp( ir.RankedTensorType.get(a_dims, word_type), a) b = mhlo.BitcastConvertOp( ir.RankedTensorType.get(b_dims, word_type), b) b = mhlo.ShiftRightLogicalOp( b, _broadcast(const(word_dtype, r_nbits), b_dims)) return mhlo.OrOp(a, b)
def pack(a, b): a_dims = ir.RankedTensorType(a.type).shape b_dims = ir.RankedTensorType(b.type).shape a = mhlo.BitcastConvertOp(ir.RankedTensorType.get(a_dims, word_type), a) b = mhlo.BitcastConvertOp(ir.RankedTensorType.get(b_dims, word_type), b) a = mhlo.ConvertOp(ir.RankedTensorType.get(a_dims, double_word_type), a) b = mhlo.ConvertOp(ir.RankedTensorType.get(b_dims, double_word_type), b) a = mhlo.ShiftLeftOp(a, _broadcast(const(double_word_dtype, nbits), a_dims)) return mhlo.OrOp(a, b)
def fst(t): dims = ir.RankedTensorType(t.type).shape st = mhlo.ShiftRightLogicalOp(t, const(double_word_dtype, nbits)) return mhlo.BitcastConvertOp( ir.RankedTensorType.get(dims, etype), mhlo.ConvertOp(ir.RankedTensorType.get(dims, word_type), st)).result
def snd(t): dims = ir.RankedTensorType(t.type).shape return mhlo.BitcastConvertOp( ir.RankedTensorType.get(dims, etype), mhlo.ShiftLeftOp(t, _broadcast(const(word_dtype, r_nbits), dims))).result
def fst(t): st = mhlo.AndOp(t, const(word_dtype, ((1 << r_nbits) - 1) << r_nbits)) return mhlo.BitcastConvertOp(ir.RankedTensorType.get([], etype), st).result
def snd(t): dims = ir.RankedTensorType(t.type).shape return mhlo.BitcastConvertOp( ir.RankedTensorType.get(dims, etype), mhlo.ConvertOp(ir.RankedTensorType.get(dims, word_type), t)).result