def _normalize_float(x): info = dtypes.finfo(dtypes.dtype(x)) cond = lax.abs(x) < info.tiny x1 = _where(cond, x * _lax_const(x, 1 << info.nmant), x) x2 = _where(cond, lax.full_like(x, -info.nmant, dtype=np.int32), lax.full_like(x, 0, dtype=np.int32)) int_type = _INT_DTYPES[info.bits] return lax.bitcast_convert_type(x1, int_type), x2
def _normalize_float(x): info = dtypes.finfo(dtypes.dtype(x)) int_type = _INT_DTYPES[info.bits] cond = lax.abs(x) < info.tiny x1 = _where(cond, x * _lax_const(x, 1 << info.nmant), x) x2 = _where(cond, int_type(-info.nmant), int_type(0)) return lax.bitcast_convert_type(x1, int_type), x2
def ldexp(x1, x2): _check_arraylike("ldexp", x1, x2) x1_dtype = dtypes.dtype(x1) x2_dtype = dtypes.dtype(x2) if (dtypes.issubdtype(x1_dtype, np.complexfloating) or dtypes.issubdtype(x2_dtype, np.inexact)): raise ValueError( f"ldexp not supported for input types {(x1_dtype, x2_dtype)}") x1, x2 = _promote_shapes("ldexp", x1, x2) dtype = dtypes.canonicalize_dtype(dtypes._to_inexact_dtype(x1_dtype)) info = dtypes.finfo(dtype) int_type = _INT_DTYPES[info.bits] x1 = lax.convert_element_type(x1, dtype) x2 = lax.convert_element_type(x2, int_type) mask = (1 << info.nexp) - 1 bias = ((1 << info.nexp) - 1) >> 1 x, e = _normalize_float(x1) x2 += e + ((x >> info.nmant) & mask) - bias # find underflow/overflow before denormalization underflow_cond = x2 < -(bias + info.nmant) overflow_cond = x2 > bias m = lax.full_like(x, 1, dtype=dtype) # denormals cond = x2 < -bias + 1 x2 = _where(cond, x2 + info.nmant, x2) m = _where(cond, m / (1 << info.nmant), m) x2 = lax.convert_element_type(x2, np.int32) x &= ~(mask << info.nmant) x |= ((lax.convert_element_type(x2, int_type) + bias) << info.nmant) x = lax.convert_element_type(m, dtype) * lax.bitcast_convert_type(x, dtype) # underflow x = _where(underflow_cond, lax.full_like(x, 0, dtype=dtype), x) # overflow x = _where(overflow_cond, lax.sign(x1) * lax.full_like(x, np.inf), x) # ldexp(x1, x2) = x1 for x1 = inf, -inf, nan, 0 return _where(isinf(x1) | isnan(x1) | (x1 == 0), x1, x)
def frexp(x): _check_arraylike("frexp", x) x, = _promote_dtypes_inexact(x) if dtypes.issubdtype(x.dtype, np.complexfloating): raise TypeError("frexp does not support complex-valued inputs") dtype = dtypes.dtype(x) info = dtypes.finfo(dtype) mask = (1 << info.nexp) - 1 bias = ((1 << info.nexp) - 1) >> 1 x1, x2 = _normalize_float(x) x2 += ((x1 >> info.nmant) & mask) - bias + 1 x1 &= ~(mask << info.nmant) x1 |= (bias - 1) << info.nmant x1 = lax.bitcast_convert_type(x1, dtype) cond = isinf(x) | isnan(x) | (x == 0) x2 = _where(cond, lax_internal._zeros(x2), x2) return _where(cond, x, x1), lax.convert_element_type(x2, np.int32)
def ldexp(x1, x2): _check_arraylike("ldexp", x1, x2) dtype = dtypes.canonicalize_dtype(_result_dtype(np.ldexp, x1, x2)) x1, x2 = _promote_shapes("ldexp", x1, x2) x1 = lax.convert_element_type(x1, dtype) info = dtypes.finfo(dtype) mask = (1 << info.nexp) - 1 bias = ((1 << info.nexp) - 1) >> 1 int_type = _INT_DTYPES[info.bits] x, e = _normalize_float(x1) x2 += e + ((x >> info.nmant) & mask) - bias # find underflow/overflow before denormalization underflow_cond = x2 < -(bias + info.nmant) overflow_cond = x2 > bias m = lax.full_like(x, 1, dtype=dtype) # denormals cond = x2 < -bias + 1 x2 = _where(cond, x2 + info.nmant, x2) m = _where(cond, m / (1 << info.nmant), m) x2 = lax.convert_element_type(x2, np.int32) x &= ~(mask << info.nmant) x |= ((lax.convert_element_type(x2, int_type) + bias) << info.nmant) x = lax.convert_element_type(m, dtype) * lax.bitcast_convert_type(x, dtype) # underflow x = _where(underflow_cond, lax.full_like(x, 0, dtype=dtype), x) # overflow x = _where(overflow_cond, lax.sign(x1) * lax.full_like(x, np.inf), x) # ldexp(x1, x2) = x1 for x1 = inf, -inf, nan, 0 return _where(isinf(x1) | isnan(x1) | (x1 == 0), x1, x)
def eig_abstract_eval(operand, *, compute_left_eigenvectors, compute_right_eigenvectors): if isinstance(operand, ShapedArray): if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]: raise ValueError("Argument to nonsymmetric eigendecomposition must have " "shape [..., n, n], got shape {}".format(operand.shape)) batch_dims = operand.shape[:-2] n = operand.shape[-1] dtype = np.complex64 if dtypes.finfo(operand.dtype).bits == 32 else np.complex128 dtype = dtypes.canonicalize_dtype(dtype) vl = vr = operand.update(shape=batch_dims + (n, n), dtype=dtype) w = operand.update(shape=batch_dims + (n,), dtype=dtype) else: raise NotImplementedError output = [w] if compute_left_eigenvectors: output.append(vl) if compute_right_eigenvectors: output.append(vr) return tuple(output)
def signbit(x): x, = _promote_args("signbit", x) dtype = dtypes.dtype(x) if dtypes.issubdtype(dtype, np.integer): return lax.lt(x, _constant_like(x, 0)) elif dtypes.issubdtype(dtype, np.bool_): return lax.full_like(x, False, dtype=np.bool_) elif not dtypes.issubdtype(dtype, np.floating): raise ValueError("jax.numpy.signbit is not well defined for %s" % dtype) # TPU supports BF16 but not S16 types, so as a workaround, convert BF16 to # F32. if dtype == dtypes.bfloat16: dtype = np.float32 x = lax.convert_element_type(x, np.float32) info = dtypes.finfo(dtype) if info.bits not in _INT_DTYPES: raise NotImplementedError( "jax.numpy.signbit only supports 16, 32, and 64-bit types.") int_type = _INT_DTYPES[info.bits] x = lax.bitcast_convert_type(x, int_type) return lax.convert_element_type(x >> (info.nexp + info.nmant), np.bool_)
def _select_and_gather_add_translation(ctx, avals_in, avals_out, tangents, operand, *, select_prim, window_dimensions, window_strides, padding, base_dilation, window_dilation, max_bits=64): c = ctx.builder tangents_aval, operand_aval, = avals_in dtype = operand_aval.dtype etype = xla.dtype_to_primitive_type(dtype) nbits = dtypes.finfo(dtype).bits assert nbits <= max_bits double_word_reduction = nbits * 2 <= max_bits const = lambda c, dtype, x: xops.Constant(c, np.array(x, dtype=dtype)) if double_word_reduction: # TODO(b/73062247): XLA doesn't yet implement ReduceWindow on tuples, so # we implement a pair-wise ReduceWindow by packing two k-bit values into # 2k-bit unsigned integer using bit tricks. word_dtype = lax._UINT_DTYPES[nbits] double_word_dtype = lax._UINT_DTYPES[nbits * 2] word_type = xla.dtype_to_primitive_type(word_dtype) double_word_type = xla.dtype_to_primitive_type(double_word_dtype) # Packs two values into a tuple. def pack(a, b): a = xops.BitcastConvertType(a, word_type) b = xops.BitcastConvertType(b, word_type) a = xops.ConvertElementType(a, double_word_type) b = xops.ConvertElementType(b, double_word_type) a = xops.ShiftLeft(a, const(c, double_word_dtype, nbits)) return xops.Or(a, b) # Unpacks the first element of a tuple. def fst(c, t): st = xops.ShiftRightLogical(t, const(c, double_word_dtype, nbits)) return xops.BitcastConvertType( xops.ConvertElementType(st, word_type), etype) # Unpacks the second element of a tuple. def snd(t): return xops.BitcastConvertType( xops.ConvertElementType(t, word_type), etype) else: # The double-word trick above only works if we have a sufficiently large # type. As an alternative, we can pack two half words into a single word, # at the cost of precision. # TODO(b/73062247): add support for tuple reductions and remove this case. warnings.warn( "Using reduced precision for gradient of reduce-window " "min/max operator to work around missing XLA support for " "pair-reductions. This is likely from a second or " "higher derivative of a max-pooling operation.") r_nbits = nbits // 2 # Drop/round the bottom mantissa bits. nexp = dtypes.finfo(dtype).nexp nmant = r_nbits - nexp - 1 double_word_dtype = word_dtype = lax._UINT_DTYPES[nbits] word_type = xla.dtype_to_primitive_type(word_dtype) # Packs two values into a tuple. def pack(a, b): a = xops.ReducePrecision(a, exponent_bits=nexp, mantissa_bits=nmant) b = xops.ReducePrecision(b, exponent_bits=nexp, mantissa_bits=nmant) a = xops.BitcastConvertType(a, word_type) b = xops.BitcastConvertType(b, word_type) b = xops.ShiftRightLogical(b, const(c, word_dtype, r_nbits)) return xops.Or(a, b) # Unpacks the first element of a tuple. def fst(c, t): st = xops.And( t, const(c, word_dtype, ((1 << r_nbits) - 1) << r_nbits)) return xops.BitcastConvertType(st, etype) # Unpacks the second element of a tuple. def snd(t): return xops.BitcastConvertType( xops.ShiftLeft(t, const(c, word_dtype, r_nbits)), etype) def reducer(): c = xc.XlaBuilder("select_and_gather_pair_reducer") x = xla.parameter( c, 0, xla_client.Shape.array_shape(np.dtype(double_word_dtype), ())) y = xla.parameter( c, 1, xla_client.Shape.array_shape(np.dtype(double_word_dtype), ())) assert select_prim is lax.ge_p or select_prim is lax.le_p which = xops.Ge if select_prim is lax.ge_p else xops.Le xops.Select(which(fst(c, x), fst(c, y)), x, y) return c.build() assert select_prim is lax.ge_p or select_prim is lax.le_p, select_prim init = -np.inf if select_prim is lax.ge_p else np.inf out = xops.ReduceWindowWithGeneralPadding( pack(operand, tangents), pack(const(c, dtype, init), const(c, dtype, 0)), reducer(), window_dimensions, window_strides, base_dilation, window_dilation, padding) return [snd(out)]
def num_float_bits(dtype): return _dtypes.finfo(_dtypes.canonicalize_dtype(dtype)).bits
def _select_and_gather_add_lowering(ctx, tangents, operand, *, select_prim, window_dimensions, window_strides, padding, base_dilation, window_dilation, max_bits=64): _, operand_aval, = ctx.avals_in out_aval, = ctx.avals_out dtype = operand_aval.dtype etype = mlir.dtype_to_ir_type(dtype) nbits = dtypes.finfo(dtype).bits assert nbits <= max_bits double_word_reduction = nbits * 2 <= max_bits const = lambda dtype, x: mlir.ir_constant(np.array(x, dtype=dtype), canonicalize_types=False) if jax._src.lib.mlir_api_version >= 9: def _broadcast(x, dims): return mhlo.BroadcastOp(x, mlir.dense_int_elements(dims)) else: def _broadcast(x, dims): etype = ir.RankedTensorType(x.type).element_type return mhlo.BroadcastOp(ir.RankedTensorType(dims, etype), x, mlir.dense_int_elements(dims)) if double_word_reduction: # TODO(b/73062247): XLA doesn't yet implement ReduceWindow on tuples, so # we implement a pair-wise ReduceWindow by packing two k-bit values into # 2k-bit unsigned integer using bit tricks. word_dtype = lax._UINT_DTYPES[nbits] double_word_dtype = lax._UINT_DTYPES[nbits * 2] word_type = mlir.dtype_to_ir_type(word_dtype) double_word_type = mlir.dtype_to_ir_type(double_word_dtype) # Packs two values into a tuple. def pack(a, b): a_dims = ir.RankedTensorType(a.type).shape b_dims = ir.RankedTensorType(b.type).shape a = mhlo.BitcastConvertOp( ir.RankedTensorType.get(a_dims, word_type), a) b = mhlo.BitcastConvertOp( ir.RankedTensorType.get(b_dims, word_type), b) a = mhlo.ConvertOp( ir.RankedTensorType.get(a_dims, double_word_type), a) b = mhlo.ConvertOp( ir.RankedTensorType.get(b_dims, double_word_type), b) a = mhlo.ShiftLeftOp( a, _broadcast(const(double_word_dtype, nbits), a_dims)) return mhlo.OrOp(a, b) # Unpacks the first element of a tuple. def fst(t): dims = ir.RankedTensorType(t.type).shape st = mhlo.ShiftRightLogicalOp(t, const(double_word_dtype, nbits)) return mhlo.BitcastConvertOp( ir.RankedTensorType.get(dims, etype), mhlo.ConvertOp(ir.RankedTensorType.get(dims, word_type), st)).result # Unpacks the second element of a tuple. def snd(t): dims = ir.RankedTensorType(t.type).shape return mhlo.BitcastConvertOp( ir.RankedTensorType.get(dims, etype), mhlo.ConvertOp(ir.RankedTensorType.get(dims, word_type), t)).result else: # The double-word trick above only works if we have a sufficiently large # type. As an alternative, we can pack two half words into a single word, # at the cost of precision. # TODO(b/73062247): add support for tuple reductions and remove this case. warnings.warn( "Using reduced precision for gradient of reduce-window " "min/max operator to work around missing XLA support for " "pair-reductions. This is likely from a second or " "higher derivative of a max-pooling operation.") r_nbits = nbits // 2 # Drop/round the bottom mantissa bits. nexp = dtypes.finfo(dtype).nexp nmant = r_nbits - nexp - 1 double_word_dtype = word_dtype = lax._UINT_DTYPES[nbits] double_word_type = word_type = mlir.dtype_to_ir_type(word_dtype) # Packs two values into a tuple. def pack(a, b): a_dims = ir.RankedTensorType(a.type).shape b_dims = ir.RankedTensorType(b.type).shape if jax._src.lib.mlir_api_version >= 21: a = mhlo.ReducePrecisionOp(a, exponent_bits=mlir.i32_attr(nexp), mantissa_bits=mlir.i32_attr(nmant)) b = mhlo.ReducePrecisionOp(b, exponent_bits=mlir.i32_attr(nexp), mantissa_bits=mlir.i32_attr(nmant)) else: a = mhlo.ReducePrecisionOp(a.type, a, exponent_bits=mlir.i32_attr(nexp), mantissa_bits=mlir.i32_attr(nmant)) b = mhlo.ReducePrecisionOp(b.type, b, exponent_bits=mlir.i32_attr(nexp), mantissa_bits=mlir.i32_attr(nmant)) a = mhlo.BitcastConvertOp( ir.RankedTensorType.get(a_dims, word_type), a) b = mhlo.BitcastConvertOp( ir.RankedTensorType.get(b_dims, word_type), b) b = mhlo.ShiftRightLogicalOp( b, _broadcast(const(word_dtype, r_nbits), b_dims)) return mhlo.OrOp(a, b) # Unpacks the first element of a tuple. def fst(t): st = mhlo.AndOp(t, const(word_dtype, ((1 << r_nbits) - 1) << r_nbits)) return mhlo.BitcastConvertOp(ir.RankedTensorType.get([], etype), st).result # Unpacks the second element of a tuple. def snd(t): dims = ir.RankedTensorType(t.type).shape return mhlo.BitcastConvertOp( ir.RankedTensorType.get(dims, etype), mhlo.ShiftLeftOp(t, _broadcast(const(word_dtype, r_nbits), dims))).result assert select_prim is lax.ge_p or select_prim is lax.le_p, select_prim init = -np.inf if select_prim is lax.ge_p else np.inf rw = mhlo.ReduceWindowOp( [ir.RankedTensorType.get(out_aval.shape, double_word_type)], pack(operand, tangents), pack(const(dtype, init), const(dtype, 0)), mlir.dense_int_elements(window_dimensions), window_strides=mlir.dense_int_elements(window_strides), base_dilations=mlir.dense_int_elements(base_dilation), window_dilations=mlir.dense_int_elements(window_dilation), padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64), shape=(len(padding), 2))) scalar_type = ir.RankedTensorType.get([], double_word_type) reducer = rw.regions[0].blocks.append(scalar_type, scalar_type) with ir.InsertionPoint(reducer): x, y = reducer.arguments assert select_prim is lax.ge_p or select_prim is lax.le_p which = "GE" if select_prim is lax.ge_p else "LE" out = mhlo.SelectOp(mlir.compare_mhlo(fst(x), fst(y), which), x, y) mhlo.ReturnOp(out) return [snd(rw.result)]