Exemple #1
0
def _minmax_mhlo(op, cmp, x, y):
    """Min/max that compares complex values lexicographically as pairs."""
    tensor_type = ir.RankedTensorType(x.type)
    if ir.ComplexType.isinstance(tensor_type.element_type):
        rx = mhlo.RealOp(x).result
        ry = mhlo.RealOp(y).result
        real_eq = compare_mhlo(rx, ry, "EQ", "FLOAT")
        real_cmp = compare_mhlo(rx, ry, cmp, "FLOAT")
        imag_cmp = compare_mhlo(
            mhlo.ImagOp(x).result,
            mhlo.ImagOp(y).result, cmp, "FLOAT")
        which = mhlo.SelectOp(real_eq, imag_cmp, real_cmp).result
        return mhlo.SelectOp(which, x, y)
    else:
        return op(x, y)
Exemple #2
0
def _minmax_mhlo(op, cmp, x, y):
  """Min/max that compares complex values lexicographically as pairs."""
  tensor_type = ir.RankedTensorType(x.type)
  if ir.ComplexType.isinstance(tensor_type.element_type):
    rx = mhlo.RealOp(x).result
    ry = mhlo.RealOp(y).result
    dims = [tensor_type.get_dim_size(i) for i in range(tensor_type.rank)]
    bool_shape = ir.RankedTensorType.get(dims, ir.IntegerType.get_signless(1))
    real_eq = mhlo.CompareOp(bool_shape, rx, ry, ir.StringAttr.get("EQ"),
                             ir.StringAttr.get("FLOAT"))
    real_cmp = mhlo.CompareOp(bool_shape, rx, ry,
                              ir.StringAttr.get(cmp),
                              ir.StringAttr.get("FLOAT"))
    imag_cmp = mhlo.CompareOp(bool_shape, mhlo.ImagOp(x).result,
                              mhlo.ImagOp(y).result,
                              ir.StringAttr.get(cmp),
                              ir.StringAttr.get("FLOAT"))
    which = mhlo.SelectOp(real_eq, imag_cmp, real_cmp).result
    return mhlo.SelectOp(which, x, y)
  else:
    return op(x, y)