def test_scan_prims_disabled(): def f(tree, yz): y, z = yz return tree_map(lambda x: (x + y) * z, tree) Tree = laxtuple("Tree", ["x", "y", "z"]) a = Tree(np.array([1., 2.]), np.array(3., dtype=np.float32), np.array(4., dtype=np.float32)) bs = (np.array([1., 2., 3., 4.]), np.array([4., 3., 2., 1.])) expected_tree = lax.scan(f, a, bs) with control_flow_prims_disabled(): actual_tree = scan(f, a, bs) assert_allclose(actual_tree.x, expected_tree.x) assert_allclose(actual_tree.y, expected_tree.y) assert_allclose(actual_tree.z, expected_tree.z)
import math import jax.numpy as np from jax import partial, random from jax.flatten_util import ravel_pytree from jax.random import PRNGKey import numpyro.distributions as dist from numpyro.hmc_util import IntegratorState, build_tree, find_reasonable_step_size, velocity_verlet, warmup_adapter from numpyro.util import cond, fori_loop, laxtuple HMCState = laxtuple('HMCState', [ 'z', 'z_grad', 'potential_energy', 'num_steps', 'accept_prob', 'step_size', 'inverse_mass_matrix', 'rng' ]) def _get_num_steps(step_size, trajectory_length): num_steps = np.array(trajectory_length / step_size, dtype=np.int32) return np.where(num_steps < 1, np.array(1, dtype=np.int32), num_steps) def _sample_momentum(unpack_fn, inverse_mass_matrix, rng): if inverse_mass_matrix.ndim == 1: r = dist.norm(0., np.sqrt( np.reciprocal(inverse_mass_matrix))).rvs(random_state=rng) return unpack_fn(r) elif inverse_mass_matrix.ndim == 2: raise NotImplementedError
import jax import jax.numpy as np from jax import grad, jit, partial, random, value_and_grad, vmap from jax.flatten_util import ravel_pytree from jax.ops import index_update from jax.scipy.special import expit from jax.tree_util import tree_multimap from numpyro.distributions.constraints import biject_to from numpyro.distributions.util import cholesky_inverse from numpyro.handlers import seed, substitute, trace from numpyro.util import cond, laxtuple, while_loop AdaptWindow = laxtuple("AdaptWindow", ["start", "end"]) AdaptState = laxtuple("AdaptState", ["step_size", "inverse_mass_matrix", "mass_matrix_sqrt", "ss_state", "mm_state", "window_idx", "rng"]) IntegratorState = laxtuple("IntegratorState", ["z", "r", "potential_energy", "z_grad"]) TreeInfo = laxtuple('TreeInfo', ['z_left', 'r_left', 'z_left_grad', 'z_right', 'r_right', 'z_right_grad', 'z_proposal', 'z_proposal_pe', 'z_proposal_grad', 'depth', 'weight', 'r_sum', 'turning', 'diverging', 'sum_accept_probs', 'num_proposals']) def dual_averaging(t0=10, kappa=0.75, gamma=0.05): """ Dual Averaging is a scheme to solve convex optimization problems. It belongs to a class of subgradient methods which uses subgradients (which lie in a dual space) to update states (in primal space) of a model. Under some conditions, the averages of generated parameters during the scheme are