def test_iterative_max_product_chain(): rnd = np.random.RandomState(0) chain = np.c_[np.arange(9), np.arange(1, 10)] for i in xrange(10): unary_potentials = rnd.normal(size=(10, 3)) pairwise_potentials = rnd.normal(size=(9, 3, 3)) result_ad3 = inference_ad3(unary_potentials, pairwise_potentials, chain, branch_and_bound=True) result_mp = iterative_max_product(unary_potentials, pairwise_potentials, chain) assert_array_equal(result_ad3, result_mp)
def test_iterative_max_product_tree(): try: from scipy.sparse.csgraph import minimum_spanning_tree except: raise SkipTest("Not testing trees, scipy version >= 0.11 required") rnd = np.random.RandomState(0) for i in xrange(100): # generate random tree using mst graph = rnd.uniform(size=(10, 10)) tree = minimum_spanning_tree(sparse.csr_matrix(graph)) tree_edges = np.c_[tree.nonzero()] unary_potentials = rnd.normal(size=(10, 3)) pairwise_potentials = rnd.normal(size=(9, 3, 3)) result_ad3 = inference_ad3(unary_potentials, pairwise_potentials, tree_edges, branch_and_bound=True) result_mp = iterative_max_product(unary_potentials, pairwise_potentials, tree_edges) assert_array_equal(result_ad3, result_mp)