def jvpfun(instantiate, transform_stack, primals, tangents): tangents = [ Zero.from_value(t) if not isinstance(t, Zero) and dtype(t) is float0 else t for t in tangents ] ctx = (source_info_util.transform_name_stack('jvp') if transform_stack else contextlib.nullcontext()) with core.new_main(JVPTrace) as main, ctx: out_primals, out_tangents = yield (main, primals, tangents), {} del main if type(instantiate) is bool: instantiate = [instantiate] * len(out_tangents) out_tangents = [ instantiate_zeros(t) if inst else t for t, inst in zip(out_tangents, instantiate) ] yield out_primals, out_tangents
def _power(x1, x2): x1, x2 = _promote_args("power", x1, x2) dtype = dtypes.dtype(x1) if not dtypes.issubdtype(dtype, np.integer): return lax.pow(x1, x2) # Integer power => use binary exponentiation. # TODO(phawkins): add integer pow support to XLA. bits = 6 # Anything more would overflow for any x1 > 1 zero = _constant_like(x2, 0) one = _constant_like(x2, 1) # Initialize acc carefully such that pow(0, x2) is zero for x2 != 0 acc = _where(lax.bitwise_and(lax.eq(x1, zero), lax.ne(x2, zero)), zero, one) for _ in range(bits): acc = _where(lax.bitwise_and(x2, one), lax.mul(acc, x1), acc) x1 = lax.mul(x1, x1) x2 = lax.shift_right_logical(x2, one) return acc
def _ravel_list(lst): if not lst: return jnp.array([], jnp.float32), lambda _: [] from_dtypes = [dtypes.dtype(l) for l in lst] to_dtype = dtypes.result_type(*from_dtypes) sizes, shapes = unzip2((jnp.size(x), jnp.shape(x)) for x in lst) indices = np.cumsum(sizes) def unravel(arr): chunks = jnp.split(arr, indices[:-1]) with warnings.catch_warnings(): warnings.simplefilter( "ignore") # ignore complex-to-real cast warning return [ lax.convert_element_type(chunk.reshape(shape), dtype) for chunk, shape, dtype in zip(chunks, shapes, from_dtypes) ] ravel = lambda e: jnp.ravel(lax.convert_element_type(e, to_dtype)) raveled = jnp.concatenate([ravel(e) for e in lst]) return raveled, unravel
def frexp(x): _check_arraylike("frexp", x) x, = _promote_dtypes_inexact(x) if dtypes.issubdtype(x.dtype, np.complexfloating): raise TypeError("frexp does not support complex-valued inputs") dtype = dtypes.dtype(x) info = dtypes.finfo(dtype) mask = (1 << info.nexp) - 1 bias = ((1 << info.nexp) - 1) >> 1 x1, x2 = _normalize_float(x) x2 += ((x1 >> info.nmant) & mask) - bias + 1 x1 &= ~(mask << info.nmant) x1 |= (bias - 1) << info.nmant x1 = lax.bitcast_convert_type(x1, dtype) cond = isinf(x) | isnan(x) | (x == 0) x2 = _where(cond, lax_internal._zeros(x2), x2) return _where(cond, x, x1), lax.convert_element_type(x2, np.int32)
def _ravel_list(lst): if not lst: return jnp.array([], jnp.float32), lambda _: [] from_dtypes = [dtypes.dtype(l) for l in lst] to_dtype = dtypes.result_type(*from_dtypes) sizes, shapes = unzip2((jnp.size(x), jnp.shape(x)) for x in lst) indices = np.cumsum(sizes) if all(dt == to_dtype for dt in from_dtypes): # Skip any dtype conversion, resulting in a dtype-polymorphic `unravel`. # See https://github.com/google/jax/issues/7809. del from_dtypes, to_dtype def unravel(arr): chunks = jnp.split(arr, indices[:-1]) return [ chunk.reshape(shape) for chunk, shape in zip(chunks, shapes) ] raveled = jnp.concatenate([jnp.ravel(e) for e in lst]) return raveled, unravel # When there is more than one distinct input dtype, we perform type # conversions and produce a dtype-specific unravel function. def unravel(arr): arr_dtype = dtypes.dtype(arr) if arr_dtype != to_dtype: raise TypeError( f"unravel function given array of dtype {arr_dtype}, " f"but expected dtype {to_dtype}") chunks = jnp.split(arr, indices[:-1]) with warnings.catch_warnings(): warnings.simplefilter( "ignore") # ignore complex-to-real cast warning return [ lax.convert_element_type(chunk.reshape(shape), dtype) for chunk, shape, dtype in zip(chunks, shapes, from_dtypes) ] ravel = lambda e: jnp.ravel(lax.convert_element_type(e, to_dtype)) raveled = jnp.concatenate([ravel(e) for e in lst]) return raveled, unravel
def floor_divide(x1, x2): x1, x2 = _promote_args("floor_divide", x1, x2) dtype = dtypes.dtype(x1) if dtypes.issubdtype(dtype, np.integer): quotient = lax.div(x1, x2) select = logical_and(lax.sign(x1) != lax.sign(x2), lax.rem(x1, x2) != 0) # TODO(mattjj): investigate why subtracting a scalar was causing promotion return _where(select, quotient - 1, quotient) elif dtypes.issubdtype(dtype, np.complexfloating): x1r = lax.real(x1) x1i = lax.imag(x1) x2r = lax.real(x2) x2i = lax.imag(x2) which = lax.ge(lax.abs(x2r), lax.abs(x2i)) rat1 = _where(which, lax.full_like(x2i, 1), lax.div(x2r, x2i)) rat2 = _where(which, lax.div(x2i, x2r), _lax_const(x2i, 1)) out = lax.floor(lax.div(lax.add(lax.mul(x1r, rat1), lax.mul(x1i, rat2)), lax.add(lax.mul(x2r, rat1), lax.mul(x2i, rat2)))) return lax.convert_element_type(out, dtype) else: return _float_divmod(x1, x2)[0]
def nanvar(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, out=None, ddof=0, keepdims=False, where=None): _check_arraylike("nanvar", a) lax_internal._check_user_dtype_supported(dtype, "nanvar") if out is not None: raise NotImplementedError( "The 'out' argument to jnp.nanvar is not supported.") computation_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype) a = a.astype(computation_dtype) a_mean = nanmean(a, axis, dtype=computation_dtype, keepdims=True, where=where) centered = _where(lax_internal._isnan(a), 0, lax.sub(a, a_mean)) # double-where trick for gradients. if dtypes.issubdtype(centered.dtype, np.complexfloating): centered = lax.real(lax.mul(centered, lax.conj(centered))) else: centered = lax.square(centered) normalizer = sum(lax_internal.bitwise_not(lax_internal._isnan(a)), axis=axis, keepdims=keepdims, where=where) normalizer = normalizer - ddof normalizer_mask = lax.le(normalizer, 0) result = sum(centered, axis, keepdims=keepdims, where=where) result = _where(normalizer_mask, np.nan, result) divisor = _where(normalizer_mask, 1, normalizer) out = lax.div(result, lax.convert_element_type(divisor, result.dtype)) return lax.convert_element_type(out, dtype)
def _var(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, out=None, ddof=0, keepdims=False, *, where=None): _check_arraylike("var", a) lax_internal._check_user_dtype_supported(dtype, "var") if out is not None: raise NotImplementedError( "The 'out' argument to jnp.var is not supported.") computation_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype) a = a.astype(computation_dtype) a_mean = mean(a, axis, dtype=computation_dtype, keepdims=True, where=where) centered = lax.sub(a, a_mean) if dtypes.issubdtype(centered.dtype, np.complexfloating): centered = lax.real(lax.mul(centered, lax.conj(centered))) else: centered = lax.square(centered) if where is None: if axis is None: normalizer = core.dimension_as_value(np.size(a)) else: normalizer = core.dimension_as_value(_axis_size(a, axis)) else: normalizer = sum(_broadcast_to(where, np.shape(a)), axis, dtype=dtype, keepdims=keepdims) normalizer = normalizer - ddof result = sum(centered, axis, keepdims=keepdims, where=where) out = lax.div(result, lax.convert_element_type(normalizer, result.dtype)) return lax.convert_element_type(out, dtype)
def _closure_convert_for_avals(fun, in_tree, in_avals): wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) jaxpr, out_pvals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals) out_tree = out_tree() # We only want to closure convert for constants with respect to which we're # differentiating. As a proxy for that, we hoist consts with float dtype. # TODO(frostig,mattjj): revise this approach from jax.numpy import inexact is_float = lambda c: dtypes.issubdtype(dtypes.dtype(c), inexact) (closure_consts, hoisted_consts), merge = partition_list(is_float, consts) num_consts = len(hoisted_consts) def converted_fun(*args_hconsts): num_args = len(args_hconsts) - num_consts args, hoisted_consts = split_list(args_hconsts, [num_args]) consts = merge(closure_consts, hoisted_consts) all_args, in_tree2 = tree_flatten(tuple(args)) assert in_tree == in_tree2 out_flat = core.eval_jaxpr(jaxpr, consts, *all_args) return tree_unflatten(out_tree, out_flat) return converted_fun, hoisted_consts
def _nan_reduction(a, name, jnp_reduction, init_val, nan_if_all_nan, axis=None, keepdims=None, **kwargs): _check_arraylike(name, a) if not dtypes.issubdtype(dtypes.dtype(a), np.inexact): return jnp_reduction(a, axis=axis, keepdims=keepdims, **kwargs) out = jnp_reduction(_where(lax_internal._isnan(a), _reduction_init_val(a, init_val), a), axis=axis, keepdims=keepdims, **kwargs) if nan_if_all_nan: return _where( all(lax_internal._isnan(a), axis=axis, keepdims=keepdims), _lax_const(a, np.nan), out) else: return out
def signbit(x): x, = _promote_args("signbit", x) dtype = dtypes.dtype(x) if dtypes.issubdtype(dtype, np.integer): return lax.lt(x, _constant_like(x, 0)) elif dtypes.issubdtype(dtype, np.bool_): return lax.full_like(x, False, dtype=np.bool_) elif not dtypes.issubdtype(dtype, np.floating): raise ValueError("jax.numpy.signbit is not well defined for %s" % dtype) # TPU supports BF16 but not S16 types, so as a workaround, convert BF16 to # F32. if dtype == dtypes.bfloat16: dtype = np.float32 x = lax.convert_element_type(x, np.float32) info = dtypes.finfo(dtype) if info.bits not in _INT_DTYPES: raise NotImplementedError( "jax.numpy.signbit only supports 16, 32, and 64-bit types.") int_type = _INT_DTYPES[info.bits] x = lax.bitcast_convert_type(x, int_type) return lax.convert_element_type(x >> (info.nexp + info.nmant), np.bool_)
def testDtypeFromValue(self, dtype): self.assertEqual(dtypes.dtype(dtype.type(0)), dtype)
def testDtypeFromScalarValue(self, typ): self.assertEqual(dtypes.dtype(typ(0)), dtypes.python_scalar_dtypes[typ])
def _result_dtype(op, *args): """Compute result dtype of applying op to arguments with given dtypes.""" args = [np.ones((0,) * np.ndim(arg), dtypes.dtype(arg)) for arg in args] return dtypes.dtype(op(*args))
def _reduction(a, name, np_fun, op, init_val, has_identity=True, preproc=None, bool_op=None, upcast_f16_for_computation=False, axis=None, dtype=None, out=None, keepdims=False, initial=None, where_=None, parallel_reduce=None): bool_op = bool_op or op # Note: we must accept out=None as an argument, because numpy reductions delegate to # object methods. For example `np.sum(x)` will call `x.sum()` if the `sum()` method # exists, passing along all its arguments. if out is not None: raise NotImplementedError( f"The 'out' argument to jnp.{name} is not supported.") _check_arraylike(name, a) lax_internal._check_user_dtype_supported(dtype, name) axis = core.concrete_or_error(None, axis, f"axis argument to jnp.{name}().") if initial is None and not has_identity and where_ is not None: raise ValueError( f"reduction operation {name} does not have an identity, so to use a " f"where mask one has to specify 'initial'") a = a if isinstance(a, ndarray) else _asarray(a) a = preproc(a) if preproc else a pos_dims, dims = _reduction_dims(a, axis) if initial is None and not has_identity: shape = np.shape(a) if not _all(core.greater_equal_dim(shape[d], 1) for d in pos_dims): raise ValueError( f"zero-size array to reduction operation {name} which has no identity" ) result_dtype = dtypes.canonicalize_dtype( dtype or dtypes.dtype(np_fun(np.ones((), dtype=dtypes.dtype(a))))) if upcast_f16_for_computation and dtypes.issubdtype( result_dtype, np.inexact): computation_dtype = _upcast_f16(result_dtype) else: computation_dtype = result_dtype a = lax.convert_element_type(a, computation_dtype) op = op if computation_dtype != np.bool_ else bool_op # NB: in XLA, init_val must be an identity for the op, so the user-specified # initial value must be applied afterward. init_val = _reduction_init_val(a, init_val) if where_ is not None: a = _where(where_, a, init_val) if pos_dims is not dims: if parallel_reduce is None: raise NotImplementedError( f"Named reductions not implemented for jnp.{name}()") result = parallel_reduce(a, dims) else: result = lax.reduce(a, init_val, op, dims) if initial is not None: result = op(lax.convert_element_type(initial, a.dtype), result) if keepdims: result = lax.expand_dims(result, pos_dims) return lax.convert_element_type(result, dtype or result_dtype)
def divmod(x1, x2): x1, x2 = _promote_args("divmod", x1, x2) if dtypes.issubdtype(dtypes.dtype(x1), np.integer): return floor_divide(x1, x2), remainder(x1, x2) else: return _float_divmod(x1, x2)
def testDtypeFromNone(self): with self.assertRaisesRegex(ValueError, "Invalid argument to dtype"): dtypes.dtype(None)
def replace_float0s(primal, tangent): if dtype(tangent) is float0: return zeros_like_jaxval(primal) else: return tangent
def dtype(self): return dtypes.dtype(self.val)
def op(*args): zero = lambda x: lax.full_like(x, shape=(), fill_value=0) args = (x if dtypes.issubdtype(dtypes.dtype(x), np.bool_) else lax.ne( x, zero(x)) for x in args) return bitwise_op(*_promote_args(np_op.__name__, *args))
def absolute(x): _check_arraylike('absolute', x) dt = dtypes.dtype(x) return x if dt == np.bool_ or dtypes.issubdtype( dt, np.unsignedinteger) else lax.abs(x)
def _constant_like(x, const): return np.array(const, dtype=dtypes.dtype(x))
def testDtypeFromDtype(self, dtype): self.assertEqual(dtypes.dtype(dtype), dtype)
def recast_to_float0(primal, tangent): if core.primal_dtype_to_tangent_dtype(dtype(primal)) == float0: return Zero(get_aval(primal).at_least_vspace()) else: return tangent
def testDtypeFromString(self, dtype): self.assertEqual(dtypes.dtype(str(dtype)), dtype)
def _canonicalize_arg(arg): return np.asarray(arg, dtype=dtypes.dtype(arg, canonicalize=True), order='C')
def _bint_ir_types(aval: core.AbstractBInt) -> Sequence[ir.Type]: return (ir.RankedTensorType.get((), dtype_to_ir_type(dtypes.dtype('int32'))), )
def copysign(x1, x2): x1, x2 = _promote_args_inexact("copysign", x1, x2) if dtypes.issubdtype(dtypes.dtype(x1), np.complexfloating): raise TypeError("copysign does not support complex-valued inputs") return _where(signbit(x2).astype(bool), -lax.abs(x1), lax.abs(x1))