def test_compute_valid(self): """Calculating validity of Sum""" # Without IVs v12 = spn.IVs(num_vars=2, num_vals=4) v34 = spn.ContVars(num_vars=2) s1 = spn.Sum((v12, [0, 1, 2, 3])) s2 = spn.Sum((v12, [0, 1, 2, 4])) s3 = spn.Sum((v12, [0, 1, 2, 3]), (v34, 0)) p1 = spn.Product((v12, [0, 5]), (v34, 0)) p2 = spn.Product((v12, [1, 6]), (v34, 0)) p3 = spn.Product((v12, [1, 6]), (v34, 1)) s4 = spn.Sum(p1, p2) s5 = spn.Sum(p1, p3) self.assertTrue(v12.is_valid()) self.assertTrue(v34.is_valid()) self.assertTrue(s1.is_valid()) self.assertFalse(s2.is_valid()) self.assertFalse(s3.is_valid()) self.assertTrue(s4.is_valid()) self.assertFalse(s5.is_valid()) # With IVS s6 = spn.Sum(p1, p2) s6.generate_ivs() self.assertTrue(s6.is_valid()) s7 = spn.Sum(p1, p2) s7.set_ivs(spn.ContVars(num_vars=2)) self.assertFalse(s7.is_valid()) s8 = spn.Sum(p1, p2) s8.set_ivs(spn.IVs(num_vars=2, num_vals=2)) with self.assertRaises(spn.StructureError): s8.is_valid() s9 = spn.Sum(p1, p2) s9.set_ivs((v12, [0, 3])) self.assertTrue(s9.is_valid())
def test_get_out_size(self): """Computing the sizes of the outputs of nodes in SPN graph""" # Generate graph v1 = spn.ContVars(num_vars=5) v2 = spn.ContVars(num_vars=5) v3 = spn.ContVars(num_vars=5) s1 = spn.Sum((v1, [1, 3]), (v1, [1, 4]), v2) # v1 included twice s2 = spn.Sum(v1, (v3, [0, 1, 2, 3, 4])) s3 = spn.Sum(v2, v3, v3) # v3 included twice n4 = spn.Concat(s1, v1) n5 = spn.Concat((v3, [0, 4]), s3) n6 = spn.Concat(n4, s2, n5, (n4, [0]), (n5, [1])) # n4 and n5 included twice # Test num = v1.get_out_size() self.assertEqual(num, 5) num = v2.get_out_size() self.assertEqual(num, 5) num = v3.get_out_size() self.assertEqual(num, 5) num = s1.get_out_size() self.assertEqual(num, 1) num = s2.get_out_size() self.assertEqual(num, 1) num = s3.get_out_size() self.assertEqual(num, 1) num = n4.get_out_size() self.assertEqual(num, 6) num = n5.get_out_size() self.assertEqual(num, 3) num = n6.get_out_size() self.assertEqual(num, 12)
def test_compute_mpe_path(self): v12 = spn.IVs(num_vars=2, num_vals=4) v34 = spn.ContVars(num_vars=2) v5 = spn.ContVars(num_vars=1) p = spn.Concat((v12, [0, 5]), v34, (v12, [3]), v5) counts = tf.placeholder(tf.float32, shape=(None, 6)) op = p._compute_log_mpe_path(tf.identity(counts), v12.get_value(), v34.get_value(), v12.get_value(), v5.get_value()) feed = np.r_[:18].reshape(-1, 6) with self.test_session() as sess: out = sess.run(op, feed_dict={counts: feed}) np.testing.assert_array_almost_equal( out[0], np.array([[0., 0., 0., 0., 0., 1., 0., 0.], [6., 0., 0., 0., 0., 7., 0., 0.], [12., 0., 0., 0., 0., 13., 0., 0.]], dtype=np.float32)) np.testing.assert_array_almost_equal( out[1], np.array([[2., 3.], [8., 9.], [14., 15.]], dtype=np.float32)) np.testing.assert_array_almost_equal( out[2], np.array([[0., 0., 0., 4., 0., 0., 0., 0.], [0., 0., 0., 10., 0., 0., 0., 0.], [0., 0., 0., 16., 0., 0., 0., 0.]], dtype=np.float32)) np.testing.assert_array_almost_equal( out[3], np.array([[5.], [11.], [17.]], dtype=np.float32))
def test_traverse_graph_nostop_noparams(self): """Traversing the whole graph excluding param nodes""" counter = [0] nodes = [None] * 10 def fun(node): nodes[counter[0]] = node counter[0] += 1 # Generate graph v1 = spn.ContVars(num_vars=1) v2 = spn.ContVars(num_vars=1) v3 = spn.ContVars(num_vars=1) s1 = spn.Sum(v1, v1, v2) # v1 included twice s2 = spn.Sum(v1, v3) s3 = spn.Sum(v2, v3, v3) # v3 included twice s4 = spn.Sum(s1, v1) s5 = spn.Sum(s2, v3, s3) s6 = spn.Sum(s4, s2, s5, s4, s5) # s4 and s5 included twice spn.generate_weights(s6) # Traverse spn.traverse_graph(s6, fun=fun, skip_params=True) # Test self.assertEqual(counter[0], 9) self.assertIs(nodes[0], s6) self.assertIs(nodes[1], s4) self.assertIs(nodes[2], s2) self.assertIs(nodes[3], s5) self.assertIs(nodes[4], s1) self.assertIs(nodes[5], v1) self.assertIs(nodes[6], v3) self.assertIs(nodes[7], s3) self.assertIs(nodes[8], v2)
def test_compute_graph_up_noconst(self): """Computing value assuming no constant functions""" # Number of times val_fun was called # Use list to avoid creating local fun variable during assignment counter = [0] def val_fun(node, *inputs): counter[0] += 1 if isinstance(node, spn.graph.node.VarNode): return 1 elif isinstance(node, spn.graph.node.ParamNode): return 0.1 else: weight_val, iv_val, *values = inputs return weight_val + sum(values) + 1 # Generate graph v1 = spn.ContVars(num_vars=1) v2 = spn.ContVars(num_vars=1) v3 = spn.ContVars(num_vars=1) s1 = spn.Sum(v1, v1, v2) # v1 included twice s2 = spn.Sum(v1, v3) s3 = spn.Sum(v2, v3, v3) # v3 included twice s4 = spn.Sum(s1, v1) s5 = spn.Sum(s2, v3, s3) s6 = spn.Sum(s4, s2, s5, s4, s5) # s4 and s5 included twice spn.generate_weights(s6) # Calculate value val = spn.compute_graph_up(s6, val_fun) # Test self.assertAlmostEqual(val, 35.2) self.assertEqual(counter[0], 15)
def test_compute_mpe_path(self): v12 = spn.IVs(num_vars=2, num_vals=4) v34 = spn.ContVars(num_vars=2) v5 = spn.ContVars(num_vars=1) p = spn.Product((v12, [0, 5]), v34, (v12, [3]), v5) counts = tf.placeholder(tf.float32, shape=(None, 1)) op = p._compute_mpe_path(tf.identity(counts), v12.get_value(), v34.get_value(), v12.get_value(), v5.get_value()) feed = [[0], [1], [2]] with tf.Session() as sess: out = sess.run(op, feed_dict={counts: feed}) np.testing.assert_array_almost_equal( out[0], np.array([[0., 0., 0., 0., 0., 0., 0., 0.], [1., 0., 0., 0., 0., 1., 0., 0.], [2., 0., 0., 0., 0., 2., 0., 0.]], dtype=np.float32)) np.testing.assert_array_almost_equal( out[1], np.array([[0., 0.], [1., 1.], [2., 2.]], dtype=np.float32)) np.testing.assert_array_almost_equal( out[2], np.array([[0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 1., 0., 0., 0., 0.], [0., 0., 0., 2., 0., 0., 0., 0.]], dtype=np.float32)) np.testing.assert_array_almost_equal( out[3], np.array([[0.], [1.], [2.]], dtype=np.float32))
def sumslayer_prepare_common(batch_size, factor, indices, input_sizes, ivs, same_inputs, sum_sizes): if indices: indices = [np.random.choice(list(range(size * factor)), size=size, replace=False) for size in input_sizes] else: factor = 1 indices = [np.arange(size) for size in input_sizes] if not same_inputs: input_nodes = [spn.ContVars(num_vars=size * factor) for size in input_sizes] values = [np.random.rand(batch_size, size * factor) for size in input_sizes] input_tuples = [(node, ind.tolist()) for node, ind in zip(input_nodes, indices)] feed_dict = {node: val for node, val in zip(input_nodes, values)} else: input_nodes = [spn.ContVars(num_vars=max(input_sizes) * factor)] values = [np.random.rand(batch_size, max(input_sizes) * factor)] * len(input_sizes) input_tuples = [(input_nodes[0], ind.tolist()) for ind in indices] feed_dict = {input_nodes[0]: values[0]} if 1 in sum_sizes: ivs = False if ivs: ivs = [np.random.randint(size, size=batch_size) for size in sum_sizes] weights = np.random.rand(sum(sum_sizes)) root_weights = np.random.rand(len(sum_sizes)) return feed_dict, indices, input_nodes, input_tuples, ivs, values, weights, root_weights
def test_compute_value(self): """Calculating value of Product""" def test(inputs, feed, output): with self.subTest(inputs=inputs, feed=feed): n = spn.Product(*inputs) op = n.get_value(spn.InferenceType.MARGINAL) op_log = n.get_log_value(spn.InferenceType.MARGINAL) op_mpe = n.get_value(spn.InferenceType.MPE) op_log_mpe = n.get_log_value(spn.InferenceType.MPE) with tf.Session() as sess: out = sess.run(op, feed_dict=feed) out_log = sess.run(tf.exp(op_log), feed_dict=feed) out_mpe = sess.run(op_mpe, feed_dict=feed) out_log_mpe = sess.run(tf.exp(op_log_mpe), feed_dict=feed) np.testing.assert_array_almost_equal( out, np.array(output, dtype=spn.conf.dtype.as_numpy_dtype())) np.testing.assert_array_almost_equal( out_log, np.array(output, dtype=spn.conf.dtype.as_numpy_dtype())) np.testing.assert_array_almost_equal( out_mpe, np.array(output, dtype=spn.conf.dtype.as_numpy_dtype())) np.testing.assert_array_almost_equal( out_log_mpe, np.array(output, dtype=spn.conf.dtype.as_numpy_dtype())) # Create inputs v1 = spn.ContVars(num_vars=3) v2 = spn.ContVars(num_vars=1) # Multiple inputs, multi-element batch test([v1, v2], { v1: [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], v2: [[0.7], [0.8]] }, [[0.1 * 0.2 * 0.3 * 0.7], [0.4 * 0.5 * 0.6 * 0.8]]) test([(v1, [0, 2]), (v2, [0])], { v1: [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], v2: [[0.7], [0.8]] }, [[0.1 * 0.3 * 0.7], [0.4 * 0.6 * 0.8]]) # Single input with 1 value, multi-element batch test([v2], {v2: [[0.1], [0.2]]}, [[0.1], [0.2]]) test([(v1, [1])], {v1: [[0.01, 0.1, 0.03], [0.02, 0.2, 0.04]]}, [[0.1], [0.2]]) # Multiple inputs, single-element batch test([v1, v2], { v1: [[0.1, 0.2, 0.3]], v2: [[0.7]] }, [[0.1 * 0.2 * 0.3 * 0.7]]) test([(v1, [0, 2]), (v2, [0])], { v1: [[0.1, 0.2, 0.3]], v2: [[0.7]] }, [[0.1 * 0.3 * 0.7]]) # Single input with 1 value, single-element batch test([v2], {v2: [[0.1]]}, [[0.1]]) test([(v1, [1])], {v1: [[0.01, 0.1, 0.03]]}, [[0.1]])
def test_get_num_nodes(self): """Computing the number of nodes in the SPN graph""" # Generate graph v1 = spn.ContVars(num_vars=1) v2 = spn.ContVars(num_vars=1) v3 = spn.ContVars(num_vars=1) s1 = spn.Sum(v1, v1, v2) # v1 included twice s2 = spn.Sum(v1, v3) s3 = spn.Sum(v2, v3, v3) # v3 included twice s4 = spn.Sum(s1, v1) s5 = spn.Sum(s2, v3, s3) s6 = spn.Sum(s4, s2, s5, s4, s5) # s4 and s5 included twice spn.generate_weights(s6) # Test num = v1.get_num_nodes(skip_params=True) self.assertEqual(num, 1) num = v1.get_num_nodes(skip_params=False) self.assertEqual(num, 1) num = v2.get_num_nodes(skip_params=True) self.assertEqual(num, 1) num = v2.get_num_nodes(skip_params=False) self.assertEqual(num, 1) num = v3.get_num_nodes(skip_params=True) self.assertEqual(num, 1) num = v3.get_num_nodes(skip_params=False) self.assertEqual(num, 1) num = s1.get_num_nodes(skip_params=True) self.assertEqual(num, 3) num = s1.get_num_nodes(skip_params=False) self.assertEqual(num, 4) num = s2.get_num_nodes(skip_params=True) self.assertEqual(num, 3) num = s2.get_num_nodes(skip_params=False) self.assertEqual(num, 4) num = s3.get_num_nodes(skip_params=True) self.assertEqual(num, 3) num = s3.get_num_nodes(skip_params=False) self.assertEqual(num, 4) num = s4.get_num_nodes(skip_params=True) self.assertEqual(num, 4) num = s4.get_num_nodes(skip_params=False) self.assertEqual(num, 6) num = s5.get_num_nodes(skip_params=True) self.assertEqual(num, 6) num = s5.get_num_nodes(skip_params=False) self.assertEqual(num, 9) num = s6.get_num_nodes(skip_params=True) self.assertEqual(num, 9) num = s6.get_num_nodes(skip_params=False) self.assertEqual(num, 15)
def test_compute_mpe_path_noivs(self): spn.conf.argmax_zero = True v12 = spn.IVs(num_vars=2, num_vals=4) v34 = spn.ContVars(num_vars=2) v5 = spn.ContVars(num_vars=1) s = spn.Sum((v12, [0, 5]), v34, (v12, [3]), v5) w = s.generate_weights() counts = tf.placeholder(tf.float32, shape=(None, 1)) op = s._compute_log_mpe_path(tf.identity(counts), w.get_log_value(), None, v12.get_log_value(), v34.get_log_value(), v12.get_log_value(), v5.get_log_value()) init = w.initialize() counts_feed = [[10], [11], [12], [13]] v12_feed = [[0, 1], [1, 1], [0, 0], [3, 3]] v34_feed = [[0.1, 0.2], [1.2, 0.2], [0.1, 0.2], [0.9, 0.8]] v5_feed = [[0.5], [0.5], [1.2], [0.9]] with self.test_session() as sess: sess.run(init) # Skip the IVs op out = sess.run(op[:1] + op[2:], feed_dict={ counts: counts_feed, v12: v12_feed, v34: v34_feed, v5: v5_feed }) # Weights np.testing.assert_array_almost_equal( np.squeeze(out[0]), np.array([[10., 0., 0., 0., 0., 0.], [0., 0., 11., 0., 0., 0.], [0., 0., 0., 0., 0., 12.], [0., 0., 0., 0., 13., 0.]], dtype=np.float32)) np.testing.assert_array_almost_equal( out[1], np.array([[10., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0.]], dtype=np.float32)) np.testing.assert_array_almost_equal( out[2], np.array([[0., 0.], [11., 0.], [0., 0.], [0., 0.]], dtype=np.float32)) np.testing.assert_array_almost_equal( out[3], np.array([[0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 13., 0., 0., 0., 0.]], dtype=np.float32)) np.testing.assert_array_almost_equal( out[4], np.array([[0.], [0.], [12.], [0.]], dtype=np.float32))
def test_input_flags(self): """Detection of different types of inputs""" inpt = spn.Input() self.assertFalse(inpt) self.assertFalse(inpt.is_op) self.assertFalse(inpt.is_var) self.assertFalse(inpt.is_param) n = spn.Sum() inpt = spn.Input(n) self.assertTrue(inpt) self.assertTrue(inpt.is_op) self.assertFalse(inpt.is_var) self.assertFalse(inpt.is_param) n = spn.ContVars() inpt = spn.Input(n) self.assertTrue(inpt) self.assertFalse(inpt.is_op) self.assertTrue(inpt.is_var) self.assertFalse(inpt.is_param) n = spn.Weights() inpt = spn.Input(n) self.assertTrue(inpt) self.assertFalse(inpt.is_op) self.assertFalse(inpt.is_var) self.assertTrue(inpt.is_param)
def test_is_valid_false(self): """Checking validity of the SPN""" # Create graph v12 = spn.IVs(num_vars=2, num_vals=4, name="V12") v34 = spn.ContVars(num_vars=2, name="V34") s1 = spn.Sum((v12, [0, 1, 2, 3]), name="S1") s2 = spn.Sum((v12, [4, 5, 6, 7]), name="S2") p1 = spn.Product((v12, [0, 7]), name="P1") p2 = spn.Product((v12, [2, 3, 4]), name="P2") p3 = spn.Product(v34, name="P3") n1 = spn.Concat(s1, s2, p3, name="N1") n2 = spn.Concat(p1, p2, name="N2") p4 = spn.Product((n1, [0]), (n1, [1]), name="P4") p5 = spn.Product((n2, [0]), (n1, [2]), name="P5") s3 = spn.Sum(p4, n2, name="S3") p6 = spn.Product(s3, (n1, [2]), name="P6") s4 = spn.Sum(p5, p6, name="S4") # Test self.assertTrue(v12.is_valid()) self.assertTrue(v34.is_valid()) self.assertTrue(s1.is_valid()) self.assertTrue(s2.is_valid()) self.assertTrue(p1.is_valid()) self.assertTrue(p3.is_valid()) self.assertTrue(p4.is_valid()) self.assertTrue(n1.is_valid()) self.assertFalse(p2.is_valid()) self.assertFalse(n2.is_valid()) self.assertFalse(s3.is_valid()) self.assertFalse(s4.is_valid()) self.assertFalse(p5.is_valid()) self.assertFalse(p6.is_valid())
def test_masked_weights(self): v12 = spn.IVs(num_vars=2, num_vals=4) v34 = spn.ContVars(num_vars=2) v5 = spn.ContVars(num_vars=1) s = spn.SumsLayer((v12, [0, 5]), v34, (v12, [3]), v5, (v12, [0, 5]), v34, (v12, [3]), v5, num_or_size_sums=[3, 1, 3, 4, 1]) s.generate_weights(initializer=tf.initializers.random_uniform(0.0, 1.0)) with self.test_session() as sess: sess.run(s.weights.node.initialize()) weights = sess.run(s.weights.node.variable) shape = [5, 4] self.assertEqual(shape, s.weights.node.variable.shape.as_list()) [self.assertEqual(weights[row, col], 0.0) for row, col in [(0, -1), (1, 1), (1, 2), (1, 3), (2, -1), (4, 1), (4, 2), (4, 3)]] self.assertAllClose(np.sum(weights, axis=1), np.ones(5))
def test_generte_set_errors(self): """Detecting structure errors in __generate_set""" gen = spn.DenseSPNGenerator(num_decomps=2, num_subsets=3, num_mixtures=2) v1 = spn.IVs(num_vars=2, num_vals=4) v2 = spn.ContVars(num_vars=3, name="ContVars1") v3 = spn.ContVars(num_vars=2, name="ContVars2") s1 = spn.Sum(v3, v2) n1 = spn.Concat(v2) with self.assertRaises(spn.StructureError): gen._DenseSPNGenerator__generate_set([ spn.Input(v1, [0, 3, 2, 6, 7]), spn.Input(v2, [1, 2]), spn.Input(s1, None), spn.Input(n1, None) ])
def test_compute_graph_up_const(self): """Computing value with constant function detection""" # Number of times val_fun was called # Use list to avoid creating local fun variable during assignment counter = [0] # Generate graph v1 = spn.ContVars(num_vars=1) v2 = spn.ContVars(num_vars=1) v3 = spn.ContVars(num_vars=1) s1 = spn.Sum(v1, v1, v2) # v1 included twice s2 = spn.Sum(v1, v3) s3 = spn.Sum(v2, v3, v3) # v3 included twice s4 = spn.Sum(s1, v1) s5 = spn.Sum(s2, v3, s3) s6 = spn.Sum(s4, s2, s5, s4, s5) # s4 and s5 included twice def val_fun(node, *inputs): counter[0] += 1 # s3 is not needed for calculations since only parent is s5 self.assertIsNot(node, s3) # Fixed value or compute using children if node == s5: return 16 else: if isinstance(node, spn.graph.node.VarNode): return 1 else: weight_val, iv_val, *values = inputs return sum(values) + 1 def const_fun(node): if node == s5: return True else: return False # Calculate value val = spn.compute_graph_up(s6, val_fun, const_fun) # Test self.assertEqual(val, 48) self.assertEqual(counter[0], 8)
def test_get_scope(self): """Computing the scope of nodes of the SPN graph""" # Create graph v12 = spn.IVs(num_vars=2, num_vals=4, name="V12") v34 = spn.ContVars(num_vars=2, name="V34") s1 = spn.Sum((v12, [0, 1, 2, 3]), name="S1") s2 = spn.Sum((v12, [4, 5, 6, 7]), name="S2") p1 = spn.Product((v12, [0, 7]), name="P1") p2 = spn.Product((v12, [3, 4]), name="P2") p3 = spn.Product(v34, name="P3") n1 = spn.Concat(s1, s2, p3, name="N1") n2 = spn.Concat(p1, p2, name="N2") p4 = spn.Product((n1, [0]), (n1, [1]), name="P4") p5 = spn.Product((n2, [0]), (n1, [2]), name="P5") s3 = spn.Sum(p4, n2, name="S3") p6 = spn.Product(s3, (n1, [2]), name="P6") s4 = spn.Sum(p5, p6, name="S4") # Test self.assertListEqual(v12.get_scope(), [spn.Scope(v12, 0), spn.Scope(v12, 0), spn.Scope(v12, 0), spn.Scope(v12, 0), spn.Scope(v12, 1), spn.Scope(v12, 1), spn.Scope(v12, 1), spn.Scope(v12, 1)]) self.assertListEqual(v34.get_scope(), [spn.Scope(v34, 0), spn.Scope(v34, 1)]) self.assertListEqual(s1.get_scope(), [spn.Scope(v12, 0)]) self.assertListEqual(s2.get_scope(), [spn.Scope(v12, 1)]) self.assertListEqual(p1.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1)]) self.assertListEqual(p2.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1)]) self.assertListEqual(p3.get_scope(), [spn.Scope(v34, 0) | spn.Scope(v34, 1)]) self.assertListEqual(n1.get_scope(), [spn.Scope(v12, 0), spn.Scope(v12, 1), spn.Scope(v34, 0) | spn.Scope(v34, 1)]) self.assertListEqual(n2.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1), spn.Scope(v12, 0) | spn.Scope(v12, 1)]) self.assertListEqual(p4.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1)]) self.assertListEqual(p5.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1)]) self.assertListEqual(s3.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1)]) self.assertListEqual(p6.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1)]) self.assertListEqual(s4.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1)])
def test_gather_input_scopes(self): v12 = spn.IVs(num_vars=2, num_vals=4, name="V12") v34 = spn.ContVars(num_vars=2, name="V34") s1 = spn.Sum(v12, v12, v34, (v12, [7, 3, 1, 0]), (v34, 0), name="S1") scopes_v12 = v12._compute_scope() scopes_v34 = v34._compute_scope() # Note: weights/ivs are disconnected, so None should be output these scopes = s1._gather_input_scopes(None, None, None, scopes_v12, scopes_v34, scopes_v12, scopes_v34) self.assertTupleEqual( scopes, (None, None, None, scopes_v12, scopes_v34, [ scopes_v12[7], scopes_v12[3], scopes_v12[1], scopes_v12[0] ], [scopes_v34[0]]))
def test_generte_set(self): """Generation of sets of inputs with __generate_set""" gen = spn.DenseSPNGenerator(num_decomps=2, num_subsets=3, num_mixtures=2) v1 = spn.IVs(num_vars=2, num_vals=4) v2 = spn.ContVars(num_vars=3, name="ContVars1") v3 = spn.ContVars(num_vars=2, name="ContVars2") s1 = spn.Sum(v3) n1 = spn.Concat(v2) out = gen._DenseSPNGenerator__generate_set([ spn.Input(v1, [0, 3, 2, 6, 7]), spn.Input(v2, [1, 2]), spn.Input(s1, None), spn.Input(n1, None) ]) # scope_dict: # Scope({IVs(0x7f00cb4049b0):0}): {(IVs(0x7f00cb4049b0), 0), # (IVs(0x7f00cb4049b0), 2), # (IVs(0x7f00cb4049b0), 3)}, # Scope({IVs(0x7f00cb4049b0):1}): {(IVs(0x7f00cb4049b0), 7), # (IVs(0x7f00cb4049b0), 6)}, # Scope({ContVars1(0x7f00b7982ef0):1}): {(Concat(0x7f00cb404d68), 1), # (ContVars1(0x7f00b7982ef0), 1)}, # Scope({ContVars1(0x7f00b7982ef0):2}): {(Concat(0x7f00cb404d68), 2), # (ContVars1(0x7f00b7982ef0), 2)}, # Scope({ContVars1(0x7f00b7982ef0):0}): {(Concat(0x7f00cb404d68), 0)}, # Scope({ContVars2(0x7f00cb391eb8):0, ContVars2(0x7f00cb391eb8):1}): { # (Sum(0x7f00cb404a90), 0)}} # Since order is undetermined, we check items self.assertEqual(len(out), 6) self.assertIn(tuple(sorted([(v2, 1), (n1, 1)])), out) self.assertIn(tuple(sorted([(v2, 2), (n1, 2)])), out) self.assertIn(tuple(sorted([(n1, 0)])), out) self.assertIn(tuple(sorted([(v1, 0), (v1, 2), (v1, 3)])), out) self.assertIn(tuple(sorted([(v1, 6), (v1, 7)])), out) self.assertIn(tuple(sorted([(s1, 0)])), out)
def test(num_vars, value): with self.subTest(num_vars=num_vars, value=value): n = spn.ContVars(num_vars=num_vars) op = n.get_value() op_log = n.get_log_value() with self.test_session() as sess: out = sess.run(op, feed_dict={n: value}) out_log = sess.run(tf.exp(op_log), feed_dict={n: value}) np.testing.assert_array_almost_equal( out, np.array(value, dtype=spn.conf.dtype.as_numpy_dtype())) np.testing.assert_array_almost_equal( out_log, np.array(value, dtype=spn.conf.dtype.as_numpy_dtype()))
def test_compute_valid(self): """Calculating validity of Product""" v12 = spn.IVs(num_vars=2, num_vals=4) v34 = spn.ContVars(num_vars=2) p1 = spn.Product((v12, [0, 5])) p2 = spn.Product((v12, [0, 3])) p3 = spn.Product((v12, [0, 5]), v34) p4 = spn.Product((v12, [0, 3]), v34) p5 = spn.Product((v12, [0, 5]), v34, (v12, 2)) self.assertTrue(p1.is_valid()) self.assertFalse(p2.is_valid()) self.assertTrue(p3.is_valid()) self.assertFalse(p4.is_valid()) self.assertFalse(p5.is_valid())
def test_traverse_graph_stop(self): """Traversing the graph until fun returns True""" counter = [0] nodes = [None] * 9 true_node_no = 4 # s5 def fun(node): nodes[counter[0]] = node counter[0] += 1 if counter[0] == true_node_no: return True # Generate graph v1 = spn.ContVars(num_vars=1) v2 = spn.ContVars(num_vars=1) v3 = spn.ContVars(num_vars=1) s1 = spn.Sum(v1, v1, v2) # v1 included twice s2 = spn.Sum(v1, v3) s3 = spn.Sum(v2, v3, v3) # v3 included twice s4 = spn.Sum(s1, v1) s5 = spn.Sum(s2, v3, s3) s6 = spn.Sum(s4, s2, s5, s4, s5) # s4 and s5 included twice # Traverse spn.traverse_graph(s6, fun=fun, skip_params=True) # Test self.assertEqual(counter[0], 4) self.assertIs(nodes[0], s6) self.assertIs(nodes[1], s4) self.assertIs(nodes[2], s2) self.assertIs(nodes[3], s5) self.assertIs(nodes[4], None) self.assertIs(nodes[5], None) self.assertIs(nodes[6], None) self.assertIs(nodes[7], None) self.assertIs(nodes[8], None)
def test_compute_valid(self): """Calculating validity of PermProducts""" v12 = spn.IVs(num_vars=2, num_vals=3) v345 = spn.IVs(num_vars=3, num_vals=3) v678 = spn.ContVars(num_vars=3) v910 = spn.ContVars(num_vars=2) p1 = spn.PermProducts((v12, [0, 1]), (v12, [4, 5])) p2 = spn.PermProducts((v12, [3, 5]), (v345, [0, 1, 2])) p3 = spn.PermProducts((v345, [0, 1, 2]), (v345, [3, 4, 5]), (v345, [6, 7, 8])) p4 = spn.PermProducts((v345, [6, 8]), (v678, [0, 1])) p5 = spn.PermProducts((v678, [1]), v910) p6 = spn.PermProducts(v678, v910) p7 = spn.PermProducts((v678, [0, 1, 2])) p8 = spn.PermProducts((v910, [0]), (v910, [1])) self.assertTrue(p1.is_valid()) self.assertTrue(p2.is_valid()) self.assertTrue(p3.is_valid()) self.assertTrue(p4.is_valid()) self.assertTrue(p5.is_valid()) self.assertTrue(p6.is_valid()) self.assertTrue(p7.is_valid()) self.assertTrue(p8.is_valid()) p9 = spn.PermProducts((v12, [0, 1]), (v12, [1, 2])) p10 = spn.PermProducts((v12, [3, 4, 5]), (v345, [0]), (v345, [0, 1, 2])) p11 = spn.PermProducts((v345, [3, 5]), (v678, [0]), (v678, [0])) p12 = spn.PermProducts((v910, [1]), (v910, [1])) p13 = spn.PermProducts(v910, v910) p14 = spn.PermProducts((v12, [0]), (v12, [1])) self.assertFalse(p9.is_valid()) self.assertFalse(p10.is_valid()) self.assertFalse(p11.is_valid()) self.assertFalse(p12.is_valid()) self.assertFalse(p13.is_valid()) self.assertEqual(p14.num_prods, 1) self.assertFalse(p14.is_valid())
def test(num_vars, value): with self.subTest(num_vars=num_vars, value=value): p = tf.placeholder(spn.conf.dtype, [None, num_vars]) n = spn.ContVars(feed=p, num_vars=num_vars) op = n.get_value() op_log = n.get_log_value() with tf.Session() as sess: out = sess.run(op, feed_dict={p: value}) out_log = sess.run(tf.exp(op_log), feed_dict={p: value}) np.testing.assert_array_almost_equal( out, np.array(value, dtype=spn.conf.dtype.as_numpy_dtype())) np.testing.assert_array_almost_equal( out_log, np.array(value, dtype=spn.conf.dtype.as_numpy_dtype()))
def test_generte_set(self): """Generation of sets of inputs with __generate_set""" gen = spn.DenseSPNGenerator(num_decomps=2, num_subsets=3, num_mixtures=2) v1 = spn.IVs(num_vars=2, num_vals=4) v2 = spn.ContVars(num_vars=3, name="ContVars1") v3 = spn.ContVars(num_vars=2, name="ContVars2") s1 = spn.Sum(v3) n1 = spn.Concat(v2) out = gen._DenseSPNGenerator__generate_set([ spn.Input(v1, [0, 3, 2, 6, 7]), spn.Input(v2, [1, 2]), spn.Input(s1, None), spn.Input(n1, None) ]) # Since order is undetermined, we check items self.assertEqual(len(out), 6) self.assertIn(tuple(sorted([(v2, 1), (n1, 1)])), out) self.assertIn(tuple(sorted([(v2, 2), (n1, 2)])), out) self.assertIn(tuple(sorted([(n1, 0)])), out) self.assertIn(tuple(sorted([(v1, 0), (v1, 2), (v1, 3)])), out) self.assertIn(tuple(sorted([(v1, 6), (v1, 7)])), out) self.assertIn(tuple(sorted([(s1, 0)])), out)
def _run_op_test(self, op_fun, inputs, indices=None, ivs=None, inf_type=spn.InferenceType.MARGINAL, log=False, on_gpu=True): """Run a single test for a single op.""" # Preparations op_name = op_fun.__name__ device_name = '/gpu:0' if on_gpu else '/cpu:0' # Print print2( "--> %s: on_gpu=%s, inputs_shape=%s, indices=%s, ivs=%s, inference=%s, log=%s" % (op_name, on_gpu, inputs.shape, ("No" if indices is None else "Yes"), ("No" if ivs is None else "Yes"), ("MPE" if inf_type == spn.InferenceType.MPE else "MARGINAL"), log), self.file) input_size = inputs.shape[1] # Compute true output true_out = self._true_output(op_fun, inputs, indices, ivs, inf_type) # Create graph tf.reset_default_graph() with tf.device(device_name): # Create input inputs_pl = spn.ContVars(num_vars=input_size) # Create IVs if ivs is None: ivs_pl = [None for _ in range(self.num_sums)] else: if op_fun is Ops.sum: ivs_pl = [ spn.IVs(num_vars=1, num_vals=input_size) for _ in range(self.num_sums) ] elif op_fun is Ops.par_sums or Ops.sums: ivs_pl = [ spn.IVs(num_vars=self.num_sums, num_vals=input_size) ] # Create ops start_time = time.time() init_ops, ops = op_fun(inputs_pl, indices, ivs_pl, self.num_sums, inf_type, log) for _ in range(self.num_ops - 1): # The tuple ensures that the next op waits for the output # of the previous op, effectively stacking the ops # but using the original input every time init_ops, ops = op_fun(inputs_pl, indices, ivs_pl, self.num_sums, inf_type, log, tf.tuple([ops])[0]) setup_time = time.time() - start_time # Get num of graph ops graph_size = len(tf.get_default_graph().get_operations()) # Run op multiple times output_correct = True with tf.Session(config=tf.ConfigProto( allow_soft_placement=False, log_device_placement=self.log_devs)) as sess: # Initialize weights of all the sum nodes in the graph start_time = time.time() init_ops.run() weights_init_time = time.time() - start_time run_times = [] # Create feed dictionary feed = {inputs_pl: inputs} if ivs is not None: for iv_pl in ivs_pl: feed[iv_pl] = ivs for n in range(self.num_runs): # Run start_time = time.time() out = sess.run(ops, feed_dict=feed) run_times.append(time.time() - start_time) # Test value try: np.testing.assert_array_almost_equal( out, (np.log(true_out) if log else true_out)) except AssertionError: output_correct = False self.test_failed = True if self.profile: # Add additional options to trace the session execution options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() out = sess.run(ops, feed_dict=feed, options=options, run_metadata=run_metadata) # Create the Timeline object, and write it to a json file fetched_timeline = timeline.Timeline(run_metadata.step_stats) chrome_trace = fetched_timeline.generate_chrome_trace_format() if not os.path.exists(self.profiles_dir): os.makedirs(self.profiles_dir) file_name = op_name file_name += ("_GPU" if on_gpu else "_CPU") file_name += ("_MPE-LOG" if log else "_MPE") if inf_type == \ spn.InferenceType.MPE else ("_MARGINAL-LOG" if log else "_MARGINAL") if indices is not None: file_name += "_Indices" if ivs is not None: file_name += "_IVS" with open( '%s/timeline_value_%s.json' % (self.profiles_dir, file_name), 'w') as f: f.write(chrome_trace) # Return stats return OpTestResult(op_name, on_gpu, graph_size, ("No" if indices is None else "Yes"), ("No" if ivs is None else "Yes"), setup_time, weights_init_time, run_times, output_correct)
def test_gather_input_tensors(self): def test(inpt, feed, true_output): with self.subTest(inputs=inpt, feed=feed): n = spn.Concat(inpt) op, = n._gather_input_tensors(n.inputs[0].node.get_value()) with self.test_session() as sess: out = sess.run(op, feed_dict=feed) np.testing.assert_array_equal(out, np.array(true_output)) v1 = spn.ContVars(num_vars=3) v2 = spn.ContVars(num_vars=1) # Disconnected input n = spn.Concat(None) op, = n._gather_input_tensors(3) self.assertIs(op, None) # None input tensor n = spn.Concat((v1, 1)) op, = n._gather_input_tensors(None) self.assertIs(op, None) # Gathering for indices specified test((v1, [0, 2, 1]), {v1: [[1, 2, 3], [4, 5, 6]]}, [[1.0, 3.0, 2.0], [4.0, 6.0, 5.0]]) test((v1, [0, 2]), {v1: [[1, 2, 3], [4, 5, 6]]}, [[1.0, 3.0], [4.0, 6.0]]) test((v1, [1]), {v1: [[1, 2, 3], [4, 5, 6]]}, [[2.0], [5.0]]) test((v1, [0, 2, 1]), {v1: [[1, 2, 3]]}, [[1.0, 3.0, 2.0]]) test((v1, [0, 2]), {v1: [[1, 2, 3]]}, [[1.0, 3.0]]) test((v1, [1]), {v1: [[1, 2, 3]]}, [[2.0]]) # Test that if None indices, it passes the tensor directly n = spn.Concat(v1) t = tf.constant([1, 2, 3]) op, = n._gather_input_tensors(t) self.assertIs(op, t) # Gathering for None indices test(v1, {v1: [[1, 2, 3], [4, 5, 6]]}, [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) test((v1, None), {v1: [[1, 2, 3], [4, 5, 6]]}, [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) test(v1, {v1: [[1, 2, 3]]}, [[1.0, 2.0, 3.0]]) test((v1, None), {v1: [[1, 2, 3]]}, [[1.0, 2.0, 3.0]]) # Gathering for single index specified test((v1, 1), {v1: [[1, 2, 3], [4, 5, 6]]}, [[2.0], [5.0]]) test((v1, [1]), {v1: [[1, 2, 3], [4, 5, 6]]}, [[2.0], [5.0]]) test((v1, 1), {v1: [[1, 2, 3]]}, [[2.0]]) test((v1, [1]), {v1: [[1, 2, 3]]}, [[2.0]]) # Gathering for one element input, index specified test((v2, 0), {v2: [[1], [4]]}, [[1.0], [4.0]]) test((v2, [0]), {v2: [[1], [4]]}, [[1.0], [4.0]]) test((v2, 0), {v2: [[1]]}, [[1.0]]) test((v2, [0]), {v2: [[1]]}, [[1.0]]) # Gathering for one element input, None indices test(v2, {v2: [[1], [4]]}, [[1.0], [4.0]]) test((v2, None), {v2: [[1], [4]]}, [[1.0], [4.0]]) test(v2, {v2: [[1]]}, [[1.0]]) test((v2, None), {v2: [[1]]}, [[1.0]])
def test_compute_scope(self): """Calculating scope of Sums""" # Create a graph v12 = spn.IVs(num_vars=2, num_vals=4, name="V12") v34 = spn.ContVars(num_vars=3, name="V34") scopes_per_node = { v12: [spn.Scope(v12, 0), spn.Scope(v12, 0), spn.Scope(v12, 0), spn.Scope(v12, 0), spn.Scope(v12, 1), spn.Scope(v12, 1), spn.Scope(v12, 1), spn.Scope(v12, 1)], v34: [spn.Scope(v34, 0), spn.Scope(v34, 1), spn.Scope(v34, 2)] } def generate_scopes_from_inputs(node, inputs, num_or_size_sums, ivs=False): # Create a flat list of scopes, where the scope elements of a single input # node are subsequent in the list flat_scopes = [] size = 0 for inp in inputs: if isinstance(inp, tuple) and inp[1]: input_indices = [inp[1]] if isinstance(inp[1], int) else inp[1] for i in input_indices: flat_scopes.append(scopes_per_node[inp[0]][i]) size += len(input_indices) elif not isinstance(inp, tuple): flat_scopes.extend(scopes_per_node[inp]) size += len(scopes_per_node[inp]) else: flat_scopes.extend(scopes_per_node[inp[0]]) size += len(scopes_per_node[inp[0]]) if isinstance(num_or_size_sums, int): num_or_size_sums = num_or_size_sums * [size // num_or_size_sums] new_scope = [] offset = 0 # For each sum generate the scope based on its size for i, s in enumerate(num_or_size_sums): scope = flat_scopes[offset] for j in range(1, s): scope |= flat_scopes[j + offset] offset += s if ivs: scope |= spn.Scope(node.ivs.node, i) new_scope.append(scope) scopes_per_node[node] = new_scope def sums_layer_and_test(inputs, num_or_size_sums, name, ivs=False): """ Create a sums layer, generate its correct scope and test """ sums_layer = spn.SumsLayer(*inputs, num_or_size_sums=num_or_size_sums, name=name) if ivs: sums_layer.generate_ivs() generate_scopes_from_inputs(sums_layer, inputs, num_or_size_sums, ivs=ivs) self.assertListEqual(sums_layer.get_scope(), scopes_per_node[sums_layer]) return sums_layer def concat_layer_and_test(inputs, name): """ Create a concat node, generate its scopes and assert whether it is correct """ scope = [] for inp in inputs: if isinstance(inp, tuple): indices = inp[1] if isinstance(inp[1], int): indices = [inp[1]] for i in indices: scope.append(scopes_per_node[inp[0]][i]) else: scope.extend(scopes_per_node[inp]) concat = spn.Concat(*inputs, name=name) self.assertListEqual(concat.get_scope(), scope) scopes_per_node[concat] = scope return concat ss1 = sums_layer_and_test( [(v12, [0, 1, 2, 3]), (v12, [1, 2, 5, 6]), (v12, [4, 5, 6, 7])], 3, "Ss1", ivs=True) ss2 = sums_layer_and_test([(v12, [6, 7]), (v34, 0)], num_or_size_sums=[1, 2], name="Ss2") ss3 = sums_layer_and_test([(v12, [3, 7]), (v34, 1), (v12, [4, 5, 6]), v34], num_or_size_sums=[1, 2, 2, 2, 2], name="Ss3") s1 = sums_layer_and_test([(v34, [1, 2])], num_or_size_sums=1, name="S1", ivs=True) concat_layer_and_test([(ss1, [0, 2]), (ss2, 0)], name="N1") concat_layer_and_test([(ss1, 1), ss3, s1], name="N2") n = concat_layer_and_test([(ss1, 0), ss2, (ss3, [0, 1]), s1], name="N3") sums_layer_and_test([(ss1, [1, 2]), ss2], num_or_size_sums=[2, 1, 1], name="Ss4") sums_layer_and_test([(ss1, [0, 2]), (n, [0, 1]), (ss3, [4, 2])], num_or_size_sums=[3, 2, 1], name="Ss5")
def test_compute_valid(self): """Calculating validity of Sums""" # Without IVs v12 = spn.IVs(num_vars=2, num_vals=4) v34 = spn.ContVars(num_vars=2) s1 = spn.SumsLayer((v12, [0, 1, 2, 3]), (v12, [0, 1, 2, 3]), (v12, [0, 1, 2, 3]), num_or_size_sums=3) self.assertTrue(s1.is_valid()) s2 = spn.SumsLayer((v12, [0, 1, 2, 4]), name="S2") s2b = spn.SumsLayer((v12, [0, 1, 2, 4]), num_or_size_sums=[3, 1], name="S2b") self.assertTrue(s2b.is_valid()) self.assertFalse(s2.is_valid()) s3 = spn.SumsLayer((v12, [0, 1, 2, 3]), (v34, 0), (v12, [0, 1, 2, 3]), (v34, 0), num_or_size_sums=2) s3b = spn.SumsLayer((v12, [0, 1, 2, 3]), (v34, 0), (v12, [0, 1, 2, 3]), (v34, 0), num_or_size_sums=[4, 1, 4, 1]) s3c = spn.SumsLayer((v12, [0, 1, 2, 3]), (v34, 0), (v12, [0, 1, 2, 3]), (v34, 0), num_or_size_sums=[4, 1, 5]) self.assertFalse(s3.is_valid()) self.assertTrue(s3b.is_valid()) self.assertFalse(s3c.is_valid()) p1 = spn.Product((v12, [0, 5]), (v34, 0)) p2 = spn.Product((v12, [1, 6]), (v34, 0)) p3 = spn.Product((v12, [1, 6]), (v34, 1)) s4 = spn.SumsLayer(p1, p2, p1, p2, num_or_size_sums=2) s5 = spn.SumsLayer(p1, p3, p1, p3, p1, p3, num_or_size_sums=3) s6 = spn.SumsLayer(p1, p2, p3, num_or_size_sums=[2, 1]) s7 = spn.SumsLayer(p1, p2, p3, num_or_size_sums=[1, 2]) s8 = spn.SumsLayer(p1, p2, p3, p2, p1, num_or_size_sums=[2, 1, 2]) self.assertTrue(s4.is_valid()) self.assertFalse(s5.is_valid()) # p1 and p3 different scopes self.assertTrue(s6.is_valid()) self.assertFalse(s7.is_valid()) # p2 and p3 different scopes self.assertTrue(s8.is_valid()) # With IVS s6 = spn.SumsLayer(p1, p2, p1, p2, p1, p2, num_or_size_sums=3) s6.generate_ivs() self.assertTrue(s6.is_valid()) s7 = spn.SumsLayer(p1, p2, num_or_size_sums=1) s7.set_ivs(spn.ContVars(num_vars=2)) self.assertFalse(s7.is_valid()) s7 = spn.SumsLayer(p1, p2, p3, num_or_size_sums=3) s7.set_ivs(spn.ContVars(num_vars=3)) self.assertTrue(s7.is_valid()) s7 = spn.SumsLayer(p1, p2, p3, num_or_size_sums=[2, 1]) s7.set_ivs(spn.ContVars(num_vars=3)) self.assertFalse(s7.is_valid()) s8 = spn.SumsLayer(p1, p2, p1, p2, num_or_size_sums=2) s8.set_ivs(spn.IVs(num_vars=3, num_vals=2)) with self.assertRaises(spn.StructureError): s8.is_valid() s9 = spn.SumsLayer(p1, p2, p1, p2, num_or_size_sums=[1, 3]) s9.set_ivs(spn.ContVars(num_vars=2)) with self.assertRaises(spn.StructureError): s9.is_valid() s9 = spn.SumsLayer(p1, p2, p1, p2, num_or_size_sums=[1, 3]) s9.set_ivs(spn.ContVars(num_vars=3)) with self.assertRaises(spn.StructureError): s9.is_valid() s9 = spn.SumsLayer(p1, p2, p1, p2, num_or_size_sums=2) s9.set_ivs(spn.IVs(num_vars=1, num_vals=4)) self.assertTrue(s9.is_valid()) s9 = spn.SumsLayer(p1, p2, p1, p2, num_or_size_sums=[1, 3]) s9.set_ivs(spn.IVs(num_vars=1, num_vals=4)) self.assertTrue(s9.is_valid()) s9 = spn.SumsLayer(p1, p2, p1, p2, num_or_size_sums=[1, 3]) s9.set_ivs(spn.IVs(num_vars=2, num_vals=2)) self.assertFalse(s9.is_valid()) s9 = spn.SumsLayer(p1, p2, p1, p2, num_or_size_sums=2) s9.set_ivs(spn.IVs(num_vars=2, num_vals=2)) self.assertTrue(s9.is_valid()) s9 = spn.SumsLayer(p1, p2, p1, p2, num_or_size_sums=[1, 2, 1]) s9.set_ivs(spn.IVs(num_vars=2, num_vals=2)) self.assertFalse(s9.is_valid()) s10 = spn.SumsLayer(p1, p2, p1, p2, num_or_size_sums=2) s10.set_ivs((v12, [0, 3, 5, 7])) self.assertTrue(s10.is_valid()) s10 = spn.SumsLayer(p1, p2, p1, p2, num_or_size_sums=[1, 2, 1]) s10.set_ivs((v12, [0, 3, 5, 7])) self.assertFalse(s10.is_valid())
def test_get_nodes(self): """Obtaining the list of nodes in the SPN graph""" # Generate graph v1 = spn.ContVars(num_vars=1) v2 = spn.ContVars(num_vars=1) v3 = spn.ContVars(num_vars=1) s1 = spn.Sum(v1, v1, v2) # v1 included twice s2 = spn.Sum(v1, v3) s3 = spn.Sum(v2, v3, v3) # v3 included twice s4 = spn.Sum(s1, v1) s5 = spn.Sum(s2, v3, s3) s6 = spn.Sum(s4, s2, s5, s4, s5) # s4 and s5 included twice spn.generate_weights(s6) # Test nodes = v1.get_nodes(skip_params=True) self.assertListEqual(nodes, [v1]) nodes = v1.get_nodes(skip_params=False) self.assertListEqual(nodes, [v1]) nodes = v2.get_nodes(skip_params=True) self.assertListEqual(nodes, [v2]) nodes = v2.get_nodes(skip_params=False) self.assertListEqual(nodes, [v2]) nodes = v3.get_nodes(skip_params=True) self.assertListEqual(nodes, [v3]) nodes = v3.get_nodes(skip_params=False) self.assertListEqual(nodes, [v3]) nodes = s1.get_nodes(skip_params=True) self.assertListEqual(nodes, [s1, v1, v2]) nodes = s1.get_nodes(skip_params=False) self.assertListEqual(nodes, [s1, s1.weights.node, v1, v2]) nodes = s2.get_nodes(skip_params=True) self.assertListEqual(nodes, [s2, v1, v3]) nodes = s2.get_nodes(skip_params=False) self.assertListEqual(nodes, [s2, s2.weights.node, v1, v3]) nodes = s3.get_nodes(skip_params=True) self.assertListEqual(nodes, [s3, v2, v3]) nodes = s3.get_nodes(skip_params=False) self.assertListEqual(nodes, [s3, s3.weights.node, v2, v3]) nodes = s4.get_nodes(skip_params=True) self.assertListEqual(nodes, [s4, s1, v1, v2]) nodes = s4.get_nodes(skip_params=False) self.assertListEqual(nodes, [s4, s4.weights.node, s1, v1, s1.weights.node, v2]) nodes = s5.get_nodes(skip_params=True) self.assertListEqual(nodes, [s5, s2, v3, s3, v1, v2]) nodes = s5.get_nodes(skip_params=False) self.assertListEqual(nodes, [s5, s5.weights.node, s2, v3, s3, s2.weights.node, v1, s3.weights.node, v2]) nodes = s6.get_nodes(skip_params=True) self.assertListEqual(nodes, [s6, s4, s2, s5, s1, v1, v3, s3, v2]) nodes = s6.get_nodes(skip_params=False) self.assertListEqual(nodes, [s6, s6.weights.node, s4, s2, s5, s4.weights.node, s1, v1, s2.weights.node, v3, s5.weights.node, s3, s1.weights.node, v2, s3.weights.node])
def test_input_conversion(self): """Conversion and verification of input specs in Input""" v1 = spn.ContVars(num_vars=5) # None inpt = spn.Input() self.assertIs(inpt.node, None) self.assertIs(inpt.indices, None) self.assertFalse(inpt) inpt = spn.Input(None) self.assertIs(inpt.node, None) self.assertIs(inpt.indices, None) self.assertFalse(inpt) inpt = spn.Input(None, [1, 2, 3]) self.assertIs(inpt.node, None) self.assertIs(inpt.indices, None) self.assertFalse(inpt) inpt = spn.Input.as_input(None) self.assertIs(inpt.node, None) self.assertIs(inpt.indices, None) self.assertFalse(inpt) inpt = spn.Input.as_input((None, [1, 2, 3])) self.assertIs(inpt.node, None) self.assertIs(inpt.indices, None) self.assertFalse(inpt) # Node inpt = spn.Input(v1) self.assertIs(inpt.node, v1) self.assertIs(inpt.indices, None) self.assertTrue(inpt) inpt = spn.Input.as_input(v1) self.assertIs(inpt.node, v1) self.assertIs(inpt.indices, None) self.assertTrue(inpt) # (Node, None) inpt = spn.Input(v1, None) self.assertIs(inpt.node, v1) self.assertIs(inpt.indices, None) self.assertTrue(inpt) inpt = spn.Input.as_input((v1, None)) self.assertIs(inpt.node, v1) self.assertIs(inpt.indices, None) self.assertTrue(inpt) # (Node, index) inpt = spn.Input(v1, 10) self.assertIs(inpt.node, v1) self.assertListEqual(inpt.indices, [10]) self.assertTrue(inpt) inpt = spn.Input.as_input((v1, 10)) self.assertIs(inpt.node, v1) self.assertListEqual(inpt.indices, [10]) self.assertTrue(inpt) # (Node, indices) inpt = spn.Input(v1, [10]) self.assertIs(inpt.node, v1) self.assertListEqual(inpt.indices, [10]) self.assertTrue(inpt) inpt = spn.Input.as_input((v1, [10])) self.assertIs(inpt.node, v1) self.assertListEqual(inpt.indices, [10]) self.assertTrue(inpt) inpt = spn.Input(v1, [10, 1, 20]) self.assertIs(inpt.node, v1) self.assertListEqual(inpt.indices, [10, 1, 20]) self.assertTrue(inpt) inpt = spn.Input.as_input((v1, [10, 1, 20])) self.assertIs(inpt.node, v1) self.assertListEqual(inpt.indices, [10, 1, 20]) self.assertTrue(inpt) # Checking type of input with self.assertRaises(TypeError): inpt = spn.Input(set()) with self.assertRaises(TypeError): inpt = spn.Input.as_input(set()) with self.assertRaises(TypeError): inpt = spn.Input.as_input((set())) with self.assertRaises(TypeError): inpt = spn.Input.as_input(tuple()) with self.assertRaises(TypeError): inpt = spn.Input(v1, set()) with self.assertRaises(TypeError): inpt = spn.Input(v1, set()) with self.assertRaises(TypeError): inpt = spn.Input.as_input((v1, set())) with self.assertRaises(TypeError): inpt = spn.Input.as_input((v1,)) # Detecting empty list with self.assertRaises(ValueError): inpt = spn.Input(v1, []) with self.assertRaises(ValueError): inpt = spn.Input.as_input((v1, [])) # Detecting incorrect indices with self.assertRaises(ValueError): inpt = spn.Input(v1, [1, set(), 2]) with self.assertRaises(ValueError): inpt = spn.Input.as_input((v1, [1, set(), 2])) with self.assertRaises(ValueError): inpt = spn.Input(v1, [1, -1, 2]) with self.assertRaises(ValueError): inpt = spn.Input.as_input((v1, [1, -1, 2])) # Detecting duplicate indices with self.assertRaises(ValueError): inpt = spn.Input(v1, [0, 1, 3, 1]) with self.assertRaises(ValueError): inpt = spn.Input.as_input((v1, [0, 1, 3, 1]))