def extend(self, direction): """Double the treesize by extending the tree in the given direction. If direction is larger than 0, extend it to the right, otherwise extend it to the left. Return a tuple `(diverging, turning)` of type (DivergenceInfo, bool). `diverging` indicates, that the tree extension was aborted because the energy change exceeded `self.Emax`. `turning` indicates that the tree extension was stopped because the termination criterior was reached (the trajectory is turning back). """ if direction > 0: tree, diverging, turning = self._build_subtree( self.right, self.depth, floatX(np.asarray(self.step_size))) leftmost_begin, leftmost_end = self.left, self.right rightmost_begin, rightmost_end = tree.left, tree.right leftmost_p_sum = self.p_sum rightmost_p_sum = tree.p_sum self.right = tree.right else: tree, diverging, turning = self._build_subtree( self.left, self.depth, floatX(np.asarray(-self.step_size))) leftmost_begin, leftmost_end = tree.right, tree.left rightmost_begin, rightmost_end = self.left, self.right leftmost_p_sum = tree.p_sum rightmost_p_sum = self.p_sum self.left = tree.right self.depth += 1 self.n_proposals += tree.n_proposals if diverging or turning: return diverging, turning size1, size2 = self.log_size, tree.log_size if logbern(size2 - size1): self.proposal = tree.proposal self.log_size = np.logaddexp(self.log_size, tree.log_size) self.log_weighted_accept_sum = np.logaddexp( self.log_weighted_accept_sum, tree.log_weighted_accept_sum) self.p_sum[:] += tree.p_sum # Additional turning check only when tree depth > 0 to avoid redundant work if self.depth > 0: left, right = self.left, self.right p_sum = self.p_sum turning = (p_sum.dot(left.v) <= 0) or (p_sum.dot(right.v) <= 0) p_sum1 = leftmost_p_sum + rightmost_begin.p.data turning1 = (p_sum1.dot(leftmost_begin.v) <= 0) or (p_sum1.dot( rightmost_begin.v) <= 0) p_sum2 = leftmost_end.p.data + rightmost_p_sum turning2 = (p_sum2.dot(leftmost_end.v) <= 0) or (p_sum2.dot( rightmost_end.v) <= 0) turning = turning | turning1 | turning2 return diverging, turning
def _build_subtree(self, left, depth, epsilon): if depth == 0: return self._single_step(left, epsilon) tree1, diverging, turning = self._build_subtree( left, depth - 1, epsilon) if diverging or turning: return tree1, diverging, turning tree2, diverging, turning = self._build_subtree( tree1.right, depth - 1, epsilon) left, right = tree1.left, tree2.right if not (diverging or turning): p_sum = tree1.p_sum + tree2.p_sum turning = (p_sum.dot(left.v) <= 0) or (p_sum.dot(right.v) <= 0) # Additional U turn check only when depth > 1 to avoid redundant work. if depth - 1 > 0: p_sum1 = tree1.p_sum + tree2.left.p.data turning1 = (p_sum1.dot(tree1.left.v) <= 0) or (p_sum1.dot( tree2.left.v) <= 0) p_sum2 = tree1.right.p.data + tree2.p_sum turning2 = (p_sum2.dot(tree1.right.v) <= 0) or (p_sum2.dot( tree2.right.v) <= 0) turning = turning | turning1 | turning2 log_size = np.logaddexp(tree1.log_size, tree2.log_size) log_weighted_accept_sum = np.logaddexp( tree1.log_weighted_accept_sum, tree2.log_weighted_accept_sum) if logbern(tree2.log_size - log_size): proposal = tree2.proposal else: proposal = tree1.proposal else: p_sum = tree1.p_sum log_size = tree1.log_size log_weighted_accept_sum = tree1.log_weighted_accept_sum proposal = tree1.proposal n_proposals = tree1.n_proposals + tree2.n_proposals tree = Subtree(left, right, p_sum, proposal, log_size, log_weighted_accept_sum, n_proposals) return tree, diverging, turning
def _hamiltonian_step(self, start, p0, step_size): if self.tune and self.iter_count < 200: max_treedepth = self.early_max_treedepth else: max_treedepth = self.max_treedepth tree = _Tree(len(p0), self.integrator, start, step_size, self.Emax) for _ in range(max_treedepth): direction = logbern(np.log(0.5)) * 2 - 1 divergence_info, turning = tree.extend(direction) if divergence_info or turning: break else: if not self.tune: self._reached_max_treedepth += 1 stats = tree.stats() accept_stat = stats["mean_tree_accept"] return HMCStepData(tree.proposal, accept_stat, divergence_info, stats)