def _build_tree(self, depth, state, sum_mom, sum_weight, stats, h_init, rng): if depth == 0: # recursion base case try: state = self.integrator.step(state) h = self.system.h(state) h = np.inf if np.isnan(h) else h sum_mom += state.mom sum_weight += LogRepFloat(log_val=-h) stats['sum_acc_prob'] += min(1, np.exp(h_init - h)) stats['n_step'] += 1 terminate = h - h_init > self.max_delta_h if terminate: stats['diverging'] = True logger.info( f'Terminating build_tree due to integrator divergence ' f'(delta_h = {h - h_init:.1e}).') except IntegratorError as e: logger.info( f'Terminating build_tree due to integrator error:\n{e!s}') stats['non_reversible_step'] = isinstance( e, NonReversibleStepError) stats['convergence_error'] = isinstance(e, ConvergenceError) state = None terminate = True return terminate, state, state, state sum_mom_i, sum_mom_o = np.zeros((2, ) + state.mom.shape) sum_weight_i, sum_weight_o = LogRepFloat(0.), LogRepFloat(0.) # build inner subsubtree terminate_i, state_i, state, state_pi = self._build_tree( depth - 1, state, sum_mom_i, sum_weight_i, stats, h_init, rng) if terminate_i: return True, None, None, None # build outer subsubtree terminate_o, _, state_o, state_po = self._build_tree( depth - 1, state, sum_mom_o, sum_weight_o, stats, h_init, rng) if terminate_o: return True, None, None, None # independently sample proposal from 2 subsubtrees by relative weights sum_weight_s = sum_weight_i + sum_weight_o accept_o_prob = sum_weight_o / sum_weight_s state_p = state_po if rng.uniform() < accept_o_prob else state_pi # update overall tree weight sum_weight += sum_weight_s # calculate termination criteria for subtree sum_mom_s = sum_mom_i + sum_mom_o terminate_s = self.termination_criterion(state_i, state_o, sum_mom_s) # update overall tree summed momentum sum_mom += sum_mom_s return terminate_s, state_i, state_o, state_p
def sample(self, state, rng): h_init = self.system.h(state) sum_mom = state.mom.copy() sum_weight = LogRepFloat(log_val=-h_init) stats = {'n_step': 0, 'sum_acc_prob': 0.} state_n, state_l, state_r = state, state.copy(), state.copy() # set integration directions of initial left and right tree leaves state_l.dir = -1 state_r.dir = +1 for depth in range(self.max_tree_depth): # uniformly sample direction to expand tree in direction = 2 * (rng.uniform() < 0.5) - 1 sum_mom_s = np.zeros(state.mom.shape) sum_weight_s = LogRepFloat(0.) if direction == 1: # expand tree by adding subtree to right edge terminate_s, _, state_r, state_p = self._build_tree( depth, state_r, sum_mom_s, sum_weight_s, stats, h_init, rng) else: # expand tree by adding subtree to left edge terminate_s, _, state_l, state_p = self._build_tree( depth, state_l, sum_mom_s, sum_weight_s, stats, h_init, rng) if terminate_s: break # progressively sample new state by choosing between # current new state and proposal from new subtree, biasing # towards the new subtree proposal if rng.uniform() < sum_weight_s / sum_weight: state_n = state_p sum_weight += sum_weight_s sum_mom += sum_mom_s if self.termination_criterion(state_l, state_r, sum_mom): break if stats['n_step'] > 0: stats['accept_prob'] = stats['sum_acc_prob'] / stats['n_step'] else: stats['accept_prob'] = 0. stats['hamiltonian'] = self.system.h(state_n) stats['tree_depth'] = depth return state_n, stats