예제 #1
0
파일: nuts.py 프로젝트: gdupret/pymc3
    def extend(self, direction):
        """Double the treesize by extending the tree in the given direction.

        If direction is larger than 0, extend it to the right, otherwise
        extend it to the left.

        Return a tuple `(diverging, turning)` of type (DivergenceInfo, bool).
        `diverging` indicates, that the tree extension was aborted because
        the energy change exceeded `self.Emax`. `turning` indicates that
        the tree extension was stopped because the termination criterior
        was reached (the trajectory is turning back).
        """
        if direction > 0:
            tree, diverging, turning = self._build_subtree(
                self.right, self.depth, floatX(np.asarray(self.step_size))
            )
            leftmost_begin, leftmost_end = self.left, self.right
            rightmost_begin, rightmost_end = tree.left, tree.right
            leftmost_p_sum = self.p_sum
            rightmost_p_sum = tree.p_sum
            self.right = tree.right
        else:
            tree, diverging, turning = self._build_subtree(
                self.left, self.depth, floatX(np.asarray(-self.step_size))
            )
            leftmost_begin, leftmost_end = tree.right, tree.left
            rightmost_begin, rightmost_end = self.left, self.right
            leftmost_p_sum = tree.p_sum
            rightmost_p_sum = self.p_sum
            self.left = tree.right

        self.depth += 1
        self.n_proposals += tree.n_proposals

        if diverging or turning:
            return diverging, turning

        size1, size2 = self.log_size, tree.log_size
        if logbern(size2 - size1):
            self.proposal = tree.proposal

        self.log_size = np.logaddexp(self.log_size, tree.log_size)
        self.log_weighted_accept_sum = np.logaddexp(
            self.log_weighted_accept_sum, tree.log_weighted_accept_sum
        )
        self.p_sum[:] += tree.p_sum

        # Additional turning check only when tree depth > 0 to avoid redundant work
        if self.depth > 0:
            left, right = self.left, self.right
            p_sum = self.p_sum
            turning = (p_sum.dot(left.v) <= 0) or (p_sum.dot(right.v) <= 0)
            p_sum1 = leftmost_p_sum + rightmost_begin.p
            turning1 = (p_sum1.dot(leftmost_begin.v) <= 0) or (p_sum1.dot(rightmost_begin.v) <= 0)
            p_sum2 = leftmost_end.p + rightmost_p_sum
            turning2 = (p_sum2.dot(leftmost_end.v) <= 0) or (p_sum2.dot(rightmost_end.v) <= 0)
            turning = turning | turning1 | turning2

        return diverging, turning
예제 #2
0
    def _build_subtree(self, left, depth, epsilon):
        if depth == 0:
            return self._single_step(left, epsilon)

        tree1, diverging, turning = self._build_subtree(
            left, depth - 1, epsilon)
        if diverging or turning:
            return tree1, diverging, turning

        tree2, diverging, turning = self._build_subtree(
            tree1.right, depth - 1, epsilon)

        left, right = tree1.left, tree2.right

        if not (diverging or turning):
            p_sum = tree1.p_sum + tree2.p_sum
            turning = (p_sum.dot(left.v) <= 0) or (p_sum.dot(right.v) <= 0)
            # Additional U turn check only when depth > 1 to avoid redundant work.
            if depth - 1 > 0:
                p_sum1 = tree1.p_sum + tree2.left.p.data
                turning1 = (p_sum1.dot(tree1.left.v) <= 0) or (p_sum1.dot(
                    tree2.left.v) <= 0)
                p_sum2 = tree1.right.p.data + tree2.p_sum
                turning2 = (p_sum2.dot(tree1.right.v) <= 0) or (p_sum2.dot(
                    tree2.right.v) <= 0)
                turning = turning | turning1 | turning2

            log_size = np.logaddexp(tree1.log_size, tree2.log_size)
            log_weighted_accept_sum = np.logaddexp(
                tree1.log_weighted_accept_sum, tree2.log_weighted_accept_sum)
            if logbern(tree2.log_size - log_size):
                proposal = tree2.proposal
            else:
                proposal = tree1.proposal
        else:
            p_sum = tree1.p_sum
            log_size = tree1.log_size
            log_weighted_accept_sum = tree1.log_weighted_accept_sum
            proposal = tree1.proposal

        n_proposals = tree1.n_proposals + tree2.n_proposals

        tree = Subtree(left, right, p_sum, proposal, log_size,
                       log_weighted_accept_sum, n_proposals)
        return tree, diverging, turning
예제 #3
0
    def _hamiltonian_step(self, start, p0, step_size):
        if self.tune and self.iter_count < 200:
            max_treedepth = self.early_max_treedepth
        else:
            max_treedepth = self.max_treedepth

        tree = _Tree(len(p0), self.integrator, start, step_size, self.Emax)

        for _ in range(max_treedepth):
            direction = logbern(np.log(0.5)) * 2 - 1
            divergence_info, turning = tree.extend(direction)

            if divergence_info or turning:
                break
        else:
            if not self.tune:
                self._reached_max_treedepth += 1

        stats = tree.stats()
        accept_stat = stats["mean_tree_accept"]
        return HMCStepData(tree.proposal, accept_stat, divergence_info, stats)