コード例 #1
0
 def _build_tree(self, depth, state, sum_mom, sum_weight, stats, h_init,
                 rng):
     if depth == 0:
         # recursion base case
         try:
             state = self.integrator.step(state)
             h = self.system.h(state)
             h = np.inf if np.isnan(h) else h
             sum_mom += state.mom
             sum_weight += LogRepFloat(log_val=-h)
             stats['sum_acc_prob'] += min(1, np.exp(h_init - h))
             stats['n_step'] += 1
             terminate = h - h_init > self.max_delta_h
             if terminate:
                 stats['diverging'] = True
                 logger.info(
                     f'Terminating build_tree due to integrator divergence '
                     f'(delta_h = {h - h_init:.1e}).')
         except IntegratorError as e:
             logger.info(
                 f'Terminating build_tree due to integrator error:\n{e!s}')
             stats['non_reversible_step'] = isinstance(
                 e, NonReversibleStepError)
             stats['convergence_error'] = isinstance(e, ConvergenceError)
             state = None
             terminate = True
         return terminate, state, state, state
     sum_mom_i, sum_mom_o = np.zeros((2, ) + state.mom.shape)
     sum_weight_i, sum_weight_o = LogRepFloat(0.), LogRepFloat(0.)
     # build inner subsubtree
     terminate_i, state_i, state, state_pi = self._build_tree(
         depth - 1, state, sum_mom_i, sum_weight_i, stats, h_init, rng)
     if terminate_i:
         return True, None, None, None
     # build outer subsubtree
     terminate_o, _, state_o, state_po = self._build_tree(
         depth - 1, state, sum_mom_o, sum_weight_o, stats, h_init, rng)
     if terminate_o:
         return True, None, None, None
     # independently sample proposal from 2 subsubtrees by relative weights
     sum_weight_s = sum_weight_i + sum_weight_o
     accept_o_prob = sum_weight_o / sum_weight_s
     state_p = state_po if rng.uniform() < accept_o_prob else state_pi
     # update overall tree weight
     sum_weight += sum_weight_s
     # calculate termination criteria for subtree
     sum_mom_s = sum_mom_i + sum_mom_o
     terminate_s = self.termination_criterion(state_i, state_o, sum_mom_s)
     # update overall tree summed momentum
     sum_mom += sum_mom_s
     return terminate_s, state_i, state_o, state_p
コード例 #2
0
 def sample(self, state, rng):
     h_init = self.system.h(state)
     sum_mom = state.mom.copy()
     sum_weight = LogRepFloat(log_val=-h_init)
     stats = {'n_step': 0, 'sum_acc_prob': 0.}
     state_n, state_l, state_r = state, state.copy(), state.copy()
     # set integration directions of initial left and right tree leaves
     state_l.dir = -1
     state_r.dir = +1
     for depth in range(self.max_tree_depth):
         # uniformly sample direction to expand tree in
         direction = 2 * (rng.uniform() < 0.5) - 1
         sum_mom_s = np.zeros(state.mom.shape)
         sum_weight_s = LogRepFloat(0.)
         if direction == 1:
             # expand tree by adding subtree to right edge
             terminate_s, _, state_r, state_p = self._build_tree(
                 depth, state_r, sum_mom_s, sum_weight_s, stats, h_init,
                 rng)
         else:
             # expand tree by adding subtree to left edge
             terminate_s, _, state_l, state_p = self._build_tree(
                 depth, state_l, sum_mom_s, sum_weight_s, stats, h_init,
                 rng)
         if terminate_s:
             break
         # progressively sample new state by choosing between
         # current new state and proposal from new subtree, biasing
         # towards the new subtree proposal
         if rng.uniform() < sum_weight_s / sum_weight:
             state_n = state_p
         sum_weight += sum_weight_s
         sum_mom += sum_mom_s
         if self.termination_criterion(state_l, state_r, sum_mom):
             break
     if stats['n_step'] > 0:
         stats['accept_prob'] = stats['sum_acc_prob'] / stats['n_step']
     else:
         stats['accept_prob'] = 0.
     stats['hamiltonian'] = self.system.h(state_n)
     stats['tree_depth'] = depth
     return state_n, stats