def test_ravel_pytree(pytree): flat, unravel_fn = ravel_pytree(pytree) unravel = unravel_fn(flat) tree_flatten(tree_multimap(lambda x, y: assert_allclose(x, y), unravel, pytree)) assert all(tree_flatten(tree_multimap(lambda x, y: canonicalize_dtype(lax.dtype(x)) == canonicalize_dtype(lax.dtype(y)), unravel, pytree))[0])
def _uniform(key, shape, dtype, minval, maxval): if not onp.issubdtype(dtype, onp.floating): raise TypeError("uniform only accepts floating point dtypes.") dtype = xla_bridge.canonicalize_dtype(dtype) minval = lax.convert_element_type(minval, dtype) maxval = lax.convert_element_type(maxval, dtype) finfo = onp.finfo(dtype) nbits, nmant = finfo.bits, finfo.nmant if nbits not in (32, 64): raise TypeError("uniform only accepts 32- or 64-bit dtypes.") bits = _random_bits(key, nbits, shape) # The strategy here is to randomize only the mantissa bits with an exponent of # 1 (after applying the bias), then shift and scale to the desired range. The # bit-level transformation we use relies on Numpy and XLA having bit-for-bit # equivalent float representations, which might not be true on all platforms. float_bits = lax.bitwise_or( lax.shift_right_logical(bits, onp.array(nbits - nmant, lax.dtype(bits))), onp.array(1., dtype).view(onp.uint32 if nbits == 32 else onp.uint64)) floats = lax.bitcast_convert_type(float_bits, dtype) - onp.array(1., dtype) return lax.max(minval, lax.reshape(floats * (maxval - minval) + minval, shape))
def prepare_single_layer_model(input_size, output_size, width, key): init_random_params, predict = stax.serial(Dense(width), Relu, Dense(output_size), LogSoftmax) key, split = random.split(key) _, params = init_random_params(split, (-1, input_size)) cast = lambda x: x.astype(canonicalize_dtype(onp.float64)) params = tree_util.tree_map(cast, params) return predict, params, key
def __init__(self, value=0., log_density=0., event_ndim=0, validate_args=None): if event_ndim > np.ndim(value): raise ValueError('Expected event_dim <= v.dim(), actual {} vs {}' .format(event_ndim, np.ndim(value))) batch_dim = np.ndim(value) - event_ndim batch_shape = np.shape(value)[:batch_dim] event_shape = np.shape(value)[batch_dim:] self.value = lax.convert_element_type(value, xla_bridge.canonicalize_dtype(np.float64)) # NB: following Pyro implementation, log_density should be broadcasted to batch_shape self.log_density = promote_shapes(log_density, shape=batch_shape)[0] super(Delta, self).__init__(batch_shape, event_shape, validate_args=validate_args)
def get_batch(input_size, output_size, batch_size, key): key, split = random.split(key) # jax.random will always generate float32 even if jax_enable_x64==True. xs = random.normal(split, shape=(batch_size, input_size), dtype=canonicalize_dtype(onp.float64)) key, split = random.split(key) ys = random.randint(split, minval=0, maxval=output_size, shape=(batch_size, )) ys = to_onehot(ys, output_size) return (xs, ys), key
def randint(key, shape, minval, maxval, dtype=onp.int32): """Sample uniform random values in [minval, maxval) with given shape/dtype. Args: key: a PRNGKey used as the random key. shape: a tuple of nonnegative integers representing the shape. minval: optional, a minimum (inclusive) value for the range (default 0). maxval: optional, a maximum (exclusive) value for the range (default 1). dtype: optional, an int dtype for the returned values (default int32). Returns: A random array with the specified shape and dtype. """ if not onp.issubdtype(dtype, onp.integer): raise TypeError("randint only accepts integer dtypes.") dtype = xla_bridge.canonicalize_dtype(dtype) minval = lax.convert_element_type(minval, dtype) maxval = lax.convert_element_type(maxval, dtype) nbits = onp.iinfo(dtype).bits if nbits not in (32, 64): raise TypeError("randint only accepts 32- or 64-bit dtypes.") # if we don't have minval < maxval, just always return minval # https://github.com/google/jax/issues/222 maxval = lax.max(lax.add(minval, onp.array(1, dtype)), maxval) # This algorithm is biased whenever (maxval - minval) is not a power of 2. # We generate double the number of random bits required by the dtype so as to # reduce that bias. k1, k2 = split(key) rbits = lambda key: _random_bits(key, nbits, shape) higher_bits, lower_bits = rbits(k1), rbits(k2) unsigned_dtype = onp.uint32 if nbits == 32 else onp.uint64 span = lax.convert_element_type(maxval - minval, unsigned_dtype) # To compute a remainder operation on an integer that might have twice as many # bits as we can represent in the native unsigned dtype, we compute a # multiplier equal to 2**nbits % span (using that nbits is 32 or 64). multiplier = lax.rem(onp.array(2**16, unsigned_dtype), span) multiplier = lax.rem(lax.mul(multiplier, multiplier), span) if nbits == 64: multiplier = lax.rem(lax.mul(multiplier, multiplier), span) random_offset = lax.add(lax.mul(lax.rem(higher_bits, span), multiplier), lax.rem(lower_bits, span)) random_offset = lax.rem(random_offset, span) return lax.add(minval, lax.convert_element_type(random_offset, dtype))
def normal(key, shape=(), dtype=onp.float64): """Sample standard normal random values with given shape and float dtype. Args: key: a PRNGKey used as the random key. shape: a tuple of nonnegative integers representing the shape. dtype: optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32). Returns: A random array with the specified shape and dtype. """ dtype = xla_bridge.canonicalize_dtype(dtype) return _normal(key, shape, dtype)
def _ravel_list(*leaves): leaves_metadata = tree_map( lambda l: pytree_metadata(np.ravel(l), np.shape(l), np.size(l), canonicalize_dtype(lax.dtype(l))), leaves) leaves_idx = np.cumsum( np.array((0, ) + tuple(d.size for d in leaves_metadata))) def unravel_list(arr): return [ np.reshape(lax.dynamic_slice_in_dim(arr, leaves_idx[i], m.size), m.shape).astype(m.dtype) for i, m in enumerate(leaves_metadata) ] return np.concatenate([m.flat for m in leaves_metadata]), unravel_list
def eig_abstract_eval(operand): if isinstance(operand, ShapedArray): if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]: raise ValueError("Argument to nonsymmetric eigendecomposition must have " "shape [..., n, n], got shape {}".format(operand.shape)) batch_dims = operand.shape[:-2] n = operand.shape[-1] dtype = onp.complex64 if onp.finfo(operand.dtype).bits == 32 else onp.complex128 dtype = xb.canonicalize_dtype(dtype) vl = vr = ShapedArray(batch_dims + (n, n), dtype) w = ShapedArray(batch_dims + (n,), dtype) else: raise NotImplementedError return w, vl, vr
def uniform(key, shape=(), dtype=onp.float64, minval=0., maxval=1.): """Sample uniform random values in [minval, maxval) with given shape/dtype. Args: key: a PRNGKey used as the random key. shape: a tuple of nonnegative integers representing the shape. dtype: optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32). minval: optional, a minimum (inclusive) value for the range (default 0). maxval: optional, a maximum (exclusive) value for the range (default 1). Returns: A random array with the specified shape and dtype. """ dtype = xla_bridge.canonicalize_dtype(dtype) return _uniform(key, shape, dtype, minval, maxval)
def t(key, df, shape=(), dtype=onp.float64): """Sample Student's t random values with given shape and float dtype. Args: key: a PRNGKey used as the random key. df: an array-like broadcastable to `shape` and used as the shape parameter of the random variables. shape: optional, a tuple of nonnegative integers representing the shape (default scalar). dtype: optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32). Returns: A random array with the specified shape and dtype. """ dtype = xla_bridge.canonicalize_dtype(dtype) return _t(key, df, shape, dtype)
def dirichlet(key, alpha, shape=(), dtype=onp.float64): """Sample Cauchy random values with given shape and float dtype. Args: key: a PRNGKey used as the random key. alpha: an array-like with `alpha.shape[:-1]` broadcastable to `shape` and used as the concentration parameter of the random variables. shape: optional, a tuple of nonnegative integers representing the batch shape (defaults to `alpha.shape[:-1]`). dtype: optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32). Returns: A random array with the specified shape and dtype. """ dtype = xla_bridge.canonicalize_dtype(dtype) return _dirichlet(key, alpha, shape, dtype)
def randint(key, shape, minval, maxval, dtype=onp.int64): """Sample uniform random values in [minval, maxval) with given shape/dtype. Args: key: a PRNGKey used as the random key. shape: a tuple of nonnegative integers representing the shape. minval: int or array of ints broadcast-compatible with ``shape``, a minimum (inclusive) value for the range. maxval: int or array of ints broadcast-compatible with ``shape``, a maximum (exclusive) value for the range. dtype: optional, an int dtype for the returned values (default int64 if jax_enable_x64 is true, otherwise int32). Returns: A random array with the specified shape and dtype. """ dtype = xla_bridge.canonicalize_dtype(dtype) return _randint(key, shape, minval, maxval, dtype)
def bernoulli(key, p=onp.float32(0.5), shape=()): """Sample Bernoulli random values with given shape and mean. Args: key: a PRNGKey used as the random key. p: optional, an array-like of floating dtype broadcastable to `shape` for the mean of the random variables (default 0.5). shape: optional, a tuple of nonnegative integers representing the shape (default scalar). Returns: A random array with the specified shape and boolean dtype. """ dtype = xla_bridge.canonicalize_dtype(lax.dtype(p)) if not onp.issubdtype(dtype, onp.floating): msg = "bernoulli probability `p` must have a floating dtype, got {}." raise TypeError(msg.format(dtype)) p = lax.convert_element_type(p, dtype) return _bernoulli(key, p, shape)
def t(key, df, shape=(), dtype=onp.float64): """Sample Student's t random values with given shape and float dtype. Args: key: a PRNGKey used as the random key. df: a float or array of floats broadcast-compatible with ``shape`` representing the parameter of the distribution. shape: optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with ``df``. The default (None) produces a result shape equal to ``df.shape``. dtype: optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32). Returns: A random array with the specified dtype and with shape given by ``shape`` if ``shape`` is not None, or else by ``df.shape``. """ dtype = xla_bridge.canonicalize_dtype(dtype) return _t(key, df, shape, dtype)
def bernoulli(key, p=onp.float32(0.5), shape=None): """Sample Bernoulli random values with given shape and mean. Args: key: a PRNGKey used as the random key. p: optional, a float or array of floats for the mean of the random variables. Must be broadcast-compatible with ``shape``. Default 0.5. shape: optional, a tuple of nonnegative integers representing the result shape. Must be broadcast-compatible with ``p.shape``. The default (None) produces a result shape equal to ``p.shape``. Returns: A random array with boolean dtype and shape given by ``shape`` if ``shape`` is not None, or else ``p.shape``. """ dtype = xla_bridge.canonicalize_dtype(lax.dtype(p)) if not onp.issubdtype(dtype, onp.floating): msg = "bernoulli probability `p` must have a floating dtype, got {}." raise TypeError(msg.format(dtype)) p = lax.convert_element_type(p, dtype) return _bernoulli(key, p, shape)
def beta(key, a, b, shape=None, dtype=onp.float64): """Sample Bernoulli random values with given shape and mean. Args: key: a PRNGKey used as the random key. a: a float or array of floats broadcast-compatible with ``shape`` representing the first parameter "alpha". b: a float or array of floats broadcast-compatible with ``shape`` representing the second parameter "beta". shape: optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with ``a`` and ``b``. The default (None) produces a result shape by broadcasting ``a`` and ``b``. dtype: optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32). Returns: A random array with the specified dtype and shape given by ``shape`` if ``shape`` is not None, or else by broadcasting ``a`` and ``b``. """ dtype = xla_bridge.canonicalize_dtype(dtype) return _beta(key, a, b, shape, dtype)
def _randint(key, shape, minval, maxval, dtype=onp.int32): _check_shape("randint", shape) if not onp.issubdtype(dtype, onp.integer): raise TypeError("randint only accepts integer dtypes.") dtype = xla_bridge.canonicalize_dtype(dtype) minval = lax.convert_element_type(minval, dtype) maxval = lax.convert_element_type(maxval, dtype) nbits = onp.iinfo(dtype).bits if nbits not in (32, 64): raise TypeError("randint only accepts 32- or 64-bit dtypes.") # if we don't have minval < maxval, just always return minval # https://github.com/google/jax/issues/222 maxval = lax.max(lax.add(minval, onp.array(1, dtype)), maxval) # This algorithm is biased whenever (maxval - minval) is not a power of 2. # We generate double the number of random bits required by the dtype so as to # reduce that bias. k1, k2 = split(key) rbits = lambda key: _random_bits(key, nbits, shape) higher_bits, lower_bits = rbits(k1), rbits(k2) unsigned_dtype = onp.uint32 if nbits == 32 else onp.uint64 span = lax.convert_element_type(maxval - minval, unsigned_dtype) # To compute a remainder operation on an integer that might have twice as many # bits as we can represent in the native unsigned dtype, we compute a # multiplier equal to 2**nbits % span (using that nbits is 32 or 64). multiplier = lax.rem(onp.array(2**16, unsigned_dtype), span) multiplier = lax.rem(lax.mul(multiplier, multiplier), span) if nbits == 64: multiplier = lax.rem(lax.mul(multiplier, multiplier), span) random_offset = lax.add(lax.mul(lax.rem(higher_bits, span), multiplier), lax.rem(lower_bits, span)) random_offset = lax.rem(random_offset, span) return lax.add(minval, lax.convert_element_type(random_offset, dtype))
def dirichlet(key, alpha, shape=None, dtype=onp.float64): """Sample Cauchy random values with given shape and float dtype. Args: key: a PRNGKey used as the random key. alpha: an array of shape ``(..., n)`` used as the concentration parameter of the random variables. shape: optional, a tuple of nonnegative integers specifying the result batch shape; that is, the prefix of the result shape excluding the last element of value ``n``. Must be broadcast-compatible with ``alpha.shape[:-1]``. The default (None) produces a result shape equal to ``alpha.shape``. dtype: optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32). Returns: A random array with the specified dtype and shape given by ``shape + (alpha.shape[-1],)`` if ``shape`` is not None, or else ``alpha.shape``. """ dtype = xla_bridge.canonicalize_dtype(dtype) return _dirichlet(key, alpha, shape, dtype)
def truncated_normal(key, lower, upper, shape=None, dtype=onp.float64): """Sample truncated standard normal random values with given shape and dtype. Args: key: a PRNGKey used as the random key. lower: a float or array of floats representing the lower bound for truncation. Must be broadcast-compatible with ``upper``. upper: a float or array of floats representing the upper bound for truncation. Must be broadcast-compatible with ``lower``. shape: optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with ``lower`` and ``upper``. The default (None) produces a result shape by broadcasting ``lower`` and ``upper``. dtype: optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32). Returns: A random array with the specified dtype and shape given by ``shape`` if ``shape`` is not None, or else by broadcasting ``lower`` and ``upper``. """ dtype = xla_bridge.canonicalize_dtype(dtype) return _truncated_normal(key, lower, upper, shape, dtype)
def uniform(key, shape, dtype=onp.float32, minval=0., maxval=1.): """Sample uniform random values in [minval, maxval) with given shape/dtype. Args: key: a PRNGKey used as the random key. shape: a tuple of nonnegative integers representing the shape. dtype: optional, a float dtype for the returned values (default float32). minval: optional, a minimum (inclusive) value for the range (default 0). maxval: optional, a maximum (exclusive) value for the range (default 1). Returns: A random array with the specified shape and dtype. """ if not onp.issubdtype(dtype, onp.floating): raise TypeError("uniform only accepts floating point dtypes.") dtype = xla_bridge.canonicalize_dtype(dtype) minval = lax.convert_element_type(minval, dtype) maxval = lax.convert_element_type(maxval, dtype) finfo = onp.finfo(dtype) nbits, nmant = finfo.bits, finfo.nmant if nbits not in (32, 64): raise TypeError("uniform only accepts 32- or 64-bit dtypes.") bits = _random_bits(key, nbits, shape) # The strategy here is to randomize only the mantissa bits with an exponent of # 1 (after applying the bias), then shift and scale to the desired range. The # bit-level transformation we use relies on Numpy and XLA having bit-for-bit # equivalent float representations, which might not be true on all platforms. float_bits = lax.bitwise_or( lax.shift_right_logical(bits, onp.array(nbits - nmant, lax._dtype(bits))), onp.array(1., dtype).view(onp.uint32 if nbits == 32 else onp.uint64)) floats = lax.bitcast_convert_type(float_bits, dtype) - onp.array(1., dtype) return lax.max(minval, lax.reshape(floats * (maxval - minval) + minval, shape))
def multivariate_normal(key, mean, cov, shape=None, dtype=onp.float64): """Sample multivariate normal random values with given mean and covariance. Args: key: a PRNGKey used as the random key. mean: a mean vector of shape ``(..., n)``. cov: a positive definite covariance matrix of shape ``(..., n, n)``. The batch shape ``...`` must be broadcast-compatible with that of ``mean``. shape: optional, a tuple of nonnegative integers specifying the result batch shape; that is, the prefix of the result shape excluding the last axis. Must be broadcast-compatible with ``mean.shape[:-1]`` and ``cov.shape[:-2]``. The default (None) produces a result batch shape by broadcasting together the batch shapes of ``mean`` and ``cov``. dtype: optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32). Returns: A random array with the specified dtype and shape given by ``shape + mean.shape[-1:]`` if ``shape`` is not None, or else ``broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:]``. """ dtype = xla_bridge.canonicalize_dtype(dtype) return _multivariate_normal(key, mean, cov, shape, dtype)
def float_types(): return set( onp.dtype(xla_bridge.canonicalize_dtype(dtype)) for dtype in [onp.float32, onp.float64])
def get_dtype(x): return canonicalize_dtype(lax.dtype(x))
def poisson(key, rate, shape, dtype=np.int64): dtype = canonicalize_dtype(dtype) return _poisson(key, rate, shape, dtype)
def complex_types(): return sorted(list({onp.dtype(xla_bridge.canonicalize_dtype(dtype)) for dtype in [onp.complex64, onp.complex128]}))
def standard_gamma(key, alpha, shape=(), dtype=np.float64): dtype = xla_bridge.canonicalize_dtype(dtype) return _standard_gamma(key, alpha, shape, dtype)
def _inputs_to_kernel(x1, x2, use_pooling, compute_ntk): """Transforms (batches of) inputs to a `Kernel`. This is a private method. Docstring and example are for internal reference. The kernel contains the empirical covariances between different inputs and their entries (pixels) necessary to compute the covariance of the Gaussian Process corresponding to an infinite Bayesian or gradient-flow-trained neural network. The smallest necessary number of covariance entries is tracked. For example, all networks are assumed to have i.i.d. weights along the channel / feature / logits dimensions, hence covariance between different entries along these dimensions is known to be 0 and is not tracked. Args: x1: a 2D `np.ndarray` of shape `[batch_size_1, n_features]` (dense network) or 4D of shape `[batch_size_1, height, width, channels]` (conv-nets). x2: an optional `np.ndarray` with the same shape as `x1` apart from possibly different leading batch size. `None` means `x2 == x1`. use_pooling: a boolean, indicating whether pooling will be used somewhere in the model. If so, more covariance entries need to be tracked. Is set automatically based on the network topology. Specifically, is set to `False` if a `serial` or `parallel` networks contain a `Flatten` layer and no pooling layers (`AvgPool` or `GlobalAvgPool`). Has no effect for non-convolutional models. compute_ntk: a boolean, `True` to compute both NTK and NNGP kernels, `False` to only compute NNGP. Example: ```python >>> x = np.ones((10, 32, 16, 3)) >>> _inputs_to_kernel(x, None, use_pooling=True, >>> compute_ntk=True).ntk.shape (10, 10, 32, 32, 16, 16) >>> _inputs_to_kernel(x, None, use_pooling=False, >>> compute_ntk=True).ntk.shape (10, 10, 32, 16) >>> x1 = np.ones((10, 128)) >>> x2 = np.ones((20, 128)) >>> _inputs_to_kernel(x, None, use_pooling=True, >>> compute_ntk=False).nngp.shape (10, 20) >>> _inputs_to_kernel(x, None, use_pooling=False, >>> compute_ntk=False).nngp.shape (10, 20) >>> _inputs_to_kernel(x, None, use_pooling=False, >>> compute_ntk=False).ntk None ``` Returns: a `Kernel` object. """ x1 = x1.astype(xla_bridge.canonicalize_dtype(np.float64)) var1 = _get_variance(x1) if x2 is None: x2 = x1 var2 = None else: if x1.shape[1:] != x2.shape[1:]: raise ValueError( '`x1` and `x2` are expected to be batches of' ' inputs with the same shape (apart from the batch size),' ' got %s and %s.' % (str(x1.shape), str(x2.shape))) x2 = x2.astype(xla_bridge.canonicalize_dtype(np.float64)) var2 = _get_variance(x2) if use_pooling and x1.ndim == 4: x2 = np.expand_dims(x2, -1) nngp = np.dot(x1, x2) / x1.shape[-1] nngp = np.transpose(np.squeeze(nngp, -1), (0, 3, 1, 4, 2, 5)) elif x1.ndim == 4 or x1.ndim == 2: nngp = _batch_uncentered_covariance(x1, x2) else: raise ValueError('Inputs must be 2D or 4D `np.ndarray`s of shape ' '`[batch_size, n_features]` or ' '`[batch_size, height, width, channels]`, ' 'got %s.' % str(x1.shape)) ntk = 0. if compute_ntk else None is_gaussian = False is_height_width = True return Kernel(var1, nngp, var2, ntk, is_gaussian, is_height_width)
def init_kernel(init_params, num_warmup, step_size=1.0, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, trajectory_length=2*math.pi, max_tree_depth=10, run_warmup=True, progbar=True, rng=PRNGKey(0)): """ Initializes the HMC sampler. :param init_params: Initial parameters to begin sampling. The type must be consistent with the input type to `potential_fn`. :param int num_warmup: Number of warmup steps; samples generated during warmup are discarded. :param float step_size: Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1. :param bool adapt_step_size: A flag to decide if we want to adapt step_size during warm-up phase using Dual Averaging scheme. :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass matrix during warm-up phase using Welford scheme. :param bool dense_mass: A flag to decide if mass matrix is dense or diagonal (default when ``dense_mass=False``) :param float target_accept_prob: Target acceptance probability for step size adaptation using Dual Averaging. Increasing this value will lead to a smaller step size, hence the sampling will be slower but more robust. Default to 0.8. :param float trajectory_length: Length of a MCMC trajectory for HMC. Default value is :math:`2\\pi`. :param int max_tree_depth: Max depth of the binary tree created during the doubling scheme of NUTS sampler. Defaults to 10. :param bool run_warmup: Flag to decide whether warmup is run. If ``True``, `init_kernel` returns an initial :data:`~numpyro.mcmc.HMCState` that can be used to generate samples using MCMC. Else, returns the arguments and callable that does the initial adaptation. :param bool progbar: Whether to enable progress bar updates. Defaults to ``True``. :param jax.random.PRNGKey rng: random key to be used as the source of randomness. """ step_size = lax.convert_element_type(step_size, xla_bridge.canonicalize_dtype(np.float64)) nonlocal momentum_generator, wa_update, trajectory_len, max_treedepth, wa_steps wa_steps = num_warmup trajectory_len = trajectory_length max_treedepth = max_tree_depth z = init_params z_flat, unravel_fn = ravel_pytree(z) momentum_generator = partial(_sample_momentum, unravel_fn) find_reasonable_ss = partial(find_reasonable_step_size, potential_fn, kinetic_fn, momentum_generator) wa_init, wa_update = warmup_adapter(num_warmup, adapt_step_size=adapt_step_size, adapt_mass_matrix=adapt_mass_matrix, dense_mass=dense_mass, target_accept_prob=target_accept_prob, find_reasonable_step_size=find_reasonable_ss) rng_hmc, rng_wa = random.split(rng) wa_state = wa_init(z, rng_wa, step_size, mass_matrix_size=np.size(z_flat)) r = momentum_generator(wa_state.mass_matrix_sqrt, rng) vv_state = vv_init(z, r) hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, 0, 0., 0., False, wa_state, rng_hmc) # TODO: Remove; this should be the responsibility of the MCMC class. if run_warmup and num_warmup > 0: # JIT if progress bar updates not required if not progbar: hmc_state = fori_loop(0, num_warmup, lambda *args: sample_kernel(args[1]), hmc_state) else: with tqdm.trange(num_warmup, desc='warmup') as t: for i in t: hmc_state = jit(sample_kernel)(hmc_state) t.set_postfix_str(get_diagnostics_str(hmc_state), refresh=False) return hmc_state
def _get_num_steps(step_size, trajectory_length): num_steps = np.clip(trajectory_length / step_size, a_min=1) # NB: casting to np.int64 does not take effect (returns np.int32 instead) # if jax_enable_x64 is False return num_steps.astype(xla_bridge.canonicalize_dtype(np.int64))