示例#1
0
def test_scan_prims_disabled():
    def f(tree, yz):
        y, z = yz
        return tree_map(lambda x: (x + y) * z, tree)

    Tree = laxtuple("Tree", ["x", "y", "z"])
    a = Tree(np.array([1., 2.]),
             np.array(3., dtype=np.float32),
             np.array(4., dtype=np.float32))
    bs = (np.array([1., 2., 3., 4.]),
          np.array([4., 3., 2., 1.]))

    expected_tree = lax.scan(f, a, bs)
    with control_flow_prims_disabled():
        actual_tree = scan(f, a, bs)
    assert_allclose(actual_tree.x, expected_tree.x)
    assert_allclose(actual_tree.y, expected_tree.y)
    assert_allclose(actual_tree.z, expected_tree.z)
示例#2
0
import math

import jax.numpy as np
from jax import partial, random
from jax.flatten_util import ravel_pytree
from jax.random import PRNGKey

import numpyro.distributions as dist
from numpyro.hmc_util import IntegratorState, build_tree, find_reasonable_step_size, velocity_verlet, warmup_adapter
from numpyro.util import cond, fori_loop, laxtuple

HMCState = laxtuple('HMCState', [
    'z', 'z_grad', 'potential_energy', 'num_steps', 'accept_prob', 'step_size',
    'inverse_mass_matrix', 'rng'
])


def _get_num_steps(step_size, trajectory_length):
    num_steps = np.array(trajectory_length / step_size, dtype=np.int32)
    return np.where(num_steps < 1, np.array(1, dtype=np.int32), num_steps)


def _sample_momentum(unpack_fn, inverse_mass_matrix, rng):
    if inverse_mass_matrix.ndim == 1:
        r = dist.norm(0., np.sqrt(
            np.reciprocal(inverse_mass_matrix))).rvs(random_state=rng)
        return unpack_fn(r)
    elif inverse_mass_matrix.ndim == 2:
        raise NotImplementedError

示例#3
0
import jax
import jax.numpy as np
from jax import grad, jit, partial, random, value_and_grad, vmap
from jax.flatten_util import ravel_pytree
from jax.ops import index_update
from jax.scipy.special import expit
from jax.tree_util import tree_multimap

from numpyro.distributions.constraints import biject_to
from numpyro.distributions.util import cholesky_inverse
from numpyro.handlers import seed, substitute, trace
from numpyro.util import cond, laxtuple, while_loop

AdaptWindow = laxtuple("AdaptWindow", ["start", "end"])
AdaptState = laxtuple("AdaptState", ["step_size", "inverse_mass_matrix", "mass_matrix_sqrt",
                                     "ss_state", "mm_state", "window_idx", "rng"])
IntegratorState = laxtuple("IntegratorState", ["z", "r", "potential_energy", "z_grad"])

TreeInfo = laxtuple('TreeInfo', ['z_left', 'r_left', 'z_left_grad',
                                 'z_right', 'r_right', 'z_right_grad',
                                 'z_proposal', 'z_proposal_pe', 'z_proposal_grad',
                                 'depth', 'weight', 'r_sum', 'turning', 'diverging',
                                 'sum_accept_probs', 'num_proposals'])


def dual_averaging(t0=10, kappa=0.75, gamma=0.05):
    """
    Dual Averaging is a scheme to solve convex optimization problems. It
    belongs to a class of subgradient methods which uses subgradients (which
    lie in a dual space) to update states (in primal space) of a model. Under
    some conditions, the averages of generated parameters during the scheme are