from trees.util import plot_tree, plot_tree_2d from trees.ddt import DirichletDiffusionTree, Inverse, GaussianLikelihoodModel from trees.mcmc import MetropolisHastingsSampler from tqdm import tqdm if __name__ == "__main__": D = 2 N = 100 X = np.random.multivariate_normal(mean=np.zeros(D), cov=np.eye(D), size=N).astype(np.float32) df = Inverse(c=1) lm = GaussianLikelihoodModel(sigma=np.eye(D) / 4.0, mu0=np.zeros(D), sigma0=np.eye(D)) ddt = DirichletDiffusionTree(df=df, likelihood_model=lm) mh = MetropolisHastingsSampler(ddt, X) mh.initialize_assignments() for _ in tqdm(xrange(1000)): mh.sample() plt.figure() plt.plot(mh.likelihoods) plt.figure() plot_tree(mh.tree) plt.figure() plot_tree_2d(mh.tree, X) plt.show()
from trees.util import plot_tree from trees import Tree, TreeNode, TreeLeaf import matplotlib.pyplot as plt if __name__ == "__main__": leaf1 = TreeLeaf(1) leaf2 = TreeLeaf(2) leaf3 = TreeLeaf(3) node1 = TreeNode() node1.add_child(leaf1) node1.add_child(leaf2) node2 = TreeNode() node2.add_child(node1) node2.add_child(leaf3) tree = Tree(node2) plot_tree(tree) plt.show() p = leaf1.detach() plot_tree(tree) plt.show() leaf3.attach(p) plot_tree(tree) plt.show()