class InferenceTests(unittest.TestCase):
    def setUp(self):
        self.model = LinearMRF(1, 2)

    def test_inf(self):
        unary_beliefs = np.array([[0, 1], [0, 1]])
        unary_potentials = np.array([[1, 0], [1, 0]])
        pairwise_potentials = np.array([[2, 1, 1, 0]])
        correct = np.array([[1, 0], [1, 0]])
        result = self.model.inference_itr(unary_beliefs, unary_potentials,
                                          pairwise_potentials)
        np.testing.assert_array_equal(correct, result)
Exemple #2
0
class InferenceTests(unittest.TestCase):
    def setUp(self):
        self.model = LinearMRF(1, 2)

    def test_inf(self):
        unary_beliefs = np.array([[0, 1], [0, 1]])
        unary_potentials = np.array([[1, 0], [1, 0]])
        pairwise_potentials = np.array([[2, 1, 1, 0]])
        correct = np.array([[1, 0], [1, 0]])
        result = self.model.inference_itr(unary_beliefs, unary_potentials,
                                          pairwise_potentials)
        np.testing.assert_array_equal(correct, result)

    def test_local_score_1x2_1(self):
        unary_beliefs = np.array([[1, 0], [0, 1]])
        unary_potentials = np.array([[1, 0], [1, 0]])
        pairwise_potentials = np.array([[2, 1, 1, 0]])
        result = self.model.calculate_local_score(1, 0, unary_beliefs,
                                                  unary_potentials,
                                                  pairwise_potentials)
        self.assertEqual(3, result)

        result = self.model.calculate_local_score(1, 1, unary_beliefs,
                                                  unary_potentials,
                                                  pairwise_potentials)
        self.assertEqual(1, result)

    def test_local_score_1x2_2(self):
        unary_beliefs = np.array([[0, 1], [0, 1]])
        unary_potentials = np.array([[1, 0], [1, 0]])
        pairwise_potentials = np.array([[2, 1, 1, 0.5]])
        result = self.model.calculate_local_score(0, 0, unary_beliefs,
                                                  unary_potentials,
                                                  pairwise_potentials)
        self.assertEqual(1 + 1, result)

        result = self.model.calculate_local_score(0, 1, unary_beliefs,
                                                  unary_potentials,
                                                  pairwise_potentials)
        self.assertEqual(0.5, result)

    def test_local_score(self):
        self.model = LinearMRF(2, 2)
        node = 1
        assignment = 1
        unary_beliefs = np.array([[0, 1], [0, 1], [1, 0], [0, 1]])
        unary_potentials = np.array([[0.1, 0.8], [0.5, 0.71], [0.6, 0.2],
                                     [1, 0]])
        pairwise_potentials = np.array([[1, 1, 1, 0.2], [0, 0, 0, 0],
                                        [1, 1, 0.3, 0.5], [1, 1, 1.2, 1]])
        result = self.model.calculate_local_score(node, assignment,
                                                  unary_beliefs,
                                                  unary_potentials,
                                                  pairwise_potentials)
        self.assertEqual(result, 0.71 + 0.2 + 0.5)

        pairwise_potentials = np.array([[1, 1, 1, 0.7], [0, 0, 0, 0],
                                        [1, 1, 1, 7], [1, 1, 1, 2]])
        result = self.model.calculate_local_score(node, assignment,
                                                  unary_beliefs,
                                                  unary_potentials,
                                                  pairwise_potentials)
        self.assertEqual(result, 0.71 + 0.7 + 7)

        pairwise_potentials = np.array([[1, 1, 1, 5], [2, 2, 2, 2],
                                        [1, 1, 1, 7], [1, 1, 1, -1]])
        result = self.model.calculate_local_score(node, assignment,
                                                  unary_beliefs,
                                                  unary_potentials,
                                                  pairwise_potentials)
        self.assertEqual(result, 0.71 + 5 + 7)