Пример #1
0
    def test_compute_graph_up_noconst(self):
        """Computing value assuming no constant functions"""
        # Number of times val_fun was called
        # Use list to avoid creating local fun variable during assignment
        counter = [0]

        def val_fun(node, *inputs):
            counter[0] += 1
            if isinstance(node, spn.graph.node.VarNode):
                return 1
            elif isinstance(node, spn.graph.node.ParamNode):
                return 0.1
            else:
                weight_val, iv_val, *values = inputs
                return weight_val + sum(values) + 1

        # Generate graph
        v1 = spn.RawLeaf(num_vars=1)
        v2 = spn.RawLeaf(num_vars=1)
        v3 = spn.RawLeaf(num_vars=1)
        s1 = spn.Sum(v1, v1, v2)  # v1 included twice
        s2 = spn.Sum(v1, v3)
        s3 = spn.Sum(v2, v3, v3)  # v3 included twice
        s4 = spn.Sum(s1, v1)
        s5 = spn.Sum(s2, v3, s3)
        s6 = spn.Sum(s4, s2, s5, s4, s5)  # s4 and s5 included twice
        spn.generate_weights(s6)

        # Calculate value
        val = spn.compute_graph_up(s6, val_fun)

        # Test
        self.assertAlmostEqual(val, 35.2)
        self.assertEqual(counter[0], 15)
Пример #2
0
    def test_traversing_on_dense(self):
        """Compare traversal algs on dense SPN"""
        def fun1(node, *args):
            counter[0] += 1

        def fun2(node, *args):
            counter[0] += 1
            if node.is_op:
                return [None] * len(node.inputs)

        # Generate dense graph
        v1 = spn.IndicatorLeaf(num_vars=3, num_vals=2, name="IndicatorLeaf1")
        v2 = spn.IndicatorLeaf(num_vars=3, num_vals=2, name="IndicatorLeaf2")

        gen = spn.DenseSPNGenerator(num_decomps=2,
                                    num_subsets=3,
                                    num_mixtures=2,
                                    input_dist=spn.DenseSPNGenerator.InputDist.MIXTURE,
                                    num_input_mixtures=None)
        root = gen.generate(v1, v2)
        spn.generate_weights(root)

        # Run traversal algs and count nodes
        counter = [0]
        spn.compute_graph_up_down(root, down_fun=fun2, graph_input=1)
        c1 = counter[0]

        counter = [0]
        spn.compute_graph_up(root, val_fun=fun1)
        c2 = counter[0]

        counter = [0]
        spn.traverse_graph(root, fun=fun1, skip_params=False)
        c3 = counter[0]

        # Compare
        self.assertEqual(c1, c3)
        self.assertEqual(c2, c3)
Пример #3
0
    def test_compute_graph_up_const(self):
        """Computing value with constant function detection"""
        # Number of times val_fun was called
        # Use list to avoid creating local fun variable during assignment
        counter = [0]

        # Generate graph
        v1 = spn.RawLeaf(num_vars=1)
        v2 = spn.RawLeaf(num_vars=1)
        v3 = spn.RawLeaf(num_vars=1)
        s1 = spn.Sum(v1, v1, v2)  # v1 included twice
        s2 = spn.Sum(v1, v3)
        s3 = spn.Sum(v2, v3, v3)  # v3 included twice
        s4 = spn.Sum(s1, v1)
        s5 = spn.Sum(s2, v3, s3)
        s6 = spn.Sum(s4, s2, s5, s4, s5)  # s4 and s5 included twice

        def val_fun(node, *inputs):
            counter[0] += 1
            # s3 is not needed for calculations since only parent is s5
            self.assertIsNot(node, s3)
            # Fixed value or compute using children
            if node == s5:
                return 16
            else:
                if isinstance(node, spn.graph.node.VarNode):
                    return 1
                else:
                    weight_val, iv_val, *values = inputs
                    return sum(values) + 1

        def const_fun(node):
            if node == s5:
                return True
            else:
                return False

        # Calculate value
        val = spn.compute_graph_up(s6, val_fun, const_fun)

        # Test
        self.assertEqual(val, 48)
        self.assertEqual(counter[0], 8)