def pack(a, b):
     a_dims = ir.RankedTensorType(a.type).shape
     b_dims = ir.RankedTensorType(b.type).shape
     if jax._src.lib.mlir_api_version >= 21:
         a = mhlo.ReducePrecisionOp(a,
                                    exponent_bits=mlir.i32_attr(nexp),
                                    mantissa_bits=mlir.i32_attr(nmant))
         b = mhlo.ReducePrecisionOp(b,
                                    exponent_bits=mlir.i32_attr(nexp),
                                    mantissa_bits=mlir.i32_attr(nmant))
     else:
         a = mhlo.ReducePrecisionOp(a.type,
                                    a,
                                    exponent_bits=mlir.i32_attr(nexp),
                                    mantissa_bits=mlir.i32_attr(nmant))
         b = mhlo.ReducePrecisionOp(b.type,
                                    b,
                                    exponent_bits=mlir.i32_attr(nexp),
                                    mantissa_bits=mlir.i32_attr(nmant))
     a = mhlo.BitcastConvertOp(
         ir.RankedTensorType.get(a_dims, word_type), a)
     b = mhlo.BitcastConvertOp(
         ir.RankedTensorType.get(b_dims, word_type), b)
     b = mhlo.ShiftRightLogicalOp(
         b, _broadcast(const(word_dtype, r_nbits), b_dims))
     return mhlo.OrOp(a, b)
Exemple #2
0
 def pack(a, b):
   a_dims = ir.RankedTensorType(a.type).shape
   b_dims = ir.RankedTensorType(b.type).shape
   a = mhlo.BitcastConvertOp(ir.RankedTensorType.get(a_dims, word_type), a)
   b = mhlo.BitcastConvertOp(ir.RankedTensorType.get(b_dims, word_type), b)
   a = mhlo.ConvertOp(ir.RankedTensorType.get(a_dims, double_word_type), a)
   b = mhlo.ConvertOp(ir.RankedTensorType.get(b_dims, double_word_type), b)
   a = mhlo.ShiftLeftOp(a,
                        _broadcast(const(double_word_dtype, nbits), a_dims))
   return mhlo.OrOp(a, b)
 def fst(t):
     dims = ir.RankedTensorType(t.type).shape
     st = mhlo.ShiftRightLogicalOp(t, const(double_word_dtype, nbits))
     return mhlo.BitcastConvertOp(
         ir.RankedTensorType.get(dims, etype),
         mhlo.ConvertOp(ir.RankedTensorType.get(dims, word_type),
                        st)).result
 def snd(t):
     dims = ir.RankedTensorType(t.type).shape
     return mhlo.BitcastConvertOp(
         ir.RankedTensorType.get(dims, etype),
         mhlo.ShiftLeftOp(t, _broadcast(const(word_dtype, r_nbits),
                                        dims))).result
 def fst(t):
     st = mhlo.AndOp(t,
                     const(word_dtype, ((1 << r_nbits) - 1) << r_nbits))
     return mhlo.BitcastConvertOp(ir.RankedTensorType.get([], etype),
                                  st).result
 def snd(t):
     dims = ir.RankedTensorType(t.type).shape
     return mhlo.BitcastConvertOp(
         ir.RankedTensorType.get(dims, etype),
         mhlo.ConvertOp(ir.RankedTensorType.get(dims, word_type),
                        t)).result