Exemple #1
0
 def test_is_valid_false(self):
     """Checking validity of the SPN"""
     # Create graph
     v12 = spn.IVs(num_vars=2, num_vals=4, name="V12")
     v34 = spn.ContVars(num_vars=2, name="V34")
     s1 = spn.Sum((v12, [0, 1, 2, 3]), name="S1")
     s2 = spn.Sum((v12, [4, 5, 6, 7]), name="S2")
     p1 = spn.Product((v12, [0, 7]), name="P1")
     p2 = spn.Product((v12, [2, 3, 4]), name="P2")
     p3 = spn.Product(v34, name="P3")
     n1 = spn.Concat(s1, s2, p3, name="N1")
     n2 = spn.Concat(p1, p2, name="N2")
     p4 = spn.Product((n1, [0]), (n1, [1]), name="P4")
     p5 = spn.Product((n2, [0]), (n1, [2]), name="P5")
     s3 = spn.Sum(p4, n2, name="S3")
     p6 = spn.Product(s3, (n1, [2]), name="P6")
     s4 = spn.Sum(p5, p6, name="S4")
     # Test
     self.assertTrue(v12.is_valid())
     self.assertTrue(v34.is_valid())
     self.assertTrue(s1.is_valid())
     self.assertTrue(s2.is_valid())
     self.assertTrue(p1.is_valid())
     self.assertTrue(p3.is_valid())
     self.assertTrue(p4.is_valid())
     self.assertTrue(n1.is_valid())
     self.assertFalse(p2.is_valid())
     self.assertFalse(n2.is_valid())
     self.assertFalse(s3.is_valid())
     self.assertFalse(s4.is_valid())
     self.assertFalse(p5.is_valid())
     self.assertFalse(p6.is_valid())
Exemple #2
0
    def test_get_out_size(self):
        """Computing the sizes of the outputs of nodes in SPN graph"""
        # Generate graph
        v1 = spn.ContVars(num_vars=5)
        v2 = spn.ContVars(num_vars=5)
        v3 = spn.ContVars(num_vars=5)
        s1 = spn.Sum((v1, [1, 3]), (v1, [1, 4]), v2)  # v1 included twice
        s2 = spn.Sum(v1, (v3, [0, 1, 2, 3, 4]))
        s3 = spn.Sum(v2, v3, v3)  # v3 included twice
        n4 = spn.Concat(s1, v1)
        n5 = spn.Concat((v3, [0, 4]), s3)
        n6 = spn.Concat(n4, s2, n5, (n4, [0]), (n5, [1]))  # n4 and n5 included twice

        # Test
        num = v1.get_out_size()
        self.assertEqual(num, 5)
        num = v2.get_out_size()
        self.assertEqual(num, 5)
        num = v3.get_out_size()
        self.assertEqual(num, 5)
        num = s1.get_out_size()
        self.assertEqual(num, 1)
        num = s2.get_out_size()
        self.assertEqual(num, 1)
        num = s3.get_out_size()
        self.assertEqual(num, 1)
        num = n4.get_out_size()
        self.assertEqual(num, 6)
        num = n5.get_out_size()
        self.assertEqual(num, 3)
        num = n6.get_out_size()
        self.assertEqual(num, 12)
Exemple #3
0
 def test_get_scope(self):
     """Computing the scope of nodes of the SPN graph"""
     # Create graph
     v12 = spn.IVs(num_vars=2, num_vals=4, name="V12")
     v34 = spn.ContVars(num_vars=2, name="V34")
     s1 = spn.Sum((v12, [0, 1, 2, 3]), name="S1")
     s2 = spn.Sum((v12, [4, 5, 6, 7]), name="S2")
     p1 = spn.Product((v12, [0, 7]), name="P1")
     p2 = spn.Product((v12, [3, 4]), name="P2")
     p3 = spn.Product(v34, name="P3")
     n1 = spn.Concat(s1, s2, p3, name="N1")
     n2 = spn.Concat(p1, p2, name="N2")
     p4 = spn.Product((n1, [0]), (n1, [1]), name="P4")
     p5 = spn.Product((n2, [0]), (n1, [2]), name="P5")
     s3 = spn.Sum(p4, n2, name="S3")
     p6 = spn.Product(s3, (n1, [2]), name="P6")
     s4 = spn.Sum(p5, p6, name="S4")
     # Test
     self.assertListEqual(v12.get_scope(),
                          [spn.Scope(v12, 0), spn.Scope(v12, 0),
                           spn.Scope(v12, 0), spn.Scope(v12, 0),
                           spn.Scope(v12, 1), spn.Scope(v12, 1),
                           spn.Scope(v12, 1), spn.Scope(v12, 1)])
     self.assertListEqual(v34.get_scope(),
                          [spn.Scope(v34, 0), spn.Scope(v34, 1)])
     self.assertListEqual(s1.get_scope(),
                          [spn.Scope(v12, 0)])
     self.assertListEqual(s2.get_scope(),
                          [spn.Scope(v12, 1)])
     self.assertListEqual(p1.get_scope(),
                          [spn.Scope(v12, 0) | spn.Scope(v12, 1)])
     self.assertListEqual(p2.get_scope(),
                          [spn.Scope(v12, 0) | spn.Scope(v12, 1)])
     self.assertListEqual(p3.get_scope(),
                          [spn.Scope(v34, 0) | spn.Scope(v34, 1)])
     self.assertListEqual(n1.get_scope(),
                          [spn.Scope(v12, 0),
                           spn.Scope(v12, 1),
                           spn.Scope(v34, 0) | spn.Scope(v34, 1)])
     self.assertListEqual(n2.get_scope(),
                          [spn.Scope(v12, 0) | spn.Scope(v12, 1),
                           spn.Scope(v12, 0) | spn.Scope(v12, 1)])
     self.assertListEqual(p4.get_scope(),
                          [spn.Scope(v12, 0) | spn.Scope(v12, 1)])
     self.assertListEqual(p5.get_scope(),
                          [spn.Scope(v12, 0) | spn.Scope(v12, 1) |
                           spn.Scope(v34, 0) | spn.Scope(v34, 1)])
     self.assertListEqual(s3.get_scope(),
                          [spn.Scope(v12, 0) | spn.Scope(v12, 1)])
     self.assertListEqual(p6.get_scope(),
                          [spn.Scope(v12, 0) | spn.Scope(v12, 1) |
                           spn.Scope(v34, 0) | spn.Scope(v34, 1)])
     self.assertListEqual(s4.get_scope(),
                          [spn.Scope(v12, 0) | spn.Scope(v12, 1) |
                           spn.Scope(v34, 0) | spn.Scope(v34, 1)])
Exemple #4
0
 def test(inpt, feed, true_output):
     with self.subTest(inputs=inpt, feed=feed):
         n = spn.Concat(inpt)
         op, = n._gather_input_tensors(n.inputs[0].node.get_value())
         with self.test_session() as sess:
             out = sess.run(op, feed_dict=feed)
         np.testing.assert_array_equal(out, np.array(true_output))
 def test_compute_mpe_path(self):
     v12 = spn.IndicatorLeaf(num_vars=2, num_vals=4)
     v34 = spn.RawLeaf(num_vars=2)
     v5 = spn.RawLeaf(num_vars=1)
     p = spn.Concat((v12, [0, 5]), v34, (v12, [3]), v5)
     counts = tf.placeholder(tf.float32, shape=(None, 6))
     op = p._compute_log_mpe_path(tf.identity(counts), v12.get_value(),
                                  v34.get_value(), v12.get_value(),
                                  v5.get_value())
     feed = np.r_[:18].reshape(-1, 6)
     with self.test_session() as sess:
         out = sess.run(op, feed_dict={counts: feed})
     np.testing.assert_array_almost_equal(
         out[0],
         np.array([[0., 0., 0., 0., 0., 1., 0., 0.],
                   [6., 0., 0., 0., 0., 7., 0., 0.],
                   [12., 0., 0., 0., 0., 13., 0., 0.]],
                  dtype=np.float32))
     np.testing.assert_array_almost_equal(
         out[1], np.array([[2., 3.], [8., 9.], [14., 15.]],
                          dtype=np.float32))
     np.testing.assert_array_almost_equal(
         out[2],
         np.array([[0., 0., 0., 4., 0., 0., 0., 0.],
                   [0., 0., 0., 10., 0., 0., 0., 0.],
                   [0., 0., 0., 16., 0., 0., 0., 0.]],
                  dtype=np.float32))
     np.testing.assert_array_almost_equal(
         out[3], np.array([[5.], [11.], [17.]], dtype=np.float32))
Exemple #6
0
 def concat_layer_and_test(inputs, name):
     """ Create a concat node, generate its scopes and assert whether it is correct """
     scope = []
     for inp in inputs:
         if isinstance(inp, tuple):
             indices = inp[1]
             if isinstance(inp[1], int):
                 indices = [inp[1]]
             for i in indices:
                 scope.append(scopes_per_node[inp[0]][i])
         else:
             scope.extend(scopes_per_node[inp])
     concat = spn.Concat(*inputs, name=name)
     self.assertListEqual(concat.get_scope(), scope)
     scopes_per_node[concat] = scope
     return concat
Exemple #7
0
    def test_generte_set_errors(self):
        """Detecting structure errors in __generate_set"""
        gen = spn.DenseSPNGenerator(num_decomps=2,
                                    num_subsets=3,
                                    num_mixtures=2)
        v1 = spn.IVs(num_vars=2, num_vals=4)
        v2 = spn.ContVars(num_vars=3, name="ContVars1")
        v3 = spn.ContVars(num_vars=2, name="ContVars2")
        s1 = spn.Sum(v3, v2)
        n1 = spn.Concat(v2)

        with self.assertRaises(spn.StructureError):
            gen._DenseSPNGenerator__generate_set([
                spn.Input(v1, [0, 3, 2, 6, 7]),
                spn.Input(v2, [1, 2]),
                spn.Input(s1, None),
                spn.Input(n1, None)
            ])
Exemple #8
0
    def test_generte_set(self):
        """Generation of sets of inputs with __generate_set"""
        gen = spn.DenseSPNGenerator(num_decomps=2,
                                    num_subsets=3,
                                    num_mixtures=2)
        v1 = spn.IVs(num_vars=2, num_vals=4)
        v2 = spn.ContVars(num_vars=3, name="ContVars1")
        v3 = spn.ContVars(num_vars=2, name="ContVars2")
        s1 = spn.Sum(v3)
        n1 = spn.Concat(v2)
        out = gen._DenseSPNGenerator__generate_set([
            spn.Input(v1, [0, 3, 2, 6, 7]),
            spn.Input(v2, [1, 2]),
            spn.Input(s1, None),
            spn.Input(n1, None)
        ])
        # scope_dict:
        # Scope({IVs(0x7f00cb4049b0):0}): {(IVs(0x7f00cb4049b0), 0),
        #                                  (IVs(0x7f00cb4049b0), 2),
        #                                  (IVs(0x7f00cb4049b0), 3)},
        # Scope({IVs(0x7f00cb4049b0):1}): {(IVs(0x7f00cb4049b0), 7),
        #                                  (IVs(0x7f00cb4049b0), 6)},
        # Scope({ContVars1(0x7f00b7982ef0):1}): {(Concat(0x7f00cb404d68), 1),
        #                                        (ContVars1(0x7f00b7982ef0), 1)},
        # Scope({ContVars1(0x7f00b7982ef0):2}): {(Concat(0x7f00cb404d68), 2),
        #                                        (ContVars1(0x7f00b7982ef0), 2)},
        # Scope({ContVars1(0x7f00b7982ef0):0}): {(Concat(0x7f00cb404d68), 0)},
        # Scope({ContVars2(0x7f00cb391eb8):0, ContVars2(0x7f00cb391eb8):1}): {
        #                                         (Sum(0x7f00cb404a90), 0)}}

        # Since order is undetermined, we check items
        self.assertEqual(len(out), 6)
        self.assertIn(tuple(sorted([(v2, 1), (n1, 1)])), out)
        self.assertIn(tuple(sorted([(v2, 2), (n1, 2)])), out)
        self.assertIn(tuple(sorted([(n1, 0)])), out)
        self.assertIn(tuple(sorted([(v1, 0), (v1, 2), (v1, 3)])), out)
        self.assertIn(tuple(sorted([(v1, 6), (v1, 7)])), out)
        self.assertIn(tuple(sorted([(s1, 0)])), out)
 def test(inputs, feed, output):
     with self.subTest(inputs=inputs, feed=feed):
         n = spn.Concat(*inputs)
         op = n.get_value(spn.InferenceType.MARGINAL)
         op_log = n.get_log_value(spn.InferenceType.MARGINAL)
         op_mpe = n.get_value(spn.InferenceType.MPE)
         op_log_mpe = n.get_log_value(spn.InferenceType.MPE)
         with self.test_session() as sess:
             out = sess.run(op, feed_dict=feed)
             out_log = sess.run(tf.exp(op_log), feed_dict=feed)
             out_mpe = sess.run(op_mpe, feed_dict=feed)
             out_log_mpe = sess.run(tf.exp(op_log_mpe), feed_dict=feed)
         np.testing.assert_array_almost_equal(
             out, np.array(output,
                           dtype=spn.conf.dtype.as_numpy_dtype()))
         np.testing.assert_array_almost_equal(
             out_log,
             np.array(output, dtype=spn.conf.dtype.as_numpy_dtype()))
         np.testing.assert_array_almost_equal(
             out_mpe,
             np.array(output, dtype=spn.conf.dtype.as_numpy_dtype()))
         np.testing.assert_array_almost_equal(
             out_log_mpe,
             np.array(output, dtype=spn.conf.dtype.as_numpy_dtype()))
Exemple #10
0
 def test_generte_set(self):
     """Generation of sets of inputs with __generate_set"""
     gen = spn.DenseSPNGenerator(num_decomps=2,
                                 num_subsets=3,
                                 num_mixtures=2)
     v1 = spn.IVs(num_vars=2, num_vals=4)
     v2 = spn.ContVars(num_vars=3, name="ContVars1")
     v3 = spn.ContVars(num_vars=2, name="ContVars2")
     s1 = spn.Sum(v3)
     n1 = spn.Concat(v2)
     out = gen._DenseSPNGenerator__generate_set([
         spn.Input(v1, [0, 3, 2, 6, 7]),
         spn.Input(v2, [1, 2]),
         spn.Input(s1, None),
         spn.Input(n1, None)
     ])
     # Since order is undetermined, we check items
     self.assertEqual(len(out), 6)
     self.assertIn(tuple(sorted([(v2, 1), (n1, 1)])), out)
     self.assertIn(tuple(sorted([(v2, 2), (n1, 2)])), out)
     self.assertIn(tuple(sorted([(n1, 0)])), out)
     self.assertIn(tuple(sorted([(v1, 0), (v1, 2), (v1, 3)])), out)
     self.assertIn(tuple(sorted([(v1, 6), (v1, 7)])), out)
     self.assertIn(tuple(sorted([(s1, 0)])), out)
Exemple #11
0
    def test_gather_input_tensors(self):
        def test(inpt, feed, true_output):
            with self.subTest(inputs=inpt, feed=feed):
                n = spn.Concat(inpt)
                op, = n._gather_input_tensors(n.inputs[0].node.get_value())
                with self.test_session() as sess:
                    out = sess.run(op, feed_dict=feed)
                np.testing.assert_array_equal(out, np.array(true_output))

        v1 = spn.ContVars(num_vars=3)
        v2 = spn.ContVars(num_vars=1)

        # Disconnected input
        n = spn.Concat(None)
        op, = n._gather_input_tensors(3)
        self.assertIs(op, None)

        # None input tensor
        n = spn.Concat((v1, 1))
        op, = n._gather_input_tensors(None)
        self.assertIs(op, None)

        # Gathering for indices specified
        test((v1, [0, 2, 1]),
             {v1: [[1, 2, 3],
                   [4, 5, 6]]},
             [[1.0, 3.0, 2.0],
              [4.0, 6.0, 5.0]])
        test((v1, [0, 2]),
             {v1: [[1, 2, 3],
                   [4, 5, 6]]},
             [[1.0, 3.0],
              [4.0, 6.0]])
        test((v1, [1]),
             {v1: [[1, 2, 3],
                   [4, 5, 6]]},
             [[2.0],
              [5.0]])
        test((v1, [0, 2, 1]),
             {v1: [[1, 2, 3]]},
             [[1.0, 3.0, 2.0]])
        test((v1, [0, 2]),
             {v1: [[1, 2, 3]]},
             [[1.0, 3.0]])
        test((v1, [1]),
             {v1: [[1, 2, 3]]},
             [[2.0]])

        # Test that if None indices, it passes the tensor directly
        n = spn.Concat(v1)
        t = tf.constant([1, 2, 3])
        op, = n._gather_input_tensors(t)
        self.assertIs(op, t)

        # Gathering for None indices
        test(v1,
             {v1: [[1, 2, 3],
                   [4, 5, 6]]},
             [[1.0, 2.0, 3.0],
              [4.0, 5.0, 6.0]])
        test((v1, None),
             {v1: [[1, 2, 3],
                   [4, 5, 6]]},
             [[1.0, 2.0, 3.0],
              [4.0, 5.0, 6.0]])
        test(v1,
             {v1: [[1, 2, 3]]},
             [[1.0, 2.0, 3.0]])
        test((v1, None),
             {v1: [[1, 2, 3]]},
             [[1.0, 2.0, 3.0]])

        # Gathering for single index specified
        test((v1, 1),
             {v1: [[1, 2, 3],
                   [4, 5, 6]]},
             [[2.0],
              [5.0]])
        test((v1, [1]),
             {v1: [[1, 2, 3],
                   [4, 5, 6]]},
             [[2.0],
              [5.0]])
        test((v1, 1),
             {v1: [[1, 2, 3]]},
             [[2.0]])
        test((v1, [1]),
             {v1: [[1, 2, 3]]},
             [[2.0]])

        # Gathering for one element input, index specified
        test((v2, 0),
             {v2: [[1],
                   [4]]},
             [[1.0],
              [4.0]])
        test((v2, [0]),
             {v2: [[1],
                   [4]]},
             [[1.0],
              [4.0]])
        test((v2, 0),
             {v2: [[1]]},
             [[1.0]])
        test((v2, [0]),
             {v2: [[1]]},
             [[1.0]])

        # Gathering for one element input, None indices
        test(v2,
             {v2: [[1],
                   [4]]},
             [[1.0],
              [4.0]])
        test((v2, None),
             {v2: [[1],
                   [4]]},
             [[1.0],
              [4.0]])
        test(v2,
             {v2: [[1]]},
             [[1.0]])
        test((v2, None),
             {v2: [[1]]},
             [[1.0]])
 def test_comput_scope(self):
     """Calculating scope of Product"""
     # Create a graph
     v12 = spn.IndicatorLeaf(num_vars=2, num_vals=4, name="V12")
     v34 = spn.RawLeaf(num_vars=2, name="V34")
     s1 = spn.Sum((v12, [0, 1, 2, 3]), name="S1")
     s1.generate_latent_indicators()
     s2 = spn.Sum((v12, [4, 5, 6, 7]), name="S2")
     p1 = spn.Product((v12, [0, 7]), name="P1")
     p2 = spn.Product((v12, [3, 4]), name="P1")
     p3 = spn.Product(v34, name="P3")
     n1 = spn.Concat(s1, s2, p3, name="N1")
     n2 = spn.Concat(p1, p2, name="N2")
     p4 = spn.Product((n1, [0]), (n1, [1]), name="P4")
     p5 = spn.Product((n2, [0]), (n1, [2]), name="P5")
     s3 = spn.Sum(p4, n2, name="S3")
     p6 = spn.Product(s3, (n1, [2]), name="P6")
     s4 = spn.Sum(p5, p6, name="S4")
     s4.generate_latent_indicators()
     # Test
     self.assertListEqual(v12.get_scope(),
                          [spn.Scope(v12, 0), spn.Scope(v12, 0),
                           spn.Scope(v12, 0), spn.Scope(v12, 0),
                           spn.Scope(v12, 1), spn.Scope(v12, 1),
                           spn.Scope(v12, 1), spn.Scope(v12, 1)])
     self.assertListEqual(v34.get_scope(),
                          [spn.Scope(v34, 0), spn.Scope(v34, 1)])
     self.assertListEqual(s1.get_scope(),
                          [spn.Scope(v12, 0) | spn.Scope(s1.latent_indicators.node, 0)])
     self.assertListEqual(s2.get_scope(),
                          [spn.Scope(v12, 1)])
     self.assertListEqual(p1.get_scope(),
                          [spn.Scope(v12, 0) | spn.Scope(v12, 1)])
     self.assertListEqual(p2.get_scope(),
                          [spn.Scope(v12, 0) | spn.Scope(v12, 1)])
     self.assertListEqual(p3.get_scope(),
                          [spn.Scope(v34, 0) | spn.Scope(v34, 1)])
     self.assertListEqual(n1.get_scope(),
                          [spn.Scope(v12, 0) | spn.Scope(s1.latent_indicators.node, 0),
                           spn.Scope(v12, 1),
                           spn.Scope(v34, 0) | spn.Scope(v34, 1)])
     self.assertListEqual(n2.get_scope(),
                          [spn.Scope(v12, 0) | spn.Scope(v12, 1),
                           spn.Scope(v12, 0) | spn.Scope(v12, 1)])
     self.assertListEqual(p4.get_scope(),
                          [spn.Scope(v12, 0) | spn.Scope(v12, 1) |
                           spn.Scope(s1.latent_indicators.node, 0)])
     self.assertListEqual(p5.get_scope(),
                          [spn.Scope(v12, 0) | spn.Scope(v12, 1) |
                           spn.Scope(v34, 0) | spn.Scope(v34, 1)])
     self.assertListEqual(s3.get_scope(),
                          [spn.Scope(v12, 0) | spn.Scope(v12, 1) |
                           spn.Scope(s1.latent_indicators.node, 0)])
     self.assertListEqual(p6.get_scope(),
                          [spn.Scope(v12, 0) | spn.Scope(v12, 1) |
                           spn.Scope(v34, 0) | spn.Scope(v34, 1) |
                           spn.Scope(s1.latent_indicators.node, 0)])
     self.assertListEqual(s4.get_scope(),
                          [spn.Scope(v12, 0) | spn.Scope(v12, 1) |
                           spn.Scope(v34, 0) | spn.Scope(v34, 1) |
                           spn.Scope(s1.latent_indicators.node, 0) |
                           spn.Scope(s4.latent_indicators.node, 0)])
 def test_comput_scope(self):
     """Calculating scope of PermuteProducts"""
     # Create graph
     v12 = spn.IndicatorLeaf(num_vars=2, num_vals=4, name="V12")
     v34 = spn.RawLeaf(num_vars=2, name="V34")
     s1 = spn.Sum((v12, [0, 1, 2, 3]), name="S1")
     s1.generate_latent_indicators()
     s2 = spn.Sum((v12, [4, 5, 6, 7]), name="S2")
     p1 = spn.Product((v12, [0, 7]), name="P1")
     p2 = spn.Product((v12, [3, 4]), name="P2")
     p3 = spn.Product(v34, name="P3")
     n1 = spn.Concat(s1, s2, p3, name="N1")
     n2 = spn.Concat(p1, p2, name="N2")
     pp1 = spn.PermuteProducts(n1, n2, name="PP1")  # num_prods = 6
     pp2 = spn.PermuteProducts((n1, [0, 1]), (n2, [0]), name="PP2")  # num_prods = 2
     pp3 = spn.PermuteProducts(n2, p3, name="PP3")  # num_prods = 2
     pp4 = spn.PermuteProducts(p2, p3, name="PP4")  # num_prods = 1
     pp5 = spn.PermuteProducts((n2, [0, 1]), name="PP5")  # num_prods = 1
     pp6 = spn.PermuteProducts(p3, name="PP6")  # num_prods = 1
     n3 = spn.Concat((pp1, [0, 2, 3]), pp2, pp4, name="N3")
     s3 = spn.Sum((pp1, [0, 2, 4]), (pp1, [1, 3, 5]), pp2, pp3, (pp4, 0),
                  pp5, pp6, name="S3")
     s3.generate_latent_indicators()
     n4 = spn.Concat((pp3, [0, 1]), pp5, (pp6, 0), name="N4")
     pp7 = spn.PermuteProducts(n3, s3, n4, name="PP7")  # num_prods = 24
     pp8 = spn.PermuteProducts(n3, name="PP8")  # num_prods = 1
     pp9 = spn.PermuteProducts((n4, [0, 1, 2, 3]), name="PP9")  # num_prods = 1
     # Test
     self.assertListEqual(v12.get_scope(),
                          [spn.Scope(v12, 0), spn.Scope(v12, 0),
                           spn.Scope(v12, 0), spn.Scope(v12, 0),
                           spn.Scope(v12, 1), spn.Scope(v12, 1),
                           spn.Scope(v12, 1), spn.Scope(v12, 1)])
     self.assertListEqual(v34.get_scope(),
                          [spn.Scope(v34, 0), spn.Scope(v34, 1)])
     self.assertListEqual(s1.get_scope(),
                          [spn.Scope(v12, 0) | spn.Scope(s1.latent_indicators.node, 0)])
     self.assertListEqual(s2.get_scope(),
                          [spn.Scope(v12, 1)])
     self.assertListEqual(p1.get_scope(),
                          [spn.Scope(v12, 0) | spn.Scope(v12, 1)])
     self.assertListEqual(p2.get_scope(),
                          [spn.Scope(v12, 0) | spn.Scope(v12, 1)])
     self.assertListEqual(p3.get_scope(),
                          [spn.Scope(v34, 0) | spn.Scope(v34, 1)])
     self.assertListEqual(n1.get_scope(),
                          [spn.Scope(v12, 0) | spn.Scope(s1.latent_indicators.node, 0),
                           spn.Scope(v12, 1),
                           spn.Scope(v34, 0) | spn.Scope(v34, 1)])
     self.assertListEqual(n2.get_scope(),
                          [spn.Scope(v12, 0) | spn.Scope(v12, 1),
                           spn.Scope(v12, 0) | spn.Scope(v12, 1)])
     self.assertListEqual(pp1.get_scope(),
                          [spn.Scope(v12, 0) | spn.Scope(v12, 1) |
                           spn.Scope(s1.latent_indicators.node, 0),
                           spn.Scope(v12, 0) | spn.Scope(v12, 1) |
                           spn.Scope(s1.latent_indicators.node, 0),
                           spn.Scope(v12, 0) | spn.Scope(v12, 1),
                           spn.Scope(v12, 0) | spn.Scope(v12, 1),
                           spn.Scope(v12, 0) | spn.Scope(v12, 1) |
                           spn.Scope(v34, 0) | spn.Scope(v34, 1),
                           spn.Scope(v12, 0) | spn.Scope(v12, 1) |
                           spn.Scope(v34, 0) | spn.Scope(v34, 1)])
     self.assertListEqual(pp2.get_scope(),
                          [spn.Scope(v12, 0) | spn.Scope(v12, 1) |
                           spn.Scope(s1.latent_indicators.node, 0),
                           spn.Scope(v12, 0) | spn.Scope(v12, 1)])
     self.assertListEqual(pp3.get_scope(),
                          [spn.Scope(v12, 0) | spn.Scope(v12, 1) |
                           spn.Scope(v34, 0) | spn.Scope(v34, 1),
                           spn.Scope(v12, 0) | spn.Scope(v12, 1) |
                           spn.Scope(v34, 0) | spn.Scope(v34, 1)])
     self.assertListEqual(pp4.get_scope(),
                          [spn.Scope(v12, 0) | spn.Scope(v12, 1) |
                           spn.Scope(v34, 0) | spn.Scope(v34, 1)])
     self.assertListEqual(pp5.get_scope(),
                          [spn.Scope(v12, 0) | spn.Scope(v12, 1)])
     self.assertListEqual(pp6.get_scope(),
                          [spn.Scope(v34, 0) | spn.Scope(v34, 1)])
     self.assertListEqual(n3.get_scope(),
                          [spn.Scope(v12, 0) | spn.Scope(v12, 1) |
                           spn.Scope(s1.latent_indicators.node, 0),
                           spn.Scope(v12, 0) | spn.Scope(v12, 1),
                           spn.Scope(v12, 0) | spn.Scope(v12, 1),
                           spn.Scope(v12, 0) | spn.Scope(v12, 1) |
                           spn.Scope(s1.latent_indicators.node, 0),
                           spn.Scope(v12, 0) | spn.Scope(v12, 1),
                           spn.Scope(v12, 0) | spn.Scope(v12, 1) |
                           spn.Scope(v34, 0) | spn.Scope(v34, 1)])
     self.assertListEqual(s3.get_scope(),
                          [spn.Scope(v12, 0) | spn.Scope(v12, 1) |
                           spn.Scope(v34, 0) | spn.Scope(v34, 1) |
                           spn.Scope(s1.latent_indicators.node, 0) |
                           spn.Scope(s3.latent_indicators.node, 0)])
     self.assertListEqual(n4.get_scope(),
                          [spn.Scope(v12, 0) | spn.Scope(v12, 1) |
                           spn.Scope(v34, 0) | spn.Scope(v34, 1),
                           spn.Scope(v12, 0) | spn.Scope(v12, 1) |
                           spn.Scope(v34, 0) | spn.Scope(v34, 1),
                           spn.Scope(v12, 0) | spn.Scope(v12, 1),
                           spn.Scope(v34, 0) | spn.Scope(v34, 1)])
     self.assertListEqual(pp7.get_scope(),
                          [spn.Scope(v12, 0) | spn.Scope(v12, 1) |
                           spn.Scope(v34, 0) | spn.Scope(v34, 1) |
                           spn.Scope(s1.latent_indicators.node, 0) |
                           spn.Scope(s3.latent_indicators.node, 0)] * 24)
     self.assertListEqual(pp8.get_scope(),
                          [spn.Scope(v12, 0) | spn.Scope(v12, 1) |
                           spn.Scope(s1.latent_indicators.node, 0) | spn.Scope(v34, 0) |
                           spn.Scope(v34, 1)])
     self.assertListEqual(pp9.get_scope(),
                          [spn.Scope(v12, 0) | spn.Scope(v12, 1) |
                           spn.Scope(v34, 0) | spn.Scope(v34, 1)])
Exemple #14
0
 def test_comput_scope(self):
     """Calculating scope of ProductsLayer"""
     # Create graph
     v12 = spn.IndicatorLeaf(num_vars=2, num_vals=4, name="V12")
     v34 = spn.RawLeaf(num_vars=2, name="V34")
     s1 = spn.Sum((v12, [0, 1, 2, 3]), name="S1")
     s1.generate_latent_indicators()
     s2 = spn.Sum((v12, [4, 5, 6, 7]), name="S2")
     pl1 = spn.ProductsLayer((v12, [0, 5, 6, 7]), (v12, [3, 4]),
                             v34,
                             num_or_size_prods=[4, 3, 1],
                             name="PL1")
     n1 = spn.Concat(s1, s2, (pl1, [2]), name="N1")
     n2 = spn.Concat((pl1, [0]), (pl1, [1]), name="N2")
     s3 = spn.Sum(pl1, name="S3")
     s3.generate_latent_indicators()
     pl2 = spn.ProductsLayer((n1, [0, 1]), (n1, 2), (n2, 0), (pl1, [1]),
                             n2,
                             s3, (n2, 1),
                             s3,
                             pl1,
                             num_or_size_prods=[2, 3, 3, 5],
                             name="PL2")
     s4 = spn.Sum((pl2, 0), n2, name="S4")
     s5 = spn.Sum(pl2, name="S5")
     s6 = spn.Sum((pl2, [1, 3]), name="S6")
     s6.generate_latent_indicators()
     pl3 = spn.ProductsLayer(s4, (n1, 2), num_or_size_prods=1, name="PL3")
     pl4 = spn.ProductsLayer(s4,
                             s5,
                             s6,
                             s4,
                             s5,
                             s6,
                             num_or_size_prods=2,
                             name="PL4")
     # Test
     self.assertListEqual(v12.get_scope(), [
         spn.Scope(v12, 0),
         spn.Scope(v12, 0),
         spn.Scope(v12, 0),
         spn.Scope(v12, 0),
         spn.Scope(v12, 1),
         spn.Scope(v12, 1),
         spn.Scope(v12, 1),
         spn.Scope(v12, 1)
     ])
     self.assertListEqual(
         v34.get_scope(),
         [spn.Scope(v34, 0), spn.Scope(v34, 1)])
     self.assertListEqual(
         s1.get_scope(),
         [spn.Scope(v12, 0) | spn.Scope(s1.latent_indicators.node, 0)])
     self.assertListEqual(s2.get_scope(), [spn.Scope(v12, 1)])
     self.assertListEqual(pl1.get_scope(), [
         spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v12, 1)
         | spn.Scope(v12, 1),
         spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0),
         spn.Scope(v34, 1)
     ])
     self.assertListEqual(n1.get_scope(), [
         spn.Scope(v12, 0) | spn.Scope(s1.latent_indicators.node, 0),
         spn.Scope(v12, 1),
         spn.Scope(v34, 1)
     ])
     self.assertListEqual(n2.get_scope(), [
         spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v12, 1)
         | spn.Scope(v12, 1),
         spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0)
     ])
     self.assertListEqual(s3.get_scope(), [
         spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0)
         | spn.Scope(v34, 1) | spn.Scope(s3.latent_indicators.node, 0)
     ])
     self.assertListEqual(pl2.get_scope(), [
         spn.Scope(v12, 0) | spn.Scope(v12, 1)
         | spn.Scope(s1.latent_indicators.node, 0),
         spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0)
         | spn.Scope(v34, 1),
         spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0)
         | spn.Scope(v34, 1) | spn.Scope(s3.latent_indicators.node, 0),
         spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0)
         | spn.Scope(v34, 1) | spn.Scope(s3.latent_indicators.node, 0)
     ])
     self.assertListEqual(s4.get_scope(), [
         spn.Scope(v12, 0) | spn.Scope(v12, 1)
         | spn.Scope(s1.latent_indicators.node, 0) | spn.Scope(v34, 0)
     ])
     self.assertListEqual(s5.get_scope(), [
         spn.Scope(v12, 0) | spn.Scope(v12, 1)
         | spn.Scope(s1.latent_indicators.node, 0) | spn.Scope(v34, 0)
         | spn.Scope(v34, 1) | spn.Scope(s3.latent_indicators.node, 0)
     ])
     self.assertListEqual(s6.get_scope(), [
         spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0)
         | spn.Scope(v34, 1) | spn.Scope(s3.latent_indicators.node, 0)
         | spn.Scope(s6.latent_indicators.node, 0)
     ])
     self.assertListEqual(pl3.get_scope(), [
         spn.Scope(v12, 0) | spn.Scope(v12, 1)
         | spn.Scope(s1.latent_indicators.node, 0) | spn.Scope(v34, 0)
         | spn.Scope(v34, 1)
     ])
     self.assertListEqual(pl4.get_scope(), [
         spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0)
         | spn.Scope(v34, 1) | spn.Scope(s1.latent_indicators.node, 0)
         | spn.Scope(s3.latent_indicators.node, 0)
         | spn.Scope(s6.latent_indicators.node, 0),
         spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0)
         | spn.Scope(v34, 1) | spn.Scope(s1.latent_indicators.node, 0)
         | spn.Scope(s3.latent_indicators.node, 0)
         | spn.Scope(s6.latent_indicators.node, 0)
     ])