Exemplo n.º 1
0
 def pack(a, b):
     a_dims = ir.RankedTensorType(a.type).shape
     b_dims = ir.RankedTensorType(b.type).shape
     if jax._src.lib.mlir_api_version >= 21:
         a = mhlo.ReducePrecisionOp(a,
                                    exponent_bits=mlir.i32_attr(nexp),
                                    mantissa_bits=mlir.i32_attr(nmant))
         b = mhlo.ReducePrecisionOp(b,
                                    exponent_bits=mlir.i32_attr(nexp),
                                    mantissa_bits=mlir.i32_attr(nmant))
     else:
         a = mhlo.ReducePrecisionOp(a.type,
                                    a,
                                    exponent_bits=mlir.i32_attr(nexp),
                                    mantissa_bits=mlir.i32_attr(nmant))
         b = mhlo.ReducePrecisionOp(b.type,
                                    b,
                                    exponent_bits=mlir.i32_attr(nexp),
                                    mantissa_bits=mlir.i32_attr(nmant))
     a = mhlo.BitcastConvertOp(
         ir.RankedTensorType.get(a_dims, word_type), a)
     b = mhlo.BitcastConvertOp(
         ir.RankedTensorType.get(b_dims, word_type), b)
     b = mhlo.ShiftRightLogicalOp(
         b, _broadcast(const(word_dtype, r_nbits), b_dims))
     return mhlo.OrOp(a, b)
Exemplo n.º 2
0
 def pack(a, b):
   a_dims = ir.RankedTensorType(a.type).shape
   b_dims = ir.RankedTensorType(b.type).shape
   a = mhlo.BitcastConvertOp(ir.RankedTensorType.get(a_dims, word_type), a)
   b = mhlo.BitcastConvertOp(ir.RankedTensorType.get(b_dims, word_type), b)
   a = mhlo.ConvertOp(ir.RankedTensorType.get(a_dims, double_word_type), a)
   b = mhlo.ConvertOp(ir.RankedTensorType.get(b_dims, double_word_type), b)
   a = mhlo.ShiftLeftOp(a,
                        _broadcast(const(double_word_dtype, nbits), a_dims))
   return mhlo.OrOp(a, b)
Exemplo n.º 3
0
 def fst(t):
     dims = ir.RankedTensorType(t.type).shape
     st = mhlo.ShiftRightLogicalOp(t, const(double_word_dtype, nbits))
     return mhlo.BitcastConvertOp(
         ir.RankedTensorType.get(dims, etype),
         mhlo.ConvertOp(ir.RankedTensorType.get(dims, word_type),
                        st)).result
Exemplo n.º 4
0
def compare_mhlo(x, y, direction, type):
    """Creates mhlo.CompareOp."""
    if jax._src.lib.mlir_api_version >= 5:
        return mhlo.CompareOp(x, y,
                              mhlo.ComparisonDirectionAttr.get(direction),
                              mhlo.ComparisonTypeAttr.get(type))
    dims = ir.RankedTensorType(x.type).shape
    bool_shape = ir.RankedTensorType.get(dims, ir.IntegerType.get_signless(1))
    if jax._src.lib.mlir_api_version >= 3:
        return mhlo.CompareOp(bool_shape, x, y,
                              mhlo.ComparisonDirectionAttr.get(direction),
                              mhlo.ComparisonTypeAttr.get(type))
    return mhlo.CompareOp(bool_shape, x, y, ir.StringAttr.get(direction),
                          ir.StringAttr.get(type))
Exemplo n.º 5
0
def _minmax_mhlo(op, cmp, x, y):
    """Min/max that compares complex values lexicographically as pairs."""
    tensor_type = ir.RankedTensorType(x.type)
    if ir.ComplexType.isinstance(tensor_type.element_type):
        rx = mhlo.RealOp(x).result
        ry = mhlo.RealOp(y).result
        real_eq = compare_mhlo(rx, ry, "EQ", "FLOAT")
        real_cmp = compare_mhlo(rx, ry, cmp, "FLOAT")
        imag_cmp = compare_mhlo(
            mhlo.ImagOp(x).result,
            mhlo.ImagOp(y).result, cmp, "FLOAT")
        which = mhlo.SelectOp(real_eq, imag_cmp, real_cmp).result
        return mhlo.SelectOp(which, x, y)
    else:
        return op(x, y)
Exemplo n.º 6
0
Arquivo: mlir.py Projeto: GJBoth/jax
def _minmax_mhlo(op, cmp, x, y):
  """Min/max that compares complex values lexicographically as pairs."""
  tensor_type = ir.RankedTensorType(x.type)
  if ir.ComplexType.isinstance(tensor_type.element_type):
    rx = mhlo.RealOp(x).result
    ry = mhlo.RealOp(y).result
    dims = [tensor_type.get_dim_size(i) for i in range(tensor_type.rank)]
    bool_shape = ir.RankedTensorType.get(dims, ir.IntegerType.get_signless(1))
    real_eq = mhlo.CompareOp(bool_shape, rx, ry, ir.StringAttr.get("EQ"),
                             ir.StringAttr.get("FLOAT"))
    real_cmp = mhlo.CompareOp(bool_shape, rx, ry,
                              ir.StringAttr.get(cmp),
                              ir.StringAttr.get("FLOAT"))
    imag_cmp = mhlo.CompareOp(bool_shape, mhlo.ImagOp(x).result,
                              mhlo.ImagOp(y).result,
                              ir.StringAttr.get(cmp),
                              ir.StringAttr.get("FLOAT"))
    which = mhlo.SelectOp(real_eq, imag_cmp, real_cmp).result
    return mhlo.SelectOp(which, x, y)
  else:
    return op(x, y)
Exemplo n.º 7
0
 def snd(t):
     dims = ir.RankedTensorType(t.type).shape
     return mhlo.BitcastConvertOp(
         ir.RankedTensorType.get(dims, etype),
         mhlo.ShiftLeftOp(t, _broadcast(const(word_dtype, r_nbits),
                                        dims))).result
Exemplo n.º 8
0
 def snd(t):
     dims = ir.RankedTensorType(t.type).shape
     return mhlo.BitcastConvertOp(
         ir.RankedTensorType.get(dims, etype),
         mhlo.ConvertOp(ir.RankedTensorType.get(dims, word_type),
                        t)).result
Exemplo n.º 9
0
 def _broadcast(x, dims):
     etype = ir.RankedTensorType(x.type).element_type
     return mhlo.BroadcastOp(ir.RankedTensorType(dims, etype), x,
                             mlir.dense_int_elements(dims))