def test_tree_max_product_chain(): rnd = np.random.RandomState(0) forward = np.c_[np.arange(9), np.arange(1, 10)] backward = np.c_[np.arange(1, 10), np.arange(9)] for i in xrange(10): unary_potentials = rnd.normal(size=(10, 3)) pairwise_potentials = rnd.normal(size=(9, 3, 3)) for chain in [forward, backward]: result_ad3 = inference_ad3(unary_potentials, pairwise_potentials, chain, branch_and_bound=True) result_mp = inference_max_product(unary_potentials, pairwise_potentials, chain) assert_array_equal(result_ad3, result_mp)
def test_tree_max_product_tree(): try: from scipy.sparse.csgraph import minimum_spanning_tree except: raise SkipTest("Not testing trees, scipy version >= 0.11 required") rnd = np.random.RandomState(0) for i in xrange(100): # generate random tree using mst graph = rnd.uniform(size=(10, 10)) tree = minimum_spanning_tree(sparse.csr_matrix(graph)) tree_edges = np.c_[tree.nonzero()] unary_potentials = rnd.normal(size=(10, 3)) pairwise_potentials = rnd.normal(size=(9, 3, 3)) result_ad3 = inference_ad3(unary_potentials, pairwise_potentials, tree_edges, branch_and_bound=True) result_mp = inference_max_product(unary_potentials, pairwise_potentials, tree_edges) assert_array_equal(result_ad3, result_mp)
def batch_marginals(self, param_x, param_y, w, num_classes=None): ys = [] marginalss = [] for x in param_x: if self.inference_method == "max-product": y = inference_max_product(self._get_unary_potentials(x, w), self._get_pairwise_potentials(x, w), self._get_edges(x), relaxed=True) marginals = self.turn_dense_to_one_matrix(y, num_classes) elif self.inference_method == "ad3": y, marginals = inference_ad3_local( self._get_unary_potentials(x, w), self._get_pairwise_potentials(x, w), self._get_edges(x), relaxed=True, branch_and_bound=False, return_marginals=True) else: print("Unknown inference method !") ys.append(y) marginalss.append(marginals) return np.array(ys), np.array(marginalss)