Ejemplo n.º 1
0
def _normalize_float(x):
  info = dtypes.finfo(dtypes.dtype(x))
  cond = lax.abs(x) < info.tiny
  x1 = _where(cond, x * _lax_const(x, 1 << info.nmant), x)
  x2 = _where(cond, lax.full_like(x, -info.nmant, dtype=np.int32), lax.full_like(x, 0, dtype=np.int32))
  int_type = _INT_DTYPES[info.bits]
  return lax.bitcast_convert_type(x1, int_type), x2
Ejemplo n.º 2
0
def _normalize_float(x):
    info = dtypes.finfo(dtypes.dtype(x))
    int_type = _INT_DTYPES[info.bits]
    cond = lax.abs(x) < info.tiny
    x1 = _where(cond, x * _lax_const(x, 1 << info.nmant), x)
    x2 = _where(cond, int_type(-info.nmant), int_type(0))
    return lax.bitcast_convert_type(x1, int_type), x2
Ejemplo n.º 3
0
def ldexp(x1, x2):
    _check_arraylike("ldexp", x1, x2)
    x1_dtype = dtypes.dtype(x1)
    x2_dtype = dtypes.dtype(x2)
    if (dtypes.issubdtype(x1_dtype, np.complexfloating)
            or dtypes.issubdtype(x2_dtype, np.inexact)):
        raise ValueError(
            f"ldexp not supported for input types {(x1_dtype, x2_dtype)}")

    x1, x2 = _promote_shapes("ldexp", x1, x2)

    dtype = dtypes.canonicalize_dtype(dtypes._to_inexact_dtype(x1_dtype))
    info = dtypes.finfo(dtype)
    int_type = _INT_DTYPES[info.bits]

    x1 = lax.convert_element_type(x1, dtype)
    x2 = lax.convert_element_type(x2, int_type)

    mask = (1 << info.nexp) - 1
    bias = ((1 << info.nexp) - 1) >> 1
    x, e = _normalize_float(x1)
    x2 += e + ((x >> info.nmant) & mask) - bias

    # find underflow/overflow before denormalization
    underflow_cond = x2 < -(bias + info.nmant)
    overflow_cond = x2 > bias

    m = lax.full_like(x, 1, dtype=dtype)

    # denormals
    cond = x2 < -bias + 1
    x2 = _where(cond, x2 + info.nmant, x2)
    m = _where(cond, m / (1 << info.nmant), m)

    x2 = lax.convert_element_type(x2, np.int32)
    x &= ~(mask << info.nmant)
    x |= ((lax.convert_element_type(x2, int_type) + bias) << info.nmant)

    x = lax.convert_element_type(m, dtype) * lax.bitcast_convert_type(x, dtype)

    # underflow
    x = _where(underflow_cond, lax.full_like(x, 0, dtype=dtype), x)
    # overflow
    x = _where(overflow_cond, lax.sign(x1) * lax.full_like(x, np.inf), x)
    # ldexp(x1, x2) = x1 for x1 = inf, -inf, nan, 0
    return _where(isinf(x1) | isnan(x1) | (x1 == 0), x1, x)
Ejemplo n.º 4
0
def frexp(x):
    _check_arraylike("frexp", x)
    x, = _promote_dtypes_inexact(x)
    if dtypes.issubdtype(x.dtype, np.complexfloating):
        raise TypeError("frexp does not support complex-valued inputs")

    dtype = dtypes.dtype(x)
    info = dtypes.finfo(dtype)
    mask = (1 << info.nexp) - 1
    bias = ((1 << info.nexp) - 1) >> 1

    x1, x2 = _normalize_float(x)
    x2 += ((x1 >> info.nmant) & mask) - bias + 1
    x1 &= ~(mask << info.nmant)
    x1 |= (bias - 1) << info.nmant
    x1 = lax.bitcast_convert_type(x1, dtype)

    cond = isinf(x) | isnan(x) | (x == 0)
    x2 = _where(cond, lax_internal._zeros(x2), x2)
    return _where(cond, x, x1), lax.convert_element_type(x2, np.int32)
Ejemplo n.º 5
0
def ldexp(x1, x2):
  _check_arraylike("ldexp", x1, x2)
  dtype = dtypes.canonicalize_dtype(_result_dtype(np.ldexp, x1, x2))
  x1, x2 = _promote_shapes("ldexp", x1, x2)
  x1 = lax.convert_element_type(x1, dtype)

  info = dtypes.finfo(dtype)
  mask = (1 << info.nexp) - 1
  bias = ((1 << info.nexp) - 1) >> 1

  int_type = _INT_DTYPES[info.bits]

  x, e = _normalize_float(x1)
  x2 += e + ((x >> info.nmant) & mask) - bias

  # find underflow/overflow before denormalization
  underflow_cond = x2 < -(bias + info.nmant)
  overflow_cond = x2 > bias

  m = lax.full_like(x, 1, dtype=dtype)

  # denormals
  cond = x2 < -bias + 1
  x2 = _where(cond, x2 + info.nmant, x2)
  m = _where(cond, m / (1 << info.nmant), m)

  x2 = lax.convert_element_type(x2, np.int32)
  x &= ~(mask << info.nmant)
  x |= ((lax.convert_element_type(x2, int_type) + bias) << info.nmant)

  x = lax.convert_element_type(m, dtype) * lax.bitcast_convert_type(x, dtype)

  # underflow
  x = _where(underflow_cond, lax.full_like(x, 0, dtype=dtype), x)
  # overflow
  x = _where(overflow_cond, lax.sign(x1) * lax.full_like(x, np.inf), x)
  # ldexp(x1, x2) = x1 for x1 = inf, -inf, nan, 0
  return _where(isinf(x1) | isnan(x1) | (x1 == 0), x1, x)
Ejemplo n.º 6
0
def eig_abstract_eval(operand, *, compute_left_eigenvectors,
                      compute_right_eigenvectors):
  if isinstance(operand, ShapedArray):
    if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
      raise ValueError("Argument to nonsymmetric eigendecomposition must have "
                       "shape [..., n, n], got shape {}".format(operand.shape))

    batch_dims = operand.shape[:-2]
    n = operand.shape[-1]
    dtype = np.complex64 if dtypes.finfo(operand.dtype).bits == 32 else np.complex128
    dtype = dtypes.canonicalize_dtype(dtype)
    vl = vr = operand.update(shape=batch_dims + (n, n), dtype=dtype)
    w = operand.update(shape=batch_dims + (n,), dtype=dtype)
  else:
    raise NotImplementedError

  output = [w]
  if compute_left_eigenvectors:
    output.append(vl)
  if compute_right_eigenvectors:
    output.append(vr)

  return tuple(output)
Ejemplo n.º 7
0
def signbit(x):
    x, = _promote_args("signbit", x)
    dtype = dtypes.dtype(x)
    if dtypes.issubdtype(dtype, np.integer):
        return lax.lt(x, _constant_like(x, 0))
    elif dtypes.issubdtype(dtype, np.bool_):
        return lax.full_like(x, False, dtype=np.bool_)
    elif not dtypes.issubdtype(dtype, np.floating):
        raise ValueError("jax.numpy.signbit is not well defined for %s" %
                         dtype)

    # TPU supports BF16 but not S16 types, so as a workaround, convert BF16 to
    # F32.
    if dtype == dtypes.bfloat16:
        dtype = np.float32
        x = lax.convert_element_type(x, np.float32)

    info = dtypes.finfo(dtype)
    if info.bits not in _INT_DTYPES:
        raise NotImplementedError(
            "jax.numpy.signbit only supports 16, 32, and 64-bit types.")
    int_type = _INT_DTYPES[info.bits]
    x = lax.bitcast_convert_type(x, int_type)
    return lax.convert_element_type(x >> (info.nexp + info.nmant), np.bool_)
Ejemplo n.º 8
0
def _select_and_gather_add_translation(ctx,
                                       avals_in,
                                       avals_out,
                                       tangents,
                                       operand,
                                       *,
                                       select_prim,
                                       window_dimensions,
                                       window_strides,
                                       padding,
                                       base_dilation,
                                       window_dilation,
                                       max_bits=64):
    c = ctx.builder
    tangents_aval, operand_aval, = avals_in
    dtype = operand_aval.dtype
    etype = xla.dtype_to_primitive_type(dtype)
    nbits = dtypes.finfo(dtype).bits

    assert nbits <= max_bits
    double_word_reduction = nbits * 2 <= max_bits

    const = lambda c, dtype, x: xops.Constant(c, np.array(x, dtype=dtype))

    if double_word_reduction:
        # TODO(b/73062247): XLA doesn't yet implement ReduceWindow on tuples, so
        # we implement a pair-wise ReduceWindow by packing two k-bit values into
        # 2k-bit unsigned integer using bit tricks.
        word_dtype = lax._UINT_DTYPES[nbits]
        double_word_dtype = lax._UINT_DTYPES[nbits * 2]
        word_type = xla.dtype_to_primitive_type(word_dtype)
        double_word_type = xla.dtype_to_primitive_type(double_word_dtype)

        # Packs two values into a tuple.
        def pack(a, b):
            a = xops.BitcastConvertType(a, word_type)
            b = xops.BitcastConvertType(b, word_type)
            a = xops.ConvertElementType(a, double_word_type)
            b = xops.ConvertElementType(b, double_word_type)
            a = xops.ShiftLeft(a, const(c, double_word_dtype, nbits))
            return xops.Or(a, b)

        # Unpacks the first element of a tuple.
        def fst(c, t):
            st = xops.ShiftRightLogical(t, const(c, double_word_dtype, nbits))
            return xops.BitcastConvertType(
                xops.ConvertElementType(st, word_type), etype)

        # Unpacks the second element of a tuple.
        def snd(t):
            return xops.BitcastConvertType(
                xops.ConvertElementType(t, word_type), etype)

    else:
        # The double-word trick above only works if we have a sufficiently large
        # type. As an alternative, we can pack two half words into a single word,
        # at the cost of precision.
        # TODO(b/73062247): add support for tuple reductions and remove this case.
        warnings.warn(
            "Using reduced precision for gradient of reduce-window "
            "min/max operator to work around missing XLA support for "
            "pair-reductions. This is likely from a second or "
            "higher derivative of a max-pooling operation.")
        r_nbits = nbits // 2
        # Drop/round the bottom mantissa bits.
        nexp = dtypes.finfo(dtype).nexp
        nmant = r_nbits - nexp - 1

        double_word_dtype = word_dtype = lax._UINT_DTYPES[nbits]
        word_type = xla.dtype_to_primitive_type(word_dtype)

        # Packs two values into a tuple.
        def pack(a, b):
            a = xops.ReducePrecision(a,
                                     exponent_bits=nexp,
                                     mantissa_bits=nmant)
            b = xops.ReducePrecision(b,
                                     exponent_bits=nexp,
                                     mantissa_bits=nmant)
            a = xops.BitcastConvertType(a, word_type)
            b = xops.BitcastConvertType(b, word_type)
            b = xops.ShiftRightLogical(b, const(c, word_dtype, r_nbits))
            return xops.Or(a, b)

        # Unpacks the first element of a tuple.
        def fst(c, t):
            st = xops.And(
                t, const(c, word_dtype, ((1 << r_nbits) - 1) << r_nbits))
            return xops.BitcastConvertType(st, etype)

        # Unpacks the second element of a tuple.
        def snd(t):
            return xops.BitcastConvertType(
                xops.ShiftLeft(t, const(c, word_dtype, r_nbits)), etype)

    def reducer():
        c = xc.XlaBuilder("select_and_gather_pair_reducer")
        x = xla.parameter(
            c, 0, xla_client.Shape.array_shape(np.dtype(double_word_dtype),
                                               ()))
        y = xla.parameter(
            c, 1, xla_client.Shape.array_shape(np.dtype(double_word_dtype),
                                               ()))
        assert select_prim is lax.ge_p or select_prim is lax.le_p
        which = xops.Ge if select_prim is lax.ge_p else xops.Le
        xops.Select(which(fst(c, x), fst(c, y)), x, y)
        return c.build()

    assert select_prim is lax.ge_p or select_prim is lax.le_p, select_prim
    init = -np.inf if select_prim is lax.ge_p else np.inf
    out = xops.ReduceWindowWithGeneralPadding(
        pack(operand, tangents), pack(const(c, dtype,
                                            init), const(c, dtype, 0)),
        reducer(), window_dimensions, window_strides, base_dilation,
        window_dilation, padding)
    return [snd(out)]
Ejemplo n.º 9
0
def num_float_bits(dtype):
    return _dtypes.finfo(_dtypes.canonicalize_dtype(dtype)).bits
Ejemplo n.º 10
0
def _select_and_gather_add_lowering(ctx,
                                    tangents,
                                    operand,
                                    *,
                                    select_prim,
                                    window_dimensions,
                                    window_strides,
                                    padding,
                                    base_dilation,
                                    window_dilation,
                                    max_bits=64):
    _, operand_aval, = ctx.avals_in
    out_aval, = ctx.avals_out
    dtype = operand_aval.dtype
    etype = mlir.dtype_to_ir_type(dtype)
    nbits = dtypes.finfo(dtype).bits

    assert nbits <= max_bits
    double_word_reduction = nbits * 2 <= max_bits

    const = lambda dtype, x: mlir.ir_constant(np.array(x, dtype=dtype),
                                              canonicalize_types=False)

    if jax._src.lib.mlir_api_version >= 9:

        def _broadcast(x, dims):
            return mhlo.BroadcastOp(x, mlir.dense_int_elements(dims))
    else:

        def _broadcast(x, dims):
            etype = ir.RankedTensorType(x.type).element_type
            return mhlo.BroadcastOp(ir.RankedTensorType(dims, etype), x,
                                    mlir.dense_int_elements(dims))

    if double_word_reduction:
        # TODO(b/73062247): XLA doesn't yet implement ReduceWindow on tuples, so
        # we implement a pair-wise ReduceWindow by packing two k-bit values into
        # 2k-bit unsigned integer using bit tricks.
        word_dtype = lax._UINT_DTYPES[nbits]
        double_word_dtype = lax._UINT_DTYPES[nbits * 2]
        word_type = mlir.dtype_to_ir_type(word_dtype)
        double_word_type = mlir.dtype_to_ir_type(double_word_dtype)

        # Packs two values into a tuple.
        def pack(a, b):
            a_dims = ir.RankedTensorType(a.type).shape
            b_dims = ir.RankedTensorType(b.type).shape
            a = mhlo.BitcastConvertOp(
                ir.RankedTensorType.get(a_dims, word_type), a)
            b = mhlo.BitcastConvertOp(
                ir.RankedTensorType.get(b_dims, word_type), b)
            a = mhlo.ConvertOp(
                ir.RankedTensorType.get(a_dims, double_word_type), a)
            b = mhlo.ConvertOp(
                ir.RankedTensorType.get(b_dims, double_word_type), b)
            a = mhlo.ShiftLeftOp(
                a, _broadcast(const(double_word_dtype, nbits), a_dims))
            return mhlo.OrOp(a, b)

        # Unpacks the first element of a tuple.
        def fst(t):
            dims = ir.RankedTensorType(t.type).shape
            st = mhlo.ShiftRightLogicalOp(t, const(double_word_dtype, nbits))
            return mhlo.BitcastConvertOp(
                ir.RankedTensorType.get(dims, etype),
                mhlo.ConvertOp(ir.RankedTensorType.get(dims, word_type),
                               st)).result

        # Unpacks the second element of a tuple.
        def snd(t):
            dims = ir.RankedTensorType(t.type).shape
            return mhlo.BitcastConvertOp(
                ir.RankedTensorType.get(dims, etype),
                mhlo.ConvertOp(ir.RankedTensorType.get(dims, word_type),
                               t)).result

    else:
        # The double-word trick above only works if we have a sufficiently large
        # type. As an alternative, we can pack two half words into a single word,
        # at the cost of precision.
        # TODO(b/73062247): add support for tuple reductions and remove this case.
        warnings.warn(
            "Using reduced precision for gradient of reduce-window "
            "min/max operator to work around missing XLA support for "
            "pair-reductions. This is likely from a second or "
            "higher derivative of a max-pooling operation.")
        r_nbits = nbits // 2
        # Drop/round the bottom mantissa bits.
        nexp = dtypes.finfo(dtype).nexp
        nmant = r_nbits - nexp - 1

        double_word_dtype = word_dtype = lax._UINT_DTYPES[nbits]
        double_word_type = word_type = mlir.dtype_to_ir_type(word_dtype)

        # Packs two values into a tuple.
        def pack(a, b):
            a_dims = ir.RankedTensorType(a.type).shape
            b_dims = ir.RankedTensorType(b.type).shape
            if jax._src.lib.mlir_api_version >= 21:
                a = mhlo.ReducePrecisionOp(a,
                                           exponent_bits=mlir.i32_attr(nexp),
                                           mantissa_bits=mlir.i32_attr(nmant))
                b = mhlo.ReducePrecisionOp(b,
                                           exponent_bits=mlir.i32_attr(nexp),
                                           mantissa_bits=mlir.i32_attr(nmant))
            else:
                a = mhlo.ReducePrecisionOp(a.type,
                                           a,
                                           exponent_bits=mlir.i32_attr(nexp),
                                           mantissa_bits=mlir.i32_attr(nmant))
                b = mhlo.ReducePrecisionOp(b.type,
                                           b,
                                           exponent_bits=mlir.i32_attr(nexp),
                                           mantissa_bits=mlir.i32_attr(nmant))
            a = mhlo.BitcastConvertOp(
                ir.RankedTensorType.get(a_dims, word_type), a)
            b = mhlo.BitcastConvertOp(
                ir.RankedTensorType.get(b_dims, word_type), b)
            b = mhlo.ShiftRightLogicalOp(
                b, _broadcast(const(word_dtype, r_nbits), b_dims))
            return mhlo.OrOp(a, b)

        # Unpacks the first element of a tuple.
        def fst(t):
            st = mhlo.AndOp(t,
                            const(word_dtype, ((1 << r_nbits) - 1) << r_nbits))
            return mhlo.BitcastConvertOp(ir.RankedTensorType.get([], etype),
                                         st).result

        # Unpacks the second element of a tuple.
        def snd(t):
            dims = ir.RankedTensorType(t.type).shape
            return mhlo.BitcastConvertOp(
                ir.RankedTensorType.get(dims, etype),
                mhlo.ShiftLeftOp(t, _broadcast(const(word_dtype, r_nbits),
                                               dims))).result

    assert select_prim is lax.ge_p or select_prim is lax.le_p, select_prim
    init = -np.inf if select_prim is lax.ge_p else np.inf
    rw = mhlo.ReduceWindowOp(
        [ir.RankedTensorType.get(out_aval.shape, double_word_type)],
        pack(operand, tangents),
        pack(const(dtype, init), const(dtype, 0)),
        mlir.dense_int_elements(window_dimensions),
        window_strides=mlir.dense_int_elements(window_strides),
        base_dilations=mlir.dense_int_elements(base_dilation),
        window_dilations=mlir.dense_int_elements(window_dilation),
        padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64),
                                            shape=(len(padding), 2)))
    scalar_type = ir.RankedTensorType.get([], double_word_type)
    reducer = rw.regions[0].blocks.append(scalar_type, scalar_type)
    with ir.InsertionPoint(reducer):
        x, y = reducer.arguments
        assert select_prim is lax.ge_p or select_prim is lax.le_p
        which = "GE" if select_prim is lax.ge_p else "LE"
        out = mhlo.SelectOp(mlir.compare_mhlo(fst(x), fst(y), which), x, y)
        mhlo.ReturnOp(out)
    return [snd(rw.result)]