Ejemplo n.º 1
0
    def test_traversing_on_dense(self):
        """Compare traversal algs on dense SPN"""
        def fun1(node, *args):
            counter[0] += 1

        def fun2(node, *args):
            counter[0] += 1
            if node.is_op:
                return [None] * len(node.inputs)

        # Generate dense graph
        v1 = spn.IndicatorLeaf(num_vars=3, num_vals=2, name="IndicatorLeaf1")
        v2 = spn.IndicatorLeaf(num_vars=3, num_vals=2, name="IndicatorLeaf2")

        gen = spn.DenseSPNGenerator(num_decomps=2,
                                    num_subsets=3,
                                    num_mixtures=2,
                                    input_dist=spn.DenseSPNGenerator.InputDist.MIXTURE,
                                    num_input_mixtures=None)
        root = gen.generate(v1, v2)
        spn.generate_weights(root)

        # Run traversal algs and count nodes
        counter = [0]
        spn.compute_graph_up_down(root, down_fun=fun2, graph_input=1)
        c1 = counter[0]

        counter = [0]
        spn.compute_graph_up(root, val_fun=fun1)
        c2 = counter[0]

        counter = [0]
        spn.traverse_graph(root, fun=fun1, skip_params=False)
        c3 = counter[0]

        # Compare
        self.assertEqual(c1, c3)
        self.assertEqual(c2, c3)
Ejemplo n.º 2
0
    def test_compute_graph_down(self):
        counter = [0]
        parent_vals_saved = {}

        def fun(node, parent_vals):
            parent_vals_saved[node] = parent_vals
            val = sum(parent_vals) + 0.01
            counter[0] += 1
            if node.is_op:
                return [val + i for i, _ in enumerate(node.inputs)]
            else:
                return 101

        # Generate graph
        v1 = spn.RawLeaf(num_vars=1, name="v1")
        v2 = spn.RawLeaf(num_vars=1, name="v2")
        v3 = spn.RawLeaf(num_vars=1, name="v3")
        s1 = spn.Sum(v1, v1, v2, name="s1")  # v1 included twice
        s2 = spn.Sum(v1, v3, name="s2")
        s3 = spn.Sum(v2, v3, v3, name="s3")  # v3 included twice
        s4 = spn.Sum(s1, v1, name="s4")
        s5 = spn.Sum(s2, v3, s3, name="s5")
        s6 = spn.Sum(s4, s2, s5, s4, s5, name="s6")  # s4 and s5 included twice
        spn.generate_weights(s6)

        down_values = {}
        spn.compute_graph_up_down(s6, down_fun=fun, graph_input=5,
                                  down_values=down_values)

        self.assertEqual(counter[0], 15)
        # Using sorted since order is not guaranteed
        self.assertListAlmostEqual(sorted(parent_vals_saved[s6]), [5])
        self.assertListAlmostEqual(down_values[s6], [5.01, 6.01, 7.01, 8.01,
                                                     9.01, 10.01, 11.01])
        self.assertListAlmostEqual(sorted(parent_vals_saved[s5]), [9.01, 11.01])
        self.assertListAlmostEqual(down_values[s5], [20.03, 21.03, 22.03,
                                                     23.03, 24.03])
        self.assertListAlmostEqual(sorted(parent_vals_saved[s4]), [7.01, 10.01])
        self.assertListAlmostEqual(down_values[s4], [17.03, 18.03, 19.03, 20.03])
        self.assertListAlmostEqual(sorted(parent_vals_saved[s3]), [24.03])
        self.assertListAlmostEqual(down_values[s3], [24.04, 25.04, 26.04,
                                                     27.04, 28.04])
        self.assertListAlmostEqual(sorted(parent_vals_saved[s2]), [8.01, 22.03])
        self.assertListAlmostEqual(down_values[s2], [30.05, 31.05, 32.05, 33.05])
        self.assertListAlmostEqual(sorted(parent_vals_saved[s1]), [19.03])
        self.assertListAlmostEqual(down_values[s1], [19.04, 20.04, 21.04,
                                                     22.04, 23.04])

        self.assertListAlmostEqual(sorted(parent_vals_saved[v1]),
                                   [20.03, 21.04, 22.04, 32.05])
        self.assertEqual(down_values[v1], 101)
        self.assertListAlmostEqual(sorted(parent_vals_saved[v2]),
                                   [23.04, 26.04])
        self.assertEqual(down_values[v2], 101)
        self.assertListAlmostEqual(sorted(parent_vals_saved[v3]),
                                   [23.03, 27.04, 28.04, 33.05])
        self.assertEqual(down_values[v3], 101)

        # Test if the algorithm works on a VarNode and calls graph_input function
        down_values = {}
        parent_vals_saved = {}
        spn.compute_graph_up_down(v1, down_fun=fun, graph_input=lambda: 5,
                                  down_values=down_values)
        self.assertEqual(parent_vals_saved[v1][0], 5)
        self.assertEqual(down_values[v1], 101)