Esempio n. 1
0
def test_tree_max_product_chain():
    rnd = np.random.RandomState(0)
    forward = np.c_[np.arange(9), np.arange(1, 10)]
    backward = np.c_[np.arange(1, 10), np.arange(9)]
    for i in xrange(10):
        unary_potentials = rnd.normal(size=(10, 3))
        pairwise_potentials = rnd.normal(size=(9, 3, 3))
        for chain in [forward, backward]:
            result_ad3 = inference_ad3(unary_potentials, pairwise_potentials, chain, branch_and_bound=True)
            result_mp = inference_max_product(unary_potentials, pairwise_potentials, chain)
            assert_array_equal(result_ad3, result_mp)
Esempio n. 2
0
def test_tree_max_product_chain():
    rnd = np.random.RandomState(0)
    forward = np.c_[np.arange(9), np.arange(1, 10)]
    backward = np.c_[np.arange(1, 10), np.arange(9)]
    for i in xrange(10):
        unary_potentials = rnd.normal(size=(10, 3))
        pairwise_potentials = rnd.normal(size=(9, 3, 3))
        for chain in [forward, backward]:
            result_ad3 = inference_ad3(unary_potentials,
                                       pairwise_potentials,
                                       chain,
                                       branch_and_bound=True)
            result_mp = inference_max_product(unary_potentials,
                                              pairwise_potentials, chain)
            assert_array_equal(result_ad3, result_mp)
Esempio n. 3
0
def test_tree_max_product_tree():
    try:
        from scipy.sparse.csgraph import minimum_spanning_tree
    except:
        raise SkipTest("Not testing trees, scipy version >= 0.11 required")

    rnd = np.random.RandomState(0)
    for i in xrange(100):
        # generate random tree using mst
        graph = rnd.uniform(size=(10, 10))
        tree = minimum_spanning_tree(sparse.csr_matrix(graph))
        tree_edges = np.c_[tree.nonzero()]

        unary_potentials = rnd.normal(size=(10, 3))
        pairwise_potentials = rnd.normal(size=(9, 3, 3))
        result_ad3 = inference_ad3(unary_potentials, pairwise_potentials, tree_edges, branch_and_bound=True)
        result_mp = inference_max_product(unary_potentials, pairwise_potentials, tree_edges)
        assert_array_equal(result_ad3, result_mp)
Esempio n. 4
0
def test_tree_max_product_tree():
    try:
        from scipy.sparse.csgraph import minimum_spanning_tree
    except:
        raise SkipTest("Not testing trees, scipy version >= 0.11 required")

    rnd = np.random.RandomState(0)
    for i in xrange(100):
        # generate random tree using mst
        graph = rnd.uniform(size=(10, 10))
        tree = minimum_spanning_tree(sparse.csr_matrix(graph))
        tree_edges = np.c_[tree.nonzero()]

        unary_potentials = rnd.normal(size=(10, 3))
        pairwise_potentials = rnd.normal(size=(9, 3, 3))
        result_ad3 = inference_ad3(unary_potentials,
                                   pairwise_potentials,
                                   tree_edges,
                                   branch_and_bound=True)
        result_mp = inference_max_product(unary_potentials,
                                          pairwise_potentials, tree_edges)
        assert_array_equal(result_ad3, result_mp)
Esempio n. 5
0
 def batch_marginals(self, param_x, param_y, w, num_classes=None):
     ys = []
     marginalss = []
     for x in param_x:
         if self.inference_method == "max-product":
             y = inference_max_product(self._get_unary_potentials(x, w),
                                       self._get_pairwise_potentials(x, w),
                                       self._get_edges(x),
                                       relaxed=True)
             marginals = self.turn_dense_to_one_matrix(y, num_classes)
         elif self.inference_method == "ad3":
             y, marginals = inference_ad3_local(
                 self._get_unary_potentials(x, w),
                 self._get_pairwise_potentials(x, w),
                 self._get_edges(x),
                 relaxed=True,
                 branch_and_bound=False,
                 return_marginals=True)
         else:
             print("Unknown inference method !")
         ys.append(y)
         marginalss.append(marginals)
     return np.array(ys), np.array(marginalss)