def threefry_seed(seed: int) -> jnp.ndarray: """Create a single raw threefry PRNG key given an integer seed. Args: seed: a 64- or 32-bit integer used as the value of the key. Returns: The PRNG key contents, modeled as an array of shape (2,) and dtype uint32. The key is constructed from a 64-bit seed by effectively bit-casting to a pair of uint32 values (or from a 32-bit seed by first padding out with zeros). """ # Avoid overflowerror in X32 mode by first converting ints to int64. # This breaks JIT invariance for large ints, but supports the common # use-case of instantiating with Python hashes in X32 mode. if isinstance(seed, int): seed_arr = jnp.asarray(np.int64(seed)) else: seed_arr = jnp.asarray(seed) if seed_arr.shape: raise TypeError(f"PRNG key seed must be a scalar; got {seed!r}.") if not np.issubdtype(seed_arr.dtype, np.integer): raise TypeError(f"PRNG key seed must be an integer; got {seed!r}") convert = lambda k: lax.reshape(lax.convert_element_type(k, np.uint32), [1]) k1 = convert( lax.shift_right_logical(seed_arr, lax_internal._const(seed_arr, 32))) k2 = convert(jnp.bitwise_and(seed_arr, np.uint32(0xFFFFFFFF))) return lax.concatenate([k1, k2], 0)
def _cond_batching_rule(axis_size, axis_name, main_type, args, dims, branches, linear): index, *ops = args index_dim, *op_dims = dims if index_dim is not batching.not_mapped: # Convert to a lax.select. While we could get away with not broadcasting # some operands yet, because all outputs must be broadcast together anyway # for the select we broadcast the input operands for simplicity and leave # optimizations to XLA. # TODO(mattjj,frostig): assumes branches are side-effect-free, revise! index, *ops = (batching.bdim_at_front(x, d, axis_size) for x, d in zip(args, dims)) in_batched = [True] * len(branches[0].in_avals) out_batched = [True] * len(branches[0].out_avals) branches_batched = [ batching.batch_jaxpr(jaxpr, axis_size, in_batched, out_batched, axis_name, main_type)[0] for jaxpr in branches ] branch_outs = [] for i, jaxpr in enumerate(branches_batched): # Perform a select on the inputs for safety of reverse-mode autodiff; see # https://github.com/google/jax/issues/1052 predicate = lax.eq(index, lax._const(index, i)) ops_ = [ _bcast_select(predicate, x, lax.stop_gradient(x)) for x in ops ] branch_outs.append(core.jaxpr_as_fun(jaxpr)(*ops_)) out = [_bcast_select_n(index, *outs) for outs in zip(*branch_outs)] return out, [0 if b else None for b in out_batched] else: ops_bat = [d is not batching.not_mapped for d in op_dims] ops = [ batching.moveaxis(x, d, 0) if b else x for b, x, d in zip(ops_bat, ops, op_dims) ] branches_out_bat = [ batching.batch_jaxpr(jaxpr, axis_size, ops_bat, False, axis_name, main_type)[1] for jaxpr in branches ] out_bat = [any(bat) for bat in zip(*branches_out_bat)] branches_batched = tuple( batching.batch_jaxpr(jaxpr, axis_size, ops_bat, out_bat, axis_name, main_type)[0] for jaxpr in branches) out_dims = [0 if b else batching.not_mapped for b in out_bat] out = cond_p.bind(index, *ops, branches=branches_batched, linear=linear) return out, out_dims
def _eye(cls, N, M, k, *, dtype=None, index_dtype='int32'): if k > 0: diag_size = min(N, M - k) else: diag_size = min(N + k, M) if diag_size <= 0: # if k is out of range, return an empty matrix. return cls._empty((N, M), dtype=dtype, index_dtype=index_dtype) k = jnp.asarray(k) data = jnp.ones(diag_size, dtype=dtype) idx = jnp.arange(diag_size, dtype=index_dtype) zero = _const(idx, 0) k = _const(idx, k) row = lax.sub(idx, lax.cond(k >= 0, lambda: zero, lambda: k)) col = lax.add(idx, lax.cond(k <= 0, lambda: zero, lambda: k)) return cls((data, row, col), shape=(N, M), rows_sorted=True, cols_sorted=True)
def _eye(cls, N, M, k, *, dtype=None, index_dtype='int32'): if k > 0: diag_size = min(N, M - k) else: diag_size = min(N + k, M) if diag_size <= 0: # if k is out of range, return an empty matrix. return cls._empty((N, M), dtype=dtype, index_dtype=index_dtype) k = jnp.asarray(k) data = jnp.ones(diag_size, dtype=dtype) idx = jnp.arange(diag_size, dtype=index_dtype) zero = _const(idx, 0) k = _const(idx, k) col = lax.add(idx, lax.cond(k <= 0, lambda: zero, lambda: k)) indices = col.astype(index_dtype) # TODO(jakevdp): this can be done more efficiently. row = lax.sub(idx, lax.cond(k >= 0, lambda: zero, lambda: k)) indptr = jnp.zeros(N + 1, dtype=index_dtype).at[1:].set( jnp.cumsum(jnp.bincount(row, length=N).astype(index_dtype))) return cls((data, indices, indptr), shape=(N, M))
def _atan2_taylor(primals_in, series_in): x, y = primals_in primal_out = lax.atan2(x, y) x, series = jet(lax.div, primals_in, series_in) one = lax_internal._const(x, 1) c0, cs = jet(lambda x: lax.div(one, 1 + lax.square(x)), (x, ), (series, )) c = [c0] + cs u = [x] + series v = [primal_out] + [None] * len(series) for k in range(1, len(v)): v[k] = fact(k - 1) * sum( _scale(k, j) * c[k - j] * u[j] for j in range(1, k + 1)) primal_out, *series_out = v return primal_out, series_out
def _reduce_window_prod(operand: Array, window_dimensions: core.Shape, window_strides: Sequence[int], padding: Sequence[Tuple[int, int]], base_dilation: Optional[Sequence[int]] = None, window_dilation: Optional[Sequence[int]] = None) -> Array: init_value = lax._const(operand, 1) jaxpr, consts = lax._reduction_jaxpr(lax.mul, lax._abstractify(init_value)) if base_dilation is None: base_dilation = (1,) * len(window_dimensions) if window_dilation is None: window_dilation = (1,) * len(window_dimensions) out, = reduce_window_p.bind( operand, init_value, jaxpr=jaxpr, consts=consts, window_dimensions=tuple(window_dimensions), window_strides=tuple(window_strides), padding=tuple(padding), base_dilation=tuple(base_dilation), window_dilation=tuple(window_dilation)) return out
x, = primals_in series, = series_in primal_out = prim.bind(x) c0, cs = jet(deriv, primals_in, series_in) c = [c0] + cs u = [x] + series v = [primal_out] + [None] * len(series) for k in range(1, len(v)): v[k] = fact(k - 1) * sum( _scale(k, j) * c[k - j] * u[j] for j in range(1, k + 1)) primal_out, *series_out = v return primal_out, series_out def_deriv( lax.erf_p, lambda x: lax.mul(lax._const(x, 2. / np.sqrt(np.pi)), lax.exp(lax.neg(lax.square(x))))) def def_comp(prim, comp): """ Define the jet rule for a primitive in terms of a composition of simpler primitives. """ jet_rules[prim] = partial(jet, comp) def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1) def_comp(lax.log1p_p, lambda x: lax.log(1 + x)) def_comp(lax.sqrt_p, lambda x: x**0.5) def_comp(lax.rsqrt_p, lambda x: x**-0.5) def_comp(lax.asinh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1)))
def deriv_prop(prim, deriv, primals_in, series_in): x, = primals_in series, = series_in primal_out = prim.bind(x) c0, cs = jet(deriv, primals_in, series_in) c = [c0] + cs u = [x] + series v = [primal_out] + [None] * len(series) for k in range(1, len(v)): v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1)) primal_out, *series_out = v return primal_out, series_out def_deriv(lax.erf_p, lambda x: lax.mul(lax._const(x, 2. / np.sqrt(np.pi)), lax.exp(lax.neg(lax.square(x))))) def def_comp(prim, comp): """ Define the jet rule for a primitive in terms of a composition of simpler primitives. """ jet_rules[prim] = partial(jet, comp) def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1) def_comp(lax.log1p_p, lambda x: lax.log(1 + x)) def_comp(lax.sqrt_p, lambda x: x ** 0.5) def_comp(lax.rsqrt_p, lambda x: x ** -0.5) def_comp(lax.asinh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1))) def_comp(lax.acosh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) - 1)))
def norm(x, ord=None, axis: Union[None, Tuple[int, ...], int] = None, keepdims=False): x = _promote_arg_dtypes(jnp.asarray(x)) x_shape = jnp.shape(x) ndim = len(x_shape) if axis is None: # NumPy has an undocumented behavior that admits arbitrary rank inputs if # `ord` is None: https://github.com/numpy/numpy/issues/14215 if ord is None: return jnp.sqrt( jnp.sum(jnp.real(x * jnp.conj(x)), keepdims=keepdims)) axis = tuple(range(ndim)) elif isinstance(axis, tuple): axis = tuple(canonicalize_axis(x, ndim) for x in axis) else: axis = (canonicalize_axis(axis, ndim), ) num_axes = len(axis) if num_axes == 1: if ord is None or ord == 2: return jnp.sqrt( jnp.sum(jnp.real(x * jnp.conj(x)), axis=axis, keepdims=keepdims)) elif ord == jnp.inf: return jnp.amax(jnp.abs(x), axis=axis, keepdims=keepdims) elif ord == -jnp.inf: return jnp.amin(jnp.abs(x), axis=axis, keepdims=keepdims) elif ord == 0: return jnp.sum(x != 0, dtype=jnp.finfo(lax.dtype(x)).dtype, axis=axis, keepdims=keepdims) elif ord == 1: # Numpy has a special case for ord == 1 as an optimization. We don't # really need the optimization (XLA could do it for us), but the Numpy # code has slightly different type promotion semantics, so we need a # special case too. return jnp.sum(jnp.abs(x), axis=axis, keepdims=keepdims) else: abs_x = jnp.abs(x) ord = lax_internal._const(abs_x, ord) out = jnp.sum(abs_x**ord, axis=axis, keepdims=keepdims) return jnp.power(out, 1. / ord) elif num_axes == 2: row_axis, col_axis = cast(Tuple[int, ...], axis) if ord is None or ord in ('f', 'fro'): return jnp.sqrt( jnp.sum(jnp.real(x * jnp.conj(x)), axis=axis, keepdims=keepdims)) elif ord == 1: if not keepdims and col_axis > row_axis: col_axis -= 1 return jnp.amax(jnp.sum(jnp.abs(x), axis=row_axis, keepdims=keepdims), axis=col_axis, keepdims=keepdims) elif ord == -1: if not keepdims and col_axis > row_axis: col_axis -= 1 return jnp.amin(jnp.sum(jnp.abs(x), axis=row_axis, keepdims=keepdims), axis=col_axis, keepdims=keepdims) elif ord == jnp.inf: if not keepdims and row_axis > col_axis: row_axis -= 1 return jnp.amax(jnp.sum(jnp.abs(x), axis=col_axis, keepdims=keepdims), axis=row_axis, keepdims=keepdims) elif ord == -jnp.inf: if not keepdims and row_axis > col_axis: row_axis -= 1 return jnp.amin(jnp.sum(jnp.abs(x), axis=col_axis, keepdims=keepdims), axis=row_axis, keepdims=keepdims) elif ord in ('nuc', 2, -2): x = jnp.moveaxis(x, axis, (-2, -1)) if ord == 2: reducer = jnp.amax elif ord == -2: reducer = jnp.amin else: reducer = jnp.sum y = reducer(svd(x, compute_uv=False), axis=-1) if keepdims: y = jnp.expand_dims(y, axis) return y else: raise ValueError("Invalid order '{}' for matrix norm.".format(ord)) else: raise ValueError( "Invalid axis values ({}) for jnp.linalg.norm.".format(axis))