def div(lhs, rhs): if dtypes.issubdtype(dtypes.result_type(lhs), np.integer): quotient = np.floor_divide(lhs, rhs) select = np.logical_and( np.sign(lhs) != np.sign(rhs), np.remainder(lhs, rhs) != 0) return np.where(select, quotient + 1, quotient) else: return np.divide(lhs, rhs)
def _reduce_window_sum_shape_rule(operand, *, window_dimensions, window_strides, padding, base_dilation, window_dilation): if not dtypes.issubdtype(operand.dtype, np.number): msg = "operand to reduce_window_sum must have a number dtype, got {}" raise TypeError(msg.format(np.dtype(operand.dtype).name)) return _common_reduce_window_shape_rule(operand, window_dimensions, window_strides, padding, base_dilation, window_dilation)
def _check_special(name, xla_shape, buf): assert not xla_shape.is_tuple() if dtypes.issubdtype(xla_shape.element_type(), np.inexact): if config.jax_debug_nans and np.any(np.isnan(buf.to_py())): raise FloatingPointError( f"invalid value (nan) encountered in {name}") if config.jax_debug_infs and np.any(np.isinf(buf.to_py())): raise FloatingPointError( f"invalid value (inf) encountered in {name}")
def fn(x1, x2): x1, x2 = _promote_args(numpy_fn.__name__, x1, x2) # Comparison on complex types are defined as a lexicographic ordering on # the (real, imag) pair. if dtypes.issubdtype(dtypes.dtype(x1), np.complexfloating): rx = lax.real(x1) ry = lax.real(x2) return lax.select(lax.eq(rx, ry), lax_fn(lax.imag(x1), lax.imag(x2)), lax_fn(rx, ry)) return lax_fn(x1, x2)
def _var_promote_types(a_dtype, dtype): if dtype: if (not dtypes.issubdtype(dtype, np.complexfloating) and dtypes.issubdtype(a_dtype, np.complexfloating)): msg = ( "jax.numpy.var does not yet support real dtype parameters when " "computing the variance of an array of complex values. The " "semantics of numpy.var seem unclear in this case. Please comment " "on https://github.com/google/jax/issues/2283 if this behavior is " "important to you.") raise ValueError(msg) a_dtype = dtypes.promote_types(a_dtype, dtype) else: if not dtypes.issubdtype(a_dtype, np.inexact): dtype = a_dtype = dtypes.canonicalize_dtype(dtypes.float_) else: dtype = _complex_elem_type(a_dtype) a_dtype = dtypes.promote_types(a_dtype, np.float32) return a_dtype, dtype
def _average(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, weights=None, returned=False): a = _asarray(a) if weights is None: # Treat all weights as 1 avg = mean(a, axis=axis) if axis is None: weights_sum = lax.full((), core.dimension_as_value(np.size(a)), dtype=avg.dtype) else: weights_sum = lax.full_like(avg, core.dimension_as_value(a.shape[axis]), dtype=avg.dtype) else: weights = _asarray(weights) if dtypes.issubdtype(a.dtype, np.inexact): out_dtype = dtypes.result_type(a.dtype, weights.dtype) else: out_dtype = dtypes.result_type(a.dtype, weights.dtype, dtypes.float_) out_dtype = dtypes.canonicalize_dtype(out_dtype) a_shape = np.shape(a) a_ndim = len(a_shape) weights_shape = np.shape(weights) axis = None if axis is None else _canonicalize_axis(axis, a_ndim) if a_shape != weights_shape: # Make sure the dimensions work out if axis is None: raise ValueError("Axis must be specified when shapes of a and " "weights differ.") if len(weights_shape) != 1: raise ValueError("1D weights expected when shapes of a and " "weights differ.") if not core.symbolic_equal_dim(weights_shape[0], a_shape[axis]): raise ValueError("Length of weights not " "compatible with specified axis.") weights = _broadcast_to(weights, (a_ndim - 1) * (1, ) + weights_shape) weights = _moveaxis(weights, -1, axis) weights_sum = sum(weights, axis=axis, dtype=out_dtype) avg = sum(a * weights, axis=axis, dtype=out_dtype) / weights_sum if returned: if avg.shape != weights_sum.shape: weights_sum = _broadcast_to(weights_sum, avg.shape) return avg, weights_sum return avg
def _reduction_init_val(a, init_val): # This function uses np.* functions because lax pattern matches against the # specific concrete values of the reduction inputs. a_dtype = dtypes.canonicalize_dtype(dtypes.dtype(a)) if a_dtype == 'bool': return np.array(init_val > 0, dtype=a_dtype) try: return np.array(init_val, dtype=a_dtype) except OverflowError: assert dtypes.issubdtype(a_dtype, np.integer) sign, info = np.sign(init_val), dtypes.iinfo(a_dtype) return np.array(info.min if sign < 0 else info.max, dtype=a_dtype)
def rand(shape, dtype): """The random sampler function.""" if not _dtypes.issubdtype(dtype, np.floating): # only float types have inf return base_rand(shape, dtype) if _dtypes.issubdtype(dtype, np.complexfloating): base_dtype = np.real(np.array(0, dtype=dtype)).dtype out = (rand(shape, base_dtype) + np.array(1j, dtype) * rand(shape, base_dtype)) return _cast_to_shape(out, shape, dtype) dims = _dims_of_shape(shape) posinf_flips = rng.rand(*dims) < 0.1 neginf_flips = rng.rand(*dims) < 0.1 vals = base_rand(shape, dtype) vals = np.where(posinf_flips, np.array(np.inf, dtype=dtype), vals) vals = np.where(neginf_flips, np.array(-np.inf, dtype=dtype), vals) return _cast_to_shape(np.asarray(vals, dtype=dtype), shape, dtype)
def testBinaryNonPromotion(self, dtype, weak_type, promotion): # Regression test for https://github.com/google/jax/issues/6051 x = lax_internal._convert_element_type(0, dtype, weak_type=weak_type) with jax.numpy_dtype_promotion(promotion): y = (x + x) if promotion == 'standard' or not weak_type or dtype == dtypes.bool_: expected_dtype = dtype elif dtypes.issubdtype(dtype, np.complexfloating): expected_dtype = dtypes.complex_ elif dtypes.issubdtype(dtype, np.floating): expected_dtype = dtypes.float_ else: expected_dtype = dtypes.int_ # No boolean weak types. expected_weak_type = weak_type and dtype != bool expected_dtype = dtypes.canonicalize_dtype(expected_dtype) self.assertEqual(y.dtype, expected_dtype) self.assertEqual(dtypes.is_weakly_typed(y), expected_weak_type)
def frexp(x): _check_arraylike("frexp", x) if dtypes.issubdtype(x.dtype, np.complexfloating): raise TypeError("frexp does not support complex-valued inputs") elif not dtypes.issubdtype(dtypes.dtype(x), np.floating): x = lax.convert_element_type(x, np.float_) dtype = dtypes.dtype(x) info = dtypes.finfo(dtype) mask = (1 << info.nexp) - 1 bias = ((1 << info.nexp) - 1) >> 1 x1, x2 = _normalize_float(x) x2 += ((x1 >> info.nmant) & mask) - bias + 1 x1 &= ~(mask << info.nmant) x1 |= (bias - 1) << info.nmant x1 = lax.bitcast_convert_type(x1, dtype) cond = isinf(x) | isnan(x) | (x == 0) x2 = _where(cond, lax_internal._zeros(x2), x2) return _where(cond, x, x1), lax.convert_element_type(x2, np.int32)
def logaddexp(x1, x2): x1, x2 = _promote_args_inexact("logaddexp", x1, x2) amax = lax.max(x1, x2) if dtypes.issubdtype(x1.dtype, np.floating): delta = lax.sub(x1, x2) return lax.select(lax_internal._isnan(delta), lax.add(x1, x2), # NaNs or infinities of the same sign. lax.add(amax, lax.log1p(lax.exp(lax.neg(lax.abs(delta)))))) else: delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2))) out = lax.add(amax, lax.log1p(lax.exp(delta))) return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi))
def floor_divide(x1, x2): x1, x2 = _promote_args("floor_divide", x1, x2) dtype = dtypes.dtype(x1) if dtypes.issubdtype(dtype, np.integer): quotient = lax.div(x1, x2) select = logical_and(lax.sign(x1) != lax.sign(x2), lax.rem(x1, x2) != 0) # TODO(mattjj): investigate why subtracting a scalar was causing promotion return _where(select, quotient - 1, quotient) elif dtypes.issubdtype(dtype, np.complexfloating): x1r = lax.real(x1) x1i = lax.imag(x1) x2r = lax.real(x2) x2i = lax.imag(x2) which = lax.ge(lax.abs(x2r), lax.abs(x2i)) rat1 = _where(which, lax.full_like(x2i, 1), lax.div(x2r, x2i)) rat2 = _where(which, lax.div(x2i, x2r), _lax_const(x2i, 1)) out = lax.floor(lax.div(lax.add(lax.mul(x1r, rat1), lax.mul(x1i, rat2)), lax.add(lax.mul(x2r, rat1), lax.mul(x2i, rat2)))) return lax.convert_element_type(out, dtype) else: return _float_divmod(x1, x2)[0]
def signbit(x): x, = _promote_args("signbit", x) dtype = dtypes.dtype(x) if dtypes.issubdtype(dtype, np.integer): return lax.lt(x, _constant_like(x, 0)) elif dtypes.issubdtype(dtype, np.bool_): return lax.full_like(x, False, dtype=np.bool_) elif not dtypes.issubdtype(dtype, np.floating): raise ValueError("jax.numpy.signbit is not well defined for %s" % dtype) # TPU supports BF16 but not S16 types, so as a workaround, convert BF16 to # F32. if dtype == dtypes.bfloat16: dtype = np.float32 x = lax.convert_element_type(x, np.float32) info = dtypes.finfo(dtype) if info.bits not in _INT_DTYPES: raise NotImplementedError( "jax.numpy.signbit only supports 16, 32, and 64-bit types.") int_type = _INT_DTYPES[info.bits] x = lax.bitcast_convert_type(x, int_type) return lax.convert_element_type(x >> (info.nexp + info.nmant), np.bool_)
def testCsdWithSameParamAgainstNumpy(self, *, shape, dtype, fs, window, nperseg, noverlap, nfft, detrend, scaling, timeaxis, average): is_complex = dtypes.issubdtype(dtype, np.complexfloating) if is_complex and detrend is not None: self.skipTest( "Complex signal is not supported in lax-backed `signal.detrend`." ) kwds = dict(fs=fs, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=not is_complex, scaling=scaling, axis=timeaxis, average=average) def osp_fun(x, y): # When the identical parameters are given, jsp-version follows # the behavior with copied parameters. freqs, Pxy = osp_signal.csd(x, y.copy(), **kwds) # Make type-casting the same as JAX. return freqs.astype(_real_dtype(dtype)), Pxy.astype( _complex_dtype(dtype)) jsp_fun = partial(jsp_signal.csd, **kwds) tol = { np.float32: 1e-5, np.float64: 1e-12, np.complex64: 1e-5, np.complex128: 1e-12 } if jtu.device_under_test() == 'tpu': tol = _TPU_FFT_TOL rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] * 2 self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
def testIsSubdtype(self): for t in scalar_types: self.assertTrue(dtypes.issubdtype(t, t)) self.assertTrue(dtypes.issubdtype(np.dtype(t).type, t)) self.assertTrue(dtypes.issubdtype(t, np.dtype(t).type)) self.assertTrue(dtypes.issubdtype(t, np.dtype(t))) if t != jnp.bfloat16: for category in [np.generic, jnp.inexact, jnp.integer, jnp.signedinteger, jnp.unsignedinteger, jnp.floating, jnp.complexfloating]: self.assertEqual(dtypes.issubdtype(t, category), np.issubdtype(np.dtype(t).type, category)) self.assertEqual(dtypes.issubdtype(t, category), np.issubdtype(np.dtype(t).type, category))
def testStftAgainstNumpy(self, *, shape, dtype, fs, window, nperseg, noverlap, nfft, detrend, boundary, padded, timeaxis): is_complex = dtypes.issubdtype(dtype, np.complexfloating) if is_complex and detrend is not None: self.skipTest( "Complex signal is not supported in lax-backed `signal.detrend`." ) kwds = dict(fs=fs, window=window, nfft=nfft, boundary=boundary, padded=padded, detrend=detrend, nperseg=nperseg, noverlap=noverlap, axis=timeaxis, return_onesided=not is_complex) def osp_fun(x): freqs, time, Pxx = osp_signal.stft(x, **kwds) return freqs.astype(_real_dtype(dtype)), time.astype( _real_dtype(dtype)), Pxx.astype(_complex_dtype(dtype)) jsp_fun = partial(jsp_signal.stft, **kwds) tol = { np.float32: 1e-5, np.float64: 1e-12, np.complex64: 1e-5, np.complex128: 1e-12 } if jtu.device_under_test() == 'tpu': tol = _TPU_FFT_TOL rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
def _power(x1, x2): x1, x2 = _promote_args("power", x1, x2) dtype = dtypes.dtype(x1) if not dtypes.issubdtype(dtype, np.integer): return lax.pow(x1, x2) # Integer power => use binary exponentiation. # TODO(phawkins): add integer pow support to XLA. bits = 6 # Anything more would overflow for any x1 > 1 zero = _constant_like(x2, 0) one = _constant_like(x2, 1) # Initialize acc carefully such that pow(0, x2) is zero for x2 != 0 acc = _where(lax.bitwise_and(lax.eq(x1, zero), lax.ne(x2, zero)), zero, one) for _ in range(bits): acc = _where(lax.bitwise_and(x2, one), lax.mul(acc, x1), acc) x1 = lax.mul(x1, x1) x2 = lax.shift_right_logical(x2, one) return acc
def nanvar(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, out=None, ddof=0, keepdims=False, where=None): _check_arraylike("nanvar", a) lax_internal._check_user_dtype_supported(dtype, "nanvar") if out is not None: raise NotImplementedError( "The 'out' argument to jnp.nanvar is not supported.") computation_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype) a = a.astype(computation_dtype) a_mean = nanmean(a, axis, dtype=computation_dtype, keepdims=True, where=where) centered = _where(lax_internal._isnan(a), 0, lax.sub(a, a_mean)) # double-where trick for gradients. if dtypes.issubdtype(centered.dtype, np.complexfloating): centered = lax.real(lax.mul(centered, lax.conj(centered))) else: centered = lax.square(centered) normalizer = sum(lax_internal.bitwise_not(lax_internal._isnan(a)), axis=axis, keepdims=keepdims, where=where) normalizer = normalizer - ddof normalizer_mask = lax.le(normalizer, 0) result = sum(centered, axis, keepdims=keepdims, where=where) result = _where(normalizer_mask, np.nan, result) divisor = _where(normalizer_mask, 1, normalizer) out = lax.div(result, lax.convert_element_type(divisor, result.dtype)) return lax.convert_element_type(out, dtype)
def _rand_dtype(rand, shape, dtype, scale=1., post=lambda x: x): """Produce random values given shape, dtype, scale, and post-processor. Args: rand: a function for producing random values of a given shape, e.g. a bound version of either np.RandomState.randn or np.RandomState.rand. shape: a shape value as a tuple of positive integers. dtype: a numpy dtype. scale: optional, a multiplicative scale for the random values (default 1). post: optional, a callable for post-processing the random values (default identity). Returns: An ndarray of the given shape and dtype using random values based on a call to rand but scaled, converted to the appropriate dtype, and post-processed. """ r = lambda: np.asarray(scale * rand(*_dims_of_shape(shape)), dtype) if _dtypes.issubdtype(dtype, np.complexfloating): vals = r() + 1.0j * r() else: vals = r() return _cast_to_shape(np.asarray(post(vals), dtype), shape, dtype)
def _var(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, out=None, ddof=0, keepdims=False, *, where=None): _check_arraylike("var", a) lax_internal._check_user_dtype_supported(dtype, "var") if out is not None: raise NotImplementedError( "The 'out' argument to jnp.var is not supported.") computation_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype) a = a.astype(computation_dtype) a_mean = mean(a, axis, dtype=computation_dtype, keepdims=True, where=where) centered = lax.sub(a, a_mean) if dtypes.issubdtype(centered.dtype, np.complexfloating): centered = lax.real(lax.mul(centered, lax.conj(centered))) else: centered = lax.square(centered) if where is None: if axis is None: normalizer = core.dimension_as_value(np.size(a)) else: normalizer = core.dimension_as_value(_axis_size(a, axis)) else: normalizer = sum(_broadcast_to(where, np.shape(a)), axis, dtype=dtype, keepdims=keepdims) normalizer = normalizer - ddof result = sum(centered, axis, keepdims=keepdims, where=where) out = lax.div(result, lax.convert_element_type(normalizer, result.dtype)) return lax.convert_element_type(out, dtype)
def _closure_convert_for_avals(fun, in_tree, in_avals): wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) jaxpr, out_pvals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals) out_tree = out_tree() # We only want to closure convert for constants with respect to which we're # differentiating. As a proxy for that, we hoist consts with float dtype. # TODO(frostig,mattjj): revise this approach from jax.numpy import inexact is_float = lambda c: dtypes.issubdtype(dtypes.dtype(c), inexact) (closure_consts, hoisted_consts), merge = partition_list(is_float, consts) num_consts = len(hoisted_consts) def converted_fun(*args_hconsts): num_args = len(args_hconsts) - num_consts args, hoisted_consts = split_list(args_hconsts, [num_args]) consts = merge(closure_consts, hoisted_consts) all_args, in_tree2 = tree_flatten(tuple(args)) assert in_tree == in_tree2 out_flat = core.eval_jaxpr(jaxpr, consts, *all_args) return tree_unflatten(out_tree, out_flat) return converted_fun, hoisted_consts
def _nan_reduction(a, name, jnp_reduction, init_val, nan_if_all_nan, axis=None, keepdims=None, **kwargs): _check_arraylike(name, a) if not dtypes.issubdtype(dtypes.dtype(a), np.inexact): return jnp_reduction(a, axis=axis, keepdims=keepdims, **kwargs) out = jnp_reduction(_where(lax_internal._isnan(a), _reduction_init_val(a, init_val), a), axis=axis, keepdims=keepdims, **kwargs) if nan_if_all_nan: return _where( all(lax_internal._isnan(a), axis=axis, keepdims=keepdims), _lax_const(a, np.nan), out) else: return out
def _get_min_identity(dt): return np.inf if dtypes.issubdtype(dt, np.floating) else np.iinfo(dt).max
def _to_inexact_dtype(dtype): """Promotes a dtype into an inexact dtype, if it is not already one.""" return dtype if dtypes.issubdtype( dtype, np.inexact) else dtypes.promote_types(dtype, dtypes.float_)
def op(*args): zero = lambda x: lax.full_like(x, shape=(), fill_value=0) args = (x if dtypes.issubdtype(dtypes.dtype(x), np.bool_) else lax.ne( x, zero(x)) for x in args) return bitwise_op(*_promote_args(np_op.__name__, *args))
def _reduction(a, name, np_fun, op, init_val, has_identity=True, preproc=None, bool_op=None, upcast_f16_for_computation=False, axis=None, dtype=None, out=None, keepdims=False, initial=None, where_=None, parallel_reduce=None): bool_op = bool_op or op # Note: we must accept out=None as an argument, because numpy reductions delegate to # object methods. For example `np.sum(x)` will call `x.sum()` if the `sum()` method # exists, passing along all its arguments. if out is not None: raise NotImplementedError( f"The 'out' argument to jnp.{name} is not supported.") _check_arraylike(name, a) lax_internal._check_user_dtype_supported(dtype, name) axis = core.concrete_or_error(None, axis, f"axis argument to jnp.{name}().") if initial is None and not has_identity and where_ is not None: raise ValueError( f"reduction operation {name} does not have an identity, so to use a " f"where mask one has to specify 'initial'") a = a if isinstance(a, ndarray) else _asarray(a) a = preproc(a) if preproc else a pos_dims, dims = _reduction_dims(a, axis) if initial is None and not has_identity: shape = np.shape(a) if not _all(core.greater_equal_dim(shape[d], 1) for d in pos_dims): raise ValueError( f"zero-size array to reduction operation {name} which has no identity" ) result_dtype = dtypes.canonicalize_dtype( dtype or dtypes.dtype(np_fun(np.ones((), dtype=dtypes.dtype(a))))) if upcast_f16_for_computation and dtypes.issubdtype( result_dtype, np.inexact): computation_dtype = _upcast_f16(result_dtype) else: computation_dtype = result_dtype a = lax.convert_element_type(a, computation_dtype) op = op if computation_dtype != np.bool_ else bool_op # NB: in XLA, init_val must be an identity for the op, so the user-specified # initial value must be applied afterward. init_val = _reduction_init_val(a, init_val) if where_ is not None: a = _where(where_, a, init_val) if pos_dims is not dims: if parallel_reduce is None: raise NotImplementedError( f"Named reductions not implemented for jnp.{name}()") result = parallel_reduce(a, dims) else: result = lax.reduce(a, init_val, op, dims) if initial is not None: result = op(lax.convert_element_type(initial, a.dtype), result) if keepdims: result = lax.expand_dims(result, pos_dims) return lax.convert_element_type(result, dtype or result_dtype)
def absolute(x): _check_arraylike('absolute', x) dt = dtypes.dtype(x) return x if dt == np.bool_ or dtypes.issubdtype( dt, np.unsignedinteger) else lax.abs(x)
def copysign(x1, x2): x1, x2 = _promote_args_inexact("copysign", x1, x2) if dtypes.issubdtype(dtypes.dtype(x1), np.complexfloating): raise TypeError("copysign does not support complex-valued inputs") return _where(signbit(x2).astype(bool), -lax.abs(x1), lax.abs(x1))
def divmod(x1, x2): x1, x2 = _promote_args("divmod", x1, x2) if dtypes.issubdtype(dtypes.dtype(x1), np.integer): return floor_divide(x1, x2), remainder(x1, x2) else: return _float_divmod(x1, x2)
def fmod(x1, x2): _check_arraylike("fmod", x1, x2) if dtypes.issubdtype(dtypes.result_type(x1, x2), np.integer): x2 = _where(x2 == 0, lax_internal._ones(x2), x2) return lax.rem(*_promote_args("fmod", x1, x2))