Exemplo n.º 1
0
 def bc_flat(aug_t0, aug_t1):
     error_flat, _ = ravel_pytree(bc(unravel(aug_t0), unravel(aug_t1)))
     return error_flat
Exemplo n.º 2
0
    def body_fn(state):
        i, key, _, _ = state
        key, subkey = random.split(key)

        if radius is None or prototype_params is None:
            # XXX: we don't want to apply enum to draw latent samples
            model_ = model
            if enum:
                from numpyro.contrib.funsor import enum as enum_handler

                if isinstance(model, substitute) and isinstance(
                        model.fn, enum_handler):
                    model_ = substitute(model.fn.fn, data=model.data)
                elif isinstance(model, enum_handler):
                    model_ = model.fn

            # Wrap model in a `substitute` handler to initialize from `init_loc_fn`.
            seeded_model = substitute(seed(model_, subkey),
                                      substitute_fn=init_strategy)
            model_trace = trace(seeded_model).get_trace(
                *model_args, **model_kwargs)
            constrained_values, inv_transforms = {}, {}
            for k, v in model_trace.items():
                if (v["type"] == "sample" and not v["is_observed"]
                        and not v["fn"].support.is_discrete):
                    constrained_values[k] = v["value"]
                    with helpful_support_errors(v):
                        inv_transforms[k] = biject_to(v["fn"].support)
            params = transform_fn(
                inv_transforms,
                {k: v
                 for k, v in constrained_values.items()},
                invert=True,
            )
        else:  # this branch doesn't require tracing the model
            params = {}
            for k, v in prototype_params.items():
                if k in init_values:
                    params[k] = init_values[k]
                else:
                    params[k] = random.uniform(subkey,
                                               jnp.shape(v),
                                               minval=-radius,
                                               maxval=radius)
                    key, subkey = random.split(key)

        potential_fn = partial(potential_energy,
                               model,
                               model_args,
                               model_kwargs,
                               enum=enum)
        if validate_grad:
            if forward_mode_differentiation:
                pe = potential_fn(params)
                z_grad = jacfwd(potential_fn)(params)
            else:
                pe, z_grad = value_and_grad(potential_fn)(params)
            z_grad_flat = ravel_pytree(z_grad)[0]
            is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))
        else:
            pe = potential_fn(params)
            is_valid = jnp.isfinite(pe)
            z_grad = None

        return i + 1, key, (params, pe, z_grad), is_valid
Exemplo n.º 3
0
 def testEmpty(self):
     tree = []
     raveled, unravel = flatten_util.ravel_pytree(tree)
     self.assertEqual(raveled.dtype, jnp.float32)  # convention
     tree_ = unravel(raveled)
     self.assertAllClose(tree, tree_, atol=0., rtol=0.)
Exemplo n.º 4
0
Arquivo: mcmc.py Projeto: juvu/numpyro
    def init_kernel(init_params,
                    num_warmup,
                    step_size=1.0,
                    adapt_step_size=True,
                    adapt_mass_matrix=True,
                    dense_mass=False,
                    target_accept_prob=0.8,
                    trajectory_length=2 * math.pi,
                    max_tree_depth=10,
                    run_warmup=True,
                    progbar=True,
                    rng=PRNGKey(0)):
        """
        Initializes the HMC sampler.

        :param init_params: Initial parameters to begin sampling. The type must
            be consistent with the input type to `potential_fn`.
        :param int num_warmup_steps: Number of warmup steps; samples generated
            during warmup are discarded.
        :param float step_size: Determines the size of a single step taken by the
            verlet integrator while computing the trajectory using Hamiltonian
            dynamics. If not specified, it will be set to 1.
        :param bool adapt_step_size: A flag to decide if we want to adapt step_size
            during warm-up phase using Dual Averaging scheme.
        :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass
            matrix during warm-up phase using Welford scheme.
        :param bool dense_mass: A flag to decide if mass matrix is dense or
            diagonal (default when ``dense_mass=False``)
        :param float target_accept_prob: Target acceptance probability for step size
            adaptation using Dual Averaging. Increasing this value will lead to a smaller
            step size, hence the sampling will be slower but more robust. Default to 0.8.
        :param float trajectory_length: Length of a MCMC trajectory for HMC. Default
            value is :math:`2\\pi`.
        :param int max_tree_depth: Max depth of the binary tree created during the doubling
            scheme of NUTS sampler. Defaults to 10.
        :param bool run_warmup: Flag to decide whether warmup is run. If ``True``,
            `init_kernel` returns an initial :data:`HMCState` that can be used to
            generate samples using MCMC. Else, returns the arguments and callable
            that does the initial adaptation.
        :param bool progbar: Whether to enable progress bar updates. Defaults to
            ``True``.
        :param bool heuristic_step_size: If ``True``, a coarse grained adjustment of
            step size is done at the beginning of each adaptation window to achieve
            `target_acceptance_prob`.
        :param jax.random.PRNGKey rng: random key to be used as the source of
            randomness.
        """
        step_size = float(step_size)
        nonlocal momentum_generator, wa_update, trajectory_len, max_treedepth, wa_steps
        wa_steps = num_warmup
        trajectory_len = float(trajectory_length)
        max_treedepth = max_tree_depth
        z = init_params
        z_flat, unravel_fn = ravel_pytree(z)
        momentum_generator = partial(_sample_momentum, unravel_fn)

        find_reasonable_ss = partial(find_reasonable_step_size, potential_fn,
                                     kinetic_fn, momentum_generator)

        wa_init, wa_update = warmup_adapter(
            num_warmup,
            adapt_step_size=adapt_step_size,
            adapt_mass_matrix=adapt_mass_matrix,
            dense_mass=dense_mass,
            target_accept_prob=target_accept_prob,
            find_reasonable_step_size=find_reasonable_ss)

        rng_hmc, rng_wa = random.split(rng)
        wa_state = wa_init(z,
                           rng_wa,
                           step_size,
                           mass_matrix_size=np.size(z_flat))
        r = momentum_generator(wa_state.mass_matrix_sqrt, rng)
        vv_state = vv_init(z, r)
        hmc_state = HMCState(0, vv_state.z, vv_state.z_grad,
                             vv_state.potential_energy, 0, 0., 0., wa_state,
                             rng_hmc)

        if run_warmup and num_warmup > 0:
            # JIT if progress bar updates not required
            if not progbar:
                hmc_state = jit(fori_loop, static_argnums=(2, ))(
                    0, num_warmup, lambda *args: sample_kernel(args[1]),
                    hmc_state)
            else:
                with tqdm.trange(num_warmup, desc='warmup') as t:
                    for i in t:
                        hmc_state = sample_kernel(hmc_state)
                        t.set_postfix_str(get_diagnostics_str(hmc_state),
                                          refresh=False)
        return hmc_state
Exemplo n.º 5
0
Arquivo: util.py Projeto: juvu/numpyro
def fori_collect(lower,
                 upper,
                 body_fun,
                 init_val,
                 transform=identity,
                 progbar=True,
                 **progbar_opts):
    """
    This looping construct works like :func:`~jax.lax.fori_loop` but with the additional
    effect of collecting values from the loop body. In addition, this allows for
    post-processing of these samples via `transform`, and progress bar updates.
    Note that, `progbar=False` will be faster, especially when collecting a
    lot of samples. Refer to example usage in :func:`~numpyro.mcmc.hmc`.

    :param int lower: the index to start the collective work. In other words,
        we will skip collecting the first `lower` values.
    :param int upper: number of times to run the loop body.
    :param body_fun: a callable that takes a collection of
        `np.ndarray` and returns a collection with the same shape and
        `dtype`.
    :param init_val: initial value to pass as argument to `body_fun`. Can
        be any Python collection type containing `np.ndarray` objects.
    :param transform: a callable to post-process the values returned by `body_fn`.
    :param progbar: whether to post progress bar updates.
    :param `**progbar_opts`: optional additional progress bar arguments. A
        `diagnostics_fn` can be supplied which when passed the current value
        from `body_fun` returns a string that is used to update the progress
        bar postfix. Also a `progbar_desc` keyword argument can be supplied
        which is used to label the progress bar.
    :return: collection with the same type as `init_val` with values
        collected along the leading axis of `np.ndarray` objects.
    """
    assert lower < upper
    init_val_flat, unravel_fn = ravel_pytree(transform(init_val))
    ravel_fn = lambda x: ravel_pytree(transform(x))[0]  # noqa: E731

    if not progbar:
        collection = np.zeros((upper - lower, ) + init_val_flat.shape)

        def _body_fn(i, vals):
            val, collection = vals
            val = body_fun(val)
            i = np.where(i >= lower, i - lower, 0)
            collection = ops.index_update(collection, i, ravel_fn(val))
            return val, collection

        _, collection = jit(fori_loop,
                            static_argnums=(2, ))(0, upper, _body_fn,
                                                  (init_val, collection))
    else:
        diagnostics_fn = progbar_opts.pop('diagnostics_fn', None)
        progbar_desc = progbar_opts.pop('progbar_desc', '')
        collection = []

        val = init_val
        with tqdm.trange(upper, desc=progbar_desc) as t:
            for i in t:
                val = body_fun(val)
                if i >= lower:
                    collection.append(jit(ravel_fn)(val))
                if diagnostics_fn:
                    t.set_postfix_str(diagnostics_fn(val), refresh=False)

        # XXX: jax.numpy.stack/concatenate is currently slow
        collection = onp.stack(collection)

    return vmap(unravel_fn)(collection)
Exemplo n.º 6
0
def augmented_post_drift(rng, state_unravel_fn, param_unravel_fn,
                         driftx_apply_fn, prior_driftw_apply_fn,
                         driftw0_apply_fn, driftw_apply_fn, diff_apply_fn,
                         W_driftwt, W_priorwt, W_diffusion, W0_post, setting,
                         t, aug, eps, **kwargs):  # last 3 partial
    """aug: [x, w, kl]
  """
    flat_w, x, kl = state_unravel_fn(aug)  # flat_w inputs
    w = param_unravel_fn(flat_w)  # hierarchical w params

    _, dxt = driftx_apply_fn(
        w, (t, x), **
        kwargs)  # TODO: does this need to be bayesian? No we sample once only.

    _, dwtkl = driftw_apply_fn(W_driftwt, (t, flat_w), **kwargs)

    # pw_drift = 0. # assume prior has zero drift
    _, pw_drift = prior_driftw_apply_fn(W_priorwt, (t, flat_w),
                                        **kwargs)  # account for prior
    _, diffwt = diff_apply_fn(W_diffusion, (t, flat_w),
                              **kwargs)  # with time dependence

    # stop gradient only to compute the KL
    _dwtkl = dwtkl  # already "includes the w0kl", no need to add again.
    u = (_dwtkl - pw_drift) * 0
    if setting['diffw']:
        u = (_dwtkl - pw_drift) / diffwt
    dkl = np.sum((u**2) / 2)  # dt = 1
    u2 = u * eps
    dkl_stl = np.sum(u2)

    # _w0kl, _w0_prior = 0, 0 # for gaussian mixture on W0
    if setting['stop_grad']:
        print("applying sticking the landing")
        if setting['priorw0_sigma'] >= 0:
            print("stopping gradient on w0 posterior")
            # stop grad on posterior over w0
            W0_post = tree_map(lambda w: lax.stop_gradient(w), W0_post)
            # _w0kl = tree_map(lambda w: lax.stop_gradient(w), _w0kl) # TODO: alternative?
            # w0out, _w0, _w0kl = driftw0_apply_fn(W0_post, (0, x), rng=rng, logsigma2=np.log(setting['priorw0_sigma']))
        # approx posterior over weights
        if setting['diff_drift']:
            print("stopping gradient on difference parameterization")
            W_driftwtkl = tree_map(lambda w: lax.stop_gradient(w),
                                   W_driftwt[-1])
            W_driftwtkl = (W_driftwt[0], W_driftwtkl
                           )  # first slot is params for t, which are ()
            _, _dwtkl = driftw_apply_fn(W_driftwtkl, (t, flat_w),
                                        **kwargs)  # dwtkl
        else:
            W_driftwtkl = tree_map(lambda w: lax.stop_gradient(w), W_driftwt)
            _, _dwtkl = driftw_apply_fn(W_driftwtkl, (t, flat_w),
                                        **kwargs)  # dwtkl

    # for original (but wrong) formulation of STL
    # u = (_dwtkl - pw_drift) * 0
    # if setting['diffw']:
    #   u = (_dwtkl - pw_drift) / diffwt
    # dkl = np.sum((u ** 2) / 2)
    ## dkl = dkl + _w0kl # account for bayesian w0 parameters

    # New STL: second term stop gradient (not first term - moved above)
    if setting['stop_grad']:
        print("stopping gradient on second STL term")
        u2 = (_dwtkl - pw_drift) * 0
        if setting['diffw']:
            u2 = (_dwtkl - pw_drift) / diffwt
        u2 = u2 * eps  # u dBt
        dkl_stl = np.sum(u2)
    dkl = dkl + dkl_stl

    # outputs = [dxt, dwtkl, dkl]
    outputs = [dwtkl, dxt, dkl]

    return ravel_pytree(outputs)[0]
Exemplo n.º 7
0
def _initialize_mass_matrix(z, inverse_mass_matrix, dense_mass):
    if isinstance(dense_mass, list):
        if inverse_mass_matrix is None:
            inverse_mass_matrix = {}
        # if user specifies an ndarray mass matrix, then we convert it to a dict
        elif not isinstance(inverse_mass_matrix, dict):
            inverse_mass_matrix = {tuple(sorted(z)): inverse_mass_matrix}
        mass_matrix_sqrt = {}
        mass_matrix_sqrt_inv = {}
        for site_names in dense_mass:
            inverse_mm = inverse_mass_matrix.get(site_names)
            z_block = tuple(z[k] for k in site_names)
            inverse_mm, mm_sqrt, mm_sqrt_inv = _initialize_mass_matrix(
                z_block, inverse_mm, True
            )
            inverse_mass_matrix[site_names] = inverse_mm
            mass_matrix_sqrt[site_names] = mm_sqrt
            mass_matrix_sqrt_inv[site_names] = mm_sqrt_inv
        # NB: this branch only happens when users want to use block diagonal
        # inverse_mass_matrix, for example, {("a",): jnp.ones(3), ("b",): jnp.ones(3)}.
        for site_names, inverse_mm in inverse_mass_matrix.items():
            if site_names in dense_mass:
                continue
            z_block = tuple(z[k] for k in site_names)
            inverse_mm, mm_sqrt, mm_sqrt_inv = _initialize_mass_matrix(
                z_block, inverse_mm, False
            )
            inverse_mass_matrix[site_names] = inverse_mm
            mass_matrix_sqrt[site_names] = mm_sqrt
            mass_matrix_sqrt_inv[site_names] = mm_sqrt_inv
        remaining_sites = tuple(sorted(set(z) - set().union(*inverse_mass_matrix)))
        if len(remaining_sites) > 0:
            z_block = tuple(z[k] for k in remaining_sites)
            inverse_mm, mm_sqrt, mm_sqrt_inv = _initialize_mass_matrix(
                z_block, None, False
            )
            inverse_mass_matrix[remaining_sites] = inverse_mm
            mass_matrix_sqrt[remaining_sites] = mm_sqrt
            mass_matrix_sqrt_inv[remaining_sites] = mm_sqrt_inv
        expected_site_names = sorted(z)
        actual_site_names = sorted(
            [k for site_names in inverse_mass_matrix for k in site_names]
        )
        assert actual_site_names == expected_site_names, (
            "There seems to be a conflict of sites names specified in the initial"
            " `inverse_mass_matrix` and in `dense_mass` argument."
        )
        return inverse_mass_matrix, mass_matrix_sqrt, mass_matrix_sqrt_inv

    mass_matrix_size = jnp.size(ravel_pytree(z)[0])
    if inverse_mass_matrix is None:
        if dense_mass:
            inverse_mass_matrix = jnp.identity(mass_matrix_size)
        else:
            inverse_mass_matrix = jnp.ones(mass_matrix_size)
        mass_matrix_sqrt = mass_matrix_sqrt_inv = inverse_mass_matrix
    else:
        if dense_mass:
            if jnp.ndim(inverse_mass_matrix) == 1:
                inverse_mass_matrix = jnp.diag(inverse_mass_matrix)
            mass_matrix_sqrt_inv = jnp.swapaxes(
                jnp.linalg.cholesky(inverse_mass_matrix[..., ::-1, ::-1])[
                    ..., ::-1, ::-1
                ],
                -2,
                -1,
            )
            identity = jnp.identity(inverse_mass_matrix.shape[-1])
            mass_matrix_sqrt = solve_triangular(
                mass_matrix_sqrt_inv, identity, lower=True
            )
        else:
            if jnp.ndim(inverse_mass_matrix) == 2:
                inverse_mass_matrix = jnp.diag(inverse_mass_matrix)
            mass_matrix_sqrt_inv = jnp.sqrt(inverse_mass_matrix)
            mass_matrix_sqrt = jnp.reciprocal(mass_matrix_sqrt_inv)
    return inverse_mass_matrix, mass_matrix_sqrt, mass_matrix_sqrt_inv
Exemplo n.º 8
0
def fista_step(data, loss_and_prox_op, model_param, options):
    """Fista optimization step for solving regularized problem.

  Args:
    data: A tuple of inputs and labels passed to the loss function.
    loss_and_prox_op: Tuple of (loss_f, prox_g)
      loss_f is the loss function that takes in model_param, inputs, and labels.
      prox_g is the proximity operator for g.
    model_param: Current model parameters to be passed to loss_f.
    options: A dictionary of optimizer specific hyper-parameters.

  Returns:
    Updated model parameters and updated step size.
  """
    options = dict(options)
    step_size = options.get('step_size', 1.0)
    acceleration = options.get('acceleration', True)
    t = options.get('t', 1.0)
    verbose = options.get('verbose', False)
    reuse_last_step = options.get('reuse_last_step', False)

    loss_f, prox_g = loss_and_prox_op
    inputs, labels = data[0], data[1]
    fun_f = lambda param: loss_f(param, inputs, labels)
    value_and_grad_f = jax.value_and_grad(fun_f)
    x, unravel_fn = ravel_pytree(model_param)
    y = options.get('y', x)
    value_f, grad_f = value_and_grad_f(unravel_fn(y))
    grad_f, unravel_fn = ravel_pytree(grad_f)

    def next_candidate(step_size):
        return prox_g(y - grad_f * step_size, step_size)

    def stop_cond(step_size, next_iter):
        diff = next_iter - y
        sqdist = jnp.sum(diff**2)

        # We do not compute the non-smooth term (g in the paper)
        # as it cancels out from value_F and value_Q.
        value_bigf = fun_f(next_iter)
        value_bigq = value_f + jnp.sum(
            diff * grad_f) + 0.5 / step_size * sqdist
        return value_bigf <= value_bigq

    x_old = x

    step_size, x = backtracking(next_candidate, stop_cond, step_size, options)

    # Acceleration.
    if acceleration:
        t_next = (1 + jnp.sqrt(1 + 4 * t**2)) / 2.
        y = x + (t - 1) / t_next * (x - x_old)
        t = t_next
        options['y'] = y
        options['t'] = t
    else:
        y = x

    if reuse_last_step:
        options['step_size'] = step_size
    if verbose:
        logging.info('Step size: %f', step_size)

    return unravel_fn(x), options
Exemplo n.º 9
0
def gradient_descent_line_search_step(data, loss_f, model_param, options):
    """Gradient Descent optimization with line search step.

  Args:
    data: A tuple of inputs and labels passed to the loss function.
    loss_f: The loss function that takes in model_param, inputs, and labels.
    model_param: Current model parameters to be passed to loss_f.
    options: A dictionary of optimizer specific hyper-parameters.

  Returns:
    Updated model parameters and updated step size.
  """
    options = dict(options)
    beta = options.get('beta', 0.9)
    beta_prime = options.get('beta_prime', 1e-4)
    step_size = options.get('step_size', 10000.0)
    verbose = options.get('verbose', False)
    reuse_last_step = options.get('reuse_last_step', False)

    inputs, labels = data[0], data[1]
    loss_with_data_f = lambda param: loss_f(param, inputs, labels)
    value_and_grad_f = jax.value_and_grad(loss_with_data_f)
    value, grad = value_and_grad_f(model_param)

    # Maximum learning rate allowed from Theorem 5 in Gunasekar et al. 2017
    if options['bound_step']:
        # Bound by dual of L2
        b_const = jnp.max(jnp.linalg.norm(inputs, ord=2, axis=0))
        step_size = min(step_size, 1 / (b_const * b_const * value))

    grad, unravel_fn = ravel_pytree(grad)
    x, unravel_fn = ravel_pytree(model_param)

    # If we normalize step_size will be harder to tune.
    direction = -grad

    # TODO(fartash): consider using the condition in FISTA
    def next_candidate(step_size):
        next_iter = x + step_size * direction
        next_value, next_grad = value_and_grad_f(unravel_fn(next_iter))
        next_grad, _ = ravel_pytree(next_grad)
        return next_iter, next_value, next_grad

    def stop_cond(step_size, res):
        _, next_value, next_grad = res
        gd = jnp.sum(grad * direction)

        # Strong Wolfe condition.
        cond1 = next_value <= value + beta_prime * step_size * gd
        cond2 = jnp.sum(jnp.abs(next_grad * direction)) >= beta * gd
        return cond1 and cond2

    step_size, res = backtracking(next_candidate,
                                  stop_cond,
                                  step_size,
                                  options=options)
    next_param = res[0]

    if reuse_last_step:
        options['step_size'] = step_size
    if verbose:
        logging.info('Step size: %f', step_size)

    return unravel_fn(next_param), options
Exemplo n.º 10
0
 def update(i, g, w):
     g_flat, unflatten = ravel_pytree(g)
     w_flat = ravel_pytree_jit(w)
     updated_params = soft_thresholding(w_flat - step_size * g_flat,
                                        step_size * lambd)
     return unflatten(updated_params)
Exemplo n.º 11
0
 def next_candidate(step_size):
     next_iter = x + step_size * direction
     next_value, next_grad = value_and_grad_f(unravel_fn(next_iter))
     next_grad, _ = ravel_pytree(next_grad)
     return next_iter, next_value, next_grad
Exemplo n.º 12
0
        g_flat, unflatten = ravel_pytree(g)
        w_flat = ravel_pytree_jit(w)
        updated_params = soft_thresholding(w_flat - step_size * g_flat,
                                           step_size * lambd)
        return unflatten(updated_params)

    def get_params(w):
        return w

    def set_step_size(lr):
        step_size = lr

    return init, update, get_params, soft_thresholding, set_step_size


ravel_pytree_jit = jit(lambda tree: ravel_pytree(tree)[0])


@jit
def line_search(w, g, batch, beta):
    lr_i = 1
    g_flat, unflatten_g = ravel_pytree(g)
    w_flat = ravel_pytree_jit(w)
    z_flat = soft_thresholding(w_flat - lr_i * g_flat, lr_i * lambd)
    z = unflatten_g(z_flat)
    for i in range(20):
        is_converged = loss(
            z, batch) > loss(w, batch) + g_flat @ (z_flat - w_flat) + np.sum(
                (z_flat - w_flat)**2) / (2 * lr_i)
        lr_i = jnp.where(is_converged, lr_i, beta * lr_i)
    return lr_i
Exemplo n.º 13
0
def ravel_first_arg_(unravel, y_flat, *args):
    y = unravel(y_flat)
    ans = yield (y, ) + args, {}
    ans_flat, _ = ravel_pytree(ans)
    yield ans_flat
Exemplo n.º 14
0
def _odeint_wrapper(func, rtol, atol, mxstep, y0, ts, *args):
    y0, unravel = ravel_pytree(y0)
    func = ravel_first_arg(func, unravel)
    out = _odeint(func, rtol, atol, mxstep, y0, ts, *args)
    return jax.vmap(unravel)(out)
Exemplo n.º 15
0
def aug_init(W0, _inputs):  # partial over last
    flat_aug, _ = ravel_pytree([W0, _inputs, 0.])
    flat_w_dim = len(ravel_pytree([W0])[0])
    return flat_aug, flat_w_dim
Exemplo n.º 16
0
def _weight_fn(params):
    flat_params, _ = ravel_pytree(params)
    return 0.5 * jnp.sum(jnp.square(flat_params))
Exemplo n.º 17
0
    def apply_fun(params, inputs, rng, **kwargs):
        """(aka predict_apply_fn)
    Args:
      inputs: x_t0 ... x_tn
      params: (w0, w_driftwt, w_difft)
    Return: 
      x_t1, log p(x) / q(x)
    """
        W0, W_driftwt, W_diffusion, W_priorwt, W0_post = params

        # _, W0 = driftx_init_fn(rng, (-1, x_dim)) # W0 is randomized each call
        # Bayesian about W0 (mean field)
        W0kl = 0.
        if W0_post is not None:
            tw0 = 0  # NOTE: this is just a dummy variable to be compatible with `shape_dependent`
            _, W0, W0kl = driftw0_apply_fn(W0_post, (tw0, inputs),
                                           rng=rng,
                                           logsigma2=np.log(
                                               setting['priorw0_sigma']))
        flat_W0, param_unravel_fn = ravel_pytree(W0)
        _, state_unravel_fn = ravel_pytree([flat_W0, inputs, W0kl])

        aug_init_ = lambda inputs_: aug_init(W0, inputs_)
        aug_post_drift = lambda t_, aug_, eps_: augmented_post_drift(
            rng, state_unravel_fn, param_unravel_fn, driftx_apply_fn,
            prior_driftw_apply_fn, driftw0_apply_fn, driftw_apply_fn,
            diff_apply_fn, W_driftwt, W_priorwt, W_diffusion, W0_post, setting,
            t_, aug_, eps_, **kwargs)
        aug_prior_drift = lambda t_, aug_, eps_: augmented_prior_drift(
            state_unravel_fn, param_unravel_fn, driftx_apply_fn,
            prior_driftw_apply_fn, diff_apply_fn, W_priorwt, W_diffusion,
            setting, t_, aug_, eps_, **kwargs)
        aug_diffusion = lambda t_, aug_: augmented_diffusion(
            state_unravel_fn, param_unravel_fn, diff_apply_fn, W_diffusion, t_,
            aug_, **kwargs)

        if kwargs.get('entire', False):
            print("returning entire trajectory...")
            # outputs = sdeint(aug_post_drift, aug_diffusion, aug_init_(inputs), np.linspace(0, 1, 100), rng)
            outputs = sdeint(aug_post_drift, aug_diffusion,
                             aug_init_(inputs)[0],
                             aug_init_(inputs)[1], np.linspace(0, 1, 90),
                             rng)  # TODO: to prevent oom problems
            ws, xs, kls = zip(*[state_unravel_fn(out) for out in outputs
                                ])  # list(x_dim), list(w_dim), list(1)
            # prior_outputs = sdeint(aug_prior_drift, aug_diffusion, aug_init_(inputs), np.linspace(0, 1, 100), rng)
            prior_outputs = sdeint(aug_prior_drift, aug_diffusion,
                                   aug_init_(inputs)[0],
                                   aug_init_(inputs)[1], np.linspace(0, 1, 90),
                                   rng)  # TODO: 100 steps
            prior_ws, prior_xs, prior_kls = zip(*[
                state_unravel_fn(out) for out in prior_outputs
            ])  # list(x_dim), list(w_dim), list(1)
            return xs, ws, kls, prior_xs, prior_ws, prior_kls

        solution = _sdeint(aug_post_drift, aug_diffusion,
                           aug_init_(inputs)[0],
                           aug_init_(inputs)[1], np.linspace(0, 1, 20),
                           rng)  # was sdeint[-1]
        w, x, kl = state_unravel_fn(solution)
        prior_outputs = _sdeint(
            aug_prior_drift, aug_diffusion,
            aug_init_(inputs)[0],
            aug_init_(inputs)[1], np.linspace(0, 1, 20),
            rng)  # TODO: ramp up 100 steps for better performance
        prior_ws, prior_xs, prior_kls = state_unravel_fn(prior_outputs)
        return x, w, kl, prior_xs, prior_ws, prior_kls
Exemplo n.º 18
0
def scipy_minimize_with_jax(fun, x0,
                            method=None,
                            args=(),
                            bounds=None,
                            constraints=(),
                            tol=None,
                            callback=None,
                            options=None):
  """
  A simple wrapper for scipy.optimize.minimize using JAX.

  Parameters
  ----------
  fun: function
    The objective function to be minimized, written in JAX code
    so that it is automatically differentiable.  It is of type,
    ```fun: x, *args -> float``` where `x` is a PyTree and args
    is a tuple of the fixed parameters needed to completely specify the function.

  x0: jnp.ndarray
    Initial guess represented as a JAX PyTree.

  args: tuple, optional.
    Extra arguments passed to the objective function
    and its derivative.  Must consist of valid JAX types; e.g. the leaves
    of the PyTree must be floats.

  method : str or callable, optional
    Type of solver.  Should be one of
        - 'Nelder-Mead' :ref:`(see here) <optimize.minimize-neldermead>`
        - 'Powell'      :ref:`(see here) <optimize.minimize-powell>`
        - 'CG'          :ref:`(see here) <optimize.minimize-cg>`
        - 'BFGS'        :ref:`(see here) <optimize.minimize-bfgs>`
        - 'Newton-CG'   :ref:`(see here) <optimize.minimize-newtoncg>`
        - 'L-BFGS-B'    :ref:`(see here) <optimize.minimize-lbfgsb>`
        - 'TNC'         :ref:`(see here) <optimize.minimize-tnc>`
        - 'COBYLA'      :ref:`(see here) <optimize.minimize-cobyla>`
        - 'SLSQP'       :ref:`(see here) <optimize.minimize-slsqp>`
        - 'trust-constr':ref:`(see here) <optimize.minimize-trustconstr>`
        - 'dogleg'      :ref:`(see here) <optimize.minimize-dogleg>`
        - 'trust-ncg'   :ref:`(see here) <optimize.minimize-trustncg>`
        - 'trust-exact' :ref:`(see here) <optimize.minimize-trustexact>`
        - 'trust-krylov' :ref:`(see here) <optimize.minimize-trustkrylov>`
        - custom - a callable object (added in version 0.14.0),
          see below for description.
    If not given, chosen to be one of ``BFGS``, ``L-BFGS-B``, ``SLSQP``,
    depending on if the problem has constraints or bounds.

  bounds : sequence or `Bounds`, optional
    Bounds on variables for L-BFGS-B, TNC, SLSQP, Powell, and
    trust-constr methods. There are two ways to specify the bounds:
        1. Instance of `Bounds` class.
        2. Sequence of ``(min, max)`` pairs for each element in `x`. None
        is used to specify no bound.
    Note that in order to use `bounds` you will need to manually flatten
    them in the same order as your inputs `x0`.

  constraints : {Constraint, dict} or List of {Constraint, dict}, optional
    Constraints definition (only for COBYLA, SLSQP and trust-constr).
    Constraints for 'trust-constr' are defined as a single object or a
    list of objects specifying constraints to the optimization problem.
    Available constraints are:
        - `LinearConstraint`
        - `NonlinearConstraint`
    Constraints for COBYLA, SLSQP are defined as a list of dictionaries.
    Each dictionary with fields:
        type : str
            Constraint type: 'eq' for equality, 'ineq' for inequality.
        fun : callable
            The function defining the constraint.
        jac : callable, optional
            The Jacobian of `fun` (only for SLSQP).
        args : sequence, optional
            Extra arguments to be passed to the function and Jacobian.
    Equality constraint means that the constraint function result is to
    be zero whereas inequality means that it is to be non-negative.
    Note that COBYLA only supports inequality constraints.

    Note that in order to use `constraints` you will need to manually flatten
    them in the same order as your inputs `x0`.

  tol : float, optional
    Tolerance for termination. For detailed control, use solver-specific
    options.

  options : dict, optional
      A dictionary of solver options. All methods accept the following
      generic options:
          maxiter : int
              Maximum number of iterations to perform. Depending on the
              method each iteration may use several function evaluations.
          disp : bool
              Set to True to print convergence messages.
      For method-specific options, see :func:`show_options()`.

  callback : callable, optional
      Called after each iteration. For 'trust-constr' it is a callable with
      the signature:
          ``callback(xk, OptimizeResult state) -> bool``
      where ``xk`` is the current parameter vector represented as a PyTree,
       and ``state`` is an `OptimizeResult` object, with the same fields
      as the ones from the return. If callback returns True the algorithm
      execution is terminated.

      For all the other methods, the signature is:
          ```callback(xk)```
      where `xk` is the current parameter vector, represented as a PyTree.

  Returns
  -------
  res : The optimization result represented as a ``OptimizeResult`` object.
    Important attributes are:
        ``x``: the solution array, represented as a JAX PyTree
        ``success``: a Boolean flag indicating if the optimizer exited successfully
        ``message``: describes the cause of the termination.
    See `scipy.optimize.OptimizeResult` for a description of other attributes.

  """
  if soptimize is None:
    raise errors.PackageMissingError(f'"scipy" must be installed when user want to use '
                                     f'function: {scipy_minimize_with_jax}')

  # Use tree flatten and unflatten to convert params x0 from PyTrees to flat arrays
  x0_flat, unravel = ravel_pytree(x0)

  # Wrap the objective function to consume flat _original_
  # numpy arrays and produce scalar outputs.
  def fun_wrapper(x_flat, *args):
    x = unravel(x_flat)
    r = fun(x, *args)
    r = r.value if isinstance(r, bm.JaxArray) else r
    return float(r)

  # Wrap the gradient in a similar manner
  jac = jit(grad(fun))

  def jac_wrapper(x_flat, *args):
    x = unravel(x_flat)
    g_flat, _ = ravel_pytree(jac(x, *args))
    return np.array(g_flat)

  # Wrap the callback to consume a pytree
  def callback_wrapper(x_flat, *args):
    if callback is not None:
      x = unravel(x_flat)
      return callback(x, *args)

  # Minimize with scipy
  results = soptimize.minimize(fun_wrapper,
                                    x0_flat,
                                    args=args,
                                    method=method,
                                    jac=jac_wrapper,
                                    callback=callback_wrapper,
                                    bounds=bounds,
                                    constraints=constraints,
                                    tol=tol,
                                    options=options)

  # pack the output back into a PyTree
  results["x"] = unravel(results["x"])
  return results
Exemplo n.º 19
0
def build_tree(
    verlet_update,
    kinetic_fn,
    verlet_state,
    inverse_mass_matrix,
    step_size,
    rng_key,
    max_delta_energy=1000.0,
    max_tree_depth=10,
):
    """
    Builds a binary tree from the `verlet_state`. This is used in NUTS sampler.

    **References:**

    1. *The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo*,
       Matthew D. Hoffman, Andrew Gelman
    2. *A Conceptual Introduction to Hamiltonian Monte Carlo*,
       Michael Betancourt

    :param verlet_update: A callable to get a new integrator state given a current
        integrator state.
    :param kinetic_fn: A callable to compute kinetic energy.
    :param verlet_state: Initial integrator state.
    :param inverse_mass_matrix: Inverse of the mass matrix.
    :param float step_size: Step size for the current trajectory.
    :param jax.random.PRNGKey rng_key: random key to be used as the source of
        randomness.
    :param float max_delta_energy: A threshold to decide if the new state diverges
        (based on the energy difference) too much from the initial integrator state.
    :param int max_tree_depth: Max depth of the binary tree created during the doubling
        scheme of NUTS sampler. Defaults to 10. This argument also accepts a tuple of
        integers `(d1, d2)`, where `d1` is the max tree depth at the current MCMC
        step and `d2` is the global max tree depth for all MCMC steps.
    :return: information of the tree.
    :rtype: :data:`TreeInfo`
    """
    if isinstance(max_tree_depth, tuple):
        max_tree_depth_current, max_tree_depth = max_tree_depth
    else:
        max_tree_depth_current = max_tree_depth
    z, r, potential_energy, z_grad = verlet_state
    energy_current = potential_energy + kinetic_fn(inverse_mass_matrix, r)
    latent_size = jnp.size(ravel_pytree(r)[0])
    r_ckpts = jnp.zeros((max_tree_depth, latent_size))
    r_sum_ckpts = jnp.zeros((max_tree_depth, latent_size))

    tree = TreeInfo(
        z,
        r,
        z_grad,
        z,
        r,
        z_grad,
        z,
        potential_energy,
        z_grad,
        energy_current,
        depth=0,
        weight=jnp.zeros(()),
        r_sum=r,
        turning=jnp.array(False),
        diverging=jnp.array(False),
        sum_accept_probs=jnp.zeros(()),
        num_proposals=jnp.array(0, dtype=jnp.result_type(int)),
    )

    def _cond_fn(state):
        tree, _ = state
        return (tree.depth < max_tree_depth_current) & ~tree.turning & ~tree.diverging

    def _body_fn(state):
        tree, key = state
        key, direction_key, doubling_key = random.split(key, 3)
        going_right = random.bernoulli(direction_key)
        tree = _double_tree(
            tree,
            verlet_update,
            kinetic_fn,
            inverse_mass_matrix,
            step_size,
            going_right,
            doubling_key,
            energy_current,
            max_delta_energy,
            r_ckpts,
            r_sum_ckpts,
        )
        return tree, key

    state = (tree, rng_key)
    tree, _ = while_loop(_cond_fn, _body_fn, state)
    return tree
Exemplo n.º 20
0
 def jac_wrapper(x_flat, *args):
   x = unravel(x_flat)
   g_flat, _ = ravel_pytree(jac(x, *args))
   return np.array(g_flat)
Exemplo n.º 21
0
def fori_collect(
    lower,
    upper,
    body_fun,
    init_val,
    transform=identity,
    progbar=True,
    return_last_val=False,
    collection_size=None,
    thinning=1,
    **progbar_opts,
):
    """
    This looping construct works like :func:`~jax.lax.fori_loop` but with the additional
    effect of collecting values from the loop body. In addition, this allows for
    post-processing of these samples via `transform`, and progress bar updates.
    Note that, `progbar=False` will be faster, especially when collecting a
    lot of samples. Refer to example usage in :func:`~numpyro.infer.mcmc.hmc`.

    :param int lower: the index to start the collective work. In other words,
        we will skip collecting the first `lower` values.
    :param int upper: number of times to run the loop body.
    :param body_fun: a callable that takes a collection of
        `np.ndarray` and returns a collection with the same shape and
        `dtype`.
    :param init_val: initial value to pass as argument to `body_fun`. Can
        be any Python collection type containing `np.ndarray` objects.
    :param transform: a callable to post-process the values returned by `body_fn`.
    :param progbar: whether to post progress bar updates.
    :param bool return_last_val: If `True`, the last value is also returned.
        This has the same type as `init_val`.
    :param thinning: Positive integer that controls the thinning ratio for retained
        values. Defaults to 1, i.e. no thinning.
    :param int collection_size: Size of the returned collection. If not
        specified, the size will be ``(upper - lower) // thinning``. If the
        size is larger than ``(upper - lower) // thinning``, only the top
        ``(upper - lower) // thinning`` entries will be non-zero.
    :param `**progbar_opts`: optional additional progress bar arguments. A
        `diagnostics_fn` can be supplied which when passed the current value
        from `body_fun` returns a string that is used to update the progress
        bar postfix. Also a `progbar_desc` keyword argument can be supplied
        which is used to label the progress bar.
    :return: collection with the same type as `init_val` with values
        collected along the leading axis of `np.ndarray` objects.
    """
    assert lower <= upper
    assert thinning >= 1
    collection_size = ((upper - lower) // thinning
                       if collection_size is None else collection_size)
    assert collection_size >= (upper - lower) // thinning
    init_val_flat, unravel_fn = ravel_pytree(transform(init_val))
    start_idx = lower + (upper - lower) % thinning
    num_chains = progbar_opts.pop("num_chains", 1)
    # host_callback does not work yet with multi-GPU platforms
    # See: https://github.com/google/jax/issues/6447
    if num_chains > 1 and jax.default_backend() == "gpu":
        warnings.warn(
            "We will disable progress bar because it does not work yet on multi-GPUs platforms."
        )
        progbar = False

    @cached_by(fori_collect, body_fun, transform)
    def _body_fn(i, vals):
        val, collection, start_idx, thinning = vals
        val = body_fun(val)
        idx = (i - start_idx) // thinning
        collection = cond(
            idx >= 0,
            collection,
            lambda x: x.at[idx].set(ravel_pytree(transform(val))[0]),
            collection,
            identity,
        )
        return val, collection, start_idx, thinning

    collection = jnp.zeros((collection_size, ) + init_val_flat.shape)
    if not progbar:
        last_val, collection, _, _ = fori_loop(
            0, upper, _body_fn, (init_val, collection, start_idx, thinning))
    elif num_chains > 1:
        progress_bar_fori_loop = progress_bar_factory(upper, num_chains)
        _body_fn_pbar = progress_bar_fori_loop(_body_fn)
        last_val, collection, _, _ = fori_loop(
            0, upper, _body_fn_pbar,
            (init_val, collection, start_idx, thinning))
    else:
        diagnostics_fn = progbar_opts.pop("diagnostics_fn", None)
        progbar_desc = progbar_opts.pop("progbar_desc", lambda x: "")

        vals = (init_val, collection, device_put(start_idx),
                device_put(thinning))
        if upper == 0:
            # special case, only compiling
            jit(_body_fn)(0, vals)
        else:
            with tqdm.trange(upper) as t:
                for i in t:
                    vals = jit(_body_fn)(i, vals)
                    t.set_description(progbar_desc(i), refresh=False)
                    if diagnostics_fn:
                        t.set_postfix_str(diagnostics_fn(vals[0]),
                                          refresh=False)

        last_val, collection, _, _ = vals

    unravel_collection = vmap(unravel_fn)(collection)
    return (unravel_collection,
            last_val) if return_last_val else unravel_collection
Exemplo n.º 22
0
 def __call__(self, params):
     param_flat, _ = flatten_util.ravel_pytree(params)
     return self.initialize(param_flat)
Exemplo n.º 23
0
    def sample_kernel(sa_state, model_args=(), model_kwargs=None):
        pe_fn = potential_fn
        if potential_fn_gen:
            pe_fn = potential_fn_gen(*model_args, **model_kwargs)
        zs, pes, loc, scale = sa_state.adapt_state
        # we recompute loc/scale after each iteration to avoid precision loss
        # XXX: consider to expose a setting to do this job periodically
        # to save some computations
        loc = jnp.mean(zs, 0)
        if scale.ndim == 2:
            cov = jnp.cov(zs, rowvar=False, bias=True)
            if cov.shape == ():  # JAX returns scalar for 1D input
                cov = cov.reshape((1, 1))
            cholesky = jnp.linalg.cholesky(cov)
            scale = jnp.where(jnp.any(jnp.isnan(cholesky)), scale, cholesky)
        else:
            scale = jnp.std(zs, 0)

        rng_key, rng_key_z, rng_key_reject, rng_key_accept = random.split(
            sa_state.rng_key, 4)
        _, unravel_fn = ravel_pytree(sa_state.z)

        z = loc + _sample_proposal(scale, rng_key_z)
        pe = pe_fn(unravel_fn(z))
        pe = jnp.where(jnp.isnan(pe), jnp.inf, pe)
        diverging = (pe - sa_state.potential_energy) > max_delta_energy

        # NB: all terms having the pattern *s will have shape N x ...
        # and all terms having the pattern *s_ will have shape (N + 1) x ...
        locs, scales = _get_proposal_loc_and_scale(zs, loc, scale, z)
        zs_ = jnp.concatenate([zs, z[None, :]])
        pes_ = jnp.concatenate([pes, pe[None]])
        locs_ = jnp.concatenate([locs, loc[None, :]])
        scales_ = jnp.concatenate([scales, scale[None, ...]])
        if scale.ndim == 2:  # dense_mass
            log_weights_ = dist.MultivariateNormal(
                locs_, scale_tril=scales_).log_prob(zs_) + pes_
        else:
            log_weights_ = dist.Normal(locs_,
                                       scales_).log_prob(zs_).sum(-1) + pes_
        # mask invalid values (nan, +inf) by -inf
        log_weights_ = jnp.where(jnp.isfinite(log_weights_), log_weights_,
                                 -jnp.inf)
        # get rejecting index
        j = random.categorical(rng_key_reject, log_weights_)
        zs = _numpy_delete(zs_, j)
        pes = _numpy_delete(pes_, j)
        loc = locs_[j]
        scale = scales_[j]
        adapt_state = SAAdaptState(zs, pes, loc, scale)

        # NB: weights[-1] / sum(weights) is the probability of rejecting the new sample `z`.
        accept_prob = 1 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_))
        itr = sa_state.i + 1
        n = jnp.where(sa_state.i < wa_steps, itr, itr - wa_steps)
        mean_accept_prob = sa_state.mean_accept_prob + (
            accept_prob - sa_state.mean_accept_prob) / n

        # XXX: we make a modification of SA sampler in [1]
        # in [1], each MCMC state contains N points `zs`
        # here we do resampling to pick randomly a point from those N points
        k = random.categorical(rng_key_accept, jnp.zeros(zs.shape[0]))
        z = unravel_fn(zs[k])
        pe = pes[k]
        return SAState(itr, z, pe, accept_prob, mean_accept_prob, diverging,
                       adapt_state, rng_key)
Exemplo n.º 24
0
def lfads_losses(params, lfads_hps, key, x_bxt, kl_scale, keep_rate):
    """Compute the training loss of the LFADS autoencoder

  Arguments:
    params: a dictionary of LFADS parameters
    lfads_hps: a dictionary of LFADS hyperparameters
    key: random.PRNGKey for random bits
    x_bxt: np array of input with leading dims being batch and time
    keep_rate: dropout keep rate
    kl_scale: scale on KL

  Returns:
    a dictionary of all losses, including the key 'total' used for optimization
  """

    B = lfads_hps['batch_size']
    key, skeys = utils.keygen(key, 2)
    keys = random.split(next(skeys), B)
    lfads = batch_lfads(params, lfads_hps, keys, x_bxt, keep_rate)

    # Sum over time and state dims, average over batch.
    # KL - g0
    ic_post_mean_b = lfads['ic_mean']
    ic_post_logvar_b = lfads['ic_logvar']
    kl_loss_g0_b = dists.batch_kl_gauss_gauss(ic_post_mean_b, ic_post_logvar_b,
                                              params['ic_prior'],
                                              lfads_hps['var_min'])
    kl_loss_g0_prescale = np.sum(kl_loss_g0_b) / B
    kl_loss_g0 = kl_scale * kl_loss_g0_prescale

    # KL - Inferred input
    ii_post_mean_bxt = lfads['ii_mean_t']
    ii_post_var_bxt = lfads['ii_logvar_t']
    keys = random.split(next(skeys), B)
    kl_loss_ii_b = dists.batch_kl_gauss_ar1(keys, ii_post_mean_bxt,
                                            ii_post_var_bxt,
                                            params['ii_prior'],
                                            lfads_hps['var_min'])
    kl_loss_ii_prescale = np.sum(kl_loss_ii_b) / B
    kl_loss_ii = kl_scale * kl_loss_ii_prescale

    # Log-likelihood of data given latents.
    lograte_bxt = lfads['lograte_t']
    log_p_xgz = np.sum(dists.poisson_log_likelihood(x_bxt, lograte_bxt)) / B

    # L2
    l2reg = lfads_hps['l2reg']
    flatten_lfads = lambda params: flatten_util.ravel_pytree(params)[0]
    l2_loss = l2reg * np.sum(flatten_lfads(params)**2)

    loss = -log_p_xgz + kl_loss_g0 + kl_loss_ii + l2_loss
    all_losses = {
        'total': loss,
        'nlog_p_xgz': -log_p_xgz,
        'kl_g0': kl_loss_g0,
        'kl_g0_prescale': kl_loss_g0_prescale,
        'kl_ii': kl_loss_ii,
        'kl_ii_prescale': kl_loss_ii_prescale,
        'l2': l2_loss
    }
    return all_losses
Exemplo n.º 25
0
    def construct_proxy_fn(
        prototype_trace,
        subsample_plate_sizes,
        model,
        model_args,
        model_kwargs,
        num_blocks=1,
    ):
        ref_params = {
            name: biject_to(prototype_trace[name]["fn"].support).inv(value)
            for name, value in reference_params.items()
        }

        ref_params_flat, unravel_fn = ravel_pytree(ref_params)

        def log_likelihood(params_flat, subsample_indices=None):
            if subsample_indices is None:
                subsample_indices = {
                    k: jnp.arange(v[0])
                    for k, v in subsample_plate_sizes.items()
                }
            params = unravel_fn(params_flat)
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                params = {
                    name: biject_to(prototype_trace[name]["fn"].support)(value)
                    for name, value in params.items()
                }
                with block(), trace() as tr, substitute(
                        data=subsample_indices), substitute(data=params):
                    model(*model_args, **model_kwargs)

            log_lik = {}
            for site in tr.values():
                if site["type"] == "sample" and site["is_observed"]:
                    for frame in site["cond_indep_stack"]:
                        if frame.name in log_lik:
                            log_lik[frame.name] += _sum_all_except_at_dim(
                                site["fn"].log_prob(site["value"]), frame.dim)
                        else:
                            log_lik[frame.name] = _sum_all_except_at_dim(
                                site["fn"].log_prob(site["value"]), frame.dim)
            return log_lik

        def log_likelihood_sum(params_flat, subsample_indices=None):
            return {
                k: v.sum()
                for k, v in log_likelihood(params_flat,
                                           subsample_indices).items()
            }

        # those stats are dict keyed by subsample names
        ref_log_likelihoods_sum = log_likelihood_sum(ref_params_flat)
        ref_log_likelihood_grads_sum = jacobian(log_likelihood_sum)(
            ref_params_flat)
        ref_log_likelihood_hessians_sum = hessian(log_likelihood_sum)(
            ref_params_flat)

        def gibbs_init(rng_key, gibbs_sites):
            ref_subsample_log_liks = log_likelihood(ref_params_flat,
                                                    gibbs_sites)
            ref_subsample_log_lik_grads = jacfwd(log_likelihood)(
                ref_params_flat, gibbs_sites)
            ref_subsample_log_lik_hessians = jacfwd(jacfwd(log_likelihood))(
                ref_params_flat, gibbs_sites)
            return TaylorProxyState(
                ref_subsample_log_liks,
                ref_subsample_log_lik_grads,
                ref_subsample_log_lik_hessians,
            )

        def gibbs_update(rng_key, gibbs_sites, gibbs_state):
            u_new, pads, new_idxs, starts = _block_update_proxy(
                num_blocks, rng_key, gibbs_sites, subsample_plate_sizes)

            new_states = defaultdict(dict)
            ref_subsample_log_liks = log_likelihood(ref_params_flat, new_idxs)
            ref_subsample_log_lik_grads = jacfwd(log_likelihood)(
                ref_params_flat, new_idxs)
            ref_subsample_log_lik_hessians = jacfwd(jacfwd(log_likelihood))(
                ref_params_flat, new_idxs)
            for stat, new_block_values, last_values in zip(
                ["log_liks", "grads", "hessians"],
                [
                    ref_subsample_log_liks,
                    ref_subsample_log_lik_grads,
                    ref_subsample_log_lik_hessians,
                ],
                [
                    gibbs_state.ref_subsample_log_liks,
                    gibbs_state.ref_subsample_log_lik_grads,
                    gibbs_state.ref_subsample_log_lik_hessians,
                ],
            ):
                for name, subsample_idx in gibbs_sites.items():
                    size, subsample_size = subsample_plate_sizes[name]
                    pad, start = pads[name], starts[name]
                    new_value = jnp.pad(
                        last_values[name],
                        [(0, pad)] + [(0, 0)] *
                        (jnp.ndim(last_values[name]) - 1),
                    )
                    new_value = lax.dynamic_update_slice_in_dim(
                        new_value, new_block_values[name], start, 0)
                    new_states[stat][name] = new_value[:subsample_size]
            gibbs_state = TaylorProxyState(new_states["log_liks"],
                                           new_states["grads"],
                                           new_states["hessians"])
            return u_new, gibbs_state

        def proxy_fn(params, subsample_lik_sites, gibbs_state):
            params_flat, _ = ravel_pytree(params)
            params_diff = params_flat - ref_params_flat

            ref_subsample_log_liks = gibbs_state.ref_subsample_log_liks
            ref_subsample_log_lik_grads = gibbs_state.ref_subsample_log_lik_grads
            ref_subsample_log_lik_hessians = gibbs_state.ref_subsample_log_lik_hessians

            proxy_sum = defaultdict(float)
            proxy_subsample = defaultdict(float)
            for name in subsample_lik_sites:
                proxy_subsample[name] = (
                    ref_subsample_log_liks[name] +
                    jnp.dot(ref_subsample_log_lik_grads[name], params_diff) +
                    0.5 * jnp.dot(
                        jnp.dot(ref_subsample_log_lik_hessians[name],
                                params_diff),
                        params_diff,
                    ))

                proxy_sum[name] = (
                    ref_log_likelihoods_sum[name] +
                    jnp.dot(ref_log_likelihood_grads_sum[name], params_diff) +
                    0.5 * jnp.dot(
                        jnp.dot(ref_log_likelihood_hessians_sum[name],
                                params_diff),
                        params_diff,
                    ))
            return proxy_sum, proxy_subsample

        return proxy_fn, gibbs_init, gibbs_update
Exemplo n.º 26
0
 def init_fn(z, rng_key):
     z_flat, _ = ravel_pytree(z)
     results = kernel.bootstrap_results(z_flat)
     return TFPKernelState(z, results, rng_key)
Exemplo n.º 27
0
 def _assertAllClose(self, x, y, rtol):
     x = ravel_pytree(x)[0]
     y = ravel_pytree(y)[0]
     diff = 2 * np.sum(
         np.abs(x - y)) / (np.sum(np.abs(x)) + np.sum(np.abs(y)) + 1e-4)
     self.assertLess(diff, rtol)
Exemplo n.º 28
0
          "epoch": epoch,
          "wallclock": time.time() - start_time
      })

    return get_params(opt_state)

  # See https://github.com/google/jax/issues/7809.
  binarize = lambda arr: tree_map(lambda x: x > 0.5, arr)

  print("Training normal model...")
  everything_mask = tree_map(lambda x: jnp.ones_like(x, dtype=jnp.dtype("bool")), init_params)
  final_params = train(init_params, everything_mask, "no_mask")

  # Mask as was implemented in the original paper
  print("Training lottery ticket model...")
  final_params_flat, unravel = ravel_pytree(final_params)
  cutoff = jnp.percentile(jnp.abs(final_params_flat), config.remove_percentile)
  mask = binarize(unravel(jnp.abs(final_params_flat) > cutoff))
  train(init_params, mask, "lottery_mask")

  print("Training lottery ticket sign model...")
  # The lottery ticket mask but instead of using the initial weights, just use
  # the sign of the initial weights.
  mask = binarize(unravel(jnp.abs(final_params_flat) > cutoff))
  w0 = tree_map(lambda x, m: 0.01 * jnp.sign(x) * m, final_params, mask)
  train(w0, mask, "lottery_sign_mask")

  # Totally random mask
  print("Training random mask model...")
  mask = binarize(
      unravel(random.uniform(rp.poop(), final_params_flat.shape) > config.remove_percentile / 100))
Exemplo n.º 29
0
 def jacobian_builder(losses, params):
     grads_tree = jax.jacrev(losses)(params)
     grads, _ = flatten_util.ravel_pytree(grads_tree)
     grads = jnp.reshape(grads,
                         (batch_size, int(grads.shape[0] / batch_size)))
     return grads
Exemplo n.º 30
0
 def dynamics_one_flat(t, aug, args):
     flat, _ = ravel_pytree(dynamics_one(t, unravel(aug), args))
     return flat