Exemplo n.º 1
0
 def sample(self, key, sample_shape=()):
     return random.exponential(
         key, shape=sample_shape + self.batch_shape) / self.rate
Exemplo n.º 2
0
    def sample(self, state, model_args, model_kwargs):
        model_kwargs = {} if model_kwargs is None else model_kwargs
        num_discretes = self._support_sizes_flat.shape[0]

        def potential_fn(z_gibbs, z_hmc):
            return self.inner_kernel._potential_fn_gen(*model_args,
                                                       _gibbs_sites=z_gibbs,
                                                       **model_kwargs)(z_hmc)

        def update_discrete(idx, rng_key, hmc_state, z_discrete, ke_discrete,
                            delta_pe_sum):
            # Algo 1, line 19: get a new discrete proposal
            rng_key, z_discrete_new, pe_new, log_accept_ratio = self._discrete_proposal_fn(
                rng_key, z_discrete, hmc_state.potential_energy,
                partial(potential_fn, z_hmc=hmc_state.z), idx,
                self._support_sizes_flat[idx])
            # Algo 1, line 20: depending on reject or refract, we will update
            # the discrete variable and its corresponding kinetic energy. In case of
            # refract, we will need to update the potential energy and its grad w.r.t. hmc_state.z
            ke_discrete_i_new = ke_discrete[idx] + log_accept_ratio
            grad_ = jacfwd if self.inner_kernel._forward_mode_differentiation else grad
            z_discrete, pe, ke_discrete_i, z_grad = lax.cond(
                ke_discrete_i_new > 0,
                (z_discrete_new, pe_new, ke_discrete_i_new),
                lambda vals: vals + (grad_(partial(potential_fn, vals[0]))
                                     (hmc_state.z), ),
                (z_discrete, hmc_state.potential_energy, ke_discrete[idx],
                 hmc_state.z_grad), identity)

            delta_pe_sum = delta_pe_sum + pe - hmc_state.potential_energy
            ke_discrete = ops.index_update(ke_discrete, idx, ke_discrete_i)
            hmc_state = hmc_state._replace(potential_energy=pe, z_grad=z_grad)
            return rng_key, hmc_state, z_discrete, ke_discrete, delta_pe_sum

        def update_continuous(hmc_state, z_discrete):
            model_kwargs_ = model_kwargs.copy()
            model_kwargs_["_gibbs_sites"] = z_discrete
            hmc_state_new = self.inner_kernel.sample(hmc_state, model_args,
                                                     model_kwargs_)

            # each time a sub-trajectory is performed, we need to reset i and adapt_state
            # (we will only update them at the end of HMCGibbs step)
            # For `num_steps`, we will record its cumulative sum for diagnostics
            hmc_state = hmc_state_new._replace(
                i=hmc_state.i,
                adapt_state=hmc_state.adapt_state,
                num_steps=hmc_state.num_steps + hmc_state_new.num_steps)
            return hmc_state

        def body_fn(i, vals):
            rng_key, hmc_state, z_discrete, ke_discrete, delta_pe_sum, arrival_times = vals
            idx = jnp.argmin(arrival_times)
            # NB: length of each sub-trajectory is scaled from the current min(arrival_times)
            # (see the note at total_time below)
            trajectory_length = arrival_times[idx] * time_unit
            arrival_times = arrival_times - arrival_times[idx]
            arrival_times = ops.index_update(arrival_times, idx, 1.)

            # this is a trick, so that in a sub-trajectory of HMC, we always accept the new proposal
            pe = jnp.inf
            hmc_state = hmc_state._replace(trajectory_length=trajectory_length,
                                           potential_energy=pe)
            # Algo 1, line 7: perform a sub-trajectory
            hmc_state = update_continuous(hmc_state, z_discrete)
            # Algo 1, line 8: perform a discrete update
            rng_key, hmc_state, z_discrete, ke_discrete, delta_pe_sum = update_discrete(
                idx, rng_key, hmc_state, z_discrete, ke_discrete, delta_pe_sum)
            return rng_key, hmc_state, z_discrete, ke_discrete, delta_pe_sum, arrival_times

        z_discrete = {
            k: v
            for k, v in state.z.items() if k not in state.hmc_state.z
        }
        rng_key, rng_ke, rng_time, rng_r, rng_accept = random.split(
            state.rng_key, 5)
        # Algo 1, line 2: sample discrete kinetic energy
        ke_discrete = random.exponential(rng_ke, (num_discretes, ))
        # Algo 1, line 4 and 5: sample the initial amount of time that each discrete site visits
        # the point 0/1. The logic in GetStepSizesNSteps(...) is more complicated but does
        # the same job: the sub-trajectory length eta_t * M_t is the lag between two arrival time.
        arrival_times = random.uniform(rng_time, (num_discretes, ))
        # compute the amount of time to make `num_discrete_updates` discrete updates
        total_time = (self._num_discrete_updates - 1) // num_discretes \
            + jnp.sort(arrival_times)[(self._num_discrete_updates - 1) % num_discretes]
        # NB: total_time can be different from the HMC trajectory length, so we need to scale
        # the time unit so that total_time * time_unit = hmc_trajectory_length
        time_unit = state.hmc_state.trajectory_length / total_time

        # Algo 1, line 2: sample hmc momentum
        r = momentum_generator(state.hmc_state.r,
                               state.hmc_state.adapt_state.mass_matrix_sqrt,
                               rng_r)
        hmc_state = state.hmc_state._replace(r=r, num_steps=0)
        hmc_ke = euclidean_kinetic_energy(
            hmc_state.adapt_state.inverse_mass_matrix, r)
        # Algo 1, line 10: compute the initial energy
        energy_old = hmc_ke + hmc_state.potential_energy

        # Algo 1, line 3: set initial values
        delta_pe_sum = 0.
        init_val = (rng_key, hmc_state, z_discrete, ke_discrete, delta_pe_sum,
                    arrival_times)
        # Algo 1, line 6-9: perform the update loop
        rng_key, hmc_state_new, z_discrete_new, _, delta_pe_sum, _ = fori_loop(
            0, self._num_discrete_updates, body_fn, init_val)
        # Algo 1, line 10: compute the proposal energy
        hmc_ke = euclidean_kinetic_energy(
            hmc_state.adapt_state.inverse_mass_matrix, hmc_state_new.r)
        energy_new = hmc_ke + hmc_state_new.potential_energy
        # Algo 1, line 11: perform MH correction
        delta_energy = energy_new - energy_old - delta_pe_sum
        delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf,
                                 delta_energy)
        accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0)

        # record the correct new num_steps
        hmc_state = hmc_state._replace(num_steps=hmc_state_new.num_steps)
        # reset the trajectory length
        hmc_state_new = hmc_state_new._replace(
            trajectory_length=hmc_state.trajectory_length)
        hmc_state, z_discrete = cond(random.bernoulli(rng_key, accept_prob),
                                     (hmc_state_new, z_discrete_new), identity,
                                     (hmc_state, z_discrete), identity)

        # perform hmc adapting (similar to the implementation in hmc)
        adapt_state = cond(hmc_state.i < self._num_warmup,
                           (hmc_state.i, accept_prob,
                            (hmc_state.z, ), hmc_state.adapt_state),
                           lambda args: self._wa_update(*args),
                           hmc_state.adapt_state, identity)

        itr = hmc_state.i + 1
        n = jnp.where(hmc_state.i < self._num_warmup, itr,
                      itr - self._num_warmup)
        mean_accept_prob_prev = state.hmc_state.mean_accept_prob
        mean_accept_prob = mean_accept_prob_prev + (accept_prob -
                                                    mean_accept_prob_prev) / n
        hmc_state = hmc_state._replace(i=itr,
                                       accept_prob=accept_prob,
                                       mean_accept_prob=mean_accept_prob,
                                       adapt_state=adapt_state)

        z = {**z_discrete, **hmc_state.z}
        return MixedHMCState(z, hmc_state, rng_key, accept_prob)