예제 #1
0
    def test_input_flags(self):
        """Detection of different types of inputs"""
        inpt = spn.Input()
        self.assertFalse(inpt)
        self.assertFalse(inpt.is_op)
        self.assertFalse(inpt.is_var)
        self.assertFalse(inpt.is_param)

        n = spn.Sum()
        inpt = spn.Input(n)
        self.assertTrue(inpt)
        self.assertTrue(inpt.is_op)
        self.assertFalse(inpt.is_var)
        self.assertFalse(inpt.is_param)

        n = spn.ContVars()
        inpt = spn.Input(n)
        self.assertTrue(inpt)
        self.assertFalse(inpt.is_op)
        self.assertTrue(inpt.is_var)
        self.assertFalse(inpt.is_param)

        n = spn.Weights()
        inpt = spn.Input(n)
        self.assertTrue(inpt)
        self.assertFalse(inpt.is_op)
        self.assertFalse(inpt.is_var)
        self.assertTrue(inpt.is_param)
예제 #2
0
    def test_generte_set_errors(self):
        """Detecting structure errors in __generate_set"""
        gen = spn.DenseSPNGenerator(num_decomps=2,
                                    num_subsets=3,
                                    num_mixtures=2)
        v1 = spn.IVs(num_vars=2, num_vals=4)
        v2 = spn.ContVars(num_vars=3, name="ContVars1")
        v3 = spn.ContVars(num_vars=2, name="ContVars2")
        s1 = spn.Sum(v3, v2)
        n1 = spn.Concat(v2)

        with self.assertRaises(spn.StructureError):
            gen._DenseSPNGenerator__generate_set([
                spn.Input(v1, [0, 3, 2, 6, 7]),
                spn.Input(v2, [1, 2]),
                spn.Input(s1, None),
                spn.Input(n1, None)
            ])
예제 #3
0
    def test_generte_set(self):
        """Generation of sets of inputs with __generate_set"""
        gen = spn.DenseSPNGenerator(num_decomps=2,
                                    num_subsets=3,
                                    num_mixtures=2)
        v1 = spn.IVs(num_vars=2, num_vals=4)
        v2 = spn.ContVars(num_vars=3, name="ContVars1")
        v3 = spn.ContVars(num_vars=2, name="ContVars2")
        s1 = spn.Sum(v3)
        n1 = spn.Concat(v2)
        out = gen._DenseSPNGenerator__generate_set([
            spn.Input(v1, [0, 3, 2, 6, 7]),
            spn.Input(v2, [1, 2]),
            spn.Input(s1, None),
            spn.Input(n1, None)
        ])
        # scope_dict:
        # Scope({IVs(0x7f00cb4049b0):0}): {(IVs(0x7f00cb4049b0), 0),
        #                                  (IVs(0x7f00cb4049b0), 2),
        #                                  (IVs(0x7f00cb4049b0), 3)},
        # Scope({IVs(0x7f00cb4049b0):1}): {(IVs(0x7f00cb4049b0), 7),
        #                                  (IVs(0x7f00cb4049b0), 6)},
        # Scope({ContVars1(0x7f00b7982ef0):1}): {(Concat(0x7f00cb404d68), 1),
        #                                        (ContVars1(0x7f00b7982ef0), 1)},
        # Scope({ContVars1(0x7f00b7982ef0):2}): {(Concat(0x7f00cb404d68), 2),
        #                                        (ContVars1(0x7f00b7982ef0), 2)},
        # Scope({ContVars1(0x7f00b7982ef0):0}): {(Concat(0x7f00cb404d68), 0)},
        # Scope({ContVars2(0x7f00cb391eb8):0, ContVars2(0x7f00cb391eb8):1}): {
        #                                         (Sum(0x7f00cb404a90), 0)}}

        # Since order is undetermined, we check items
        self.assertEqual(len(out), 6)
        self.assertIn(tuple(sorted([(v2, 1), (n1, 1)])), out)
        self.assertIn(tuple(sorted([(v2, 2), (n1, 2)])), out)
        self.assertIn(tuple(sorted([(n1, 0)])), out)
        self.assertIn(tuple(sorted([(v1, 0), (v1, 2), (v1, 3)])), out)
        self.assertIn(tuple(sorted([(v1, 6), (v1, 7)])), out)
        self.assertIn(tuple(sorted([(s1, 0)])), out)
예제 #4
0
 def test_generte_set(self):
     """Generation of sets of inputs with __generate_set"""
     gen = spn.DenseSPNGenerator(num_decomps=2,
                                 num_subsets=3,
                                 num_mixtures=2)
     v1 = spn.IVs(num_vars=2, num_vals=4)
     v2 = spn.ContVars(num_vars=3, name="ContVars1")
     v3 = spn.ContVars(num_vars=2, name="ContVars2")
     s1 = spn.Sum(v3)
     n1 = spn.Concat(v2)
     out = gen._DenseSPNGenerator__generate_set([
         spn.Input(v1, [0, 3, 2, 6, 7]),
         spn.Input(v2, [1, 2]),
         spn.Input(s1, None),
         spn.Input(n1, None)
     ])
     # Since order is undetermined, we check items
     self.assertEqual(len(out), 6)
     self.assertIn(tuple(sorted([(v2, 1), (n1, 1)])), out)
     self.assertIn(tuple(sorted([(v2, 2), (n1, 2)])), out)
     self.assertIn(tuple(sorted([(n1, 0)])), out)
     self.assertIn(tuple(sorted([(v1, 0), (v1, 2), (v1, 3)])), out)
     self.assertIn(tuple(sorted([(v1, 6), (v1, 7)])), out)
     self.assertIn(tuple(sorted([(s1, 0)])), out)
예제 #5
0
    def test_input_conversion(self):
        """Conversion and verification of input specs in Input"""
        v1 = spn.ContVars(num_vars=5)
        # None
        inpt = spn.Input()
        self.assertIs(inpt.node, None)
        self.assertIs(inpt.indices, None)
        self.assertFalse(inpt)
        inpt = spn.Input(None)
        self.assertIs(inpt.node, None)
        self.assertIs(inpt.indices, None)
        self.assertFalse(inpt)
        inpt = spn.Input(None, [1, 2, 3])
        self.assertIs(inpt.node, None)
        self.assertIs(inpt.indices, None)
        self.assertFalse(inpt)
        inpt = spn.Input.as_input(None)
        self.assertIs(inpt.node, None)
        self.assertIs(inpt.indices, None)
        self.assertFalse(inpt)
        inpt = spn.Input.as_input((None, [1, 2, 3]))
        self.assertIs(inpt.node, None)
        self.assertIs(inpt.indices, None)
        self.assertFalse(inpt)
        # Node
        inpt = spn.Input(v1)
        self.assertIs(inpt.node, v1)
        self.assertIs(inpt.indices, None)
        self.assertTrue(inpt)
        inpt = spn.Input.as_input(v1)
        self.assertIs(inpt.node, v1)
        self.assertIs(inpt.indices, None)
        self.assertTrue(inpt)
        # (Node, None)
        inpt = spn.Input(v1, None)
        self.assertIs(inpt.node, v1)
        self.assertIs(inpt.indices, None)
        self.assertTrue(inpt)
        inpt = spn.Input.as_input((v1, None))
        self.assertIs(inpt.node, v1)
        self.assertIs(inpt.indices, None)
        self.assertTrue(inpt)
        # (Node, index)
        inpt = spn.Input(v1, 10)
        self.assertIs(inpt.node, v1)
        self.assertListEqual(inpt.indices, [10])
        self.assertTrue(inpt)
        inpt = spn.Input.as_input((v1, 10))
        self.assertIs(inpt.node, v1)
        self.assertListEqual(inpt.indices, [10])
        self.assertTrue(inpt)
        # (Node, indices)
        inpt = spn.Input(v1, [10])
        self.assertIs(inpt.node, v1)
        self.assertListEqual(inpt.indices, [10])
        self.assertTrue(inpt)
        inpt = spn.Input.as_input((v1, [10]))
        self.assertIs(inpt.node, v1)
        self.assertListEqual(inpt.indices, [10])
        self.assertTrue(inpt)
        inpt = spn.Input(v1, [10, 1, 20])
        self.assertIs(inpt.node, v1)
        self.assertListEqual(inpt.indices, [10, 1, 20])
        self.assertTrue(inpt)
        inpt = spn.Input.as_input((v1, [10, 1, 20]))
        self.assertIs(inpt.node, v1)
        self.assertListEqual(inpt.indices, [10, 1, 20])
        self.assertTrue(inpt)

        # Checking type of input
        with self.assertRaises(TypeError):
            inpt = spn.Input(set())
        with self.assertRaises(TypeError):
            inpt = spn.Input.as_input(set())
        with self.assertRaises(TypeError):
            inpt = spn.Input.as_input((set()))
        with self.assertRaises(TypeError):
            inpt = spn.Input.as_input(tuple())
        with self.assertRaises(TypeError):
            inpt = spn.Input(v1, set())
        with self.assertRaises(TypeError):
            inpt = spn.Input(v1, set())
        with self.assertRaises(TypeError):
            inpt = spn.Input.as_input((v1, set()))
        with self.assertRaises(TypeError):
            inpt = spn.Input.as_input((v1,))
        # Detecting empty list
        with self.assertRaises(ValueError):
            inpt = spn.Input(v1, [])
        with self.assertRaises(ValueError):
            inpt = spn.Input.as_input((v1, []))
        # Detecting incorrect indices
        with self.assertRaises(ValueError):
            inpt = spn.Input(v1, [1, set(), 2])
        with self.assertRaises(ValueError):
            inpt = spn.Input.as_input((v1, [1, set(), 2]))
        with self.assertRaises(ValueError):
            inpt = spn.Input(v1, [1, -1, 2])
        with self.assertRaises(ValueError):
            inpt = spn.Input.as_input((v1, [1, -1, 2]))
        # Detecting duplicate indices
        with self.assertRaises(ValueError):
            inpt = spn.Input(v1, [0, 1, 3, 1])
        with self.assertRaises(ValueError):
            inpt = spn.Input.as_input((v1, [0, 1, 3, 1]))