def test_is_valid_false(self): """Checking validity of the SPN""" # Create graph v12 = spn.IVs(num_vars=2, num_vals=4, name="V12") v34 = spn.ContVars(num_vars=2, name="V34") s1 = spn.Sum((v12, [0, 1, 2, 3]), name="S1") s2 = spn.Sum((v12, [4, 5, 6, 7]), name="S2") p1 = spn.Product((v12, [0, 7]), name="P1") p2 = spn.Product((v12, [2, 3, 4]), name="P2") p3 = spn.Product(v34, name="P3") n1 = spn.Concat(s1, s2, p3, name="N1") n2 = spn.Concat(p1, p2, name="N2") p4 = spn.Product((n1, [0]), (n1, [1]), name="P4") p5 = spn.Product((n2, [0]), (n1, [2]), name="P5") s3 = spn.Sum(p4, n2, name="S3") p6 = spn.Product(s3, (n1, [2]), name="P6") s4 = spn.Sum(p5, p6, name="S4") # Test self.assertTrue(v12.is_valid()) self.assertTrue(v34.is_valid()) self.assertTrue(s1.is_valid()) self.assertTrue(s2.is_valid()) self.assertTrue(p1.is_valid()) self.assertTrue(p3.is_valid()) self.assertTrue(p4.is_valid()) self.assertTrue(n1.is_valid()) self.assertFalse(p2.is_valid()) self.assertFalse(n2.is_valid()) self.assertFalse(s3.is_valid()) self.assertFalse(s4.is_valid()) self.assertFalse(p5.is_valid()) self.assertFalse(p6.is_valid())
def test_get_out_size(self): """Computing the sizes of the outputs of nodes in SPN graph""" # Generate graph v1 = spn.ContVars(num_vars=5) v2 = spn.ContVars(num_vars=5) v3 = spn.ContVars(num_vars=5) s1 = spn.Sum((v1, [1, 3]), (v1, [1, 4]), v2) # v1 included twice s2 = spn.Sum(v1, (v3, [0, 1, 2, 3, 4])) s3 = spn.Sum(v2, v3, v3) # v3 included twice n4 = spn.Concat(s1, v1) n5 = spn.Concat((v3, [0, 4]), s3) n6 = spn.Concat(n4, s2, n5, (n4, [0]), (n5, [1])) # n4 and n5 included twice # Test num = v1.get_out_size() self.assertEqual(num, 5) num = v2.get_out_size() self.assertEqual(num, 5) num = v3.get_out_size() self.assertEqual(num, 5) num = s1.get_out_size() self.assertEqual(num, 1) num = s2.get_out_size() self.assertEqual(num, 1) num = s3.get_out_size() self.assertEqual(num, 1) num = n4.get_out_size() self.assertEqual(num, 6) num = n5.get_out_size() self.assertEqual(num, 3) num = n6.get_out_size() self.assertEqual(num, 12)
def test_get_scope(self): """Computing the scope of nodes of the SPN graph""" # Create graph v12 = spn.IVs(num_vars=2, num_vals=4, name="V12") v34 = spn.ContVars(num_vars=2, name="V34") s1 = spn.Sum((v12, [0, 1, 2, 3]), name="S1") s2 = spn.Sum((v12, [4, 5, 6, 7]), name="S2") p1 = spn.Product((v12, [0, 7]), name="P1") p2 = spn.Product((v12, [3, 4]), name="P2") p3 = spn.Product(v34, name="P3") n1 = spn.Concat(s1, s2, p3, name="N1") n2 = spn.Concat(p1, p2, name="N2") p4 = spn.Product((n1, [0]), (n1, [1]), name="P4") p5 = spn.Product((n2, [0]), (n1, [2]), name="P5") s3 = spn.Sum(p4, n2, name="S3") p6 = spn.Product(s3, (n1, [2]), name="P6") s4 = spn.Sum(p5, p6, name="S4") # Test self.assertListEqual(v12.get_scope(), [spn.Scope(v12, 0), spn.Scope(v12, 0), spn.Scope(v12, 0), spn.Scope(v12, 0), spn.Scope(v12, 1), spn.Scope(v12, 1), spn.Scope(v12, 1), spn.Scope(v12, 1)]) self.assertListEqual(v34.get_scope(), [spn.Scope(v34, 0), spn.Scope(v34, 1)]) self.assertListEqual(s1.get_scope(), [spn.Scope(v12, 0)]) self.assertListEqual(s2.get_scope(), [spn.Scope(v12, 1)]) self.assertListEqual(p1.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1)]) self.assertListEqual(p2.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1)]) self.assertListEqual(p3.get_scope(), [spn.Scope(v34, 0) | spn.Scope(v34, 1)]) self.assertListEqual(n1.get_scope(), [spn.Scope(v12, 0), spn.Scope(v12, 1), spn.Scope(v34, 0) | spn.Scope(v34, 1)]) self.assertListEqual(n2.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1), spn.Scope(v12, 0) | spn.Scope(v12, 1)]) self.assertListEqual(p4.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1)]) self.assertListEqual(p5.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1)]) self.assertListEqual(s3.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1)]) self.assertListEqual(p6.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1)]) self.assertListEqual(s4.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1)])
def test(inpt, feed, true_output): with self.subTest(inputs=inpt, feed=feed): n = spn.Concat(inpt) op, = n._gather_input_tensors(n.inputs[0].node.get_value()) with self.test_session() as sess: out = sess.run(op, feed_dict=feed) np.testing.assert_array_equal(out, np.array(true_output))
def test_compute_mpe_path(self): v12 = spn.IndicatorLeaf(num_vars=2, num_vals=4) v34 = spn.RawLeaf(num_vars=2) v5 = spn.RawLeaf(num_vars=1) p = spn.Concat((v12, [0, 5]), v34, (v12, [3]), v5) counts = tf.placeholder(tf.float32, shape=(None, 6)) op = p._compute_log_mpe_path(tf.identity(counts), v12.get_value(), v34.get_value(), v12.get_value(), v5.get_value()) feed = np.r_[:18].reshape(-1, 6) with self.test_session() as sess: out = sess.run(op, feed_dict={counts: feed}) np.testing.assert_array_almost_equal( out[0], np.array([[0., 0., 0., 0., 0., 1., 0., 0.], [6., 0., 0., 0., 0., 7., 0., 0.], [12., 0., 0., 0., 0., 13., 0., 0.]], dtype=np.float32)) np.testing.assert_array_almost_equal( out[1], np.array([[2., 3.], [8., 9.], [14., 15.]], dtype=np.float32)) np.testing.assert_array_almost_equal( out[2], np.array([[0., 0., 0., 4., 0., 0., 0., 0.], [0., 0., 0., 10., 0., 0., 0., 0.], [0., 0., 0., 16., 0., 0., 0., 0.]], dtype=np.float32)) np.testing.assert_array_almost_equal( out[3], np.array([[5.], [11.], [17.]], dtype=np.float32))
def concat_layer_and_test(inputs, name): """ Create a concat node, generate its scopes and assert whether it is correct """ scope = [] for inp in inputs: if isinstance(inp, tuple): indices = inp[1] if isinstance(inp[1], int): indices = [inp[1]] for i in indices: scope.append(scopes_per_node[inp[0]][i]) else: scope.extend(scopes_per_node[inp]) concat = spn.Concat(*inputs, name=name) self.assertListEqual(concat.get_scope(), scope) scopes_per_node[concat] = scope return concat
def test_generte_set_errors(self): """Detecting structure errors in __generate_set""" gen = spn.DenseSPNGenerator(num_decomps=2, num_subsets=3, num_mixtures=2) v1 = spn.IVs(num_vars=2, num_vals=4) v2 = spn.ContVars(num_vars=3, name="ContVars1") v3 = spn.ContVars(num_vars=2, name="ContVars2") s1 = spn.Sum(v3, v2) n1 = spn.Concat(v2) with self.assertRaises(spn.StructureError): gen._DenseSPNGenerator__generate_set([ spn.Input(v1, [0, 3, 2, 6, 7]), spn.Input(v2, [1, 2]), spn.Input(s1, None), spn.Input(n1, None) ])
def test_generte_set(self): """Generation of sets of inputs with __generate_set""" gen = spn.DenseSPNGenerator(num_decomps=2, num_subsets=3, num_mixtures=2) v1 = spn.IVs(num_vars=2, num_vals=4) v2 = spn.ContVars(num_vars=3, name="ContVars1") v3 = spn.ContVars(num_vars=2, name="ContVars2") s1 = spn.Sum(v3) n1 = spn.Concat(v2) out = gen._DenseSPNGenerator__generate_set([ spn.Input(v1, [0, 3, 2, 6, 7]), spn.Input(v2, [1, 2]), spn.Input(s1, None), spn.Input(n1, None) ]) # scope_dict: # Scope({IVs(0x7f00cb4049b0):0}): {(IVs(0x7f00cb4049b0), 0), # (IVs(0x7f00cb4049b0), 2), # (IVs(0x7f00cb4049b0), 3)}, # Scope({IVs(0x7f00cb4049b0):1}): {(IVs(0x7f00cb4049b0), 7), # (IVs(0x7f00cb4049b0), 6)}, # Scope({ContVars1(0x7f00b7982ef0):1}): {(Concat(0x7f00cb404d68), 1), # (ContVars1(0x7f00b7982ef0), 1)}, # Scope({ContVars1(0x7f00b7982ef0):2}): {(Concat(0x7f00cb404d68), 2), # (ContVars1(0x7f00b7982ef0), 2)}, # Scope({ContVars1(0x7f00b7982ef0):0}): {(Concat(0x7f00cb404d68), 0)}, # Scope({ContVars2(0x7f00cb391eb8):0, ContVars2(0x7f00cb391eb8):1}): { # (Sum(0x7f00cb404a90), 0)}} # Since order is undetermined, we check items self.assertEqual(len(out), 6) self.assertIn(tuple(sorted([(v2, 1), (n1, 1)])), out) self.assertIn(tuple(sorted([(v2, 2), (n1, 2)])), out) self.assertIn(tuple(sorted([(n1, 0)])), out) self.assertIn(tuple(sorted([(v1, 0), (v1, 2), (v1, 3)])), out) self.assertIn(tuple(sorted([(v1, 6), (v1, 7)])), out) self.assertIn(tuple(sorted([(s1, 0)])), out)
def test(inputs, feed, output): with self.subTest(inputs=inputs, feed=feed): n = spn.Concat(*inputs) op = n.get_value(spn.InferenceType.MARGINAL) op_log = n.get_log_value(spn.InferenceType.MARGINAL) op_mpe = n.get_value(spn.InferenceType.MPE) op_log_mpe = n.get_log_value(spn.InferenceType.MPE) with self.test_session() as sess: out = sess.run(op, feed_dict=feed) out_log = sess.run(tf.exp(op_log), feed_dict=feed) out_mpe = sess.run(op_mpe, feed_dict=feed) out_log_mpe = sess.run(tf.exp(op_log_mpe), feed_dict=feed) np.testing.assert_array_almost_equal( out, np.array(output, dtype=spn.conf.dtype.as_numpy_dtype())) np.testing.assert_array_almost_equal( out_log, np.array(output, dtype=spn.conf.dtype.as_numpy_dtype())) np.testing.assert_array_almost_equal( out_mpe, np.array(output, dtype=spn.conf.dtype.as_numpy_dtype())) np.testing.assert_array_almost_equal( out_log_mpe, np.array(output, dtype=spn.conf.dtype.as_numpy_dtype()))
def test_generte_set(self): """Generation of sets of inputs with __generate_set""" gen = spn.DenseSPNGenerator(num_decomps=2, num_subsets=3, num_mixtures=2) v1 = spn.IVs(num_vars=2, num_vals=4) v2 = spn.ContVars(num_vars=3, name="ContVars1") v3 = spn.ContVars(num_vars=2, name="ContVars2") s1 = spn.Sum(v3) n1 = spn.Concat(v2) out = gen._DenseSPNGenerator__generate_set([ spn.Input(v1, [0, 3, 2, 6, 7]), spn.Input(v2, [1, 2]), spn.Input(s1, None), spn.Input(n1, None) ]) # Since order is undetermined, we check items self.assertEqual(len(out), 6) self.assertIn(tuple(sorted([(v2, 1), (n1, 1)])), out) self.assertIn(tuple(sorted([(v2, 2), (n1, 2)])), out) self.assertIn(tuple(sorted([(n1, 0)])), out) self.assertIn(tuple(sorted([(v1, 0), (v1, 2), (v1, 3)])), out) self.assertIn(tuple(sorted([(v1, 6), (v1, 7)])), out) self.assertIn(tuple(sorted([(s1, 0)])), out)
def test_gather_input_tensors(self): def test(inpt, feed, true_output): with self.subTest(inputs=inpt, feed=feed): n = spn.Concat(inpt) op, = n._gather_input_tensors(n.inputs[0].node.get_value()) with self.test_session() as sess: out = sess.run(op, feed_dict=feed) np.testing.assert_array_equal(out, np.array(true_output)) v1 = spn.ContVars(num_vars=3) v2 = spn.ContVars(num_vars=1) # Disconnected input n = spn.Concat(None) op, = n._gather_input_tensors(3) self.assertIs(op, None) # None input tensor n = spn.Concat((v1, 1)) op, = n._gather_input_tensors(None) self.assertIs(op, None) # Gathering for indices specified test((v1, [0, 2, 1]), {v1: [[1, 2, 3], [4, 5, 6]]}, [[1.0, 3.0, 2.0], [4.0, 6.0, 5.0]]) test((v1, [0, 2]), {v1: [[1, 2, 3], [4, 5, 6]]}, [[1.0, 3.0], [4.0, 6.0]]) test((v1, [1]), {v1: [[1, 2, 3], [4, 5, 6]]}, [[2.0], [5.0]]) test((v1, [0, 2, 1]), {v1: [[1, 2, 3]]}, [[1.0, 3.0, 2.0]]) test((v1, [0, 2]), {v1: [[1, 2, 3]]}, [[1.0, 3.0]]) test((v1, [1]), {v1: [[1, 2, 3]]}, [[2.0]]) # Test that if None indices, it passes the tensor directly n = spn.Concat(v1) t = tf.constant([1, 2, 3]) op, = n._gather_input_tensors(t) self.assertIs(op, t) # Gathering for None indices test(v1, {v1: [[1, 2, 3], [4, 5, 6]]}, [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) test((v1, None), {v1: [[1, 2, 3], [4, 5, 6]]}, [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) test(v1, {v1: [[1, 2, 3]]}, [[1.0, 2.0, 3.0]]) test((v1, None), {v1: [[1, 2, 3]]}, [[1.0, 2.0, 3.0]]) # Gathering for single index specified test((v1, 1), {v1: [[1, 2, 3], [4, 5, 6]]}, [[2.0], [5.0]]) test((v1, [1]), {v1: [[1, 2, 3], [4, 5, 6]]}, [[2.0], [5.0]]) test((v1, 1), {v1: [[1, 2, 3]]}, [[2.0]]) test((v1, [1]), {v1: [[1, 2, 3]]}, [[2.0]]) # Gathering for one element input, index specified test((v2, 0), {v2: [[1], [4]]}, [[1.0], [4.0]]) test((v2, [0]), {v2: [[1], [4]]}, [[1.0], [4.0]]) test((v2, 0), {v2: [[1]]}, [[1.0]]) test((v2, [0]), {v2: [[1]]}, [[1.0]]) # Gathering for one element input, None indices test(v2, {v2: [[1], [4]]}, [[1.0], [4.0]]) test((v2, None), {v2: [[1], [4]]}, [[1.0], [4.0]]) test(v2, {v2: [[1]]}, [[1.0]]) test((v2, None), {v2: [[1]]}, [[1.0]])
def test_comput_scope(self): """Calculating scope of Product""" # Create a graph v12 = spn.IndicatorLeaf(num_vars=2, num_vals=4, name="V12") v34 = spn.RawLeaf(num_vars=2, name="V34") s1 = spn.Sum((v12, [0, 1, 2, 3]), name="S1") s1.generate_latent_indicators() s2 = spn.Sum((v12, [4, 5, 6, 7]), name="S2") p1 = spn.Product((v12, [0, 7]), name="P1") p2 = spn.Product((v12, [3, 4]), name="P1") p3 = spn.Product(v34, name="P3") n1 = spn.Concat(s1, s2, p3, name="N1") n2 = spn.Concat(p1, p2, name="N2") p4 = spn.Product((n1, [0]), (n1, [1]), name="P4") p5 = spn.Product((n2, [0]), (n1, [2]), name="P5") s3 = spn.Sum(p4, n2, name="S3") p6 = spn.Product(s3, (n1, [2]), name="P6") s4 = spn.Sum(p5, p6, name="S4") s4.generate_latent_indicators() # Test self.assertListEqual(v12.get_scope(), [spn.Scope(v12, 0), spn.Scope(v12, 0), spn.Scope(v12, 0), spn.Scope(v12, 0), spn.Scope(v12, 1), spn.Scope(v12, 1), spn.Scope(v12, 1), spn.Scope(v12, 1)]) self.assertListEqual(v34.get_scope(), [spn.Scope(v34, 0), spn.Scope(v34, 1)]) self.assertListEqual(s1.get_scope(), [spn.Scope(v12, 0) | spn.Scope(s1.latent_indicators.node, 0)]) self.assertListEqual(s2.get_scope(), [spn.Scope(v12, 1)]) self.assertListEqual(p1.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1)]) self.assertListEqual(p2.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1)]) self.assertListEqual(p3.get_scope(), [spn.Scope(v34, 0) | spn.Scope(v34, 1)]) self.assertListEqual(n1.get_scope(), [spn.Scope(v12, 0) | spn.Scope(s1.latent_indicators.node, 0), spn.Scope(v12, 1), spn.Scope(v34, 0) | spn.Scope(v34, 1)]) self.assertListEqual(n2.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1), spn.Scope(v12, 0) | spn.Scope(v12, 1)]) self.assertListEqual(p4.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(s1.latent_indicators.node, 0)]) self.assertListEqual(p5.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1)]) self.assertListEqual(s3.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(s1.latent_indicators.node, 0)]) self.assertListEqual(p6.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1) | spn.Scope(s1.latent_indicators.node, 0)]) self.assertListEqual(s4.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1) | spn.Scope(s1.latent_indicators.node, 0) | spn.Scope(s4.latent_indicators.node, 0)])
def test_comput_scope(self): """Calculating scope of PermuteProducts""" # Create graph v12 = spn.IndicatorLeaf(num_vars=2, num_vals=4, name="V12") v34 = spn.RawLeaf(num_vars=2, name="V34") s1 = spn.Sum((v12, [0, 1, 2, 3]), name="S1") s1.generate_latent_indicators() s2 = spn.Sum((v12, [4, 5, 6, 7]), name="S2") p1 = spn.Product((v12, [0, 7]), name="P1") p2 = spn.Product((v12, [3, 4]), name="P2") p3 = spn.Product(v34, name="P3") n1 = spn.Concat(s1, s2, p3, name="N1") n2 = spn.Concat(p1, p2, name="N2") pp1 = spn.PermuteProducts(n1, n2, name="PP1") # num_prods = 6 pp2 = spn.PermuteProducts((n1, [0, 1]), (n2, [0]), name="PP2") # num_prods = 2 pp3 = spn.PermuteProducts(n2, p3, name="PP3") # num_prods = 2 pp4 = spn.PermuteProducts(p2, p3, name="PP4") # num_prods = 1 pp5 = spn.PermuteProducts((n2, [0, 1]), name="PP5") # num_prods = 1 pp6 = spn.PermuteProducts(p3, name="PP6") # num_prods = 1 n3 = spn.Concat((pp1, [0, 2, 3]), pp2, pp4, name="N3") s3 = spn.Sum((pp1, [0, 2, 4]), (pp1, [1, 3, 5]), pp2, pp3, (pp4, 0), pp5, pp6, name="S3") s3.generate_latent_indicators() n4 = spn.Concat((pp3, [0, 1]), pp5, (pp6, 0), name="N4") pp7 = spn.PermuteProducts(n3, s3, n4, name="PP7") # num_prods = 24 pp8 = spn.PermuteProducts(n3, name="PP8") # num_prods = 1 pp9 = spn.PermuteProducts((n4, [0, 1, 2, 3]), name="PP9") # num_prods = 1 # Test self.assertListEqual(v12.get_scope(), [spn.Scope(v12, 0), spn.Scope(v12, 0), spn.Scope(v12, 0), spn.Scope(v12, 0), spn.Scope(v12, 1), spn.Scope(v12, 1), spn.Scope(v12, 1), spn.Scope(v12, 1)]) self.assertListEqual(v34.get_scope(), [spn.Scope(v34, 0), spn.Scope(v34, 1)]) self.assertListEqual(s1.get_scope(), [spn.Scope(v12, 0) | spn.Scope(s1.latent_indicators.node, 0)]) self.assertListEqual(s2.get_scope(), [spn.Scope(v12, 1)]) self.assertListEqual(p1.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1)]) self.assertListEqual(p2.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1)]) self.assertListEqual(p3.get_scope(), [spn.Scope(v34, 0) | spn.Scope(v34, 1)]) self.assertListEqual(n1.get_scope(), [spn.Scope(v12, 0) | spn.Scope(s1.latent_indicators.node, 0), spn.Scope(v12, 1), spn.Scope(v34, 0) | spn.Scope(v34, 1)]) self.assertListEqual(n2.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1), spn.Scope(v12, 0) | spn.Scope(v12, 1)]) self.assertListEqual(pp1.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(s1.latent_indicators.node, 0), spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(s1.latent_indicators.node, 0), spn.Scope(v12, 0) | spn.Scope(v12, 1), spn.Scope(v12, 0) | spn.Scope(v12, 1), spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1), spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1)]) self.assertListEqual(pp2.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(s1.latent_indicators.node, 0), spn.Scope(v12, 0) | spn.Scope(v12, 1)]) self.assertListEqual(pp3.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1), spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1)]) self.assertListEqual(pp4.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1)]) self.assertListEqual(pp5.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1)]) self.assertListEqual(pp6.get_scope(), [spn.Scope(v34, 0) | spn.Scope(v34, 1)]) self.assertListEqual(n3.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(s1.latent_indicators.node, 0), spn.Scope(v12, 0) | spn.Scope(v12, 1), spn.Scope(v12, 0) | spn.Scope(v12, 1), spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(s1.latent_indicators.node, 0), spn.Scope(v12, 0) | spn.Scope(v12, 1), spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1)]) self.assertListEqual(s3.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1) | spn.Scope(s1.latent_indicators.node, 0) | spn.Scope(s3.latent_indicators.node, 0)]) self.assertListEqual(n4.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1), spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1), spn.Scope(v12, 0) | spn.Scope(v12, 1), spn.Scope(v34, 0) | spn.Scope(v34, 1)]) self.assertListEqual(pp7.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1) | spn.Scope(s1.latent_indicators.node, 0) | spn.Scope(s3.latent_indicators.node, 0)] * 24) self.assertListEqual(pp8.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(s1.latent_indicators.node, 0) | spn.Scope(v34, 0) | spn.Scope(v34, 1)]) self.assertListEqual(pp9.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1)])
def test_comput_scope(self): """Calculating scope of ProductsLayer""" # Create graph v12 = spn.IndicatorLeaf(num_vars=2, num_vals=4, name="V12") v34 = spn.RawLeaf(num_vars=2, name="V34") s1 = spn.Sum((v12, [0, 1, 2, 3]), name="S1") s1.generate_latent_indicators() s2 = spn.Sum((v12, [4, 5, 6, 7]), name="S2") pl1 = spn.ProductsLayer((v12, [0, 5, 6, 7]), (v12, [3, 4]), v34, num_or_size_prods=[4, 3, 1], name="PL1") n1 = spn.Concat(s1, s2, (pl1, [2]), name="N1") n2 = spn.Concat((pl1, [0]), (pl1, [1]), name="N2") s3 = spn.Sum(pl1, name="S3") s3.generate_latent_indicators() pl2 = spn.ProductsLayer((n1, [0, 1]), (n1, 2), (n2, 0), (pl1, [1]), n2, s3, (n2, 1), s3, pl1, num_or_size_prods=[2, 3, 3, 5], name="PL2") s4 = spn.Sum((pl2, 0), n2, name="S4") s5 = spn.Sum(pl2, name="S5") s6 = spn.Sum((pl2, [1, 3]), name="S6") s6.generate_latent_indicators() pl3 = spn.ProductsLayer(s4, (n1, 2), num_or_size_prods=1, name="PL3") pl4 = spn.ProductsLayer(s4, s5, s6, s4, s5, s6, num_or_size_prods=2, name="PL4") # Test self.assertListEqual(v12.get_scope(), [ spn.Scope(v12, 0), spn.Scope(v12, 0), spn.Scope(v12, 0), spn.Scope(v12, 0), spn.Scope(v12, 1), spn.Scope(v12, 1), spn.Scope(v12, 1), spn.Scope(v12, 1) ]) self.assertListEqual( v34.get_scope(), [spn.Scope(v34, 0), spn.Scope(v34, 1)]) self.assertListEqual( s1.get_scope(), [spn.Scope(v12, 0) | spn.Scope(s1.latent_indicators.node, 0)]) self.assertListEqual(s2.get_scope(), [spn.Scope(v12, 1)]) self.assertListEqual(pl1.get_scope(), [ spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v12, 1) | spn.Scope(v12, 1), spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0), spn.Scope(v34, 1) ]) self.assertListEqual(n1.get_scope(), [ spn.Scope(v12, 0) | spn.Scope(s1.latent_indicators.node, 0), spn.Scope(v12, 1), spn.Scope(v34, 1) ]) self.assertListEqual(n2.get_scope(), [ spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v12, 1) | spn.Scope(v12, 1), spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) ]) self.assertListEqual(s3.get_scope(), [ spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1) | spn.Scope(s3.latent_indicators.node, 0) ]) self.assertListEqual(pl2.get_scope(), [ spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(s1.latent_indicators.node, 0), spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1), spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1) | spn.Scope(s3.latent_indicators.node, 0), spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1) | spn.Scope(s3.latent_indicators.node, 0) ]) self.assertListEqual(s4.get_scope(), [ spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(s1.latent_indicators.node, 0) | spn.Scope(v34, 0) ]) self.assertListEqual(s5.get_scope(), [ spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(s1.latent_indicators.node, 0) | spn.Scope(v34, 0) | spn.Scope(v34, 1) | spn.Scope(s3.latent_indicators.node, 0) ]) self.assertListEqual(s6.get_scope(), [ spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1) | spn.Scope(s3.latent_indicators.node, 0) | spn.Scope(s6.latent_indicators.node, 0) ]) self.assertListEqual(pl3.get_scope(), [ spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(s1.latent_indicators.node, 0) | spn.Scope(v34, 0) | spn.Scope(v34, 1) ]) self.assertListEqual(pl4.get_scope(), [ spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1) | spn.Scope(s1.latent_indicators.node, 0) | spn.Scope(s3.latent_indicators.node, 0) | spn.Scope(s6.latent_indicators.node, 0), spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1) | spn.Scope(s1.latent_indicators.node, 0) | spn.Scope(s3.latent_indicators.node, 0) | spn.Scope(s6.latent_indicators.node, 0) ])