예제 #1
0
    def update_fn(gradient, state, params=None):
        del params
        lr = step_size_fn(state.count)
        lr_sqrt = jnp.sqrt(lr)
        noise_std = jnp.sqrt(2 * (1 - momentum_decay))

        preconditioner_state = preconditioner.update_preconditioner(
            gradient, state.preconditioner_state)

        noise, new_key = tree_utils.normal_like_tree(gradient, state.rng_key)
        noise = preconditioner.multiply_by_m_sqrt(noise, preconditioner_state)

        def update_momentum(m, g, n):
            return momentum_decay * m + g * lr_sqrt + n * noise_std

        momentum = jax.tree_map(update_momentum, state.momentum, gradient,
                                noise)
        updates = preconditioner.multiply_by_m_inv(momentum,
                                                   preconditioner_state)
        updates = jax.tree_map(lambda m: m * lr_sqrt, updates)
        return updates, OptaxSGLDState(
            count=state.count + 1,
            rng_key=new_key,
            momentum=momentum,
            preconditioner_state=preconditioner_state)
예제 #2
0
 def sample_parms_fn(params, state):
     mean = params["mean"]
     std = jax.tree_map(jax.nn.softplus, params["inv_softplus_std"])
     noise, new_key = tree_utils.normal_like_tree(mean, state["mfvi_key"])
     params_sampled = jax.tree_multimap(lambda m, s, n: m + n * s, mean,
                                        std, noise)
     new_mfvi_state = {
         "net_state": copy.deepcopy(state["net_state"]),
         "mfvi_key": new_key
     }
     return params_sampled, new_mfvi_state
예제 #3
0
    def adaptive_hmc_update(dataset,
                            params,
                            net_state,
                            log_likelihood,
                            state_grad,
                            key,
                            step_size,
                            n_leapfrog_steps,
                            target_accept_rate,
                            step_size_adaptation_speed,
                            do_mh_correction=True):

        normal_key, uniform_key, jitter_key = jax.random.split(key, 3)
        momentum, _ = tree_utils.normal_like_tree(params, normal_key)

        new_params, net_state, new_momentum, new_grad, new_log_likelihood = (
            leapfrog(dataset, params, net_state, momentum, state_grad,
                     step_size, n_leapfrog_steps))
        accept_prob = get_accept_prob(log_likelihood, params, momentum,
                                      new_log_likelihood, new_params,
                                      new_momentum)
        accepted = jax.random.uniform(uniform_key) < accept_prob

        step_size = adapt_step_size(step_size, target_accept_rate, accept_prob,
                                    step_size_adaptation_speed)

        if do_mh_correction:
            params = jax.lax.cond(accepted, _first, _second,
                                  (new_params, params))
            log_likelihood = jnp.where(accepted, new_log_likelihood,
                                       log_likelihood)
            state_grad = jax.lax.cond(accepted, _first, _second,
                                      (new_grad, state_grad))
        else:
            params, log_likelihood, state_grad = (new_params,
                                                  new_log_likelihood, new_grad)
        return (params, net_state, log_likelihood, state_grad, step_size,
                accept_prob, accepted)